From 3bfaf9be2804387ce91d2f8bd4da5a88c0f78ce8 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns Date: Sat, 18 May 2024 13:10:39 +0530 Subject: [PATCH] feat(app): add suggestion generation --- src/lib/outputParsers/listLineOutputParser.ts | 2 +- ui/components/Chat.tsx | 1 + ui/components/ChatWindow.tsx | 31 +++++++++++++++++-- ui/lib/actions.ts | 22 +++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 ui/lib/actions.ts diff --git a/src/lib/outputParsers/listLineOutputParser.ts b/src/lib/outputParsers/listLineOutputParser.ts index 4fde080..57a9bbc 100644 --- a/src/lib/outputParsers/listLineOutputParser.ts +++ b/src/lib/outputParsers/listLineOutputParser.ts @@ -9,7 +9,7 @@ class LineListOutputParser extends BaseOutputParser { constructor(args?: LineListOutputParserArgs) { super(); - this.key = args.key || this.key; + this.key = args.key ?? this.key; } static lc_name() { diff --git a/ui/components/Chat.tsx b/ui/components/Chat.tsx index ddd2957..7b0c1b3 100644 --- a/ui/components/Chat.tsx +++ b/ui/components/Chat.tsx @@ -63,6 +63,7 @@ const Chat = ({ dividerRef={isLast ? dividerRef : undefined} isLast={isLast} rewrite={rewrite} + sendMessage={sendMessage} /> {!isLast && msg.role === 'assistant' && (
diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 1cc6ae0..5f266b5 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -1,18 +1,20 @@ 'use client'; -import { useEffect, useState } from 'react'; +import { useEffect, useRef, useState } from 'react'; import { Document } from '@langchain/core/documents'; import Navbar from './Navbar'; import Chat from './Chat'; import EmptyChat from './EmptyChat'; import { toast } from 'sonner'; import { useSearchParams } from 'next/navigation'; +import { getSuggestions } from '@/lib/actions'; export type Message = { id: string; createdAt: Date; content: string; role: 'user' | 'assistant'; + suggestions?: string[]; sources?: Document[]; }; @@ -145,10 +147,15 @@ const ChatWindow = () => { const [chatHistory, setChatHistory] = useState<[string, string][]>([]); const [messages, setMessages] = useState([]); + const messagesRef = useRef([]); const [loading, setLoading] = useState(false); const [messageAppeared, setMessageAppeared] = useState(false); const [focusMode, setFocusMode] = useState('webSearch'); + useEffect(() => { + messagesRef.current = messages; + }, [messages]); + const sendMessage = async (message: string) => { if (loading) return; setLoading(true); @@ -177,7 +184,7 @@ const ChatWindow = () => { }, ]); - const messageHandler = (e: MessageEvent) => { + const messageHandler = async (e: MessageEvent) => { const data = JSON.parse(e.data); if (data.type === 'error') { @@ -239,8 +246,28 @@ const ChatWindow = () => { ['human', message], ['assistant', recievedMessage], ]); + ws?.removeEventListener('message', messageHandler); setLoading(false); + + const lastMsg = messagesRef.current[messagesRef.current.length - 1]; + + if ( + lastMsg.role === 'assistant' && + lastMsg.sources && + lastMsg.sources.length > 0 && + !lastMsg.suggestions + ) { + const suggestions = await getSuggestions(messagesRef.current); + setMessages((prev) => + prev.map((msg) => { + if (msg.id === lastMsg.id) { + return { ...msg, suggestions: suggestions }; + } + return msg; + }), + ); + } } }; diff --git a/ui/lib/actions.ts b/ui/lib/actions.ts new file mode 100644 index 0000000..d7eb71f --- /dev/null +++ b/ui/lib/actions.ts @@ -0,0 +1,22 @@ +import { Message } from '@/components/ChatWindow'; + +export const getSuggestions = async (chatHisory: Message[]) => { + const chatModel = localStorage.getItem('chatModel'); + const chatModelProvider = localStorage.getItem('chatModelProvider'); + + const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + chat_history: chatHisory, + chat_model: chatModel, + chat_model_provider: chatModelProvider, + }), + }); + + const data = (await res.json()) as { suggestions: string[] }; + + return data.suggestions; +};