添加ChatGLM2教程 (#125)

* Create GLM2对接教程.md

* 添加GLM2接入教程

* Delete GLM2对接教程.md

* Delete image.png

* Delete openai_api.py

* Delete openai_api_int4.py

* Delete openai_api_int8.py

* Create GLM2对接教程.md

* 添加ChatGLM2接口

* Delete openai_api_int4.py

* Delete openai_api_int8.py

* Update openai_api.py

* Update GLM2对接教程.md
This commit is contained in:
stakeswky
2023-07-13 09:21:19 +08:00
committed by GitHub
parent 5f784e1def
commit 0be7a3e355
3 changed files with 210 additions and 0 deletions

View File

@@ -0,0 +1,174 @@
# coding=utf-8
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from fastapi import Depends, HTTPException, Request
from starlette.status import HTTP_401_UNAUTHORIZED
import argparse
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
async def verify_token(request: Request):
auth_header = request.headers.get('Authorization')
if auth_header:
token_type, _, token = auth_header.partition(' ')
if token_type.lower() == "bearer" and token == "sk-aaabbbcccdddeeefffggghhhiiijjjkkk": # 这里配置你的token
return True
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authorization credentials",
)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest, token: bool = Depends(verify_token)):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="16", type=str, help="Model name")
args = parser.parse_args()
model_dict = {
"4": "THUDM/chatglm2-6b-int4",
"8": "THUDM/chatglm2-6b-int8",
"16": "THUDM/chatglm2-6b"
}
model_name = model_dict.get(args.model_name, "THUDM/chatglm2-6b")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=6006, workers=1)