feat(suggestions): handle custom OpenAI

This commit is contained in:
ItzCrazyKns 2024-10-30 10:29:06 +05:30
parent 3e7645614f
commit 65d057a05e
1 changed files with 51 additions and 10 deletions

View File

@ -4,14 +4,27 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableChatModelProviders } from '../lib/providers'; import { getAvailableChatModelProviders } from '../lib/providers';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
import { ChatOpenAI } from '@langchain/openai';
const router = express.Router(); const router = express.Router();
interface ChatModel {
provider: string;
model: string;
customOpenAIBaseURL?: string;
customOpenAIKey?: string;
}
interface SuggestionsBody {
chatHistory: any[];
chatModel?: ChatModel;
}
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
try { try {
let { chat_history, chat_model, chat_model_provider } = req.body; let body: SuggestionsBody = req.body;
chat_history = chat_history.map((msg: any) => { const chatHistory = body.chatHistory.map((msg: any) => {
if (msg.role === 'user') { if (msg.role === 'user') {
return new HumanMessage(msg.content); return new HumanMessage(msg.content);
} else if (msg.role === 'assistant') { } else if (msg.role === 'assistant') {
@ -19,22 +32,50 @@ router.post('/', async (req, res) => {
} }
}); });
const chatModels = await getAvailableChatModelProviders(); const chatModelProviders = await getAvailableChatModelProviders();
const provider = chat_model_provider ?? Object.keys(chatModels)[0];
const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; const chatModelProvider =
body.chatModel?.provider || Object.keys(chatModelProviders)[0];
const chatModel =
body.chatModel?.model ||
Object.keys(chatModelProviders[chatModelProvider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) { if (body.chatModel?.provider === 'custom_openai') {
llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; if (
!body.chatModel?.customOpenAIBaseURL ||
!body.chatModel?.customOpenAIKey
) {
return res
.status(400)
.json({ message: 'Missing custom OpenAI base URL or key' });
}
llm = new ChatOpenAI({
modelName: body.chatModel.model,
openAIApiKey: body.chatModel.customOpenAIKey,
temperature: 0.7,
configuration: {
baseURL: body.chatModel.customOpenAIBaseURL,
},
}) as unknown as BaseChatModel;
} else if (
chatModelProviders[chatModelProvider] &&
chatModelProviders[chatModelProvider][chatModel]
) {
llm = chatModelProviders[chatModelProvider][chatModel]
.model as unknown as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {
res.status(500).json({ message: 'Invalid LLM model selected' }); return res.status(400).json({ message: 'Invalid model selected' });
return;
} }
const suggestions = await generateSuggestions({ chat_history }, llm); const suggestions = await generateSuggestions(
{ chat_history: chatHistory },
llm,
);
res.status(200).json({ suggestions: suggestions }); res.status(200).json({ suggestions: suggestions });
} catch (err) { } catch (err) {