Files
FastGPT/plugins/model/pdf-marker/api_mp.py
2025-02-27 10:30:43 +08:00

142 lines
4.9 KiB
Python

import asyncio
import base64
import fitz
import torch.multiprocessing as mp
import shutil
import time
from contextlib import asynccontextmanager
from loguru import logger
from fastapi import HTTPException, FastAPI, UploadFile, File
import multiprocessing
from marker.output import save_markdown
from marker.convert import convert_single_pdf
from marker.models import load_all_models
import torch
from concurrent.futures import ProcessPoolExecutor
import os
app = FastAPI()
model_lst = None
model_refs = None
temp_dir = "./temp"
os.environ['PROCESSES_PER_GPU'] = str(2)
def worker_init(counter, lock):
global model_lst
num_gpus = torch.cuda.device_count()
processes_per_gpu = int(os.environ.get('PROCESSES_PER_GPU', 1))
with lock:
worker_id = counter.value
counter.value += 1
if num_gpus == 0:
device = 'cpu'
else:
device_id = worker_id // processes_per_gpu
if device_id >= num_gpus:
raise ValueError(f"Worker ID {worker_id} exceeds available GPUs ({num_gpus}).")
device = f'cuda:{device_id}'
model_lst = load_all_models(device=device, dtype=torch.float32)
print(f"Worker {worker_id}: Models loaded successfully on {device}!")
for model in model_lst:
if model is None:
continue
model.share_memory()
def process_file_with_multiprocessing(temp_file_path):
global model_lst
full_text, images, out_meta = convert_single_pdf(temp_file_path, model_lst, batch_multiplier=1)
fname = os.path.basename(temp_file_path)
subfolder_path = save_markdown(r'./result', fname, full_text, images, out_meta)
md_content_with_base64_images = embed_images_as_base64(full_text, subfolder_path)
return md_content_with_base64_images, out_meta
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
mp.set_start_method('spawn')
except RuntimeError:
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")
manager = multiprocessing.Manager()
worker_counter = manager.Value('i', 0)
worker_lock = manager.Lock()
global my_pool
gpu_count = torch.cuda.device_count()
my_pool = ProcessPoolExecutor(max_workers=gpu_count*int(os.environ.get('PROCESSES_PER_GPU', 1)), initializer=worker_init, initargs=(worker_counter, worker_lock))
yield
global temp_dir
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
del model_lst
del model_refs
print("Application shutdown, cleaning up...")
app.router.lifespan_context = lifespan
@app.post("/v1/parse/file")
async def read_file(
file: UploadFile = File(...)):
try:
start_time = time.time()
global temp_dir
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, file.filename)
with open(temp_file_path, "wb") as temp_file:
temp_file.write(await file.read())
pdf_document = fitz.open(temp_file_path)
total_pages = pdf_document.page_count
pdf_document.close()
global my_pool
loop = asyncio.get_event_loop()
md_content_with_base64_images, out_meta = await loop.run_in_executor(my_pool, process_file_with_multiprocessing, temp_file_path)
end_time = time.time()
duration = end_time - start_time
print(file.filename+"Total time:", duration)
return {
"success": True,
"message": "",
"data": {
"markdown": md_content_with_base64_images,
"page": total_pages,
"duration": duration
}
}
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=f"错误信息: {str(e)}")
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
def img_to_base64(img_path):
with open(img_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
def embed_images_as_base64(md_content, image_dir):
lines = md_content.split('\n')
new_lines = []
for line in lines:
if line.startswith("![") and "](" in line and ")" in line:
start_idx = line.index("](") + 2
end_idx = line.index(")", start_idx)
img_rel_path = line[start_idx:end_idx]
img_name = os.path.basename(img_rel_path)
img_path = os.path.join(image_dir, img_name)
if os.path.exists(img_path):
img_base64 = img_to_base64(img_path)
new_line = f'{line[:start_idx]}data:image/png;base64,{img_base64}{line[end_idx:]}'
new_lines.append(new_line)
else:
new_lines.append(line)
else:
new_lines.append(line)
return '\n'.join(new_lines)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7231)