Source code for steamship.plugin.streaming_generator

from abc import ABC, abstractmethod
from typing import List

from steamship.data.block import Block, StreamState
from steamship.invocable import InvocableResponse, post
from steamship.invocable.plugin_service import PluginRequest, PluginService
from steamship.plugin.inputs.raw_block_and_tag_plugin_input import RawBlockAndTagPluginInput
from steamship.plugin.inputs.raw_block_and_tag_plugin_input_with_preallocated_blocks import (
    RawBlockAndTagPluginInputWithPreallocatedBlocks,
)
from steamship.plugin.outputs.block_type_plugin_output import BlockTypePluginOutput
from steamship.plugin.outputs.stream_complete_plugin_output import StreamCompletePluginOutput

# Note!
# =====
#
# This is the PLUGIN IMPLEMENTOR's View of a Streaming Generator.
#
# If you are using the Steamship Client, you probably want steamship.client.operations.generator instead
# of this file.
#


[docs] class StreamingGenerator( PluginService[RawBlockAndTagPluginInputWithPreallocatedBlocks, StreamCompletePluginOutput], ABC ):
[docs] @abstractmethod def run( self, request: PluginRequest[RawBlockAndTagPluginInputWithPreallocatedBlocks] ) -> InvocableResponse[StreamCompletePluginOutput]: raise NotImplementedError()
[docs] @post("streamResultToBlocks") def run_endpoint(self, **kwargs) -> InvocableResponse[StreamCompletePluginOutput]: input = PluginRequest[RawBlockAndTagPluginInputWithPreallocatedBlocks].parse_obj(kwargs) for block in input.data.blocks: block.client = self.client for block in input.data.output_blocks: block.client = self.client return self.run(input)
[docs] @post("determineOutputBlockTypes") def determine_output_block_types_endpoint( self, **kwargs ) -> InvocableResponse[BlockTypePluginOutput]: input = PluginRequest[RawBlockAndTagPluginInput].parse_obj(kwargs) for block in input.data.blocks: block.client = self.client try: return self.determine_output_block_types(input) except BaseException as e: # If anything goes wrong, make sure # we automatically abort any open streams # so the client can know self.abort_open_block_streams(input.data.blocks) raise e
[docs] @abstractmethod def determine_output_block_types( self, request: PluginRequest[RawBlockAndTagPluginInput] ) -> InvocableResponse[BlockTypePluginOutput]: raise NotImplementedError()
[docs] def abort_open_block_streams(self, blocks: List[Block]): for block in blocks: refreshed_block = Block.get(block.client, block.id) if refreshed_block.stream_state == StreamState.STARTED: refreshed_block.abort_stream()