build: update langchain

This commit is contained in:
Hk-Gosuto
2024-01-24 20:51:46 +08:00
parent 30f9dc756a
commit 16e82afaad
16 changed files with 161 additions and 91 deletions

View File

@@ -1,33 +1,39 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { auth } from "../../../auth";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { BaseCallbackHandler } from "langchain/callbacks";
import { BaseCallbackHandler } from "@langchain/core/callbacks/base";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import {
AgentExecutor,
initializeAgentExecutorWithOptions,
} from "langchain/agents";
import { AgentExecutor } from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
// import * as langchainTools from "langchain/tools";
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { DynamicTool, Tool } from "langchain/tools";
import {
DynamicTool,
Tool,
StructuredToolInterface,
} from "@langchain/core/tools";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
import { DynamicStructuredTool, formatToOpenAITool } from "langchain/tools";
import { formatToOpenAIToolMessages } from "langchain/agents/format_scratchpad/openai_tools";
import {
OpenAIToolsAgentOutputParser,
type ToolsAgentStep,
} from "langchain/agents/openai/output_parser";
import { RunnableSequence } from "langchain/schema/runnable";
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
import { RunnableSequence } from "@langchain/core/runnables";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import {
SystemMessage,
HumanMessage,
AIMessage,
} from "@langchain/core/messages";
export interface RequestMessage {
role: string;
@@ -66,23 +72,27 @@ export class AgentApi {
private encoder: TextEncoder;
private transformStream: TransformStream;
private writer: WritableStreamDefaultWriter<any>;
private controller: AbortController;
constructor(
encoder: TextEncoder,
transformStream: TransformStream,
writer: WritableStreamDefaultWriter<any>,
controller: AbortController,
) {
this.encoder = encoder;
this.transformStream = transformStream;
this.writer = writer;
this.controller = controller;
}
async getHandler(reqBody: any) {
var writer = this.writer;
var encoder = this.encoder;
var controller = this.controller;
return BaseCallbackHandler.fromMethods({
async handleLLMNewToken(token: string) {
if (token) {
if (token && !controller.signal.aborted) {
var response = new ResponseBody();
response.message = token;
await writer.ready;
@@ -92,6 +102,11 @@ export class AgentApi {
}
},
async handleChainError(err, runId, parentRunId, tags) {
if (controller.signal.aborted) {
console.warn("[handleChainError]", "abort");
await writer.close();
return;
}
console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
@@ -112,6 +127,11 @@ export class AgentApi {
// await writer.close();
},
async handleLLMError(e: Error) {
if (controller.signal.aborted) {
console.warn("[handleLLMError]", "abort");
await writer.close();
return;
}
console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
@@ -159,6 +179,9 @@ export class AgentApi {
// console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
async handleAgentEnd(action, runId, parentRunId, tags) {
if (controller.signal.aborted) {
return;
}
console.log("[handleAgentEnd]");
await writer.ready;
await writer.close();
@@ -301,14 +324,6 @@ export class AgentApi {
pastMessages.push(new AIMessage(message.content));
});
// const memory = new BufferMemory({
// memoryKey: "chat_history",
// returnMessages: true,
// inputKey: "input",
// outputKey: "output",
// chatHistory: new ChatMessageHistory(pastMessages),
// });
let llm = new ChatOpenAI(
{
modelName: reqBody.model,
@@ -349,7 +364,9 @@ export class AgentApi {
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
const modelWithTools = llm.bind({ tools: tools.map(formatToOpenAITool) });
const modelWithTools = llm.bind({
tools: tools.map(convertToOpenAITool),
});
const runnableAgent = RunnableSequence.from([
{
input: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
@@ -373,19 +390,21 @@ export class AgentApi {
tools,
});
// const executor = await initializeAgentExecutorWithOptions(tools, llm, {
// agentType: "openai-functions",
// returnIntermediateSteps: reqBody.returnIntermediateSteps,
// maxIterations: reqBody.maxIterations,
// memory: memory,
// });
executor.call(
{
input: reqBody.messages.slice(-1)[0].content,
},
[handler],
);
executor
.call(
{
input: reqBody.messages.slice(-1)[0].content,
signal: this.controller.signal,
},
[handler],
)
.catch((error) => {
if (this.controller.signal.aborted) {
console.warn("[AgentCall]", "abort");
} else {
console.error("[AgentCall]", error);
}
});
return new Response(this.transformStream.readable, {
headers: { "Content-Type": "text/event-stream" },