aidial_analytics_realtime/rates.py (66 lines of code) (raw):

import os from decimal import Decimal from typing import Annotated, Dict, Literal, Union from pydantic import BaseModel, Field, parse_raw_as class ModelRate(BaseModel): prompt_price: Annotated[Decimal, Field(default_factory=Decimal)] completion_price: Annotated[Decimal, Field(default_factory=Decimal)] class TokenModelRate(ModelRate): unit: Literal["token"] def calculate( self, request_content: str, response_content: str, usage: dict | None ): price = Decimal(0) if usage is None: return price prompt_tokens = Decimal(usage["prompt_tokens"]) completion_tokens = Decimal(usage["completion_tokens"]) return ( self.prompt_price * prompt_tokens + self.completion_price * completion_tokens ) class CharWithoutSpaceModelRate(ModelRate): unit: Literal["char_without_whitespace"] @staticmethod def get_chars_without_whitespaces(a: str): return sum([1 if i != " " else 0 for i in a]) def calculate( self, request_content: str, response_content: str, usage: dict | None ): request_len = self.get_chars_without_whitespaces(request_content) response_len = self.get_chars_without_whitespaces(response_content) return ( self.prompt_price * request_len + self.completion_price * response_len ) Rates = Dict[ str, Annotated[ Union[TokenModelRate, CharWithoutSpaceModelRate], Field(discriminator="unit"), ], ] class RatesCalculator: def __init__(self, rates_str: str | None = None): if rates_str is None: rates_str = os.environ.get("MODEL_RATES", "{}") assert rates_str is not None self.rates = parse_raw_as(Rates, rates_str) def get_rate(self, deployment: str, model: str): deployment_rate = self.rates.get(deployment) if deployment_rate is not None: return deployment_rate else: return self.rates.get(model) def calculate_price( self, deployment: str, model: str, request_content: str, response_content: str, usage: dict | None, ) -> Decimal: rate = self.get_rate(deployment, model) if not rate: return Decimal(0) return rate.calculate(request_content, response_content, usage)