From 61044715e97c4d47343284e6d63b4d96b4c00972 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns Date: Sat, 29 Jun 2024 11:09:31 +0530 Subject: [PATCH] feat(msg-handler): update message types --- src/websocket/messageHandler.ts | 84 ++++++++++++++++++++++++++++----- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/src/websocket/messageHandler.ts b/src/websocket/messageHandler.ts index 98f67c2..0afda9f 100644 --- a/src/websocket/messageHandler.ts +++ b/src/websocket/messageHandler.ts @@ -9,11 +9,21 @@ import handleRedditSearch from '../agents/redditSearchAgent'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import logger from '../utils/logger'; +import db from '../db'; +import { chats, messages } from '../db/schema'; +import { eq } from 'drizzle-orm'; +import crypto from 'crypto'; type Message = { - type: string; + messageId: string; + chatId: string; content: string; +}; + +type WSMessage = { + message: Message; copilot: boolean; + type: string; focusMode: string; history: Array<[string, string]>; }; @@ -30,8 +40,12 @@ const searchHandlers = { const handleEmitterEvents = ( emitter: EventEmitter, ws: WebSocket, - id: string, + messageId: string, + chatId: string, ) => { + let recievedMessage = ''; + let sources = []; + emitter.on('data', (data) => { const parsedData = JSON.parse(data); if (parsedData.type === 'response') { @@ -39,21 +53,36 @@ const handleEmitterEvents = ( JSON.stringify({ type: 'message', data: parsedData.data, - messageId: id, + messageId: messageId, }), ); + recievedMessage += parsedData.data; } else if (parsedData.type === 'sources') { ws.send( JSON.stringify({ type: 'sources', data: parsedData.data, - messageId: id, + messageId: messageId, }), ); + sources = parsedData.data; } }); emitter.on('end', () => { - ws.send(JSON.stringify({ type: 'messageEnd', messageId: id })); + ws.send(JSON.stringify({ type: 'messageEnd', messageId: messageId })); + + db.insert(messages) + .values({ + content: recievedMessage, + chatId: chatId, + messageId: messageId, + role: 'assistant', + metadata: JSON.stringify({ + createdAt: new Date(), + ...(sources && sources.length > 0 && { sources }), + }), + }) + .execute(); }); emitter.on('error', (data) => { const parsedData = JSON.parse(data); @@ -74,8 +103,10 @@ export const handleMessage = async ( embeddings: Embeddings, ) => { try { - const parsedMessage = JSON.parse(message) as Message; - const id = Math.random().toString(36).substring(7); + const parsedWSMessage = JSON.parse(message) as WSMessage; + const parsedMessage = parsedWSMessage.message; + + const id = crypto.randomBytes(7).toString('hex'); if (!parsedMessage.content) return ws.send( @@ -86,7 +117,7 @@ export const handleMessage = async ( }), ); - const history: BaseMessage[] = parsedMessage.history.map((msg) => { + const history: BaseMessage[] = parsedWSMessage.history.map((msg) => { if (msg[0] === 'human') { return new HumanMessage({ content: msg[1], @@ -98,8 +129,9 @@ export const handleMessage = async ( } }); - if (parsedMessage.type === 'message') { - const handler = searchHandlers[parsedMessage.focusMode]; + if (parsedWSMessage.type === 'message') { + const handler = searchHandlers[parsedWSMessage.focusMode]; + if (handler) { const emitter = handler( parsedMessage.content, @@ -107,7 +139,37 @@ export const handleMessage = async ( llm, embeddings, ); - handleEmitterEvents(emitter, ws, id); + + handleEmitterEvents(emitter, ws, id, parsedMessage.chatId); + + const chat = await db.query.chats.findFirst({ + where: eq(chats.id, parsedMessage.chatId), + }); + + if (!chat) { + await db + .insert(chats) + .values({ + id: parsedMessage.chatId, + title: parsedMessage.content, + createdAt: new Date().toString(), + focusMode: parsedWSMessage.focusMode, + }) + .execute(); + } + + await db + .insert(messages) + .values({ + content: parsedMessage.content, + chatId: parsedMessage.chatId, + messageId: id, + role: 'user', + metadata: JSON.stringify({ + createdAt: new Date(), + }), + }) + .execute(); } else { ws.send( JSON.stringify({