aidial_adapter_bedrock/llm/model/claude/v3/adapter.py (307 lines of code) (raw):
from dataclasses import dataclass
from logging import DEBUG
from typing import List, Optional, Tuple, assert_never
from aidial_sdk.chat_completion import Message as DialMessage
from anthropic import NOT_GIVEN, MessageStopEvent, NotGiven
from anthropic.lib.bedrock import AsyncAnthropicBedrock
from anthropic.lib.streaming import (
AsyncMessageStream,
ContentBlockStopEvent,
InputJsonEvent,
TextEvent,
)
from anthropic.types import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
MessageDeltaEvent,
)
from anthropic.types import MessageParam as ClaudeMessage
from anthropic.types import (
MessageStartEvent,
MessageStreamEvent,
TextBlock,
ToolUseBlock,
)
from anthropic.types.message_create_params import ToolChoice
from aidial_adapter_bedrock.adapter_deployments import AdapterDeployment
from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.deployments import Claude3Deployment
from aidial_adapter_bedrock.dial_api.request import (
ModelParameters as DialParameters,
)
from aidial_adapter_bedrock.dial_api.storage import (
FileStorage,
create_file_storage,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
keep_last,
turn_based_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import parse_dial_message
from aidial_adapter_bedrock.llm.model.claude.v3.converters import (
ClaudeFinishReason,
to_claude_messages,
to_claude_tool_config,
to_dial_finish_reason,
)
from aidial_adapter_bedrock.llm.model.claude.v3.params import ClaudeParameters
from aidial_adapter_bedrock.llm.model.claude.v3.tokenizer import (
create_tokenizer,
tokenize_text,
)
from aidial_adapter_bedrock.llm.model.claude.v3.tools import (
process_tools_block,
process_with_tools,
)
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
class UsageEventHandler(AsyncMessageStream):
prompt_tokens: int = 0
completion_tokens: int = 0
stop_reason: Optional[ClaudeFinishReason] = None
async def on_stream_event(self, event: MessageStreamEvent):
if isinstance(event, MessageStartEvent):
self.prompt_tokens = event.message.usage.input_tokens
elif isinstance(event, MessageDeltaEvent):
self.completion_tokens += event.usage.output_tokens
self.stop_reason = event.delta.stop_reason
# NOTE: it's not pydantic BaseModel, because
# ClaudeMessage.content is of Iterable type and
# pydantic automatically converts lists into
# list iterators following the type.
# See https://github.com/anthropics/anthropic-sdk-python/issues/656 for details.
@dataclass
class ClaudeRequest:
params: ClaudeParameters
messages: ListProjection[ClaudeMessage]
class Adapter(ChatCompletionAdapter):
deployment: AdapterDeployment[Claude3Deployment]
storage: Optional[FileStorage]
client: AsyncAnthropicBedrock
async def _prepare_claude_request(
self, params: DialParameters, messages: List[DialMessage]
) -> ClaudeRequest:
if len(messages) == 0:
raise ValidationError("List of messages must not be empty")
tools = NOT_GIVEN
tool_choice: ToolChoice | NotGiven = NOT_GIVEN
if (tool_config := params.tool_config) is not None:
tools = [
to_claude_tool_config(tool_function)
for tool_function in tool_config.functions
]
tool_choice = (
{"type": "any"} if tool_config.required else {"type": "auto"}
)
# NOTE tool_choice.disable_parallel_tool_use=True option isn't supported
# by older Claude3 versions, so we limit the number of generated function calls
# to one in the adapter itself for the functions mode.
parsed_messages = [
process_with_tools(parse_dial_message(m), params.tools_mode)
for m in messages
]
system_prompt, claude_messages = await to_claude_messages(
parsed_messages, self.storage
)
claude_params = ClaudeParameters(
max_tokens=params.max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC,
stop_sequences=params.stop,
system=system_prompt or NOT_GIVEN,
temperature=(
NOT_GIVEN
if params.temperature is None
else params.temperature / 2
),
top_p=params.top_p or NOT_GIVEN,
tools=tools,
tool_choice=tool_choice,
)
return ClaudeRequest(params=claude_params, messages=claude_messages)
async def _compute_discarded_messages(
self, request: ClaudeRequest, max_prompt_tokens: int | None
) -> Tuple[DiscardedMessages | None, ClaudeRequest]:
if max_prompt_tokens is None:
return None, request
discarded_messages, messages = await truncate_prompt(
messages=request.messages.list,
tokenizer=create_tokenizer(
self.deployment.reference_deployment_id, request.params
),
keep_message=keep_last,
partitioner=turn_based_partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)
claude_messages = ListProjection(messages)
discarded_messages = list(
request.messages.to_original_indices(discarded_messages)
)
return discarded_messages, ClaudeRequest(
params=request.params,
messages=claude_messages,
)
async def chat(
self,
consumer: Consumer,
params: DialParameters,
messages: List[DialMessage],
):
request = await self._prepare_claude_request(params, messages)
discarded_messages, request = await self._compute_discarded_messages(
request, params.max_prompt_tokens
)
if params.stream:
await self.invoke_streaming(
consumer,
params.tools_mode,
request,
discarded_messages,
)
else:
await self.invoke_non_streaming(
consumer,
params.tools_mode,
request,
discarded_messages,
)
async def count_prompt_tokens(
self, params: DialParameters, messages: List[DialMessage]
) -> int:
request = await self._prepare_claude_request(params, messages)
return await create_tokenizer(
self.deployment.reference_deployment_id, request.params
)(request.messages.list)
async def count_completion_tokens(self, string: str) -> int:
return tokenize_text(string)
async def compute_discarded_messages(
self, params: DialParameters, messages: List[DialMessage]
) -> DiscardedMessages | None:
request = await self._prepare_claude_request(params, messages)
discarded_messages, _request = await self._compute_discarded_messages(
request, params.max_prompt_tokens
)
return discarded_messages
async def invoke_streaming(
self,
consumer: Consumer,
tools_mode: ToolsMode | None,
request: ClaudeRequest,
discarded_messages: DiscardedMessages | None,
):
if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{
"deployment": self.deployment,
"request": request,
}
)
log.debug(f"Streaming request: {msg}")
async with self.client.messages.stream(
messages=request.messages.raw_list,
model=self.deployment.upstream_deployment_id,
**request.params,
) as stream:
prompt_tokens = 0
completion_tokens = 0
stop_reason = None
async for event in stream:
if log.isEnabledFor(DEBUG):
log.debug(
f"claude response event: {json_dumps_short(event)}"
)
match event:
case MessageStartEvent(message=message):
prompt_tokens += message.usage.input_tokens
case TextEvent(text=text):
consumer.append_content(text)
case MessageDeltaEvent(usage=usage):
completion_tokens += usage.output_tokens
case ContentBlockStopEvent(content_block=content_block):
match content_block:
case ToolUseBlock():
process_tools_block(
consumer, content_block, tools_mode
)
case TextBlock():
# Already handled in TextEvent
pass
case _:
assert_never(content_block)
case MessageStopEvent(message=message):
completion_tokens += message.usage.output_tokens
stop_reason = message.stop_reason
case (
InputJsonEvent()
| ContentBlockStartEvent()
| ContentBlockDeltaEvent()
):
pass
case _:
assert_never(event)
consumer.close_content(
to_dial_finish_reason(stop_reason, tools_mode)
)
consumer.add_usage(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
consumer.set_discarded_messages(discarded_messages)
async def invoke_non_streaming(
self,
consumer: Consumer,
tools_mode: ToolsMode | None,
request: ClaudeRequest,
discarded_messages: DiscardedMessages | None,
):
if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{
"deployment": self.deployment,
"request": request,
}
)
log.debug(f"Request: {msg}")
message = await self.client.messages.create(
messages=request.messages.raw_list,
model=self.deployment.upstream_deployment_id,
**request.params,
stream=False,
)
if log.isEnabledFor(DEBUG):
log.debug(f"claude response message: {json_dumps_short(message)}")
for content in message.content:
match content:
case TextBlock(text=text):
consumer.append_content(text)
case ToolUseBlock():
process_tools_block(consumer, content, tools_mode)
case _:
assert_never(content)
consumer.close_content(
to_dial_finish_reason(message.stop_reason, tools_mode)
)
consumer.add_usage(
TokenUsage(
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
)
)
consumer.set_discarded_messages(discarded_messages)
@classmethod
def create(
cls,
deployment: AdapterDeployment[Claude3Deployment],
api_key: str,
aws_client_config: AWSClientConfig,
):
storage: Optional[FileStorage] = create_file_storage(api_key=api_key)
client_kwargs = aws_client_config.get_anthropic_bedrock_client_kwargs()
return cls(
deployment=deployment,
storage=storage,
client=AsyncAnthropicBedrock(**client_kwargs),
)