Checkpoint Filter: Fix for refactored OpenWebUI

This commit is contained in:
projectmoon 2024-09-25 21:49:29 +02:00
parent 68e0be9608
commit 72702924be
1 changed files with 31 additions and 17 deletions

View File

@ -2,9 +2,9 @@
title: Checkpoint Summary Filter
author: projectmoon
author_url: https://git.agnos.is/projectmoon/open-webui-filters
version: 0.1.0
version: 0.2.0
license: AGPL-3.0+
required_open_webui_version: 0.3.9
required_open_webui_version: 0.3.29
"""
# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
@ -28,16 +28,20 @@ from chromadb import Collection as ChromaCollection
from chromadb.api.types import Document as ChromaDocument
# OpenWebUI imports
from config import CHROMA_CLIENT
from apps.rag.main import app as rag_app
from apps.ollama.main import app as ollama_app
from apps.ollama.main import show_model_info, ModelNameForm
from utils.misc import get_last_user_message, get_last_assistant_message
from main import generate_chat_completions
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import show_model_info, ModelNameForm
from open_webui.utils.misc import get_last_user_message, get_last_assistant_message
from open_webui.main import generate_chat_completions
from apps.webui.models.chats import Chats
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from open_webui.apps.webui.models.chats import Chats
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import Users
# Why refactor when you can janky monkey patch? This will be fixed at
# some point.
CHROMA_CLIENT = VECTOR_DB_CLIENT.client
# Embedding (not yet used)
EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
@ -70,6 +74,16 @@ summary. If the conversation is a story or roleplaying sesison, only use the
names of the characters, places, and events in the story.
""".replace("\n", " ").strip()
# yoinked from stack overflow. hack to get user into
# generate_chat_completions. this is used to turn the __user__ dict
# given to the filter into a thing that the main OpenWebUI code can
# understand for calling its chat completion endpoint internally.
class BlackMagicDictionary(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Message(TypedDict):
id: NotRequired[str]
role: str
@ -97,6 +111,7 @@ class SummarizerResponse(BaseModel):
class Summarizer(BaseModel):
messages: List[dict]
model: str
user: Any
prompt: str = SUMMARIZER_PROMPT
async def summarize(self) -> Optional[SummarizerResponse]:
@ -107,7 +122,6 @@ class Summarizer(BaseModel):
}
messages = [sys_message] + self.messages + [user_message]
request = {
"model": self.model,
"messages": messages,
@ -115,7 +129,7 @@ class Summarizer(BaseModel):
"keep_alive": "10s"
}
resp = await generate_chat_completions(request)
resp = await generate_chat_completions(request, user=self.user)
if "choices" in resp and len(resp["choices"]) > 0:
content: str = resp["choices"][0]["message"]["content"]
return SummarizerResponse(summary=content)
@ -167,6 +181,7 @@ class Checkpointer(BaseModel):
messages: List[dict]=[] # stripped set of messages
full_messages: List[dict]=[] # all the messages
embedding_func: EmbeddingFunc=(lambda a: 0)
user: Optional[Any] = None
collection_name: ClassVar[str] = "chat_checkpoints"
@ -230,7 +245,7 @@ class Checkpointer(BaseModel):
)
async def create_checkpoint(self) -> str:
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages)
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages, user=self.user)
resp = await summarizer.summarize()
if resp:
slug = self._calculate_slug()
@ -531,7 +546,6 @@ class Filter:
return message_chain
async def create_checkpoint(
self,
messages: List[dict],
@ -569,7 +583,8 @@ class Filter:
summarizer_model=summarizer_model,
chroma_client=CHROMA_CLIENT,
full_messages=messages,
messages=message_chain
messages=message_chain,
user=BlackMagicDictionary(self.user)
)
try:
@ -752,7 +767,6 @@ class Filter:
last_checkpointed_id=checkpoint.message_id if checkpoint else None
)
print(f"[{self.session_info.chat_id}] Done checking for summarization")
return body