Merge branch 'ItzCrazyKns:master' into master

This commit is contained in:
Chuck 2024-05-10 12:00:18 +08:00 committed by GitHub
commit baef45b456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 151 additions and 6 deletions

View File

@ -0,0 +1,55 @@
import { RunnableSequence, RunnableMap } from '@langchain/core/runnables';
import ListLineOutputParser from '../lib/outputParsers/listLineOutputParser';
import { PromptTemplate } from '@langchain/core/prompts';
import formatChatHistoryAsString from '../utils/formatHistory';
import { BaseMessage } from '@langchain/core/messages';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatOpenAI } from '@langchain/openai';
const suggestionGeneratorPrompt = `
You are an AI suggestion generator for an AI powered search engine. You will be given a conversation below. You need to generate 4-5 suggestions based on the conversation. The suggestion should be relevant to the conversation that can be used by the user to ask the chat model for more information.
You need to make sure the suggestions are relevant to the conversation and are helpful to the user. Keep a note that the user might use these suggestions to ask a chat model for more information.
Make sure the suggestions are medium in length and are informative and relevant to the conversation.
Provide these suggestions separated by newlines between the XML tags <suggestions> and </suggestions>. For example:
<suggestions>
Suggestion 1
Suggestion 2
Suggestion 3
</suggestions>
Conversation:
{chat_history}
`;
type SuggestionGeneratorInput = {
chat_history: BaseMessage[];
};
const outputParser = new ListLineOutputParser({
key: 'suggestions',
});
const createSuggestionGeneratorChain = (llm: BaseChatModel) => {
return RunnableSequence.from([
RunnableMap.from({
chat_history: (input: SuggestionGeneratorInput) =>
formatChatHistoryAsString(input.chat_history),
}),
PromptTemplate.fromTemplate(suggestionGeneratorPrompt),
llm,
outputParser,
]);
};
const generateSuggestions = (
input: SuggestionGeneratorInput,
llm: BaseChatModel,
) => {
(llm as ChatOpenAI).temperature = 0;
const suggestionGeneratorChain = createSuggestionGeneratorChain(llm);
return suggestionGeneratorChain.invoke(input);
};
export default generateSuggestions;

View File

@ -0,0 +1,43 @@
import { BaseOutputParser } from '@langchain/core/output_parsers';
interface LineListOutputParserArgs {
key?: string;
}
class LineListOutputParser extends BaseOutputParser<string[]> {
private key = 'questions';
constructor(args?: LineListOutputParserArgs) {
super();
this.key = args.key || this.key;
}
static lc_name() {
return 'LineListOutputParser';
}
lc_namespace = ['langchain', 'output_parsers', 'line_list_output_parser'];
async parse(text: string): Promise<string[]> {
const regex = /^(\s*(-|\*|\d+\.\s|\d+\)\s|\u2022)\s*)+/;
const startKeyIndex = text.indexOf(`<${this.key}>`);
const endKeyIndex = text.indexOf(`</${this.key}>`);
const questionsStartIndex =
startKeyIndex === -1 ? 0 : startKeyIndex + `<${this.key}>`.length;
const questionsEndIndex = endKeyIndex === -1 ? text.length : endKeyIndex;
const lines = text
.slice(questionsStartIndex, questionsEndIndex)
.trim()
.split('\n')
.filter((line) => line.trim() !== '')
.map((line) => line.replace(regex, ''));
return lines;
}
getFormatInstructions(): string {
throw new Error('Not implemented.');
}
}
export default LineListOutputParser;

View File

@ -157,7 +157,6 @@ export const getAvailableEmbeddingModelProviders = async () => {
}); });
return acc; return acc;
}, {}); }, {});
} catch (err) { } catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`); logger.error(`Error loading Ollama embeddings: ${err}`);
} }
@ -172,7 +171,7 @@ export const getAvailableEmbeddingModelProviders = async () => {
modelName: 'Xenova/gte-small', modelName: 'Xenova/gte-small',
}), }),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({ 'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased' modelName: 'Xenova/bert-base-multilingual-uncased',
}), }),
}; };
} catch (err) { } catch (err) {

View File

@ -3,6 +3,7 @@ import imagesRouter from './images';
import videosRouter from './videos'; import videosRouter from './videos';
import configRouter from './config'; import configRouter from './config';
import modelsRouter from './models'; import modelsRouter from './models';
import suggestionsRouter from './suggestions';
const router = express.Router(); const router = express.Router();
@ -10,5 +11,6 @@ router.use('/images', imagesRouter);
router.use('/videos', videosRouter); router.use('/videos', videosRouter);
router.use('/config', configRouter); router.use('/config', configRouter);
router.use('/models', modelsRouter); router.use('/models', modelsRouter);
router.use('/suggestions', suggestionsRouter);
export default router; export default router;

46
src/routes/suggestions.ts Normal file
View File

@ -0,0 +1,46 @@
import express from 'express';
import generateSuggestions from '../agents/suggestionGeneratorAgent';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableChatModelProviders } from '../lib/providers';
import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger';
const router = express.Router();
router.post('/', async (req, res) => {
try {
let { chat_history, chat_model, chat_model_provider } = req.body;
chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') {
return new HumanMessage(msg.content);
} else if (msg.role === 'assistant') {
return new AIMessage(msg.content);
}
});
const chatModels = await getAvailableChatModelProviders();
const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = chat_model || Object.keys(chatModels[provider])[0];
let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
}
if (!llm) {
res.status(500).json({ message: 'Invalid LLM model selected' });
return;
}
const suggestions = await generateSuggestions({ chat_history }, llm);
res.status(200).json({ suggestions: suggestions });
} catch (err) {
res.status(500).json({ message: 'An error has occurred.' });
logger.error(`Error in generating suggestions: ${err.message}`);
}
});
export default router;