diff --git a/python/reranker/bge-reranker-base/Dockerfile b/python/reranker/bge-reranker-base/Dockerfile new file mode 100644 index 000000000..3ba1822de --- /dev/null +++ b/python/reranker/bge-reranker-base/Dockerfile @@ -0,0 +1,10 @@ +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime + +# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile +COPY ./bge-reranker-base ./bge-reranker-base + +COPY app.py Dockerfile requirement.txt . + +RUN python3 -m pip install -r requirement.txt + +ENTRYPOINT python3 app.py diff --git a/python/reranker/bge-reranker-base/README.md b/python/reranker/bge-reranker-base/README.md new file mode 100644 index 000000000..9291c6c22 --- /dev/null +++ b/python/reranker/bge-reranker-base/README.md @@ -0,0 +1,48 @@ + +## 推荐配置 + +推荐配置如下: + +{{< table "table-hover table-striped-columns" >}} +| 类型 | 内存 | 显存 | 硬盘空间 | 启动命令 | +|------|---------|---------|----------|--------------------------| +| base | >=4GB | >=3GB | >=8GB | python app.py | +{{< /table >}} + +## 部署 + +### 环境要求 + +- Python 3.10.11 +- CUDA 11.7 +- 科学上网环境 + +### 源码部署 + +1. 根据上面的环境配置配置好环境,具体教程自行 GPT; +2. 下载 [python 文件](app.py) +3. 在命令行输入命令 `pip install -r requirments.txt`; +4. 按照[https://huggingface.co/BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)下载模型仓库到app.py同级目录 +5. 添加环境变量 `export ACCESS_TOKEN=XXXXXX` 配置 token,这里的 token 只是加一层验证,防止接口被人盗用,默认值为 `ACCESS_TOKEN` ; +6. 执行命令 `python app.py`。 + +然后等待模型下载,直到模型加载完毕为止。如果出现报错先问 GPT。 + +启动成功后应该会显示如下地址: + +![](/imgs/chatglm2.png) + +> 这里的 `http://0.0.0.0:6006` 就是连接地址。 + +### docker 部署 + +**镜像和端口** + ++ 镜像名: `luanshaotong/reranker:v0.1` ++ 端口号: 6006 + +``` +# 设置安全凭证(即oneapi中的渠道密钥) +通过环境变量ACCESS_TOKEN引入,默认值:ACCESS_TOKEN。 +有关docker环境变量引入的方法请自寻教程,此处不再赘述。 +``` diff --git a/python/reranker/bge-reranker-base/app.py b/python/reranker/bge-reranker-base/app.py new file mode 100644 index 000000000..9d7a7ee5d --- /dev/null +++ b/python/reranker/bge-reranker-base/app.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time: 2023/11/7 22:45 +@Author: zhidong +@File: reranker.py +@Desc: +""" +import os +import numpy as np +import logging +import uvicorn +import datetime +from fastapi import FastAPI, Security, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from FlagEmbedding import FlagReranker +from pydantic import Field, BaseModel, validator +from typing import Optional, List + +def response(code, msg, data=None): + time = str(datetime.datetime.now()) + if data is None: + data = [] + result = { + "code": code, + "message": msg, + "data": data, + "time": time + } + return result + +def success(data=None, msg=''): + return + + +class Inputs(BaseModel): + id: str + text: Optional[str] + + +class QADocs(BaseModel): + query: Optional[str] + inputs: Optional[List[Inputs]] + + +class Singleton(type): + def __call__(cls, *args, **kwargs): + if not hasattr(cls, '_instance'): + cls._instance = super().__call__(*args, **kwargs) + return cls._instance + + +RERANK_MODEL_PATH = os.path.join(os.path.dirname(__file__), "bge-reranker-base") + +class Reranker(metaclass=Singleton): + def __init__(self, model_path): + self.reranker = FlagReranker(model_path, + use_fp16=False) + + def compute_score(self, pairs: List[List[str]]): + if len(pairs) > 0: + result = self.reranker.compute_score(pairs) + if isinstance(result, float): + result = [result] + return result + else: + return None + + +class Chat(object): + def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH): + self.reranker = Reranker(rerank_model_path) + + def fit_query_answer_rerank(self, query_docs: QADocs) -> List: + if query_docs is None or len(query_docs.inputs) == 0: + return [] + new_docs = [] + pair = [] + for answer in query_docs.inputs: + pair.append([query_docs.query, answer.text]) + scores = self.reranker.compute_score(pair) + for index, score in enumerate(scores): + new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))}) + new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True)) + return new_docs + +app = FastAPI() +security = HTTPBearer() +env_bearer_token = 'ACCESS_TOKEN' + +@app.post('/api/v1/rerank') +async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)): + token = credentials.credentials + if env_bearer_token is not None and token != env_bearer_token: + raise HTTPException(status_code=401, detail="Invalid token") + chat = Chat() + qa_docs_with_rerank = chat.fit_query_answer_rerank(docs) + return response(200, msg="重排成功", data=qa_docs_with_rerank) + +if __name__ == "__main__": + token = os.getenv("ACCESS_TOKEN") + if token is not None: + env_bearer_token = token + try: + uvicorn.run(app, host='0.0.0.0', port=6006) + except Exception as e: + print(f"API启动失败!\n报错:\n{e}") \ No newline at end of file diff --git a/python/reranker/bge-reranker-base/requirement.txt b/python/reranker/bge-reranker-base/requirement.txt new file mode 100644 index 000000000..cc2adad0f --- /dev/null +++ b/python/reranker/bge-reranker-base/requirement.txt @@ -0,0 +1,6 @@ +fastapi==0.104.1 +FlagEmbedding==1.1.5 +pydantic==1.10.13 +uvicorn==0.17.6 +itsdangerous +protobuf