aidial_adapter_vertexai/embedding/multi_modal.py (174 lines of code) (raw):
from logging import DEBUG
from typing import AsyncIterator, Callable, List, Tuple
from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings.request import EmbeddingsRequest
from pydantic.v1 import BaseModel
from vertexai.vision_models import (
Image,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)
from aidial_adapter_vertexai.chat.errors import UserError, ValidationError
from aidial_adapter_vertexai.dial_api.embedding_inputs import (
EMPTY_INPUT_LIST_ERROR,
collect_embedding_inputs,
)
from aidial_adapter_vertexai.dial_api.resource import AttachmentResource
from aidial_adapter_vertexai.dial_api.storage import FileStorage
from aidial_adapter_vertexai.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_vertexai.embedding.types import (
Embedding,
make_embeddings_response,
vector_to_embedding,
)
from aidial_adapter_vertexai.utils.concurrency import gather_sync
from aidial_adapter_vertexai.utils.json import json_dumps_short
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.vertex_ai import get_multi_modal_embedding_model
# See the documentation: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api
SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png"]
class ModelRequest(BaseModel):
class Config:
arbitrary_types_allowed = True
image: Image | None = None
contextual_text: str | None = None
def count_input_tokens(self) -> int:
# The model doesn't report the number of input tokens.
# However, one could count it oneself:
# https://cloud.google.com/vertex-ai/generative-ai/pricing#embedding-models
# As of 29 Jul 2024, one image costs as much as 500 text input characters
ret = len(self.contextual_text or "")
if self.image:
ret += 500
return ret
def extract_embeddings(
self, response: MultiModalEmbeddingResponse
) -> Tuple[List[float], int]:
vector: List[float] | None = None
if self.image:
vector = response.image_embedding
else:
vector = response.text_embedding
if vector is None:
raise ValueError("No embeddings returned")
return vector, self.count_input_tokens()
def compute_embeddings(
request: ModelRequest,
model: MultiModalEmbeddingModel,
base64_encode: bool,
dimensions: int | None,
) -> Tuple[Embedding, int]:
if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{
"image": request.image,
"contextual_text": request.contextual_text,
"dimension": dimensions,
}
)
log.debug(f"request: {msg}")
response: MultiModalEmbeddingResponse = model.get_embeddings(
image=request.image,
contextual_text=request.contextual_text,
dimension=dimensions,
)
if log.isEnabledFor(DEBUG):
msg = json_dumps_short(response)
log.debug(f"response: {msg}")
vec, tokens = request.extract_embeddings(response)
return vector_to_embedding(base64_encode, vec), tokens
def validate_request(request: EmbeddingsRequest) -> None:
if request.custom_fields is not None:
if request.custom_fields.instruction is not None:
raise ValidationError("Instruction prompt is not supported")
if request.custom_fields.type is not None:
raise ValidationError(
"The embedding model does not support embedding types"
)
def _validate_content_type(content_type: str, supported_types: List[str]):
if content_type not in supported_types:
raise UserError(
f"Unsupported attachment content type: {content_type}. "
f"Supported attachment types: {', '.join(supported_types)}."
)
async def get_requests(
storage: FileStorage | None, request: EmbeddingsRequest
) -> AsyncIterator[ModelRequest]:
async def download_image(attachment: Attachment) -> Image:
resource = await AttachmentResource(attachment=attachment).download(
storage
)
_validate_content_type(resource.type, SUPPORTED_IMAGE_TYPES)
return Image(image_bytes=resource.data)
async def on_text(text: str):
return ModelRequest(contextual_text=text)
async def on_attachment(attachment: Attachment):
return ModelRequest(image=await download_image(attachment))
async def on_mixed(inputs: List[str | Attachment]) -> ModelRequest:
if len(inputs) == 0:
raise EMPTY_INPUT_LIST_ERROR
elif len(inputs) == 1:
if isinstance(inputs[0], str):
return await on_text(inputs[0])
else:
return await on_attachment(inputs[0])
elif len(inputs) == 2:
if isinstance(inputs[0], str) and isinstance(inputs[1], Attachment):
return ModelRequest(
contextual_text=inputs[0],
image=await download_image(inputs[1]),
)
elif isinstance(inputs[0], Attachment) and isinstance(
inputs[1], str
):
return ModelRequest(
contextual_text=inputs[1],
image=await download_image(inputs[0]),
)
else:
raise ValidationError(
"The first element of a custom_input list element must be a string and the second element must be an image attachment or vice versa"
)
else:
raise ValidationError(
"No more than two elements are allowed in an element of custom_input list"
)
return collect_embedding_inputs(
request,
on_text=on_text,
on_attachment=on_attachment,
on_mixed=on_mixed,
)
class MultiModalEmbeddingsAdapter(EmbeddingsAdapter):
model_id: str
model: MultiModalEmbeddingModel
storage: FileStorage | None
@classmethod
async def create(
cls, storage: FileStorage | None, model_id: str
) -> "EmbeddingsAdapter":
model = await get_multi_modal_embedding_model(model_id)
return cls(model_id=model_id, model=model, storage=storage)
async def embeddings(
self, request: EmbeddingsRequest
) -> EmbeddingsResponse:
validate_request(request)
base64_encode = request.encoding_format == "base64"
# NOTE: The model doesn't support batched inputs
tasks: List[Callable[[], Tuple[Embedding, int]]] = []
async for sub_request in await get_requests(self.storage, request):
tasks.append(
lambda sub_req=sub_request: compute_embeddings(
sub_req,
self.model,
base64_encode=base64_encode,
dimensions=request.dimensions,
)
)
embeddings: List[Embedding] = []
total_tokens = 0
for embedding, tokens in await gather_sync(tasks):
embeddings.append(embedding)
total_tokens += tokens
return make_embeddings_response(
model=self.model_id,
embeddings=embeddings,
tokens=total_tokens,
)