aidial_adapter_bedrock/llm/converse/adapter.py (166 lines of code) (raw):
from typing import Any, Awaitable, Callable, List, Tuple
from aidial_sdk.chat_completion import Message as DialMessage
from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import FileStorage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
keep_last,
turn_based_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.converse.input import (
extract_converse_system_prompt,
to_converse_messages,
to_converse_tools,
)
from aidial_adapter_bedrock.llm.converse.output import (
process_non_streaming,
process_streaming,
)
from aidial_adapter_bedrock.llm.converse.types import (
ConverseDeployment,
ConverseDocumentType,
ConverseImageType,
ConverseMessage,
ConverseRequestWrapper,
ConverseTools,
InferenceConfig,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection
ConverseMessages = List[Tuple[ConverseMessage, Any]]
class ConverseAdapter(ChatCompletionAdapter):
deployment: str
bedrock: Bedrock
storage: FileStorage | None
supported_image_types: list[ConverseImageType]
supported_document_types: list[ConverseDocumentType]
tokenize_text: Callable[[str], int] = default_tokenize_string
input_tokenizer_factory: Callable[
[ConverseDeployment, ConverseRequestWrapper],
Callable[[ConverseMessages], Awaitable[int]],
]
support_tools: bool
partitioner: Callable[[ConverseMessages], List[int]] = (
turn_based_partitioner
)
async def _discard_messages(
self, params: ConverseRequestWrapper, max_prompt_tokens: int | None
) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]:
if max_prompt_tokens is None:
return None, params
discarded_messages, messages = await truncate_prompt(
messages=params.messages.list,
tokenizer=self.input_tokenizer_factory(self.deployment, params),
keep_message=keep_last,
partitioner=self.partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)
return list(
params.messages.to_original_indices(discarded_messages)
), ConverseRequestWrapper(
messages=ListProjection(
omit_by_indices(messages, discarded_messages)
),
system=params.system,
inferenceConfig=params.inferenceConfig,
toolConfig=params.toolConfig,
)
async def count_prompt_tokens(
self, params: ModelParameters, messages: List[DialMessage]
) -> int:
converse_params = await self.construct_converse_params(messages, params)
return await self.input_tokenizer_factory(
self.deployment, converse_params
)(converse_params.messages.list)
async def count_completion_tokens(self, string: str) -> int:
return self.tokenize_text(string)
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[DialMessage]
) -> DiscardedMessages | None:
converse_params = await self.construct_converse_params(messages, params)
discarded_messages, _ = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
return discarded_messages
def get_tool_config(self, params: ModelParameters) -> ConverseTools | None:
if params.tool_config and not self.support_tools:
raise ValidationError("Tools are not supported")
return (
to_converse_tools(params.tool_config)
if params.tool_config
else None
)
async def construct_converse_params(
self,
messages: List[DialMessage],
params: ModelParameters,
) -> ConverseRequestWrapper:
system_prompt_extraction = extract_converse_system_prompt(messages)
converse_messages = await to_converse_messages(
system_prompt_extraction.non_system_messages,
self.storage,
start_offset=system_prompt_extraction.system_message_count,
supported_image_types=self.supported_image_types,
supported_document_types=self.supported_document_types,
)
system_message = system_prompt_extraction.system_prompt
if not converse_messages.list:
raise ValidationError("List of messages must not be empty")
return ConverseRequestWrapper(
system=[system_message] if system_message else None,
messages=converse_messages,
inferenceConfig=InferenceConfig(
**remove_nones(
{
"temperature": params.temperature,
"topP": params.top_p,
"maxTokens": params.max_tokens,
"stopSequences": params.stop,
}
)
),
toolConfig=self.get_tool_config(params),
)
def is_stream(self, params: ModelParameters) -> bool:
return params.stream
async def chat(
self,
consumer: Consumer,
params: ModelParameters,
messages: List[DialMessage],
) -> None:
converse_params = await self.construct_converse_params(messages, params)
discarded_messages, converse_params = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
if not converse_params.messages.raw_list:
raise ValidationError("No messages left after truncation")
consumer.set_discarded_messages(discarded_messages)
if self.is_stream(params):
await process_streaming(
params=params,
stream=(
await self.bedrock.aconverse_streaming(
self.deployment, **converse_params.to_request()
)
),
consumer=consumer,
)
else:
process_non_streaming(
params=params,
response=await self.bedrock.aconverse_non_streaming(
self.deployment, **converse_params.to_request()
),
consumer=consumer,
)