feat(providers): add Groq provider

This commit is contained in:
ItzCrazyKns 2024-05-01 19:43:06 +05:30
parent 6e304e7051
commit edc40d8fe6
No known key found for this signature in database
GPG Key ID: 8162927C7CCE3065
5 changed files with 86 additions and 5 deletions

View File

@ -1,11 +1,12 @@
[GENERAL] [GENERAL]
PORT = 3001 # Port to run the server on PORT = 3001 # Port to run the server on
SIMILARITY_MEASURE = "cosine" # "cosine" or "dot" SIMILARITY_MEASURE = "cosine" # "cosine" or "dot"
CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" or "groq"
CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use
[API_KEYS] [API_KEYS]
OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef
GROQ = "" # Groq API key - gsk_1234567890abcdef1234567890abcdef
[API_ENDPOINTS] [API_ENDPOINTS]
SEARXNG = "http://localhost:32768" # SearxNG API URL SEARXNG = "http://localhost:32768" # SearxNG API URL

View File

@ -13,6 +13,7 @@ interface Config {
}; };
API_KEYS: { API_KEYS: {
OPENAI: string; OPENAI: string;
GROQ: string;
}; };
API_ENDPOINTS: { API_ENDPOINTS: {
SEARXNG: string; SEARXNG: string;
@ -41,6 +42,8 @@ export const getChatModel = () => loadConfig().GENERAL.CHAT_MODEL;
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI; export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ;
export const getSearxngApiEndpoint = () => loadConfig().API_ENDPOINTS.SEARXNG; export const getSearxngApiEndpoint = () => loadConfig().API_ENDPOINTS.SEARXNG;
export const getOllamaApiEndpoint = () => loadConfig().API_ENDPOINTS.OLLAMA; export const getOllamaApiEndpoint = () => loadConfig().API_ENDPOINTS.OLLAMA;

View File

@ -1,11 +1,16 @@
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { ChatOllama } from '@langchain/community/chat_models/ollama';
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { getOllamaApiEndpoint, getOpenaiApiKey } from '../config'; import {
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
} from '../config';
import logger from '../utils/logger'; import logger from '../utils/logger';
export const getAvailableProviders = async () => { export const getAvailableProviders = async () => {
const openAIApiKey = getOpenaiApiKey(); const openAIApiKey = getOpenaiApiKey();
const groqApiKey = getGroqApiKey();
const ollamaEndpoint = getOllamaApiEndpoint(); const ollamaEndpoint = getOllamaApiEndpoint();
const models = {}; const models = {};
@ -13,17 +18,17 @@ export const getAvailableProviders = async () => {
if (openAIApiKey) { if (openAIApiKey) {
try { try {
models['openai'] = { models['openai'] = {
'gpt-3.5-turbo': new ChatOpenAI({ 'GPT-3.5 turbo': new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-3.5-turbo', modelName: 'gpt-3.5-turbo',
temperature: 0.7, temperature: 0.7,
}), }),
'gpt-4': new ChatOpenAI({ 'GPT-4': new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4', modelName: 'gpt-4',
temperature: 0.7, temperature: 0.7,
}), }),
'gpt-4-turbo': new ChatOpenAI({ 'GPT-4 turbo': new ChatOpenAI({
openAIApiKey, openAIApiKey,
modelName: 'gpt-4-turbo', modelName: 'gpt-4-turbo',
temperature: 0.7, temperature: 0.7,
@ -38,6 +43,59 @@ export const getAvailableProviders = async () => {
} }
} }
if (groqApiKey) {
try {
models['groq'] = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
embeddings: new OpenAIEmbeddings({
openAIApiKey: openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
}
}
if (ollamaEndpoint) { if (ollamaEndpoint) {
try { try {
const response = await fetch(`${ollamaEndpoint}/api/tags`); const response = await fetch(`${ollamaEndpoint}/api/tags`);

View File

@ -3,6 +3,7 @@ import { getAvailableProviders } from '../lib/providers';
import { import {
getChatModel, getChatModel,
getChatModelProvider, getChatModelProvider,
getGroqApiKey,
getOllamaApiEndpoint, getOllamaApiEndpoint,
getOpenaiApiKey, getOpenaiApiKey,
updateConfig, updateConfig,
@ -30,6 +31,7 @@ router.get('/', async (_, res) => {
config['openeaiApiKey'] = getOpenaiApiKey(); config['openeaiApiKey'] = getOpenaiApiKey();
config['ollamaApiUrl'] = getOllamaApiEndpoint(); config['ollamaApiUrl'] = getOllamaApiEndpoint();
config['groqApiKey'] = getGroqApiKey();
res.status(200).json(config); res.status(200).json(config);
}); });
@ -44,6 +46,7 @@ router.post('/', async (req, res) => {
}, },
API_KEYS: { API_KEYS: {
OPENAI: config.openeaiApiKey, OPENAI: config.openeaiApiKey,
GROQ: config.groqApiKey,
}, },
API_ENDPOINTS: { API_ENDPOINTS: {
OLLAMA: config.ollamaApiUrl, OLLAMA: config.ollamaApiUrl,

View File

@ -9,6 +9,7 @@ interface SettingsType {
selectedProvider: string; selectedProvider: string;
selectedChatModel: string; selectedChatModel: string;
openeaiApiKey: string; openeaiApiKey: string;
groqApiKey: string;
ollamaApiUrl: string; ollamaApiUrl: string;
} }
@ -194,6 +195,21 @@ const SettingsDialog = ({
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
/> />
</div> </div>
<div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">GROQ API Key</p>
<input
type="text"
placeholder="GROQ API Key"
defaultValue={config.groqApiKey}
onChange={(e) =>
setConfig({
...config,
groqApiKey: e.target.value,
})
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
/>
</div>
</div> </div>
)} )}
{isLoading && ( {isLoading && (