From 7e6272ca1bc3ffdc9f9e0c02914dff1415fefa18 Mon Sep 17 00:00:00 2001 From: stakeswky <64798754+stakeswky@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:19:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0embedding=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=20(#136)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/zh/examples/ChatGLM2/openai_api.py | 75 ++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/docs/zh/examples/ChatGLM2/openai_api.py b/docs/zh/examples/ChatGLM2/openai_api.py index 40a62778c..d0865394e 100644 --- a/docs/zh/examples/ChatGLM2/openai_api.py +++ b/docs/zh/examples/ChatGLM2/openai_api.py @@ -6,12 +6,16 @@ from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from transformers import AutoTokenizer, AutoModel -from sse_starlette.sse import ServerSentEvent, EventSourceResponse +from sse_starlette.sse import EventSourceResponse from fastapi import Depends, HTTPException, Request from starlette.status import HTTP_401_UNAUTHORIZED import argparse +import tiktoken +import numpy as np +from sentence_transformers import SentenceTransformer +from sklearn.preprocessing import PolynomialFeatures @asynccontextmanager @@ -81,6 +85,34 @@ async def verify_token(request: Request): detail="Invalid authorization credentials", ) +class EmbeddingRequest(BaseModel): + input: List[str] + model: str + +class EmbeddingResponse(BaseModel): + data: list + model: str + object: str + usage: dict + +def num_tokens_from_string(string: str) -> int: + """Returns the number of tokens in a text string.""" + encoding = tiktoken.get_encoding('cl100k_base') + num_tokens = len(encoding.encode(string)) + return num_tokens + +def expand_features(embedding, target_length): + poly = PolynomialFeatures(degree=2) + expanded_embedding = poly.fit_transform(embedding.reshape(1, -1)) + expanded_embedding = expanded_embedding.flatten() + if len(expanded_embedding) > target_length: + # 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度 + expanded_embedding = expanded_embedding[:target_length] + elif len(expanded_embedding) < target_length: + # 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度 + expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding))) + return expanded_embedding + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest, token: bool = Depends(verify_token)): @@ -152,6 +184,43 @@ async def predict(query: str, history: List[List[str]], model_id: str): yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield '[DONE]' +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +async def get_embeddings(request: EmbeddingRequest, token: bool = Depends(verify_token)): + + + # 计算嵌入向量和tokens数量 + embeddings = [embeddings_model.encode(text) for text in request.input] + + # 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度 + embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings] + + # Min-Max normalization + embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings] + + # 将numpy数组转换为列表 + embeddings = [embedding.tolist() for embedding in embeddings] + prompt_tokens = sum(len(text.split()) for text in request.input) + total_tokens = sum(num_tokens_from_string(text) for text in request.input) + + + response = { + "data": [ + { + "embedding": embedding, + "index": index, + "object": "embedding" + } for index, embedding in enumerate(embeddings) + ], + "model": request.model, + "object": "list", + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": total_tokens, + } + } + + return response + if __name__ == "__main__": @@ -169,6 +238,6 @@ if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda() - model.eval() + embeddings_model = SentenceTransformer('moka-ai/m3e-large',device='cpu') uvicorn.run(app, host='0.0.0.0', port=6006, workers=1)