From 2785cdd97a9e913d335f01eda6a5a7f5e6fe1f06 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:27:48 +0530 Subject: [PATCH] feat(routes): add search route --- src/routes/index.ts | 2 + src/routes/search.ts | 150 ++++++++++++++++++++++++++++++++ src/websocket/messageHandler.ts | 2 +- 3 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 src/routes/search.ts diff --git a/src/routes/index.ts b/src/routes/index.ts index af928ab..6e82e54 100644 --- a/src/routes/index.ts +++ b/src/routes/index.ts @@ -5,6 +5,7 @@ import configRouter from './config'; import modelsRouter from './models'; import suggestionsRouter from './suggestions'; import chatsRouter from './chats'; +import searchRouter from './search'; const router = express.Router(); @@ -14,5 +15,6 @@ router.use('/config', configRouter); router.use('/models', modelsRouter); router.use('/suggestions', suggestionsRouter); router.use('/chats', chatsRouter); +router.use('/search', searchRouter); export default router; diff --git a/src/routes/search.ts b/src/routes/search.ts new file mode 100644 index 0000000..8eb2490 --- /dev/null +++ b/src/routes/search.ts @@ -0,0 +1,150 @@ +import express from 'express'; +import logger from '../utils/logger'; +import { BaseChatModel } from 'langchain/chat_models/base'; +import { Embeddings } from 'langchain/embeddings/base'; +import { ChatOpenAI } from '@langchain/openai'; +import { + getAvailableChatModelProviders, + getAvailableEmbeddingModelProviders, +} from '../lib/providers'; +import { searchHandlers } from '../websocket/messageHandler'; +import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; + +const router = express.Router(); + +interface chatModel { + provider: string; + model: string; + customOpenAIBaseURL?: string; + customOpenAIKey?: string; +} + +interface embeddingModel { + provider: string; + model: string; +} + +interface RequestBody { + focusMode: string; + chatModel?: chatModel; + embeddingModel?: embeddingModel; + query: string; + history: Array<[string, string]>; +} + +router.post('/', async (req, res) => { + try { + const body: RequestBody = req.body; + + if (!body.focusMode || !body.query) { + return res.status(400).json({ message: 'Missing focus mode or query' }); + } + + body.history = body.history || []; + + const history: BaseMessage[] = body.history.map((msg) => { + if (msg[0] === 'human') { + return new HumanMessage({ + content: msg[1], + }); + } else { + return new AIMessage({ + content: msg[1], + }); + } + }); + + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); + + const chatModelProvider = + body.chatModel?.provider || Object.keys(chatModelProviders)[0]; + const chatModel = + body.chatModel?.model || + Object.keys(chatModelProviders[chatModelProvider])[0]; + + const embeddingModelProvider = + body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0]; + const embeddingModel = + body.embeddingModel?.model || + Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; + + let llm: BaseChatModel | undefined; + let embeddings: Embeddings | undefined; + + if (body.chatModel?.provider === 'custom_openai') { + if ( + !body.chatModel?.customOpenAIBaseURL || + !body.chatModel?.customOpenAIKey + ) { + return res + .status(400) + .json({ message: 'Missing custom OpenAI base URL or key' }); + } + + llm = new ChatOpenAI({ + modelName: body.chatModel.model, + openAIApiKey: body.chatModel.customOpenAIKey, + temperature: 0.7, + configuration: { + baseURL: body.chatModel.customOpenAIBaseURL, + }, + }) as unknown as BaseChatModel; + } else if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] + ) { + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; + } + + if ( + embeddingModelProviders[embeddingModelProvider] && + embeddingModelProviders[embeddingModelProvider][embeddingModel] + ) { + embeddings = embeddingModelProviders[embeddingModelProvider][ + embeddingModel + ].model as Embeddings | undefined; + } + + if (!llm || !embeddings) { + return res.status(400).json({ message: 'Invalid model selected' }); + } + + const searchHandler = searchHandlers[body.focusMode]; + + if (!searchHandler) { + return res.status(400).json({ message: 'Invalid focus mode' }); + } + + const emitter = searchHandler(body.query, history, llm, embeddings); + + let message = ''; + let sources = []; + + emitter.on('data', (data) => { + const parsedData = JSON.parse(data); + if (parsedData.type === 'response') { + message += parsedData.data; + } else if (parsedData.type === 'sources') { + sources = parsedData.data; + } + }); + + emitter.on('end', () => { + res.status(200).json({ message, sources }); + }); + + emitter.on('error', (data) => { + const parsedData = JSON.parse(data); + res.status(500).json({ message: parsedData.data }); + }); + } catch (err: any) { + logger.error(`Error in getting search results: ${err.message}`); + res.status(500).json({ message: 'An error has occurred.' }); + } +}); + +export default router; diff --git a/src/websocket/messageHandler.ts b/src/websocket/messageHandler.ts index 0afda9f..332910c 100644 --- a/src/websocket/messageHandler.ts +++ b/src/websocket/messageHandler.ts @@ -28,7 +28,7 @@ type WSMessage = { history: Array<[string, string]>; }; -const searchHandlers = { +export const searchHandlers = { webSearch: handleWebSearch, academicSearch: handleAcademicSearch, writingAssistant: handleWritingAssistant,