mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 21:13:50 +00:00
update reranker source code (#1082)
This commit is contained in:
@@ -3,8 +3,10 @@ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
|||||||
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
|
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
|
||||||
COPY ./bge-reranker-base ./bge-reranker-base
|
COPY ./bge-reranker-base ./bge-reranker-base
|
||||||
|
|
||||||
COPY app.py Dockerfile requirement.txt .
|
COPY requirement.txt .
|
||||||
|
|
||||||
RUN python3 -m pip install -r requirement.txt
|
RUN python3 -m pip install -r requirement.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
COPY app.py Dockerfile .
|
||||||
|
|
||||||
ENTRYPOINT python3 app.py
|
ENTRYPOINT python3 app.py
|
||||||
|
@@ -38,7 +38,7 @@
|
|||||||
|
|
||||||
**镜像和端口**
|
**镜像和端口**
|
||||||
|
|
||||||
+ 镜像名: `luanshaotong/reranker:v0.1`
|
+ 镜像名: `luanshaotong/reranker:v0.2`
|
||||||
+ 端口号: 6006
|
+ 端口号: 6006
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@@ -32,15 +32,9 @@ def response(code, msg, data=None):
|
|||||||
def success(data=None, msg=''):
|
def success(data=None, msg=''):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class Inputs(BaseModel):
|
|
||||||
id: str
|
|
||||||
text: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class QADocs(BaseModel):
|
class QADocs(BaseModel):
|
||||||
query: Optional[str]
|
query: Optional[str]
|
||||||
inputs: Optional[List[Inputs]]
|
documents: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
class Singleton(type):
|
class Singleton(type):
|
||||||
@@ -72,23 +66,24 @@ class Chat(object):
|
|||||||
self.reranker = Reranker(rerank_model_path)
|
self.reranker = Reranker(rerank_model_path)
|
||||||
|
|
||||||
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
||||||
if query_docs is None or len(query_docs.inputs) == 0:
|
if query_docs is None or len(query_docs.documents) == 0:
|
||||||
return []
|
return []
|
||||||
new_docs = []
|
new_docs = []
|
||||||
pair = []
|
pair = []
|
||||||
for answer in query_docs.inputs:
|
for answer in query_docs.documents:
|
||||||
pair.append([query_docs.query, answer.text])
|
pair.append([query_docs.query, answer])
|
||||||
scores = self.reranker.compute_score(pair)
|
scores = self.reranker.compute_score(pair)
|
||||||
for index, score in enumerate(scores):
|
for index, score in enumerate(scores):
|
||||||
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
|
new_docs.append({"index": index, "text": query_docs.documents[index], "score": 1 / (1 + np.exp(-score))})
|
||||||
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
|
#results = [{"document": {"text": documents["text"]}, "index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||||
return new_docs
|
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
env_bearer_token = 'ACCESS_TOKEN'
|
env_bearer_token = 'ACCESS_TOKEN'
|
||||||
|
|
||||||
@app.post('/api/v1/rerank')
|
@app.post('/v1/rerank')
|
||||||
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
if env_bearer_token is not None and token != env_bearer_token:
|
if env_bearer_token is not None and token != env_bearer_token:
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
|
transformers[sentencepiece]
|
||||||
FlagEmbedding==1.1.5
|
FlagEmbedding==1.1.5
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
uvicorn==0.17.6
|
uvicorn==0.17.6
|
Reference in New Issue
Block a user