aidial_adapter_bedrock/llm/message.py (213 lines of code) (raw):

from abc import ABC, abstractmethod from typing import List, Optional, Self, Union from aidial_sdk.chat_completion import Attachment, CustomContent, FunctionCall from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import ( MessageContentPart, MessageContentTextPart, Role, ToolCall, ) from pydantic import BaseModel from aidial_adapter_bedrock.dial_api.request import ( collect_text_content, is_plain_text_content, is_text_content, to_message_content, ) from aidial_adapter_bedrock.llm.errors import ValidationError class MessageABC(ABC, BaseModel): @abstractmethod def to_message(self) -> DialMessage: ... @classmethod @abstractmethod def from_message(cls, message: DialMessage) -> Self | None: ... class BaseMessageABC(MessageABC): @property @abstractmethod def text_content(self) -> str: ... class SystemMessage(BaseMessageABC): content: str | List[MessageContentTextPart] def to_message(self) -> DialMessage: return DialMessage( role=Role.SYSTEM, content=to_message_content(self.content), ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.SYSTEM: return None content = message.content if not is_text_content(content): raise ValidationError( "System message is expected to be a string or a list of text content parts" ) return cls(content=content) @property def text_content(self) -> str: return collect_text_content(self.content) class HumanRegularMessage(BaseMessageABC): content: str | List[MessageContentPart] custom_content: Optional[CustomContent] = None def to_message(self) -> DialMessage: return DialMessage( role=Role.USER, content=self.content, custom_content=self.custom_content, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.USER: return None content = message.content if content is None: raise ValidationError( "User message is expected to have content field" ) return cls(content=content, custom_content=message.custom_content) @property def text_content(self) -> str: return collect_text_content(self.content) @property def attachments(self) -> List[Attachment]: return ( self.custom_content.attachments or [] if self.custom_content else [] ) class HumanToolResultMessage(MessageABC): id: str content: str def to_message(self) -> DialMessage: return DialMessage( role=Role.TOOL, tool_call_id=self.id, content=self.content, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.TOOL: return None if not is_plain_text_content(message.content): raise ValidationError( "The tool message shouldn't contain content parts" ) if message.content is None or message.tool_call_id is None: raise ValidationError( "The tool message is expected to have content and tool_call_id fields" ) return cls(id=message.tool_call_id, content=message.content) class HumanFunctionResultMessage(MessageABC): name: str content: str def to_message(self) -> DialMessage: return DialMessage( role=Role.FUNCTION, name=self.name, content=self.content, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.FUNCTION: return None if not is_plain_text_content(message.content): raise ValidationError( "The function message shouldn't contain content parts" ) if message.content is None or message.name is None: raise ValidationError( "The function message is expected to have content and name fields" ) return cls(name=message.name, content=message.content) class AIRegularMessage(BaseMessageABC): content: str custom_content: Optional[CustomContent] = None def to_message(self) -> DialMessage: return DialMessage( role=Role.ASSISTANT, content=self.content, custom_content=self.custom_content, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.ASSISTANT: return None if message.function_call is not None or message.tool_calls is not None: return None if not is_plain_text_content(message.content): raise ValidationError( "The assistant message shouldn't contain content parts" ) if message.content is None: raise ValidationError( "The assistant message is expected to have content" ) return cls( content=message.content, custom_content=message.custom_content ) @property def text_content(self) -> str: return self.content @property def attachments(self) -> List[Attachment]: return ( self.custom_content.attachments or [] if self.custom_content else [] ) class AIToolCallMessage(MessageABC): calls: List[ToolCall] content: Optional[str] = None def to_message(self) -> DialMessage: return DialMessage( role=Role.ASSISTANT, content=self.content, tool_calls=self.calls, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.ASSISTANT: return None if message.tool_calls is None or message.function_call is not None: return None if not is_plain_text_content(message.content): raise ValidationError( "The assistant message with tool calls shouldn't contain content parts" ) return cls(calls=message.tool_calls, content=message.content) class AIFunctionCallMessage(MessageABC): call: FunctionCall content: Optional[str] = None def to_message(self) -> DialMessage: return DialMessage( role=Role.ASSISTANT, content=self.content, function_call=self.call, ) @classmethod def from_message(cls, message: DialMessage) -> Self | None: if message.role != Role.ASSISTANT: return None if message.function_call is None or message.tool_calls is not None: return None if not is_plain_text_content(message.content): raise ValidationError( "The assistant message with function call shouldn't contain content parts" ) return cls(call=message.function_call, content=message.content) BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage] ToolMessage = Union[ HumanToolResultMessage, HumanFunctionResultMessage, AIToolCallMessage, AIFunctionCallMessage, ] def parse_dial_message(msg: DialMessage) -> BaseMessage | ToolMessage: message = ( SystemMessage.from_message(msg) or HumanRegularMessage.from_message(msg) or HumanToolResultMessage.from_message(msg) or HumanFunctionResultMessage.from_message(msg) or AIRegularMessage.from_message(msg) or AIToolCallMessage.from_message(msg) or AIFunctionCallMessage.from_message(msg) ) if message is None: raise ValidationError("Unknown message type or invalid message") return message