aidial_adapter_vertexai/chat/bison/adapter.py (127 lines of code) (raw):

from typing import AsyncIterator, List, Optional from typing_extensions import TypedDict, override from vertexai.preview.language_models import ChatModel, CodeChatModel from aidial_adapter_vertexai.chat.bison.base import BisonChatCompletionAdapter from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt from aidial_adapter_vertexai.chat.errors import ValidationError from aidial_adapter_vertexai.dial_api.request import ModelParameters from aidial_adapter_vertexai.vertex_ai import ( get_chat_model, get_code_chat_model, ) class CodeChatParamsBase(TypedDict, total=False): max_output_tokens: Optional[int] temperature: Optional[float] class ChatParamsBase(TypedDict, total=False): max_output_tokens: Optional[int] temperature: Optional[float] # Extra compared to CodeChatParams stop_sequences: Optional[List[str]] top_k: Optional[int] top_p: Optional[float] ChatParamsStream = ChatParamsBase CodeChatParamsStream = CodeChatParamsBase class NoStreamParams(TypedDict, total=False): candidate_count: Optional[int] class ChatParamsNoStream(ChatParamsBase, NoStreamParams): pass class CodeChatParamsNoStream(CodeChatParamsBase, NoStreamParams): pass class BisonChatAdapter(BisonChatCompletionAdapter): model: ChatModel @classmethod async def create(cls, model_id: str) -> "BisonChatAdapter": return cls(await get_chat_model(model_id)) def prepare_parameters_no_stream( self, params: ModelParameters ) -> ChatParamsNoStream: return { "max_output_tokens": params.max_tokens, "temperature": params.temperature, "stop_sequences": params.stop, "top_p": params.top_p, "candidate_count": params.n, } def prepare_parameters_stream( self, params: ModelParameters ) -> ChatParamsStream: return { "max_output_tokens": params.max_tokens, "temperature": params.temperature, "stop_sequences": params.stop, "top_p": params.top_p, } @override async def send_message_async( self, params: ModelParameters, prompt: BisonPrompt ) -> AsyncIterator[str]: chat = self.model.start_chat( context=prompt.system_instruction, message_history=prompt.history, ) generic_validate_parameters(params) if params.stream: stream = chat.send_message_streaming_async( message=prompt.last_user_message, **self.prepare_parameters_stream(params), ) async for chunk in stream: yield chunk.text else: response = await chat.send_message_async( message=prompt.last_user_message, **self.prepare_parameters_no_stream(params), ) yield response.text class BisonCodeChatAdapter(BisonChatCompletionAdapter): model: CodeChatModel @classmethod async def create(cls, model_id: str) -> "BisonCodeChatAdapter": return cls(await get_code_chat_model(model_id)) def validate_parameters(self, params: ModelParameters) -> None: if params.stop is not None and params.stop != []: raise ValidationError( "stop sequences are not supported for code chat model" ) if params.top_p is not None: raise ValidationError("top_p is not supported for code chat model") def prepare_parameters_no_stream( self, params: ModelParameters ) -> CodeChatParamsNoStream: return { "max_output_tokens": params.max_tokens, "temperature": params.temperature, "candidate_count": params.n, } def prepare_parameters_stream( self, params: ModelParameters ) -> CodeChatParamsStream: return { "max_output_tokens": params.max_tokens, "temperature": params.temperature, } @override async def send_message_async( self, params: ModelParameters, prompt: BisonPrompt ) -> AsyncIterator[str]: chat = self.model.start_chat( context=prompt.system_instruction, message_history=prompt.history, ) generic_validate_parameters(params) self.validate_parameters(params) if params.stream: stream = chat.send_message_streaming_async( message=prompt.last_user_message, **self.prepare_parameters_stream(params), ) async for chunk in stream: yield chunk.text else: response = await chat.send_message_async( message=prompt.last_user_message, **self.prepare_parameters_no_stream(params), ) yield response.text def generic_validate_parameters(params: ModelParameters) -> None: # Currently n>1 is emulated by calling the model n times. # So the individual generation requests are expected to have n=1 or unset. if params.n is not None and params.n > 1: raise ValueError("n is expected to be 1 or unset")