aidial_adapter_bedrock/llm/chat_emulator.py (82 lines of code) (raw):

from abc import ABC, abstractmethod from typing import Callable, List, Optional, Tuple, TypedDict from pydantic import BaseModel from aidial_adapter_bedrock.llm.message import ( AIRegularMessage, BaseMessage, HumanRegularMessage, ) class ChatEmulator(ABC, BaseModel): @abstractmethod def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]: """Returns a prompt string and a list of stop sequences.""" @abstractmethod def get_ai_cue(self) -> Optional[str]: pass class CueMapping(TypedDict): system: Optional[str] human: Optional[str] ai: Optional[str] class BasicChatEmulator(ChatEmulator): prelude_template: Optional[str] add_cue: Callable[[BaseMessage, int], bool] add_invitation_cue: bool fallback_to_completion: bool cues: CueMapping separator: str @property def _prelude(self) -> Optional[str]: if self.prelude_template is None: return None return self.prelude_template.format(**self.cues) def _get_cue(self, message: BaseMessage) -> Optional[str]: if isinstance(message, HumanRegularMessage): return self.cues["human"] elif isinstance(message, AIRegularMessage): return self.cues["ai"] elif isinstance(message, BaseMessage): return self.cues["system"] else: raise ValueError(f"Unknown message type: {message.type}") def _format_message(self, message: BaseMessage, idx: int) -> str: cue = self._get_cue(message) if cue is None or not self.add_cue(message, idx): cue_prefix = "" else: cue_prefix = cue + " " return (cue_prefix + message.text_content.lstrip()).rstrip() def get_ai_cue(self) -> Optional[str]: return self.cues["ai"] def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]: if ( self.fallback_to_completion and len(messages) == 1 and isinstance(messages[0], HumanRegularMessage) ): return messages[0].text_content, [] ret: List[str] = [] if self._prelude is not None: ret.append(self._prelude) for message in messages: ret.append(self._format_message(message, len(ret))) if self.add_invitation_cue: ret.append( self._format_message(AIRegularMessage(content=""), len(ret)) ) stop_sequences: List[str] = [] human_role = self.cues["human"] if human_role is not None: stop_sequences = [self.separator + human_role] return self.separator.join(ret), stop_sequences default_emulator = BasicChatEmulator( prelude_template=""" You are a helpful assistant participating in a dialog with a user. The messages from the user start with "{ai}". The messages from you start with "{human}". Reply to the last message from the user taking into account the preceding dialog history. ==================== """.strip(), add_cue=lambda *_: True, add_invitation_cue=True, fallback_to_completion=True, cues=CueMapping( system="Human:", human="Human:", ai="Assistant:", ), separator="\n\n", )