diff --git a/docker-compose.yaml b/docker-compose.yaml
index 5eef31e..f9b3757 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -36,4 +36,4 @@ services:
- perplexica-network
networks:
- perplexica-network:
\ No newline at end of file
+ perplexica-network:
diff --git a/package.json b/package.json
index a4af068..c3aa58d 100644
--- a/package.json
+++ b/package.json
@@ -6,7 +6,7 @@
"scripts": {
"start": "node dist/app.js",
"build": "tsc",
- "dev": "nodemon src/app.ts" ,
+ "dev": "nodemon src/app.ts",
"format": "prettier . --check",
"format:write": "prettier . --write"
},
diff --git a/src/agents/suggestionGeneratorAgent.ts b/src/agents/suggestionGeneratorAgent.ts
new file mode 100644
index 0000000..59bd9ea
--- /dev/null
+++ b/src/agents/suggestionGeneratorAgent.ts
@@ -0,0 +1,55 @@
+import { RunnableSequence, RunnableMap } from '@langchain/core/runnables';
+import ListLineOutputParser from '../lib/outputParsers/listLineOutputParser';
+import { PromptTemplate } from '@langchain/core/prompts';
+import formatChatHistoryAsString from '../utils/formatHistory';
+import { BaseMessage } from '@langchain/core/messages';
+import { BaseChatModel } from '@langchain/core/language_models/chat_models';
+import { ChatOpenAI } from '@langchain/openai';
+
+const suggestionGeneratorPrompt = `
+You are an AI suggestion generator for an AI powered search engine. You will be given a conversation below. You need to generate 4-5 suggestions based on the conversation. The suggestion should be relevant to the conversation that can be used by the user to ask the chat model for more information.
+You need to make sure the suggestions are relevant to the conversation and are helpful to the user. Keep a note that the user might use these suggestions to ask a chat model for more information.
+Make sure the suggestions are medium in length and are informative and relevant to the conversation.
+
+Provide these suggestions separated by newlines between the XML tags and . For example:
+
+
+Suggestion 1
+Suggestion 2
+Suggestion 3
+
+
+Conversation:
+{chat_history}
+`;
+
+type SuggestionGeneratorInput = {
+ chat_history: BaseMessage[];
+};
+
+const outputParser = new ListLineOutputParser({
+ key: 'suggestions',
+});
+
+const createSuggestionGeneratorChain = (llm: BaseChatModel) => {
+ return RunnableSequence.from([
+ RunnableMap.from({
+ chat_history: (input: SuggestionGeneratorInput) =>
+ formatChatHistoryAsString(input.chat_history),
+ }),
+ PromptTemplate.fromTemplate(suggestionGeneratorPrompt),
+ llm,
+ outputParser,
+ ]);
+};
+
+const generateSuggestions = (
+ input: SuggestionGeneratorInput,
+ llm: BaseChatModel,
+) => {
+ (llm as ChatOpenAI).temperature = 0;
+ const suggestionGeneratorChain = createSuggestionGeneratorChain(llm);
+ return suggestionGeneratorChain.invoke(input);
+};
+
+export default generateSuggestions;
diff --git a/src/lib/outputParsers/listLineOutputParser.ts b/src/lib/outputParsers/listLineOutputParser.ts
new file mode 100644
index 0000000..4fde080
--- /dev/null
+++ b/src/lib/outputParsers/listLineOutputParser.ts
@@ -0,0 +1,43 @@
+import { BaseOutputParser } from '@langchain/core/output_parsers';
+
+interface LineListOutputParserArgs {
+ key?: string;
+}
+
+class LineListOutputParser extends BaseOutputParser {
+ private key = 'questions';
+
+ constructor(args?: LineListOutputParserArgs) {
+ super();
+ this.key = args.key || this.key;
+ }
+
+ static lc_name() {
+ return 'LineListOutputParser';
+ }
+
+ lc_namespace = ['langchain', 'output_parsers', 'line_list_output_parser'];
+
+ async parse(text: string): Promise {
+ const regex = /^(\s*(-|\*|\d+\.\s|\d+\)\s|\u2022)\s*)+/;
+ const startKeyIndex = text.indexOf(`<${this.key}>`);
+ const endKeyIndex = text.indexOf(`${this.key}>`);
+ const questionsStartIndex =
+ startKeyIndex === -1 ? 0 : startKeyIndex + `<${this.key}>`.length;
+ const questionsEndIndex = endKeyIndex === -1 ? text.length : endKeyIndex;
+ const lines = text
+ .slice(questionsStartIndex, questionsEndIndex)
+ .trim()
+ .split('\n')
+ .filter((line) => line.trim() !== '')
+ .map((line) => line.replace(regex, ''));
+
+ return lines;
+ }
+
+ getFormatInstructions(): string {
+ throw new Error('Not implemented.');
+ }
+}
+
+export default LineListOutputParser;
diff --git a/src/lib/providers.ts b/src/lib/providers.ts
index d2e40f0..c817f87 100644
--- a/src/lib/providers.ts
+++ b/src/lib/providers.ts
@@ -157,7 +157,6 @@ export const getAvailableEmbeddingModelProviders = async () => {
});
return acc;
}, {});
-
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
}
@@ -172,11 +171,11 @@ export const getAvailableEmbeddingModelProviders = async () => {
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
- modelName: 'Xenova/bert-base-multilingual-uncased'
+ modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
- } catch(err) {
- logger.error(`Error loading local embeddings: ${err}`);
+ } catch (err) {
+ logger.error(`Error loading local embeddings: ${err}`);
}
return models;
diff --git a/src/routes/index.ts b/src/routes/index.ts
index 04390cd..257e677 100644
--- a/src/routes/index.ts
+++ b/src/routes/index.ts
@@ -3,6 +3,7 @@ import imagesRouter from './images';
import videosRouter from './videos';
import configRouter from './config';
import modelsRouter from './models';
+import suggestionsRouter from './suggestions';
const router = express.Router();
@@ -10,5 +11,6 @@ router.use('/images', imagesRouter);
router.use('/videos', videosRouter);
router.use('/config', configRouter);
router.use('/models', modelsRouter);
+router.use('/suggestions', suggestionsRouter);
export default router;
diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts
new file mode 100644
index 0000000..10e5715
--- /dev/null
+++ b/src/routes/suggestions.ts
@@ -0,0 +1,46 @@
+import express from 'express';
+import generateSuggestions from '../agents/suggestionGeneratorAgent';
+import { BaseChatModel } from '@langchain/core/language_models/chat_models';
+import { getAvailableChatModelProviders } from '../lib/providers';
+import { HumanMessage, AIMessage } from '@langchain/core/messages';
+import logger from '../utils/logger';
+
+const router = express.Router();
+
+router.post('/', async (req, res) => {
+ try {
+ let { chat_history, chat_model, chat_model_provider } = req.body;
+
+ chat_history = chat_history.map((msg: any) => {
+ if (msg.role === 'user') {
+ return new HumanMessage(msg.content);
+ } else if (msg.role === 'assistant') {
+ return new AIMessage(msg.content);
+ }
+ });
+
+ const chatModels = await getAvailableChatModelProviders();
+ const provider = chat_model_provider || Object.keys(chatModels)[0];
+ const chatModel = chat_model || Object.keys(chatModels[provider])[0];
+
+ let llm: BaseChatModel | undefined;
+
+ if (chatModels[provider] && chatModels[provider][chatModel]) {
+ llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
+ }
+
+ if (!llm) {
+ res.status(500).json({ message: 'Invalid LLM model selected' });
+ return;
+ }
+
+ const suggestions = await generateSuggestions({ chat_history }, llm);
+
+ res.status(200).json({ suggestions: suggestions });
+ } catch (err) {
+ res.status(500).json({ message: 'An error has occurred.' });
+ logger.error(`Error in generating suggestions: ${err.message}`);
+ }
+});
+
+export default router;