update reranker source code (#1082)

This commit is contained in:
wikkipedia
2024-03-27 21:43:48 +08:00
committed by GitHub
parent 3f892bd810
commit 00ace0b69c
4 changed files with 18 additions and 20 deletions

View File

@@ -3,8 +3,10 @@ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
COPY ./bge-reranker-base ./bge-reranker-base
COPY app.py Dockerfile requirement.txt .
COPY requirement.txt .
RUN python3 -m pip install -r requirement.txt
RUN python3 -m pip install -r requirement.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
COPY app.py Dockerfile .
ENTRYPOINT python3 app.py

View File

@@ -38,7 +38,7 @@
**镜像和端口**
+ 镜像名: `luanshaotong/reranker:v0.1`
+ 镜像名: `luanshaotong/reranker:v0.2`
+ 端口号: 6006
```

View File

@@ -4,7 +4,7 @@
@Time: 2023/11/7 22:45
@Author: zhidong
@File: reranker.py
@Desc:
@Desc:
"""
import os
import numpy as np
@@ -30,17 +30,11 @@ def response(code, msg, data=None):
return result
def success(data=None, msg=''):
return
class Inputs(BaseModel):
id: str
text: Optional[str]
return
class QADocs(BaseModel):
query: Optional[str]
inputs: Optional[List[Inputs]]
documents: Optional[List[str]]
class Singleton(type):
@@ -72,23 +66,24 @@ class Chat(object):
self.reranker = Reranker(rerank_model_path)
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
if query_docs is None or len(query_docs.inputs) == 0:
if query_docs is None or len(query_docs.documents) == 0:
return []
new_docs = []
pair = []
for answer in query_docs.inputs:
pair.append([query_docs.query, answer.text])
for answer in query_docs.documents:
pair.append([query_docs.query, answer])
scores = self.reranker.compute_score(pair)
for index, score in enumerate(scores):
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
return new_docs
new_docs.append({"index": index, "text": query_docs.documents[index], "score": 1 / (1 + np.exp(-score))})
#results = [{"document": {"text": documents["text"]}, "index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
return {"results": results}
app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'
@app.post('/api/v1/rerank')
@app.post('/v1/rerank')
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
@@ -104,4 +99,4 @@ if __name__ == "__main__":
try:
uvicorn.run(app, host='0.0.0.0', port=6006)
except Exception as e:
print(f"API启动失败\n报错:\n{e}")
print(f"API启动失败\n报错:\n{e}")

View File

@@ -1,4 +1,5 @@
fastapi==0.104.1
transformers[sentencepiece]
FlagEmbedding==1.1.5
pydantic==1.10.13
uvicorn==0.17.6