aidial_adapter_bedrock/server/exceptions.py (112 lines of code) (raw):
"""
The kinds of service exceptions which bedrock invocation may throw:
https://github.com/boto/botocore/blob/1.31.57/botocore/data/bedrock-runtime/2023-09-30/service-2.json#L46-L57
The service exceptions have the following inheritance hierarchy:
- ValidationException (botocore.errorfactory)
- ClientError (botocore.exceptions)
- Exception (builtins)
- BaseException (builtins)
- object (builtins)
The recommended way to discriminate service exceptions is
to access `response` field of a ClientError instance:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html#parsing-error-responses-and-catching-exceptions-from-aws-services
"""
import json
from enum import Enum
from functools import wraps
from typing import assert_never
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import (
InternalServerError,
InvalidRequestError,
ResourceNotFoundError,
)
from anthropic import APIStatusError
from botocore.exceptions import ClientError
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.utils.log_config import app_logger as log
def create_error(status_code: int, message: str) -> DialException:
return (
InvalidRequestError(message)
if status_code < 500
else InternalServerError(message)
)
class BedrockExceptionCode(Enum):
"""
See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelWithResponseStream.html#API_runtime_InvokeModelWithResponseStream_ResponseSyntax
for the types of exceptions
"""
INTERNAL_SERVER = "internalServerException"
MODEL_STREAM_ERROR = "modelStreamErrorException"
MODEL_TIMEOUT = "modelTimeoutException"
SERVER_UNAVAILABLE = "serviceUnavailableException"
THROTTLING = "throttlingException"
VALIDATION = "validationException"
def __eq__(self, other):
if isinstance(other, str):
return self.value.lower() == other.lower()
return NotImplemented
def get_status_code(self) -> int:
match self:
case BedrockExceptionCode.INTERNAL_SERVER:
return 500
case BedrockExceptionCode.MODEL_STREAM_ERROR:
return 424
case BedrockExceptionCode.MODEL_TIMEOUT:
return 408
case BedrockExceptionCode.SERVER_UNAVAILABLE:
return 503
case BedrockExceptionCode.THROTTLING:
return 429
case BedrockExceptionCode.VALIDATION:
return 400
case _:
assert_never(self)
def _get_meta_status_code(response: dict) -> int | None:
code = response.get("ResponseMetadata", {}).get("HTTPStatusCode")
if isinstance(code, int):
return code
return None
def _get_response_error_code(response: dict) -> int | None:
code = response.get("Error", {}).get("Code")
try:
return BedrockExceptionCode(code).get_status_code()
except Exception:
return None
def _get_content_filter_error(response: dict) -> DialException | None:
if (
message := response.get("message")
) and "One or more prompts contains filtered words" in message:
return InvalidRequestError(message=message, code="content_filter")
return None
def to_dial_exception(e: Exception) -> DialException:
if (
isinstance(e, ClientError)
and hasattr(e, "response")
and isinstance(e.response, dict)
):
response = e.response
log.debug(
f"botocore.exceptions.ClientError.response: {json.dumps(response)}"
)
if error := _get_content_filter_error(response):
return error
status_code = (
_get_response_error_code(response)
or _get_meta_status_code(response)
or 500
)
return create_error(status_code, str(e))
if isinstance(e, APIStatusError):
return create_error(e.status_code, e.message)
if isinstance(e, ValidationError):
return e.to_dial_exception()
if isinstance(e, UserError):
return e.to_dial_exception()
if isinstance(e, DialException):
return e
return InternalServerError(str(e))
def dial_exception_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
dial_exception = to_dial_exception(e)
log.exception(
f"Caught exception: {type(e).__module__}.{type(e).__name__}. "
f"The exception converted to the dial exception: {dial_exception!r}."
)
raise dial_exception from e
return wrapper
def not_implemented_handler(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except NotImplementedError:
raise ResourceNotFoundError("The endpoint is not implemented")
return wrapper