aidial_adapter_openai/gpt.py (85 lines of code) (raw):
from typing import AsyncIterator, List, Tuple, cast
from aidial_sdk.exceptions import InvalidRequestError
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.parsers import chat_completions_parser
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.streaming import (
chunk_to_dict,
debug_print,
generate_stream,
map_stream,
)
from aidial_adapter_openai.utils.tokenizer import PlainTextTokenizer
from aidial_adapter_openai.utils.truncate_prompt import (
DiscardedMessages,
TruncatedTokens,
truncate_prompt,
)
def plain_text_truncate_prompt(
request: dict,
messages: List[dict],
max_prompt_tokens: int,
tokenizer: PlainTextTokenizer,
) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message["role"] == "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)
async def gpt_chat_completion(
request: dict,
deployment_id: str,
upstream_endpoint: str,
creds: OpenAICreds,
api_version: str,
tokenizer: PlainTextTokenizer,
eliminate_empty_choices: bool,
):
discarded_messages = None
estimated_prompt_tokens = None
if "max_prompt_tokens" in request:
max_prompt_tokens = request["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
raise InvalidRequestError(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
)
if max_prompt_tokens < 1:
raise InvalidRequestError(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
)
del request["max_prompt_tokens"]
request["messages"], discarded_messages, estimated_prompt_tokens = (
plain_text_truncate_prompt(
request=request,
messages=cast(List[dict], request["messages"]),
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
)
)
client = chat_completions_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version}
)
response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(client.chat.completions.create, request)
)
if isinstance(response, AsyncIterator):
return generate_stream(
stream=map_stream(chunk_to_dict, response),
get_prompt_tokens=lambda: estimated_prompt_tokens
or tokenizer.tokenize_request(request, request["messages"]),
tokenize_response=tokenizer.tokenize_response,
deployment=deployment_id,
discarded_messages=discarded_messages,
eliminate_empty_choices=eliminate_empty_choices,
)
else:
rest = response.to_dict()
if discarded_messages is not None:
rest |= {"statistics": {"discarded_messages": discarded_messages}}
debug_print("response", rest)
return rest