mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-28 17:29:44 +00:00
submit ocr module (#2815)
This commit is contained in:

committed by
shilin66

parent
0e6877b0a1
commit
850382af7d
17
python/suryaocr/Dockerfile
Normal file
17
python/suryaocr/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime
|
||||
|
||||
# please download the model from https://huggingface.co/vikp/surya_det3
|
||||
# and https://huggingface.co/vikp/surya_rec2, and put it in the directory vikp/
|
||||
COPY ./vikp ./vikp
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN python3 -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN python3 -m pip uninstall opencv-python -y
|
||||
|
||||
RUN python3 -m pip install opencv-python-headless -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
COPY app.py Dockerfile ./
|
||||
|
||||
ENTRYPOINT python3 app.py
|
120
python/suryaocr/README.md
Normal file
120
python/suryaocr/README.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# 接入Surya OCR文本检测
|
||||
|
||||
## 源码部署
|
||||
|
||||
### 1. 安装环境
|
||||
|
||||
- Python 3.9+
|
||||
- CUDA 11.8
|
||||
- 科学上网环境
|
||||
|
||||
### 2. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 3. 下载模型
|
||||
|
||||
代码首次运行时会自动从huggingface下载模型,可跳过以下步骤。
|
||||
也可以手动下载模型,在对应代码目录下clone模型
|
||||
|
||||
```sh
|
||||
mkdir vikp && cd vikp
|
||||
|
||||
git lfs install
|
||||
|
||||
git clone https://huggingface.co/vikp/surya_det3
|
||||
# 镜像下载 https://hf-mirror.com/vikp/surya_det3
|
||||
|
||||
git clone https://huggingface.co/vikp/surya_rec2
|
||||
# 镜像下载 https://hf-mirror.com/vikp/surya_rec2
|
||||
```
|
||||
|
||||
最终手动下载的目录结构如下:
|
||||
|
||||
```
|
||||
vikp/surya_det3
|
||||
vikp/surya_rec2
|
||||
app.py
|
||||
Dockerfile
|
||||
requirements.txt
|
||||
```
|
||||
|
||||
### 4. 运行代码
|
||||
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
|
||||
对应请求地址为
|
||||
`http://0.0.0.0:7230/v1/surya_ocr`
|
||||
|
||||
### 5. 测试
|
||||
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
|
||||
IMAGE_PATH = "your/path/to/image.png"
|
||||
ACCESS_TOKEN = "your_access_token"
|
||||
|
||||
with open(IMAGE_PATH, 'rb') as img_file:
|
||||
encoded_string = base64.b64encode(img_file.read())
|
||||
encoded_image = encoded_string.decode('utf-8')
|
||||
data = {"images": [encoded_image], "sorted": True}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {ACCESS_TOKEN}"
|
||||
}
|
||||
res = requests.post(url="http://0.0.0.0:7230/v1/surya_ocr",
|
||||
headers=headers,
|
||||
json=data)
|
||||
|
||||
print(res.text)
|
||||
```
|
||||
|
||||
## docker部署
|
||||
|
||||
### 镜像获取
|
||||
|
||||
**本地编译镜像:**
|
||||
```bash
|
||||
docker build -t surya_ocr:v0.1 .
|
||||
```
|
||||
|
||||
**或拉取线上镜像:**
|
||||
Todo:待发布
|
||||
|
||||
### docker-compose.yml示例
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
surya-ocr:
|
||||
image: surya_ocr:v0.1
|
||||
container_name: surya-ocr
|
||||
# GPU运行环境,如果宿主机未安装,将deploy配置隐藏即可
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
ports:
|
||||
- 7230:7230
|
||||
environment:
|
||||
- BATCH_SIZE=32
|
||||
- ACCESS_TOKEN=YOUR_ACCESS_TOKEN
|
||||
- LANGS='["zh","en"]'
|
||||
```
|
||||
**环境变量:**
|
||||
```
|
||||
BATCH_SIZE:根据实际内存/显存情况配置,每个batch约占用40MB的VRAM,cpu默认32,mps默认64,cuda默认512
|
||||
ACCESS_TOKEN:服务的access_token
|
||||
LANGS:支持的语言列表,默认["zh","en"]
|
||||
```
|
||||
|
||||
## 接入FastGPT
|
||||
|
||||
Todo: 待补充
|
143
python/suryaocr/app.py
Normal file
143
python/suryaocr/app.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from PIL import Image, ImageFile
|
||||
from pydantic import BaseModel
|
||||
from surya.model.detection.model import load_model as load_det_model
|
||||
from surya.model.detection.model import load_processor as load_det_processor
|
||||
from surya.model.recognition.model import load_model as load_rec_model
|
||||
from surya.model.recognition.processor import load_processor as load_rec_processor
|
||||
from surya.ocr import run_ocr
|
||||
from surya.schema import OCRResult
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
env_bearer_token = None
|
||||
|
||||
|
||||
# GPU显存回收
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available(): # 检查是否可用CUDA
|
||||
torch.cuda.empty_cache() # 清空CUDA缓存
|
||||
torch.cuda.ipc_collect() # 收集CUDA内存碎片
|
||||
|
||||
|
||||
class ImageReq(BaseModel):
|
||||
images: List[str]
|
||||
sorted: Optional[bool] = False
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if not hasattr(cls, '_instance'):
|
||||
cls._instance = super().__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class Surya(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
self.langs = json.loads(os.getenv("LANGS", '["zh", "en"]'))
|
||||
self.batch_size = os.getenv("BATCH_SIZE")
|
||||
if self.batch_size is not None:
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.det_processor, self.det_model = load_det_processor(
|
||||
), load_det_model()
|
||||
self.rec_model, self.rec_processor = load_rec_model(
|
||||
), load_rec_processor()
|
||||
|
||||
def run(self, image: ImageFile.ImageFile) -> List[OCRResult]:
|
||||
predictions = run_ocr([image], [self.langs], self.det_model,
|
||||
self.det_processor, self.rec_model,
|
||||
self.rec_processor, self.batch_size)
|
||||
return predictions
|
||||
|
||||
|
||||
class Chat(object):
|
||||
|
||||
def __init__(self):
|
||||
self.surya = Surya()
|
||||
|
||||
def base64_to_image(base64_string: str) -> ImageFile.ImageFile:
|
||||
image_data = base64.b64decode(base64_string)
|
||||
image_stream = io.BytesIO(image_data)
|
||||
image = Image.open(image_stream)
|
||||
return image
|
||||
|
||||
def sort_text_by_bbox(original_data: List[dict]) -> str:
|
||||
# 根据bbox进行排序,从左到右,从上到下。返回排序后的按行的字符串。
|
||||
# 排序
|
||||
lines, line = [], []
|
||||
original_data.sort(key=lambda item: item["bbox"][1])
|
||||
for item in original_data:
|
||||
mid_h = (item["bbox"][1] + item["bbox"][3]) / 2
|
||||
if len(line) == 0 or (mid_h >= line[0]["bbox"][1]
|
||||
and mid_h <= line[0]["bbox"][3]):
|
||||
line.append(item)
|
||||
else:
|
||||
lines.append(line)
|
||||
line = [item]
|
||||
lines.append(line)
|
||||
for line in lines:
|
||||
line.sort(key=lambda item: item["bbox"][0])
|
||||
# 构建行字符串
|
||||
string_result = ""
|
||||
for line in lines:
|
||||
for item in line:
|
||||
string_result += item["text"] + " "
|
||||
string_result += "\n"
|
||||
return string_result
|
||||
|
||||
def query_ocr(self, image_base64: str,
|
||||
sorted: bool) -> List[OCRResult] | str:
|
||||
if image_base64 is None or len(image_base64) == 0:
|
||||
return []
|
||||
image = Chat.base64_to_image(image_base64)
|
||||
|
||||
ocr_result = self.surya.run(image)
|
||||
result = []
|
||||
|
||||
for text_line in ocr_result[0].text_lines:
|
||||
result.append({"text": text_line.text, "bbox": text_line.bbox})
|
||||
if sorted:
|
||||
result = Chat.sort_text_by_bbox(result)
|
||||
|
||||
torch_gc()
|
||||
return result
|
||||
|
||||
|
||||
@app.post('/v1/surya_ocr')
|
||||
async def handle_post_request(
|
||||
image_req: ImageReq,
|
||||
credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
chat = Chat()
|
||||
try:
|
||||
results = []
|
||||
for image_base64 in image_req.images:
|
||||
results.append(chat.query_ocr(image_base64, image_req.sorted))
|
||||
return {"error": "success", "results": results}
|
||||
except Exception as e:
|
||||
logging.error(f"识别报错:{e}")
|
||||
return {"error": "识别出错"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_bearer_token = os.getenv("ACCESS_TOKEN")
|
||||
try:
|
||||
uvicorn.run(app, host='0.0.0.0', port=7230)
|
||||
except Exception as e:
|
||||
logging.error(f"API启动失败!报错:{e}")
|
3
python/suryaocr/requirements.txt
Normal file
3
python/suryaocr/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
surya-ocr==0.5.0
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.17.6
|
Reference in New Issue
Block a user