mirror of
https://github.com/YuWanTingbb/unofficial-gpt4.git
synced 2025-10-13 13:45:06 +00:00
gpt4-copilot-java v0.2.2
This commit is contained in:
@@ -8,7 +8,7 @@ LABEL maintainer="Yanyutin753"
|
||||
USER root
|
||||
|
||||
# 复制JAR文件到容器的/app目录下
|
||||
COPY /target/gpt-4-copilot-0.2.1.jar app.jar
|
||||
COPY /target/gpt-4-copilot-0.2.2.jar app.jar
|
||||
|
||||
# 声明服务运行在8081端口
|
||||
EXPOSE 8081
|
||||
|
7
pom.xml
7
pom.xml
@@ -10,7 +10,7 @@
|
||||
</parent>
|
||||
<groupId>com.gpt4.copilot</groupId>
|
||||
<artifactId>gpt-4-copilot</artifactId>
|
||||
<version>0.2.1</version>
|
||||
<version>0.2.2</version>
|
||||
<name>native</name>
|
||||
<description>Demo project for Spring Boot with GraalVM Native Image</description>
|
||||
<properties>
|
||||
@@ -54,6 +54,11 @@
|
||||
<artifactId>chatgpt-java</artifactId>
|
||||
<version>1.1.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.knuddels</groupId>
|
||||
<artifactId>jtokkit</artifactId>
|
||||
<version>0.6.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package com.gpt4.copilot.controller;
|
||||
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONException;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.alibaba.fastjson2.TypeReference;
|
||||
@@ -11,13 +12,13 @@ import com.gpt4.copilot.pojo.Conversation;
|
||||
import com.gpt4.copilot.pojo.Result;
|
||||
import com.gpt4.copilot.pojo.SystemSetting;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import com.unfbx.chatgpt.utils.TikTokensUtil;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
@@ -52,6 +53,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||
@Data
|
||||
@RestController()
|
||||
public class ChatController {
|
||||
|
||||
public static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
|
||||
/**
|
||||
* 模型
|
||||
@@ -501,29 +503,33 @@ public class ChatController {
|
||||
String model = modelAdjust(conversation);
|
||||
Request streamRequest = getPrompt(conversation, model, headersMap);
|
||||
try (Response resp = client.newCall(streamRequest).execute()) {
|
||||
log.info(resp.toString());
|
||||
if (!resp.isSuccessful()) {
|
||||
if (resp.code() == 429) {
|
||||
return new ResponseEntity<>(Result.error("rate limit exceeded"), HttpStatus.TOO_MANY_REQUESTS);
|
||||
} else if (resp.code() == 400) {
|
||||
return new ResponseEntity<>(Result.error("messages is none or too long and over 32K"), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
} else {
|
||||
String token = getCopilotToken(apiKey);
|
||||
if (token == null) {
|
||||
return new ResponseEntity<>(Result.error("copilot APIKey is wrong"), HttpStatus.UNAUTHORIZED);
|
||||
if (resp.code() == 403) {
|
||||
String token = getCopilotToken(apiKey);
|
||||
if (token == null) {
|
||||
return new ResponseEntity<>(Result.error("copilot APIKey is wrong"), HttpStatus.UNAUTHORIZED);
|
||||
}
|
||||
copilotTokenList.put(apiKey, token);
|
||||
log.info("token过期,Github CopilotToken重置化成功!");
|
||||
return againConversation(response, conversation, token, apiKey, model);
|
||||
} else {
|
||||
return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
copilotTokenList.put(apiKey, token);
|
||||
log.info("token过期,Github CopilotToken重置化成功!");
|
||||
againConversation(response, conversation, token, apiKey, model);
|
||||
}
|
||||
} else {
|
||||
// 流式和非流式输出
|
||||
outPutChat(response, resp, conversation, model);
|
||||
return outPutChat(response, resp, conversation, model);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
return null;
|
||||
}, executor);
|
||||
|
||||
return getObjectResponseEntity(response, future);
|
||||
@@ -542,7 +548,7 @@ public class ChatController {
|
||||
|
||||
private String getEmbRequestApikey(String authorizationHeader, @org.springframework.web.bind.annotation.RequestBody Object conversation) {
|
||||
if (conversation == null) {
|
||||
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format");
|
||||
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format");
|
||||
}
|
||||
String apiKey;
|
||||
if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) {
|
||||
@@ -624,17 +630,16 @@ public class ChatController {
|
||||
}
|
||||
coCopilotTokenList.put(apiKey, token);
|
||||
log.info("token过期,coCopilotToken重置化成功!");
|
||||
againConversation(response, conversation, token, apiKey, model);
|
||||
return againConversation(response, conversation, token, apiKey, model);
|
||||
}
|
||||
} else {
|
||||
// 流式和非流式输出
|
||||
outPutChat(response, resp, conversation, model);
|
||||
return outPutChat(response, resp, conversation, model);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
return null;
|
||||
}, executor);
|
||||
|
||||
return getObjectResponseEntity(response, future);
|
||||
@@ -677,7 +682,7 @@ public class ChatController {
|
||||
|
||||
private void checkConversation(Conversation conversation) {
|
||||
if (conversation == null) {
|
||||
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format");
|
||||
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format");
|
||||
}
|
||||
long tokens = conversation.tokens();
|
||||
if (tokens > 32 * 1024) {
|
||||
@@ -722,7 +727,7 @@ public class ChatController {
|
||||
*/
|
||||
private String[] extractEmbApiKeyAndRequestUrl(String authorizationHeader, Object conversation) throws IllegalArgumentException {
|
||||
if (conversation == null) {
|
||||
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format");
|
||||
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format");
|
||||
}
|
||||
return getApiKeyAndRequestUrl(authorizationHeader);
|
||||
}
|
||||
@@ -788,17 +793,16 @@ public class ChatController {
|
||||
}
|
||||
selfTokenList.put(apiKey, token);
|
||||
log.info("token过期,自定义selfToken重置化成功!");
|
||||
againConversation(response, conversation, token, apiKey, model);
|
||||
return againConversation(response, conversation, token, apiKey, model);
|
||||
}
|
||||
} else {
|
||||
// 流式和非流式输出
|
||||
outPutChat(response, resp, conversation, model);
|
||||
return outPutChat(response, resp, conversation, model);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
return null;
|
||||
}, executor);
|
||||
|
||||
return getObjectResponseEntity(response, future);
|
||||
@@ -813,11 +817,11 @@ public class ChatController {
|
||||
* @param token
|
||||
* @return
|
||||
*/
|
||||
public Object againConversation(HttpServletResponse response,
|
||||
@org.springframework.web.bind.annotation.RequestBody Conversation conversation,
|
||||
String token,
|
||||
String apiKey,
|
||||
String model) {
|
||||
public ResponseEntity<Object> againConversation(HttpServletResponse response,
|
||||
@org.springframework.web.bind.annotation.RequestBody Conversation conversation,
|
||||
String token,
|
||||
String apiKey,
|
||||
String model) {
|
||||
try {
|
||||
Map<String, String> headersMap = new HashMap<>();
|
||||
//添加头部
|
||||
@@ -825,15 +829,16 @@ public class ChatController {
|
||||
Request streamRequest = getPrompt(conversation, model, headersMap);
|
||||
try (Response resp = client.newCall(streamRequest).execute()) {
|
||||
if (!resp.isSuccessful()) {
|
||||
return new ResponseEntity<>("copilot/cocopilot/self APIKey is wrong Or your network is wrong", HttpStatus.UNAUTHORIZED);
|
||||
return new ResponseEntity<>("APIKey is wrong!", HttpStatus.UNAUTHORIZED);
|
||||
} else {
|
||||
// 流式和非流式输出
|
||||
outPutChat(response, resp, conversation, model);
|
||||
return outPutChat(response, resp, conversation, model);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
log.info("Exception " + e.getMessage());
|
||||
return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -937,7 +942,7 @@ public class ChatController {
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}, executor);
|
||||
|
||||
@@ -1027,7 +1032,7 @@ public class ChatController {
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}, executor);
|
||||
|
||||
@@ -1103,7 +1108,7 @@ public class ChatController {
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}, executor);
|
||||
return getObjectResponseEntity(future);
|
||||
@@ -1132,7 +1137,7 @@ public class ChatController {
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
|
||||
return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1154,7 +1159,6 @@ public class ChatController {
|
||||
return getToken(request);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private String getToken(Request request) throws IOException {
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
log.info(response.toString());
|
||||
@@ -1263,40 +1267,116 @@ public class ChatController {
|
||||
* @param resp
|
||||
* @param conversation
|
||||
*/
|
||||
private void outPutChat(HttpServletResponse response, Response resp, Conversation conversation, String model) {
|
||||
try {
|
||||
boolean isStream = conversation.isStream();
|
||||
int sleep_time = calculateSleepTime(model, isStream);
|
||||
if (isStream) {
|
||||
response.setContentType("text/event-stream; charset=UTF-8");
|
||||
} else {
|
||||
response.setContentType("application/json; charset=utf-8");
|
||||
}
|
||||
|
||||
try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true);
|
||||
BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = in.readLine()) != null) {
|
||||
out.println(line);
|
||||
out.flush();
|
||||
if (sleep_time > 0 && line.startsWith("data")) {
|
||||
Thread.sleep(sleep_time);
|
||||
}
|
||||
}
|
||||
log.info("使用模型:" + model + ",vscode_version:" + systemSetting.getVscode_version() +
|
||||
",copilot_chat_version:" + systemSetting.getCopilot_chat_version()
|
||||
+ ",字符间隔时间:" + sleep_time + "ms,响应:" + resp);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IOException("Thread was interrupted", e);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("IO Exception occurred", e);
|
||||
private ResponseEntity<Object> outPutChat(HttpServletResponse response, Response resp, Conversation conversation, String model) {
|
||||
boolean isStream = conversation.isStream();
|
||||
int sleep_time = calculateSleepTime(model, isStream);
|
||||
if (isStream) {
|
||||
response.setContentType("text/event-stream; charset=UTF-8");
|
||||
return outIsStreamPutChat(response, resp, model, sleep_time);
|
||||
} else {
|
||||
response.setContentType("application/json; charset=utf-8");
|
||||
return outNoStreamPutChat(response, resp, model);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* chat非流式接口的输出
|
||||
*
|
||||
* @param response
|
||||
* @param resp
|
||||
* @param model
|
||||
* @return
|
||||
*/
|
||||
private ResponseEntity<Object> outNoStreamPutChat(HttpServletResponse response, Response resp, String model) {
|
||||
try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true);
|
||||
BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
long tokens = 0;
|
||||
while ((line = in.readLine()) != null) {
|
||||
out.println(line);
|
||||
out.flush();
|
||||
if (line.trim().startsWith("{")) {
|
||||
try {
|
||||
JSONObject resJson = com.alibaba.fastjson2.JSON.parseObject(line);
|
||||
JSONObject usageJson = resJson.getJSONObject("usage");
|
||||
if (usageJson != null) {
|
||||
Integer totalTokens = usageJson.getInteger("total_tokens");
|
||||
if (totalTokens != null) {
|
||||
tokens = totalTokens;
|
||||
} else {
|
||||
log.warn("total_tokens field missing in the response");
|
||||
}
|
||||
} else {
|
||||
log.warn("usage field missing in the response");
|
||||
}
|
||||
} catch (JSONException je) {
|
||||
log.error("Error parsing JSON", je);
|
||||
}
|
||||
}
|
||||
}
|
||||
log.info("使用模型:" + model + ",补全tokens:" + tokens + ",vscode_version:" + systemSetting.getVscode_version() +
|
||||
",copilot_chat_version:" + systemSetting.getCopilot_chat_version() + ",响应:" + resp);
|
||||
if (tokens <= 0) {
|
||||
return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
return null;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* chat流式接口的输出
|
||||
*
|
||||
* @param response
|
||||
* @param resp
|
||||
* @param model
|
||||
* @param sleep_time
|
||||
* @return
|
||||
*/
|
||||
private ResponseEntity<Object> outIsStreamPutChat(HttpServletResponse response, Response resp, String model, int sleep_time) {
|
||||
try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true);
|
||||
BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
long tokens = 0;
|
||||
while ((line = in.readLine()) != null) {
|
||||
out.println(line);
|
||||
out.flush();
|
||||
if (sleep_time > 0 && line.startsWith("data:")) {
|
||||
try {
|
||||
String res_text = line.replace("data: ", "");
|
||||
if (!"[DONE]".equals(res_text.trim())) {
|
||||
JSONObject resJson = com.alibaba.fastjson2.JSON.parseObject(res_text);
|
||||
JSONArray choicesArray = resJson.getJSONArray("choices");
|
||||
if (choicesArray.size() > 0) {
|
||||
JSONObject firstChoice = choicesArray.getJSONObject(0);
|
||||
String content = firstChoice.getJSONObject("delta").getString("content");
|
||||
tokens += TikTokensUtil.tokens("gpt-4-0613", content);
|
||||
}
|
||||
Thread.sleep(sleep_time);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("使用模型:" + model + ",补全tokens:" + tokens + ",vscode_version:" + systemSetting.getVscode_version() +
|
||||
",copilot_chat_version:" + systemSetting.getCopilot_chat_version()
|
||||
+ ",字符间隔时间:" + sleep_time + "ms,响应:" + resp);
|
||||
if (tokens <= 0) {
|
||||
return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
return null;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* chat接口的每个字的睡眠时间
|
||||
*/
|
||||
private int calculateSleepTime(String model, boolean isStream) {
|
||||
|
@@ -25,9 +25,9 @@ public class CustomErrorController implements ErrorController {
|
||||
" <title>Document</title>\n" +
|
||||
"</head>\n" +
|
||||
"<body>\n" +
|
||||
" <p>Thanks you use gpt4-copilot-java-0.2.1</p>\n" +
|
||||
" <p>Thanks you use gpt4-copilot-java-0.2.2</p>\n" +
|
||||
" <p><a href=\"https://apifox.com/apidoc/shared-4301e565-a8df-48a0-85a5-bda2c4c3965a\">详细使用文档</a></p>\n" +
|
||||
" <p><a href=\"https://github.com/Yanyutin753/gpt4-copilot-java-sh\">项目地址</a></p>\n" +
|
||||
" <p><a href=\"https://github.com/Yanyutin753/unofficial-gpt4-api\">项目地址</a></p>\n" +
|
||||
"</body>\n" +
|
||||
"</html>\n", HttpStatus.OK);
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@ import com.alibaba.fastjson2.JSONException;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.gpt4.copilot.controller.ChatController;
|
||||
import com.gpt4.copilot.pojo.SystemSetting;
|
||||
import com.unfbx.chatgpt.utils.TikTokensUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
@@ -256,6 +257,7 @@ public class copilotApplication {
|
||||
@Scheduled(cron = "0 0 3 1/3 * ?")
|
||||
private static void updateLatestVersion() {
|
||||
try {
|
||||
new TikTokensUtil();
|
||||
String latestVersion = getLatestVSCodeVersion();
|
||||
String latestChatVersion = getLatestExtensionVersion("GitHub", "copilot-chat");
|
||||
if (latestVersion != null && latestChatVersion != null) {
|
||||
@@ -296,14 +298,10 @@ public class copilotApplication {
|
||||
System.out.println("one_selfCopilot_limit:" + ChatController.getSystemSetting().getOne_selfCopilot_limit());
|
||||
System.out.println("gpt4-copilot-java 初始化接口成功!");
|
||||
System.out.println("======================================================");
|
||||
System.out.println("******原神gpt4-copilot-java-native v0.2.1启动成功******");
|
||||
System.out.println("* 对chat接口的模型进行重定向,减少潜在的风险");
|
||||
System.out.println("* 使用ConcurrentHashMap,粗略的对于每个密钥按每分钟进行限速");
|
||||
System.out.println("* 新增环境变量用于对gpt-4*等模型进行系统prompt提示");
|
||||
System.out.println("* 新增url|apikey形式传入/self/*接口,用于自定义地址和密钥");
|
||||
System.out.println("* 新增对于消息报错提醒,减少对于用户的困扰");
|
||||
System.out.println("******原神gpt4-copilot-java v0.2.2启动成功******");
|
||||
System.out.println("* 由于本人略菜,graalvm依赖问题无法解决,之后代码将只通过jar和docker的形式运行");
|
||||
System.out.println("* 修复部分bug,优化读取config.json代码,提升稳定性");
|
||||
System.out.println("* 新增每个密钥对于特定的机器码,且保存在文件中,一秘钥一机器码,减小被查询异常");
|
||||
System.out.println("* 新增token计算,优化报错,支持one_api重试机制");
|
||||
System.out.println("URL地址:http://0.0.0.0:" + config.getServerPort() + config.getPrefix() + "");
|
||||
System.out.println("======================================================");
|
||||
}
|
||||
|
@@ -2,20 +2,23 @@ package com.gpt4.copilot.pojo;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.knuddels.jtokkit.Encodings;
|
||||
import com.knuddels.jtokkit.api.Encoding;
|
||||
import com.knuddels.jtokkit.api.EncodingRegistry;
|
||||
import com.knuddels.jtokkit.api.ModelType;
|
||||
import com.unfbx.chatgpt.entity.chat.BaseChatCompletion;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import com.unfbx.chatgpt.utils.TikTokensUtil;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
/**
|
||||
* 描述: chat模型参数
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* 2023-03-02
|
||||
* @author YANGYANG
|
||||
*/
|
||||
@Data
|
||||
@SuperBuilder
|
||||
@@ -24,6 +27,9 @@ import java.util.List;
|
||||
@AllArgsConstructor
|
||||
public class Conversation implements Serializable {
|
||||
|
||||
private static final Map<String, Encoding> modelMap = new HashMap();
|
||||
private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
|
||||
|
||||
/**
|
||||
* 是否流式输出.
|
||||
* default:false
|
||||
@@ -48,9 +54,88 @@ public class Conversation implements Serializable {
|
||||
return 0;
|
||||
}
|
||||
String temModel = this.getModel() == null || !model.startsWith("gpt-4") ? "gpt-3.5-turbo" :"gpt-4-0613";
|
||||
return TikTokensUtil.tokens(temModel, this.messages);
|
||||
return tokens(temModel, this.messages);
|
||||
}
|
||||
|
||||
@Builder.Default
|
||||
private String model = "gpt-3.5-turbo";
|
||||
|
||||
|
||||
public static Encoding getEncoding(@NotNull String modelName) {
|
||||
return modelMap.get(modelName);
|
||||
}
|
||||
|
||||
public static int tokens(@NotNull Encoding enc, String text) {
|
||||
return encode(enc, text).size();
|
||||
}
|
||||
|
||||
public static List encode(@NotNull Encoding enc, String text) {
|
||||
return StrUtil.isBlank(text) ? new ArrayList() : enc.encode(text);
|
||||
}
|
||||
|
||||
public static int tokens(@NotNull String modelName, @NotNull List<Message> messages) {
|
||||
Encoding encoding = getEncoding(modelName);
|
||||
int tokensPerMessage = 0;
|
||||
int tokensPerName = 0;
|
||||
if (!modelName.equals(BaseChatCompletion.Model.GPT_3_5_TURBO_0613.getName()) && !modelName.equals(BaseChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName()) && !modelName.equals(BaseChatCompletion.Model.GPT_4_0314.getName()) && !modelName.equals(BaseChatCompletion.Model.GPT_4_32K_0314.getName()) && !modelName.equals(BaseChatCompletion.Model.GPT_4_0613.getName()) && !modelName.equals(BaseChatCompletion.Model.GPT_4_32K_0613.getName())) {
|
||||
if (modelName.equals(BaseChatCompletion.Model.GPT_3_5_TURBO_0301.getName())) {
|
||||
tokensPerMessage = 4;
|
||||
tokensPerName = -1;
|
||||
} else if (modelName.contains(BaseChatCompletion.Model.GPT_3_5_TURBO.getName())) {
|
||||
log.warn("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.");
|
||||
tokensPerMessage = 3;
|
||||
tokensPerName = 1;
|
||||
} else if (modelName.contains(BaseChatCompletion.Model.GPT_4.getName())) {
|
||||
log.warn("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.");
|
||||
tokensPerMessage = 3;
|
||||
tokensPerName = 1;
|
||||
} else {
|
||||
log.warn("不支持的model {}. See https://github.com/openai/openai-python/blob/main/chatml.md 更多信息.", modelName);
|
||||
}
|
||||
} else {
|
||||
tokensPerMessage = 3;
|
||||
tokensPerName = 1;
|
||||
}
|
||||
|
||||
int sum = 0;
|
||||
Iterator var6 = messages.iterator();
|
||||
|
||||
while(var6.hasNext()) {
|
||||
Message msg = (Message)var6.next();
|
||||
sum += tokensPerMessage;
|
||||
sum += tokens(encoding, msg.getContent());
|
||||
sum += tokens(encoding, msg.getRole());
|
||||
sum += tokens(encoding, msg.getName());
|
||||
if (StrUtil.isNotBlank(msg.getName())) {
|
||||
sum += tokensPerName;
|
||||
}
|
||||
}
|
||||
|
||||
sum += 3;
|
||||
return sum;
|
||||
}
|
||||
|
||||
|
||||
static {
|
||||
ModelType[] var0 = ModelType.values();
|
||||
int var1 = var0.length;
|
||||
|
||||
for(int var2 = 0; var2 < var1; ++var2) {
|
||||
ModelType modelType = var0[var2];
|
||||
modelMap.put(modelType.getName(), registry.getEncodingForModel(modelType));
|
||||
}
|
||||
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_3_5_TURBO_0301.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_3_5_TURBO_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_3_5_TURBO_16K.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_3_5_TURBO_1106.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_32K.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_32K_0314.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_0314.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_32K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
modelMap.put(BaseChatCompletion.Model.GPT_4_VISION_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
target/gpt-4-copilot-0.2.2.jar.original
Normal file
BIN
target/gpt-4-copilot-0.2.2.jar.original
Normal file
Binary file not shown.
@@ -1,3 +1,3 @@
|
||||
artifactId=gpt-4-copilot
|
||||
groupId=com.gpt4.copilot
|
||||
version=0.2.1
|
||||
version=0.2.2
|
||||
|
Reference in New Issue
Block a user