aidial_adapter_bedrock/llm/converse/output.py (128 lines of code) (raw):

import json from typing import Any, AsyncIterator, Dict, assert_never from aidial_sdk.chat_completion import FinishReason as DialFinishReason from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall from aidial_sdk.chat_completion import ToolCall as DialToolCall from aidial_sdk.exceptions import RuntimeServerError from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.converse.constants import ( CONVERSE_TO_DIAL_FINISH_REASON, ) from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode def to_dial_finish_reason( converse_stop_reason: ConverseStopReason, ) -> DialFinishReason: if converse_stop_reason not in CONVERSE_TO_DIAL_FINISH_REASON.keys(): raise RuntimeServerError( f"Unsupported converse stop reason: {converse_stop_reason}" ) return CONVERSE_TO_DIAL_FINISH_REASON[converse_stop_reason] async def process_streaming( params: ModelParameters, stream: AsyncIterator[Any], consumer: Consumer, ) -> None: current_tool_use = None async for event in stream: if (content_block_start := event.get("contentBlockStart")) and ( tool_use := content_block_start.get("start", {}).get("toolUse") ): if current_tool_use is not None: raise ValueError("Tool use already started") current_tool_use = {"input": ""} | tool_use elif content_block := event.get("contentBlockDelta"): delta = content_block.get("delta", {}) if message := delta.get("text"): consumer.append_content(message) if "toolUse" in delta: if current_tool_use is None: raise ValueError("Received tool delta before start block") else: current_tool_use["input"] += delta["toolUse"].get( "input", "" ) elif event.get("contentBlockStop"): if current_tool_use: match params.tools_mode: case ToolsMode.TOOLS: consumer.create_function_tool_call( tool_call=DialToolCall( type="function", id=current_tool_use["toolUseId"], index=None, function=DialFunctionCall( name=current_tool_use["name"], arguments=current_tool_use["input"], ), ) ) case ToolsMode.FUNCTIONS: # ignoring multiple function calls in one response if not consumer.has_function_call: consumer.create_function_call( function_call=DialFunctionCall( name=current_tool_use["name"], arguments=current_tool_use["input"], ) ) case None: raise RuntimeError( "Tool use received without tools mode" ) case _: assert_never(params.tools_mode) current_tool_use = None elif (message_stop := event.get("messageStop")) and ( stop_reason := message_stop.get("stopReason") ): consumer.close_content(to_dial_finish_reason(stop_reason)) def process_non_streaming( params: ModelParameters, response: Dict[str, Any], consumer: Consumer, ) -> None: message = response["output"]["message"] for content_block in message.get("content", []): if "text" in content_block: consumer.append_content(content_block["text"]) if "toolUse" in content_block: match params.tools_mode: case ToolsMode.TOOLS: consumer.create_function_tool_call( tool_call=DialToolCall( type="function", id=content_block["toolUse"]["toolUseId"], index=None, function=DialFunctionCall( name=content_block["toolUse"]["name"], arguments=json.dumps( content_block["toolUse"]["input"] ), ), ) ) case ToolsMode.FUNCTIONS: # ignoring multiple function calls in one response if not consumer.has_function_call: consumer.create_function_call( function_call=DialFunctionCall( name=content_block["toolUse"]["name"], arguments=json.dumps( content_block["toolUse"]["input"] ), ) ) case None: raise RuntimeError("Tool use received without tools mode") case _: assert_never(params.tools_mode) if usage := response.get("usage"): consumer.add_usage( TokenUsage( prompt_tokens=usage.get("inputTokens", 0), completion_tokens=usage.get("outputTokens", 0), ) ) if stop_reason := response.get("stopReason"): consumer.close_content(to_dial_finish_reason(stop_reason))