mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-22 12:20:34 +00:00
update reranker source code (#1082)
This commit is contained in:
@@ -3,8 +3,10 @@ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
||||
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
|
||||
COPY ./bge-reranker-base ./bge-reranker-base
|
||||
|
||||
COPY app.py Dockerfile requirement.txt .
|
||||
COPY requirement.txt .
|
||||
|
||||
RUN python3 -m pip install -r requirement.txt
|
||||
RUN python3 -m pip install -r requirement.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
COPY app.py Dockerfile .
|
||||
|
||||
ENTRYPOINT python3 app.py
|
||||
|
@@ -38,7 +38,7 @@
|
||||
|
||||
**镜像和端口**
|
||||
|
||||
+ 镜像名: `luanshaotong/reranker:v0.1`
|
||||
+ 镜像名: `luanshaotong/reranker:v0.2`
|
||||
+ 端口号: 6006
|
||||
|
||||
```
|
||||
|
@@ -4,7 +4,7 @@
|
||||
@Time: 2023/11/7 22:45
|
||||
@Author: zhidong
|
||||
@File: reranker.py
|
||||
@Desc:
|
||||
@Desc:
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
@@ -30,17 +30,11 @@ def response(code, msg, data=None):
|
||||
return result
|
||||
|
||||
def success(data=None, msg=''):
|
||||
return
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
id: str
|
||||
text: Optional[str]
|
||||
|
||||
return
|
||||
|
||||
class QADocs(BaseModel):
|
||||
query: Optional[str]
|
||||
inputs: Optional[List[Inputs]]
|
||||
documents: Optional[List[str]]
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
@@ -72,23 +66,24 @@ class Chat(object):
|
||||
self.reranker = Reranker(rerank_model_path)
|
||||
|
||||
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
||||
if query_docs is None or len(query_docs.inputs) == 0:
|
||||
if query_docs is None or len(query_docs.documents) == 0:
|
||||
return []
|
||||
new_docs = []
|
||||
pair = []
|
||||
for answer in query_docs.inputs:
|
||||
pair.append([query_docs.query, answer.text])
|
||||
for answer in query_docs.documents:
|
||||
pair.append([query_docs.query, answer])
|
||||
scores = self.reranker.compute_score(pair)
|
||||
for index, score in enumerate(scores):
|
||||
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
|
||||
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
|
||||
return new_docs
|
||||
new_docs.append({"index": index, "text": query_docs.documents[index], "score": 1 / (1 + np.exp(-score))})
|
||||
#results = [{"document": {"text": documents["text"]}, "index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||
return {"results": results}
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
env_bearer_token = 'ACCESS_TOKEN'
|
||||
|
||||
@app.post('/api/v1/rerank')
|
||||
@app.post('/v1/rerank')
|
||||
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
@@ -104,4 +99,4 @@ if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app, host='0.0.0.0', port=6006)
|
||||
except Exception as e:
|
||||
print(f"API启动失败!\n报错:\n{e}")
|
||||
print(f"API启动失败!\n报错:\n{e}")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
fastapi==0.104.1
|
||||
transformers[sentencepiece]
|
||||
FlagEmbedding==1.1.5
|
||||
pydantic==1.10.13
|
||||
uvicorn==0.17.6
|
Reference in New Issue
Block a user