aidial_adapter_dial/transformer.py (175 lines of code) (raw):
import logging
import re
from typing import Any, Callable, Coroutine, Self
import aiohttp
from pydantic import BaseModel
from aidial_adapter_dial.utils.app_data import AppData
from aidial_adapter_dial.utils.storage import FileStorage
log = logging.getLogger(__name__)
class AttachmentTransformer(BaseModel):
local_storage: FileStorage
local_user_bucket: str
local_appdata: str
remote_storage: FileStorage
remote_user_bucket: str
proxy_mode: bool
@classmethod
async def create(
cls, remote_storage: FileStorage, local_storage: FileStorage
) -> Self:
async with aiohttp.ClientSession() as session:
local = await local_storage.get_bucket(session)
local_appdata = local.get("appdata")
if local_appdata is None:
raise ValueError(
"The local appdata bucket is expected to be set"
)
local_user_bucket = AppData.parse(local_appdata).user_bucket
remote = await remote_storage.get_bucket(session)
remote_appdata = remote.get("appdata")
if remote_appdata is None:
remote_user_bucket = remote["bucket"]
else:
remote_user_bucket = AppData.parse(remote_appdata).user_bucket
proxy_mode = (
remote_storage.dial_url == local_storage.dial_url
and remote_storage.api_key == local_storage.api_key
)
log.debug(f"proxy_mode: {proxy_mode}")
return cls(
remote_storage=remote_storage,
remote_user_bucket=remote_user_bucket,
local_storage=local_storage,
local_user_bucket=local_user_bucket,
local_appdata=local_appdata,
proxy_mode=proxy_mode,
)
def get_remote_url(self, local_url: str) -> str:
"""
user/app files:
if proxy_mode:
< files/LOCAL_USER_BUCKET/PATH
> files/REMOTE_USER_BUCKET/PATH
else:
< files/LOCAL_USER_BUCKET/PATH
> files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
"""
if self.proxy_mode:
return local_url
if not local_url.startswith(f"files/{self.local_user_bucket}/"):
raise ValueError(f"Unexpected local URL: {local_url!r}")
return f"files/{self.remote_user_bucket}/{local_url.removeprefix('files/')}"
def get_local_url(self, remote_url: str) -> str:
"""
user/app files uploaded from local to remote earlier (reverse of get_remote_url):
if proxy_mode:
< files/REMOTE_USER_BUCKET/PATH
> files/LOCAL_USER_BUCKET/PATH
else:
< files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
> files/LOCAL_USER_BUCKET/PATH
created by remote (user):
< files/REMOTE_USER_BUCKET/appdata/REMOTE_APP_NAME/PATH
> files/LOCAL_USER_BUCKET/appdata/LOCAL_APP_NAME/PATH
created by remote (app):
< files/REMOTE_APP_BUCKET/PATH
> This means an application has a bug in it.
We reject such URLs right away since there is no way
to read an application file by a user.
"""
if not remote_url.startswith(f"files/{self.remote_user_bucket}/"):
raise ValueError(
f"The remote file ({remote_url!r}) is expected "
f"to be uploaded to the remote user bucket ({self.remote_user_bucket!r})"
)
remote_path = remote_url.removeprefix(
f"files/{self.remote_user_bucket}/"
)
if remote_path.startswith("appdata/"):
regex = r"appdata/([^/]+)/(.+)"
match = re.match(regex, remote_path)
if match is None:
raise ValueError(f"Invalid remote appdata path: {remote_url!r}")
_remote_app_name, path = match.groups()
return f"files/{self.local_appdata}/{path}"
if not self.proxy_mode:
if remote_path.startswith(f"{self.local_user_bucket}/"):
path = remote_path.removeprefix(f"{self.local_user_bucket}/")
return f"files/{self.local_user_bucket}/{path}"
raise ValueError(
f"The remote file ({remote_url!r}) is expected to be uploaded either "
"to remote appdata path or "
"to a local user bucket subpath of remote user bucket."
)
else:
return remote_url
async def modify_request_attachment(self, attachment: dict) -> None:
if (ref_url := attachment.get("reference_url")) and (
local_ref_url := self.local_storage.to_dial_url(ref_url)
):
remote_ref_url = self.get_remote_url(local_ref_url)
attachment["reference_url"] = remote_ref_url
if (url := attachment.get("url")) and (
local_url := self.local_storage.to_dial_url(url)
):
remote_url = self.get_remote_url(local_url)
await download_and_upload_file(
self.local_storage,
local_url,
self.remote_storage,
remote_url,
attachment.get("type"),
)
attachment["url"] = remote_url
log.debug(
f"uploaded from local to remote: from {local_url!r} to {remote_url!r}"
)
async def modify_response_attachment(self, attachment: dict) -> None:
if (ref_url := attachment.get("reference_url")) and (
remote_ref_url := self.remote_storage.to_dial_url(ref_url)
):
local_ref_url = self.get_local_url(remote_ref_url)
attachment["reference_url"] = local_ref_url
if (url := attachment.get("url")) and (
remote_url := self.remote_storage.to_dial_url(url)
):
local_url = self.get_local_url(remote_url)
await download_and_upload_file(
self.remote_storage,
remote_url,
self.local_storage,
local_url,
attachment.get("type"),
)
attachment["url"] = local_url
log.debug(
f"uploaded from remote to local: from {remote_url!r} to {local_url!r}"
)
async def modify_request(self, request: dict) -> dict:
if "messages" in request:
messages = request["messages"]
for message in messages:
await modify_message(message, self.modify_request_attachment)
return request
async def modify_response_chunk(self, response: dict) -> dict:
choices = response.get("choices")
if choices is None:
return response
for choice in choices:
if "delta" in choice:
await modify_message(
choice["delta"], self.modify_response_attachment
)
return response
async def modify_response(self, response: dict) -> dict:
choices = response.get("choices")
if choices is None:
return response
for choice in choices:
if "message" in choice:
await modify_message(
choice["message"], self.modify_response_attachment
)
return response
async def download_and_upload_file(
src_storage: FileStorage,
src_url: str,
dest_storage: FileStorage,
dest_url: str,
content_type: str | None,
):
log.debug(f"downloading from {src_url!r} and uploading to {dest_url!r}")
if src_url != dest_url:
if _is_directory(src_url):
raise ValueError("Directories aren't yet supported")
async with aiohttp.ClientSession() as session:
content = await src_storage.download(src_url, session)
await dest_storage.upload(dest_url, content_type, content, session)
async def modify_message(
message: dict,
modify_attachment: Callable[[dict], Coroutine[Any, Any, None]],
) -> None:
cc = message.get("custom_content")
if cc is None:
return
attachments = cc.get("attachments")
if attachments is None:
return
for attachment in attachments:
await modify_attachment(attachment)
def _is_directory(url: str) -> bool:
return url[-1] == "/"