aidial_adapter_bedrock/llm/converse/input.py (353 lines of code) (raw):

import json import re import uuid from dataclasses import dataclass from typing import List, Set, Tuple, assert_never from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import ( MessageContentImagePart, MessageContentTextPart, ) from aidial_sdk.chat_completion import Role as DialRole from aidial_sdk.chat_completion import ToolCall as DialToolCall from aidial_sdk.exceptions import RuntimeServerError from aidial_adapter_bedrock.dial_api.request import ToolsConfig from aidial_adapter_bedrock.dial_api.resource import ( AttachmentResource, UnsupportedContentType, URLResource, ) from aidial_adapter_bedrock.dial_api.storage import FileStorage from aidial_adapter_bedrock.llm.converse.constants import ( CONVERSE_DOCUMENT_TYPE_TO_MIME, CONVERSE_IMAGE_TYPE_TO_MIME, DOCUMENT_MIME_TO_CONVERSE_TYPE, IMAGE_MIME_TO_CONVERSE_TYPE, ) from aidial_adapter_bedrock.llm.converse.types import ( ConverseContentPart, ConverseDocumentPart, ConverseDocumentPartConfig, ConverseDocumentType, ConverseImagePart, ConverseImagePartConfig, ConverseImageType, ConverseMessage, ConverseRole, ConverseTextPart, ConverseToolResultPart, ConverseTools, ConverseToolSpec, ConverseToolUsePart, ) from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.utils.list import group_by from aidial_adapter_bedrock.utils.list_projection import ListProjection from aidial_adapter_bedrock.utils.resource import Resource def to_converse_role(role: DialRole) -> ConverseRole: """ Converse API accepts only 'user' and 'assistant' roles """ match role: case DialRole.USER | DialRole.TOOL | DialRole.FUNCTION: return ConverseRole.USER case DialRole.ASSISTANT: return ConverseRole.ASSISTANT case DialRole.SYSTEM: raise ValidationError("System messages are not allowed") case _: assert_never(role) def to_converse_tools(tools_config: ToolsConfig) -> ConverseTools: tools: list[ConverseToolSpec] = [] for function in tools_config.functions: tools.append( { "toolSpec": { "name": function.name, "description": function.description or "", "inputSchema": { "json": function.parameters or {"type": "object", "properties": {}} }, } } ) return { "tools": tools, "toolChoice": ({"any": {}} if tools_config.required else {"auto": {}}), } def function_call_to_content_part( dial_call: DialFunctionCall, ) -> ConverseToolUsePart: return { "toolUse": { "toolUseId": dial_call.name, "name": dial_call.name, "input": json.loads(dial_call.arguments), } } def tool_call_to_content_part( dial_call: DialToolCall, ) -> ConverseToolUsePart: return { "toolUse": { "toolUseId": dial_call.id, "name": dial_call.function.name, "input": json.loads(dial_call.function.arguments), } } def function_result_to_content_part( message: DialMessage, ) -> ConverseToolResultPart: if message.role != DialRole.FUNCTION: raise RuntimeServerError( "Function result message is expected to have function role" ) if not message.name or not isinstance(message.content, str): raise RuntimeServerError( "Function result message is expected to have function name and plain text content" ) return { "toolResult": { "toolUseId": message.name, "content": [{"text": message.content}], "status": "success", } } def tool_result_to_content_part( message: DialMessage, ) -> ConverseToolResultPart: if message.role != DialRole.TOOL: raise RuntimeServerError( "Tool result message is expected to have tool role" ) if not message.tool_call_id or not isinstance(message.content, str): raise RuntimeServerError( "Tool result message is expected to have tool call id and plain text content" ) try: json_content = json.loads(message.content) return { "toolResult": { "toolUseId": message.tool_call_id, "content": [{"json": json_content}], "status": "success", } } except json.JSONDecodeError: return { "toolResult": { "toolUseId": message.tool_call_id, "content": [{"text": message.content}], "status": "success", } } def sanitize_document_name(name: str) -> str: """ The name must: - Be between 1-200 characters long - Only contain alphanumeric chars, spaces, hyphens, parentheses, and square brackets - Not have consecutive spaces """ name = re.sub(r"\s+", " ", name) name = re.sub(r"[^a-zA-Z0-9\-\(\)\[\] _]", "_", name) return name[:200] def to_converse_multi_modal_part( resource: Resource, name: str | None = None, ) -> ConverseImagePart | ConverseDocumentPart: if converse_type := IMAGE_MIME_TO_CONVERSE_TYPE.get(resource.type): return ConverseImagePart( image=ConverseImagePartConfig( format=converse_type, source={"bytes": resource.data}, ) ) elif converse_type := DOCUMENT_MIME_TO_CONVERSE_TYPE.get(resource.type): return ConverseDocumentPart( document=ConverseDocumentPartConfig( format=converse_type, name=sanitize_document_name(name or str(uuid.uuid4())), source={"bytes": resource.data}, ) ) else: raise UnsupportedContentType( message="Unknown multi-modal type", type=resource.type, supported_types=[], ) async def _get_converse_message_content( message: DialMessage, storage: FileStorage | None, supported_image_types: list[ConverseImageType], supported_document_types: list[ConverseDocumentType], ) -> List[ConverseContentPart]: image_mime_types = [ CONVERSE_IMAGE_TYPE_TO_MIME[t] for t in supported_image_types ] document_mime_types = [ CONVERSE_DOCUMENT_TYPE_TO_MIME[t] for t in supported_document_types ] def _unsupported_multi_modal_error(t: str) -> str: message = f"Unsupported attachment type: {t}\n" if not supported_image_types and not supported_document_types: return message + "Model does not support multi-modal" if supported_image_types: message += f"Supported image types: {', '.join([t.value for t in supported_image_types])}\n" else: message += "Images are not supported\n" if supported_document_types: message += f"Supported document types: {', '.join([t.value for t in supported_document_types])}" else: message += "Documents are not supported" return message if message.role == DialRole.FUNCTION: return [function_result_to_content_part(message)] elif message.role == DialRole.TOOL: return [tool_result_to_content_part(message)] content = [] match message.content: case str(): content.append({"text": message.content}) case list(): for part in message.content: match part: case MessageContentTextPart(): content.append({"text": part.text}) case MessageContentImagePart(): try: resource = await URLResource( url=part.image_url.url, supported_types=image_mime_types, ).download(storage) content.append( to_converse_multi_modal_part(resource) ) except UnsupportedContentType as e: raise UserError( error_message=_unsupported_multi_modal_error( e.type ) ) case None: pass case _: assert_never(message.content) if message.custom_content and message.custom_content.attachments: for attachment in message.custom_content.attachments: try: resource = await AttachmentResource( attachment=attachment, supported_types=image_mime_types + document_mime_types, ).download(storage) content.append( to_converse_multi_modal_part( resource, name=attachment.title, ) ) except UnsupportedContentType as e: raise UserError( error_message=_unsupported_multi_modal_error(e.type), ) if message.function_call and message.tool_calls: raise ValidationError( "You cannot use both function call and tool calls in the same message" ) elif message.function_call: content.append(function_call_to_content_part(message.function_call)) elif message.tool_calls: content.extend( [ tool_call_to_content_part(tool_call) for tool_call in message.tool_calls ] ) return content async def to_converse_message( message: DialMessage, storage: FileStorage | None = None, supported_image_types: list[ConverseImageType] | None = None, supported_document_types: list[ConverseDocumentType] | None = None, ) -> ConverseMessage: return { "role": to_converse_role(message.role), "content": await _get_converse_message_content( message, storage, supported_image_types or [], supported_document_types or [], ), } @dataclass class ExtractSystemPromptResult: system_prompt: ConverseTextPart | None system_message_count: int non_system_messages: List[DialMessage] def extract_converse_system_prompt( messages: List[DialMessage], ) -> ExtractSystemPromptResult: system_msgs = [] found_non_system = False system_messages_count = 0 non_system_messages = [] for msg in messages: if msg.role == DialRole.SYSTEM: if found_non_system: raise ValidationError( "A system message can only follow another system message" ) system_messages_count += 1 match msg.content: case str(): system_msgs.append(msg.content) case list(): for part in msg.content: match part: case MessageContentTextPart(): system_msgs.append(part.text) case MessageContentImagePart(): raise ValidationError( "System messages cannot contain images" ) case None: pass case _: assert_never(msg.content) else: found_non_system = True non_system_messages.append(msg) combined = "\n\n".join(msg for msg in system_msgs if msg) return ExtractSystemPromptResult( system_prompt=ConverseTextPart(text=combined) if combined else None, system_message_count=system_messages_count, non_system_messages=non_system_messages, ) async def to_converse_messages( messages: List[DialMessage], storage: FileStorage | None = None, supported_image_types: list[ConverseImageType] | None = None, supported_document_types: list[ConverseDocumentType] | None = None, # Offset for system messages at the beginning start_offset: int = 0, ) -> ListProjection[ConverseMessage]: def _merge( a: Tuple[ConverseMessage, Set[int]], b: Tuple[ConverseMessage, Set[int]], ) -> Tuple[ConverseMessage, Set[int]]: (msg1, set1), (msg2, set2) = a, b content1 = msg1["content"] content2 = msg2["content"] return { "role": msg1["role"], "content": content1 + content2, }, set1 | set2 converted = [ ( await to_converse_message( msg, storage, supported_image_types, supported_document_types ), {idx}, ) for idx, msg in enumerate(messages, start=start_offset) ] # Merge messages with the same roles to achieve an alternation of user-assistant roles. return ListProjection( group_by( lst=converted, key=lambda msg: msg[0]["role"], init=lambda msg: msg, merge=_merge, ) )