open-webui-filters/memories.py

793 lines
27 KiB
Python
Raw Normal View History

2024-07-15 14:39:54 +00:00
"""
title: Memory Filter
author: projectmoon
author_url: https://git.agnos.is/projectmoon/open-webui-filters
2024-07-18 20:13:00 +00:00
version: 0.0.2
required_open_webui_version: 0.3.9
2024-07-15 14:39:54 +00:00
"""
2024-07-18 20:13:00 +00:00
# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
#
# Changelog:
# 0.0.1 - Initial release, proof of concept
# 0.0.2 - Slightly less hacky (but still hacky) way of getting chat IDs
2024-07-15 14:39:54 +00:00
# System imports
import asyncio
import hashlib
import uuid
import json
from typing import Optional, List, Dict, Callable, Any, NewType, Tuple, Awaitable
from typing_extensions import TypedDict, NotRequired
# Libraries available to OpenWebUI
import markdown
from bs4 import BeautifulSoup
from pydantic import BaseModel as PydanticBaseModel, Field
import chromadb
from chromadb import Collection as ChromaCollection
from chromadb.api.types import Document as ChromaDocument
# OpenWebUI imports
from config import CHROMA_CLIENT
from apps.rag.main import app
from utils.misc import get_last_user_message, get_last_assistant_message
from main import generate_chat_completions
# OpenWebUI aliases
EMBEDDING_FUNCTION = app.state.EMBEDDING_FUNCTION
# Custom type declarations
EmbeddingFunc = NewType('EmbeddingFunc', Callable[[str], List[Any]])
# Prompts
ENRICHMENT_SUMMARY_PROMPT = """
You are tasked with analyzing the following Characters and Plot Details
sections and reducing this set of information into lists of the most
important points needed for the continuation of the narrative you are
writing. Remove duplicate or conflicting information. If there is conflicting
information, decide on something consistent and interesting for the story.
Your reply must consist of two sections: Characters and Plot Details. These
sections must be markdown ### Headers. Under each header, respond with a
list of bullet points. Each bullet point must be one piece of relevant information.
Limit each bullet point to one sentence. Respond ONLY with the Characters and
Plot Details sections, with the bullet points under them, and nothing else.
Do not respond with any commentary. ONLY respond with the bullet points.
""".replace("\n", " ").strip()
QUERY_PROMPT = """
You are tasked with generating questions for a vector database
about the narrative presented below. The queries must be questions about
parts of the story that you need more details on. The questions must be
about past events in the story, or questions about the characters involved
or mentioned in the scene (their appearance, mental state, past actions, etc).
Your reply must consist of two sections: Characters and Plot Details. These
sections must be markdown ### Headers. Under each header, respond with a
list of bullet points. Each bullet point must be a single question or sentence
that will be given to the vector database. Generate a maximum of 5 Character
queries and 5 Plot Detail queries.
Limit each bullet point to one sentence. Respond ONLY with the Characters and
Plot Details sections, with the bullet points under them, and nothing else.
Do not respond with any commentary. ONLY respond with the bullet points.
""".replace("\n", " ").strip()
SUMMARIZER_PROMPT = """
You are a narrative summarizer. Summarize the given message as if it's
part of a story. Your response must have two separate sections: Characters
and Plot Details. These sections should be markdown ### Headers. Under each
section, respond with a list of bullet points. This knowledge will be stored
in a vector database for your future use.
The Characters section should note any characters in the scene, and important
things that happen to them. Describe the characters' appearances, actions,
mental states, and emotional states. The Plot Details section should have a
list of important plot details in this scene.
The bullet points you generate must be in the context of storing future
knowledge about the story. Do not focus on useless details: only focus on
information that you could lose in the future as your context window shifts.
Limit each bullet point to one sentence. The sentence MUST be in the PAST TENSE.
Respond ONLY with the Characters and Plot Details sections, with the bullet points
under them, and nothing else. Do not respond with any commentary. ONLY respond with
the bullet points.
""".replace("\n", " ").strip()
class Message(TypedDict):
id: NotRequired[str]
role: str
content: str
class MessageInsertMetadata(TypedDict):
role: str
chapter: str
class MessageInsert(TypedDict):
message_id: str
content: str
metadata: MessageInsertMetadata
embeddings: List[Any]
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class SummarizerResponse(BaseModel):
characters: List[str]
plot: List[str]
class Summarizer(BaseModel):
message: str
model: str
prompt: str = SUMMARIZER_PROMPT
def extract_section(self, soup: BeautifulSoup, section_name: str) -> List[str]:
for h3 in soup.find_all('h3'):
heading = h3.get_text().strip()
if heading != section_name:
continue
# Find the next sibling which should be a <ul> or <ol>
ul = h3.find_next_sibling('ul')
ol = h3.find_next_sibling('ol')
list_items = []
if ul:
list_items = [li.get_text().strip() for li in ul.find_all('li')]
elif ol:
list_items = [li.get_text().strip() for li in ol.find_all('li')]
return list_items
return []
def sanitize_section(self, bullet_points: List[str]) -> List[str]:
return [
bullet.strip().lstrip("-*•123456789").strip() for bullet in bullet_points
]
async def summarize(self) -> SummarizerResponse:
messages: List[Message] = [
{ "role": "system", "content": SUMMARIZER_PROMPT },
{ "role": "user", "content": self.message }
]
request = {
"model": self.model,
"messages": messages,
"stream": False,
"keep_alive": "10s"
}
resp = await generate_chat_completions(request)
if "choices" in resp and len(resp["choices"]) > 0:
content: str = resp["choices"][0]["message"]["content"]
html = markdown.markdown(content)
soup = BeautifulSoup(html, "html.parser")
character_results = self.extract_section(soup, "Characters")
character_results = self.sanitize_section(character_results)
plot_points = self.extract_section(soup, "Plot Details")
plot_points = self.sanitize_section(plot_points)
return SummarizerResponse(characters=character_results, plot=plot_points)
else:
return SummarizerResponse(characters=[], plot=[])
class Chapter(BaseModel):
"""
Focuses on a single 'chapter,' or chunk of a conversation. Provides methods to
search for data in this section of conversational story history.
"""
convo_id: Optional[str]
client: chromadb.ClientAPI
chapter_id: str
messages: List[Message]
embedding_func: EmbeddingFunc
def create_metadata(self) -> Dict:
return { "convo_id": self.convo_id, "chapter": self.chapter_id }
def get_collection(self) -> Optional[ChromaCollection]:
try:
coll = self.client.get_collection("stories")
if not self.convo_id:
self.convo_id = (
coll.metadata["current_convo_id"] if "current_convo_id" in coll.metadata else None
)
return coll
except ValueError as e:
return None
def _create_inserts(self, summary: SummarizerResponse) -> List[MessageInsert]:
inserts = []
plot_points = summary.plot
character_points = summary.characters
for plot_point in plot_points:
inserts.append({
'id': str(uuid.uuid4()),
'content': plot_point,
'metadata': {
"convo_id": self.convo_id,
"chapter": self.chapter_id,
"type": "plot"
},
'embedding': self.embedding_func(plot_point)
})
for character_point in character_points:
inserts.append({
'id': str(uuid.uuid4()),
'content': character_point,
'metadata': {
"convo_id": self.convo_id,
"chapter": self.chapter_id,
"type": "character"
},
'embedding': self.embedding_func(character_point)
})
return inserts
def chapter_state(self) -> dict:
"""Useful for storing current place in chapter, and convo switching."""
coll = self.get_collection()
result = coll.get(ids=f"chapter-{self.chapter_id}", include=["metadatas"])
if len(result.metadatas) > 0:
return result.metadatas[0]
else:
return {}
def embed(self, summary: SummarizerResponse):
"""
Store plot points for this chapter in ChromaDB.
"""
coll = self.get_collection()
if not self.convo_id:
return
inserts = self._create_inserts(summary)
if len(inserts) > 0:
documents = [entry['content'] for entry in inserts]
metadatas = [entry['metadata'] for entry in inserts]
ids = [entry['id'] for entry in inserts]
embeddings = [entry['embedding'] for entry in inserts]
coll.upsert(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
def query_plot(self, search_term):
return self.query(search_term, "plot")
def query_characters(self, search_term):
return self.query(search_term, "character")
def query(self, search_term: str, type: str) -> List[ChromaDocument]:
coll = self.get_collection()
if coll and self.convo_id:
term_embedding = self.embedding_func(search_term)
results = coll.query(
query_embeddings=[term_embedding],
include=["documents", "metadatas"],
where={
"$and": [
{ "convo_id": self.convo_id },
{ "chapter": self.chapter_id },
{ "type": type }
]
},
n_results = 5
)
# flatten out list of list of documents
# because chroma returns a List[List[Document]] for some reason.
if 'documents' in results:
docs = [
doc
for doc_list in results['documents']
for doc in doc_list
]
metadatas = [
md
for md_list in results['metadatas']
for md in md_list
]
results = []
for (doc, metadata) in zip(docs, metadatas):
results.append({ "doc": doc, "metadata": metadata })
return results
else:
return []
else:
return []
class Story(BaseModel):
"""Container for chapters. Manages an entire conversation."""
convo_id: Optional[str] = None
client: chromadb.ClientAPI
messages: List[Message]
embedding_func: EmbeddingFunc
def _collection_name(self):
return f"stories"
def create_metadata(self):
try:
coll = self.client.get_collection(self._collection_name())
if coll:
# If we have pre-specified a convo id, update metadata
# of collection accordingly.
if self.convo_id:
metadata = coll.metadata
metadata['current_convo_id'] = self.convo_id
metadata["hnsw:space"] = "cosine"
coll = self.client.get_or_create_collection(
name=self._collection_name(), metadata=metadata
)
else: # Otherwise pull it out of the database.
self.convo_id = (
coll.metadata['current_convo_id'] if 'current_convo_id' in coll.metadata else None
)
return coll.metadata
except ValueError:
return { "current_convo_id": "<unset>", "current_chapter": 1 }
def convo_state(self) -> dict:
"""Retrieve information about the current conversation."""
if not self.convo_id or self.convo_id == "<unset>":
return {}
convo_state_id = f"convo-{self.convo_id}"
coll = self.get_collection()
result = coll.get(ids=[convo_state_id], include=["metadatas"])
if len(result.metadatas) > 0:
return result.metadatas[0]
else:
# insert convo state
# TODO do something useful with convo summary
convo_summary = f"State for convo {self.convo_id}"
convo_metadata = { "current_chapter": 1 }
coll.add(
ids=[convo_state_id],
documents=[convo_summary], # maybe store convo summary here?
embeddings=self.embedding_func(convo_summary),
metadatas=[convo_metadata]
)
return convo_metadata
def switch_convo(self):
"""Force a switch of current conversation."""
if not self.convo_id:
# If we have only a user message (i.e. start of
# conversation), forcibly set to <unset>
if len(self.messages) < 2:
self.convo_id = "<unset>"
else:
# Otherwise attempt to get the cllection, which forces
# metatada creation and updates.
self.get_collection()
def get_collection(self):
"""Retrieve the collection, with its context set to the current convo ID."""
try:
coll = self.client.get_collection(self._collection_name())
if coll:
# If we have pre-specified a convo id, update metadata
# of collection accordingly.
if self.convo_id:
metadata = coll.metadata
metadata['current_convo_id'] = self.convo_id
metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection(
name=self._collection_name(), metadata=metadata
)
else: # Otherwise pull existing convo id out of the database.
self.convo_id = (
coll.metadata['current_convo_id'] if 'current_convo_id' in coll.metadata else None
)
return coll
except ValueError:
# if the stories collection does not exist, create it
# completely from scratch.
metadata = { "current_convo_id": "<unset>", "hnsw:space": "cosine" }
return self.client.get_or_create_collection(self._collection_name(), metadata=metadata)
def _current_chapter(self) -> int:
try:
return self.convo_state()["current_chapter"]
except:
return 1
def _current_chapter_object(self) -> Chapter:
return Chapter(
convo_id = self.convo_id, chapter_id=str(self._current_chapter()),
messages=self.messages, client=self.client, embedding_func=self.embedding_func
)
def embed_summary(self, summary: SummarizerResponse):
self._current_chapter_object().embed(summary)
def query_plot(self, term: str) -> List[ChromaDocument]:
return self._current_chapter_object().query_plot(term)
def query_characters(self, term: str) -> List[ChromaDocument]:
return self._current_chapter_object().query_characters(term)
# Utils
2024-07-18 20:13:00 +00:00
class SessionInfo(BaseModel):
chat_id: str
message_id: str
session_id: str
def extract_session_info(event_emitter) -> Optional[SessionInfo]:
"""The latest innovation in hacky workarounds."""
try:
info = event_emitter.__closure__[0].cell_contents
return SessionInfo(
chat_id=info["chat_id"],
message_id=info["message_id"],
session_id=info["session_id"]
)
except:
return None
2024-07-15 14:39:54 +00:00
def create_enrichment_summary_prompt(
narrative: str,
character_details: List[str],
plot_details: List[str]
) -> str:
prompt = ENRICHMENT_SUMMARY_PROMPT
prompt += "Here are the original Character and Plot Details sections."
prompt += " Summarize them according to the instructions.\n\n"
snippets = "## Character Details:\n"
for character_detail in character_details:
snippets += f"- {character_detail}\n"
snippets = snippets.strip()
snippets += "\n"
snippets += "\n\n## Plot Details:\n"
for plot_point in plot_details:
snippets += f"- {plot_point}\n"
snippets = snippets.strip()
snippets += "\n"
snippets = snippets.strip()
prompt += snippets + "\n\n"
prompt += "Additionally, the narrative you must continue is provided below."
prompt += "\n\n-----\n\n"
prompt += narrative
return prompt.strip()
def create_context(results: SummarizerResponse) -> Optional[str]:
if not results:
return None
character_details = results.characters
plot_details = results.plot
snippets = "## Relevant Character Details:\n"
snippets += "These are relevant bits of information about characters in the story.\n"
for character_detail in character_details:
snippets += f"- {character_detail}\n"
snippets = snippets.strip()
snippets += "\n"
snippets += "\n\n## Relevant Plot Details:\n"
snippets += "These are relevant plot details that happened earlier in the story.\n"
for plot_point in plot_details:
snippets += f"- {plot_point}\n"
snippets = snippets.strip()
snippets += "\n"
message = (
"\n\nUse the following context as information about the story, inside <context></context> XML tags.\n\n"
f"<context>\n{snippets}</context>\n"
"When answering to user:\n"
"- Use the context to enhance your knowledge of the story.\n"
"- If you don't know, do not ask for clarification.\n"
"Do not mention that you obtained the information from the context.\n"
"Do not mention the context.\n"
f"Continue the story according to the user's directions."
)
return message
def split_messages(messages, keep_amount):
if len(messages) <= keep_amount:
return messages[:], []
recent_messages = messages[-keep_amount:]
old_messages = messages[:-keep_amount]
return recent_messages, old_messages
def chunk_messages(messages, chunk_size):
return [messages[i:i + chunk_size] for i in range(0, len(messages), chunk_size)]
def llm_messages_to_user_messages(messages):
return [
{'role': 'user', 'content': msg['content']}
for msg in messages if msg['role'] == 'assistant'
]
# Das Filter
class Filter:
class Valves(BaseModel):
def summarizer_model(self, body):
if self.summarizer_model_id == "":
# This will be the model ID in the convo. If not base
# model, it will cause problems.
return body["model"]
else:
return self.summarizer_model_id
summarizer_model_id: str = Field(
default="",
description="Model used to summarize the conversation. Must be a base model.",
)
n_last_messages: int = Field(
default=4, description="Number of last messages to retain."
)
pass
class UserValves(BaseModel):
pass
def __init__(self):
self.valves = self.Valves()
pass
def extract_convo_id(self, messages):
"""Extract ID of first message to use as conversation ID."""
if len(messages) > 0:
first_user_message = next(
(message for message in messages if message.get("role") == "user"), None
)
if first_user_message and 'id' in first_user_message:
return first_user_message['id']
else:
raise ValueError("No messages found to extract conversation ID")
else:
raise ValueError("No messages found to extract conversation ID")
async def summarize(self, messages) -> Optional[SummarizerResponse]:
message_to_summarize = get_last_assistant_message(messages)
if message_to_summarize:
summarizer = Summarizer(model=self.summarizer_model_id, message=message_to_summarize)
return await summarizer.summarize()
else:
return None
async def send_outlet_status(self, event_emitter, done: bool):
description = (
"Analyzing Narrative (do not reply until this is done)" if not done else
"Narrative analysis complete (you may now reply)."
)
await event_emitter({
"type": "status",
"data": {
"description": description,
"done": done,
},
})
async def set_enriching_status(self, state: str):
if not self.event_emitter:
return
done = state == "done"
description = "Enriching Narrative"
if state == "init": description = f"{description}: Initializing..."
if state == "searching": description = f"{description}: Searching..."
if state == "analyzing": description = f"{description}: Analyzing..."
description = (
description if not done else
"Enrichment Complete"
)
await self.event_emitter({
"type": "status",
"data": {
"description": description,
"done": done,
},
})
async def outlet(
self,
body: dict,
__user__: Optional[dict],
__event_emitter__: Callable[[Any], Awaitable[None]],
) -> dict:
# Useful things to have around.
2024-07-18 20:13:00 +00:00
self.session_info = extract_session_info(__event_emitter__)
2024-07-15 14:39:54 +00:00
self.event_emitter = __event_emitter__
self.summarizer_model_id = self.valves.summarizer_model(body)
await self.send_outlet_status(__event_emitter__, False)
messages = body['messages']
# summarize into plot points.
summary = await self.summarize(messages)
story = Story(
2024-07-18 20:13:00 +00:00
convo_id=self.session_info.chat_id,
client=CHROMA_CLIENT,
2024-07-15 14:39:54 +00:00
embedding_func=EMBEDDING_FUNCTION,
messages=messages
)
story.switch_convo()
if summary:
story.embed_summary(summary)
await self.send_outlet_status(__event_emitter__, True)
return body
async def generate_enrichment_queries(self, messages) -> SummarizerResponse:
last_response = get_last_assistant_message(messages)
user_input = get_last_user_message(messages)
query_message = ""
if last_response: query_message += f"## Assistant\n\n{last_response}\n\n"
if user_input: query_message += f"## User\n\n{user_input}\n\n"
query_message = query_message.strip()
summarizer = Summarizer(
model=self.summarizer_model_id,
message=query_message,
prompt=QUERY_PROMPT
)
return await summarizer.summarize()
async def summarize_enrichment(
self,
messages,
character_results: List[ChromaDocument],
plot_results: List[ChromaDocument]
) -> SummarizerResponse:
last_response = get_last_assistant_message(messages)
user_input = get_last_user_message(messages)
character_details = [r['doc'] for r in character_results]
plot_details = [r['doc'] for r in plot_results]
narrative_message = ""
if last_response: narrative_message += f"## Assistant\n\n{last_response}\n\n"
if user_input: narrative_message += f"## User\n\n{user_input}\n\n"
narrative_message = narrative_message.strip()
summarization_prompt = create_enrichment_summary_prompt(
narrative=narrative_message,
plot_details=plot_details,
character_details=character_details
)
summarizer = Summarizer(
model=self.summarizer_model_id,
message=narrative_message,
prompt=summarization_prompt
)
return await summarizer.summarize()
2024-07-18 20:13:00 +00:00
async def enrich(self, story: Story, messages) -> Optional[SummarizerResponse]:
if len(messages) < 2:
return None
2024-07-15 14:39:54 +00:00
await self.set_enriching_status("searching")
query_generation_result = await self.generate_enrichment_queries(messages)
character_results = [result
for query in query_generation_result.characters
for result in story.query_characters(query)]
plot_results = [result
for query in query_generation_result.plot
for result in story.query_plot(query)]
await self.set_enriching_status("analyzing")
return await self.summarize_enrichment(messages, character_results, plot_results)
async def update_system_message(self, messages, system_message):
story = Story(
2024-07-18 20:13:00 +00:00
convo_id=self.session_info.chat_id,
client=CHROMA_CLIENT,
2024-07-15 14:39:54 +00:00
embedding_func=EMBEDDING_FUNCTION,
messages=messages
)
story.switch_convo()
if story.convo_id == "<unset>":
return
2024-07-18 20:13:00 +00:00
enrichment_summary: Optional[SummarizerResponse] = await self.enrich(story, messages)
if enrichment_summary:
context = create_context(enrichment_summary)
else:
context = None
2024-07-15 14:39:54 +00:00
if context:
system_message["content"] += context
async def inlet(
self,
body: dict,
__user__: Optional[dict],
__event_emitter__: Callable[[Any], Awaitable[None]]
) -> dict:
# Useful properties to have around.
2024-07-18 20:13:00 +00:00
self.session_info = extract_session_info(__event_emitter__)
2024-07-15 14:39:54 +00:00
self.event_emitter = __event_emitter__
self.summarizer_model_id = self.valves.summarizer_model(body)
2024-07-18 20:13:00 +00:00
2024-07-15 14:39:54 +00:00
await self.set_enriching_status("init")
messages = body["messages"]
# Ensure we always keep the system prompt
system_prompt = next(
(message for message in messages if message.get("role") == "system"), None
)
if system_prompt:
all_messages = [
message for message in messages if message.get("role") != "system"
]
recent_messages, old_messages = split_messages(all_messages, self.valves.n_last_messages)
most_recent_messages = messages[-self.valves.n_last_messages :]
else:
system_prompt = { "id": str(uuid.uuid4()), "role": "system", "content": "" }
recent_messages, old_messages = split_messages(messages, self.valves.n_last_messages)
await self.update_system_message(messages, system_prompt)
recent_messages.insert(0, system_prompt)
body["messages"] = recent_messages
await self.set_enriching_status("done")
return body