aidial_adapter_openai/gpt4_multi_modal/transformation.py (133 lines of code) (raw):
from dataclasses import dataclass
from typing import List, Optional, Set, cast
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from pydantic import BaseModel, Field
from aidial_adapter_openai.dial_api.resource import (
AttachmentResource,
DialResource,
URLResource,
ValidationError,
parse_attachment,
)
from aidial_adapter_openai.dial_api.storage import FileStorage
from aidial_adapter_openai.utils.image import ImageDetail, ImageMetadata
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.multi_modal_message import (
MultiModalMessage,
create_image_content_part,
create_text_content_part,
)
from aidial_adapter_openai.utils.resource import Resource
from aidial_adapter_openai.utils.text import decapitalize
# Officially supported image types by GPT-4 Vision, GPT-4o
SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp", "image/gif"]
SUPPORTED_FILE_EXTS = ["jpg", "jpeg", "png", "webp", "gif"]
@dataclass(order=True, frozen=True)
class TransformationError:
name: str
message: str
class ResourceProcessor(BaseModel):
class Config:
arbitrary_types_allowed = True # for errors
file_storage: FileStorage | None
errors: Set[TransformationError] = Field(default_factory=set)
def collect_resource(
self,
meta: List[ImageMetadata],
result: Resource | TransformationError,
detail: Optional[ImageDetail],
):
if isinstance(result, TransformationError):
self.errors.add(result)
else:
meta.append(ImageMetadata.from_resource(result, detail))
async def try_download_resource(
self, dial_resource: DialResource
) -> Resource | TransformationError:
try:
resource = await dial_resource.download(self.file_storage)
except Exception as e:
logger.error(
f"Failed to download {dial_resource.entity_name}: {str(e)}"
)
name = await dial_resource.get_resource_name(self.file_storage)
message = (
e.message
if isinstance(e, ValidationError)
else f"Failed to download the {dial_resource.entity_name}"
)
return TransformationError(name=name, message=message)
return resource
async def download_attachment_images(
self, attachments: List[dict]
) -> List[ImageMetadata]:
if attachments:
logger.debug(f"original attachments: {attachments}")
ret: List[ImageMetadata] = []
for attachment in attachments:
dial_resource = AttachmentResource(
attachment=parse_attachment(attachment),
entity_name="image attachment",
supported_types=SUPPORTED_IMAGE_TYPES,
)
result = await self.try_download_resource(dial_resource)
self.collect_resource(ret, result, None)
return ret
async def download_content_images(
self, content: str | list
) -> List[ImageMetadata]:
if isinstance(content, str):
return []
ret: List[ImageMetadata] = []
for content_part in content:
image_url = content_part.get("image_url")
if image_url and (url := image_url.get("url")):
detail = image_url.get("detail")
if detail not in [None, "auto", "low", "high"]:
raise ValidationError("Unexpected image detail")
dial_resource = URLResource(
url=url,
entity_name="image",
supported_types=SUPPORTED_IMAGE_TYPES,
)
result = await self.try_download_resource(dial_resource)
self.collect_resource(ret, result, detail)
return ret
async def transform_message(self, message: dict) -> MultiModalMessage:
message = message.copy()
content = message.get("content") or ""
custom_content = message.pop("custom_content", None) or {}
attachments = custom_content.get("attachments") or []
attachment_meta = await self.download_attachment_images(attachments)
content_meta = await self.download_content_images(content)
meta = [*content_meta, *attachment_meta]
if not meta:
return MultiModalMessage(image_metadatas=[], raw_message=message)
content_parts = (
[create_text_content_part(content)]
if isinstance(content, str)
else content
) + [
create_image_content_part(meta.image, meta.detail)
for meta in attachment_meta
]
return MultiModalMessage(
image_metadatas=meta,
raw_message={**message, "content": content_parts},
)
async def transform_messages(
self, messages: List[dict]
) -> List[MultiModalMessage] | DialException:
transformations = [
await self.transform_message(message) for message in messages
]
if self.errors:
image_fails = sorted(list(self.errors))
msg = "The following files failed to process:\n"
msg += "\n".join(
f"{idx}. {error.name}: {decapitalize(error.message)}"
for idx, error in enumerate(image_fails, start=1)
)
return InvalidRequestError(message=msg, display_message=msg)
transformations = cast(List[MultiModalMessage], transformations)
return transformations