aidial_adapter_bedrock/llm/consumer.py (128 lines of code) (raw):

from abc import ABC, abstractmethod from types import TracebackType from typing import Optional, assert_never from aidial_sdk.chat_completion import ( Attachment, Choice, FinishReason, FunctionCall, Response, ToolCall, ) from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.message import ( AIFunctionCallMessage, AIToolCallMessage, ) from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages class Consumer(ABC): @abstractmethod def append_content(self, content: str): pass @abstractmethod def close_content(self, finish_reason: FinishReason | None = None): pass @abstractmethod def add_attachment(self, attachment: Attachment): pass @abstractmethod def add_usage(self, usage: TokenUsage): pass @abstractmethod def set_discarded_messages( self, discarded_messages: Optional[DiscardedMessages] ): pass @abstractmethod def create_function_tool_call(self, tool_call: ToolCall): pass @abstractmethod def create_function_call(self, function_call: FunctionCall): pass @property @abstractmethod def has_function_call(self) -> bool: pass class ChoiceConsumer(Consumer): usage: TokenUsage response: Response _choice: Optional[Choice] discarded_messages: Optional[DiscardedMessages] tools_emulator: Optional[ToolsEmulator] def __init__(self, response: Response): self.response = response self._choice = None self.usage = TokenUsage() self.discarded_messages = None self.tools_emulator = None @property def choice(self) -> Choice: if self._choice is None: # Delay opening a choice to the very last moment # so as to give opportunity for exceptions to bubble up to # the level of HTTP response (instead of error objects in a stream). choice = self._choice = self.response.create_choice() choice.open() return choice else: return self._choice def __enter__(self): return self def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None, ) -> bool | None: if exc is None and self._choice is not None: self._choice.close() return False def set_tools_emulator(self, tools_emulator: ToolsEmulator): self.tools_emulator = tools_emulator def _process_content( self, content: str | None, finish_reason: FinishReason | None = None ): if self.tools_emulator is not None: res = self.tools_emulator.recognize_call(content) else: res = content if res is None: # Choice.close(finish_reason: Optional[FinishReason]) can be called only once # Currently, there's no other way to explicitly set the finish reason self.choice._last_finish_reason = finish_reason return if isinstance(res, str): self.choice.append_content(res) return if isinstance(res, AIToolCallMessage): for call in res.calls: self.create_function_tool_call(call) return if isinstance(res, AIFunctionCallMessage): call = res.call self.choice.create_function_call( name=call.name, arguments=call.arguments ) return assert_never(res) def close_content(self, finish_reason: FinishReason | None = None): self._process_content(None, finish_reason) def append_content(self, content: str): self._process_content(content) def add_attachment(self, attachment: Attachment): self.choice.add_attachment(attachment) def add_usage(self, usage: TokenUsage): self.usage.accumulate(usage) def set_discarded_messages( self, discarded_messages: Optional[DiscardedMessages] ): self.discarded_messages = discarded_messages def create_function_tool_call(self, tool_call: ToolCall): self.choice.create_function_tool_call( id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments, ) def create_function_call(self, function_call: FunctionCall): self.choice.create_function_call( name=function_call.name, arguments=function_call.arguments ) @property def has_function_call(self) -> bool: return self.choice.has_function_call