aidial_adapter_openai/dalle3.py (116 lines of code) (raw):
from typing import Any, AsyncIterator, Optional
import aiohttp
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from aidial_adapter_openai.dial_api.storage import FileStorage
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.streaming import build_chunk, generate_id
IMG_USAGE = {
"prompt_tokens": 0,
"completion_tokens": 1,
"total_tokens": 1,
}
async def generate_image(
api_url: str, creds: OpenAICreds, user_prompt: str
) -> JSONResponse | Any:
async with aiohttp.ClientSession() as session:
async with session.post(
api_url,
json={"prompt": user_prompt, "response_format": "b64_json"},
headers=get_auth_headers(creds),
) as response:
status_code = response.status
data = await response.json()
if status_code == 200:
return data
if "error" in data:
error = data["error"]
if error.get("code") in [
"content_policy_violation",
"contentFilter",
]:
error["code"] = "content_filter"
return DIALException(
status_code=status_code,
message=error.get("message"),
type=error.get("type"),
param=error.get("param"),
code=error.get("code"),
).to_fastapi_response()
else:
return JSONResponse(content=data, status_code=status_code)
def build_custom_content(base64_image: str, revised_prompt: str) -> Any:
return {
"custom_content": {
"attachments": [
{"title": "Revised prompt", "data": revised_prompt},
{"title": "Image", "type": "image/png", "data": base64_image},
]
},
"content": "",
}
async def generate_stream(
id: str, created: int, custom_content: Any
) -> AsyncIterator[dict]:
yield build_chunk(id, None, {"role": "assistant"}, created, True)
yield build_chunk(id, None, custom_content, created, True)
yield build_chunk(id, "stop", {}, created, True, usage=IMG_USAGE)
def get_user_prompt(data: Any) -> str:
try:
prompt = data["messages"][-1]["content"]
if not isinstance(prompt, str):
raise ValueError("Content isn't a string")
return prompt
except Exception as e:
raise RequestValidationError(
"Invalid request. Expected a string at path 'messages[-1].content'."
) from e
async def move_attachments_data_to_storage(
custom_content: Any, file_storage: FileStorage
):
for attachment in custom_content["custom_content"]["attachments"]:
if (
"data" not in attachment
or "type" not in attachment
or not attachment["type"].startswith("image/")
):
continue
file_metadata = await file_storage.upload_file_as_base64(
attachment["data"], attachment["type"]
)
del attachment["data"]
attachment["url"] = file_metadata["url"]
async def chat_completion(
data: Any,
upstream_endpoint: str,
creds: OpenAICreds,
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
):
if data.get("n", 1) > 1:
raise RequestValidationError("The deployment doesn't support n > 1")
api_url = f"{upstream_endpoint}?api-version={api_version}"
user_prompt = get_user_prompt(data)
model_response = await generate_image(api_url, creds, user_prompt)
if isinstance(model_response, JSONResponse):
return model_response
base64_image = model_response["data"][0]["b64_json"]
revised_prompt = model_response["data"][0]["revised_prompt"]
id = generate_id()
created = model_response["created"]
custom_content = build_custom_content(base64_image, revised_prompt)
if file_storage is not None:
await move_attachments_data_to_storage(custom_content, file_storage)
if is_stream:
return generate_stream(id, created, custom_content)
else:
return build_chunk(
id,
"stop",
{"role": "assistant", "content": "", **custom_content},
created,
False,
usage=IMG_USAGE,
)