aidial_adapter_vertexai/chat/gemini/processor.py (221 lines of code) (raw):

from abc import ABC from dataclasses import dataclass from logging import DEBUG from typing import ( Any, Callable, Coroutine, Dict, Generic, List, Optional, ParamSpec, Set, Union, assert_never, ) from aidial_sdk.chat_completion import ( Message, MessageContentImagePart, MessageContentTextPart, ) from google.genai.types import Part as GenAIPart from pydantic.v1 import BaseModel, Field from vertexai.preview.generative_models import Part from aidial_adapter_vertexai.chat.errors import ValidationError from aidial_adapter_vertexai.chat.gemini.conversation_factory import ( ConversationFactoryBase, PartT, ) from aidial_adapter_vertexai.dial_api.request import get_attachments from aidial_adapter_vertexai.dial_api.resource import ( AttachmentResource, DialResource, URLResource, ) from aidial_adapter_vertexai.dial_api.resource import ( ValidationError as ResourceValidationError, ) from aidial_adapter_vertexai.dial_api.storage import FileStorage from aidial_adapter_vertexai.utils.json import json_dumps_short from aidial_adapter_vertexai.utils.log_config import app_logger as log from aidial_adapter_vertexai.utils.pdf import get_pdf_page_count from aidial_adapter_vertexai.utils.resource import Resource from aidial_adapter_vertexai.utils.text import decapitalize FileTypes = Dict[str, Union[str, List[str]]] Coro = Coroutine[None, None, None] InitValidator = Callable[[], Coro] PostValidator = Callable[[Resource], Coro] class AttachmentProcessor(BaseModel): file_types: FileTypes init_validator: InitValidator | None = None post_validator: PostValidator | None = None @property def mime_types(self) -> List[str]: return list(self.file_types.keys()) @property def file_exts(self) -> List[str]: def to_list(value: Union[str, List[str]]) -> List[str]: return value if isinstance(value, list) else [value] return [ ext for exts in self.file_types.values() for ext in to_list(exts) ] async def process( self, file_storage: FileStorage | None, dial_resource: DialResource ) -> Optional[Resource | str]: try: type = await dial_resource.get_content_type() if type not in self.mime_types: return None if self.init_validator is not None: await self.init_validator() resource = await dial_resource.download(file_storage) if self.post_validator is not None: await self.post_validator(resource) return resource except Exception as e: log.error( f"Failed to download {dial_resource.entity_name}: {str(e)}" ) if isinstance(e, ResourceValidationError): return e.message return f"Failed to download {dial_resource.entity_name}" @dataclass(order=True, frozen=True) class ProcessingError: name: str message: str class AttachmentProcessorsBase(BaseModel, ABC, Generic[PartT]): class Config: arbitrary_types_allowed = True # for errors processors: List[AttachmentProcessor] file_storage: FileStorage | None conversation_factory: ConversationFactoryBase[PartT, Any, Any] errors: Set[ProcessingError] = Field(default_factory=set) resource_count: int = 0 def get_error_message(self) -> str | None: error_list = sorted(list(self.errors)) if error_list: msg = "The following files failed to process:\n" msg += "\n".join( f"{idx}. {error.name}: {decapitalize(error.message)}" for idx, error in enumerate(self.errors, start=1) ) return msg return None def get_file_exts(self) -> List[str]: return sorted({ext for p in self.processors for ext in p.file_exts}) def get_mime_types(self) -> List[str]: return sorted({ty for p in self.processors for ty in p.mime_types}) async def _collect_resource( self, dial_resource: DialResource, resource: Resource | str ) -> Resource | None: if log.isEnabledFor(DEBUG): log.debug(f"resource reference: {json_dumps_short(dial_resource)}") log.debug(f"resource content: {json_dumps_short(resource)}") if isinstance(resource, str): name = await dial_resource.get_resource_name(self.file_storage) self.errors.add(ProcessingError(name=name, message=resource)) return None else: self.resource_count += 1 return resource async def process_resource( self, dial_resource: DialResource ) -> Resource | None: if not self.processors: raise ValidationError("The attachments aren't supported") for processor in self.processors: resource = await processor.process(self.file_storage, dial_resource) if resource is not None: return await self._collect_resource(dial_resource, resource) return await self._collect_resource( dial_resource, f"The {dial_resource.entity_name} isn't one of the supported types", ) async def process_message(self, message: Message) -> List[PartT]: ret: List[PartT] = [] async def collect_resource(dial_resource: DialResource): resource = await self.process_resource(dial_resource) if resource is not None: ret.append( self.conversation_factory.create_multi_modal_part( resource.data, resource.type ) ) def collect_text(text: str): ret.append(self.conversation_factory.create_text_part(text)) # Placing Images/Video parts before the text as per # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts?authuser=1#image_best_practices for attachment in get_attachments(message): await collect_resource(AttachmentResource(attachment=attachment)) content = message.content match content: case None: pass case str(): if content: collect_text(content) case list(): for part in content: match part: case MessageContentTextPart(text=text): collect_text(text) case MessageContentImagePart(image_url=image_url): await collect_resource( URLResource( url=image_url.url, entity_name="image_url" ) ) case _: assert_never(content) return ret class AttachmentProcessors(AttachmentProcessorsBase[Part]): pass class AttachmentProcessorsGenAI(AttachmentProcessorsBase[GenAIPart]): pass def max_count_validator(limit: int) -> InitValidator: count = 0 async def validator(): nonlocal count count += 1 if count > limit: raise ValidationError( f"The number of files exceeds the limit ({limit})" ) return validator def max_pdf_page_count_validator(limit: int) -> PostValidator: count = 0 async def validator(resource: Resource): nonlocal count try: pages = await get_pdf_page_count(resource.data) log.debug(f"PDF page count: {pages}") count += pages except Exception: log.exception("Failed to get PDF page count") raise ValidationError("Failed to get PDF page count") if count > limit: raise ValidationError( f"The total number of PDF pages exceeds the limit ({limit})" ) return validator P = ParamSpec("P") def seq_validators(*validators: Callable[P, Coro] | None) -> Callable[P, Coro]: async def validator(*args: P.args, **kwargs: P.kwargs) -> None: for v in validators: if v is not None: await v(*args, **kwargs) return validator def exclusive_validator() -> Callable[[str], InitValidator]: first: str | None = None def get_validator(name: str) -> InitValidator: async def validator(): nonlocal first if first is None: first = name elif first != name: raise ValidationError( f"The document type is {name!r}. " f"However, one of the documents processed earlier was of {first!r} type. " "Only one type of document is supported at a time." ) return validator return get_validator