diff --git a/src/lib/providers/anthropic.ts b/src/lib/providers/anthropic.ts index 58cd164..90460c6 100644 --- a/src/lib/providers/anthropic.ts +++ b/src/lib/providers/anthropic.ts @@ -9,26 +9,38 @@ export const loadAnthropicChatModels = async () => { try { const chatModels = { - 'Claude 3.5 Sonnet': new ChatAnthropic({ - temperature: 0.7, - anthropicApiKey: anthropicApiKey, - model: 'claude-3-5-sonnet-20240620', - }), - 'Claude 3 Opus': new ChatAnthropic({ - temperature: 0.7, - anthropicApiKey: anthropicApiKey, - model: 'claude-3-opus-20240229', - }), - 'Claude 3 Sonnet': new ChatAnthropic({ - temperature: 0.7, - anthropicApiKey: anthropicApiKey, - model: 'claude-3-sonnet-20240229', - }), - 'Claude 3 Haiku': new ChatAnthropic({ - temperature: 0.7, - anthropicApiKey: anthropicApiKey, - model: 'claude-3-haiku-20240307', - }), + 'claude-3-5-sonnet-20240620': { + displayName: 'Claude 3.5 Sonnet', + model: new ChatAnthropic({ + temperature: 0.7, + anthropicApiKey: anthropicApiKey, + model: 'claude-3-5-sonnet-20240620', + }), + }, + 'claude-3-opus-20240229': { + displayName: 'Claude 3 Opus', + model: new ChatAnthropic({ + temperature: 0.7, + anthropicApiKey: anthropicApiKey, + model: 'claude-3-opus-20240229', + }), + }, + 'claude-3-sonnet-20240229': { + displayName: 'Claude 3 Sonnet', + model: new ChatAnthropic({ + temperature: 0.7, + anthropicApiKey: anthropicApiKey, + model: 'claude-3-sonnet-20240229', + }), + }, + 'claude-3-haiku-20240307': { + displayName: 'Claude 3 Haiku', + model: new ChatAnthropic({ + temperature: 0.7, + anthropicApiKey: anthropicApiKey, + model: 'claude-3-haiku-20240307', + }), + }, }; return chatModels; diff --git a/src/lib/providers/groq.ts b/src/lib/providers/groq.ts index ffe8f6c..6249267 100644 --- a/src/lib/providers/groq.ts +++ b/src/lib/providers/groq.ts @@ -9,76 +9,97 @@ export const loadGroqChatModels = async () => { try { const chatModels = { - 'Llama 3.1 70B': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'llama-3.1-70b-versatile', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'Llama 3.1 8B': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'llama-3.1-8b-instant', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'LLaMA3 8b': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'llama3-8b-8192', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'LLaMA3 70b': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'llama3-70b-8192', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'Mixtral 8x7b': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'mixtral-8x7b-32768', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'Gemma 7b': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'gemma-7b-it', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), - 'Gemma2 9b': new ChatOpenAI( - { - openAIApiKey: groqApiKey, - modelName: 'gemma2-9b-it', - temperature: 0.7, - }, - { - baseURL: 'https://api.groq.com/openai/v1', - }, - ), + 'llama-3.1-70b-versatile': { + displayName: 'Llama 3.1 70B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'llama-3.1-70b-versatile', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'llama-3.1-8b-instant': { + displayName: 'Llama 3.1 8B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'llama-3.1-8b-instant', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'llama3-8b-8192': { + displayName: 'LLaMA3 8B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'llama3-8b-8192', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'llama3-70b-8192': { + displayName: 'LLaMA3 70B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'llama3-70b-8192', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'mixtral-8x7b-32768': { + displayName: 'Mixtral 8x7B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'mixtral-8x7b-32768', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'gemma-7b-it': { + displayName: 'Gemma 7B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'gemma-7b-it', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, + 'gemma2-9b-it': { + displayName: 'Gemma2 9B', + model: new ChatOpenAI( + { + openAIApiKey: groqApiKey, + modelName: 'gemma2-9b-it', + temperature: 0.7, + }, + { + baseURL: 'https://api.groq.com/openai/v1', + }, + ), + }, }; return chatModels; diff --git a/src/lib/providers/ollama.ts b/src/lib/providers/ollama.ts index b2901ff..ed68bfa 100644 --- a/src/lib/providers/ollama.ts +++ b/src/lib/providers/ollama.ts @@ -18,11 +18,15 @@ export const loadOllamaChatModels = async () => { const { models: ollamaModels } = (await response.json()) as any; const chatModels = ollamaModels.reduce((acc, model) => { - acc[model.model] = new ChatOllama({ - baseUrl: ollamaEndpoint, - model: model.model, - temperature: 0.7, - }); + acc[model.model] = { + displayName: model.name, + model: new ChatOllama({ + baseUrl: ollamaEndpoint, + model: model.model, + temperature: 0.7, + }), + }; + return acc; }, {}); @@ -48,10 +52,14 @@ export const loadOllamaEmbeddingsModels = async () => { const { models: ollamaModels } = (await response.json()) as any; const embeddingsModels = ollamaModels.reduce((acc, model) => { - acc[model.model] = new OllamaEmbeddings({ - baseUrl: ollamaEndpoint, - model: model.model, - }); + acc[model.model] = { + displayName: model.name, + model: new OllamaEmbeddings({ + baseUrl: ollamaEndpoint, + model: model.model, + }), + }; + return acc; }, {}); diff --git a/src/lib/providers/openai.ts b/src/lib/providers/openai.ts index 8673954..3747e37 100644 --- a/src/lib/providers/openai.ts +++ b/src/lib/providers/openai.ts @@ -9,31 +9,46 @@ export const loadOpenAIChatModels = async () => { try { const chatModels = { - 'GPT-3.5 turbo': new ChatOpenAI({ - openAIApiKey, - modelName: 'gpt-3.5-turbo', - temperature: 0.7, - }), - 'GPT-4': new ChatOpenAI({ - openAIApiKey, - modelName: 'gpt-4', - temperature: 0.7, - }), - 'GPT-4 turbo': new ChatOpenAI({ - openAIApiKey, - modelName: 'gpt-4-turbo', - temperature: 0.7, - }), - 'GPT-4 omni': new ChatOpenAI({ - openAIApiKey, - modelName: 'gpt-4o', - temperature: 0.7, - }), - 'GPT-4 omni mini': new ChatOpenAI({ - openAIApiKey, - modelName: 'gpt-4o-mini', - temperature: 0.7, - }), + 'gpt-3.5-turbo': { + displayName: 'GPT-3.5 Turbo', + model: new ChatOpenAI({ + openAIApiKey, + modelName: 'gpt-3.5-turbo', + temperature: 0.7, + }), + }, + 'gpt-4': { + displayName: 'GPT-4', + model: new ChatOpenAI({ + openAIApiKey, + modelName: 'gpt-4', + temperature: 0.7, + }), + }, + 'gpt-4-turbo': { + displayName: 'GPT-4 turbo', + model: new ChatOpenAI({ + openAIApiKey, + modelName: 'gpt-4-turbo', + temperature: 0.7, + }), + }, + 'gpt-4o': { + displayName: 'GPT-4 omni', + model: new ChatOpenAI({ + openAIApiKey, + modelName: 'gpt-4o', + temperature: 0.7, + }), + }, + 'gpt-4o-mini': { + displayName: 'GPT-4 omni mini', + model: new ChatOpenAI({ + openAIApiKey, + modelName: 'gpt-4o-mini', + temperature: 0.7, + }), + }, }; return chatModels; @@ -50,14 +65,20 @@ export const loadOpenAIEmbeddingsModels = async () => { try { const embeddingModels = { - 'Text embedding 3 small': new OpenAIEmbeddings({ - openAIApiKey, - modelName: 'text-embedding-3-small', - }), - 'Text embedding 3 large': new OpenAIEmbeddings({ - openAIApiKey, - modelName: 'text-embedding-3-large', - }), + 'text-embedding-3-small': { + displayName: 'Text Embedding 3 Small', + model: new OpenAIEmbeddings({ + openAIApiKey, + modelName: 'text-embedding-3-small', + }), + }, + 'text-embedding-3-large': { + displayName: 'Text Embedding 3 Large', + model: new OpenAIEmbeddings({ + openAIApiKey, + modelName: 'text-embedding-3-large', + }), + }, }; return embeddingModels; diff --git a/src/lib/providers/transformers.ts b/src/lib/providers/transformers.ts index 0ec7052..8a3417d 100644 --- a/src/lib/providers/transformers.ts +++ b/src/lib/providers/transformers.ts @@ -4,15 +4,24 @@ import { HuggingFaceTransformersEmbeddings } from '../huggingfaceTransformer'; export const loadTransformersEmbeddingsModels = async () => { try { const embeddingModels = { - 'BGE Small': new HuggingFaceTransformersEmbeddings({ - modelName: 'Xenova/bge-small-en-v1.5', - }), - 'GTE Small': new HuggingFaceTransformersEmbeddings({ - modelName: 'Xenova/gte-small', - }), - 'Bert Multilingual': new HuggingFaceTransformersEmbeddings({ - modelName: 'Xenova/bert-base-multilingual-uncased', - }), + 'xenova-bge-small-en-v1.5': { + displayName: 'BGE Small', + model: new HuggingFaceTransformersEmbeddings({ + modelName: 'Xenova/bge-small-en-v1.5', + }), + }, + 'xenova-gte-small': { + displayName: 'GTE Small', + model: new HuggingFaceTransformersEmbeddings({ + modelName: 'Xenova/gte-small', + }), + }, + 'xenova-bert-base-multilingual-uncased': { + displayName: 'Bert Multilingual', + model: new HuggingFaceTransformersEmbeddings({ + modelName: 'Xenova/bert-base-multilingual-uncased', + }), + }, }; return embeddingModels; diff --git a/src/routes/config.ts b/src/routes/config.ts index f255560..f635e4b 100644 --- a/src/routes/config.ts +++ b/src/routes/config.ts @@ -10,38 +10,54 @@ import { getOpenaiApiKey, updateConfig, } from '../config'; +import logger from '../utils/logger'; const router = express.Router(); router.get('/', async (_, res) => { - const config = {}; + try { + const config = {}; - const [chatModelProviders, embeddingModelProviders] = await Promise.all([ - getAvailableChatModelProviders(), - getAvailableEmbeddingModelProviders(), - ]); + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); - config['chatModelProviders'] = {}; - config['embeddingModelProviders'] = {}; + config['chatModelProviders'] = {}; + config['embeddingModelProviders'] = {}; - for (const provider in chatModelProviders) { - config['chatModelProviders'][provider] = Object.keys( - chatModelProviders[provider], - ); + for (const provider in chatModelProviders) { + config['chatModelProviders'][provider] = Object.keys( + chatModelProviders[provider], + ).map((model) => { + return { + name: model, + displayName: chatModelProviders[provider][model].displayName, + }; + }); + } + + for (const provider in embeddingModelProviders) { + config['embeddingModelProviders'][provider] = Object.keys( + embeddingModelProviders[provider], + ).map((model) => { + return { + name: model, + displayName: embeddingModelProviders[provider][model].displayName, + }; + }); + } + + config['openaiApiKey'] = getOpenaiApiKey(); + config['ollamaApiUrl'] = getOllamaApiEndpoint(); + config['anthropicApiKey'] = getAnthropicApiKey(); + config['groqApiKey'] = getGroqApiKey(); + + res.status(200).json(config); + } catch (err: any) { + res.status(500).json({ message: 'An error has occurred.' }); + logger.error(`Error getting config: ${err.message}`); } - - for (const provider in embeddingModelProviders) { - config['embeddingModelProviders'][provider] = Object.keys( - embeddingModelProviders[provider], - ); - } - - config['openaiApiKey'] = getOpenaiApiKey(); - config['ollamaApiUrl'] = getOllamaApiEndpoint(); - config['anthropicApiKey'] = getAnthropicApiKey(); - config['groqApiKey'] = getGroqApiKey(); - - res.status(200).json(config); }); router.post('/', async (req, res) => { diff --git a/src/routes/images.ts b/src/routes/images.ts index 6bd43d3..7806ce7 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -26,7 +26,7 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel] as BaseChatModel | undefined; + llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; } if (!llm) { diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts index b15ff5f..a75657e 100644 --- a/src/routes/suggestions.ts +++ b/src/routes/suggestions.ts @@ -26,7 +26,7 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel] as BaseChatModel | undefined; + llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; } if (!llm) { diff --git a/src/routes/videos.ts b/src/routes/videos.ts index 0ffdb2c..9d43fd2 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -26,7 +26,7 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel] as BaseChatModel | undefined; + llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; } if (!llm) { diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index 70e20d9..04797c5 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -45,9 +45,8 @@ export const handleConnection = async ( chatModelProviders[chatModelProvider][chatModel] && chatModelProvider != 'custom_openai' ) { - llm = chatModelProviders[chatModelProvider][chatModel] as unknown as - | BaseChatModel - | undefined; + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; } else if (chatModelProvider == 'custom_openai') { llm = new ChatOpenAI({ modelName: chatModel, @@ -65,7 +64,7 @@ export const handleConnection = async ( ) { embeddings = embeddingModelProviders[embeddingModelProvider][ embeddingModel - ] as Embeddings | undefined; + ].model as Embeddings | undefined; } if (!llm || !embeddings) { diff --git a/ui/components/SearchVideos.tsx b/ui/components/SearchVideos.tsx index 74d4381..fec229c 100644 --- a/ui/components/SearchVideos.tsx +++ b/ui/components/SearchVideos.tsx @@ -64,7 +64,7 @@ const Searchvideos = ({ const data = await res.json(); - const videos = data.videos ?? []; + const videos = data.videos ?? []; setVideos(videos); setSlides( videos.map((video: Video) => { diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 171e812..02358c5 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -49,10 +49,10 @@ export const Select = ({ className, options, ...restProps }: SelectProps) => { interface SettingsType { chatModelProviders: { - [key: string]: string[]; + [key: string]: [Record]; }; embeddingModelProviders: { - [key: string]: string[]; + [key: string]: [Record]; }; openaiApiKey: string; groqApiKey: string; @@ -68,6 +68,10 @@ const SettingsDialog = ({ setIsOpen: (isOpen: boolean) => void; }) => { const [config, setConfig] = useState(null); + const [chatModels, setChatModels] = useState>({}); + const [embeddingModels, setEmbeddingModels] = useState>( + {}, + ); const [selectedChatModelProvider, setSelectedChatModelProvider] = useState< string | null >(null); @@ -118,7 +122,7 @@ const SettingsDialog = ({ const chatModel = localStorage.getItem('chatModel') || (data.chatModelProviders && - data.chatModelProviders[chatModelProvider]?.[0]) || + data.chatModelProviders[chatModelProvider]?.[0].name) || ''; const embeddingModelProvider = localStorage.getItem('embeddingModelProvider') || @@ -127,7 +131,7 @@ const SettingsDialog = ({ const embeddingModel = localStorage.getItem('embeddingModel') || (data.embeddingModelProviders && - data.embeddingModelProviders[embeddingModelProvider]?.[0]) || + data.embeddingModelProviders[embeddingModelProvider]?.[0].name) || ''; setSelectedChatModelProvider(chatModelProvider); @@ -136,6 +140,8 @@ const SettingsDialog = ({ setSelectedEmbeddingModel(embeddingModel); setCustomOpenAIApiKey(localStorage.getItem('openAIApiKey') || ''); setCustomOpenAIBaseURL(localStorage.getItem('openAIBaseURL') || ''); + setChatModels(data.chatModelProviders || {}); + setEmbeddingModels(data.embeddingModelProviders || {}); setIsLoading(false); }; @@ -229,7 +235,8 @@ const SettingsDialog = ({ setSelectedChatModel(''); } else { setSelectedChatModel( - config.chatModelProviders[e.target.value][0], + config.chatModelProviders[e.target.value][0] + .name, ); } }} @@ -264,8 +271,8 @@ const SettingsDialog = ({ return chatModelProvider ? chatModelProvider.length > 0 ? chatModelProvider.map((model) => ({ - value: model, - label: model, + value: model.name, + label: model.displayName, })) : [ { @@ -341,7 +348,8 @@ const SettingsDialog = ({ onChange={(e) => { setSelectedEmbeddingModelProvider(e.target.value); setSelectedEmbeddingModel( - config.embeddingModelProviders[e.target.value][0], + config.embeddingModelProviders[e.target.value][0] + .name, ); }} options={Object.keys( @@ -374,8 +382,8 @@ const SettingsDialog = ({ return embeddingModelProvider ? embeddingModelProvider.length > 0 ? embeddingModelProvider.map((model) => ({ - label: model, - value: model, + label: model.displayName, + value: model.name, })) : [ {