feat(providers): separate each provider

This commit is contained in:
ItzCrazyKns 2024-07-06 14:19:33 +05:30
parent c63c9b5c8a
commit 25b5dbd63e
No known key found for this signature in database
GPG Key ID: 8162927C7CCE3065
8 changed files with 238 additions and 191 deletions

View File

@ -146,7 +146,7 @@ If you find Perplexica useful, consider giving us a star on GitHub. This helps m
We also accept donations to help sustain our project. If you would like to contribute, you can use the following options to donate. Thank you for your support! We also accept donations to help sustain our project. If you would like to contribute, you can use the following options to donate. Thank you for your support!
| Cards | Ethereum | | Cards | Ethereum |
|---|---| | ----------------------------------- | ----------------------------------------------------- |
| https://www.patreon.com/itzcrazykns | Address: `0xB025a84b2F269570Eb8D4b05DEdaA41D8525B6DD` | | https://www.patreon.com/itzcrazykns | Address: `0xB025a84b2F269570Eb8D4b05DEdaA41D8525B6DD` |
## Contribution ## Contribution

View File

@ -1,4 +1,4 @@
FROM node:slim FROM node:buster-slim
ARG SEARXNG_API_URL ARG SEARXNG_API_URL

View File

@ -1,187 +0,0 @@
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { ChatOllama } from '@langchain/community/chat_models/ollama';
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { HuggingFaceTransformersEmbeddings } from './huggingfaceTransformer';
import {
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
} from '../config';
import logger from '../utils/logger';
export const getAvailableChatModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const groqApiKey = getGroqApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'GPT-3.5 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-3.5-turbo',
temperature: 0.7,
}),
'GPT-4': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4',
temperature: 0.7,
}),
'GPT-4 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4-turbo',
temperature: 0.7,
}),
'GPT-4 omni': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4o',
temperature: 0.7,
}),
};
} catch (err) {
logger.error(`Error loading OpenAI models: ${err}`);
}
}
if (groqApiKey) {
try {
models['groq'] = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
}
}
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
}
}
models['custom_openai'] = {};
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
}
}
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: model.model,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
}
}
try {
models['local'] = {
'BGE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5',
}),
'GTE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
} catch (err) {
logger.error(`Error loading local embeddings: ${err}`);
}
return models;
};

57
src/lib/providers/groq.ts Normal file
View File

@ -0,0 +1,57 @@
import { ChatOpenAI } from '@langchain/openai';
import { getGroqApiKey } from '../../config';
import logger from '../../utils/logger';
export const loadGroqChatModels = async () => {
const groqApiKey = getGroqApiKey();
try {
const chatModels = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
};
return chatModels;
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
return {};
}
};

View File

@ -0,0 +1,36 @@
import { loadGroqChatModels } from './groq';
import { loadOllamaChatModels } from './ollama';
import { loadOpenAIChatModels, loadOpenAIEmbeddingsModel } from './openai';
import { loadTransformersEmbeddingsModel } from './transformers';
const chatModelProviders = {
openai: loadOpenAIChatModels,
groq: loadGroqChatModels,
ollama: loadOllamaChatModels,
};
const embeddingModelProviders = {
openai: loadOpenAIEmbeddingsModel,
local: loadTransformersEmbeddingsModel,
ollama: loadOllamaChatModels,
};
export const getAvailableChatModelProviders = async () => {
const models = {};
for (const provider in chatModelProviders) {
models[provider] = await chatModelProviders[provider]();
}
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const models = {};
for (const provider in embeddingModelProviders) {
models[provider] = await embeddingModelProviders[provider]();
}
return models;
};

View File

@ -0,0 +1,59 @@
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { getOllamaApiEndpoint } from '../../config';
import logger from '../../utils/logger';
import { ChatOllama } from '@langchain/community/chat_models/ollama';
export const loadOllamaChatModels = async () => {
const ollamaEndpoint = getOllamaApiEndpoint();
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
const chatModels = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
return chatModels;
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
return {};
}
};
export const loadOpenAIEmbeddingsModel = async () => {
const ollamaEndpoint = getOllamaApiEndpoint();
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
const embeddingsModels = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: model.model,
});
return acc;
}, {});
return embeddingsModels;
} catch (err) {
logger.error(`Error loading Ollama embeddings model: ${err}`);
return {};
}
};

View File

@ -0,0 +1,59 @@
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { getOpenaiApiKey } from '../../config';
import logger from '../../utils/logger';
export const loadOpenAIChatModels = async () => {
const openAIApiKey = getOpenaiApiKey();
try {
const chatModels = {
'GPT-3.5 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-3.5-turbo',
temperature: 0.7,
}),
'GPT-4': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4',
temperature: 0.7,
}),
'GPT-4 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4-turbo',
temperature: 0.7,
}),
'GPT-4 omni': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4o',
temperature: 0.7,
}),
};
return chatModels;
} catch (err) {
logger.error(`Error loading OpenAI models: ${err}`);
return {};
}
};
export const loadOpenAIEmbeddingsModel = async () => {
const openAIApiKey = getOpenaiApiKey();
try {
const embeddingModels = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
return embeddingModels;
} catch (err) {
logger.error(`Error loading OpenAI embeddings model: ${err}`);
return {};
}
};

View File

@ -0,0 +1,23 @@
import logger from '../../utils/logger';
import { HuggingFaceTransformersEmbeddings } from '../huggingfaceTransformer';
export const loadTransformersEmbeddingsModel = async () => {
try {
const embeddingModels = {
'BGE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5',
}),
'GTE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
return embeddingModels;
} catch (err) {
logger.error(`Error loading Transformers embeddings model: ${err}`);
return {};
}
};