Source code for steamship.agents.tools.classification.zero_shot_classifier_tool
from typing import List, Optional
from steamship import Steamship
from steamship.agents.llms import OpenAI
from steamship.agents.tools.text_generation.text_rewrite_tool import TextRewritingTool
from steamship.agents.utils import with_llm
from steamship.utils.repl import ToolREPL
DEFAULT_LABELS = ["question", "complaint", "suggestion"]
DEFAULT_PROMPT = """Instructions:
Please classify the following message into one of the following output labels. Respond with exactly one and only one label.
Labels:
{labels}
Passage:
{{input}}
Label describing passage:"""
[docs]
class ZeroShotClassifierTool(TextRewritingTool):
"""
Example tool to illustrate how one might classify a user message.
For example: the agent may wish to know if the use message was a question, complaint, or suggestion.
TODO: This feels like it wants to emit data to a side channel. Or perhaps it TAGS the user input block?
"""
name = "ZeroShotClassifierTool"
human_description = "Classifies a user message."
agent_description = "Used to classify a user message. The input is a string, and the output is a string with the classification label."
labels: List[str] = DEFAULT_LABELS
rewrite_prompt: str = DEFAULT_PROMPT
def __init__(
self, labels: Optional[List[str]] = None, rewrite_prompt: Optional[str] = None, **kwargs
):
_rewrite_prompt = rewrite_prompt or DEFAULT_PROMPT
kwargs["rewrite_prompt"] = kwargs.get(
"rewrite_prompt", _rewrite_prompt.format(labels=labels or DEFAULT_LABELS)
)
super().__init__(**kwargs)
if __name__ == "__main__":
tool = ZeroShotClassifierTool()
with Steamship.temporary_workspace() as client:
ToolREPL(tool).run_with_client(client=client, context=with_llm(llm=OpenAI(client=client)))