diff --git a/src/routes/images.ts b/src/routes/images.ts index 7806ce7..c54dc40 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -4,14 +4,28 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; +import { ChatOpenAI } from '@langchain/openai'; const router = express.Router(); +interface ChatModel { + provider: string; + model: string; + customOpenAIBaseURL?: string; + customOpenAIKey?: string; +} + +interface ImageSearchBody { + query: string; + chatHistory: any[]; + chatModel?: ChatModel; +} + router.post('/', async (req, res) => { try { - let { query, chat_history, chat_model_provider, chat_model } = req.body; + let body: ImageSearchBody = req.body; - chat_history = chat_history.map((msg: any) => { + const chatHistory = body.chatHistory.map((msg: any) => { if (msg.role === 'user') { return new HumanMessage(msg.content); } else if (msg.role === 'assistant') { @@ -19,22 +33,50 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableChatModelProviders(); - const provider = chat_model_provider ?? Object.keys(chatModels)[0]; - const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; + const chatModelProviders = await getAvailableChatModelProviders(); + + const chatModelProvider = + body.chatModel?.provider || Object.keys(chatModelProviders)[0]; + const chatModel = + body.chatModel?.model || + Object.keys(chatModelProviders[chatModelProvider])[0]; let llm: BaseChatModel | undefined; - if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; + if (body.chatModel?.provider === 'custom_openai') { + if ( + !body.chatModel?.customOpenAIBaseURL || + !body.chatModel?.customOpenAIKey + ) { + return res + .status(400) + .json({ message: 'Missing custom OpenAI base URL or key' }); + } + + llm = new ChatOpenAI({ + modelName: body.chatModel.model, + openAIApiKey: body.chatModel.customOpenAIKey, + temperature: 0.7, + configuration: { + baseURL: body.chatModel.customOpenAIBaseURL, + }, + }) as unknown as BaseChatModel; + } else if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] + ) { + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; } if (!llm) { - res.status(500).json({ message: 'Invalid LLM model selected' }); - return; + return res.status(400).json({ message: 'Invalid model selected' }); } - const images = await handleImageSearch({ query, chat_history }, llm); + const images = await handleImageSearch( + { query: body.query, chat_history: chatHistory }, + llm, + ); res.status(200).json({ images }); } catch (err) {