mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-29 09:44:47 +00:00
141
python/pdf-marker/api_mp.py
Normal file
141
python/pdf-marker/api_mp.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import fitz
|
||||
import torch.multiprocessing as mp
|
||||
import shutil
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException, FastAPI, UploadFile, File
|
||||
import multiprocessing
|
||||
from marker.output import save_markdown
|
||||
from marker.convert import convert_single_pdf
|
||||
from marker.models import load_all_models
|
||||
import torch
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import os
|
||||
app = FastAPI()
|
||||
model_lst = None
|
||||
model_refs = None
|
||||
temp_dir = "./temp"
|
||||
os.environ['PROCESSES_PER_GPU'] = str(2)
|
||||
|
||||
def worker_init(counter, lock):
|
||||
global model_lst
|
||||
num_gpus = torch.cuda.device_count()
|
||||
processes_per_gpu = int(os.environ.get('PROCESSES_PER_GPU', 1))
|
||||
with lock:
|
||||
worker_id = counter.value
|
||||
counter.value += 1
|
||||
if num_gpus == 0:
|
||||
device = 'cpu'
|
||||
else:
|
||||
device_id = worker_id // processes_per_gpu
|
||||
if device_id >= num_gpus:
|
||||
raise ValueError(f"Worker ID {worker_id} exceeds available GPUs ({num_gpus}).")
|
||||
device = f'cuda:{device_id}'
|
||||
model_lst = load_all_models(device=device, dtype=torch.float32)
|
||||
print(f"Worker {worker_id}: Models loaded successfully on {device}!")
|
||||
for model in model_lst:
|
||||
if model is None:
|
||||
continue
|
||||
model.share_memory()
|
||||
|
||||
def process_file_with_multiprocessing(temp_file_path):
|
||||
global model_lst
|
||||
full_text, images, out_meta = convert_single_pdf(temp_file_path, model_lst, batch_multiplier=1)
|
||||
fname = os.path.basename(temp_file_path)
|
||||
subfolder_path = save_markdown(r'./result', fname, full_text, images, out_meta)
|
||||
md_content_with_base64_images = embed_images_as_base64(full_text, subfolder_path)
|
||||
return md_content_with_base64_images, out_meta
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
mp.set_start_method('spawn')
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")
|
||||
manager = multiprocessing.Manager()
|
||||
worker_counter = manager.Value('i', 0)
|
||||
worker_lock = manager.Lock()
|
||||
global my_pool
|
||||
gpu_count = torch.cuda.device_count()
|
||||
my_pool = ProcessPoolExecutor(max_workers=gpu_count*int(os.environ.get('PROCESSES_PER_GPU', 1)), initializer=worker_init, initargs=(worker_counter, worker_lock))
|
||||
|
||||
yield
|
||||
global temp_dir
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
del model_lst
|
||||
del model_refs
|
||||
print("Application shutdown, cleaning up...")
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
@app.post("/v1/parse/file")
|
||||
async def read_file(
|
||||
file: UploadFile = File(...)):
|
||||
try:
|
||||
start_time = time.time()
|
||||
global temp_dir
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(temp_dir, file.filename)
|
||||
with open(temp_file_path, "wb") as temp_file:
|
||||
temp_file.write(await file.read())
|
||||
pdf_document = fitz.open(temp_file_path)
|
||||
total_pages = pdf_document.page_count
|
||||
pdf_document.close()
|
||||
global my_pool
|
||||
loop = asyncio.get_event_loop()
|
||||
md_content_with_base64_images, out_meta = await loop.run_in_executor(my_pool, process_file_with_multiprocessing, temp_file_path)
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(file.filename+"Total time:", duration)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "",
|
||||
"data": {
|
||||
"markdown": md_content_with_base64_images,
|
||||
"page": total_pages,
|
||||
"duration": duration
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=f"错误信息: {str(e)}")
|
||||
|
||||
finally:
|
||||
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
def img_to_base64(img_path):
|
||||
with open(img_path, "rb") as img_file:
|
||||
return base64.b64encode(img_file.read()).decode('utf-8')
|
||||
def embed_images_as_base64(md_content, image_dir):
|
||||
lines = md_content.split('\n')
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
if line.startswith("" in line:
|
||||
start_idx = line.index("](") + 2
|
||||
end_idx = line.index(")", start_idx)
|
||||
img_rel_path = line[start_idx:end_idx]
|
||||
|
||||
img_name = os.path.basename(img_rel_path)
|
||||
img_path = os.path.join(image_dir, img_name)
|
||||
|
||||
if os.path.exists(img_path):
|
||||
img_base64 = img_to_base64(img_path)
|
||||
new_line = f'{line[:start_idx]}data:image/png;base64,{img_base64}{line[end_idx:]}'
|
||||
new_lines.append(new_line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
return '\n'.join(new_lines)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7231)
|
||||
|
Reference in New Issue
Block a user