Source code for steamship.plugin.generator

import logging
import uuid
from abc import ABC, abstractmethod

from steamship import SteamshipError
from steamship.data.block import Block, BlockUploadType
from steamship.data.workspace import SignedUrl
from steamship.invocable import InvocableResponse, post
from steamship.invocable.plugin_service import PluginRequest, PluginService, TrainablePluginService
from steamship.plugin.inputs.raw_block_and_tag_plugin_input import RawBlockAndTagPluginInput
from steamship.plugin.inputs.train_plugin_input import TrainPluginInput
from steamship.plugin.inputs.training_parameter_plugin_input import TrainingParameterPluginInput
from steamship.plugin.outputs.raw_block_and_tag_plugin_output import RawBlockAndTagPluginOutput
from steamship.plugin.outputs.train_plugin_output import TrainPluginOutput
from steamship.plugin.outputs.training_parameter_plugin_output import TrainingParameterPluginOutput
from steamship.plugin.trainable_model import TrainableModel

# Note!
# =====
#
# This is the PLUGIN IMPLEMENTOR's View of a Generator.
#
# If you are using the Steamship Client, you probably want steamship.client.operations.generator instead
# of this file.
#
from steamship.utils.signed_urls import upload_to_signed_url


[docs] class Generator(PluginService[RawBlockAndTagPluginInput, RawBlockAndTagPluginOutput], ABC):
[docs] @abstractmethod def run( self, request: PluginRequest[RawBlockAndTagPluginInput] ) -> InvocableResponse[RawBlockAndTagPluginOutput]: raise NotImplementedError()
[docs] @post("generate") def run_endpoint(self, **kwargs) -> InvocableResponse[RawBlockAndTagPluginOutput]: """Exposes the Tagger's `run` operation to the Steamship Engine via the expected HTTP path POST /tag""" input = PluginRequest[RawBlockAndTagPluginInput].parse_obj(kwargs) for block in input.data.blocks: block.client = self.client result = self.run(input) # Rewrite block output by changing any blocks with byte content to pass by URL if result.data is not None and result.data.blocks is not None: result_blocks = [] for block in result.data.blocks: if block.upload_type == BlockUploadType.FILE: result_blocks.append(self.upload_block_content_to_signed_url(block)) else: result_blocks.append(block) result.data.blocks = result_blocks return result
[docs] def upload_block_content_to_signed_url(self, block: Block) -> Block: """Recreate the block (create request) as a URL request, rather than direct content, since we can't do a multipart file upload from here.""" if block.upload_bytes is None: raise SteamshipError( "There was an error with the plugin. When returning upload type FILE, the content may not be None." ) filepath = str(uuid.uuid4()) signed_url = ( self.client.get_workspace() .create_signed_url( SignedUrl.Request( bucket=SignedUrl.Bucket.PLUGIN_DATA, filepath=filepath, operation=SignedUrl.Operation.WRITE, ) ) .signed_url ) logging.info(f"Got signed url for uploading block content: {signed_url}") upload_to_signed_url(signed_url, block.upload_bytes) read_signed_url = ( self.client.get_workspace() .create_signed_url( SignedUrl.Request( bucket=SignedUrl.Bucket.PLUGIN_DATA, filepath=filepath, operation=SignedUrl.Operation.READ, ) ) .signed_url ) return Block( url=read_signed_url, upload_type=BlockUploadType.URL, mime_type=block.mime_type, tags=block.tags, text=block.text, )
[docs] class TrainableGenerator( TrainablePluginService[RawBlockAndTagPluginInput, RawBlockAndTagPluginOutput], ABC ):
[docs] @abstractmethod def run_with_model( self, request: PluginRequest[RawBlockAndTagPluginInput], model: TrainableModel ) -> InvocableResponse[RawBlockAndTagPluginOutput]: raise NotImplementedError()
# noinspection PyUnusedLocal
[docs] @post("generate") def run_endpoint(self, **kwargs) -> InvocableResponse[RawBlockAndTagPluginOutput]: """Exposes the Tagger's `run` operation to the Steamship Engine via the expected HTTP path POST /generate""" return self.run(PluginRequest[RawBlockAndTagPluginInput].parse_obj(kwargs))
# noinspection PyUnusedLocal
[docs] @post("getTrainingParameters") def get_training_parameters_endpoint( self, **kwargs ) -> InvocableResponse[TrainingParameterPluginOutput]: """Exposes the Service's `get_training_parameters` operation to the Steamship Engine via the expected HTTP path POST /getTrainingParameters""" return self.get_training_parameters(PluginRequest[TrainingParameterPluginInput](**kwargs))
# noinspection PyUnusedLocal
[docs] @post("train") def train_endpoint(self, **kwargs) -> InvocableResponse[TrainPluginOutput]: """Exposes the Service's `train` operation to the Steamship Engine via the expected HTTP path POST /train""" logging.info(f"Tagger:train_endpoint called. Calling train {kwargs}") arg = PluginRequest[TrainPluginInput].parse_obj(kwargs) model = self.model_cls()() model.receive_config(config=self.config) if arg.is_status_check: return self.train_status(arg, model) else: return self.train(arg, model)