from __future__ import annotations
import logging
import uuid
from logging import StreamHandler
from typing import Dict, List, Optional, Union, cast
from fluent.handler import FluentRecordFormatter
from steamship import File, MimeTypes, Steamship, SteamshipError, Task
from steamship.agents.logging import LOGGING_FORMAT, AgentLogging, StreamingOpts
from steamship.agents.schema.message_selectors import MessageSelector
from steamship.agents.schema.text_splitters import FixedSizeTextSplitter, TextSplitter
from steamship.base.client import Client
from steamship.data import TagKind
from steamship.data.block import Block
from steamship.data.plugin.index_plugin_instance import EmbeddingIndexPluginInstance, SearchResults
from steamship.data.tags import Tag
from steamship.data.tags.tag_constants import ChatTag, DocTag, RoleTag, TagValueKey
[docs]
class ChatHistory:
"""A ChatHistory is a wrapper of a File ideal for ongoing interactions between a user and a virtual assistant.
It also includes vector-backed storage for similarity-based retrieval."""
file: File
embedding_index: EmbeddingIndexPluginInstance
text_splitter: TextSplitter
def __init__(
self,
file: File,
embedding_index: Optional[EmbeddingIndexPluginInstance],
text_splitter: TextSplitter = None,
):
"""This init method is intended only for private use within the class. See `Chat.create()`"""
self.file = file
self.embedding_index = embedding_index
if text_splitter is not None:
self.text_splitter = text_splitter
else:
self.text_splitter = FixedSizeTextSplitter(chunk_size=300)
@staticmethod
def _get_existing_file(client: Client, context_keys: Dict[str, str]) -> Optional[File]:
"""Find an existing File object whose memory Tag matches the passed memory keys"""
file_query = (
f'kind "{TagKind.CHAT}" and name "{ChatTag.CONTEXT_KEYS}"'
+ (" and " if len(context_keys) > 0 else "")
+ " and ".join([f'value("{key}") = "{value}"' for key, value in context_keys.items()])
)
file = File.query(client, file_query)
if len(file.files) == 1:
return file.files[0]
elif len(file.files) == 0:
return None
else:
raise SteamshipError(
"Multiple ChatHistory objects have been created in this workspace with these memory keys."
)
@staticmethod
def _get_index_handle_from_file(file: File) -> str:
for tag in file.tags:
if tag.kind == TagKind.CHAT and tag.name == ChatTag.INDEX_HANDLE:
return tag.value[TagValueKey.STRING_VALUE]
raise SteamshipError(f"Could not find index handle on file with id {file.id}")
@staticmethod
def _get_embedding_index(client: Steamship, index_handle: str) -> EmbeddingIndexPluginInstance:
return cast(
EmbeddingIndexPluginInstance,
client.use_plugin(
plugin_handle="embedding-index",
instance_handle=index_handle,
config={
"embedder": {
"plugin_handle": "openai-embedder",
"plugin_instance-handle": "text-embedding-ada-002",
"fetch_if_exists": True,
"config": {"model": "text-embedding-ada-002", "dimensionality": 1536},
}
},
fetch_if_exists=True,
),
)
[docs]
@staticmethod
def get_or_create(
client: Steamship,
context_keys: Dict[str, str],
tags: List[Tag] = None,
searchable: bool = True,
) -> ChatHistory:
file = ChatHistory._get_existing_file(client, context_keys)
if file is None:
tags = tags or []
index_handle = str(uuid.uuid4())
tags.append(Tag(kind=TagKind.DOCUMENT, name=DocTag.CHAT)) # This is a Chat-related tag
tags.append(Tag(kind=TagKind.CHAT, name=ChatTag.HISTORY)) # This is a ChatHistory file
tags.append(Tag(kind=TagKind.CHAT, name=ChatTag.CONTEXT_KEYS, value=context_keys))
tags.append(
Tag(
kind=TagKind.CHAT,
name=ChatTag.INDEX_HANDLE,
value={TagValueKey.STRING_VALUE: index_handle},
)
)
blocks = []
file = File.create(
client=client,
blocks=blocks,
tags=tags,
)
else:
index_handle = ChatHistory._get_index_handle_from_file(file)
if searchable:
embedding_index = ChatHistory._get_embedding_index(client, index_handle)
else:
embedding_index = None
return ChatHistory(file, embedding_index)
[docs]
def append_message_with_role(
self,
text: str = None,
role: RoleTag = RoleTag.USER,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with content provided by the end-user."""
tags = tags or []
tags.append(
Tag(kind=TagKind.CHAT, name=ChatTag.ROLE, value={TagValueKey.STRING_VALUE: role})
)
tags.append(Tag(kind=TagKind.CHAT, name=ChatTag.MESSAGE))
block = self.file.append_block(
text=text, tags=tags, content=content, url=url, mime_type=mime_type
)
# don't index status messages
if self.embedding_index is not None and role not in [
RoleTag.AGENT,
RoleTag.TOOL,
RoleTag.LLM,
]:
chunk_tags = self.text_splitter.chunk_text_to_tags(
block, kind=TagKind.CHAT, name=ChatTag.CHUNK
)
block.tags.extend(chunk_tags)
# Only embed tags that aren't empty space.
non_empty_chunk_tags = [tag for tag in chunk_tags if tag.text.strip()]
self.embedding_index.insert(non_empty_chunk_tags)
return block
[docs]
def append_user_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with content provided by the end-user."""
return self.append_message_with_role(text, RoleTag.USER, tags, content, url, mime_type)
[docs]
def append_system_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with content provided by the system, i.e., instructions to the assistant."""
return self.append_message_with_role(text, RoleTag.SYSTEM, tags, content, url, mime_type)
[docs]
def append_assistant_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with content provided by the agent, i.e., results from the assistant."""
return self.append_message_with_role(text, RoleTag.ASSISTANT, tags, content, url, mime_type)
@property
def last_user_message(self) -> Optional[Block]:
for block in self.file.blocks[::-1]:
if block.chat_role == RoleTag.USER:
return block
return None
@property
def last_system_message(self) -> Optional[Block]:
for block in self.file.blocks[::-1]:
if block.chat_role == RoleTag.SYSTEM:
return block
return None
@property
def last_agent_message(self) -> Optional[Block]:
for block in self.file.blocks[::-1]:
if block.chat_role == RoleTag.ASSISTANT:
return block
return None
@property
def initial_system_prompt(self) -> Optional[Block]:
if len(self.file.blocks) > 0 and self.file.blocks[0].chat_role == RoleTag.SYSTEM:
return self.file.blocks[0]
else:
return None
[docs]
def refresh(self):
self.file.refresh()
@property
def tags(self) -> List[Tag]:
return self.file.tags
@property
def messages(self) -> List[Block]:
return self.file.blocks
@property
def client(self) -> Client:
return self.file.client
[docs]
def select_messages(self, selector: MessageSelector) -> List[Block]:
return selector.get_messages(self.messages)
[docs]
def search(self, text: str, k=None) -> Task[SearchResults]:
if len(text.strip()) == 0:
return Task(output=SearchResults(), state="succeeded")
if self.embedding_index is None:
raise SteamshipError("This ChatHistory has no embedding index and is not searchable.")
return self.embedding_index.search(text, k)
[docs]
def is_searchable(self) -> bool:
return self.embedding_index is not None
[docs]
def delete_messages(self, selector: MessageSelector):
"""Delete a set of selected messages from the ChatHistory.
If `selector == None`, no messages will be deleted.
NOTES:
- upon deletion, refresh() is called to ensure up-to-date history refs.
- causes a full re-index of chat history if the history is searchable.
"""
if selector:
selected_messages = selector.get_messages(self.messages)
for msg in selected_messages:
msg.delete()
self.refresh()
if self.is_searchable():
self.embedding_index.reset()
for msg in self.messages:
for tag in msg.tags:
if tag.kind == TagKind.CHAT and tag.name == ChatTag.CHUNK:
# TODO(dougreid): figure out why tag.text gets lost.
if not tag.text:
tag.text = msg.text[tag.start_idx : tag.end_idx]
# Only embed it if we've managed to generate a string representation.
if tag.text and tag.text.strip():
self.embedding_index.insert(tag)
self.refresh()
[docs]
def clear(self):
"""Deletes ALL messages from the ChatHistory (including system).
NOTE: upon deletion, refresh() is called to ensure up-to-date history refs.
"""
for block in self.file.blocks:
block.delete()
if self.is_searchable():
self.embedding_index.reset()
self.refresh()
[docs]
def append_status_message_with_role(
self,
text: str = None,
role: RoleTag = RoleTag.USER,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with content provided by the end-user."""
tags = tags or []
tags.append(
Tag(
kind=TagKind.STATUS_MESSAGE,
name=ChatTag.ROLE,
value={TagValueKey.STRING_VALUE: role},
)
)
return self.file.append_block(
text=text, tags=tags, content=content, url=url, mime_type=mime_type
)
[docs]
def append_agent_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with status update messages from the Agent."""
return self.append_status_message_with_role(
text, RoleTag.AGENT, tags, content, url, mime_type
)
[docs]
def append_tool_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with status update messages from the Agent."""
return self.append_status_message_with_role(
text, RoleTag.TOOL, tags, content, url, mime_type
)
[docs]
def append_llm_message(
self,
text: str = None,
tags: List[Tag] = None,
content: Union[str, bytes] = None,
url: Optional[str] = None,
mime_type: Optional[MimeTypes] = None,
) -> Block:
"""Append a new block to this with status update messages from the Agent."""
return self.append_status_message_with_role(
text, RoleTag.LLM, tags, content, url, mime_type
)
[docs]
def append_request_complete_message(
self,
) -> Block:
"""Append a new block to this with status update messages from the Agent."""
tags = [
Tag(kind=TagKind.AGENT_STATUS_MESSAGE, name=ChatTag.REQUEST_COMPLETE),
]
return self.append_status_message_with_role("", RoleTag.AGENT, tags, None, None, None)
[docs]
class ChatHistoryLoggingHandler(StreamHandler):
"""Logs messages emitted by Agents and Tools into a ChatHistory file.
This is a basic mechanism for streaming status messages alongside generated content.
"""
chat_history: ChatHistory
log_level: any
streaming_opts: StreamingOpts
def __init__(
self,
chat_history: ChatHistory,
log_level: any = logging.INFO,
streaming_opts: Optional[StreamingOpts] = None,
):
StreamHandler.__init__(self)
formatter = FluentRecordFormatter(LOGGING_FORMAT, fill_missing_fmt_key=True)
self.setFormatter(formatter)
self.chat_history = chat_history
self.log_level = log_level
if streaming_opts is not None:
self.streaming_opts = streaming_opts
else:
self.streaming_opts = StreamingOpts()
[docs]
def emit(self, record):
if record.levelno < self.log_level:
# don't bother doing anything if level is below logging level
return
message_dict = cast(dict, self.format(record))
is_agent_message = message_dict.get(AgentLogging.MESSAGE_AUTHOR, None) == AgentLogging.AGENT
if self.streaming_opts.include_agent_messages and is_agent_message:
return self._append_message(message_dict, AgentLogging.AGENT)
is_tool_message = message_dict.get(AgentLogging.MESSAGE_AUTHOR, None) == AgentLogging.TOOL
if self.streaming_opts.include_tool_messages and is_tool_message:
return self._append_message(message_dict, AgentLogging.TOOL)
is_llm_message = message_dict.get(AgentLogging.MESSAGE_AUTHOR, None) == AgentLogging.LLM
if self.streaming_opts.include_llm_messages and is_llm_message:
return self._append_message(message_dict, AgentLogging.LLM)
def _append_message(self, message_dict: dict, author_kind: str):
message = message_dict.get("message", None)
message_type = message_dict.get(AgentLogging.MESSAGE_TYPE, AgentLogging.MESSAGE)
if author_kind == AgentLogging.AGENT:
return self.chat_history.append_agent_message(
text=message,
mime_type=MimeTypes.TXT,
)
elif author_kind == AgentLogging.TOOL:
tool_name = message_dict.get(AgentLogging.TOOL_NAME, AgentLogging.TOOL)
return self.chat_history.append_tool_message(
text=message,
tags=[
Tag(
kind=TagKind.TOOL_STATUS_MESSAGE,
name=message_type,
value={TagValueKey.STRING_VALUE: message, "tool": tool_name},
),
],
mime_type=MimeTypes.TXT,
)
elif author_kind == AgentLogging.LLM:
llm_name = message_dict.get(AgentLogging.LLM_NAME, AgentLogging.LLM)
return self.chat_history.append_llm_message(
text=message,
tags=[
Tag(
kind=TagKind.LLM_STATUS_MESSAGE,
name=message_type,
value={TagValueKey.STRING_VALUE: message, "llm": llm_name},
),
],
mime_type=MimeTypes.TXT,
)