From 9f45ecb98d8d49a7201c0250c8c88738d621f3dc Mon Sep 17 00:00:00 2001 From: ItzCrazyKns Date: Sat, 4 May 2024 10:51:06 +0530 Subject: [PATCH] feat(providers): separate embedding providers, add custom-openai provider --- src/lib/providers.ts | 63 ++++++--- src/routes/config.ts | 27 ++-- src/routes/images.ts | 4 +- src/routes/models.ts | 12 +- src/routes/videos.ts | 4 +- src/websocket/connectionManager.ts | 55 ++++++-- ui/components/ChatWindow.tsx | 60 ++++++++- ui/components/SettingsDialog.tsx | 207 +++++++++++++++++++++++++---- 8 files changed, 360 insertions(+), 72 deletions(-) diff --git a/src/lib/providers.ts b/src/lib/providers.ts index 9b62ce0..d6904c0 100644 --- a/src/lib/providers.ts +++ b/src/lib/providers.ts @@ -8,7 +8,7 @@ import { } from '../config'; import logger from '../utils/logger'; -export const getAvailableProviders = async () => { +export const getAvailableChatModelProviders = async () => { const openAIApiKey = getOpenaiApiKey(); const groqApiKey = getGroqApiKey(); const ollamaEndpoint = getOllamaApiEndpoint(); @@ -33,10 +33,6 @@ export const getAvailableProviders = async () => { modelName: 'gpt-4-turbo', temperature: 0.7, }), - embeddings: new OpenAIEmbeddings({ - openAIApiKey, - modelName: 'text-embedding-3-large', - }), }; } catch (err) { logger.error(`Error loading OpenAI models: ${err}`); @@ -86,10 +82,6 @@ export const getAvailableProviders = async () => { baseURL: 'https://api.groq.com/openai/v1', }, ), - embeddings: new OpenAIEmbeddings({ - openAIApiKey: openAIApiKey, - modelName: 'text-embedding-3-large', - }), }; } catch (err) { logger.error(`Error loading Groq models: ${err}`); @@ -110,17 +102,56 @@ export const getAvailableProviders = async () => { }); return acc; }, {}); - - if (Object.keys(models['ollama']).length > 0) { - models['ollama']['embeddings'] = new OllamaEmbeddings({ - baseUrl: ollamaEndpoint, - model: models['ollama'][Object.keys(models['ollama'])[0]].model, - }); - } } catch (err) { logger.error(`Error loading Ollama models: ${err}`); } } + models['custom_openai'] = {}; + + return models; +}; + +export const getAvailableEmbeddingModelProviders = async () => { + const openAIApiKey = getOpenaiApiKey(); + const ollamaEndpoint = getOllamaApiEndpoint(); + + const models = {}; + + if (openAIApiKey) { + try { + models['openai'] = { + 'Text embedding 3 small': new OpenAIEmbeddings({ + openAIApiKey, + modelName: 'text-embedding-3-small', + }), + 'Text embedding 3 large': new OpenAIEmbeddings({ + openAIApiKey, + modelName: 'text-embedding-3-large', + }), + }; + } catch (err) { + logger.error(`Error loading OpenAI embeddings: ${err}`); + } + } + + if (ollamaEndpoint) { + try { + const response = await fetch(`${ollamaEndpoint}/api/tags`); + + const { models: ollamaModels } = (await response.json()) as any; + + models['ollama'] = ollamaModels.reduce((acc, model) => { + acc[model.model] = new OllamaEmbeddings({ + baseUrl: ollamaEndpoint, + model: model.model, + }); + return acc; + }, {}); + } catch (err) { + logger.error(`Error loading Ollama embeddings: ${err}`); + } + } + return models; }; diff --git a/src/routes/config.ts b/src/routes/config.ts index 1bb9246..bf13b63 100644 --- a/src/routes/config.ts +++ b/src/routes/config.ts @@ -1,5 +1,8 @@ import express from 'express'; -import { getAvailableProviders } from '../lib/providers'; +import { + getAvailableChatModelProviders, + getAvailableEmbeddingModelProviders, +} from '../lib/providers'; import { getGroqApiKey, getOllamaApiEndpoint, @@ -12,16 +15,24 @@ const router = express.Router(); router.get('/', async (_, res) => { const config = {}; - const providers = await getAvailableProviders(); + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); - for (const provider in providers) { - delete providers[provider]['embeddings']; + config['chatModelProviders'] = {}; + config['embeddingModelProviders'] = {}; + + for (const provider in chatModelProviders) { + config['chatModelProviders'][provider] = Object.keys( + chatModelProviders[provider], + ); } - config['providers'] = {}; - - for (const provider in providers) { - config['providers'][provider] = Object.keys(providers[provider]); + for (const provider in embeddingModelProviders) { + config['embeddingModelProviders'][provider] = Object.keys( + embeddingModelProviders[provider], + ); } config['openaiApiKey'] = getOpenaiApiKey(); diff --git a/src/routes/images.ts b/src/routes/images.ts index 3906689..d8ad8e1 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -1,7 +1,7 @@ import express from 'express'; import handleImageSearch from '../agents/imageSearchAgent'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { getAvailableProviders } from '../lib/providers'; +import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; @@ -19,7 +19,7 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableProviders(); + const chatModels = await getAvailableChatModelProviders(); const provider = chat_model_provider || Object.keys(chatModels)[0]; const chatModel = chat_model || Object.keys(chatModels[provider])[0]; diff --git a/src/routes/models.ts b/src/routes/models.ts index f2332f4..36df25a 100644 --- a/src/routes/models.ts +++ b/src/routes/models.ts @@ -1,14 +1,20 @@ import express from 'express'; import logger from '../utils/logger'; -import { getAvailableProviders } from '../lib/providers'; +import { + getAvailableChatModelProviders, + getAvailableEmbeddingModelProviders, +} from '../lib/providers'; const router = express.Router(); router.get('/', async (req, res) => { try { - const providers = await getAvailableProviders(); + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); - res.status(200).json({ providers }); + res.status(200).json({ chatModelProviders, embeddingModelProviders }); } catch (err) { res.status(500).json({ message: 'An error has occurred.' }); logger.error(err.message); diff --git a/src/routes/videos.ts b/src/routes/videos.ts index fecd874..e117a5a 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -1,6 +1,6 @@ import express from 'express'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { getAvailableProviders } from '../lib/providers'; +import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import handleVideoSearch from '../agents/videoSearchAgent'; @@ -19,7 +19,7 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableProviders(); + const chatModels = await getAvailableChatModelProviders(); const provider = chat_model_provider || Object.keys(chatModels)[0]; const chatModel = chat_model || Object.keys(chatModels[provider])[0]; diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index c2f3798..88efb6b 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -1,10 +1,14 @@ import { WebSocket } from 'ws'; import { handleMessage } from './messageHandler'; -import { getAvailableProviders } from '../lib/providers'; +import { + getAvailableEmbeddingModelProviders, + getAvailableChatModelProviders, +} from '../lib/providers'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; +import { ChatOpenAI } from '@langchain/openai'; export const handleConnection = async ( ws: WebSocket, @@ -13,18 +17,53 @@ export const handleConnection = async ( const searchParams = new URL(request.url, `http://${request.headers.host}`) .searchParams; - const models = await getAvailableProviders(); - const provider = - searchParams.get('chatModelProvider') || Object.keys(models)[0]; + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); + + const chatModelProvider = + searchParams.get('chatModelProvider') || Object.keys(chatModelProviders)[0]; const chatModel = - searchParams.get('chatModel') || Object.keys(models[provider])[0]; + searchParams.get('chatModel') || + Object.keys(chatModelProviders[chatModelProvider])[0]; + + const embeddingModelProvider = + searchParams.get('embeddingModelProvider') || + Object.keys(embeddingModelProviders)[0]; + const embeddingModel = + searchParams.get('embeddingModel') || + Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; let llm: BaseChatModel | undefined; let embeddings: Embeddings | undefined; - if (models[provider] && models[provider][chatModel]) { - llm = models[provider][chatModel] as BaseChatModel | undefined; - embeddings = models[provider].embeddings as Embeddings | undefined; + if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] && + chatModelProvider != 'custom_openai' + ) { + llm = chatModelProviders[chatModelProvider][chatModel] as + | BaseChatModel + | undefined; + } else if (chatModelProvider == 'custom_openai') { + llm = new ChatOpenAI({ + modelName: chatModel, + openAIApiKey: searchParams.get('openAIApiKey'), + temperature: 0.7, + configuration: { + baseURL: searchParams.get('openAIBaseURL'), + }, + }); + } + + if ( + embeddingModelProviders[embeddingModelProvider] && + embeddingModelProviders[embeddingModelProvider][embeddingModel] + ) { + embeddings = embeddingModelProviders[embeddingModelProvider][ + embeddingModel + ] as Embeddings | undefined; } if (!llm || !embeddings) { diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 68a2ba0..f8298b1 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -22,11 +22,23 @@ const useSocket = (url: string) => { const connectWs = async () => { let chatModel = localStorage.getItem('chatModel'); let chatModelProvider = localStorage.getItem('chatModelProvider'); + let embeddingModel = localStorage.getItem('embeddingModel'); + let embeddingModelProvider = localStorage.getItem( + 'embeddingModelProvider', + ); - if (!chatModel || !chatModelProvider) { - const chatModelProviders = await fetch( + if ( + !chatModel || + !chatModelProvider || + !embeddingModel || + !embeddingModelProvider + ) { + const providers = await fetch( `${process.env.NEXT_PUBLIC_API_URL}/models`, - ).then(async (res) => (await res.json())['providers']); + ).then(async (res) => await res.json()); + + const chatModelProviders = providers.chatModelProviders; + const embeddingModelProviders = providers.embeddingModelProviders; if ( !chatModelProviders || @@ -34,16 +46,52 @@ const useSocket = (url: string) => { ) return console.error('No chat models available'); + if ( + !embeddingModelProviders || + Object.keys(embeddingModelProviders).length === 0 + ) + return console.error('No embedding models available'); + chatModelProvider = Object.keys(chatModelProviders)[0]; chatModel = Object.keys(chatModelProviders[chatModelProvider])[0]; + embeddingModelProvider = Object.keys(embeddingModelProviders)[0]; + embeddingModel = Object.keys( + embeddingModelProviders[embeddingModelProvider], + )[0]; + localStorage.setItem('chatModel', chatModel!); localStorage.setItem('chatModelProvider', chatModelProvider); + localStorage.setItem('embeddingModel', embeddingModel!); + localStorage.setItem( + 'embeddingModelProvider', + embeddingModelProvider, + ); } - const ws = new WebSocket( - `${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`, - ); + const wsURL = new URL(url); + const searchParams = new URLSearchParams({}); + + searchParams.append('chatModel', chatModel!); + searchParams.append('chatModelProvider', chatModelProvider); + + if (chatModelProvider === 'custom_openai') { + searchParams.append( + 'openAIApiKey', + localStorage.getItem('openAIApiKey')!, + ); + searchParams.append( + 'openAIBaseURL', + localStorage.getItem('openAIBaseURL')!, + ); + } + + searchParams.append('embeddingModel', embeddingModel!); + searchParams.append('embeddingModelProvider', embeddingModelProvider); + + wsURL.search = searchParams.toString(); + + const ws = new WebSocket(wsURL.toString()); ws.onopen = () => { console.log('[DEBUG] open'); setWs(ws); diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 16e57de..41c10cf 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -3,7 +3,10 @@ import { CloudUpload, RefreshCcw, RefreshCw } from 'lucide-react'; import React, { Fragment, useEffect, useState } from 'react'; interface SettingsType { - providers: { + chatModelProviders: { + [key: string]: string[]; + }; + embeddingModelProviders: { [key: string]: string[]; }; openaiApiKey: string; @@ -25,6 +28,17 @@ const SettingsDialog = ({ const [selectedChatModel, setSelectedChatModel] = useState( null, ); + const [selectedEmbeddingModelProvider, setSelectedEmbeddingModelProvider] = + useState(null); + const [selectedEmbeddingModel, setSelectedEmbeddingModel] = useState< + string | null + >(null); + const [customOpenAIApiKey, setCustomOpenAIApiKey] = useState( + null, + ); + const [customOpenAIBaseURL, setCustomOpenAIBaseURL] = useState( + null, + ); const [isLoading, setIsLoading] = useState(false); const [isUpdating, setIsUpdating] = useState(false); @@ -46,6 +60,12 @@ const SettingsDialog = ({ useEffect(() => { setSelectedChatModelProvider(localStorage.getItem('chatModelProvider')); setSelectedChatModel(localStorage.getItem('chatModel')); + setSelectedEmbeddingModelProvider( + localStorage.getItem('embeddingModelProvider'), + ); + setSelectedEmbeddingModel(localStorage.getItem('embeddingModel')); + setCustomOpenAIApiKey(localStorage.getItem('openAIApiKey')); + setCustomOpenAIBaseURL(localStorage.getItem('openAIBaseUrl')); }, []); const handleSubmit = async () => { @@ -62,6 +82,13 @@ const SettingsDialog = ({ localStorage.setItem('chatModelProvider', selectedChatModelProvider!); localStorage.setItem('chatModel', selectedChatModel!); + localStorage.setItem( + 'embeddingModelProvider', + selectedEmbeddingModelProvider!, + ); + localStorage.setItem('embeddingModel', selectedEmbeddingModel!); + localStorage.setItem('openAIApiKey', customOpenAIApiKey!); + localStorage.setItem('openAIBaseURL', customOpenAIBaseURL!); } catch (err) { console.log(err); } finally { @@ -107,7 +134,7 @@ const SettingsDialog = ({ {config && !isLoading && (
- {config.providers && ( + {config.chatModelProviders && (

Chat model Provider @@ -116,36 +143,47 @@ const SettingsDialog = ({ onChange={(e) => { setSelectedChatModelProvider(e.target.value); setSelectedChatModel( - config.providers[e.target.value][0], + config.chatModelProviders[e.target.value][0], ); }} className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" > - {Object.keys(config.providers).map((provider) => ( - - ))} + {Object.keys(config.chatModelProviders).map( + (provider) => ( + + ), + )}

)} - {selectedChatModelProvider && ( -
-

Chat Model

- + setSelectedChatModel(e.target.value) + } + className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" + > + {config.chatModelProviders[ + selectedChatModelProvider + ] ? ( + config.chatModelProviders[ + selectedChatModelProvider + ].length > 0 ? ( + config.chatModelProviders[ + selectedChatModelProvider + ].map((model) => ( - ), + )) + ) : ( + ) ) : ( + )} + +
+ )} + {selectedChatModelProvider && + selectedChatModelProvider === 'custom_openai' && ( + <> +
+

Model name

+ + setSelectedChatModel(e.target.value) + } + className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" + /> +
+
+

+ Custom OpenAI API Key (optional) +

+ + setCustomOpenAIApiKey(e.target.value) + } + className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" + /> +
+
+

+ Custom OpenAI Base URL +

+ + setCustomOpenAIBaseURL(e.target.value) + } + className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" + /> +
+ + )} + {/* Embedding models */} + {config.chatModelProviders && ( +
+

+ Embedding model Provider +

+ +
+ )} + {selectedEmbeddingModelProvider && ( +
+

Embedding Model

+