aidial_assistant/model/model_client.py (145 lines of code) (raw):

from abc import ABC from itertools import islice from typing import Any, AsyncIterator, List from aidial_sdk.utils.merge_chunks import merge from openai import AsyncOpenAI from openai.types.chat import ( ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, ) from aidial_assistant.utils.open_ai import Usage class ReasonLengthException(Exception): pass class ExtraResultsCallback: def on_discarded_messages(self, discarded_messages: list[int]): pass def on_prompt_tokens(self, prompt_tokens: int): pass def on_tool_calls( self, tool_calls: list[ChatCompletionMessageToolCallParam] ): pass async def _flush_stream(stream: AsyncIterator[str]): try: async for _ in stream: pass except ReasonLengthException: pass def _discarded_messages_count_to_indices( messages: list[ChatCompletionMessageParam], discarded_messages: int ) -> list[int]: return list( islice( ( i for i, message in enumerate(messages) if message["role"] != "system" ), discarded_messages, ) ) class ModelClient(ABC): def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]): self.client = client self.model_args = model_args self._total_prompt_tokens: int = 0 self._total_completion_tokens: int = 0 async def agenerate( self, messages: List[ChatCompletionMessageParam], extra_results_callback: ExtraResultsCallback | None = None, **kwargs, ) -> AsyncIterator[str]: model_result = await self.client.chat.completions.create( **self.model_args, extra_body=kwargs, stream=True, messages=messages, ) finish_reason_length = False tool_calls_chunks: list[list[dict[str, Any]]] = [] async for chunk in model_result: chunk_dict = chunk.dict() usage: Usage | None = chunk_dict.get("usage") if usage: prompt_tokens = usage["prompt_tokens"] self._total_prompt_tokens += prompt_tokens self._total_completion_tokens += usage["completion_tokens"] if extra_results_callback: extra_results_callback.on_prompt_tokens(prompt_tokens) if extra_results_callback: discarded_messages: int | list[int] | None = chunk_dict.get( "statistics", {} ).get("discarded_messages") if discarded_messages is not None: extra_results_callback.on_discarded_messages( _discarded_messages_count_to_indices( messages, discarded_messages ) if isinstance(discarded_messages, int) else discarded_messages ) choice = chunk.choices[0] delta = choice.delta if delta.content: yield delta.content if delta.tool_calls: tool_calls_chunks.append( [ tool_call_chunk.dict() for tool_call_chunk in delta.tool_calls ] ) if choice.finish_reason == "length": finish_reason_length = True if finish_reason_length: raise ReasonLengthException() if extra_results_callback and tool_calls_chunks: tool_calls: list[ChatCompletionMessageToolCallParam] = [ ChatCompletionMessageToolCallParam(**tool_call) for tool_call in merge(*tool_calls_chunks) ] extra_results_callback.on_tool_calls(tool_calls) # TODO: Use a dedicated endpoint for counting tokens. # This request may throw an error if the number of tokens is too large. async def count_tokens( self, messages: list[ChatCompletionMessageParam] ) -> int: class PromptTokensCallback(ExtraResultsCallback): def __init__(self): self.token_count: int | None = None def on_prompt_tokens(self, prompt_tokens: int): self.token_count = prompt_tokens callback = PromptTokensCallback() await _flush_stream( self.agenerate( messages, extra_results_callback=callback, max_tokens=1 ) ) if callback.token_count is None: raise Exception("No token count received.") return callback.token_count # TODO: Use a dedicated endpoint for discarded_messages. # https://github.com/epam/ai-dial-assistant/issues/39 async def get_discarded_messages( self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int ) -> list[int]: class DiscardedMessagesCallback(ExtraResultsCallback): def __init__(self): self.discarded_messages: list[int] | None = None def on_discarded_messages(self, discarded_messages: list[int]): self.discarded_messages = discarded_messages callback = DiscardedMessagesCallback() await _flush_stream( self.agenerate( messages, extra_results_callback=callback, max_prompt_tokens=max_prompt_tokens, max_tokens=1, ) ) if callback.discarded_messages is None: raise Exception("Discarded messages were not provided.") return callback.discarded_messages @property def total_prompt_tokens(self) -> int: return self._total_prompt_tokens @property def total_completion_tokens(self) -> int: return self._total_completion_tokens