aidial_adapter_bedrock/llm/converse/input.py (353 lines of code) (raw):
import json
import re
import uuid
from dataclasses import dataclass
from typing import List, Set, Tuple, assert_never
from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall
from aidial_sdk.chat_completion import Message as DialMessage
from aidial_sdk.chat_completion import (
MessageContentImagePart,
MessageContentTextPart,
)
from aidial_sdk.chat_completion import Role as DialRole
from aidial_sdk.chat_completion import ToolCall as DialToolCall
from aidial_sdk.exceptions import RuntimeServerError
from aidial_adapter_bedrock.dial_api.request import ToolsConfig
from aidial_adapter_bedrock.dial_api.resource import (
AttachmentResource,
UnsupportedContentType,
URLResource,
)
from aidial_adapter_bedrock.dial_api.storage import FileStorage
from aidial_adapter_bedrock.llm.converse.constants import (
CONVERSE_DOCUMENT_TYPE_TO_MIME,
CONVERSE_IMAGE_TYPE_TO_MIME,
DOCUMENT_MIME_TO_CONVERSE_TYPE,
IMAGE_MIME_TO_CONVERSE_TYPE,
)
from aidial_adapter_bedrock.llm.converse.types import (
ConverseContentPart,
ConverseDocumentPart,
ConverseDocumentPartConfig,
ConverseDocumentType,
ConverseImagePart,
ConverseImagePartConfig,
ConverseImageType,
ConverseMessage,
ConverseRole,
ConverseTextPart,
ConverseToolResultPart,
ConverseTools,
ConverseToolSpec,
ConverseToolUsePart,
)
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.utils.list import group_by
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.resource import Resource
def to_converse_role(role: DialRole) -> ConverseRole:
"""
Converse API accepts only 'user' and 'assistant' roles
"""
match role:
case DialRole.USER | DialRole.TOOL | DialRole.FUNCTION:
return ConverseRole.USER
case DialRole.ASSISTANT:
return ConverseRole.ASSISTANT
case DialRole.SYSTEM:
raise ValidationError("System messages are not allowed")
case _:
assert_never(role)
def to_converse_tools(tools_config: ToolsConfig) -> ConverseTools:
tools: list[ConverseToolSpec] = []
for function in tools_config.functions:
tools.append(
{
"toolSpec": {
"name": function.name,
"description": function.description or "",
"inputSchema": {
"json": function.parameters
or {"type": "object", "properties": {}}
},
}
}
)
return {
"tools": tools,
"toolChoice": ({"any": {}} if tools_config.required else {"auto": {}}),
}
def function_call_to_content_part(
dial_call: DialFunctionCall,
) -> ConverseToolUsePart:
return {
"toolUse": {
"toolUseId": dial_call.name,
"name": dial_call.name,
"input": json.loads(dial_call.arguments),
}
}
def tool_call_to_content_part(
dial_call: DialToolCall,
) -> ConverseToolUsePart:
return {
"toolUse": {
"toolUseId": dial_call.id,
"name": dial_call.function.name,
"input": json.loads(dial_call.function.arguments),
}
}
def function_result_to_content_part(
message: DialMessage,
) -> ConverseToolResultPart:
if message.role != DialRole.FUNCTION:
raise RuntimeServerError(
"Function result message is expected to have function role"
)
if not message.name or not isinstance(message.content, str):
raise RuntimeServerError(
"Function result message is expected to have function name and plain text content"
)
return {
"toolResult": {
"toolUseId": message.name,
"content": [{"text": message.content}],
"status": "success",
}
}
def tool_result_to_content_part(
message: DialMessage,
) -> ConverseToolResultPart:
if message.role != DialRole.TOOL:
raise RuntimeServerError(
"Tool result message is expected to have tool role"
)
if not message.tool_call_id or not isinstance(message.content, str):
raise RuntimeServerError(
"Tool result message is expected to have tool call id and plain text content"
)
try:
json_content = json.loads(message.content)
return {
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": json_content}],
"status": "success",
}
}
except json.JSONDecodeError:
return {
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"text": message.content}],
"status": "success",
}
}
def sanitize_document_name(name: str) -> str:
"""
The name must:
- Be between 1-200 characters long
- Only contain alphanumeric chars, spaces, hyphens, parentheses, and square brackets
- Not have consecutive spaces
"""
name = re.sub(r"\s+", " ", name)
name = re.sub(r"[^a-zA-Z0-9\-\(\)\[\] _]", "_", name)
return name[:200]
def to_converse_multi_modal_part(
resource: Resource,
name: str | None = None,
) -> ConverseImagePart | ConverseDocumentPart:
if converse_type := IMAGE_MIME_TO_CONVERSE_TYPE.get(resource.type):
return ConverseImagePart(
image=ConverseImagePartConfig(
format=converse_type,
source={"bytes": resource.data},
)
)
elif converse_type := DOCUMENT_MIME_TO_CONVERSE_TYPE.get(resource.type):
return ConverseDocumentPart(
document=ConverseDocumentPartConfig(
format=converse_type,
name=sanitize_document_name(name or str(uuid.uuid4())),
source={"bytes": resource.data},
)
)
else:
raise UnsupportedContentType(
message="Unknown multi-modal type",
type=resource.type,
supported_types=[],
)
async def _get_converse_message_content(
message: DialMessage,
storage: FileStorage | None,
supported_image_types: list[ConverseImageType],
supported_document_types: list[ConverseDocumentType],
) -> List[ConverseContentPart]:
image_mime_types = [
CONVERSE_IMAGE_TYPE_TO_MIME[t] for t in supported_image_types
]
document_mime_types = [
CONVERSE_DOCUMENT_TYPE_TO_MIME[t] for t in supported_document_types
]
def _unsupported_multi_modal_error(t: str) -> str:
message = f"Unsupported attachment type: {t}\n"
if not supported_image_types and not supported_document_types:
return message + "Model does not support multi-modal"
if supported_image_types:
message += f"Supported image types: {', '.join([t.value for t in supported_image_types])}\n"
else:
message += "Images are not supported\n"
if supported_document_types:
message += f"Supported document types: {', '.join([t.value for t in supported_document_types])}"
else:
message += "Documents are not supported"
return message
if message.role == DialRole.FUNCTION:
return [function_result_to_content_part(message)]
elif message.role == DialRole.TOOL:
return [tool_result_to_content_part(message)]
content = []
match message.content:
case str():
content.append({"text": message.content})
case list():
for part in message.content:
match part:
case MessageContentTextPart():
content.append({"text": part.text})
case MessageContentImagePart():
try:
resource = await URLResource(
url=part.image_url.url,
supported_types=image_mime_types,
).download(storage)
content.append(
to_converse_multi_modal_part(resource)
)
except UnsupportedContentType as e:
raise UserError(
error_message=_unsupported_multi_modal_error(
e.type
)
)
case None:
pass
case _:
assert_never(message.content)
if message.custom_content and message.custom_content.attachments:
for attachment in message.custom_content.attachments:
try:
resource = await AttachmentResource(
attachment=attachment,
supported_types=image_mime_types + document_mime_types,
).download(storage)
content.append(
to_converse_multi_modal_part(
resource,
name=attachment.title,
)
)
except UnsupportedContentType as e:
raise UserError(
error_message=_unsupported_multi_modal_error(e.type),
)
if message.function_call and message.tool_calls:
raise ValidationError(
"You cannot use both function call and tool calls in the same message"
)
elif message.function_call:
content.append(function_call_to_content_part(message.function_call))
elif message.tool_calls:
content.extend(
[
tool_call_to_content_part(tool_call)
for tool_call in message.tool_calls
]
)
return content
async def to_converse_message(
message: DialMessage,
storage: FileStorage | None = None,
supported_image_types: list[ConverseImageType] | None = None,
supported_document_types: list[ConverseDocumentType] | None = None,
) -> ConverseMessage:
return {
"role": to_converse_role(message.role),
"content": await _get_converse_message_content(
message,
storage,
supported_image_types or [],
supported_document_types or [],
),
}
@dataclass
class ExtractSystemPromptResult:
system_prompt: ConverseTextPart | None
system_message_count: int
non_system_messages: List[DialMessage]
def extract_converse_system_prompt(
messages: List[DialMessage],
) -> ExtractSystemPromptResult:
system_msgs = []
found_non_system = False
system_messages_count = 0
non_system_messages = []
for msg in messages:
if msg.role == DialRole.SYSTEM:
if found_non_system:
raise ValidationError(
"A system message can only follow another system message"
)
system_messages_count += 1
match msg.content:
case str():
system_msgs.append(msg.content)
case list():
for part in msg.content:
match part:
case MessageContentTextPart():
system_msgs.append(part.text)
case MessageContentImagePart():
raise ValidationError(
"System messages cannot contain images"
)
case None:
pass
case _:
assert_never(msg.content)
else:
found_non_system = True
non_system_messages.append(msg)
combined = "\n\n".join(msg for msg in system_msgs if msg)
return ExtractSystemPromptResult(
system_prompt=ConverseTextPart(text=combined) if combined else None,
system_message_count=system_messages_count,
non_system_messages=non_system_messages,
)
async def to_converse_messages(
messages: List[DialMessage],
storage: FileStorage | None = None,
supported_image_types: list[ConverseImageType] | None = None,
supported_document_types: list[ConverseDocumentType] | None = None,
# Offset for system messages at the beginning
start_offset: int = 0,
) -> ListProjection[ConverseMessage]:
def _merge(
a: Tuple[ConverseMessage, Set[int]],
b: Tuple[ConverseMessage, Set[int]],
) -> Tuple[ConverseMessage, Set[int]]:
(msg1, set1), (msg2, set2) = a, b
content1 = msg1["content"]
content2 = msg2["content"]
return {
"role": msg1["role"],
"content": content1 + content2,
}, set1 | set2
converted = [
(
await to_converse_message(
msg, storage, supported_image_types, supported_document_types
),
{idx},
)
for idx, msg in enumerate(messages, start=start_offset)
]
# Merge messages with the same roles to achieve an alternation of user-assistant roles.
return ListProjection(
group_by(
lst=converted,
key=lambda msg: msg[0]["role"],
init=lambda msg: msg,
merge=_merge,
)
)