aidial_adapter_bedrock/llm/model/stability/v2.py (172 lines of code) (raw):

from io import BytesIO from typing import List, Optional, Tuple from aidial_sdk.chat_completion import Attachment, Message from aidial_sdk.exceptions import RequestValidationError from PIL import Image from pydantic import BaseModel from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.resource import ( DialResource, UnsupportedContentType, ) from aidial_adapter_bedrock.dial_api.storage import ( FileStorage, create_file_storage, ) from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.llm.model.stability.message import ( parse_message, validate_last_message, ) from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages from aidial_adapter_bedrock.utils.json import remove_nones from aidial_adapter_bedrock.utils.resource import Resource SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"] SUPPORTED_IMAGE_EXTENSIONS = ["jpeg", "jpe", "jpg", "png", "webp"] async def _download_resource( dial_resource: DialResource, storage: FileStorage | None ) -> Resource: try: return await dial_resource.download(storage) except UnsupportedContentType as e: raise UserError( error_message=f"Unsupported image type: {e.type}", usage_message=f"Supported image types: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}", ) def _validate_image_size( image: Resource, width_constraints: Tuple[int, int] | None, height_constraints: Tuple[int, int] | None, ) -> None: if width_constraints is None and height_constraints is None: return with Image.open(BytesIO(image.data)) as img: width, height = img.size for constraints, value, name in [ (width_constraints, width, "width"), (height_constraints, height, "height"), ]: if constraints is None: continue min_value, max_value = constraints if not (min_value <= value <= max_value): error_msg = ( f"Image {name} is {value}, but should be " f"between {min_value} and {max_value}" ) raise RequestValidationError( message=error_msg, display_message=error_msg, code="invalid_argument", ) class StabilityV2Response(BaseModel): seeds: List[int] images: List[str] # None will indicate that the request was successful # Possible values: # "Filter reason: prompt" # "Filter reason: output image" # "Filter reason: input image" # "Inference error" # null finish_reasons: List[Optional[str]] def content(self) -> str: return " " def attachments(self) -> List[Attachment]: return [ Attachment( title="Image", type="image/png", data=image, ) for image in self.images ] def usage(self) -> TokenUsage: return TokenUsage(prompt_tokens=0, completion_tokens=1) def throw_if_error(self): error = next((reason for reason in self.finish_reasons if reason), None) if not error: return if error == "Inference error": raise RuntimeError(error) else: raise ValidationError(error) class StabilityV2Adapter(ChatCompletionAdapter): model: str client: Bedrock storage: Optional[FileStorage] image_to_image_supported: bool width_constraints: Tuple[int, int] | None height_constraints: Tuple[int, int] | None @classmethod def create( cls, client: Bedrock, model: str, api_key: str, image_to_image_supported: bool, image_width_constraints: Tuple[int, int] | None = None, image_height_constraints: Tuple[int, int] | None = None, ): storage: Optional[FileStorage] = create_file_storage(api_key) return cls( client=client, model=model, storage=storage, image_to_image_supported=image_to_image_supported, width_constraints=image_width_constraints, height_constraints=image_height_constraints, ) async def compute_discarded_messages( self, params: ModelParameters, messages: List[Message] ) -> DiscardedMessages | None: validate_last_message(messages) return list(range(len(messages) - 1)) async def chat( self, consumer: Consumer, params: ModelParameters, messages: List[Message], ) -> None: message = validate_last_message(messages) text_prompt, image_resources = parse_message( message, SUPPORTED_IMAGE_TYPES ) if not self.image_to_image_supported and image_resources: raise UserError("Image-to-Image is not supported") if len(image_resources) > 1: raise UserError("Only one input image is supported") if self.image_to_image_supported and image_resources: image_resource = await _download_resource( image_resources[0], self.storage ) _validate_image_size( image_resource, self.width_constraints, self.height_constraints ) else: image_resource = None if not text_prompt: raise UserError("Text prompt is required") response, _ = await self.client.ainvoke_non_streaming( self.model, remove_nones( { "prompt": text_prompt, "image": ( image_resource.data_base64 if image_resource else None ), "mode": ( "image-to-image" if image_resource else "text-to-image" ), "output_format": "png", # This parameter controls how much input image will affect generation from 0 to 1, # where 0 means that output will be identical to input image and 1 means that model will ignore input image # Since there is no recommended default value, we use 0.5 as a middle ground "strength": 0.5 if image_resource else None, } ), ) stability_response = StabilityV2Response.parse_obj(response) stability_response.throw_if_error() consumer.append_content(stability_response.content()) consumer.close_content() consumer.add_usage(stability_response.usage()) for attachment in stability_response.attachments(): if self.storage: attachment = await save_to_storage(self.storage, attachment) consumer.add_attachment(attachment)