796 lines
26 KiB
Python
796 lines
26 KiB
Python
"""
|
|
title: Checkpoint Summary Filter
|
|
author: projectmoon
|
|
author_url: https://git.agnos.is/projectmoon/open-webui-filters
|
|
version: 0.1.0
|
|
license: AGPL-3.0+
|
|
required_open_webui_version: 0.3.9
|
|
"""
|
|
|
|
# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
|
|
|
|
# System imports
|
|
import asyncio
|
|
import hashlib
|
|
import uuid
|
|
import json
|
|
import re
|
|
import logging
|
|
|
|
from typing import Optional, List, Dict, Callable, Any, NewType, Tuple, Awaitable, ClassVar
|
|
from typing_extensions import TypedDict, NotRequired
|
|
from collections import deque
|
|
|
|
# Libraries available to OpenWebUI
|
|
from pydantic import BaseModel as PydanticBaseModel, Field
|
|
import chromadb
|
|
from chromadb import Collection as ChromaCollection
|
|
from chromadb.api.types import Document as ChromaDocument
|
|
|
|
# OpenWebUI imports
|
|
from config import CHROMA_CLIENT
|
|
from apps.rag.main import app as rag_app
|
|
from apps.ollama.main import app as ollama_app
|
|
from apps.ollama.main import show_model_info, ModelNameForm
|
|
from utils.misc import get_last_user_message, get_last_assistant_message
|
|
from main import generate_chat_completions
|
|
|
|
from apps.webui.models.chats import Chats
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.models.users import Users
|
|
|
|
# Embedding (not yet used)
|
|
EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
|
EmbeddingFunc = NewType('EmbeddingFunc', Callable[[str], List[Any]])
|
|
|
|
# Prompts
|
|
SUMMARIZER_PROMPT = """
|
|
### Main Instructions
|
|
|
|
You are a chat conversation summarizer. Your task is to summarize the given
|
|
portion of an ongoing conversation. First, determine if the conversation is
|
|
a regular chat between the user and the assistant, or if the conversation is
|
|
part of a story or role-playing session.
|
|
|
|
Summarize the important parts of the given chat between the user and the
|
|
assistant. Limit your summary to one paragraph. Make sure your summary is
|
|
detailed. Write the summary as if you are summarizing part of a larger
|
|
conversation.
|
|
|
|
### Regular Chat
|
|
|
|
If the conversation is a regular chat, write your summary referring to the
|
|
ongoing conversation as a chat. Refer to the user and the assistant as user
|
|
and assistant. Do not refer to yourself as the assistant.
|
|
|
|
### Story or Role-Playing Session
|
|
|
|
If the conversation is a story or role-playing session, write your summary
|
|
referring to the conversation as an ongoing story. Do not refer to the user
|
|
or assistant in your summary. Only use the names of the characters, places,
|
|
and events in the story.
|
|
""".replace("\n", " ").strip()
|
|
|
|
class Message(TypedDict):
|
|
id: NotRequired[str]
|
|
role: str
|
|
content: str
|
|
|
|
class MessageInsertMetadata(TypedDict):
|
|
role: str
|
|
chapter: str
|
|
|
|
class MessageInsert(TypedDict):
|
|
message_id: str
|
|
content: str
|
|
metadata: MessageInsertMetadata
|
|
embeddings: List[Any]
|
|
|
|
|
|
class BaseModel(PydanticBaseModel):
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
class SummarizerResponse(BaseModel):
|
|
summary: str
|
|
|
|
|
|
class Summarizer(BaseModel):
|
|
messages: List[dict]
|
|
model: str
|
|
prompt: str = SUMMARIZER_PROMPT
|
|
|
|
async def summarize(self) -> Optional[SummarizerResponse]:
|
|
sys_message: Message = { "role": "system", "content": SUMMARIZER_PROMPT }
|
|
user_message: Message = {
|
|
"role": "user",
|
|
"content": "Make a detailed summary of everything up to this point."
|
|
}
|
|
|
|
messages = [sys_message] + self.messages + [user_message]
|
|
|
|
request = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"keep_alive": "10s"
|
|
}
|
|
|
|
resp = await generate_chat_completions(request)
|
|
if "choices" in resp and len(resp["choices"]) > 0:
|
|
content: str = resp["choices"][0]["message"]["content"]
|
|
return SummarizerResponse(summary=content)
|
|
else:
|
|
return None
|
|
|
|
|
|
class Checkpoint(BaseModel):
|
|
# chat id
|
|
chat_id: str
|
|
|
|
# the message ID this checkpoint was created from.
|
|
message_id: str
|
|
|
|
# index of the message in the message input array. in the inlet
|
|
# function, we do not have access to incoming message ids for some
|
|
# reason. used as a fallback to drop old context when
|
|
message_index: int = 0
|
|
|
|
# the "slug", or chain of messages, that led to this point.
|
|
slug: str
|
|
|
|
# actual summary of messages.
|
|
summary: str
|
|
|
|
# if we try to put a type hint on this, it gets mad.
|
|
@staticmethod
|
|
def from_json(obj: dict):
|
|
try:
|
|
return Checkpoint(
|
|
chat_id=obj["chat_id"],
|
|
message_id=obj["message_id"],
|
|
message_index=obj["message_index"],
|
|
slug=obj["slug"],
|
|
summary=obj["summary"]
|
|
)
|
|
except:
|
|
return None
|
|
|
|
def to_json(self) -> str:
|
|
return self.model_dump_json()
|
|
|
|
|
|
class Checkpointer(BaseModel):
|
|
"""Manages summary checkpoints in a single chat."""
|
|
chat_id: str
|
|
summarizer_model: str = ""
|
|
chroma_client: chromadb.ClientAPI
|
|
messages: List[dict]=[] # stripped set of messages
|
|
full_messages: List[dict]=[] # all the messages
|
|
embedding_func: EmbeddingFunc=(lambda a: 0)
|
|
|
|
collection_name: ClassVar[str] = "chat_checkpoints"
|
|
|
|
def _get_collection(self) -> ChromaCollection:
|
|
return self.chroma_client.get_or_create_collection(
|
|
name=Checkpointer.collection_name
|
|
)
|
|
|
|
|
|
def _insert_checkpoint(self, checkpoint: Checkpoint):
|
|
coll = self._get_collection()
|
|
checkpoint_doc = checkpoint.to_json()
|
|
# Insert the checkpoint itself with slug as ID.
|
|
coll.upsert(
|
|
ids=[checkpoint.slug],
|
|
documents=[checkpoint_doc],
|
|
metadatas=[{ "chat_id": self.chat_id, "type": "checkpoint" }],
|
|
embeddings=[self.embedding_func(checkpoint_doc)]
|
|
)
|
|
|
|
# Update the chat info doc for this chat.
|
|
coll.upsert(
|
|
ids=[self.chat_id],
|
|
documents=[json.dumps({ "current_checkpoint": checkpoint.slug })],
|
|
embeddings=[self.embedding_func(self.chat_id)]
|
|
)
|
|
|
|
def _calculate_slug(self) -> Optional[str]:
|
|
if len(self.messages) == 0:
|
|
return None
|
|
|
|
message_ids = [msg["id"] for msg in reversed(self.messages)]
|
|
slug = "|".join(message_ids)
|
|
return hashlib.sha256(slug.encode()).hexdigest()
|
|
|
|
def _get_state(self):
|
|
resp = self._get_collection().get(ids=[self.chat_id], include=["documents"])
|
|
state: dict = (json.loads(resp["documents"][0])
|
|
if resp["documents"] and len(resp["documents"]) > 0
|
|
else { "current_checkpoint": None })
|
|
return state
|
|
|
|
|
|
def _find_message_index(self, message_id: str) -> Optional[int]:
|
|
for idx, message in enumerate(self.full_messages):
|
|
if message["id"] == message_id:
|
|
return idx
|
|
return None
|
|
|
|
def nuke_checkpoints(self):
|
|
"""Delete all checkpoints for this chat."""
|
|
coll = self._get_collection()
|
|
|
|
checkpoints = coll.get(
|
|
include=["documents"],
|
|
where={"chat_id": self.chat_id}
|
|
)
|
|
|
|
self._get_collection().delete(
|
|
ids=[self.chat_id] + checkpoints["ids"]
|
|
)
|
|
|
|
async def create_checkpoint(self) -> str:
|
|
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages)
|
|
resp = await summarizer.summarize()
|
|
if resp:
|
|
slug = self._calculate_slug()
|
|
checkpoint_message = self.messages[-1]
|
|
checkpoint_index = self._find_message_index(checkpoint_message["id"])
|
|
|
|
checkpoint = Checkpoint(
|
|
chat_id = self.chat_id,
|
|
slug = self._calculate_slug(),
|
|
message_id = checkpoint_message["id"],
|
|
message_index = checkpoint_index,
|
|
summary = resp.summary
|
|
)
|
|
|
|
self._insert_checkpoint(checkpoint)
|
|
return slug
|
|
|
|
def get_checkpoint(self, slug: Optional[str]) -> Optional[Checkpoint]:
|
|
if not slug:
|
|
return None
|
|
|
|
resp = self._get_collection().get(ids=[slug], include=["documents"])
|
|
checkpoint = (resp["documents"][0]
|
|
if resp["documents"] and len(resp["documents"]) > 0
|
|
else None)
|
|
|
|
if checkpoint:
|
|
return Checkpoint.from_json(json.loads(checkpoint))
|
|
else:
|
|
return None
|
|
|
|
def get_current_checkpoint(self) -> Optional[Checkpoint]:
|
|
state = self._get_state()
|
|
return self.get_checkpoint(state["current_checkpoint"])
|
|
|
|
|
|
#########################
|
|
# Utilities
|
|
#########################
|
|
|
|
class SessionInfo(BaseModel):
|
|
chat_id: str
|
|
message_id: str
|
|
session_id: str
|
|
|
|
|
|
def extract_session_info(event_emitter) -> Optional[SessionInfo]:
|
|
"""The latest innovation in hacky workarounds."""
|
|
try:
|
|
info = event_emitter.__closure__[0].cell_contents
|
|
return SessionInfo(
|
|
chat_id=info["chat_id"],
|
|
message_id=info["message_id"],
|
|
session_id=info["session_id"]
|
|
)
|
|
except:
|
|
return None
|
|
|
|
|
|
def predicted_token_use(messages) -> int:
|
|
"""Parse most recent message to calculate estimated token use."""
|
|
if len(self.messages == 0):
|
|
return 0
|
|
|
|
# Naive assumptions:
|
|
# - 1 word = 1 token.
|
|
# - 1 period, comma, or colon = 1 token
|
|
message = messages[-1]
|
|
return len(list(filter(None, re.split(r"\s|(;)|(,)|(\.)|(:)|\n", message))))
|
|
|
|
def is_big_convo(messages, num_ctx: int=8192) -> bool:
|
|
"""
|
|
Attempt to detect large pre-existing conversation by looking at
|
|
recent eval counts from messages and comparing against given
|
|
num_ctx. We check all messages for an eval count that goes above
|
|
the context limit. It doesn't matter where in the message list; if
|
|
it's somewhere in the middle, it means that there was a context
|
|
shift.
|
|
"""
|
|
for message in messages:
|
|
if "info" in message:
|
|
tokens_used = (message["info"]["eval_count"] +
|
|
message["info"]["prompt_eval_count"])
|
|
else:
|
|
tokens_used = 0
|
|
|
|
if tokens_used >= num_ctx:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def hit_context_limit(
|
|
messages,
|
|
num_ctx: int=8192,
|
|
wiggle_room: int=1000
|
|
) -> Tuple[bool, int]:
|
|
"""
|
|
Determine if we've hit the context limit, within some reasonable
|
|
estimation. We have a defined 'wiggle room' that is subtracted
|
|
from the num_ctx parameter, in order to capture near-filled
|
|
contexts. We do it this way because we're summarizing on output,
|
|
rather than before input (inlet function doesn't have enough
|
|
info).
|
|
"""
|
|
if len(messages) == 0:
|
|
return False, 0
|
|
|
|
last_message = messages[-1]
|
|
tokens_used = 0
|
|
if "info" in last_message:
|
|
tokens_used = (last_message["info"]["eval_count"] +
|
|
last_message["info"]["prompt_eval_count"])
|
|
|
|
if tokens_used >= (num_ctx - wiggle_room):
|
|
amount_over = tokens_used - num_ctx
|
|
amount_over = 0 if amount_over < 0 else amount_over
|
|
return True, amount_over
|
|
else:
|
|
return False, 0
|
|
|
|
def extract_base_model_id(model: dict) -> Optional[str]:
|
|
if "base_model_id" not in model["info"]:
|
|
return None
|
|
|
|
base_model_id = model["info"]["base_model_id"]
|
|
if not base_model_id:
|
|
base_model_id = model["id"]
|
|
|
|
return base_model_id
|
|
|
|
def extract_owu_model_param(model_obj: dict, param_name: str):
|
|
"""
|
|
Extract a parameter value from the DB definition of a model
|
|
that is based on another model.
|
|
"""
|
|
if not "params" in model_obj["info"]:
|
|
return None
|
|
|
|
params = model_obj["info"]["params"]
|
|
return params.get(param_name, None)
|
|
|
|
def extract_owu_base_model_param(base_model_id: str, param_name: str):
|
|
"""Extract a parameter value from the DB definition of an ollama base model."""
|
|
base_model = Models.get_model_by_id(base_model_id)
|
|
|
|
if not base_model:
|
|
return None
|
|
|
|
base_model.params = base_model.params.model_dump()
|
|
return base_model.params.get(param_name, None)
|
|
|
|
def extract_ollama_response_param(model: dict, param_name: str):
|
|
"""Extract a parameter value from ollama show API response."""
|
|
if "parameters" not in model:
|
|
return None
|
|
|
|
for line in model["parameters"].splitlines():
|
|
if line.startswith(param_name):
|
|
return line.lstrip(param_name).strip()
|
|
|
|
return None
|
|
|
|
async def get_model_from_ollama(model_id: str, user_id) -> Optional[dict]:
|
|
"""Call ollama show API and return model information."""
|
|
curr_user = Users.get_user_by_id(user_id)
|
|
try:
|
|
return await show_model_info(ModelNameForm(name=model_id), user=curr_user)
|
|
except Exception as e:
|
|
print(f"Could not get model info: {e}")
|
|
return None
|
|
|
|
async def calculate_num_ctx(chat_id: str, user_id, model: dict) -> int:
|
|
"""
|
|
Attempt to discover the current num_ctx parameter in many
|
|
different ways.
|
|
"""
|
|
# first check the open-webui chat parameters.
|
|
chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id)
|
|
if chat:
|
|
# this might look odd, but the chat field is a json blob of
|
|
# useful info.
|
|
chat = json.loads(chat.chat)
|
|
if "params" in chat and "num_ctx" in chat["params"]:
|
|
return chat["params"]["num_ctx"]
|
|
|
|
# then check open web ui model def
|
|
num_ctx = extract_owu_model_param(model, "num_ctx")
|
|
if num_ctx:
|
|
return num_ctx
|
|
|
|
# then check open web ui base model def.
|
|
base_model_id = extract_base_model_id(model)
|
|
if not base_model_id:
|
|
# fall back to default in case of weirdness.
|
|
return 2048
|
|
|
|
num_ctx = extract_owu_base_model_param(base_model_id, "num_ctx")
|
|
if num_ctx:
|
|
return num_ctx
|
|
|
|
# THEN check ollama directly.
|
|
base_model = await get_model_from_ollama(base_model_id, user_id)
|
|
num_ctx = extract_ollama_response_param(base_model, "num_ctx")
|
|
if num_ctx:
|
|
return num_ctx
|
|
|
|
# finally, return default.
|
|
return 2048
|
|
|
|
|
|
|
|
class Filter:
|
|
class Valves(BaseModel):
|
|
def summarizer_model(self, body):
|
|
if self.summarizer_model_id == "":
|
|
return extract_base_model_id(body["model"])
|
|
else:
|
|
return self.summarizer_model_id
|
|
|
|
summarize_large_contexts: bool = Field(
|
|
default=False,
|
|
description=(
|
|
f"Whether or not to use a large context model to summarize large "
|
|
f"pre-existing conversations."
|
|
)
|
|
)
|
|
wiggle_room: int = Field(
|
|
default=1000,
|
|
description=(
|
|
"Amount of token 'wiggle room' for estimating when a context shift occurs. "
|
|
"Subtracted from num_ctx when checking if summarization is needed."
|
|
)
|
|
)
|
|
summarizer_model_id: str = Field(
|
|
default="",
|
|
description="Model used to summarize the conversation. Must be a base model.",
|
|
)
|
|
large_summarizer_model_id: str = Field(
|
|
default="",
|
|
description=(
|
|
"Model used to summarize large pre-existing contexts. "
|
|
"Must be a base model with a context size large enough "
|
|
"to fit the conversation."
|
|
)
|
|
)
|
|
pass
|
|
|
|
class UserValves(BaseModel):
|
|
pass
|
|
|
|
def __init__(self):
|
|
self.valves = self.Valves()
|
|
pass
|
|
|
|
|
|
def load_current_chat(self) -> dict:
|
|
# the chat property of the model is the json blob that holds
|
|
# all the interesting stuff
|
|
chat = (Chats
|
|
.get_chat_by_id_and_user_id(self.session_info.chat_id, self.user["id"])
|
|
.chat)
|
|
|
|
return json.loads(chat)
|
|
|
|
def get_messages_for_checkpointing(self, messages, num_ctx, last_checkpointed_id):
|
|
"""
|
|
Assemble list of messages to checkpoint, based on current
|
|
state and valve settings.
|
|
"""
|
|
message_chain = deque()
|
|
for message in reversed(messages):
|
|
if message["id"] == last_checkpointed_id:
|
|
break
|
|
message_chain.appendleft(message)
|
|
|
|
message_chain = list(message_chain) # the lazy way
|
|
|
|
# now we check if we are a big conversation, and if valve
|
|
# settings allow that kind of summarization.
|
|
summarizer_model = self.valves.summarizer_model
|
|
if is_big_convo(messages, num_ctx) and not self.valves.summarize_large_contexts:
|
|
# must summarize using small model. for now, drop to last
|
|
# N messages.
|
|
print((
|
|
"Dropping all but last 4 messages to summarize "
|
|
"large convo without large model."
|
|
))
|
|
message_chain = message_chain[-4:]
|
|
|
|
return message_chain
|
|
|
|
|
|
async def create_checkpoint(
|
|
self,
|
|
messages: List[dict],
|
|
last_checkpointed_id: Optional[str]=None,
|
|
num_ctx: int=8192
|
|
):
|
|
if len(messages) == 0:
|
|
return
|
|
|
|
print(f"[{self.session_info.chat_id}] Detected context shift. Summarizing.")
|
|
await self.set_summarizing_status(done=False)
|
|
last_message = messages[-1] # should check for role = assistant
|
|
curr_message_id: Optional[str] = (
|
|
last_message["id"] if last_message else None
|
|
)
|
|
|
|
if not curr_message_id:
|
|
return
|
|
|
|
# strip messages down to what is in the current checkpoint.
|
|
message_chain = self.get_messages_for_checkpointing(
|
|
messages, num_ctx, last_checkpointed_id
|
|
)
|
|
|
|
# we should now have a list of messages that is just within
|
|
# the current context limit.
|
|
summarizer_model = self.valves.summarizer_model_id
|
|
if is_big_convo(message_chain, num_ctx) and self.valves.summarize_large_contexts:
|
|
print("Summarizing LARGE context!")
|
|
summarizer_model = self.valves.large_summarizer_model_id
|
|
|
|
|
|
checkpointer = Checkpointer(
|
|
chat_id=self.session_info.chat_id,
|
|
summarizer_model=summarizer_model,
|
|
chroma_client=CHROMA_CLIENT,
|
|
full_messages=messages,
|
|
messages=message_chain
|
|
)
|
|
|
|
try:
|
|
slug = await checkpointer.create_checkpoint()
|
|
await self.set_summarizing_status(done=True)
|
|
print(("Summarization checkpoint created in chat "
|
|
f"'{self.session_info.chat_id}': {slug}"))
|
|
except Exception as e:
|
|
print(f"Error creating summary: {str(e)}")
|
|
await self.set_summarizing_status(
|
|
done=True, message=f"Error summarizing: {str(e)}"
|
|
)
|
|
|
|
|
|
def update_chat_with_checkpoint(self, messages: List[dict], checkpoint: Checkpoint):
|
|
if len(messages) < checkpoint.message_index:
|
|
# do not mess with anything if the index doesn't even
|
|
# exist anymore. need a new checkpoint.
|
|
return messages
|
|
|
|
# proceed with altering the system prompt. keep system prompt,
|
|
# if it's there, and add summary to it. summary will become
|
|
# system prompt if there is no system prompt.
|
|
convo_messages = [
|
|
message for message in messages if message.get("role") != "system"
|
|
]
|
|
|
|
system_prompt = next(
|
|
(message for message in messages if message.get("role") == "system"), None
|
|
)
|
|
|
|
summary_message = f"Summary of conversation so far:\n\n{checkpoint.summary}"
|
|
|
|
if system_prompt:
|
|
system_prompt["content"] += f"\n\n{summary_message}"
|
|
else:
|
|
system_prompt = { "role": "system", "content": summary_message }
|
|
|
|
|
|
# drop old messages, reapply system prompt.
|
|
messages = self.apply_checkpoint(checkpoint, messages)
|
|
return [system_prompt] + messages
|
|
|
|
|
|
async def send_message(self, message: str):
|
|
await self.event_emitter({
|
|
"type": "status",
|
|
"data": {
|
|
"description": message,
|
|
"done": True,
|
|
},
|
|
})
|
|
|
|
async def set_summarizing_status(self, done: bool, message: Optional[str]=None):
|
|
if not self.event_emitter:
|
|
return
|
|
|
|
if not done:
|
|
description = (
|
|
"Summarizing conversation due to reaching context limit (do not reply yet)."
|
|
)
|
|
else:
|
|
description = (
|
|
"Summarization complete (you may now reply)."
|
|
)
|
|
|
|
if message:
|
|
description = message
|
|
|
|
await self.event_emitter({
|
|
"type": "status",
|
|
"data": {
|
|
"description": description,
|
|
"done": done,
|
|
},
|
|
})
|
|
|
|
def apply_checkpoint(
|
|
self, checkpoint: Checkpoint, messages: List[dict]
|
|
) -> List[dict]:
|
|
"""
|
|
Possibly shorten the message context based on a checkpoint.
|
|
This works two ways: if the messages have IDs (outlet
|
|
filter), split by message ID (very reliable). Otherwise,
|
|
attempt to split by on the recorded message index (inlet
|
|
filter; not very reliable).
|
|
"""
|
|
|
|
# first attempt to drop everything before the checkpointed
|
|
# message id.
|
|
split_point = 0
|
|
for idx, message in enumerate(messages):
|
|
if "id" in message and message["id"] == checkpoint.message_id:
|
|
split_point = idx
|
|
break
|
|
|
|
# if we can't find the ID to split on, fall back to message
|
|
# index if possible. this can happen during message
|
|
# regeneration, for example. or if we're called from the inlet
|
|
# filter, which doesn't have access to message ids.
|
|
if split_point == 0 and checkpoint.message_index <= len(messages):
|
|
split_point = checkpoint.message_index
|
|
|
|
orig = len(messages)
|
|
messages = messages[split_point:]
|
|
print((f"[{self.session_info.chat_id}] Dropped context to {len(messages)} "
|
|
f"messages (from {orig})"))
|
|
return messages
|
|
|
|
|
|
async def handle_nuke(self, body):
|
|
checkpointer = Checkpointer(
|
|
chat_id=self.session_info.chat_id,
|
|
chroma_client=CHROMA_CLIENT
|
|
)
|
|
checkpointer.nuke_checkpoints()
|
|
await self.send_message("Deleted all checkpoint for chat.")
|
|
|
|
body["messages"][-1]["content"] = (
|
|
"Respond ony with: 'Deleted all checkpoint for chat.'"
|
|
)
|
|
|
|
body["messages"] = body["messages"][-1:]
|
|
return body
|
|
|
|
async def outlet(
|
|
self,
|
|
body: dict,
|
|
__user__: Optional[dict],
|
|
__model__: Optional[dict],
|
|
__event_emitter__: Callable[[Any], Awaitable[None]],
|
|
) -> dict:
|
|
# Useful things to have around.
|
|
self.user = __user__
|
|
self.model = __model__
|
|
self.session_info = extract_session_info(__event_emitter__)
|
|
self.event_emitter = __event_emitter__
|
|
self.summarizer_model_id = self.valves.summarizer_model(body)
|
|
|
|
# global filters apply to requests coming in through proxied
|
|
# API. If we're not an OpenWebUI chat, abort mission.
|
|
if not self.session_info:
|
|
return body
|
|
|
|
if not self.model or self.modle["owned_by"] != "ollama":
|
|
return body
|
|
|
|
messages = body["messages"]
|
|
|
|
num_ctx = await calculate_num_ctx(
|
|
chat_id=self.session_info.chat_id,
|
|
user_id=self.user["id"],
|
|
model=self.model
|
|
)
|
|
|
|
# apply current checkpoint ONLY for purposes of calculating if
|
|
# we have hit num_ctx within current checkpoint.
|
|
checkpointer = Checkpointer(
|
|
chat_id=self.session_info.chat_id,
|
|
chroma_client=CHROMA_CLIENT
|
|
)
|
|
|
|
checkpoint = checkpointer.get_current_checkpoint()
|
|
messages_for_ctx_check = (self.apply_checkpoint(checkpoint, messages)
|
|
if checkpoint else messages)
|
|
|
|
hit_limit, amount_over = hit_context_limit(
|
|
messages=messages_for_ctx_check,
|
|
num_ctx=num_ctx,
|
|
wiggle_room=self.valves.wiggle_room
|
|
)
|
|
|
|
if hit_limit:
|
|
# we need the FULL message list to do proper summarizing,
|
|
# because we might be summarizing a hug context.
|
|
await self.create_checkpoint(
|
|
messages=messages,
|
|
num_ctx=num_ctx,
|
|
last_checkpointed_id=checkpoint.message_id if checkpoint else None
|
|
)
|
|
|
|
print(f"[{self.session_info.chat_id}] Done checking for summarization")
|
|
return body
|
|
|
|
|
|
async def inlet(
|
|
self,
|
|
body: dict,
|
|
__user__: Optional[dict],
|
|
__model__: Optional[dict],
|
|
__event_emitter__: Callable[[Any], Awaitable[None]]
|
|
) -> dict:
|
|
# Useful properties to have around.
|
|
self.user = __user__
|
|
self.model = __model__
|
|
self.session_info = extract_session_info(__event_emitter__)
|
|
self.event_emitter = __event_emitter__
|
|
self.summarizer_model_id = self.valves.summarizer_model(body)
|
|
|
|
# global filters apply to requests coming in through proxied
|
|
# API. If we're not an OpenWebUI chat, abort mission.
|
|
if not self.session_info:
|
|
return body
|
|
|
|
if not self.model or self.modle["owned_by"] != "ollama":
|
|
return body
|
|
|
|
# super basic external command handling (delete checkpoints).
|
|
user_msg = get_last_user_message(body["messages"])
|
|
if user_msg and user_msg == "!nuke":
|
|
return await self.handle_nuke(body)
|
|
|
|
# apply current checkpoint to the chat: adds most recent
|
|
# summary to system prompt, and drops all messages before the
|
|
# checkpoint.
|
|
checkpointer = Checkpointer(
|
|
chat_id=self.session_info.chat_id,
|
|
chroma_client=CHROMA_CLIENT
|
|
)
|
|
|
|
checkpoint = checkpointer.get_current_checkpoint()
|
|
if checkpoint:
|
|
print((
|
|
f"Using checkpoint {checkpoint.slug} for "
|
|
f"conversation {self.session_info.chat_id}"
|
|
))
|
|
|
|
body["messages"] = self.update_chat_with_checkpoint(body["messages"], checkpoint)
|
|
|
|
return body
|