aidial_adapter_vertexai/chat/bison/prompt.py (118 lines of code) (raw):

from enum import Enum from typing import List, Optional, Set, Tuple from aidial_sdk.chat_completion import Message, Role from pydantic.v1 import BaseModel from vertexai.preview.language_models import ChatMessage, ChatSession from aidial_adapter_vertexai.chat.errors import ValidationError from aidial_adapter_vertexai.chat.truncate_prompt import TruncatablePrompt from aidial_adapter_vertexai.dial_api.request import collect_text_content class ChatAuthor(str, Enum): USER = ChatSession.USER_AUTHOR BOT = ChatSession.MODEL_AUTHOR def __repr__(self) -> str: return f"{self.value!r}" class BisonPrompt(BaseModel, TruncatablePrompt): system_instruction: Optional[str] = None history: List[ChatMessage] = [] last_user_message: str @classmethod def parse(cls, history: List[Message]) -> "BisonPrompt": system_instruction, history, last_user_message = ( _validate_and_split_messages(history) ) return cls( system_instruction=system_instruction, history=list(map(_to_bison_message, history)), last_user_message=last_user_message, ) @property def has_system_instruction(self) -> bool: return self.system_instruction is not None def is_required_message(self, index: int) -> bool: # Keep the system instruction... if self.has_system_instruction and index == 0: return True # ...and the last user message if index == len(self) - 1: return True return False def __len__(self) -> int: return int(self.has_system_instruction) + len(self.history) + 1 def partition_messages(self) -> List[int]: n = len(self.history) return ( [1] * self.has_system_instruction + [2] * (n // 2) + [1] * (n % 2) + [1] ) def select(self, indices: Set[int]) -> "BisonPrompt": system_instruction: str | None = None history: List[ChatMessage] = [] offset = 0 if self.has_system_instruction and 0 in indices: system_instruction = self.system_instruction offset += 1 for idx in range(len(self.history)): if idx + offset in indices: history.append(self.history[idx]) offset += len(self.history) if offset not in indices: raise RuntimeError("The last user prompt must not be omitted.") return BisonPrompt( system_instruction=system_instruction, history=history, last_user_message=self.last_user_message, ) _SUPPORTED_ROLES = {Role.SYSTEM, Role.USER, Role.ASSISTANT} def _validate_and_split_messages( messages: List[Message], ) -> Tuple[Optional[str], List[Message], str]: if len(messages) == 0: raise ValidationError("The chat history must have at least one message") for message in messages: if message.content is None: raise ValidationError("Message content must be present") if message.role not in _SUPPORTED_ROLES: raise ValidationError( f"Message role must be one of {_SUPPORTED_ROLES}" ) if len(messages) > 0 and messages[0].role == Role.SYSTEM: system_message, *history = messages system_instruction = collect_text_content(system_message.content) system_instruction = ( system_instruction if system_instruction.strip() else None ) else: system_instruction, history = None, messages if len(history) == 0 and system_instruction is not None: raise ValidationError( "The chat history must have at least one non-system message" ) role: Optional[Role] = None for message in history: if message.role == Role.SYSTEM: raise ValidationError( "System messages other than the initial system message are not allowed" ) # Bison doesn't support empty messages, # so we replace it with a single space. message.content = message.content or " " if role == message.role: raise ValidationError("Messages must alternate between authors") role = message.role if len(history) % 2 == 0: raise ValidationError( "There should be odd number of messages for correct alternating turn" ) *history, last_message = history if last_message.role != Role.USER: raise ValidationError("The last message must be a user message") return ( system_instruction, history, collect_text_content(last_message.content), ) def _to_bison_message(message: Message) -> ChatMessage: author = ( ChatAuthor.BOT if message.role == Role.ASSISTANT else ChatAuthor.USER ) return ChatMessage( author=author, content=collect_text_content(message.content) )