diff --git a/gpu_layer_scaler.py b/gpu_layer_scaler.py index 359ea15..e91e21e 100644 --- a/gpu_layer_scaler.py +++ b/gpu_layer_scaler.py @@ -6,6 +6,9 @@ version: 0.1.0 required_open_webui_version: 0.3.9 """ +# Documentation: https://git.agnos.is/projectmoon/open-webui-filters + +# System Imports import chromadb from chromadb import ClientAPI as ChromaAPI from chromadb import Collection as ChromaCollection diff --git a/memories.py b/memories.py index 3936466..71d30e4 100644 --- a/memories.py +++ b/memories.py @@ -2,10 +2,16 @@ title: Memory Filter author: projectmoon author_url: https://git.agnos.is/projectmoon/open-webui-filters -version: 0.0.1 -required_open_webui_version: 0.3.8 +version: 0.0.2 +required_open_webui_version: 0.3.9 """ +# Documentation: https://git.agnos.is/projectmoon/open-webui-filters +# +# Changelog: +# 0.0.1 - Initial release, proof of concept +# 0.0.2 - Slightly less hacky (but still hacky) way of getting chat IDs + # System imports import asyncio import hashlib @@ -429,6 +435,24 @@ class Story(BaseModel): # Utils +class SessionInfo(BaseModel): + chat_id: str + message_id: str + session_id: str + +def extract_session_info(event_emitter) -> Optional[SessionInfo]: + """The latest innovation in hacky workarounds.""" + try: + info = event_emitter.__closure__[0].cell_contents + return SessionInfo( + chat_id=info["chat_id"], + message_id=info["message_id"], + session_id=info["session_id"] + ) + except: + return None + + def create_enrichment_summary_prompt( narrative: str, character_details: List[str], @@ -501,11 +525,6 @@ def create_context(results: SummarizerResponse) -> Optional[str]: return message -def write_log(text): - with open(f"/tmp/test-memories", "a") as file: - file.write(text + "\n") - - def split_messages(messages, keep_amount): if len(messages) <= keep_amount: return messages[:], [] @@ -621,17 +640,18 @@ class Filter: __event_emitter__: Callable[[Any], Awaitable[None]], ) -> dict: # Useful things to have around. + self.session_info = extract_session_info(__event_emitter__) self.event_emitter = __event_emitter__ self.summarizer_model_id = self.valves.summarizer_model(body) await self.send_outlet_status(__event_emitter__, False) messages = body['messages'] - convo_id = self.extract_convo_id(messages) # summarize into plot points. summary = await self.summarize(messages) story = Story( - convo_id=convo_id, client=CHROMA_CLIENT, + convo_id=self.session_info.chat_id, + client=CHROMA_CLIENT, embedding_func=EMBEDDING_FUNCTION, messages=messages ) @@ -693,7 +713,10 @@ class Filter: return await summarizer.summarize() - async def enrich(self, story: Story, messages) -> SummarizerResponse: + async def enrich(self, story: Story, messages) -> Optional[SummarizerResponse]: + if len(messages) < 2: + return None + await self.set_enriching_status("searching") query_generation_result = await self.generate_enrichment_queries(messages) character_results = [result @@ -710,7 +733,8 @@ class Filter: async def update_system_message(self, messages, system_message): story = Story( - convo_id=None, client=CHROMA_CLIENT, + convo_id=self.session_info.chat_id, + client=CHROMA_CLIENT, embedding_func=EMBEDDING_FUNCTION, messages=messages ) @@ -720,8 +744,11 @@ class Filter: if story.convo_id == "": return - enrichment_summary: SummarizerResponse = await self.enrich(story, messages) - context = create_context(enrichment_summary) + enrichment_summary: Optional[SummarizerResponse] = await self.enrich(story, messages) + if enrichment_summary: + context = create_context(enrichment_summary) + else: + context = None if context: system_message["content"] += context @@ -734,8 +761,10 @@ class Filter: __event_emitter__: Callable[[Any], Awaitable[None]] ) -> dict: # Useful properties to have around. + self.session_info = extract_session_info(__event_emitter__) self.event_emitter = __event_emitter__ self.summarizer_model_id = self.valves.summarizer_model(body) + await self.set_enriching_status("init") messages = body["messages"]