aidial_adapter_bedrock/llm/model/amazon.py (107 lines of code) (raw):
from typing import Any, AsyncIterator, Dict, List, Optional
from aidial_sdk.chat_completion import Message
from pydantic import BaseModel
from typing_extensions import override
from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulator import default_emulator
from aidial_adapter_bedrock.llm.chat_model import (
PseudoChatModel,
trivial_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.tools.default_emulator import (
default_tools_emulator,
)
class AmazonResult(BaseModel):
tokenCount: int
outputText: str
completionReason: Optional[str]
class AmazonResponse(BaseModel):
inputTextTokenCount: int
results: List[AmazonResult]
def content(self) -> str:
assert (
len(self.results) == 1
), "AmazonResponse should only have one result"
return self.results[0].outputText
def usage(self) -> TokenUsage:
assert (
len(self.results) == 1
), "AmazonResponse should only have one result"
return TokenUsage(
prompt_tokens=self.inputTextTokenCount,
completion_tokens=self.results[0].tokenCount,
)
def convert_params(params: ModelParameters) -> Dict[str, Any]:
ret = {}
if params.temperature is not None:
ret["temperature"] = params.temperature
if params.top_p is not None:
ret["topP"] = params.top_p
if params.max_tokens is not None:
ret["maxTokenCount"] = params.max_tokens
else:
# The default for max tokens is 128, which is too small for most use cases.
# Choosing reasonable default.
ret["maxTokenCount"] = DEFAULT_MAX_TOKENS_AMAZON
# NOTE: Amazon Titan (amazon.titan-tg1-large) currently only supports
# stop sequences matching pattern "$\|+".
# if params.stop is not None:
# ret["stopSequences"] = params.stop
return ret
def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]:
return {"inputText": prompt, "textGenerationConfig": params}
async def chunks_to_stream(
chunks: AsyncIterator[dict], usage: TokenUsage
) -> AsyncIterator[str]:
async for chunk in chunks:
input_tokens = chunk.get("inputTextTokenCount")
if input_tokens is not None:
usage.prompt_tokens = input_tokens
output_tokens = chunk.get("totalOutputTextTokenCount")
if output_tokens is not None:
usage.completion_tokens = output_tokens
yield chunk["outputText"]
async def response_to_stream(
response: dict, usage: TokenUsage
) -> AsyncIterator[str]:
resp = AmazonResponse.parse_obj(response)
token_usage = resp.usage()
usage.completion_tokens = token_usage.completion_tokens
usage.prompt_tokens = token_usage.prompt_tokens
yield resp.content()
class AmazonAdapter(PseudoChatModel):
model: str
client: Bedrock
@classmethod
def create(cls, client: Bedrock, model: str):
return cls(
client=client,
model=model,
tokenize_string=default_tokenize_string,
# TODO: To use conversational mode on Titan, you can use the format of User: {{}} \n Bot: when prompting the model.
# See the note at the end of: https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-a-prompt.html
chat_emulator=default_emulator,
tools_emulator=default_tools_emulator,
partitioner=trivial_partitioner,
)
@override
def preprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = super().preprocess_messages(messages)
# AWS Titan doesn't support empty messages,
# so we replace it with a single space.
for msg in messages:
msg.content = msg.content or " "
return messages
async def predict(
self, consumer: Consumer, params: ModelParameters, prompt: str
):
args = create_request(prompt, convert_params(params))
usage = TokenUsage()
if params.stream:
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks, usage)
else:
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response, usage)
stream = self.post_process_stream(stream, params, self.chat_emulator)
async for content in stream:
consumer.append_content(content)
consumer.close_content()
consumer.add_usage(usage)