submit ocr module (#2815)

This commit is contained in:
yiming-alicloud
2024-09-27 16:07:28 +08:00
committed by shilin66
parent 0e6877b0a1
commit 850382af7d
4 changed files with 283 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime
# please download the model from https://huggingface.co/vikp/surya_det3
# and https://huggingface.co/vikp/surya_rec2, and put it in the directory vikp/
COPY ./vikp ./vikp
COPY requirements.txt .
RUN python3 -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN python3 -m pip uninstall opencv-python -y
RUN python3 -m pip install opencv-python-headless -i https://pypi.tuna.tsinghua.edu.cn/simple
COPY app.py Dockerfile ./
ENTRYPOINT python3 app.py

120
python/suryaocr/README.md Normal file
View File

@@ -0,0 +1,120 @@
# 接入Surya OCR文本检测
## 源码部署
### 1. 安装环境
- Python 3.9+
- CUDA 11.8
- 科学上网环境
### 2. 安装依赖
```bash
pip install -r requirements.txt
```
### 3. 下载模型
代码首次运行时会自动从huggingface下载模型可跳过以下步骤。
也可以手动下载模型在对应代码目录下clone模型
```sh
mkdir vikp && cd vikp
git lfs install
git clone https://huggingface.co/vikp/surya_det3
# 镜像下载 https://hf-mirror.com/vikp/surya_det3
git clone https://huggingface.co/vikp/surya_rec2
# 镜像下载 https://hf-mirror.com/vikp/surya_rec2
```
最终手动下载的目录结构如下:
```
vikp/surya_det3
vikp/surya_rec2
app.py
Dockerfile
requirements.txt
```
### 4. 运行代码
```bash
python app.py
```
对应请求地址为
`http://0.0.0.0:7230/v1/surya_ocr`
### 5. 测试
```python
import requests
import base64
IMAGE_PATH = "your/path/to/image.png"
ACCESS_TOKEN = "your_access_token"
with open(IMAGE_PATH, 'rb') as img_file:
encoded_string = base64.b64encode(img_file.read())
encoded_image = encoded_string.decode('utf-8')
data = {"images": [encoded_image], "sorted": True}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {ACCESS_TOKEN}"
}
res = requests.post(url="http://0.0.0.0:7230/v1/surya_ocr",
headers=headers,
json=data)
print(res.text)
```
## docker部署
### 镜像获取
**本地编译镜像:**
```bash
docker build -t surya_ocr:v0.1 .
```
**或拉取线上镜像:**
Todo待发布
### docker-compose.yml示例
```yaml
version: '3'
services:
surya-ocr:
image: surya_ocr:v0.1
container_name: surya-ocr
# GPU运行环境如果宿主机未安装将deploy配置隐藏即可
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
ports:
- 7230:7230
environment:
- BATCH_SIZE=32
- ACCESS_TOKEN=YOUR_ACCESS_TOKEN
- LANGS='["zh","en"]'
```
**环境变量:**
```
BATCH_SIZE根据实际内存/显存情况配置每个batch约占用40MB的VRAMcpu默认32mps默认64cuda默认512
ACCESS_TOKEN服务的access_token
LANGS支持的语言列表默认["zh","en"]
```
## 接入FastGPT
Todo: 待补充

143
python/suryaocr/app.py Normal file
View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import base64
import io
import json
import logging
import os
from typing import List, Optional
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile
from pydantic import BaseModel
from surya.model.detection.model import load_model as load_det_model
from surya.model.detection.model import load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.ocr import run_ocr
from surya.schema import OCRResult
app = FastAPI()
security = HTTPBearer()
env_bearer_token = None
# GPU显存回收
def torch_gc():
if torch.cuda.is_available(): # 检查是否可用CUDA
torch.cuda.empty_cache() # 清空CUDA缓存
torch.cuda.ipc_collect() # 收集CUDA内存碎片
class ImageReq(BaseModel):
images: List[str]
sorted: Optional[bool] = False
class Singleton(type):
def __call__(cls, *args, **kwargs):
if not hasattr(cls, '_instance'):
cls._instance = super().__call__(*args, **kwargs)
return cls._instance
class Surya(metaclass=Singleton):
def __init__(self):
self.langs = json.loads(os.getenv("LANGS", '["zh", "en"]'))
self.batch_size = os.getenv("BATCH_SIZE")
if self.batch_size is not None:
self.batch_size = int(self.batch_size)
self.det_processor, self.det_model = load_det_processor(
), load_det_model()
self.rec_model, self.rec_processor = load_rec_model(
), load_rec_processor()
def run(self, image: ImageFile.ImageFile) -> List[OCRResult]:
predictions = run_ocr([image], [self.langs], self.det_model,
self.det_processor, self.rec_model,
self.rec_processor, self.batch_size)
return predictions
class Chat(object):
def __init__(self):
self.surya = Surya()
def base64_to_image(base64_string: str) -> ImageFile.ImageFile:
image_data = base64.b64decode(base64_string)
image_stream = io.BytesIO(image_data)
image = Image.open(image_stream)
return image
def sort_text_by_bbox(original_data: List[dict]) -> str:
# 根据bbox进行排序从左到右从上到下。返回排序后的按行的字符串。
# 排序
lines, line = [], []
original_data.sort(key=lambda item: item["bbox"][1])
for item in original_data:
mid_h = (item["bbox"][1] + item["bbox"][3]) / 2
if len(line) == 0 or (mid_h >= line[0]["bbox"][1]
and mid_h <= line[0]["bbox"][3]):
line.append(item)
else:
lines.append(line)
line = [item]
lines.append(line)
for line in lines:
line.sort(key=lambda item: item["bbox"][0])
# 构建行字符串
string_result = ""
for line in lines:
for item in line:
string_result += item["text"] + " "
string_result += "\n"
return string_result
def query_ocr(self, image_base64: str,
sorted: bool) -> List[OCRResult] | str:
if image_base64 is None or len(image_base64) == 0:
return []
image = Chat.base64_to_image(image_base64)
ocr_result = self.surya.run(image)
result = []
for text_line in ocr_result[0].text_lines:
result.append({"text": text_line.text, "bbox": text_line.bbox})
if sorted:
result = Chat.sort_text_by_bbox(result)
torch_gc()
return result
@app.post('/v1/surya_ocr')
async def handle_post_request(
image_req: ImageReq,
credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
chat = Chat()
try:
results = []
for image_base64 in image_req.images:
results.append(chat.query_ocr(image_base64, image_req.sorted))
return {"error": "success", "results": results}
except Exception as e:
logging.error(f"识别报错:{e}")
return {"error": "识别出错"}
if __name__ == "__main__":
env_bearer_token = os.getenv("ACCESS_TOKEN")
try:
uvicorn.run(app, host='0.0.0.0', port=7230)
except Exception as e:
logging.error(f"API启动失败报错{e}")

View File

@@ -0,0 +1,3 @@
surya-ocr==0.5.0
fastapi==0.104.1
uvicorn==0.17.6