aidial_adapter_bedrock/llm/model/stability/v1.py (95 lines of code) (raw):

from enum import Enum from typing import Any, Dict, List, Optional from aidial_sdk.chat_completion import Attachment, Message from pydantic import BaseModel, Field from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.storage import ( FileStorage, create_file_storage, ) from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.llm.model.stability.message import ( parse_message, validate_last_message, ) from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages class StabilityStatus(str, Enum): SUCCESS = "success" ERROR = "error" class StabilityError(BaseModel): id: str message: str name: str class StabilityArtifact(BaseModel): seed: int base64: str finish_reason: str = Field(alias="finishReason") class StabilityResponse(BaseModel): # TODO: Use tagged union artifacts/error result: StabilityStatus artifacts: Optional[list[StabilityArtifact]] error: Optional[StabilityError] def content(self) -> str: self._throw_if_error() # NOTE: text-to-text models aren't expected to generate empty strings. # So since we represent text-to-image model (Stability) as # a text-to-text model (via chat completion interface), # we need to return something. return " " def attachments(self) -> list[Attachment]: self._throw_if_error() return [ Attachment( title="Image", type="image/png", data=self.artifacts[0].base64, # type: ignore ) ] def usage(self) -> TokenUsage: self._throw_if_error() return TokenUsage( prompt_tokens=0, completion_tokens=1, ) def _throw_if_error(self): if self.result == StabilityStatus.ERROR: raise Exception(self.error.message) # type: ignore def create_request(prompt: str) -> Dict[str, Any]: return {"text_prompts": [{"text": prompt}]} class StabilityV1Adapter(ChatCompletionAdapter): model: str client: Bedrock storage: Optional[FileStorage] @classmethod def create(cls, client: Bedrock, model: str, api_key: str): storage: Optional[FileStorage] = create_file_storage(api_key) return cls(client=client, model=model, storage=storage) async def compute_discarded_messages( self, params: ModelParameters, messages: List[Message] ) -> DiscardedMessages | None: validate_last_message(messages) return list(range(len(messages) - 1)) async def chat( self, consumer: Consumer, params: ModelParameters, messages: List[Message], ) -> None: message = validate_last_message(messages) text_prompt, image_resources = parse_message(message, []) if image_resources: raise UserError("Image-to-Image is not supported") if text_prompt is None: raise ValidationError("Content of the last message is missing") args = create_request(text_prompt) response, _headers = await self.client.ainvoke_non_streaming( self.model, args ) resp = StabilityResponse.parse_obj(response) consumer.append_content(resp.content()) consumer.close_content() consumer.add_usage(resp.usage()) for attachment in resp.attachments(): if self.storage: attachment = await save_to_storage(self.storage, attachment) consumer.add_attachment(attachment)