mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 13:03:50 +00:00
142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
import asyncio
|
|
import base64
|
|
import fitz
|
|
import torch.multiprocessing as mp
|
|
import shutil
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from loguru import logger
|
|
from fastapi import HTTPException, FastAPI, UploadFile, File
|
|
import multiprocessing
|
|
from marker.output import save_markdown
|
|
from marker.convert import convert_single_pdf
|
|
from marker.models import load_all_models
|
|
import torch
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
import os
|
|
app = FastAPI()
|
|
model_lst = None
|
|
model_refs = None
|
|
temp_dir = "./temp"
|
|
os.environ['PROCESSES_PER_GPU'] = str(2)
|
|
|
|
def worker_init(counter, lock):
|
|
global model_lst
|
|
num_gpus = torch.cuda.device_count()
|
|
processes_per_gpu = int(os.environ.get('PROCESSES_PER_GPU', 1))
|
|
with lock:
|
|
worker_id = counter.value
|
|
counter.value += 1
|
|
if num_gpus == 0:
|
|
device = 'cpu'
|
|
else:
|
|
device_id = worker_id // processes_per_gpu
|
|
if device_id >= num_gpus:
|
|
raise ValueError(f"Worker ID {worker_id} exceeds available GPUs ({num_gpus}).")
|
|
device = f'cuda:{device_id}'
|
|
model_lst = load_all_models(device=device, dtype=torch.float32)
|
|
print(f"Worker {worker_id}: Models loaded successfully on {device}!")
|
|
for model in model_lst:
|
|
if model is None:
|
|
continue
|
|
model.share_memory()
|
|
|
|
def process_file_with_multiprocessing(temp_file_path):
|
|
global model_lst
|
|
full_text, images, out_meta = convert_single_pdf(temp_file_path, model_lst, batch_multiplier=1)
|
|
fname = os.path.basename(temp_file_path)
|
|
subfolder_path = save_markdown(r'./result', fname, full_text, images, out_meta)
|
|
md_content_with_base64_images = embed_images_as_base64(full_text, subfolder_path)
|
|
return md_content_with_base64_images, out_meta
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
try:
|
|
mp.set_start_method('spawn')
|
|
except RuntimeError:
|
|
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")
|
|
manager = multiprocessing.Manager()
|
|
worker_counter = manager.Value('i', 0)
|
|
worker_lock = manager.Lock()
|
|
global my_pool
|
|
gpu_count = torch.cuda.device_count()
|
|
my_pool = ProcessPoolExecutor(max_workers=gpu_count*int(os.environ.get('PROCESSES_PER_GPU', 1)), initializer=worker_init, initargs=(worker_counter, worker_lock))
|
|
|
|
yield
|
|
global temp_dir
|
|
if temp_dir and os.path.exists(temp_dir):
|
|
shutil.rmtree(temp_dir)
|
|
del model_lst
|
|
del model_refs
|
|
print("Application shutdown, cleaning up...")
|
|
|
|
app.router.lifespan_context = lifespan
|
|
|
|
@app.post("/v1/parse/file")
|
|
async def read_file(
|
|
file: UploadFile = File(...)):
|
|
try:
|
|
start_time = time.time()
|
|
global temp_dir
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
temp_file_path = os.path.join(temp_dir, file.filename)
|
|
with open(temp_file_path, "wb") as temp_file:
|
|
temp_file.write(await file.read())
|
|
pdf_document = fitz.open(temp_file_path)
|
|
total_pages = pdf_document.page_count
|
|
pdf_document.close()
|
|
global my_pool
|
|
loop = asyncio.get_event_loop()
|
|
md_content_with_base64_images, out_meta = await loop.run_in_executor(my_pool, process_file_with_multiprocessing, temp_file_path)
|
|
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
print(file.filename+"Total time:", duration)
|
|
return {
|
|
"success": True,
|
|
"message": "",
|
|
"data": {
|
|
"markdown": md_content_with_base64_images,
|
|
"page": total_pages,
|
|
"duration": duration
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise HTTPException(status_code=500, detail=f"错误信息: {str(e)}")
|
|
|
|
finally:
|
|
|
|
if temp_file_path and os.path.exists(temp_file_path):
|
|
os.remove(temp_file_path)
|
|
def img_to_base64(img_path):
|
|
with open(img_path, "rb") as img_file:
|
|
return base64.b64encode(img_file.read()).decode('utf-8')
|
|
def embed_images_as_base64(md_content, image_dir):
|
|
lines = md_content.split('\n')
|
|
new_lines = []
|
|
for line in lines:
|
|
if line.startswith("" in line:
|
|
start_idx = line.index("](") + 2
|
|
end_idx = line.index(")", start_idx)
|
|
img_rel_path = line[start_idx:end_idx]
|
|
|
|
img_name = os.path.basename(img_rel_path)
|
|
img_path = os.path.join(image_dir, img_name)
|
|
|
|
if os.path.exists(img_path):
|
|
img_base64 = img_to_base64(img_path)
|
|
new_line = f'{line[:start_idx]}data:image/png;base64,{img_base64}{line[end_idx:]}'
|
|
new_lines.append(new_line)
|
|
else:
|
|
new_lines.append(line)
|
|
else:
|
|
new_lines.append(line)
|
|
return '\n'.join(new_lines)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=7231)
|
|
|