diff --git a/Dockerfiles/Dockerfile.jar b/Dockerfiles/Dockerfile.jar index 25a8c1e..4398384 100644 --- a/Dockerfiles/Dockerfile.jar +++ b/Dockerfiles/Dockerfile.jar @@ -8,7 +8,7 @@ LABEL maintainer="Yanyutin753" USER root # 复制JAR文件到容器的/app目录下 -COPY /target/gpt-4-copilot-0.2.1.jar app.jar +COPY /target/gpt-4-copilot-0.2.2.jar app.jar # 声明服务运行在8081端口 EXPOSE 8081 diff --git a/pom.xml b/pom.xml index 506a63b..8a5b6d2 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ com.gpt4.copilot gpt-4-copilot - 0.2.1 + 0.2.2 native Demo project for Spring Boot with GraalVM Native Image @@ -54,6 +54,11 @@ chatgpt-java 1.1.5 + + com.knuddels + jtokkit + 0.6.1 + diff --git a/src/main/java/com/gpt4/copilot/controller/ChatController.java b/src/main/java/com/gpt4/copilot/controller/ChatController.java index d3c8563..df40048 100644 --- a/src/main/java/com/gpt4/copilot/controller/ChatController.java +++ b/src/main/java/com/gpt4/copilot/controller/ChatController.java @@ -1,6 +1,7 @@ package com.gpt4.copilot.controller; import com.alibaba.fastjson.serializer.SerializerFeature; +import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONException; import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.TypeReference; @@ -11,13 +12,13 @@ import com.gpt4.copilot.pojo.Conversation; import com.gpt4.copilot.pojo.Result; import com.gpt4.copilot.pojo.SystemSetting; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.utils.TikTokensUtil; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.Data; import lombok.extern.slf4j.Slf4j; import okhttp3.*; import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.scheduling.annotation.Scheduled; @@ -52,6 +53,7 @@ import java.util.concurrent.atomic.AtomicInteger; @Data @RestController() public class ChatController { + public static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); /** * 模型 @@ -501,29 +503,33 @@ public class ChatController { String model = modelAdjust(conversation); Request streamRequest = getPrompt(conversation, model, headersMap); try (Response resp = client.newCall(streamRequest).execute()) { + log.info(resp.toString()); if (!resp.isSuccessful()) { if (resp.code() == 429) { return new ResponseEntity<>(Result.error("rate limit exceeded"), HttpStatus.TOO_MANY_REQUESTS); } else if (resp.code() == 400) { return new ResponseEntity<>(Result.error("messages is none or too long and over 32K"), HttpStatus.INTERNAL_SERVER_ERROR); } else { - String token = getCopilotToken(apiKey); - if (token == null) { - return new ResponseEntity<>(Result.error("copilot APIKey is wrong"), HttpStatus.UNAUTHORIZED); + if (resp.code() == 403) { + String token = getCopilotToken(apiKey); + if (token == null) { + return new ResponseEntity<>(Result.error("copilot APIKey is wrong"), HttpStatus.UNAUTHORIZED); + } + copilotTokenList.put(apiKey, token); + log.info("token过期,Github CopilotToken重置化成功!"); + return againConversation(response, conversation, token, apiKey, model); + } else { + return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR); } - copilotTokenList.put(apiKey, token); - log.info("token过期,Github CopilotToken重置化成功!"); - againConversation(response, conversation, token, apiKey, model); } } else { // 流式和非流式输出 - outPutChat(response, resp, conversation, model); + return outPutChat(response, resp, conversation, model); } } } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } - return null; }, executor); return getObjectResponseEntity(response, future); @@ -542,7 +548,7 @@ public class ChatController { private String getEmbRequestApikey(String authorizationHeader, @org.springframework.web.bind.annotation.RequestBody Object conversation) { if (conversation == null) { - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format"); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format"); } String apiKey; if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) { @@ -624,17 +630,16 @@ public class ChatController { } coCopilotTokenList.put(apiKey, token); log.info("token过期,coCopilotToken重置化成功!"); - againConversation(response, conversation, token, apiKey, model); + return againConversation(response, conversation, token, apiKey, model); } } else { // 流式和非流式输出 - outPutChat(response, resp, conversation, model); + return outPutChat(response, resp, conversation, model); } } } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } - return null; }, executor); return getObjectResponseEntity(response, future); @@ -677,7 +682,7 @@ public class ChatController { private void checkConversation(Conversation conversation) { if (conversation == null) { - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format"); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format"); } long tokens = conversation.tokens(); if (tokens > 32 * 1024) { @@ -722,7 +727,7 @@ public class ChatController { */ private String[] extractEmbApiKeyAndRequestUrl(String authorizationHeader, Object conversation) throws IllegalArgumentException { if (conversation == null) { - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Request body is missing or not in JSON format"); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Request body is missing or not in JSON format"); } return getApiKeyAndRequestUrl(authorizationHeader); } @@ -788,17 +793,16 @@ public class ChatController { } selfTokenList.put(apiKey, token); log.info("token过期,自定义selfToken重置化成功!"); - againConversation(response, conversation, token, apiKey, model); + return againConversation(response, conversation, token, apiKey, model); } } else { // 流式和非流式输出 - outPutChat(response, resp, conversation, model); + return outPutChat(response, resp, conversation, model); } } } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } - return null; }, executor); return getObjectResponseEntity(response, future); @@ -813,11 +817,11 @@ public class ChatController { * @param token * @return */ - public Object againConversation(HttpServletResponse response, - @org.springframework.web.bind.annotation.RequestBody Conversation conversation, - String token, - String apiKey, - String model) { + public ResponseEntity againConversation(HttpServletResponse response, + @org.springframework.web.bind.annotation.RequestBody Conversation conversation, + String token, + String apiKey, + String model) { try { Map headersMap = new HashMap<>(); //添加头部 @@ -825,15 +829,16 @@ public class ChatController { Request streamRequest = getPrompt(conversation, model, headersMap); try (Response resp = client.newCall(streamRequest).execute()) { if (!resp.isSuccessful()) { - return new ResponseEntity<>("copilot/cocopilot/self APIKey is wrong Or your network is wrong", HttpStatus.UNAUTHORIZED); + return new ResponseEntity<>("APIKey is wrong!", HttpStatus.UNAUTHORIZED); } else { // 流式和非流式输出 - outPutChat(response, resp, conversation, model); + return outPutChat(response, resp, conversation, model); } } - return null; } catch (Exception e) { - throw new RuntimeException(e); + log.info("Exception " + e.getMessage()); + return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR); + } } @@ -937,7 +942,7 @@ public class ChatController { } return null; } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } }, executor); @@ -1027,7 +1032,7 @@ public class ChatController { } return null; } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } }, executor); @@ -1103,7 +1108,7 @@ public class ChatController { } return null; } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } }, executor); return getObjectResponseEntity(future); @@ -1132,7 +1137,7 @@ public class ChatController { } return null; } catch (Exception e) { - return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + return new ResponseEntity<>(e.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR); } } @@ -1154,7 +1159,6 @@ public class ChatController { return getToken(request); } - @Nullable private String getToken(Request request) throws IOException { try (Response response = client.newCall(request).execute()) { log.info(response.toString()); @@ -1263,40 +1267,116 @@ public class ChatController { * @param resp * @param conversation */ - private void outPutChat(HttpServletResponse response, Response resp, Conversation conversation, String model) { - try { - boolean isStream = conversation.isStream(); - int sleep_time = calculateSleepTime(model, isStream); - if (isStream) { - response.setContentType("text/event-stream; charset=UTF-8"); - } else { - response.setContentType("application/json; charset=utf-8"); - } - - try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true); - BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) { - - String line; - while ((line = in.readLine()) != null) { - out.println(line); - out.flush(); - if (sleep_time > 0 && line.startsWith("data")) { - Thread.sleep(sleep_time); - } - } - log.info("使用模型:" + model + ",vscode_version:" + systemSetting.getVscode_version() + - ",copilot_chat_version:" + systemSetting.getCopilot_chat_version() - + ",字符间隔时间:" + sleep_time + "ms,响应:" + resp); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Thread was interrupted", e); - } - } catch (IOException e) { - throw new RuntimeException("IO Exception occurred", e); + private ResponseEntity outPutChat(HttpServletResponse response, Response resp, Conversation conversation, String model) { + boolean isStream = conversation.isStream(); + int sleep_time = calculateSleepTime(model, isStream); + if (isStream) { + response.setContentType("text/event-stream; charset=UTF-8"); + return outIsStreamPutChat(response, resp, model, sleep_time); + } else { + response.setContentType("application/json; charset=utf-8"); + return outNoStreamPutChat(response, resp, model); } } /** + * chat非流式接口的输出 + * + * @param response + * @param resp + * @param model + * @return + */ + private ResponseEntity outNoStreamPutChat(HttpServletResponse response, Response resp, String model) { + try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true); + BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) { + + String line; + long tokens = 0; + while ((line = in.readLine()) != null) { + out.println(line); + out.flush(); + if (line.trim().startsWith("{")) { + try { + JSONObject resJson = com.alibaba.fastjson2.JSON.parseObject(line); + JSONObject usageJson = resJson.getJSONObject("usage"); + if (usageJson != null) { + Integer totalTokens = usageJson.getInteger("total_tokens"); + if (totalTokens != null) { + tokens = totalTokens; + } else { + log.warn("total_tokens field missing in the response"); + } + } else { + log.warn("usage field missing in the response"); + } + } catch (JSONException je) { + log.error("Error parsing JSON", je); + } + } + } + log.info("使用模型:" + model + ",补全tokens:" + tokens + ",vscode_version:" + systemSetting.getVscode_version() + + ",copilot_chat_version:" + systemSetting.getCopilot_chat_version() + ",响应:" + resp); + if (tokens <= 0) { + return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR); + } + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * chat流式接口的输出 + * + * @param response + * @param resp + * @param model + * @param sleep_time + * @return + */ + private ResponseEntity outIsStreamPutChat(HttpServletResponse response, Response resp, String model, int sleep_time) { + try (PrintWriter out = new PrintWriter(new OutputStreamWriter(response.getOutputStream(), StandardCharsets.UTF_8), true); + BufferedReader in = new BufferedReader(new InputStreamReader(resp.body().byteStream(), StandardCharsets.UTF_8))) { + + String line; + long tokens = 0; + while ((line = in.readLine()) != null) { + out.println(line); + out.flush(); + if (sleep_time > 0 && line.startsWith("data:")) { + try { + String res_text = line.replace("data: ", ""); + if (!"[DONE]".equals(res_text.trim())) { + JSONObject resJson = com.alibaba.fastjson2.JSON.parseObject(res_text); + JSONArray choicesArray = resJson.getJSONArray("choices"); + if (choicesArray.size() > 0) { + JSONObject firstChoice = choicesArray.getJSONObject(0); + String content = firstChoice.getJSONObject("delta").getString("content"); + tokens += TikTokensUtil.tokens("gpt-4-0613", content); + } + Thread.sleep(sleep_time); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + log.info("使用模型:" + model + ",补全tokens:" + tokens + ",vscode_version:" + systemSetting.getVscode_version() + + ",copilot_chat_version:" + systemSetting.getCopilot_chat_version() + + ",字符间隔时间:" + sleep_time + "ms,响应:" + resp); + if (tokens <= 0) { + return new ResponseEntity<>(Result.error("HUm... A error occur......"), HttpStatus.INTERNAL_SERVER_ERROR); + } + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + + /** * chat接口的每个字的睡眠时间 */ private int calculateSleepTime(String model, boolean isStream) { diff --git a/src/main/java/com/gpt4/copilot/controller/CustomErrorController.java b/src/main/java/com/gpt4/copilot/controller/CustomErrorController.java index 3db1d88..61660da 100644 --- a/src/main/java/com/gpt4/copilot/controller/CustomErrorController.java +++ b/src/main/java/com/gpt4/copilot/controller/CustomErrorController.java @@ -25,9 +25,9 @@ public class CustomErrorController implements ErrorController { " Document\n" + "\n" + "\n" + - " Thanks you use gpt4-copilot-java-0.2.1\n" + + " Thanks you use gpt4-copilot-java-0.2.2\n" + " 详细使用文档\n" + - " 项目地址\n" + + " 项目地址\n" + "\n" + "
Thanks you use gpt4-copilot-java-0.2.1
Thanks you use gpt4-copilot-java-0.2.2
详细使用文档
项目地址