From 65d057a05ec24ae717e97347d9dee4544c2273e1 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:29:06 +0530 Subject: [PATCH] feat(suggestions): handle custom OpenAI --- src/routes/suggestions.ts | 61 ++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts index a75657e..e997b1e 100644 --- a/src/routes/suggestions.ts +++ b/src/routes/suggestions.ts @@ -4,14 +4,27 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; +import { ChatOpenAI } from '@langchain/openai'; const router = express.Router(); +interface ChatModel { + provider: string; + model: string; + customOpenAIBaseURL?: string; + customOpenAIKey?: string; +} + +interface SuggestionsBody { + chatHistory: any[]; + chatModel?: ChatModel; +} + router.post('/', async (req, res) => { try { - let { chat_history, chat_model, chat_model_provider } = req.body; + let body: SuggestionsBody = req.body; - chat_history = chat_history.map((msg: any) => { + const chatHistory = body.chatHistory.map((msg: any) => { if (msg.role === 'user') { return new HumanMessage(msg.content); } else if (msg.role === 'assistant') { @@ -19,22 +32,50 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableChatModelProviders(); - const provider = chat_model_provider ?? Object.keys(chatModels)[0]; - const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; + const chatModelProviders = await getAvailableChatModelProviders(); + + const chatModelProvider = + body.chatModel?.provider || Object.keys(chatModelProviders)[0]; + const chatModel = + body.chatModel?.model || + Object.keys(chatModelProviders[chatModelProvider])[0]; let llm: BaseChatModel | undefined; - if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; + if (body.chatModel?.provider === 'custom_openai') { + if ( + !body.chatModel?.customOpenAIBaseURL || + !body.chatModel?.customOpenAIKey + ) { + return res + .status(400) + .json({ message: 'Missing custom OpenAI base URL or key' }); + } + + llm = new ChatOpenAI({ + modelName: body.chatModel.model, + openAIApiKey: body.chatModel.customOpenAIKey, + temperature: 0.7, + configuration: { + baseURL: body.chatModel.customOpenAIBaseURL, + }, + }) as unknown as BaseChatModel; + } else if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] + ) { + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; } if (!llm) { - res.status(500).json({ message: 'Invalid LLM model selected' }); - return; + return res.status(400).json({ message: 'Invalid model selected' }); } - const suggestions = await generateSuggestions({ chat_history }, llm); + const suggestions = await generateSuggestions( + { chat_history: chatHistory }, + llm, + ); res.status(200).json({ suggestions: suggestions }); } catch (err) {