build: update langchain

This commit is contained in:
Hk-Gosuto
2024-01-24 20:51:46 +08:00
parent 30f9dc756a
commit 16e82afaad
16 changed files with 161 additions and 91 deletions

View File

@@ -1,4 +1,4 @@
import { StructuredTool } from "langchain/tools";
import { StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
export class ArxivAPIWrapper extends StructuredTool {

View File

@@ -1,6 +1,6 @@
import { decode } from "html-entities";
import { convert as htmlToText } from "html-to-text";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
import * as cheerio from "cheerio";
import { getRandomUserAgent } from "./ua_tools";

View File

@@ -1,4 +1,4 @@
import { StructuredTool } from "langchain/tools";
import { StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import S3FileStorage from "../../utils/s3_file_storage";

View File

@@ -1,6 +1,6 @@
import { SafeSearchType, search } from "duck-duck-scrape";
import { convert as htmlToText } from "html-to-text";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
export class DuckDuckGo extends Tool {
name = "duckduckgo_search";

View File

@@ -1,6 +1,6 @@
import { decode } from "html-entities";
import { convert as htmlToText } from "html-to-text";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
const SEARCH_REGEX =
/DDG\.pageLayout\.load\('d',(\[.+\])\);DDG\.duckbar\.load\('images'/;

View File

@@ -1,6 +1,6 @@
import { decode } from "html-entities";
import { convert as htmlToText } from "html-to-text";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
import * as cheerio from "cheerio";
import { getRandomUserAgent } from "./ua_tools";

View File

@@ -1,5 +1,5 @@
import { htmlToText } from "html-to-text";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
export interface Headers {
[key: string]: string;

View File

@@ -0,0 +1,25 @@
export {
SerpAPI,
type SerpAPIParameters,
} from "@langchain/community/tools/serpapi";
export { DadJokeAPI } from "@langchain/community/tools/dadjokeapi";
export { BingSerpAPI } from "@langchain/community/tools/bingserpapi";
export {
Serper,
type SerperParameters,
} from "@langchain/community/tools/serper";
export {
GoogleCustomSearch,
type GoogleCustomSearchParams,
} from "@langchain/community/tools/google_custom_search";
export { AIPluginTool } from "@langchain/community/tools/aiplugin";
export {
WikipediaQueryRun,
type WikipediaQueryRunParams,
} from "@langchain/community/tools/wikipedia_query_run";
export { WolframAlphaTool } from "@langchain/community/tools/wolframalpha";
export { SearxngSearch } from "@langchain/community/tools/searxng_search";
export {
SearchApi,
type SearchApiParameters,
} from "@langchain/community/tools/searchapi";

View File

@@ -1,12 +1,13 @@
import axiosMod, { AxiosStatic } from "axios";
import { WebPDFLoader } from "langchain/document_loaders/web/pdf";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
import {
RecursiveCharacterTextSplitter,
TextSplitter,
} from "langchain/text_splitter";
import { CallbackManagerForToolRun } from "langchain/callbacks";
import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { formatDocumentsAsString } from "langchain/util/document";
import { Embeddings } from "langchain/dist/embeddings/base.js";

View File

@@ -1,4 +1,4 @@
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
import S3FileStorage from "../../utils/s3_file_storage";
export class StableDiffusionWrapper extends Tool {

View File

@@ -1,4 +1,4 @@
import { Tool } from "langchain/tools";
import { Tool } from "@langchain/core/tools";
export class WolframAlphaTool extends Tool {
name = "wolfram_alpha_llm";

View File

@@ -1,33 +1,39 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { auth } from "../../../auth";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { BaseCallbackHandler } from "langchain/callbacks";
import { BaseCallbackHandler } from "@langchain/core/callbacks/base";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import {
AgentExecutor,
initializeAgentExecutorWithOptions,
} from "langchain/agents";
import { AgentExecutor } from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
// import * as langchainTools from "langchain/tools";
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { DynamicTool, Tool } from "langchain/tools";
import {
DynamicTool,
Tool,
StructuredToolInterface,
} from "@langchain/core/tools";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
import { DynamicStructuredTool, formatToOpenAITool } from "langchain/tools";
import { formatToOpenAIToolMessages } from "langchain/agents/format_scratchpad/openai_tools";
import {
OpenAIToolsAgentOutputParser,
type ToolsAgentStep,
} from "langchain/agents/openai/output_parser";
import { RunnableSequence } from "langchain/schema/runnable";
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
import { RunnableSequence } from "@langchain/core/runnables";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import {
SystemMessage,
HumanMessage,
AIMessage,
} from "@langchain/core/messages";
export interface RequestMessage {
role: string;
@@ -66,23 +72,27 @@ export class AgentApi {
private encoder: TextEncoder;
private transformStream: TransformStream;
private writer: WritableStreamDefaultWriter<any>;
private controller: AbortController;
constructor(
encoder: TextEncoder,
transformStream: TransformStream,
writer: WritableStreamDefaultWriter<any>,
controller: AbortController,
) {
this.encoder = encoder;
this.transformStream = transformStream;
this.writer = writer;
this.controller = controller;
}
async getHandler(reqBody: any) {
var writer = this.writer;
var encoder = this.encoder;
var controller = this.controller;
return BaseCallbackHandler.fromMethods({
async handleLLMNewToken(token: string) {
if (token) {
if (token && !controller.signal.aborted) {
var response = new ResponseBody();
response.message = token;
await writer.ready;
@@ -92,6 +102,11 @@ export class AgentApi {
}
},
async handleChainError(err, runId, parentRunId, tags) {
if (controller.signal.aborted) {
console.warn("[handleChainError]", "abort");
await writer.close();
return;
}
console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
@@ -112,6 +127,11 @@ export class AgentApi {
// await writer.close();
},
async handleLLMError(e: Error) {
if (controller.signal.aborted) {
console.warn("[handleLLMError]", "abort");
await writer.close();
return;
}
console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
@@ -159,6 +179,9 @@ export class AgentApi {
// console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
async handleAgentEnd(action, runId, parentRunId, tags) {
if (controller.signal.aborted) {
return;
}
console.log("[handleAgentEnd]");
await writer.ready;
await writer.close();
@@ -301,14 +324,6 @@ export class AgentApi {
pastMessages.push(new AIMessage(message.content));
});
// const memory = new BufferMemory({
// memoryKey: "chat_history",
// returnMessages: true,
// inputKey: "input",
// outputKey: "output",
// chatHistory: new ChatMessageHistory(pastMessages),
// });
let llm = new ChatOpenAI(
{
modelName: reqBody.model,
@@ -349,7 +364,9 @@ export class AgentApi {
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
const modelWithTools = llm.bind({ tools: tools.map(formatToOpenAITool) });
const modelWithTools = llm.bind({
tools: tools.map(convertToOpenAITool),
});
const runnableAgent = RunnableSequence.from([
{
input: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
@@ -373,19 +390,21 @@ export class AgentApi {
tools,
});
// const executor = await initializeAgentExecutorWithOptions(tools, llm, {
// agentType: "openai-functions",
// returnIntermediateSteps: reqBody.returnIntermediateSteps,
// maxIterations: reqBody.maxIterations,
// memory: memory,
// });
executor.call(
{
input: reqBody.messages.slice(-1)[0].content,
},
[handler],
);
executor
.call(
{
input: reqBody.messages.slice(-1)[0].content,
signal: this.controller.signal,
},
[handler],
)
.catch((error) => {
if (this.controller.signal.aborted) {
console.warn("[AgentCall]", "abort");
} else {
console.error("[AgentCall]", error);
}
});
return new Response(this.transformStream.readable, {
headers: { "Content-Type": "text/event-stream" },

View File

@@ -2,9 +2,8 @@ import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { ModelProvider } from "@/app/constant";
import { OpenAI, OpenAIEmbeddings } from "@langchain/openai";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@@ -21,7 +20,8 @@ async function handle(req: NextRequest) {
const encoder = new TextEncoder();
const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter();
const agentApi = new AgentApi(encoder, transformStream, writer);
const controller = new AbortController();
const agentApi = new AgentApi(encoder, transformStream, writer, controller);
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
@@ -52,6 +52,9 @@ async function handle(req: NextRequest) {
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
controller.abort({
reason: "dall-e tool abort",
});
};
var edgeTool = new EdgeTool(

View File

@@ -1,11 +1,9 @@
import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
import { ModelProvider } from "@/app/constant";
import { OpenAI, OpenAIEmbeddings } from "@langchain/openai";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@@ -22,7 +20,8 @@ async function handle(req: NextRequest) {
const encoder = new TextEncoder();
const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter();
const agentApi = new AgentApi(encoder, transformStream, writer);
const controller = new AbortController();
const agentApi = new AgentApi(encoder, transformStream, writer, controller);
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
@@ -53,6 +52,9 @@ async function handle(req: NextRequest) {
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
controller.abort({
reason: "dall-e tool abort",
});
};
var nodejsTool = new NodeJSTool(