diff --git a/src/agents/wolframAlphaSearchAgent.ts b/src/agents/wolframAlphaSearchAgent.ts index a9a3202..c071ef0 100644 --- a/src/agents/wolframAlphaSearchAgent.ts +++ b/src/agents/wolframAlphaSearchAgent.ts @@ -9,7 +9,7 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOpenAI, OpenAI } from '@langchain/openai'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; diff --git a/src/websocket/messageHandler.ts b/src/websocket/messageHandler.ts index 08bb4c4..83fa50d 100644 --- a/src/websocket/messageHandler.ts +++ b/src/websocket/messageHandler.ts @@ -1,4 +1,4 @@ -import { WebSocket } from 'ws'; +import { EventEmitter, WebSocket } from 'ws'; import { BaseMessage, AIMessage, HumanMessage } from '@langchain/core/messages'; import handleWebSearch from '../agents/webSearchAgent'; import handleAcademicSearch from '../agents/academicSearchAgent'; @@ -15,6 +15,49 @@ type Message = { history: Array<[string, string]>; }; +const searchHandlers = { + webSearch: handleWebSearch, + academicSearch: handleAcademicSearch, + writingAssistant: handleWritingAssistant, + wolframAlphaSearch: handleWolframAlphaSearch, + youtubeSearch: handleYoutubeSearch, + redditSearch: handleRedditSearch, +}; + +const handleEmitterEvents = ( + emitter: EventEmitter, + ws: WebSocket, + id: string, +) => { + emitter.on('data', (data) => { + const parsedData = JSON.parse(data); + if (parsedData.type === 'response') { + ws.send( + JSON.stringify({ + type: 'message', + data: parsedData.data, + messageId: id, + }), + ); + } else if (parsedData.type === 'sources') { + ws.send( + JSON.stringify({ + type: 'sources', + data: parsedData.data, + messageId: id, + }), + ); + } + }); + emitter.on('end', () => { + ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); + }); + emitter.on('error', (data) => { + const parsedData = JSON.parse(data); + ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); + }); +}; + export const handleMessage = async (message: string, ws: WebSocket) => { try { const parsedMessage = JSON.parse(message) as Message; @@ -38,191 +81,12 @@ export const handleMessage = async (message: string, ws: WebSocket) => { }); if (parsedMessage.type === 'message') { - switch (parsedMessage.focusMode) { - case 'webSearch': { - const emitter = handleWebSearch(parsedMessage.content, history); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } else if (parsedData.type === 'sources') { - ws.send( - JSON.stringify({ - type: 'sources', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } - case 'academicSearch': { - const emitter = handleAcademicSearch(parsedMessage.content, history); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } else if (parsedData.type === 'sources') { - ws.send( - JSON.stringify({ - type: 'sources', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } - case 'writingAssistant': { - const emitter = handleWritingAssistant( - parsedMessage.content, - history, - ); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } - case 'wolframAlphaSearch': { - const emitter = handleWolframAlphaSearch( - parsedMessage.content, - history, - ); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } else if (parsedData.type === 'sources') { - ws.send( - JSON.stringify({ - type: 'sources', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } - case 'youtubeSearch': { - const emitter = handleYoutubeSearch(parsedMessage.content, history); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } else if (parsedData.type === 'sources') { - ws.send( - JSON.stringify({ - type: 'sources', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } - case 'redditSearch': { - const emitter = handleRedditSearch(parsedMessage.content, history); - emitter.on('data', (data) => { - const parsedData = JSON.parse(data); - if (parsedData.type === 'response') { - ws.send( - JSON.stringify({ - type: 'message', - data: parsedData.data, - messageId: id, - }), - ); - } else if (parsedData.type === 'sources') { - ws.send( - JSON.stringify({ - type: 'sources', - data: parsedData.data, - messageId: id, - }), - ); - } - }); - emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); - }); - emitter.on('error', (data) => { - const parsedData = JSON.parse(data); - ws.send(JSON.stringify({ type: 'error', data: parsedData.data })); - }); - break; - } + const handler = searchHandlers[parsedMessage.focusMode]; + if (handler) { + const emitter = handler(parsedMessage.content, history); + handleEmitterEvents(emitter, ws, id); + } else { + ws.send(JSON.stringify({ type: 'error', data: 'Invalid focus mode' })); } } } catch (error) {