diff --git a/src/websocket/messageHandler.ts b/src/websocket/messageHandler.ts index d230386..e915b22 100644 --- a/src/websocket/messageHandler.ts +++ b/src/websocket/messageHandler.ts @@ -10,8 +10,8 @@ import type { BaseChatModel } from '@langchain/core/language_models/chat_models' import type { Embeddings } from '@langchain/core/embeddings'; import logger from '../utils/logger'; import db from '../db'; -import { chats, messages } from '../db/schema'; -import { eq } from 'drizzle-orm'; +import { chats, messages as messagesSchema } from '../db/schema'; +import { eq, asc, gt } from 'drizzle-orm'; import crypto from 'crypto'; type Message = { @@ -71,7 +71,7 @@ const handleEmitterEvents = ( emitter.on('end', () => { ws.send(JSON.stringify({ type: 'messageEnd', messageId: messageId })); - db.insert(messages) + db.insert(messagesSchema) .values({ content: recievedMessage, chatId: chatId, @@ -106,7 +106,9 @@ export const handleMessage = async ( const parsedWSMessage = JSON.parse(message) as WSMessage; const parsedMessage = parsedWSMessage.message; - const id = crypto.randomBytes(7).toString('hex'); + const humanMessageId = + parsedMessage.messageId ?? crypto.randomBytes(7).toString('hex'); + const aiMessageId = crypto.randomBytes(7).toString('hex'); if (!parsedMessage.content) return ws.send( @@ -141,7 +143,7 @@ export const handleMessage = async ( parsedWSMessage.optimizationMode, ); - handleEmitterEvents(emitter, ws, id, parsedMessage.chatId); + handleEmitterEvents(emitter, ws, aiMessageId, parsedMessage.chatId); const chat = await db.query.chats.findFirst({ where: eq(chats.id, parsedMessage.chatId), @@ -159,18 +161,29 @@ export const handleMessage = async ( .execute(); } - await db - .insert(messages) - .values({ - content: parsedMessage.content, - chatId: parsedMessage.chatId, - messageId: id, - role: 'user', - metadata: JSON.stringify({ - createdAt: new Date(), - }), - }) - .execute(); + const messageExists = await db.query.messages.findFirst({ + where: eq(messagesSchema.messageId, humanMessageId), + }); + + if (!messageExists) { + await db + .insert(messagesSchema) + .values({ + content: parsedMessage.content, + chatId: parsedMessage.chatId, + messageId: humanMessageId, + role: 'user', + metadata: JSON.stringify({ + createdAt: new Date(), + }), + }) + .execute(); + } else { + await db + .delete(messagesSchema) + .where(gt(messagesSchema.id, messageExists.id)) + .execute(); + } } else { ws.send( JSON.stringify({