aidial_adapter_bedrock/dial_api/request.py (103 lines of code) (raw):
from typing import List, Optional, TypeGuard, assert_never
from aidial_sdk.chat_completion import (
MessageContentImagePart,
MessageContentPart,
MessageContentTextPart,
)
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tools.tools_config import (
ToolsConfig,
ToolsMode,
validate_messages,
)
MessageContent = str | List[MessageContentPart] | None
MessageContentSpecialized = (
MessageContent
| List[MessageContentTextPart]
| List[MessageContentImagePart]
)
class ModelParameters(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: List[str] = []
max_tokens: Optional[int] = None
max_prompt_tokens: Optional[int] = None
stream: bool = False
tool_config: Optional[ToolsConfig] = None
@classmethod
def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
stop: List[str] = []
if request.stop is not None:
stop = (
[request.stop]
if isinstance(request.stop, str)
else request.stop
)
validate_messages(request)
return cls(
temperature=request.temperature,
top_p=request.top_p,
n=request.n,
stop=stop,
max_tokens=request.max_tokens,
max_prompt_tokens=request.max_prompt_tokens,
stream=request.stream,
tool_config=ToolsConfig.from_request(request),
)
def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
return self.copy(update={"stop": [*self.stop, *stop]})
@property
def tools_mode(self) -> ToolsMode | None:
if self.tool_config is not None:
return self.tool_config.tools_mode
return None
def collect_text_content(
content: MessageContentSpecialized, delimiter: str = "\n\n"
) -> str:
match content:
case None:
return ""
case str():
return content
case list():
texts: List[str] = []
for part in content:
match part:
case MessageContentTextPart(text=text):
texts.append(text)
case MessageContentImagePart():
raise ValidationError(
"Can't extract text from an image content part"
)
case _:
assert_never(part)
return delimiter.join(texts)
case _:
assert_never(content)
def to_message_content(content: MessageContentSpecialized) -> MessageContent:
match content:
case None | str():
return content
case list():
return [*content]
case _:
assert_never(content)
def is_text_content(
content: MessageContent,
) -> TypeGuard[str | List[MessageContentTextPart]]:
match content:
case None:
return False
case str():
return True
case list():
return all(
isinstance(part, MessageContentTextPart) for part in content
)
case _:
assert_never(content)
def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]:
return content is None or isinstance(content, str)