diff --git a/docker-compose.yaml b/docker-compose.yaml index 5eef31e..f9b3757 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -36,4 +36,4 @@ services: - perplexica-network networks: - perplexica-network: \ No newline at end of file + perplexica-network: diff --git a/package.json b/package.json index a4af068..c3aa58d 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,7 @@ "scripts": { "start": "node dist/app.js", "build": "tsc", - "dev": "nodemon src/app.ts" , + "dev": "nodemon src/app.ts", "format": "prettier . --check", "format:write": "prettier . --write" }, diff --git a/src/agents/suggestionGeneratorAgent.ts b/src/agents/suggestionGeneratorAgent.ts new file mode 100644 index 0000000..59bd9ea --- /dev/null +++ b/src/agents/suggestionGeneratorAgent.ts @@ -0,0 +1,55 @@ +import { RunnableSequence, RunnableMap } from '@langchain/core/runnables'; +import ListLineOutputParser from '../lib/outputParsers/listLineOutputParser'; +import { PromptTemplate } from '@langchain/core/prompts'; +import formatChatHistoryAsString from '../utils/formatHistory'; +import { BaseMessage } from '@langchain/core/messages'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ChatOpenAI } from '@langchain/openai'; + +const suggestionGeneratorPrompt = ` +You are an AI suggestion generator for an AI powered search engine. You will be given a conversation below. You need to generate 4-5 suggestions based on the conversation. The suggestion should be relevant to the conversation that can be used by the user to ask the chat model for more information. +You need to make sure the suggestions are relevant to the conversation and are helpful to the user. Keep a note that the user might use these suggestions to ask a chat model for more information. +Make sure the suggestions are medium in length and are informative and relevant to the conversation. + +Provide these suggestions separated by newlines between the XML tags and . For example: + + +Suggestion 1 +Suggestion 2 +Suggestion 3 + + +Conversation: +{chat_history} +`; + +type SuggestionGeneratorInput = { + chat_history: BaseMessage[]; +}; + +const outputParser = new ListLineOutputParser({ + key: 'suggestions', +}); + +const createSuggestionGeneratorChain = (llm: BaseChatModel) => { + return RunnableSequence.from([ + RunnableMap.from({ + chat_history: (input: SuggestionGeneratorInput) => + formatChatHistoryAsString(input.chat_history), + }), + PromptTemplate.fromTemplate(suggestionGeneratorPrompt), + llm, + outputParser, + ]); +}; + +const generateSuggestions = ( + input: SuggestionGeneratorInput, + llm: BaseChatModel, +) => { + (llm as ChatOpenAI).temperature = 0; + const suggestionGeneratorChain = createSuggestionGeneratorChain(llm); + return suggestionGeneratorChain.invoke(input); +}; + +export default generateSuggestions; diff --git a/src/lib/outputParsers/listLineOutputParser.ts b/src/lib/outputParsers/listLineOutputParser.ts new file mode 100644 index 0000000..4fde080 --- /dev/null +++ b/src/lib/outputParsers/listLineOutputParser.ts @@ -0,0 +1,43 @@ +import { BaseOutputParser } from '@langchain/core/output_parsers'; + +interface LineListOutputParserArgs { + key?: string; +} + +class LineListOutputParser extends BaseOutputParser { + private key = 'questions'; + + constructor(args?: LineListOutputParserArgs) { + super(); + this.key = args.key || this.key; + } + + static lc_name() { + return 'LineListOutputParser'; + } + + lc_namespace = ['langchain', 'output_parsers', 'line_list_output_parser']; + + async parse(text: string): Promise { + const regex = /^(\s*(-|\*|\d+\.\s|\d+\)\s|\u2022)\s*)+/; + const startKeyIndex = text.indexOf(`<${this.key}>`); + const endKeyIndex = text.indexOf(``); + const questionsStartIndex = + startKeyIndex === -1 ? 0 : startKeyIndex + `<${this.key}>`.length; + const questionsEndIndex = endKeyIndex === -1 ? text.length : endKeyIndex; + const lines = text + .slice(questionsStartIndex, questionsEndIndex) + .trim() + .split('\n') + .filter((line) => line.trim() !== '') + .map((line) => line.replace(regex, '')); + + return lines; + } + + getFormatInstructions(): string { + throw new Error('Not implemented.'); + } +} + +export default LineListOutputParser; diff --git a/src/lib/providers.ts b/src/lib/providers.ts index d2e40f0..c817f87 100644 --- a/src/lib/providers.ts +++ b/src/lib/providers.ts @@ -157,7 +157,6 @@ export const getAvailableEmbeddingModelProviders = async () => { }); return acc; }, {}); - } catch (err) { logger.error(`Error loading Ollama embeddings: ${err}`); } @@ -172,11 +171,11 @@ export const getAvailableEmbeddingModelProviders = async () => { modelName: 'Xenova/gte-small', }), 'Bert Multilingual': new HuggingFaceTransformersEmbeddings({ - modelName: 'Xenova/bert-base-multilingual-uncased' + modelName: 'Xenova/bert-base-multilingual-uncased', }), }; - } catch(err) { - logger.error(`Error loading local embeddings: ${err}`); + } catch (err) { + logger.error(`Error loading local embeddings: ${err}`); } return models; diff --git a/src/routes/index.ts b/src/routes/index.ts index 04390cd..257e677 100644 --- a/src/routes/index.ts +++ b/src/routes/index.ts @@ -3,6 +3,7 @@ import imagesRouter from './images'; import videosRouter from './videos'; import configRouter from './config'; import modelsRouter from './models'; +import suggestionsRouter from './suggestions'; const router = express.Router(); @@ -10,5 +11,6 @@ router.use('/images', imagesRouter); router.use('/videos', videosRouter); router.use('/config', configRouter); router.use('/models', modelsRouter); +router.use('/suggestions', suggestionsRouter); export default router; diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts new file mode 100644 index 0000000..10e5715 --- /dev/null +++ b/src/routes/suggestions.ts @@ -0,0 +1,46 @@ +import express from 'express'; +import generateSuggestions from '../agents/suggestionGeneratorAgent'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { getAvailableChatModelProviders } from '../lib/providers'; +import { HumanMessage, AIMessage } from '@langchain/core/messages'; +import logger from '../utils/logger'; + +const router = express.Router(); + +router.post('/', async (req, res) => { + try { + let { chat_history, chat_model, chat_model_provider } = req.body; + + chat_history = chat_history.map((msg: any) => { + if (msg.role === 'user') { + return new HumanMessage(msg.content); + } else if (msg.role === 'assistant') { + return new AIMessage(msg.content); + } + }); + + const chatModels = await getAvailableChatModelProviders(); + const provider = chat_model_provider || Object.keys(chatModels)[0]; + const chatModel = chat_model || Object.keys(chatModels[provider])[0]; + + let llm: BaseChatModel | undefined; + + if (chatModels[provider] && chatModels[provider][chatModel]) { + llm = chatModels[provider][chatModel] as BaseChatModel | undefined; + } + + if (!llm) { + res.status(500).json({ message: 'Invalid LLM model selected' }); + return; + } + + const suggestions = await generateSuggestions({ chat_history }, llm); + + res.status(200).json({ suggestions: suggestions }); + } catch (err) { + res.status(500).json({ message: 'An error has occurred.' }); + logger.error(`Error in generating suggestions: ${err.message}`); + } +}); + +export default router;