import json
import re
import string
from json import JSONDecodeError
from typing import Dict, List, Optional
from steamship import Block, MimeTypes, Steamship, Tag
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.utils.utils import is_valid_uuid4
[docs]
def is_punctuation(text: str):
for c in text:
if c not in string.punctuation:
return False
return True
[docs]
class FunctionsBasedOutputParser(OutputParser):
tools_lookup_dict: Optional[Dict[str, Tool]] = None
def __init__(self, **kwargs):
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])}
super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs)
def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action:
wrapper = json.loads(text)
fc = wrapper.get("function_call")
name = fc.get("name", "")
if name.startswith("functions."):
name = name[len("functions.") :] # occasionally, OpenAI prepends "functions."
tool = self.tools_lookup_dict.get(name, None)
if tool is None:
raise RuntimeError(
f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}"
)
input_blocks = []
arguments = fc.get("arguments")
if arguments:
try:
args = json.loads(arguments)
if text := args.get("text"):
input_blocks.append(
Block(
text=text,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
elif uuid_arg := args.get("uuid"):
existing_block = Block.get(context.client, _id=uuid_arg)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
except json.decoder.JSONDecodeError:
if isinstance(arguments, str):
if is_valid_uuid4(arguments):
existing_block = Block.get(context.client, _id=arguments)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
else:
input_blocks.append(
Block(
text=arguments,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
return Action(tool=tool.name, input=input_blocks, context=context)
@staticmethod
def _blocks_from_text(client: Steamship, text: str) -> List[Block]:
last_response = text.split("AI:")[-1].strip()
block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?"
remaining_text = last_response
result_blocks: List[Block] = []
while remaining_text is not None and len(remaining_text.strip()) > 0:
if is_punctuation(remaining_text.strip()):
remaining_text = ""
continue
match = re.search(block_id_regex, remaining_text)
if match:
pre_block_text = FunctionsBasedOutputParser._remove_block_prefix(
candidate=remaining_text[0 : match.start()]
)
if len(pre_block_text) > 0:
result_blocks.append(Block(text=pre_block_text))
result_blocks.append(Block.get(client, _id=match.group(1)))
remaining_text = FunctionsBasedOutputParser._remove_block_suffix(
remaining_text[match.end() :]
)
else:
result_blocks.append(Block(text=remaining_text))
remaining_text = ""
return result_blocks
@staticmethod
def _remove_block_prefix(candidate: str) -> str:
removed = candidate
if removed.endswith("(Block") or removed.endswith("[Block"):
removed = removed[len("Block") + 1 :]
elif removed.endswith("Block"):
removed = removed[len("Block") :]
return removed
@staticmethod
def _remove_block_suffix(candidate: str) -> str:
removed = candidate
if removed.startswith(")") or removed.endswith("]"):
removed = removed[1:]
return removed
[docs]
def parse(self, text: str, context: AgentContext) -> Action:
if "function_call" in text:
try:
# catch invalid JSON. If it is not valid JSON, just treat as "regular" message
return self._extract_action_from_function_call(text, context)
except JSONDecodeError:
pass
finish_blocks = FunctionsBasedOutputParser._blocks_from_text(context.client, text)
for finish_block in finish_blocks:
finish_block.set_chat_role(RoleTag.ASSISTANT)
finish_block.set_request_id(context.request_id)
return FinishAction(output=finish_blocks, context=context)