Source code for steamship.agents.functional.output_parser

import json
import re
import string
from json import JSONDecodeError
from typing import Dict, List, Optional

from steamship import Block, MimeTypes, Steamship, Tag
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.utils.utils import is_valid_uuid4


[docs] def is_punctuation(text: str): for c in text: if c not in string.punctuation: return False return True
[docs] class FunctionsBasedOutputParser(OutputParser): tools_lookup_dict: Optional[Dict[str, Tool]] = None def __init__(self, **kwargs): tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])} super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs) def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action: wrapper = json.loads(text) fc = wrapper.get("function_call") name = fc.get("name", "") if name.startswith("functions."): name = name[len("functions.") :] # occasionally, OpenAI prepends "functions." tool = self.tools_lookup_dict.get(name, None) if tool is None: raise RuntimeError( f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}" ) input_blocks = [] arguments = fc.get("arguments") if arguments: try: args = json.loads(arguments) if text := args.get("text"): input_blocks.append( Block( text=text, tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], mime_type=MimeTypes.TXT, ) ) elif uuid_arg := args.get("uuid"): existing_block = Block.get(context.client, _id=uuid_arg) tag = Tag.create( existing_block.client, file_id=existing_block.file_id, block_id=existing_block.id, kind=TagKind.FUNCTION_ARG, name="uuid", ) existing_block.tags.append(tag) input_blocks.append(existing_block) except json.decoder.JSONDecodeError: if isinstance(arguments, str): if is_valid_uuid4(arguments): existing_block = Block.get(context.client, _id=arguments) tag = Tag.create( existing_block.client, file_id=existing_block.file_id, block_id=existing_block.id, kind=TagKind.FUNCTION_ARG, name="uuid", ) existing_block.tags.append(tag) input_blocks.append(existing_block) else: input_blocks.append( Block( text=arguments, tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], mime_type=MimeTypes.TXT, ) ) return Action(tool=tool.name, input=input_blocks, context=context) @staticmethod def _blocks_from_text(client: Steamship, text: str) -> List[Block]: last_response = text.split("AI:")[-1].strip() block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?" remaining_text = last_response result_blocks: List[Block] = [] while remaining_text is not None and len(remaining_text.strip()) > 0: if is_punctuation(remaining_text.strip()): remaining_text = "" continue match = re.search(block_id_regex, remaining_text) if match: pre_block_text = FunctionsBasedOutputParser._remove_block_prefix( candidate=remaining_text[0 : match.start()] ) if len(pre_block_text) > 0: result_blocks.append(Block(text=pre_block_text)) result_blocks.append(Block.get(client, _id=match.group(1))) remaining_text = FunctionsBasedOutputParser._remove_block_suffix( remaining_text[match.end() :] ) else: result_blocks.append(Block(text=remaining_text)) remaining_text = "" return result_blocks @staticmethod def _remove_block_prefix(candidate: str) -> str: removed = candidate if removed.endswith("(Block") or removed.endswith("[Block"): removed = removed[len("Block") + 1 :] elif removed.endswith("Block"): removed = removed[len("Block") :] return removed @staticmethod def _remove_block_suffix(candidate: str) -> str: removed = candidate if removed.startswith(")") or removed.endswith("]"): removed = removed[1:] return removed
[docs] def parse(self, text: str, context: AgentContext) -> Action: if "function_call" in text: try: # catch invalid JSON. If it is not valid JSON, just treat as "regular" message return self._extract_action_from_function_call(text, context) except JSONDecodeError: pass finish_blocks = FunctionsBasedOutputParser._blocks_from_text(context.client, text) for finish_block in finish_blocks: finish_block.set_chat_role(RoleTag.ASSISTANT) finish_block.set_request_id(context.request_id) return FinishAction(output=finish_blocks, context=context)