feat(providers): separate embedding providers, add custom-openai provider

This commit is contained in:
ItzCrazyKns 2024-05-04 10:51:06 +05:30
parent c710f4f88c
commit 9f45ecb98d
No known key found for this signature in database
GPG Key ID: 8162927C7CCE3065
8 changed files with 360 additions and 72 deletions

View File

@ -8,7 +8,7 @@ import {
} from '../config'; } from '../config';
import logger from '../utils/logger'; import logger from '../utils/logger';
export const getAvailableProviders = async () => { export const getAvailableChatModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey(); const openAIApiKey = getOpenaiApiKey();
const groqApiKey = getGroqApiKey(); const groqApiKey = getGroqApiKey();
const ollamaEndpoint = getOllamaApiEndpoint(); const ollamaEndpoint = getOllamaApiEndpoint();
@ -33,10 +33,6 @@ export const getAvailableProviders = async () => {
modelName: 'gpt-4-turbo', modelName: 'gpt-4-turbo',
temperature: 0.7, temperature: 0.7,
}), }),
embeddings: new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
}; };
} catch (err) { } catch (err) {
logger.error(`Error loading OpenAI models: ${err}`); logger.error(`Error loading OpenAI models: ${err}`);
@ -86,10 +82,6 @@ export const getAvailableProviders = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
embeddings: new OpenAIEmbeddings({
openAIApiKey: openAIApiKey,
modelName: 'text-embedding-3-large',
}),
}; };
} catch (err) { } catch (err) {
logger.error(`Error loading Groq models: ${err}`); logger.error(`Error loading Groq models: ${err}`);
@ -110,17 +102,56 @@ export const getAvailableProviders = async () => {
}); });
return acc; return acc;
}, {}); }, {});
if (Object.keys(models['ollama']).length > 0) {
models['ollama']['embeddings'] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: models['ollama'][Object.keys(models['ollama'])[0]].model,
});
}
} catch (err) { } catch (err) {
logger.error(`Error loading Ollama models: ${err}`); logger.error(`Error loading Ollama models: ${err}`);
} }
} }
models['custom_openai'] = {};
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
}
}
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`);
const { models: ollamaModels } = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: model.model,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
}
}
return models; return models;
}; };

View File

@ -1,5 +1,8 @@
import express from 'express'; import express from 'express';
import { getAvailableProviders } from '../lib/providers'; import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '../lib/providers';
import { import {
getGroqApiKey, getGroqApiKey,
getOllamaApiEndpoint, getOllamaApiEndpoint,
@ -12,16 +15,24 @@ const router = express.Router();
router.get('/', async (_, res) => { router.get('/', async (_, res) => {
const config = {}; const config = {};
const providers = await getAvailableProviders(); const [chatModelProviders, embeddingModelProviders] = await Promise.all([
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(),
]);
for (const provider in providers) { config['chatModelProviders'] = {};
delete providers[provider]['embeddings']; config['embeddingModelProviders'] = {};
for (const provider in chatModelProviders) {
config['chatModelProviders'][provider] = Object.keys(
chatModelProviders[provider],
);
} }
config['providers'] = {}; for (const provider in embeddingModelProviders) {
config['embeddingModelProviders'][provider] = Object.keys(
for (const provider in providers) { embeddingModelProviders[provider],
config['providers'][provider] = Object.keys(providers[provider]); );
} }
config['openaiApiKey'] = getOpenaiApiKey(); config['openaiApiKey'] = getOpenaiApiKey();

View File

@ -1,7 +1,7 @@
import express from 'express'; import express from 'express';
import handleImageSearch from '../agents/imageSearchAgent'; import handleImageSearch from '../agents/imageSearchAgent';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableChatModelProviders } from '../lib/providers';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
@ -19,7 +19,7 @@ router.post('/', async (req, res) => {
} }
}); });
const chatModels = await getAvailableProviders(); const chatModels = await getAvailableChatModelProviders();
const provider = chat_model_provider || Object.keys(chatModels)[0]; const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = chat_model || Object.keys(chatModels[provider])[0]; const chatModel = chat_model || Object.keys(chatModels[provider])[0];

View File

@ -1,14 +1,20 @@
import express from 'express'; import express from 'express';
import logger from '../utils/logger'; import logger from '../utils/logger';
import { getAvailableProviders } from '../lib/providers'; import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '../lib/providers';
const router = express.Router(); const router = express.Router();
router.get('/', async (req, res) => { router.get('/', async (req, res) => {
try { try {
const providers = await getAvailableProviders(); const [chatModelProviders, embeddingModelProviders] = await Promise.all([
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(),
]);
res.status(200).json({ providers }); res.status(200).json({ chatModelProviders, embeddingModelProviders });
} catch (err) { } catch (err) {
res.status(500).json({ message: 'An error has occurred.' }); res.status(500).json({ message: 'An error has occurred.' });
logger.error(err.message); logger.error(err.message);

View File

@ -1,6 +1,6 @@
import express from 'express'; import express from 'express';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableChatModelProviders } from '../lib/providers';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
import handleVideoSearch from '../agents/videoSearchAgent'; import handleVideoSearch from '../agents/videoSearchAgent';
@ -19,7 +19,7 @@ router.post('/', async (req, res) => {
} }
}); });
const chatModels = await getAvailableProviders(); const chatModels = await getAvailableChatModelProviders();
const provider = chat_model_provider || Object.keys(chatModels)[0]; const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = chat_model || Object.keys(chatModels[provider])[0]; const chatModel = chat_model || Object.keys(chatModels[provider])[0];

View File

@ -1,10 +1,14 @@
import { WebSocket } from 'ws'; import { WebSocket } from 'ws';
import { handleMessage } from './messageHandler'; import { handleMessage } from './messageHandler';
import { getAvailableProviders } from '../lib/providers'; import {
getAvailableEmbeddingModelProviders,
getAvailableChatModelProviders,
} from '../lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings'; import type { Embeddings } from '@langchain/core/embeddings';
import type { IncomingMessage } from 'http'; import type { IncomingMessage } from 'http';
import logger from '../utils/logger'; import logger from '../utils/logger';
import { ChatOpenAI } from '@langchain/openai';
export const handleConnection = async ( export const handleConnection = async (
ws: WebSocket, ws: WebSocket,
@ -13,18 +17,53 @@ export const handleConnection = async (
const searchParams = new URL(request.url, `http://${request.headers.host}`) const searchParams = new URL(request.url, `http://${request.headers.host}`)
.searchParams; .searchParams;
const models = await getAvailableProviders(); const [chatModelProviders, embeddingModelProviders] = await Promise.all([
const provider = getAvailableChatModelProviders(),
searchParams.get('chatModelProvider') || Object.keys(models)[0]; getAvailableEmbeddingModelProviders(),
]);
const chatModelProvider =
searchParams.get('chatModelProvider') || Object.keys(chatModelProviders)[0];
const chatModel = const chatModel =
searchParams.get('chatModel') || Object.keys(models[provider])[0]; searchParams.get('chatModel') ||
Object.keys(chatModelProviders[chatModelProvider])[0];
const embeddingModelProvider =
searchParams.get('embeddingModelProvider') ||
Object.keys(embeddingModelProviders)[0];
const embeddingModel =
searchParams.get('embeddingModel') ||
Object.keys(embeddingModelProviders[embeddingModelProvider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
let embeddings: Embeddings | undefined; let embeddings: Embeddings | undefined;
if (models[provider] && models[provider][chatModel]) { if (
llm = models[provider][chatModel] as BaseChatModel | undefined; chatModelProviders[chatModelProvider] &&
embeddings = models[provider].embeddings as Embeddings | undefined; chatModelProviders[chatModelProvider][chatModel] &&
chatModelProvider != 'custom_openai'
) {
llm = chatModelProviders[chatModelProvider][chatModel] as
| BaseChatModel
| undefined;
} else if (chatModelProvider == 'custom_openai') {
llm = new ChatOpenAI({
modelName: chatModel,
openAIApiKey: searchParams.get('openAIApiKey'),
temperature: 0.7,
configuration: {
baseURL: searchParams.get('openAIBaseURL'),
},
});
}
if (
embeddingModelProviders[embeddingModelProvider] &&
embeddingModelProviders[embeddingModelProvider][embeddingModel]
) {
embeddings = embeddingModelProviders[embeddingModelProvider][
embeddingModel
] as Embeddings | undefined;
} }
if (!llm || !embeddings) { if (!llm || !embeddings) {

View File

@ -22,11 +22,23 @@ const useSocket = (url: string) => {
const connectWs = async () => { const connectWs = async () => {
let chatModel = localStorage.getItem('chatModel'); let chatModel = localStorage.getItem('chatModel');
let chatModelProvider = localStorage.getItem('chatModelProvider'); let chatModelProvider = localStorage.getItem('chatModelProvider');
let embeddingModel = localStorage.getItem('embeddingModel');
let embeddingModelProvider = localStorage.getItem(
'embeddingModelProvider',
);
if (!chatModel || !chatModelProvider) { if (
const chatModelProviders = await fetch( !chatModel ||
!chatModelProvider ||
!embeddingModel ||
!embeddingModelProvider
) {
const providers = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/models`, `${process.env.NEXT_PUBLIC_API_URL}/models`,
).then(async (res) => (await res.json())['providers']); ).then(async (res) => await res.json());
const chatModelProviders = providers.chatModelProviders;
const embeddingModelProviders = providers.embeddingModelProviders;
if ( if (
!chatModelProviders || !chatModelProviders ||
@ -34,16 +46,52 @@ const useSocket = (url: string) => {
) )
return console.error('No chat models available'); return console.error('No chat models available');
if (
!embeddingModelProviders ||
Object.keys(embeddingModelProviders).length === 0
)
return console.error('No embedding models available');
chatModelProvider = Object.keys(chatModelProviders)[0]; chatModelProvider = Object.keys(chatModelProviders)[0];
chatModel = Object.keys(chatModelProviders[chatModelProvider])[0]; chatModel = Object.keys(chatModelProviders[chatModelProvider])[0];
embeddingModelProvider = Object.keys(embeddingModelProviders)[0];
embeddingModel = Object.keys(
embeddingModelProviders[embeddingModelProvider],
)[0];
localStorage.setItem('chatModel', chatModel!); localStorage.setItem('chatModel', chatModel!);
localStorage.setItem('chatModelProvider', chatModelProvider); localStorage.setItem('chatModelProvider', chatModelProvider);
localStorage.setItem('embeddingModel', embeddingModel!);
localStorage.setItem(
'embeddingModelProvider',
embeddingModelProvider,
);
} }
const ws = new WebSocket( const wsURL = new URL(url);
`${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`, const searchParams = new URLSearchParams({});
searchParams.append('chatModel', chatModel!);
searchParams.append('chatModelProvider', chatModelProvider);
if (chatModelProvider === 'custom_openai') {
searchParams.append(
'openAIApiKey',
localStorage.getItem('openAIApiKey')!,
); );
searchParams.append(
'openAIBaseURL',
localStorage.getItem('openAIBaseURL')!,
);
}
searchParams.append('embeddingModel', embeddingModel!);
searchParams.append('embeddingModelProvider', embeddingModelProvider);
wsURL.search = searchParams.toString();
const ws = new WebSocket(wsURL.toString());
ws.onopen = () => { ws.onopen = () => {
console.log('[DEBUG] open'); console.log('[DEBUG] open');
setWs(ws); setWs(ws);

View File

@ -3,7 +3,10 @@ import { CloudUpload, RefreshCcw, RefreshCw } from 'lucide-react';
import React, { Fragment, useEffect, useState } from 'react'; import React, { Fragment, useEffect, useState } from 'react';
interface SettingsType { interface SettingsType {
providers: { chatModelProviders: {
[key: string]: string[];
};
embeddingModelProviders: {
[key: string]: string[]; [key: string]: string[];
}; };
openaiApiKey: string; openaiApiKey: string;
@ -25,6 +28,17 @@ const SettingsDialog = ({
const [selectedChatModel, setSelectedChatModel] = useState<string | null>( const [selectedChatModel, setSelectedChatModel] = useState<string | null>(
null, null,
); );
const [selectedEmbeddingModelProvider, setSelectedEmbeddingModelProvider] =
useState<string | null>(null);
const [selectedEmbeddingModel, setSelectedEmbeddingModel] = useState<
string | null
>(null);
const [customOpenAIApiKey, setCustomOpenAIApiKey] = useState<string | null>(
null,
);
const [customOpenAIBaseURL, setCustomOpenAIBaseURL] = useState<string | null>(
null,
);
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [isUpdating, setIsUpdating] = useState(false); const [isUpdating, setIsUpdating] = useState(false);
@ -46,6 +60,12 @@ const SettingsDialog = ({
useEffect(() => { useEffect(() => {
setSelectedChatModelProvider(localStorage.getItem('chatModelProvider')); setSelectedChatModelProvider(localStorage.getItem('chatModelProvider'));
setSelectedChatModel(localStorage.getItem('chatModel')); setSelectedChatModel(localStorage.getItem('chatModel'));
setSelectedEmbeddingModelProvider(
localStorage.getItem('embeddingModelProvider'),
);
setSelectedEmbeddingModel(localStorage.getItem('embeddingModel'));
setCustomOpenAIApiKey(localStorage.getItem('openAIApiKey'));
setCustomOpenAIBaseURL(localStorage.getItem('openAIBaseUrl'));
}, []); }, []);
const handleSubmit = async () => { const handleSubmit = async () => {
@ -62,6 +82,13 @@ const SettingsDialog = ({
localStorage.setItem('chatModelProvider', selectedChatModelProvider!); localStorage.setItem('chatModelProvider', selectedChatModelProvider!);
localStorage.setItem('chatModel', selectedChatModel!); localStorage.setItem('chatModel', selectedChatModel!);
localStorage.setItem(
'embeddingModelProvider',
selectedEmbeddingModelProvider!,
);
localStorage.setItem('embeddingModel', selectedEmbeddingModel!);
localStorage.setItem('openAIApiKey', customOpenAIApiKey!);
localStorage.setItem('openAIBaseURL', customOpenAIBaseURL!);
} catch (err) { } catch (err) {
console.log(err); console.log(err);
} finally { } finally {
@ -107,7 +134,7 @@ const SettingsDialog = ({
</Dialog.Title> </Dialog.Title>
{config && !isLoading && ( {config && !isLoading && (
<div className="flex flex-col space-y-4 mt-6"> <div className="flex flex-col space-y-4 mt-6">
{config.providers && ( {config.chatModelProviders && (
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm"> <p className="text-white/70 text-sm">
Chat model Provider Chat model Provider
@ -116,36 +143,47 @@ const SettingsDialog = ({
onChange={(e) => { onChange={(e) => {
setSelectedChatModelProvider(e.target.value); setSelectedChatModelProvider(e.target.value);
setSelectedChatModel( setSelectedChatModel(
config.providers[e.target.value][0], config.chatModelProviders[e.target.value][0],
); );
}} }}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
> >
{Object.keys(config.providers).map((provider) => ( {Object.keys(config.chatModelProviders).map(
(provider) => (
<option <option
key={provider} key={provider}
value={provider} value={provider}
selected={provider === selectedChatModelProvider} selected={
provider === selectedChatModelProvider
}
> >
{provider.charAt(0).toUpperCase() + {provider.charAt(0).toUpperCase() +
provider.slice(1)} provider.slice(1)}
</option> </option>
))} ),
)}
</select> </select>
</div> </div>
)} )}
{selectedChatModelProvider && ( {selectedChatModelProvider &&
selectedChatModelProvider != 'custom_openai' && (
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">Chat Model</p> <p className="text-white/70 text-sm">Chat Model</p>
<select <select
onChange={(e) => setSelectedChatModel(e.target.value)} onChange={(e) =>
setSelectedChatModel(e.target.value)
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
> >
{config.providers[selectedChatModelProvider] ? ( {config.chatModelProviders[
config.providers[selectedChatModelProvider].length > selectedChatModelProvider
0 ? ( ] ? (
config.providers[selectedChatModelProvider].map( config.chatModelProviders[
(model) => ( selectedChatModelProvider
].length > 0 ? (
config.chatModelProviders[
selectedChatModelProvider
].map((model) => (
<option <option
key={model} key={model}
value={model} value={model}
@ -153,8 +191,7 @@ const SettingsDialog = ({
> >
{model} {model}
</option> </option>
), ))
)
) : ( ) : (
<option value="" disabled selected> <option value="" disabled selected>
No models available No models available
@ -168,6 +205,122 @@ const SettingsDialog = ({
</select> </select>
</div> </div>
)} )}
{selectedChatModelProvider &&
selectedChatModelProvider === 'custom_openai' && (
<>
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">Model name</p>
<input
type="text"
placeholder="Model name"
defaultValue={selectedChatModel!}
onChange={(e) =>
setSelectedChatModel(e.target.value)
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
/>
</div>
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">
Custom OpenAI API Key (optional)
</p>
<input
type="text"
placeholder="Custom OpenAI API Key"
defaultValue={customOpenAIApiKey!}
onChange={(e) =>
setCustomOpenAIApiKey(e.target.value)
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
/>
</div>
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">
Custom OpenAI Base URL
</p>
<input
type="text"
placeholder="Custom OpenAI Base URL"
defaultValue={customOpenAIBaseURL!}
onChange={(e) =>
setCustomOpenAIBaseURL(e.target.value)
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
/>
</div>
</>
)}
{/* Embedding models */}
{config.chatModelProviders && (
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">
Embedding model Provider
</p>
<select
onChange={(e) => {
setSelectedEmbeddingModelProvider(e.target.value);
setSelectedEmbeddingModel(
config.embeddingModelProviders[e.target.value][0],
);
}}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
>
{Object.keys(config.embeddingModelProviders).map(
(provider) => (
<option
key={provider}
value={provider}
selected={
provider === selectedEmbeddingModelProvider
}
>
{provider.charAt(0).toUpperCase() +
provider.slice(1)}
</option>
),
)}
</select>
</div>
)}
{selectedEmbeddingModelProvider && (
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">Embedding Model</p>
<select
onChange={(e) =>
setSelectedEmbeddingModel(e.target.value)
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
>
{config.embeddingModelProviders[
selectedEmbeddingModelProvider
] ? (
config.embeddingModelProviders[
selectedEmbeddingModelProvider
].length > 0 ? (
config.embeddingModelProviders[
selectedEmbeddingModelProvider
].map((model) => (
<option
key={model}
value={model}
selected={model === selectedEmbeddingModel}
>
{model}
</option>
))
) : (
<option value="" disabled selected>
No embedding models available
</option>
)
) : (
<option value="" disabled selected>
Invalid provider, please check backend logs
</option>
)}
</select>
</div>
)}
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">OpenAI API Key</p> <p className="text-white/70 text-sm">OpenAI API Key</p>
<input <input