aidial_adapter_vertexai/chat/imagen/adapter.py (117 lines of code) (raw):

from typing import List, Optional from aidial_sdk.chat_completion import Attachment, Message from PIL import Image as PIL_Image from typing_extensions import override from vertexai.preview.vision_models import ( GeneratedImage, ImageGenerationModel, ImageGenerationResponse, ) from aidial_adapter_vertexai.chat.chat_completion_adapter import ( ChatCompletionAdapter, ) from aidial_adapter_vertexai.chat.consumer import Consumer from aidial_adapter_vertexai.chat.errors import ValidationError from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig from aidial_adapter_vertexai.chat.tools import ToolsConfig from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt from aidial_adapter_vertexai.dial_api.request import ( ModelParameters, collect_text_content, ) from aidial_adapter_vertexai.dial_api.storage import ( FileStorage, compute_hash_digest, ) from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log from aidial_adapter_vertexai.utils.timer import Timer from aidial_adapter_vertexai.vertex_ai import get_image_generation_model ImagenPrompt = str class ImagenChatCompletionAdapter(ChatCompletionAdapter[ImagenPrompt]): def __init__( self, file_storage: Optional[FileStorage], model: ImageGenerationModel, ): self.file_storage = file_storage self.model = model @override async def parse_prompt( self, tools: ToolsConfig, static_tools: StaticToolsConfig, messages: List[Message], ) -> ImagenPrompt: tools.not_supported() static_tools.not_supported() if len(messages) == 0: raise ValidationError("The list of messages must not be empty") content = messages[-1].content if content is None: raise ValidationError("The last message must have content") return collect_text_content(content) @override async def truncate_prompt( self, prompt: ImagenPrompt, max_prompt_tokens: int ) -> TruncatedPrompt[ImagenPrompt]: return TruncatedPrompt(discarded_messages=[], prompt=prompt) @staticmethod def get_image_type(image: PIL_Image.Image) -> str: match image.format: case "JPEG": return "image/jpeg" case "PNG": return "image/png" case _: raise ValueError(f"Unknown image format: {image.format}") @override async def chat( self, params: ModelParameters, consumer: Consumer, prompt: ImagenPrompt ) -> None: prompt_tokens = await self.count_prompt_tokens(prompt) with Timer("predict timing: {time}", log.debug): response: ImageGenerationResponse = self.model.generate_images( prompt, number_of_images=1, seed=None ) if len(response.images) == 0: raise RuntimeError("Expected 1 image in response, but got none") image: GeneratedImage = response[0] type: str = self.get_image_type(image._pil_image) data: bytes = image._image_bytes base64_data: str = image._as_base64_string() attachment: Attachment = Attachment( title="Image", type=type, data=base64_data ) if self.file_storage is not None: with Timer("upload to file storage: {time}", log.debug): filename = "images/" + compute_hash_digest(base64_data) meta = await self.file_storage.upload( filename=filename, content_type=type, content=data ) attachment.data = None attachment.url = meta["url"] await consumer.add_attachment(attachment) # Avoid generating empty content completion = " " await consumer.append_content(completion) completion_tokens = await self.count_completion_tokens(completion) await consumer.set_usage( TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) ) @override async def count_prompt_tokens(self, prompt: ImagenPrompt) -> int: return 0 @override async def count_completion_tokens(self, string: str) -> int: return 1 @classmethod async def create( cls, file_storage: Optional[FileStorage], model_id: str, ) -> "ImagenChatCompletionAdapter": model = await get_image_generation_model(model_id) return cls(file_storage, model)