feat(providers): add `displayName` property

This commit is contained in:
ItzCrazyKns 2024-09-24 22:34:43 +05:30
parent 40f551c426
commit 1589f16d5a
12 changed files with 277 additions and 183 deletions

View File

@ -9,26 +9,38 @@ export const loadAnthropicChatModels = async () => {
try { try {
const chatModels = { const chatModels = {
'Claude 3.5 Sonnet': new ChatAnthropic({ 'claude-3-5-sonnet-20240620': {
displayName: 'Claude 3.5 Sonnet',
model: new ChatAnthropic({
temperature: 0.7, temperature: 0.7,
anthropicApiKey: anthropicApiKey, anthropicApiKey: anthropicApiKey,
model: 'claude-3-5-sonnet-20240620', model: 'claude-3-5-sonnet-20240620',
}), }),
'Claude 3 Opus': new ChatAnthropic({ },
'claude-3-opus-20240229': {
displayName: 'Claude 3 Opus',
model: new ChatAnthropic({
temperature: 0.7, temperature: 0.7,
anthropicApiKey: anthropicApiKey, anthropicApiKey: anthropicApiKey,
model: 'claude-3-opus-20240229', model: 'claude-3-opus-20240229',
}), }),
'Claude 3 Sonnet': new ChatAnthropic({ },
'claude-3-sonnet-20240229': {
displayName: 'Claude 3 Sonnet',
model: new ChatAnthropic({
temperature: 0.7, temperature: 0.7,
anthropicApiKey: anthropicApiKey, anthropicApiKey: anthropicApiKey,
model: 'claude-3-sonnet-20240229', model: 'claude-3-sonnet-20240229',
}), }),
'Claude 3 Haiku': new ChatAnthropic({ },
'claude-3-haiku-20240307': {
displayName: 'Claude 3 Haiku',
model: new ChatAnthropic({
temperature: 0.7, temperature: 0.7,
anthropicApiKey: anthropicApiKey, anthropicApiKey: anthropicApiKey,
model: 'claude-3-haiku-20240307', model: 'claude-3-haiku-20240307',
}), }),
},
}; };
return chatModels; return chatModels;

View File

@ -9,7 +9,9 @@ export const loadGroqChatModels = async () => {
try { try {
const chatModels = { const chatModels = {
'Llama 3.1 70B': new ChatOpenAI( 'llama-3.1-70b-versatile': {
displayName: 'Llama 3.1 70B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'llama-3.1-70b-versatile', modelName: 'llama-3.1-70b-versatile',
@ -19,7 +21,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'Llama 3.1 8B': new ChatOpenAI( },
'llama-3.1-8b-instant': {
displayName: 'Llama 3.1 8B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'llama-3.1-8b-instant', modelName: 'llama-3.1-8b-instant',
@ -29,7 +34,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'LLaMA3 8b': new ChatOpenAI( },
'llama3-8b-8192': {
displayName: 'LLaMA3 8B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192', modelName: 'llama3-8b-8192',
@ -39,7 +47,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'LLaMA3 70b': new ChatOpenAI( },
'llama3-70b-8192': {
displayName: 'LLaMA3 70B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192', modelName: 'llama3-70b-8192',
@ -49,7 +60,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'Mixtral 8x7b': new ChatOpenAI( },
'mixtral-8x7b-32768': {
displayName: 'Mixtral 8x7B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768', modelName: 'mixtral-8x7b-32768',
@ -59,7 +73,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'Gemma 7b': new ChatOpenAI( },
'gemma-7b-it': {
displayName: 'Gemma 7B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it', modelName: 'gemma-7b-it',
@ -69,7 +86,10 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
'Gemma2 9b': new ChatOpenAI( },
'gemma2-9b-it': {
displayName: 'Gemma2 9B',
model: new ChatOpenAI(
{ {
openAIApiKey: groqApiKey, openAIApiKey: groqApiKey,
modelName: 'gemma2-9b-it', modelName: 'gemma2-9b-it',
@ -79,6 +99,7 @@ export const loadGroqChatModels = async () => {
baseURL: 'https://api.groq.com/openai/v1', baseURL: 'https://api.groq.com/openai/v1',
}, },
), ),
},
}; };
return chatModels; return chatModels;

View File

@ -18,11 +18,15 @@ export const loadOllamaChatModels = async () => {
const { models: ollamaModels } = (await response.json()) as any; const { models: ollamaModels } = (await response.json()) as any;
const chatModels = ollamaModels.reduce((acc, model) => { const chatModels = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({ acc[model.model] = {
displayName: model.name,
model: new ChatOllama({
baseUrl: ollamaEndpoint, baseUrl: ollamaEndpoint,
model: model.model, model: model.model,
temperature: 0.7, temperature: 0.7,
}); }),
};
return acc; return acc;
}, {}); }, {});
@ -48,10 +52,14 @@ export const loadOllamaEmbeddingsModels = async () => {
const { models: ollamaModels } = (await response.json()) as any; const { models: ollamaModels } = (await response.json()) as any;
const embeddingsModels = ollamaModels.reduce((acc, model) => { const embeddingsModels = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({ acc[model.model] = {
displayName: model.name,
model: new OllamaEmbeddings({
baseUrl: ollamaEndpoint, baseUrl: ollamaEndpoint,
model: model.model, model: model.model,
}); }),
};
return acc; return acc;
}, {}); }, {});

View File

@ -9,31 +9,46 @@ export const loadOpenAIChatModels = async () => {
try { try {
const chatModels = { const chatModels = {
'GPT-3.5 turbo': new ChatOpenAI({ 'gpt-3.5-turbo': {
displayName: 'GPT-3.5 Turbo',
model: new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-3.5-turbo', modelName: 'gpt-3.5-turbo',
temperature: 0.7, temperature: 0.7,
}), }),
'GPT-4': new ChatOpenAI({ },
'gpt-4': {
displayName: 'GPT-4',
model: new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4', modelName: 'gpt-4',
temperature: 0.7, temperature: 0.7,
}), }),
'GPT-4 turbo': new ChatOpenAI({ },
'gpt-4-turbo': {
displayName: 'GPT-4 turbo',
model: new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4-turbo', modelName: 'gpt-4-turbo',
temperature: 0.7, temperature: 0.7,
}), }),
'GPT-4 omni': new ChatOpenAI({ },
'gpt-4o': {
displayName: 'GPT-4 omni',
model: new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4o', modelName: 'gpt-4o',
temperature: 0.7, temperature: 0.7,
}), }),
'GPT-4 omni mini': new ChatOpenAI({ },
'gpt-4o-mini': {
displayName: 'GPT-4 omni mini',
model: new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4o-mini', modelName: 'gpt-4o-mini',
temperature: 0.7, temperature: 0.7,
}), }),
},
}; };
return chatModels; return chatModels;
@ -50,14 +65,20 @@ export const loadOpenAIEmbeddingsModels = async () => {
try { try {
const embeddingModels = { const embeddingModels = {
'Text embedding 3 small': new OpenAIEmbeddings({ 'text-embedding-3-small': {
displayName: 'Text Embedding 3 Small',
model: new OpenAIEmbeddings({
openAIApiKey, openAIApiKey,
modelName: 'text-embedding-3-small', modelName: 'text-embedding-3-small',
}), }),
'Text embedding 3 large': new OpenAIEmbeddings({ },
'text-embedding-3-large': {
displayName: 'Text Embedding 3 Large',
model: new OpenAIEmbeddings({
openAIApiKey, openAIApiKey,
modelName: 'text-embedding-3-large', modelName: 'text-embedding-3-large',
}), }),
},
}; };
return embeddingModels; return embeddingModels;

View File

@ -4,15 +4,24 @@ import { HuggingFaceTransformersEmbeddings } from '../huggingfaceTransformer';
export const loadTransformersEmbeddingsModels = async () => { export const loadTransformersEmbeddingsModels = async () => {
try { try {
const embeddingModels = { const embeddingModels = {
'BGE Small': new HuggingFaceTransformersEmbeddings({ 'xenova-bge-small-en-v1.5': {
displayName: 'BGE Small',
model: new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5', modelName: 'Xenova/bge-small-en-v1.5',
}), }),
'GTE Small': new HuggingFaceTransformersEmbeddings({ },
'xenova-gte-small': {
displayName: 'GTE Small',
model: new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small', modelName: 'Xenova/gte-small',
}), }),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({ },
'xenova-bert-base-multilingual-uncased': {
displayName: 'Bert Multilingual',
model: new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased', modelName: 'Xenova/bert-base-multilingual-uncased',
}), }),
},
}; };
return embeddingModels; return embeddingModels;

View File

@ -10,10 +10,12 @@ import {
getOpenaiApiKey, getOpenaiApiKey,
updateConfig, updateConfig,
} from '../config'; } from '../config';
import logger from '../utils/logger';
const router = express.Router(); const router = express.Router();
router.get('/', async (_, res) => { router.get('/', async (_, res) => {
try {
const config = {}; const config = {};
const [chatModelProviders, embeddingModelProviders] = await Promise.all([ const [chatModelProviders, embeddingModelProviders] = await Promise.all([
@ -27,13 +29,23 @@ router.get('/', async (_, res) => {
for (const provider in chatModelProviders) { for (const provider in chatModelProviders) {
config['chatModelProviders'][provider] = Object.keys( config['chatModelProviders'][provider] = Object.keys(
chatModelProviders[provider], chatModelProviders[provider],
); ).map((model) => {
return {
name: model,
displayName: chatModelProviders[provider][model].displayName,
};
});
} }
for (const provider in embeddingModelProviders) { for (const provider in embeddingModelProviders) {
config['embeddingModelProviders'][provider] = Object.keys( config['embeddingModelProviders'][provider] = Object.keys(
embeddingModelProviders[provider], embeddingModelProviders[provider],
); ).map((model) => {
return {
name: model,
displayName: embeddingModelProviders[provider][model].displayName,
};
});
} }
config['openaiApiKey'] = getOpenaiApiKey(); config['openaiApiKey'] = getOpenaiApiKey();
@ -42,6 +54,10 @@ router.get('/', async (_, res) => {
config['groqApiKey'] = getGroqApiKey(); config['groqApiKey'] = getGroqApiKey();
res.status(200).json(config); res.status(200).json(config);
} catch (err: any) {
res.status(500).json({ message: 'An error has occurred.' });
logger.error(`Error getting config: ${err.message}`);
}
}); });
router.post('/', async (req, res) => { router.post('/', async (req, res) => {

View File

@ -26,7 +26,7 @@ router.post('/', async (req, res) => {
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) { if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined; llm = chatModels[provider][chatModel].model as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {

View File

@ -26,7 +26,7 @@ router.post('/', async (req, res) => {
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) { if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined; llm = chatModels[provider][chatModel].model as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {

View File

@ -26,7 +26,7 @@ router.post('/', async (req, res) => {
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) { if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined; llm = chatModels[provider][chatModel].model as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {

View File

@ -45,9 +45,8 @@ export const handleConnection = async (
chatModelProviders[chatModelProvider][chatModel] && chatModelProviders[chatModelProvider][chatModel] &&
chatModelProvider != 'custom_openai' chatModelProvider != 'custom_openai'
) { ) {
llm = chatModelProviders[chatModelProvider][chatModel] as unknown as llm = chatModelProviders[chatModelProvider][chatModel]
| BaseChatModel .model as unknown as BaseChatModel | undefined;
| undefined;
} else if (chatModelProvider == 'custom_openai') { } else if (chatModelProvider == 'custom_openai') {
llm = new ChatOpenAI({ llm = new ChatOpenAI({
modelName: chatModel, modelName: chatModel,
@ -65,7 +64,7 @@ export const handleConnection = async (
) { ) {
embeddings = embeddingModelProviders[embeddingModelProvider][ embeddings = embeddingModelProviders[embeddingModelProvider][
embeddingModel embeddingModel
] as Embeddings | undefined; ].model as Embeddings | undefined;
} }
if (!llm || !embeddings) { if (!llm || !embeddings) {

View File

@ -49,10 +49,10 @@ export const Select = ({ className, options, ...restProps }: SelectProps) => {
interface SettingsType { interface SettingsType {
chatModelProviders: { chatModelProviders: {
[key: string]: string[]; [key: string]: [Record<string, any>];
}; };
embeddingModelProviders: { embeddingModelProviders: {
[key: string]: string[]; [key: string]: [Record<string, any>];
}; };
openaiApiKey: string; openaiApiKey: string;
groqApiKey: string; groqApiKey: string;
@ -68,6 +68,10 @@ const SettingsDialog = ({
setIsOpen: (isOpen: boolean) => void; setIsOpen: (isOpen: boolean) => void;
}) => { }) => {
const [config, setConfig] = useState<SettingsType | null>(null); const [config, setConfig] = useState<SettingsType | null>(null);
const [chatModels, setChatModels] = useState<Record<string, any>>({});
const [embeddingModels, setEmbeddingModels] = useState<Record<string, any>>(
{},
);
const [selectedChatModelProvider, setSelectedChatModelProvider] = useState< const [selectedChatModelProvider, setSelectedChatModelProvider] = useState<
string | null string | null
>(null); >(null);
@ -118,7 +122,7 @@ const SettingsDialog = ({
const chatModel = const chatModel =
localStorage.getItem('chatModel') || localStorage.getItem('chatModel') ||
(data.chatModelProviders && (data.chatModelProviders &&
data.chatModelProviders[chatModelProvider]?.[0]) || data.chatModelProviders[chatModelProvider]?.[0].name) ||
''; '';
const embeddingModelProvider = const embeddingModelProvider =
localStorage.getItem('embeddingModelProvider') || localStorage.getItem('embeddingModelProvider') ||
@ -127,7 +131,7 @@ const SettingsDialog = ({
const embeddingModel = const embeddingModel =
localStorage.getItem('embeddingModel') || localStorage.getItem('embeddingModel') ||
(data.embeddingModelProviders && (data.embeddingModelProviders &&
data.embeddingModelProviders[embeddingModelProvider]?.[0]) || data.embeddingModelProviders[embeddingModelProvider]?.[0].name) ||
''; '';
setSelectedChatModelProvider(chatModelProvider); setSelectedChatModelProvider(chatModelProvider);
@ -136,6 +140,8 @@ const SettingsDialog = ({
setSelectedEmbeddingModel(embeddingModel); setSelectedEmbeddingModel(embeddingModel);
setCustomOpenAIApiKey(localStorage.getItem('openAIApiKey') || ''); setCustomOpenAIApiKey(localStorage.getItem('openAIApiKey') || '');
setCustomOpenAIBaseURL(localStorage.getItem('openAIBaseURL') || ''); setCustomOpenAIBaseURL(localStorage.getItem('openAIBaseURL') || '');
setChatModels(data.chatModelProviders || {});
setEmbeddingModels(data.embeddingModelProviders || {});
setIsLoading(false); setIsLoading(false);
}; };
@ -229,7 +235,8 @@ const SettingsDialog = ({
setSelectedChatModel(''); setSelectedChatModel('');
} else { } else {
setSelectedChatModel( setSelectedChatModel(
config.chatModelProviders[e.target.value][0], config.chatModelProviders[e.target.value][0]
.name,
); );
} }
}} }}
@ -264,8 +271,8 @@ const SettingsDialog = ({
return chatModelProvider return chatModelProvider
? chatModelProvider.length > 0 ? chatModelProvider.length > 0
? chatModelProvider.map((model) => ({ ? chatModelProvider.map((model) => ({
value: model, value: model.name,
label: model, label: model.displayName,
})) }))
: [ : [
{ {
@ -341,7 +348,8 @@ const SettingsDialog = ({
onChange={(e) => { onChange={(e) => {
setSelectedEmbeddingModelProvider(e.target.value); setSelectedEmbeddingModelProvider(e.target.value);
setSelectedEmbeddingModel( setSelectedEmbeddingModel(
config.embeddingModelProviders[e.target.value][0], config.embeddingModelProviders[e.target.value][0]
.name,
); );
}} }}
options={Object.keys( options={Object.keys(
@ -374,8 +382,8 @@ const SettingsDialog = ({
return embeddingModelProvider return embeddingModelProvider
? embeddingModelProvider.length > 0 ? embeddingModelProvider.length > 0
? embeddingModelProvider.map((model) => ({ ? embeddingModelProvider.map((model) => ({
label: model, label: model.displayName,
value: model, value: model.name,
})) }))
: [ : [
{ {