diff --git a/checkpoint_summary_filter.py b/checkpoint_summary_filter.py index 40202c1..9a3cbac 100644 --- a/checkpoint_summary_filter.py +++ b/checkpoint_summary_filter.py @@ -2,9 +2,9 @@ title: Checkpoint Summary Filter author: projectmoon author_url: https://git.agnos.is/projectmoon/open-webui-filters -version: 0.1.0 +version: 0.2.0 license: AGPL-3.0+ -required_open_webui_version: 0.3.9 +required_open_webui_version: 0.3.29 """ # Documentation: https://git.agnos.is/projectmoon/open-webui-filters @@ -28,16 +28,20 @@ from chromadb import Collection as ChromaCollection from chromadb.api.types import Document as ChromaDocument # OpenWebUI imports -from config import CHROMA_CLIENT -from apps.rag.main import app as rag_app -from apps.ollama.main import app as ollama_app -from apps.ollama.main import show_model_info, ModelNameForm -from utils.misc import get_last_user_message, get_last_assistant_message -from main import generate_chat_completions +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT +from open_webui.apps.rag.main import app as rag_app +from open_webui.apps.ollama.main import app as ollama_app +from open_webui.apps.ollama.main import show_model_info, ModelNameForm +from open_webui.utils.misc import get_last_user_message, get_last_assistant_message +from open_webui.main import generate_chat_completions -from apps.webui.models.chats import Chats -from apps.webui.models.models import Models -from apps.webui.models.users import Users +from open_webui.apps.webui.models.chats import Chats +from open_webui.apps.webui.models.models import Models +from open_webui.apps.webui.models.users import Users + +# Why refactor when you can janky monkey patch? This will be fixed at +# some point. +CHROMA_CLIENT = VECTOR_DB_CLIENT.client # Embedding (not yet used) EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION @@ -70,6 +74,16 @@ summary. If the conversation is a story or roleplaying sesison, only use the names of the characters, places, and events in the story. """.replace("\n", " ").strip() +# yoinked from stack overflow. hack to get user into +# generate_chat_completions. this is used to turn the __user__ dict +# given to the filter into a thing that the main OpenWebUI code can +# understand for calling its chat completion endpoint internally. +class BlackMagicDictionary(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + class Message(TypedDict): id: NotRequired[str] role: str @@ -97,6 +111,7 @@ class SummarizerResponse(BaseModel): class Summarizer(BaseModel): messages: List[dict] model: str + user: Any prompt: str = SUMMARIZER_PROMPT async def summarize(self) -> Optional[SummarizerResponse]: @@ -107,7 +122,6 @@ class Summarizer(BaseModel): } messages = [sys_message] + self.messages + [user_message] - request = { "model": self.model, "messages": messages, @@ -115,7 +129,7 @@ class Summarizer(BaseModel): "keep_alive": "10s" } - resp = await generate_chat_completions(request) + resp = await generate_chat_completions(request, user=self.user) if "choices" in resp and len(resp["choices"]) > 0: content: str = resp["choices"][0]["message"]["content"] return SummarizerResponse(summary=content) @@ -167,6 +181,7 @@ class Checkpointer(BaseModel): messages: List[dict]=[] # stripped set of messages full_messages: List[dict]=[] # all the messages embedding_func: EmbeddingFunc=(lambda a: 0) + user: Optional[Any] = None collection_name: ClassVar[str] = "chat_checkpoints" @@ -230,7 +245,7 @@ class Checkpointer(BaseModel): ) async def create_checkpoint(self) -> str: - summarizer = Summarizer(model=self.summarizer_model, messages=self.messages) + summarizer = Summarizer(model=self.summarizer_model, messages=self.messages, user=self.user) resp = await summarizer.summarize() if resp: slug = self._calculate_slug() @@ -531,7 +546,6 @@ class Filter: return message_chain - async def create_checkpoint( self, messages: List[dict], @@ -569,7 +583,8 @@ class Filter: summarizer_model=summarizer_model, chroma_client=CHROMA_CLIENT, full_messages=messages, - messages=message_chain + messages=message_chain, + user=BlackMagicDictionary(self.user) ) try: @@ -752,7 +767,6 @@ class Filter: last_checkpointed_id=checkpoint.message_id if checkpoint else None ) - print(f"[{self.session_info.chat_id}] Done checking for summarization") return body