mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-27 00:17:31 +00:00
update reranker source code (#1082)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
@Time: 2023/11/7 22:45
|
||||
@Author: zhidong
|
||||
@File: reranker.py
|
||||
@Desc:
|
||||
@Desc:
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
@@ -30,17 +30,11 @@ def response(code, msg, data=None):
|
||||
return result
|
||||
|
||||
def success(data=None, msg=''):
|
||||
return
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
id: str
|
||||
text: Optional[str]
|
||||
|
||||
return
|
||||
|
||||
class QADocs(BaseModel):
|
||||
query: Optional[str]
|
||||
inputs: Optional[List[Inputs]]
|
||||
documents: Optional[List[str]]
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
@@ -72,23 +66,24 @@ class Chat(object):
|
||||
self.reranker = Reranker(rerank_model_path)
|
||||
|
||||
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
||||
if query_docs is None or len(query_docs.inputs) == 0:
|
||||
if query_docs is None or len(query_docs.documents) == 0:
|
||||
return []
|
||||
new_docs = []
|
||||
pair = []
|
||||
for answer in query_docs.inputs:
|
||||
pair.append([query_docs.query, answer.text])
|
||||
for answer in query_docs.documents:
|
||||
pair.append([query_docs.query, answer])
|
||||
scores = self.reranker.compute_score(pair)
|
||||
for index, score in enumerate(scores):
|
||||
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
|
||||
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
|
||||
return new_docs
|
||||
new_docs.append({"index": index, "text": query_docs.documents[index], "score": 1 / (1 + np.exp(-score))})
|
||||
#results = [{"document": {"text": documents["text"]}, "index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||
return {"results": results}
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
env_bearer_token = 'ACCESS_TOKEN'
|
||||
|
||||
@app.post('/api/v1/rerank')
|
||||
@app.post('/v1/rerank')
|
||||
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
@@ -104,4 +99,4 @@ if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app, host='0.0.0.0', port=6006)
|
||||
except Exception as e:
|
||||
print(f"API启动失败!\n报错:\n{e}")
|
||||
print(f"API启动失败!\n报错:\n{e}")
|
||||
|
Reference in New Issue
Block a user