Source code for steamship.data.embeddings

from __future__ import annotations

import json
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, Field

from steamship import SteamshipError
from steamship.base import Task
from steamship.base.client import Client
from steamship.base.model import CamelModel
from steamship.base.request import DeleteRequest, ListRequest, Request, SortOrder
from steamship.base.response import ListResponse, Response
from steamship.data.search import Hit
from steamship.utils.metadata import metadata_to_str

MAX_RECOMMENDED_ITEM_LENGTH = 5000


[docs] class EmbedAndSearchRequest(Request): query: str docs: List[str] plugin_instance: str k: int = 1
[docs] class QueryResult(CamelModel): value: Optional[Hit] = None score: Optional[float] = None index: Optional[int] = None id: Optional[str] = None
[docs] class QueryResults(Request): items: List[QueryResult] = None
[docs] class EmbeddedItem(CamelModel): id: str = None index_id: str = None file_id: str = None block_id: str = None tag_id: str = None value: str = None external_id: str = None external_type: str = None metadata: Any = None embedding: List[float] = None
[docs] def clone_for_insert(self) -> EmbeddedItem: """Produces a clone with a string representation of the metadata""" ret = EmbeddedItem( id=self.id, index_id=self.index_id, file_id=self.file_id, block_id=self.block_id, tag_id=self.tag_id, value=self.value, external_id=self.external_id, external_type=self.external_type, metadata=self.metadata, embedding=self.embedding, ) if isinstance(ret.metadata, dict) or isinstance(ret.metadata, list): ret.metadata = json.dumps(ret.metadata) return ret
[docs] class IndexCreateRequest(Request): handle: str = None name: str = None plugin_instance: str = None fetch_if_exists: bool = True external_id: str = None external_type: str = None metadata: Any = None
[docs] class IndexInsertRequest(Request): index_id: str items: List[EmbeddedItem] = None value: str = None file_id: str = None block_type: str = None external_id: str = None external_type: str = None metadata: Any = None reindex: bool = True
[docs] class IndexItemId(CamelModel): index_id: str = None id: str = None
[docs] class IndexInsertResponse(Response): item_ids: List[IndexItemId] = None
[docs] class IndexEmbedRequest(Request): id: str
[docs] class IndexEmbedResponse(Response): id: Optional[str] = None
[docs] class IndexSearchRequest(Request): id: str query: str = None queries: List[str] = None k: int = 1 include_metadata: bool = False
[docs] class ListItemsRequest(ListRequest): id: str = None file_id: str = None block_id: str = None span_id: str = None
[docs] class ListItemsResponse(ListResponse): items: List[EmbeddedItem]
[docs] class EmbeddingIndex(CamelModel): """A persistent, read-optimized index over embeddings.""" client: Client = Field(None, exclude=True) id: str = None handle: str = None name: str = None plugin: str = None external_id: str = None external_type: str = None metadata: str = None
[docs] @classmethod def parse_obj(cls: Type[BaseModel], obj: Any) -> BaseModel: # TODO (enias): This needs to be solved at the engine side if "embeddingIndex" in obj: obj = obj["embeddingIndex"] elif "index" in obj: obj = obj["index"] return super().parse_obj(obj)
[docs] def insert_file( self, file_id: str, block_type: str = None, external_id: str = None, external_type: str = None, metadata: Union[int, float, bool, str, List, Dict] = None, reindex: bool = True, ) -> IndexInsertResponse: if isinstance(metadata, dict) or isinstance(metadata, list): metadata = json.dumps(metadata) req = IndexInsertRequest( index_id=self.id, file_id=file_id, blockType=block_type, external_id=external_id, external_type=external_type, metadata=metadata, reindex=reindex, ) return self.client.post( "embedding-index/item/create", req, expect=IndexInsertResponse, )
def _check_input(self, request: IndexInsertRequest, allow_long_records: bool): if not allow_long_records: if request.value is not None and len(request.value) > MAX_RECOMMENDED_ITEM_LENGTH: raise SteamshipError( f"Inserted item of length {len(request.value)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True." ) if request.items is not None: for i, item in enumerate(request.items): if item is not None: if isinstance(item, str) and len(item) > MAX_RECOMMENDED_ITEM_LENGTH: raise SteamshipError( f"Inserted item {i} of length {len(item)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True." ) if ( isinstance(item, EmbeddedItem) and item.value is not None and len(item.value) > MAX_RECOMMENDED_ITEM_LENGTH ): raise SteamshipError( f"Inserted item {i} of length {len(item.value)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True." )
[docs] def insert_many( self, items: List[Union[EmbeddedItem, str]], reindex: bool = True, allow_long_records=False, ) -> IndexInsertResponse: new_items = [] for item in items: if isinstance(item, str): new_items.append(EmbeddedItem(value=item)) else: new_items.append(item) req = IndexInsertRequest( index_id=self.id, items=[item.clone_for_insert() for item in new_items], reindex=reindex, ) self._check_input(req, allow_long_records) return self.client.post( "embedding-index/item/create", req, expect=IndexInsertResponse, )
[docs] def insert( self, value: str, external_id: str = None, external_type: str = None, metadata: Union[int, float, bool, str, List, Dict] = None, reindex: bool = True, allow_long_records=False, ) -> IndexInsertResponse: req = IndexInsertRequest( index_id=self.id, value=value, external_id=external_id, external_type=external_type, metadata=metadata_to_str(metadata), reindex=reindex, ) self._check_input(req, allow_long_records) return self.client.post( "embedding-index/item/create", req, expect=IndexInsertResponse, )
[docs] def embed( self, ) -> Task[IndexEmbedResponse]: req = IndexEmbedRequest(id=self.id) return self.client.post( "embedding-index/embed", req, expect=IndexEmbedResponse, )
[docs] def list_items( self, file_id: str = None, block_id: str = None, span_id: str = None, page_size: Optional[int] = None, page_token: Optional[str] = None, sort_order: Optional[SortOrder] = SortOrder.DESC, ) -> ListItemsResponse: req = ListItemsRequest( id=self.id, file_id=file_id, block_id=block_id, spanId=span_id, page_size=page_size, page_token=page_token, sort_order=sort_order, ) return self.client.post( "embedding-index/item/list", req, expect=ListItemsResponse, )
[docs] def delete(self) -> EmbeddingIndex: return self.client.post( "embedding-index/delete", DeleteRequest(id=self.id), expect=EmbeddingIndex, )
[docs] def search( self, query: Union[str, List[str]], k: int = 1, include_metadata: bool = False, ) -> Task[QueryResults]: if isinstance(query, list): req = IndexSearchRequest( id=self.id, queries=query, k=k, include_metadata=include_metadata ) else: req = IndexSearchRequest( id=self.id, query=query, k=k, include_metadata=include_metadata ) ret = self.client.post( "embedding-index/search", req, expect=QueryResults, ) return ret
[docs] @staticmethod def create( client: Client, handle: str = None, name: str = None, embedder_plugin_instance_handle: str = None, fetch_if_exists: bool = True, external_id: str = None, external_type: str = None, metadata: Any = None, ) -> EmbeddingIndex: req = IndexCreateRequest( handle=handle, name=name, plugin_instance=embedder_plugin_instance_handle, fetch_if_exists=fetch_if_exists, external_id=external_id, external_type=external_type, metadata=metadata, ) return client.post( "embedding-index/create", req, expect=EmbeddingIndex, )