aidial_adapter_dial/utils/storage.py (136 lines of code) (raw):

import io import logging import mimetypes from typing import Mapping, Optional, TypedDict from urllib.parse import urljoin import aiohttp from pydantic import BaseModel log = logging.getLogger(__name__) class FileMetadata(TypedDict): name: str parentPath: str bucket: str url: str class Bucket(TypedDict): bucket: str appdata: str | None # Users do not have appdata class AccessDeniedError(Exception): def __init__(self, url: str): super().__init__(f"Access denied: {url!r}") self.url = url class FileStorage(BaseModel): dial_url: str api_key: str bucket: Optional[Bucket] = None @property def headers(self) -> Mapping[str, str]: return {"api-key": self.api_key} async def get_bucket(self, session: aiohttp.ClientSession) -> Bucket: if self.bucket is not None: return self.bucket log.debug(f"retrieving bucket for {self.dial_url!r}") async with session.get( f"{self.dial_url}/v1/bucket", headers=self.headers, ) as response: response.raise_for_status() bucket = await response.json() self.bucket = bucket log.debug(f"bucket: {self.bucket}") return bucket @staticmethod def _to_form_data( filename: str, content_type: str | None, content: bytes ) -> aiohttp.FormData: data = aiohttp.FormData() data.add_field( "file", io.BytesIO(content), filename=filename, content_type=content_type, ) return data async def is_accessible( self, url: str, session: aiohttp.ClientSession ) -> bool: try: await self._get_metadata(url, session) log.debug(f"file is accessible: url={url!r}") return True except AccessDeniedError: log.debug(f"file isn't accessible: url={url!r}") return False def to_metadata_url(self, url: str) -> str: """ The file metadata URL given url like: files/BUCKET/foo/baz/document.pdf """ metadata_url = f"{self.dial_url}/v1/metadata/" return urljoin(metadata_url, url, allow_fragments=True) async def _get_metadata( self, url: str, session: aiohttp.ClientSession, ) -> dict: metadata_url = self.to_metadata_url(url) log.debug(f"retrieving metadata: file={url!r}, url={metadata_url!r}") async with session.get(metadata_url, headers=self.headers) as response: if not response.ok: match response.status: case 403: raise AccessDeniedError(url) case _: response.raise_for_status() metadata = await response.json() log.debug( f"retrieved metadata: file={url!r}, url={metadata_url!r}, metadata={metadata}" ) return metadata async def upload_file( self, filename: str, content_type: str | None, content: bytes, session: aiohttp.ClientSession, ) -> FileMetadata: bucket = await self.get_bucket(session) appdata = bucket["appdata"] ext = (content_type and mimetypes.guess_extension(content_type)) or "" url = f"{self.dial_url}/v1/files/{appdata}/{filename}{ext}" return await self.upload(url, content_type, content, session) async def upload( self, url: str, content_type: str | None, content: bytes, session: aiohttp.ClientSession, ) -> FileMetadata: log.debug(f"uploading file {url!r}") if self.to_dial_url(url) is None: raise ValueError(f"URL isn't DIAL url: {url!r}") url = self.to_abs_url(url) data = FileStorage._to_form_data(url, content_type, content) async with session.put( url=url, data=data, headers=self.headers, ) as response: response.raise_for_status() meta = await response.json() log.debug(f"uploaded file: url={url!r}, metadata={meta}") return meta def to_dial_url(self, link: str) -> str | None: url = self.to_abs_url(link) base_url = f"{self.dial_url}/v1/" if url.startswith(base_url): return url.removeprefix(base_url) return None def to_abs_url(self, link: str) -> str: base_url = f"{self.dial_url}/v1/" ret = urljoin(base_url, link) return ret async def download(self, url: str, session: aiohttp.ClientSession) -> bytes: log.debug(f"downloading file {url!r}") if self.to_dial_url(url) is None: raise ValueError(f"URL isn't DIAL url: {url!r}") url = self.to_abs_url(url) async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers) as response: response.raise_for_status() return await response.read()