feat(chatModels): load model from localstorage

This commit is contained in:
ItzCrazyKns 2024-05-02 12:14:26 +05:30
parent ed9ff3c20f
commit f618b713af
No known key found for this signature in database
GPG Key ID: 8162927C7CCE3065
16 changed files with 126 additions and 81 deletions

View File

@ -59,13 +59,11 @@ There are mainly 2 ways of installing Perplexica - With Docker, Without Docker.
4. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields: 4. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields:
- `CHAT_MODEL`: The name of the LLM to use. Like `llama3:latest` (using Ollama), `gpt-3.5-turbo` (using OpenAI), etc. - `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
- `CHAT_MODEL_PROVIDER`: The chat model provider, either `openai` or `ollama`. Depending upon which provider you use you would have to fill in the following fields: - `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
- `GROQ`: Your Groq API key. **You only need to fill this if you wish to use Groq's hosted models**
- `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**. **Note**: You can change these after starting Perplexica from the settings dialog.
- `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
**Note**: You can change these and use different models after running Perplexica as well from the settings page.
- `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.) - `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.)

View File

@ -1,6 +1,6 @@
{ {
"name": "perplexica-backend", "name": "perplexica-backend",
"version": "1.0.0", "version": "1.1.0",
"license": "MIT", "license": "MIT",
"author": "ItzCrazyKns", "author": "ItzCrazyKns",
"scripts": { "scripts": {

View File

@ -1,8 +1,6 @@
[GENERAL] [GENERAL]
PORT = 3001 # Port to run the server on PORT = 3001 # Port to run the server on
SIMILARITY_MEASURE = "cosine" # "cosine" or "dot" SIMILARITY_MEASURE = "cosine" # "cosine" or "dot"
CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" or "groq"
CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use
[API_KEYS] [API_KEYS]
OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef

View File

@ -8,8 +8,6 @@ interface Config {
GENERAL: { GENERAL: {
PORT: number; PORT: number;
SIMILARITY_MEASURE: string; SIMILARITY_MEASURE: string;
CHAT_MODEL_PROVIDER: string;
CHAT_MODEL: string;
}; };
API_KEYS: { API_KEYS: {
OPENAI: string; OPENAI: string;
@ -35,11 +33,6 @@ export const getPort = () => loadConfig().GENERAL.PORT;
export const getSimilarityMeasure = () => export const getSimilarityMeasure = () =>
loadConfig().GENERAL.SIMILARITY_MEASURE; loadConfig().GENERAL.SIMILARITY_MEASURE;
export const getChatModelProvider = () =>
loadConfig().GENERAL.CHAT_MODEL_PROVIDER;
export const getChatModel = () => loadConfig().GENERAL.CHAT_MODEL;
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI; export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ; export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ;
@ -52,21 +45,19 @@ export const updateConfig = (config: RecursivePartial<Config>) => {
const currentConfig = loadConfig(); const currentConfig = loadConfig();
for (const key in currentConfig) { for (const key in currentConfig) {
/* if (currentConfig[key] && !config[key]) { if (!config[key]) config[key] = {};
config[key] = currentConfig[key];
} */
if (currentConfig[key] && typeof currentConfig[key] === 'object') { if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) {
for (const nestedKey in currentConfig[key]) { for (const nestedKey in currentConfig[key]) {
if ( if (
currentConfig[key][nestedKey] &&
!config[key][nestedKey] && !config[key][nestedKey] &&
currentConfig[key][nestedKey] &&
config[key][nestedKey] !== '' config[key][nestedKey] !== ''
) { ) {
config[key][nestedKey] = currentConfig[key][nestedKey]; config[key][nestedKey] = currentConfig[key][nestedKey];
} }
} }
} else if (currentConfig[key] && !config[key] && config[key] !== '') { } else if (currentConfig[key] && config[key] !== '') {
config[key] = currentConfig[key]; config[key] = currentConfig[key];
} }
} }

View File

@ -1,8 +1,6 @@
import express from 'express'; import express from 'express';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableProviders } from '../lib/providers';
import { import {
getChatModel,
getChatModelProvider,
getGroqApiKey, getGroqApiKey,
getOllamaApiEndpoint, getOllamaApiEndpoint,
getOpenaiApiKey, getOpenaiApiKey,
@ -26,9 +24,6 @@ router.get('/', async (_, res) => {
config['providers'][provider] = Object.keys(providers[provider]); config['providers'][provider] = Object.keys(providers[provider]);
} }
config['selectedProvider'] = getChatModelProvider();
config['selectedChatModel'] = getChatModel();
config['openeaiApiKey'] = getOpenaiApiKey(); config['openeaiApiKey'] = getOpenaiApiKey();
config['ollamaApiUrl'] = getOllamaApiEndpoint(); config['ollamaApiUrl'] = getOllamaApiEndpoint();
config['groqApiKey'] = getGroqApiKey(); config['groqApiKey'] = getGroqApiKey();
@ -40,10 +35,6 @@ router.post('/', async (req, res) => {
const config = req.body; const config = req.body;
const updatedConfig = { const updatedConfig = {
GENERAL: {
CHAT_MODEL_PROVIDER: config.selectedProvider,
CHAT_MODEL: config.selectedChatModel,
},
API_KEYS: { API_KEYS: {
OPENAI: config.openeaiApiKey, OPENAI: config.openeaiApiKey,
GROQ: config.groqApiKey, GROQ: config.groqApiKey,

View File

@ -2,7 +2,6 @@ import express from 'express';
import handleImageSearch from '../agents/imageSearchAgent'; import handleImageSearch from '../agents/imageSearchAgent';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableProviders } from '../lib/providers';
import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
@ -10,7 +9,7 @@ const router = express.Router();
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
try { try {
let { query, chat_history } = req.body; let { query, chat_history, chat_model_provider, chat_model } = req.body;
chat_history = chat_history.map((msg: any) => { chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') { if (msg.role === 'user') {
@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
} }
}); });
const models = await getAvailableProviders(); const chatModels = await getAvailableProviders();
const provider = getChatModelProvider(); const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = getChatModel(); const chatModel = chat_model || Object.keys(chatModels[provider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (models[provider] && models[provider][chatModel]) { if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = models[provider][chatModel] as BaseChatModel | undefined; llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {

View File

@ -2,11 +2,13 @@ import express from 'express';
import imagesRouter from './images'; import imagesRouter from './images';
import videosRouter from './videos'; import videosRouter from './videos';
import configRouter from './config'; import configRouter from './config';
import modelsRouter from './models';
const router = express.Router(); const router = express.Router();
router.use('/images', imagesRouter); router.use('/images', imagesRouter);
router.use('/videos', videosRouter); router.use('/videos', videosRouter);
router.use('/config', configRouter); router.use('/config', configRouter);
router.use('/models', modelsRouter);
export default router; export default router;

18
src/routes/models.ts Normal file
View File

@ -0,0 +1,18 @@
import express from 'express';
import logger from '../utils/logger';
import { getAvailableProviders } from '../lib/providers';
const router = express.Router();
router.get('/', async (req, res) => {
try {
const providers = await getAvailableProviders();
res.status(200).json({ providers });
} catch (err) {
res.status(500).json({ message: 'An error has occurred.' });
logger.error(err.message);
}
});
export default router;

View File

@ -1,7 +1,6 @@
import express from 'express'; import express from 'express';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableProviders } from '../lib/providers';
import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
import handleVideoSearch from '../agents/videoSearchAgent'; import handleVideoSearch from '../agents/videoSearchAgent';
@ -10,7 +9,7 @@ const router = express.Router();
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
try { try {
let { query, chat_history } = req.body; let { query, chat_history, chat_model_provider, chat_model } = req.body;
chat_history = chat_history.map((msg: any) => { chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') { if (msg.role === 'user') {
@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
} }
}); });
const models = await getAvailableProviders(); const chatModels = await getAvailableProviders();
const provider = getChatModelProvider(); const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = getChatModel(); const chatModel = chat_model || Object.keys(chatModels[provider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (models[provider] && models[provider][chatModel]) { if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = models[provider][chatModel] as BaseChatModel | undefined; llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {

View File

@ -1,15 +1,23 @@
import { WebSocket } from 'ws'; import { WebSocket } from 'ws';
import { handleMessage } from './messageHandler'; import { handleMessage } from './messageHandler';
import { getChatModel, getChatModelProvider } from '../config';
import { getAvailableProviders } from '../lib/providers'; import { getAvailableProviders } from '../lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings'; import type { Embeddings } from '@langchain/core/embeddings';
import type { IncomingMessage } from 'http';
import logger from '../utils/logger'; import logger from '../utils/logger';
export const handleConnection = async (ws: WebSocket) => { export const handleConnection = async (
ws: WebSocket,
request: IncomingMessage,
) => {
const searchParams = new URL(request.url, `http://${request.headers.host}`)
.searchParams;
const models = await getAvailableProviders(); const models = await getAvailableProviders();
const provider = getChatModelProvider(); const provider =
const chatModel = getChatModel(); searchParams.get('chatModelProvider') || Object.keys(models)[0];
const chatModel =
searchParams.get('chatModel') || Object.keys(models[provider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
let embeddings: Embeddings | undefined; let embeddings: Embeddings | undefined;

View File

@ -10,9 +10,7 @@ export const initServer = (
const port = getPort(); const port = getPort();
const wss = new WebSocketServer({ server }); const wss = new WebSocketServer({ server });
wss.on('connection', (ws) => { wss.on('connection', handleConnection);
handleConnection(ws);
});
logger.info(`WebSocket server started on port ${port}`); logger.info(`WebSocket server started on port ${port}`);
}; };

View File

@ -19,14 +19,42 @@ const useSocket = (url: string) => {
useEffect(() => { useEffect(() => {
if (!ws) { if (!ws) {
const ws = new WebSocket(url); const connectWs = async () => {
ws.onopen = () => { let chatModel = localStorage.getItem('chatModel');
console.log('[DEBUG] open'); let chatModelProvider = localStorage.getItem('chatModelProvider');
setWs(ws);
if (!chatModel || !chatModelProvider) {
const chatModelProviders = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/models`,
).then(async (res) => (await res.json())['providers']);
if (
!chatModelProviders ||
Object.keys(chatModelProviders).length === 0
)
return console.error('No chat models available');
chatModelProvider = Object.keys(chatModelProviders)[0];
chatModel = Object.keys(chatModelProviders[chatModelProvider])[0];
localStorage.setItem('chatModel', chatModel!);
localStorage.setItem('chatModelProvider', chatModelProvider);
}
const ws = new WebSocket(
`${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`,
);
ws.onopen = () => {
console.log('[DEBUG] open');
setWs(ws);
};
}; };
connectWs();
} }
return () => { return () => {
1;
ws?.close(); ws?.close();
console.log('[DEBUG] closed'); console.log('[DEBUG] closed');
}; };

View File

@ -29,6 +29,10 @@ const SearchImages = ({
<button <button
onClick={async () => { onClick={async () => {
setLoading(true); setLoading(true);
const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel');
const res = await fetch( const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/images`, `${process.env.NEXT_PUBLIC_API_URL}/images`,
{ {
@ -39,6 +43,8 @@ const SearchImages = ({
body: JSON.stringify({ body: JSON.stringify({
query: query, query: query,
chat_history: chat_history, chat_history: chat_history,
chat_model_provider: chatModelProvider,
chat_model: chatModel,
}), }),
}, },
); );

View File

@ -42,6 +42,10 @@ const Searchvideos = ({
<button <button
onClick={async () => { onClick={async () => {
setLoading(true); setLoading(true);
const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel');
const res = await fetch( const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/videos`, `${process.env.NEXT_PUBLIC_API_URL}/videos`,
{ {
@ -52,6 +56,8 @@ const Searchvideos = ({
body: JSON.stringify({ body: JSON.stringify({
query: query, query: query,
chat_history: chat_history, chat_history: chat_history,
chat_model_provider: chatModelProvider,
chat_model: chatModel,
}), }),
}, },
); );

View File

@ -6,8 +6,6 @@ interface SettingsType {
providers: { providers: {
[key: string]: string[]; [key: string]: string[];
}; };
selectedProvider: string;
selectedChatModel: string;
openeaiApiKey: string; openeaiApiKey: string;
groqApiKey: string; groqApiKey: string;
ollamaApiUrl: string; ollamaApiUrl: string;
@ -21,6 +19,12 @@ const SettingsDialog = ({
setIsOpen: (isOpen: boolean) => void; setIsOpen: (isOpen: boolean) => void;
}) => { }) => {
const [config, setConfig] = useState<SettingsType | null>(null); const [config, setConfig] = useState<SettingsType | null>(null);
const [selectedChatModelProvider, setSelectedChatModelProvider] = useState<
string | null
>(null);
const [selectedChatModel, setSelectedChatModel] = useState<string | null>(
null,
);
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [isUpdating, setIsUpdating] = useState(false); const [isUpdating, setIsUpdating] = useState(false);
@ -39,6 +43,11 @@ const SettingsDialog = ({
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [isOpen]); }, [isOpen]);
useEffect(() => {
setSelectedChatModelProvider(localStorage.getItem('chatModelProvider'));
setSelectedChatModel(localStorage.getItem('chatModel'));
}, []);
const handleSubmit = async () => { const handleSubmit = async () => {
setIsUpdating(true); setIsUpdating(true);
@ -50,6 +59,9 @@ const SettingsDialog = ({
}, },
body: JSON.stringify(config), body: JSON.stringify(config),
}); });
localStorage.setItem('chatModelProvider', selectedChatModelProvider!);
localStorage.setItem('chatModel', selectedChatModel!);
} catch (err) { } catch (err) {
console.log(err); console.log(err);
} finally { } finally {
@ -101,21 +113,19 @@ const SettingsDialog = ({
Chat model Provider Chat model Provider
</p> </p>
<select <select
onChange={(e) => onChange={(e) => {
setConfig({ setSelectedChatModelProvider(e.target.value);
...config, setSelectedChatModel(
selectedProvider: e.target.value, config.providers[e.target.value][0],
selectedChatModel: );
config.providers[e.target.value][0], }}
})
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
> >
{Object.keys(config.providers).map((provider) => ( {Object.keys(config.providers).map((provider) => (
<option <option
key={provider} key={provider}
value={provider} value={provider}
selected={provider === config.selectedProvider} selected={provider === selectedChatModelProvider}
> >
{provider.charAt(0).toUpperCase() + {provider.charAt(0).toUpperCase() +
provider.slice(1)} provider.slice(1)}
@ -124,29 +134,22 @@ const SettingsDialog = ({
</select> </select>
</div> </div>
)} )}
{config.selectedProvider && ( {selectedChatModelProvider && (
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-white/70 text-sm">Chat Model</p> <p className="text-white/70 text-sm">Chat Model</p>
<select <select
onChange={(e) => onChange={(e) => setSelectedChatModel(e.target.value)}
setConfig({
...config,
selectedChatModel: e.target.value,
})
}
className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm" className="bg-[#111111] px-3 py-2 flex items-center overflow-hidden border border-[#1C1C1C] text-white rounded-lg text-sm"
> >
{config.providers[config.selectedProvider] ? ( {config.providers[selectedChatModelProvider] ? (
config.providers[config.selectedProvider].length > config.providers[selectedChatModelProvider].length >
0 ? ( 0 ? (
config.providers[config.selectedProvider].map( config.providers[selectedChatModelProvider].map(
(model) => ( (model) => (
<option <option
key={model} key={model}
value={model} value={model}
selected={ selected={model === selectedChatModel}
model === config.selectedChatModel
}
> >
{model} {model}
</option> </option>

View File

@ -1,6 +1,6 @@
{ {
"name": "perplexica-frontend", "name": "perplexica-frontend",
"version": "1.0.0", "version": "1.1.0",
"license": "MIT", "license": "MIT",
"author": "ItzCrazyKns", "author": "ItzCrazyKns",
"scripts": { "scripts": {