add reranker (#679)

This commit is contained in:
luanshaotong
2024-01-02 18:10:15 +08:00
committed by GitHub
parent 2e75851b02
commit d5b24eca57
4 changed files with 171 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
COPY ./bge-reranker-base ./bge-reranker-base
COPY app.py Dockerfile requirement.txt .
RUN python3 -m pip install -r requirement.txt
ENTRYPOINT python3 app.py

View File

@@ -0,0 +1,48 @@
## 推荐配置
推荐配置如下:
{{< table "table-hover table-striped-columns" >}}
| 类型 | 内存 | 显存 | 硬盘空间 | 启动命令 |
|------|---------|---------|----------|--------------------------|
| base | >=4GB | >=3GB | >=8GB | python app.py |
{{< /table >}}
## 部署
### 环境要求
- Python 3.10.11
- CUDA 11.7
- 科学上网环境
### 源码部署
1. 根据上面的环境配置配置好环境,具体教程自行 GPT
2. 下载 [python 文件](app.py)
3. 在命令行输入命令 `pip install -r requirments.txt`
4. 按照[https://huggingface.co/BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)下载模型仓库到app.py同级目录
5. 添加环境变量 `export ACCESS_TOKEN=XXXXXX` 配置 token这里的 token 只是加一层验证,防止接口被人盗用,默认值为 `ACCESS_TOKEN`
6. 执行命令 `python app.py`
然后等待模型下载,直到模型加载完毕为止。如果出现报错先问 GPT。
启动成功后应该会显示如下地址:
![](/imgs/chatglm2.png)
> 这里的 `http://0.0.0.0:6006` 就是连接地址。
### docker 部署
**镜像和端口**
+ 镜像名: `luanshaotong/reranker:v0.1`
+ 端口号: 6006
```
# 设置安全凭证即oneapi中的渠道密钥
通过环境变量ACCESS_TOKEN引入默认值ACCESS_TOKEN。
有关docker环境变量引入的方法请自寻教程此处不再赘述。
```

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time: 2023/11/7 22:45
@Author: zhidong
@File: reranker.py
@Desc:
"""
import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from FlagEmbedding import FlagReranker
from pydantic import Field, BaseModel, validator
from typing import Optional, List
def response(code, msg, data=None):
time = str(datetime.datetime.now())
if data is None:
data = []
result = {
"code": code,
"message": msg,
"data": data,
"time": time
}
return result
def success(data=None, msg=''):
return
class Inputs(BaseModel):
id: str
text: Optional[str]
class QADocs(BaseModel):
query: Optional[str]
inputs: Optional[List[Inputs]]
class Singleton(type):
def __call__(cls, *args, **kwargs):
if not hasattr(cls, '_instance'):
cls._instance = super().__call__(*args, **kwargs)
return cls._instance
RERANK_MODEL_PATH = os.path.join(os.path.dirname(__file__), "bge-reranker-base")
class Reranker(metaclass=Singleton):
def __init__(self, model_path):
self.reranker = FlagReranker(model_path,
use_fp16=False)
def compute_score(self, pairs: List[List[str]]):
if len(pairs) > 0:
result = self.reranker.compute_score(pairs)
if isinstance(result, float):
result = [result]
return result
else:
return None
class Chat(object):
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
self.reranker = Reranker(rerank_model_path)
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
if query_docs is None or len(query_docs.inputs) == 0:
return []
new_docs = []
pair = []
for answer in query_docs.inputs:
pair.append([query_docs.query, answer.text])
scores = self.reranker.compute_score(pair)
for index, score in enumerate(scores):
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
return new_docs
app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'
@app.post('/api/v1/rerank')
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
chat = Chat()
qa_docs_with_rerank = chat.fit_query_answer_rerank(docs)
return response(200, msg="重排成功", data=qa_docs_with_rerank)
if __name__ == "__main__":
token = os.getenv("ACCESS_TOKEN")
if token is not None:
env_bearer_token = token
try:
uvicorn.run(app, host='0.0.0.0', port=6006)
except Exception as e:
print(f"API启动失败\n报错:\n{e}")

View File

@@ -0,0 +1,6 @@
fastapi==0.104.1
FlagEmbedding==1.1.5
pydantic==1.10.13
uvicorn==0.17.6
itsdangerous
protobuf