mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 13:03:50 +00:00

* save toast * perf: surya ocr * perf: remove same model name * fix: indexes * perf: ip check * feat: Fixed the version number of the subapplication * feat: simple app get latest child version * perf: update child dispatch variables * feat: variables update doc
158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
import base64
|
||
import io
|
||
import json
|
||
import logging
|
||
import os
|
||
from typing import List, Optional
|
||
|
||
import torch
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Security
|
||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
from PIL import Image, ImageFile
|
||
from pydantic import BaseModel
|
||
from surya.model.detection.model import load_model as load_det_model
|
||
from surya.model.detection.model import load_processor as load_det_processor
|
||
from surya.model.recognition.model import load_model as load_rec_model
|
||
from surya.model.recognition.processor import load_processor as load_rec_processor
|
||
from surya.ocr import run_ocr
|
||
from surya.schema import OCRResult
|
||
import warnings
|
||
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
|
||
app = FastAPI()
|
||
security = HTTPBearer()
|
||
env_bearer_token = None
|
||
|
||
|
||
# GPU显存回收
|
||
def torch_gc():
|
||
if torch.cuda.is_available(): # 检查是否可用CUDA
|
||
torch.cuda.empty_cache() # 清空CUDA缓存
|
||
torch.cuda.ipc_collect() # 收集CUDA内存碎片
|
||
|
||
|
||
class ImageReq(BaseModel):
|
||
images: List[str]
|
||
sorted: Optional[bool] = False
|
||
|
||
|
||
class Singleton(type):
|
||
|
||
def __call__(cls, *args, **kwargs):
|
||
if not hasattr(cls, '_instance'):
|
||
cls._instance = super().__call__(*args, **kwargs)
|
||
return cls._instance
|
||
|
||
|
||
class Surya(metaclass=Singleton):
|
||
|
||
def __init__(self):
|
||
self.langs = json.loads(os.getenv("LANGS", '["zh", "en"]'))
|
||
self.batch_size = os.getenv("BATCH_SIZE")
|
||
if self.batch_size is not None:
|
||
self.batch_size = int(self.batch_size)
|
||
self.det_processor, self.det_model = load_det_processor(
|
||
), load_det_model()
|
||
self.rec_model, self.rec_processor = load_rec_model(
|
||
), load_rec_processor()
|
||
|
||
def run(self, image: ImageFile.ImageFile) -> List[OCRResult]:
|
||
predictions = run_ocr([image], [self.langs], self.det_model,
|
||
self.det_processor, self.rec_model,
|
||
self.rec_processor, self.batch_size)
|
||
return predictions
|
||
|
||
|
||
class Chat(object):
|
||
|
||
def __init__(self):
|
||
self.surya = Surya()
|
||
|
||
def base64_to_image(base64_string: str) -> ImageFile.ImageFile:
|
||
image_data = base64.b64decode(base64_string)
|
||
image_stream = io.BytesIO(image_data)
|
||
image = Image.open(image_stream)
|
||
return image
|
||
|
||
def sort_text_by_bbox(original_data: List[dict]) -> str:
|
||
# 根据bbox进行排序,从左到右,从上到下。返回排序后的按行的字符串。
|
||
# 排序
|
||
lines, line = [], []
|
||
original_data.sort(key=lambda item: item["bbox"][1])
|
||
for item in original_data:
|
||
mid_h = (item["bbox"][1] + item["bbox"][3]) / 2
|
||
if len(line) == 0 or (mid_h >= line[0]["bbox"][1]
|
||
and mid_h <= line[0]["bbox"][3]):
|
||
line.append(item)
|
||
else:
|
||
lines.append(line)
|
||
line = [item]
|
||
lines.append(line)
|
||
for line in lines:
|
||
line.sort(key=lambda item: item["bbox"][0])
|
||
# 构建行字符串
|
||
string_result = ""
|
||
for line in lines:
|
||
for item in line:
|
||
string_result += item["text"] + " "
|
||
string_result += "\n"
|
||
return string_result
|
||
|
||
def query_ocr(self, image_base64: str,
|
||
sorted: bool) -> str:
|
||
if image_base64 is None or len(image_base64) == 0:
|
||
return ""
|
||
try:
|
||
image = Chat.base64_to_image(image_base64)
|
||
ocr_result = self.surya.run(image)
|
||
result = []
|
||
|
||
for text_line in ocr_result[0].text_lines:
|
||
result.append(text_line.text)
|
||
|
||
if sorted:
|
||
result = self.sort_text_lines(result)
|
||
|
||
# 将所有文本行合并成一个字符串,用换行符分隔
|
||
final_result = "\n".join(result)
|
||
|
||
torch_gc()
|
||
return final_result
|
||
except Exception as e:
|
||
logging.error(f"OCR 处理失败: {e}")
|
||
raise HTTPException(status_code=400, detail=f"OCR 处理失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
def sort_text_lines(text_lines: List[str]) -> List[str]:
|
||
# 这里可以实现自定义的排序逻辑
|
||
# 目前只是简单地返回原始列表,因为我们没有位置信息来进行排序
|
||
return text_lines
|
||
|
||
@app.post('/v1/ocr/text')
|
||
async def handle_post_request(
|
||
image_req: ImageReq,
|
||
credentials: HTTPAuthorizationCredentials = Security(security)):
|
||
token = credentials.credentials
|
||
if env_bearer_token is not None and token != env_bearer_token:
|
||
raise HTTPException(status_code=401, detail="无效的令牌")
|
||
chat = Chat()
|
||
try:
|
||
results = []
|
||
for image_base64 in image_req.images:
|
||
results.append(chat.query_ocr(image_base64, image_req.sorted))
|
||
return {"error": None, "results": results}
|
||
except HTTPException as he:
|
||
raise he
|
||
except Exception as e:
|
||
logging.error(f"识别报错:{e}")
|
||
raise HTTPException(status_code=500, detail=f"识别出错: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
env_bearer_token = os.getenv("ACCESS_TOKEN")
|
||
try:
|
||
uvicorn.run(app, host='0.0.0.0', port=7230)
|
||
except Exception as e:
|
||
logging.error(f"API启动失败!报错:{e}")
|