aidial_adapter_openai/dial_api/storage.py (135 lines of code) (raw):

import base64 import hashlib import io import mimetypes import os from typing import Mapping, Optional, TypedDict from urllib.parse import unquote, urljoin import aiohttp from pydantic import BaseModel from aidial_adapter_openai.utils.auth import Auth from aidial_adapter_openai.utils.env import get_env, get_env_bool from aidial_adapter_openai.utils.log_config import logger as log CORE_API_VERSION = os.getenv("CORE_API_VERSION") class FileMetadata(TypedDict): name: str parentPath: str bucket: str url: str class Bucket(TypedDict): bucket: str appdata: str | None class FileStorage(BaseModel): dial_url: str upload_dir: str auth: Auth bucket: Optional[Bucket] = None async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket: if self.bucket is None: async with session.get( f"{self.dial_url}/v1/bucket", headers=self.auth.headers, ) as response: response.raise_for_status() self.bucket = await response.json() log.debug(f"bucket: {self.bucket}") return self.bucket async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str: bucket = await self._get_bucket(session) appdata = bucket.get("appdata") if appdata is None: raise ValueError( "Can't retrieve user bucket because appdata isn't available" ) return appdata.split("/", 1)[0] @staticmethod def _to_form_data( filename: str, content_type: str, content: bytes ) -> aiohttp.FormData: data = aiohttp.FormData() data.add_field( "file", io.BytesIO(content), filename=filename, content_type=content_type, ) return data async def upload( self, filename: str, content_type: str, content: bytes ) -> FileMetadata: async with aiohttp.ClientSession() as session: bucket = await self._get_bucket(session) appdata = bucket["appdata"] ext = mimetypes.guess_extension(content_type) or "" url = f"{self.dial_url}/v1/files/{appdata}/{self.upload_dir}/{filename}{ext}" data = FileStorage._to_form_data(filename, content_type, content) async with session.put( url=url, data=data, headers=self.auth.headers, ) as response: response.raise_for_status() meta = await response.json() log.debug(f"Uploaded file: url={url}, metadata={meta}") return meta async def upload_file_as_base64( self, data: str, content_type: str ) -> FileMetadata: filename = _compute_hash_digest(data) content: bytes = base64.b64decode(data) return await self.upload(filename, content_type, content) def attachment_link_to_url(self, link: str) -> str: if CORE_API_VERSION == "0.6": base_url = f"{self.dial_url}/v1/files/" else: base_url = f"{self.dial_url}/v1/" return urljoin(base_url, link) def _url_to_attachment_link(self, url: str) -> str: if CORE_API_VERSION == "0.6": return url.removeprefix(f"{self.dial_url}/v1/files/") else: return url.removeprefix(f"{self.dial_url}/v1/") async def download_file(self, link: str) -> bytes: url = self.attachment_link_to_url(link) headers: Mapping[str, str] = {} if url.lower().startswith(self.dial_url.lower()): headers = self.auth.headers return await download_file(url, headers) async def get_human_readable_name(self, link: str) -> str: url = self.attachment_link_to_url(link) link = self._url_to_attachment_link(url) link = link.removeprefix("files/") if link.startswith("public/"): bucket = "public" else: async with aiohttp.ClientSession() as session: bucket = await self._get_user_bucket(session) link = link.removeprefix(f"{bucket}/") decoded_link = unquote(link) return link if link == decoded_link else repr(decoded_link) async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes: async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: response.raise_for_status() return await response.read() def _compute_hash_digest(file_content: str) -> str: return hashlib.sha256(file_content.encode()).hexdigest() DIAL_USE_FILE_STORAGE = get_env_bool("DIAL_USE_FILE_STORAGE", False) DIAL_URL: Optional[str] = None if DIAL_USE_FILE_STORAGE: DIAL_URL = get_env( "DIAL_URL", "DIAL_URL must be set to use the DIAL file storage" ) def create_file_storage( base_dir: str, headers: Mapping[str, str] ) -> Optional[FileStorage]: if not DIAL_USE_FILE_STORAGE or DIAL_URL is None: return None auth = Auth.from_headers("api-key", headers) if auth is None: log.debug( "The request doesn't have required headers to use the DIAL file storage. " "Fallback to base64 encoding of images." ) return None return FileStorage(dial_url=DIAL_URL, upload_dir=base_dir, auth=auth)