Checkpoint Filter: Fix for refactored OpenWebUI
This commit is contained in:
parent
68e0be9608
commit
72702924be
|
@ -2,9 +2,9 @@
|
|||
title: Checkpoint Summary Filter
|
||||
author: projectmoon
|
||||
author_url: https://git.agnos.is/projectmoon/open-webui-filters
|
||||
version: 0.1.0
|
||||
version: 0.2.0
|
||||
license: AGPL-3.0+
|
||||
required_open_webui_version: 0.3.9
|
||||
required_open_webui_version: 0.3.29
|
||||
"""
|
||||
|
||||
# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
|
||||
|
@ -28,16 +28,20 @@ from chromadb import Collection as ChromaCollection
|
|||
from chromadb.api.types import Document as ChromaDocument
|
||||
|
||||
# OpenWebUI imports
|
||||
from config import CHROMA_CLIENT
|
||||
from apps.rag.main import app as rag_app
|
||||
from apps.ollama.main import app as ollama_app
|
||||
from apps.ollama.main import show_model_info, ModelNameForm
|
||||
from utils.misc import get_last_user_message, get_last_assistant_message
|
||||
from main import generate_chat_completions
|
||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.apps.rag.main import app as rag_app
|
||||
from open_webui.apps.ollama.main import app as ollama_app
|
||||
from open_webui.apps.ollama.main import show_model_info, ModelNameForm
|
||||
from open_webui.utils.misc import get_last_user_message, get_last_assistant_message
|
||||
from open_webui.main import generate_chat_completions
|
||||
|
||||
from apps.webui.models.chats import Chats
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
|
||||
# Why refactor when you can janky monkey patch? This will be fixed at
|
||||
# some point.
|
||||
CHROMA_CLIENT = VECTOR_DB_CLIENT.client
|
||||
|
||||
# Embedding (not yet used)
|
||||
EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
||||
|
@ -70,6 +74,16 @@ summary. If the conversation is a story or roleplaying sesison, only use the
|
|||
names of the characters, places, and events in the story.
|
||||
""".replace("\n", " ").strip()
|
||||
|
||||
# yoinked from stack overflow. hack to get user into
|
||||
# generate_chat_completions. this is used to turn the __user__ dict
|
||||
# given to the filter into a thing that the main OpenWebUI code can
|
||||
# understand for calling its chat completion endpoint internally.
|
||||
class BlackMagicDictionary(dict):
|
||||
"""dot.notation access to dictionary attributes"""
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
class Message(TypedDict):
|
||||
id: NotRequired[str]
|
||||
role: str
|
||||
|
@ -97,6 +111,7 @@ class SummarizerResponse(BaseModel):
|
|||
class Summarizer(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str
|
||||
user: Any
|
||||
prompt: str = SUMMARIZER_PROMPT
|
||||
|
||||
async def summarize(self) -> Optional[SummarizerResponse]:
|
||||
|
@ -107,7 +122,6 @@ class Summarizer(BaseModel):
|
|||
}
|
||||
|
||||
messages = [sys_message] + self.messages + [user_message]
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
|
@ -115,7 +129,7 @@ class Summarizer(BaseModel):
|
|||
"keep_alive": "10s"
|
||||
}
|
||||
|
||||
resp = await generate_chat_completions(request)
|
||||
resp = await generate_chat_completions(request, user=self.user)
|
||||
if "choices" in resp and len(resp["choices"]) > 0:
|
||||
content: str = resp["choices"][0]["message"]["content"]
|
||||
return SummarizerResponse(summary=content)
|
||||
|
@ -167,6 +181,7 @@ class Checkpointer(BaseModel):
|
|||
messages: List[dict]=[] # stripped set of messages
|
||||
full_messages: List[dict]=[] # all the messages
|
||||
embedding_func: EmbeddingFunc=(lambda a: 0)
|
||||
user: Optional[Any] = None
|
||||
|
||||
collection_name: ClassVar[str] = "chat_checkpoints"
|
||||
|
||||
|
@ -230,7 +245,7 @@ class Checkpointer(BaseModel):
|
|||
)
|
||||
|
||||
async def create_checkpoint(self) -> str:
|
||||
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages)
|
||||
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages, user=self.user)
|
||||
resp = await summarizer.summarize()
|
||||
if resp:
|
||||
slug = self._calculate_slug()
|
||||
|
@ -531,7 +546,6 @@ class Filter:
|
|||
|
||||
return message_chain
|
||||
|
||||
|
||||
async def create_checkpoint(
|
||||
self,
|
||||
messages: List[dict],
|
||||
|
@ -569,7 +583,8 @@ class Filter:
|
|||
summarizer_model=summarizer_model,
|
||||
chroma_client=CHROMA_CLIENT,
|
||||
full_messages=messages,
|
||||
messages=message_chain
|
||||
messages=message_chain,
|
||||
user=BlackMagicDictionary(self.user)
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -752,7 +767,6 @@ class Filter:
|
|||
last_checkpointed_id=checkpoint.message_id if checkpoint else None
|
||||
)
|
||||
|
||||
print(f"[{self.session_info.chat_id}] Done checking for summarization")
|
||||
return body
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue