feat(video-search): handle custom OpenAI

This commit is contained in:
ItzCrazyKns 2024-10-30 10:28:31 +05:30
parent 540f38ae68
commit 7c6ee2ead1
5 changed files with 94 additions and 25 deletions

View File

@ -4,14 +4,28 @@ import { getAvailableChatModelProviders } from '../lib/providers';
import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger'; import logger from '../utils/logger';
import handleVideoSearch from '../agents/videoSearchAgent'; import handleVideoSearch from '../agents/videoSearchAgent';
import { ChatOpenAI } from '@langchain/openai';
const router = express.Router(); const router = express.Router();
interface ChatModel {
provider: string;
model: string;
customOpenAIBaseURL?: string;
customOpenAIKey?: string;
}
interface VideoSearchBody {
query: string;
chatHistory: any[];
chatModel?: ChatModel;
}
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
try { try {
let { query, chat_history, chat_model_provider, chat_model } = req.body; let body: VideoSearchBody = req.body;
chat_history = chat_history.map((msg: any) => { const chatHistory = body.chatHistory.map((msg: any) => {
if (msg.role === 'user') { if (msg.role === 'user') {
return new HumanMessage(msg.content); return new HumanMessage(msg.content);
} else if (msg.role === 'assistant') { } else if (msg.role === 'assistant') {
@ -19,22 +33,50 @@ router.post('/', async (req, res) => {
} }
}); });
const chatModels = await getAvailableChatModelProviders(); const chatModelProviders = await getAvailableChatModelProviders();
const provider = chat_model_provider ?? Object.keys(chatModels)[0];
const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; const chatModelProvider =
body.chatModel?.provider || Object.keys(chatModelProviders)[0];
const chatModel =
body.chatModel?.model ||
Object.keys(chatModelProviders[chatModelProvider])[0];
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) { if (body.chatModel?.provider === 'custom_openai') {
llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; if (
!body.chatModel?.customOpenAIBaseURL ||
!body.chatModel?.customOpenAIKey
) {
return res
.status(400)
.json({ message: 'Missing custom OpenAI base URL or key' });
}
llm = new ChatOpenAI({
modelName: body.chatModel.model,
openAIApiKey: body.chatModel.customOpenAIKey,
temperature: 0.7,
configuration: {
baseURL: body.chatModel.customOpenAIBaseURL,
},
}) as unknown as BaseChatModel;
} else if (
chatModelProviders[chatModelProvider] &&
chatModelProviders[chatModelProvider][chatModel]
) {
llm = chatModelProviders[chatModelProvider][chatModel]
.model as unknown as BaseChatModel | undefined;
} }
if (!llm) { if (!llm) {
res.status(500).json({ message: 'Invalid LLM model selected' }); return res.status(400).json({ message: 'Invalid model selected' });
return;
} }
const videos = await handleVideoSearch({ chat_history, query }, llm); const videos = await handleVideoSearch(
{ chat_history: chatHistory, query: body.query },
llm,
);
res.status(200).json({ videos }); res.status(200).json({ videos });
} catch (err) { } catch (err) {

View File

@ -186,10 +186,10 @@ const MessageBox = ({
<div className="lg:sticky lg:top-20 flex flex-col items-center space-y-3 w-full lg:w-3/12 z-30 h-full pb-4"> <div className="lg:sticky lg:top-20 flex flex-col items-center space-y-3 w-full lg:w-3/12 z-30 h-full pb-4">
<SearchImages <SearchImages
query={history[messageIndex - 1].content} query={history[messageIndex - 1].content}
chat_history={history.slice(0, messageIndex - 1)} chatHistory={history.slice(0, messageIndex - 1)}
/> />
<SearchVideos <SearchVideos
chat_history={history.slice(0, messageIndex - 1)} chatHistory={history.slice(0, messageIndex - 1)}
query={history[messageIndex - 1].content} query={history[messageIndex - 1].content}
/> />
</div> </div>

View File

@ -13,10 +13,10 @@ type Image = {
const SearchImages = ({ const SearchImages = ({
query, query,
chat_history, chatHistory,
}: { }: {
query: string; query: string;
chat_history: Message[]; chatHistory: Message[];
}) => { }) => {
const [images, setImages] = useState<Image[] | null>(null); const [images, setImages] = useState<Image[] | null>(null);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
@ -33,6 +33,9 @@ const SearchImages = ({
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModel');
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const res = await fetch( const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/images`, `${process.env.NEXT_PUBLIC_API_URL}/images`,
{ {
@ -42,9 +45,15 @@ const SearchImages = ({
}, },
body: JSON.stringify({ body: JSON.stringify({
query: query, query: query,
chat_history: chat_history, chatHistory: chatHistory,
chat_model_provider: chatModelProvider, chatModel: {
chat_model: chatModel, provider: chatModelProvider,
model: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIBaseURL: customOpenAIBaseURL,
customOpenAIKey: customOpenAIKey,
}),
},
}), }),
}, },
); );

View File

@ -26,10 +26,10 @@ declare module 'yet-another-react-lightbox' {
const Searchvideos = ({ const Searchvideos = ({
query, query,
chat_history, chatHistory,
}: { }: {
query: string; query: string;
chat_history: Message[]; chatHistory: Message[];
}) => { }) => {
const [videos, setVideos] = useState<Video[] | null>(null); const [videos, setVideos] = useState<Video[] | null>(null);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
@ -46,6 +46,9 @@ const Searchvideos = ({
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModel');
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const res = await fetch( const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/videos`, `${process.env.NEXT_PUBLIC_API_URL}/videos`,
{ {
@ -55,9 +58,15 @@ const Searchvideos = ({
}, },
body: JSON.stringify({ body: JSON.stringify({
query: query, query: query,
chat_history: chat_history, chatHistory: chatHistory,
chat_model_provider: chatModelProvider, chatModel: {
chat_model: chatModel, provider: chatModelProvider,
model: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIBaseURL: customOpenAIBaseURL,
customOpenAIKey: customOpenAIKey,
}),
},
}), }),
}, },
); );

View File

@ -4,15 +4,24 @@ export const getSuggestions = async (chatHisory: Message[]) => {
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModel');
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProvider');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, { const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
body: JSON.stringify({ body: JSON.stringify({
chat_history: chatHisory, chatHistory: chatHisory,
chat_model: chatModel, chatModel: {
chat_model_provider: chatModelProvider, provider: chatModelProvider,
model: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIKey,
customOpenAIBaseURL,
}),
},
}), }),
}); });