mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 13:03:50 +00:00
add reranker (#679)
This commit is contained in:
10
python/reranker/bge-reranker-base/Dockerfile
Normal file
10
python/reranker/bge-reranker-base/Dockerfile
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
||||||
|
|
||||||
|
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
|
||||||
|
COPY ./bge-reranker-base ./bge-reranker-base
|
||||||
|
|
||||||
|
COPY app.py Dockerfile requirement.txt .
|
||||||
|
|
||||||
|
RUN python3 -m pip install -r requirement.txt
|
||||||
|
|
||||||
|
ENTRYPOINT python3 app.py
|
48
python/reranker/bge-reranker-base/README.md
Normal file
48
python/reranker/bge-reranker-base/README.md
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
|
||||||
|
## 推荐配置
|
||||||
|
|
||||||
|
推荐配置如下:
|
||||||
|
|
||||||
|
{{< table "table-hover table-striped-columns" >}}
|
||||||
|
| 类型 | 内存 | 显存 | 硬盘空间 | 启动命令 |
|
||||||
|
|------|---------|---------|----------|--------------------------|
|
||||||
|
| base | >=4GB | >=3GB | >=8GB | python app.py |
|
||||||
|
{{< /table >}}
|
||||||
|
|
||||||
|
## 部署
|
||||||
|
|
||||||
|
### 环境要求
|
||||||
|
|
||||||
|
- Python 3.10.11
|
||||||
|
- CUDA 11.7
|
||||||
|
- 科学上网环境
|
||||||
|
|
||||||
|
### 源码部署
|
||||||
|
|
||||||
|
1. 根据上面的环境配置配置好环境,具体教程自行 GPT;
|
||||||
|
2. 下载 [python 文件](app.py)
|
||||||
|
3. 在命令行输入命令 `pip install -r requirments.txt`;
|
||||||
|
4. 按照[https://huggingface.co/BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)下载模型仓库到app.py同级目录
|
||||||
|
5. 添加环境变量 `export ACCESS_TOKEN=XXXXXX` 配置 token,这里的 token 只是加一层验证,防止接口被人盗用,默认值为 `ACCESS_TOKEN` ;
|
||||||
|
6. 执行命令 `python app.py`。
|
||||||
|
|
||||||
|
然后等待模型下载,直到模型加载完毕为止。如果出现报错先问 GPT。
|
||||||
|
|
||||||
|
启动成功后应该会显示如下地址:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
> 这里的 `http://0.0.0.0:6006` 就是连接地址。
|
||||||
|
|
||||||
|
### docker 部署
|
||||||
|
|
||||||
|
**镜像和端口**
|
||||||
|
|
||||||
|
+ 镜像名: `luanshaotong/reranker:v0.1`
|
||||||
|
+ 端口号: 6006
|
||||||
|
|
||||||
|
```
|
||||||
|
# 设置安全凭证(即oneapi中的渠道密钥)
|
||||||
|
通过环境变量ACCESS_TOKEN引入,默认值:ACCESS_TOKEN。
|
||||||
|
有关docker环境变量引入的方法请自寻教程,此处不再赘述。
|
||||||
|
```
|
107
python/reranker/bge-reranker-base/app.py
Normal file
107
python/reranker/bge-reranker-base/app.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@Time: 2023/11/7 22:45
|
||||||
|
@Author: zhidong
|
||||||
|
@File: reranker.py
|
||||||
|
@Desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import uvicorn
|
||||||
|
import datetime
|
||||||
|
from fastapi import FastAPI, Security, HTTPException
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from FlagEmbedding import FlagReranker
|
||||||
|
from pydantic import Field, BaseModel, validator
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
def response(code, msg, data=None):
|
||||||
|
time = str(datetime.datetime.now())
|
||||||
|
if data is None:
|
||||||
|
data = []
|
||||||
|
result = {
|
||||||
|
"code": code,
|
||||||
|
"message": msg,
|
||||||
|
"data": data,
|
||||||
|
"time": time
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def success(data=None, msg=''):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class Inputs(BaseModel):
|
||||||
|
id: str
|
||||||
|
text: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class QADocs(BaseModel):
|
||||||
|
query: Optional[str]
|
||||||
|
inputs: Optional[List[Inputs]]
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton(type):
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if not hasattr(cls, '_instance'):
|
||||||
|
cls._instance = super().__call__(*args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
RERANK_MODEL_PATH = os.path.join(os.path.dirname(__file__), "bge-reranker-base")
|
||||||
|
|
||||||
|
class Reranker(metaclass=Singleton):
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.reranker = FlagReranker(model_path,
|
||||||
|
use_fp16=False)
|
||||||
|
|
||||||
|
def compute_score(self, pairs: List[List[str]]):
|
||||||
|
if len(pairs) > 0:
|
||||||
|
result = self.reranker.compute_score(pairs)
|
||||||
|
if isinstance(result, float):
|
||||||
|
result = [result]
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(object):
|
||||||
|
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
|
||||||
|
self.reranker = Reranker(rerank_model_path)
|
||||||
|
|
||||||
|
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
||||||
|
if query_docs is None or len(query_docs.inputs) == 0:
|
||||||
|
return []
|
||||||
|
new_docs = []
|
||||||
|
pair = []
|
||||||
|
for answer in query_docs.inputs:
|
||||||
|
pair.append([query_docs.query, answer.text])
|
||||||
|
scores = self.reranker.compute_score(pair)
|
||||||
|
for index, score in enumerate(scores):
|
||||||
|
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
|
||||||
|
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
|
||||||
|
return new_docs
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
security = HTTPBearer()
|
||||||
|
env_bearer_token = 'ACCESS_TOKEN'
|
||||||
|
|
||||||
|
@app.post('/api/v1/rerank')
|
||||||
|
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||||
|
token = credentials.credentials
|
||||||
|
if env_bearer_token is not None and token != env_bearer_token:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
chat = Chat()
|
||||||
|
qa_docs_with_rerank = chat.fit_query_answer_rerank(docs)
|
||||||
|
return response(200, msg="重排成功", data=qa_docs_with_rerank)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
token = os.getenv("ACCESS_TOKEN")
|
||||||
|
if token is not None:
|
||||||
|
env_bearer_token = token
|
||||||
|
try:
|
||||||
|
uvicorn.run(app, host='0.0.0.0', port=6006)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"API启动失败!\n报错:\n{e}")
|
6
python/reranker/bge-reranker-base/requirement.txt
Normal file
6
python/reranker/bge-reranker-base/requirement.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
fastapi==0.104.1
|
||||||
|
FlagEmbedding==1.1.5
|
||||||
|
pydantic==1.10.13
|
||||||
|
uvicorn==0.17.6
|
||||||
|
itsdangerous
|
||||||
|
protobuf
|
Reference in New Issue
Block a user