feat(providers): add optimization modes

This commit is contained in:
ItzCrazyKns 2024-10-11 10:35:59 +05:30
parent 877735b852
commit 7cce853618
9 changed files with 294 additions and 88 deletions

View File

@ -118,7 +118,6 @@ const createBasicAcademicSearchRetrieverChain = (llm: BaseChatModel) => {
engines: [ engines: [
'arxiv', 'arxiv',
'google scholar', 'google scholar',
'internetarchivescholar',
'pubmed', 'pubmed',
], ],
}); });
@ -143,6 +142,7 @@ const createBasicAcademicSearchRetrieverChain = (llm: BaseChatModel) => {
const createBasicAcademicSearchAnsweringChain = ( const createBasicAcademicSearchAnsweringChain = (
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const basicAcademicSearchRetrieverChain = const basicAcademicSearchRetrieverChain =
createBasicAcademicSearchRetrieverChain(llm); createBasicAcademicSearchRetrieverChain(llm);
@ -168,26 +168,33 @@ const createBasicAcademicSearchAnsweringChain = (
(doc) => doc.pageContent && doc.pageContent.length > 0, (doc) => doc.pageContent && doc.pageContent.length > 0,
); );
const [docEmbeddings, queryEmbedding] = await Promise.all([ if (optimizationMode === 'speed') {
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)), return docsWithContent.slice(0, 15);
embeddings.embedQuery(query), } else if (optimizationMode === 'balanced') {
]); console.log('Balanced mode');
const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
docsWithContent.map((doc) => doc.pageContent),
),
embeddings.embedQuery(query),
]);
const similarity = docEmbeddings.map((docEmbedding, i) => { const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding); const sim = computeSimilarity(queryEmbedding, docEmbedding);
return { return {
index: i, index: i,
similarity: sim, similarity: sim,
}; };
}); });
const sortedDocs = similarity const sortedDocs = similarity
.sort((a, b) => b.similarity - a.similarity) .sort((a, b) => b.similarity - a.similarity)
.slice(0, 15) .slice(0, 15)
.map((sim) => docsWithContent[sim.index]); .map((sim) => docsWithContent[sim.index]);
return sortedDocs; return sortedDocs;
}
}; };
return RunnableSequence.from([ return RunnableSequence.from([
@ -224,12 +231,17 @@ const basicAcademicSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = new eventEmitter(); const emitter = new eventEmitter();
try { try {
const basicAcademicSearchAnsweringChain = const basicAcademicSearchAnsweringChain =
createBasicAcademicSearchAnsweringChain(llm, embeddings); createBasicAcademicSearchAnsweringChain(
llm,
embeddings,
optimizationMode,
);
const stream = basicAcademicSearchAnsweringChain.streamEvents( const stream = basicAcademicSearchAnsweringChain.streamEvents(
{ {
@ -258,8 +270,15 @@ const handleAcademicSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = basicAcademicSearch(message, history, llm, embeddings); const emitter = basicAcademicSearch(
message,
history,
llm,
embeddings,
optimizationMode,
);
return emitter; return emitter;
}; };

View File

@ -138,6 +138,7 @@ const createBasicRedditSearchRetrieverChain = (llm: BaseChatModel) => {
const createBasicRedditSearchAnsweringChain = ( const createBasicRedditSearchAnsweringChain = (
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const basicRedditSearchRetrieverChain = const basicRedditSearchRetrieverChain =
createBasicRedditSearchRetrieverChain(llm); createBasicRedditSearchRetrieverChain(llm);
@ -163,27 +164,33 @@ const createBasicRedditSearchAnsweringChain = (
(doc) => doc.pageContent && doc.pageContent.length > 0, (doc) => doc.pageContent && doc.pageContent.length > 0,
); );
const [docEmbeddings, queryEmbedding] = await Promise.all([ if (optimizationMode === 'speed') {
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)), return docsWithContent.slice(0, 15);
embeddings.embedQuery(query), } else if (optimizationMode === 'balanced') {
]); const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
docsWithContent.map((doc) => doc.pageContent),
),
embeddings.embedQuery(query),
]);
const similarity = docEmbeddings.map((docEmbedding, i) => { const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding); const sim = computeSimilarity(queryEmbedding, docEmbedding);
return { return {
index: i, index: i,
similarity: sim, similarity: sim,
}; };
}); });
const sortedDocs = similarity const sortedDocs = similarity
.filter((sim) => sim.similarity > 0.3) .filter((sim) => sim.similarity > 0.3)
.sort((a, b) => b.similarity - a.similarity) .sort((a, b) => b.similarity - a.similarity)
.slice(0, 15) .slice(0, 15)
.map((sim) => docsWithContent[sim.index]); .map((sim) => docsWithContent[sim.index]);
return sortedDocs; return sortedDocs;
}
}; };
return RunnableSequence.from([ return RunnableSequence.from([
@ -220,12 +227,13 @@ const basicRedditSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = new eventEmitter(); const emitter = new eventEmitter();
try { try {
const basicRedditSearchAnsweringChain = const basicRedditSearchAnsweringChain =
createBasicRedditSearchAnsweringChain(llm, embeddings); createBasicRedditSearchAnsweringChain(llm, embeddings, optimizationMode);
const stream = basicRedditSearchAnsweringChain.streamEvents( const stream = basicRedditSearchAnsweringChain.streamEvents(
{ {
chat_history: history, chat_history: history,
@ -253,8 +261,15 @@ const handleRedditSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = basicRedditSearch(message, history, llm, embeddings); const emitter = basicRedditSearch(
message,
history,
llm,
embeddings,
optimizationMode,
);
return emitter; return emitter;
}; };

View File

@ -216,12 +216,34 @@ const createBasicWebSearchRetrieverChain = (llm: BaseChatModel) => {
await Promise.all( await Promise.all(
docGroups.map(async (doc) => { docGroups.map(async (doc) => {
const res = await llm.invoke(` const res = await llm.invoke(`
You are a text summarizer. You need to summarize the text provided inside the \`text\` XML block. You are a web search summarizer, tasked with summarizing a piece of text retrieved from a web search. Your job is to summarize the
You need to summarize the text into 1 or 2 sentences capturing the main idea of the text. text into a detailed, 2-4 paragraph explanation that captures the main ideas and provides a comprehensive answer to the query.
You need to make sure that you don't miss any point while summarizing the text. If the query is \"summarize\", you should provide a detailed summary of the text. If the query is a specific question, you should answer it in the summary.
You will also be given a \`query\` XML block which will contain the query of the user. Try to answer the query in the summary from the text provided.
If the query says Summarize then you just need to summarize the text without answering the query. - **Journalistic tone**: The summary should sound professional and journalistic, not too casual or vague.
Only return the summarized text without any other messages, text or XML block. - **Thorough and detailed**: Ensure that every key point from the text is captured and that the summary directly answers the query.
- **Not too lengthy, but detailed**: The summary should be informative but not excessively long. Focus on providing detailed information in a concise format.
The text will be shared inside the \`text\` XML tag, and the query inside the \`query\` XML tag.
<example>
<text>
Docker is a set of platform-as-a-service products that use OS-level virtualization to deliver software in packages called containers.
It was first released in 2013 and is developed by Docker, Inc. Docker is designed to make it easier to create, deploy, and run applications
by using containers.
</text>
<query>
What is Docker and how does it work?
</query>
Response:
Docker is a revolutionary platform-as-a-service product developed by Docker, Inc., that uses container technology to make application
deployment more efficient. It allows developers to package their software with all necessary dependencies, making it easier to run in
any environment. Released in 2013, Docker has transformed the way applications are built, deployed, and managed.
</example>
Everything below is the actual data you will be working with. Good luck!
<query> <query>
${question} ${question}
@ -273,6 +295,7 @@ const createBasicWebSearchRetrieverChain = (llm: BaseChatModel) => {
const createBasicWebSearchAnsweringChain = ( const createBasicWebSearchAnsweringChain = (
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const basicWebSearchRetrieverChain = createBasicWebSearchRetrieverChain(llm); const basicWebSearchRetrieverChain = createBasicWebSearchRetrieverChain(llm);
@ -301,27 +324,33 @@ const createBasicWebSearchAnsweringChain = (
(doc) => doc.pageContent && doc.pageContent.length > 0, (doc) => doc.pageContent && doc.pageContent.length > 0,
); );
const [docEmbeddings, queryEmbedding] = await Promise.all([ if (optimizationMode === 'speed') {
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)), return docsWithContent.slice(0, 15);
embeddings.embedQuery(query), } else if (optimizationMode === 'balanced') {
]); const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
docsWithContent.map((doc) => doc.pageContent),
),
embeddings.embedQuery(query),
]);
const similarity = docEmbeddings.map((docEmbedding, i) => { const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding); const sim = computeSimilarity(queryEmbedding, docEmbedding);
return { return {
index: i, index: i,
similarity: sim, similarity: sim,
}; };
}); });
const sortedDocs = similarity const sortedDocs = similarity
.filter((sim) => sim.similarity > 0.3) .filter((sim) => sim.similarity > 0.3)
.sort((a, b) => b.similarity - a.similarity) .sort((a, b) => b.similarity - a.similarity)
.slice(0, 15) .slice(0, 15)
.map((sim) => docsWithContent[sim.index]); .map((sim) => docsWithContent[sim.index]);
return sortedDocs; return sortedDocs;
}
}; };
return RunnableSequence.from([ return RunnableSequence.from([
@ -358,6 +387,7 @@ const basicWebSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = new eventEmitter(); const emitter = new eventEmitter();
@ -365,6 +395,7 @@ const basicWebSearch = (
const basicWebSearchAnsweringChain = createBasicWebSearchAnsweringChain( const basicWebSearchAnsweringChain = createBasicWebSearchAnsweringChain(
llm, llm,
embeddings, embeddings,
optimizationMode,
); );
const stream = basicWebSearchAnsweringChain.streamEvents( const stream = basicWebSearchAnsweringChain.streamEvents(
@ -394,8 +425,15 @@ const handleWebSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = basicWebSearch(message, history, llm, embeddings); const emitter = basicWebSearch(
message,
history,
llm,
embeddings,
optimizationMode,
);
return emitter; return emitter;
}; };

View File

@ -138,6 +138,7 @@ const createBasicYoutubeSearchRetrieverChain = (llm: BaseChatModel) => {
const createBasicYoutubeSearchAnsweringChain = ( const createBasicYoutubeSearchAnsweringChain = (
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const basicYoutubeSearchRetrieverChain = const basicYoutubeSearchRetrieverChain =
createBasicYoutubeSearchRetrieverChain(llm); createBasicYoutubeSearchRetrieverChain(llm);
@ -163,27 +164,33 @@ const createBasicYoutubeSearchAnsweringChain = (
(doc) => doc.pageContent && doc.pageContent.length > 0, (doc) => doc.pageContent && doc.pageContent.length > 0,
); );
const [docEmbeddings, queryEmbedding] = await Promise.all([ if (optimizationMode === 'speed') {
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)), return docsWithContent.slice(0, 15);
embeddings.embedQuery(query), } else {
]); const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
docsWithContent.map((doc) => doc.pageContent),
),
embeddings.embedQuery(query),
]);
const similarity = docEmbeddings.map((docEmbedding, i) => { const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding); const sim = computeSimilarity(queryEmbedding, docEmbedding);
return { return {
index: i, index: i,
similarity: sim, similarity: sim,
}; };
}); });
const sortedDocs = similarity const sortedDocs = similarity
.filter((sim) => sim.similarity > 0.3) .filter((sim) => sim.similarity > 0.3)
.sort((a, b) => b.similarity - a.similarity) .sort((a, b) => b.similarity - a.similarity)
.slice(0, 15) .slice(0, 15)
.map((sim) => docsWithContent[sim.index]); .map((sim) => docsWithContent[sim.index]);
return sortedDocs; return sortedDocs;
}
}; };
return RunnableSequence.from([ return RunnableSequence.from([
@ -220,12 +227,13 @@ const basicYoutubeSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = new eventEmitter(); const emitter = new eventEmitter();
try { try {
const basicYoutubeSearchAnsweringChain = const basicYoutubeSearchAnsweringChain =
createBasicYoutubeSearchAnsweringChain(llm, embeddings); createBasicYoutubeSearchAnsweringChain(llm, embeddings, optimizationMode);
const stream = basicYoutubeSearchAnsweringChain.streamEvents( const stream = basicYoutubeSearchAnsweringChain.streamEvents(
{ {
@ -254,8 +262,15 @@ const handleYoutubeSearch = (
history: BaseMessage[], history: BaseMessage[],
llm: BaseChatModel, llm: BaseChatModel,
embeddings: Embeddings, embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => { ) => {
const emitter = basicYoutubeSearch(message, history, llm, embeddings); const emitter = basicYoutubeSearch(
message,
history,
llm,
embeddings,
optimizationMode,
);
return emitter; return emitter;
}; };

View File

@ -22,7 +22,7 @@ type Message = {
type WSMessage = { type WSMessage = {
message: Message; message: Message;
copilot: boolean; optimizationMode: string;
type: string; type: string;
focusMode: string; focusMode: string;
history: Array<[string, string]>; history: Array<[string, string]>;
@ -138,6 +138,7 @@ export const handleMessage = async (
history, history,
llm, llm,
embeddings, embeddings,
parsedWSMessage.optimizationMode,
); );
handleEmitterEvents(emitter, ws, id, parsedMessage.chatId); handleEmitterEvents(emitter, ws, id, parsedMessage.chatId);

View File

@ -278,6 +278,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
const [messages, setMessages] = useState<Message[]>([]); const [messages, setMessages] = useState<Message[]>([]);
const [focusMode, setFocusMode] = useState('webSearch'); const [focusMode, setFocusMode] = useState('webSearch');
const [optimizationMode, setOptimizationMode] = useState('speed');
const [isMessagesLoaded, setIsMessagesLoaded] = useState(false); const [isMessagesLoaded, setIsMessagesLoaded] = useState(false);
@ -346,6 +347,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
content: message, content: message,
}, },
focusMode: focusMode, focusMode: focusMode,
optimizationMode: optimizationMode,
history: [...chatHistory, ['human', message]], history: [...chatHistory, ['human', message]],
}), }),
); );
@ -508,6 +510,8 @@ const ChatWindow = ({ id }: { id?: string }) => {
sendMessage={sendMessage} sendMessage={sendMessage}
focusMode={focusMode} focusMode={focusMode}
setFocusMode={setFocusMode} setFocusMode={setFocusMode}
optimizationMode={optimizationMode}
setOptimizationMode={setOptimizationMode}
/> />
)} )}
</div> </div>

View File

@ -4,10 +4,14 @@ const EmptyChat = ({
sendMessage, sendMessage,
focusMode, focusMode,
setFocusMode, setFocusMode,
optimizationMode,
setOptimizationMode,
}: { }: {
sendMessage: (message: string) => void; sendMessage: (message: string) => void;
focusMode: string; focusMode: string;
setFocusMode: (mode: string) => void; setFocusMode: (mode: string) => void;
optimizationMode: string;
setOptimizationMode: (mode: string) => void;
}) => { }) => {
return ( return (
<div className="relative"> <div className="relative">
@ -19,6 +23,8 @@ const EmptyChat = ({
sendMessage={sendMessage} sendMessage={sendMessage}
focusMode={focusMode} focusMode={focusMode}
setFocusMode={setFocusMode} setFocusMode={setFocusMode}
optimizationMode={optimizationMode}
setOptimizationMode={setOptimizationMode}
/> />
</div> </div>
</div> </div>

View File

@ -3,15 +3,20 @@ import { useEffect, useRef, useState } from 'react';
import TextareaAutosize from 'react-textarea-autosize'; import TextareaAutosize from 'react-textarea-autosize';
import CopilotToggle from './MessageInputActions/Copilot'; import CopilotToggle from './MessageInputActions/Copilot';
import Focus from './MessageInputActions/Focus'; import Focus from './MessageInputActions/Focus';
import Optimization from './MessageInputActions/Optimization';
const EmptyChatMessageInput = ({ const EmptyChatMessageInput = ({
sendMessage, sendMessage,
focusMode, focusMode,
setFocusMode, setFocusMode,
optimizationMode,
setOptimizationMode,
}: { }: {
sendMessage: (message: string) => void; sendMessage: (message: string) => void;
focusMode: string; focusMode: string;
setFocusMode: (mode: string) => void; setFocusMode: (mode: string) => void;
optimizationMode: string;
setOptimizationMode: (mode: string) => void;
}) => { }) => {
const [copilotEnabled, setCopilotEnabled] = useState(false); const [copilotEnabled, setCopilotEnabled] = useState(false);
const [message, setMessage] = useState(''); const [message, setMessage] = useState('');
@ -66,14 +71,13 @@ const EmptyChatMessageInput = ({
placeholder="Ask anything..." placeholder="Ask anything..."
/> />
<div className="flex flex-row items-center justify-between mt-4"> <div className="flex flex-row items-center justify-between mt-4">
<div className="flex flex-row items-center space-x-1 -mx-2"> <div className="flex flex-row items-center space-x-4">
<Focus focusMode={focusMode} setFocusMode={setFocusMode} /> <Focus focusMode={focusMode} setFocusMode={setFocusMode} />
{/* <Attach /> */}
</div> </div>
<div className="flex flex-row items-center space-x-4 -mx-2"> <div className="flex flex-row items-center space-x-1 sm:space-x-4">
<CopilotToggle <Optimization
copilotEnabled={copilotEnabled} optimizationMode={optimizationMode}
setCopilotEnabled={setCopilotEnabled} setOptimizationMode={setOptimizationMode}
/> />
<button <button
disabled={message.trim().length === 0} disabled={message.trim().length === 0}

View File

@ -0,0 +1,104 @@
import { ChevronDown, Sliders, Star, Zap } from 'lucide-react';
import { cn } from '@/lib/utils';
import {
Popover,
PopoverButton,
PopoverPanel,
Transition,
} from '@headlessui/react';
import { Fragment } from 'react';
const OptimizationModes = [
{
key: 'speed',
title: 'Speed',
description: 'Prioritize speed and get the quickest possible answer.',
icon: <Zap size={20} className="text-[#FF9800]" />,
},
{
key: 'balanced',
title: 'Balanced',
description: 'Find the right balance between speed and accuracy',
icon: <Sliders size={20} className="text-[#4CAF50]" />,
},
{
key: 'quality',
title: 'Quality (Soon)',
description: 'Get the most thorough and accurate answer',
icon: (
<Star
size={16}
className="text-[#2196F3] dark:text-[#BBDEFB] fill-[#BBDEFB] dark:fill-[#2196F3]"
/>
),
},
];
const Optimization = ({
optimizationMode,
setOptimizationMode,
}: {
optimizationMode: string;
setOptimizationMode: (mode: string) => void;
}) => {
return (
<Popover className="relative w-full max-w-[15rem] md:max-w-md lg:max-w-lg">
<PopoverButton
type="button"
className="p-2 text-black/50 dark:text-white/50 rounded-xl hover:bg-light-secondary dark:hover:bg-dark-secondary active:scale-95 transition duration-200 hover:text-black dark:hover:text-white"
>
<div className="flex flex-row items-center space-x-1">
{
OptimizationModes.find((mode) => mode.key === optimizationMode)
?.icon
}
<p className="text-xs font-medium">
{
OptimizationModes.find((mode) => mode.key === optimizationMode)
?.title
}
</p>
<ChevronDown size={20} />
</div>
</PopoverButton>
<Transition
as={Fragment}
enter="transition ease-out duration-150"
enterFrom="opacity-0 translate-y-1"
enterTo="opacity-100 translate-y-0"
leave="transition ease-in duration-150"
leaveFrom="opacity-100 translate-y-0"
leaveTo="opacity-0 translate-y-1"
>
<PopoverPanel className="absolute z-10 w-64 md:w-[250px] right-0">
<div className="flex flex-col gap-2 bg-light-primary dark:bg-dark-primary border rounded-lg border-light-200 dark:border-dark-200 w-full p-4 max-h-[200px] md:max-h-none overflow-y-auto">
{OptimizationModes.map((mode, i) => (
<PopoverButton
onClick={() => setOptimizationMode(mode.key)}
key={i}
disabled={mode.key === 'quality'}
className={cn(
'p-2 rounded-lg flex flex-col items-start justify-start text-start space-y-1 duration-200 cursor-pointer transition',
optimizationMode === mode.key
? 'bg-light-secondary dark:bg-dark-secondary'
: 'hover:bg-light-secondary dark:hover:bg-dark-secondary',
mode.key === 'quality' && 'opacity-50 cursor-not-allowed',
)}
>
<div className="flex flex-row items-center space-x-1 text-black dark:text-white">
{mode.icon}
<p className="text-sm font-medium">{mode.title}</p>
</div>
<p className="text-black/70 dark:text-white/70 text-xs">
{mode.description}
</p>
</PopoverButton>
))}
</div>
</PopoverPanel>
</Transition>
</Popover>
);
};
export default Optimization;