mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-29 01:40:51 +00:00
add mistral-ocr support as a plugin like pdf-marker (#4284)
Co-authored-by: zhengshuai.li <zhengshuai.li@cloudpense.com>
This commit is contained in:
230
plugins/model/pdf-mistral/api_mp.py
Executable file
230
plugins/model/pdf-mistral/api_mp.py
Executable file
@@ -0,0 +1,230 @@
|
||||
import time
|
||||
import base64
|
||||
import fitz
|
||||
import re
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException, FastAPI, UploadFile, File
|
||||
from fastapi.responses import JSONResponse
|
||||
from mistralai import Mistral
|
||||
import os
|
||||
import shutil
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
temp_dir = "./temp"
|
||||
|
||||
# Initialize Mistral client with API key from environment variable
|
||||
mistral_api_key = os.environ.get("MISTRAL_API_KEY", "")
|
||||
if not mistral_api_key:
|
||||
logger.warning("MISTRAL_API_KEY environment variable not set. PDF processing will fail.")
|
||||
|
||||
mistral_client = Mistral(api_key=mistral_api_key) if mistral_api_key else None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Create temp directory if it doesn't exist
|
||||
global temp_dir
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
print("Application startup, creating temp directory...")
|
||||
yield
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print("Application shutdown, cleaning up...")
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
@app.post("/v1/parse/file")
|
||||
async def read_file(
|
||||
file: UploadFile = File(...)):
|
||||
temp_file_path = None
|
||||
try:
|
||||
start_time = time.time()
|
||||
global temp_dir
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(temp_dir, file.filename)
|
||||
with open(temp_file_path, "wb") as temp_file:
|
||||
file_content = await file.read()
|
||||
temp_file.write(file_content)
|
||||
|
||||
# Get page count using PyMuPDF
|
||||
try:
|
||||
pdf_document = fitz.open(temp_file_path)
|
||||
total_pages = pdf_document.page_count
|
||||
pdf_document.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open PDF file: {str(e)}")
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": f"Failed to process PDF file: {str(e)}"
|
||||
}
|
||||
|
||||
if mistral_client is None:
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": "MISTRAL_API_KEY environment variable not set."
|
||||
}
|
||||
|
||||
# Step 1: Upload the file to Mistral's servers
|
||||
logger.info(f"Uploading file {file.filename} to Mistral servers")
|
||||
with open(temp_file_path, "rb") as f:
|
||||
try:
|
||||
uploaded_file = mistral_client.files.upload(
|
||||
file={
|
||||
"file_name": file.filename,
|
||||
"content": f,
|
||||
},
|
||||
purpose="ocr"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Try to parse Mistral API error format
|
||||
try:
|
||||
error_data = json.loads(error_msg)
|
||||
if error_data.get("object") == "error":
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": f"Mistral API upload error: {error_msg}"
|
||||
}
|
||||
|
||||
# Step 2: Get a signed URL for the uploaded file
|
||||
logger.info(f"Getting signed URL for file ID: {uploaded_file.id}")
|
||||
try:
|
||||
signed_url = mistral_client.files.get_signed_url(file_id=uploaded_file.id)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Try to parse Mistral API error format
|
||||
try:
|
||||
error_data = json.loads(error_msg)
|
||||
if error_data.get("object") == "error":
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": f"Mistral API signed URL error: {error_msg}"
|
||||
}
|
||||
|
||||
# Step 3: Process the file using the signed URL
|
||||
logger.info("Processing file with OCR API")
|
||||
try:
|
||||
ocr_response = mistral_client.ocr.process(
|
||||
model="mistral-ocr-latest",
|
||||
document={
|
||||
"type": "document_url",
|
||||
"document_url": signed_url.url,
|
||||
},
|
||||
include_image_base64=True
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Try to parse Mistral API error format
|
||||
try:
|
||||
error_data = json.loads(error_msg)
|
||||
if error_data.get("object") == "error":
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": f"Mistral OCR processing error: {error_msg}"
|
||||
}
|
||||
|
||||
# Combine all pages' markdown content
|
||||
markdown_content = "\n".join(page.markdown for page in ocr_response.pages)
|
||||
|
||||
# Create a dictionary to map image filenames to their base64 data
|
||||
image_map = {}
|
||||
for page in ocr_response.pages:
|
||||
for img in page.images:
|
||||
# Extract the image filename from the image id
|
||||
img_id = img.id
|
||||
img_base64 = img.image_base64
|
||||
|
||||
# Print a sample of the first image base64 data for debugging
|
||||
if len(image_map) == 0 and img_base64:
|
||||
print("Sample image base64 prefix:", img_base64[:50] if len(img_base64) > 50 else img_base64)
|
||||
print("Does base64 already include prefix?", img_base64.startswith("data:image/"))
|
||||
|
||||
# Ensure the base64 data is in the correct format for the upstream system
|
||||
# If it doesn't already have the prefix, add it
|
||||
if not img_base64.startswith("data:image/"):
|
||||
# Assume it's a PNG if we can't determine the type
|
||||
img_base64 = f"data:image/png;base64,{img_base64}"
|
||||
|
||||
# Add both potential formats to the map
|
||||
image_map[f"{img_id}.jpeg"] = img_base64
|
||||
image_map[f"{img_id}.png"] = img_base64
|
||||
image_map[img_id] = img_base64
|
||||
|
||||
# Use regex to find all image references in the markdown content
|
||||
# This will match patterns like 
|
||||
image_pattern = r'!\[(.*?)\]\((.*?)\)'
|
||||
|
||||
def replace_image_with_base64(match):
|
||||
alt_text = match.group(1)
|
||||
img_filename = match.group(2)
|
||||
|
||||
# Extract just the filename without path
|
||||
img_filename_only = os.path.basename(img_filename)
|
||||
|
||||
# Check if we have base64 data for this image
|
||||
if img_filename_only in image_map:
|
||||
return f""
|
||||
else:
|
||||
# If we don't have base64 data, keep the original reference
|
||||
logger.warning(f"No base64 data found for image: {img_filename_only}")
|
||||
return match.group(0)
|
||||
|
||||
# Replace all image references with base64 data
|
||||
markdown_content = re.sub(image_pattern, replace_image_with_base64, markdown_content)
|
||||
|
||||
# Clean up the uploaded file from Mistral's servers
|
||||
try:
|
||||
logger.info(f"Deleting uploaded file from Mistral servers: {uploaded_file.id}")
|
||||
mistral_client.files.delete(file_id=uploaded_file.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete uploaded file: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(file.filename + " Total time:", duration)
|
||||
|
||||
# Return with format matching client expectations
|
||||
return {
|
||||
"pages": total_pages,
|
||||
"markdown": markdown_content,
|
||||
"duration": duration # Keep this for logging purposes
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {
|
||||
"pages": 0,
|
||||
"markdown": "",
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}
|
||||
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7231)
|
Reference in New Issue
Block a user