aidial_adapter_bedrock/llm/chat_model.py (161 lines of code) (raw):

from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Callable, List, Optional from aidial_sdk.chat_completion import Message, Role from pydantic import BaseModel from typing_extensions import override import aidial_adapter_bedrock.utils.stream as stream_utils from aidial_adapter_bedrock.dial_api.request import ( ModelParameters, collect_text_content, ) from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.message import BaseMessage, SystemMessage from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig from aidial_adapter_bedrock.llm.truncate_prompt import ( DiscardedMessages, truncate_prompt, ) from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log def _is_empty_system_message(msg: Message) -> bool: return ( msg.role == Role.SYSTEM and collect_text_content(msg.content).strip() == "" ) class ChatCompletionAdapter(ABC, BaseModel): class Config: arbitrary_types_allowed = True @abstractmethod async def chat( self, consumer: Consumer, params: ModelParameters, messages: List[Message], ) -> None: pass async def count_prompt_tokens( self, params: ModelParameters, messages: List[Message] ) -> int: raise NotImplementedError async def count_completion_tokens(self, string: str) -> int: raise NotImplementedError async def compute_discarded_messages( self, params: ModelParameters, messages: List[Message] ) -> DiscardedMessages | None: """ The method truncates the list of messages to fit into the token limit set in `params.max_prompt_tokens`. If the limit isn't provided, then it returns None. Otherwise, returns the indices of _discarded_ messages which should be removed from the list to make the rest fit into the token limit. """ raise NotImplementedError class TextCompletionPrompt(BaseModel): text: str stop_sequences: List[str] discarded_messages: Optional[DiscardedMessages] = None class TextCompletionAdapter(ChatCompletionAdapter): tools_emulator: Callable[[Optional[ToolsConfig]], ToolsEmulator] @abstractmethod async def predict( self, consumer: Consumer, params: ModelParameters, prompt: str ) -> None: pass @abstractmethod async def truncate_and_linearize_messages( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] ) -> TextCompletionPrompt: pass def preprocess_messages(self, messages: List[Message]) -> List[Message]: # Skipping empty system messages messages = [ msg for msg in messages if not _is_empty_system_message(msg) ] if len(messages) == 0: raise ValidationError("List of messages must not be empty") return messages async def get_text_completion_prompt( self, params: ModelParameters, messages: List[Message] ) -> TextCompletionPrompt: messages = self.preprocess_messages(messages) tools_emulator = self.tools_emulator(params.tool_config) base_messages = tools_emulator.parse_dial_messages(messages) tool_stop_sequences = tools_emulator.get_stop_sequences() prompt = await self.truncate_and_linearize_messages( base_messages, params.max_prompt_tokens ) prompt.stop_sequences.extend(tool_stop_sequences) prompt.stop_sequences.extend(params.stop) return prompt async def chat( self, consumer: Consumer, params: ModelParameters, messages: List[Message], ) -> None: prompt = await self.get_text_completion_prompt(params, messages) params.stop = prompt.stop_sequences consumer.set_discarded_messages(prompt.discarded_messages) log.debug(f"model parameters: {params.json(exclude_none=True)}") log.debug(f"prompt: {prompt.text!r}") await self.predict(consumer, params, prompt.text) async def compute_discarded_messages( self, params: ModelParameters, messages: List[Message] ) -> DiscardedMessages | None: prompt = await self.get_text_completion_prompt(params, messages) return prompt.discarded_messages def keep_last(messages: List[Any], idx: int) -> bool: return idx == len(messages) - 1 def keep_last_and_system_messages( messages: List[BaseMessage], idx: int ) -> bool: return isinstance(messages[idx], SystemMessage) or keep_last(messages, idx) def trivial_partitioner(messages: List[Any]) -> List[int]: return [1] * len(messages) def turn_based_partitioner(messages: List[Any]) -> List[int]: n = len(messages) return [2] * (n // 2) + [1] * (n % 2) class PseudoChatModel(TextCompletionAdapter): chat_emulator: ChatEmulator tokenize_string: Callable[[str], int] partitioner: Callable[[List[BaseMessage]], List[int]] async def count_prompt_tokens( self, params: ModelParameters, messages: List[Message] ) -> int: messages = self.preprocess_messages(messages) tools_emulator = self.tools_emulator(params.tool_config) base_messages = tools_emulator.parse_dial_messages(messages) return await self.tokenize_messages(base_messages) async def count_completion_tokens(self, string: str) -> int: return self.tokenize_string(string) async def tokenize_messages(self, messages: List[BaseMessage]) -> int: return self.tokenize_string(self.chat_emulator.display(messages)[0]) @override async def truncate_and_linearize_messages( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] ) -> TextCompletionPrompt: discarded_messages, messages = await truncate_prompt( messages=messages, tokenizer=self.tokenize_messages, keep_message=keep_last_and_system_messages, partitioner=self.partitioner, model_limit=None, user_limit=max_prompt_tokens, ) text, stop_sequences = self.chat_emulator.display(messages) if max_prompt_tokens is None: discarded_messages = None return TextCompletionPrompt( text=text, stop_sequences=stop_sequences, discarded_messages=discarded_messages, ) @staticmethod def post_process_stream( stream: AsyncIterator[str], params: ModelParameters, emulator: ChatEmulator, ) -> AsyncIterator[str]: # Removing leading spaces stream = stream_utils.lstrip(stream) # Model may occasionally start responding with its cue. ai_cue = emulator.get_ai_cue() if ai_cue is not None: stream = stream_utils.remove_prefix(stream, ai_cue) stream = stream_utils.lstrip(stream) # The model may not support stop sequences, so do it manually if params.stop: stream = stream_utils.stop_at(stream, params.stop) # After all the post processing, the stream may become empty. # To avoid this, add a space to the stream. stream = stream_utils.ensure_not_empty(stream, " ") return stream