src/backend/entrypoints/llm_backend/api/models/optimization.py (111 lines of code) (raw):
from enum import Enum
from typing import Any, Optional, Union
from fastapi import Query
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import FieldValidationInfo
from market_alerts.domain.dtos import (
optimization_samplers_dtos,
optimization_target_funcs_dtos,
)
from market_alerts.entrypoints.llm_backend.api.models.llm import BacktestingRequestModel
from market_alerts.entrypoints.llm_backend.domain.exceptions import (
InvalidOptimizationParamError,
)
class ParamType(str, Enum):
string = "str"
integer = "int"
float = "float"
class ParamRange(BaseModel):
name: str
values: list
class Param(BaseModel):
name: str
value: Union[int, float, str]
class OptimizationRequestModel(BacktestingRequestModel):
n_trials: int = Field(100, ge=1, le=1000)
train_size: float = Field(1.0, ge=0.01, le=1.0)
minimize: bool = Field(True)
maximize: bool = Field(True)
params: list[ParamRange]
sampler: str = optimization_samplers_dtos[0].query_param
target_func: str = optimization_target_funcs_dtos[0].query_param
@field_validator("maximize")
def check_minimize_or_maximize(cls, v: bool, info: FieldValidationInfo):
if "minimize" in info.data:
if not (info.data["minimize"] or v):
raise ValueError("Either minimize or maximize must be True")
return v
@field_validator("sampler")
def check_sampler(cls, v: str):
allowed_values = [dto.query_param for dto in optimization_samplers_dtos]
if v not in allowed_values:
raise ValueError(
f"Sampler value '{v}' was provided, which is not among the allowed values: {', '.join(allowed_values)}"
)
return v
@field_validator("target_func")
def check_target_func(cls, v: str):
allowed_values = [dto.query_param for dto in optimization_target_funcs_dtos]
if v not in allowed_values:
raise ValueError(
f"Target func value '{v}' was provided, which is not among the allowed values: {', '.join(allowed_values)}"
)
return v
class OptimizationResult(BaseModel):
best_params: dict[str, Any]
trials: list[tuple[int, float, float, dict[str, Any], float]]
class OptimizationResponseModel(BaseModel):
minimization: Optional[OptimizationResult]
maximization: Optional[OptimizationResult]
sampler: Optional[str]
target_func: Optional[str]
class AfterOptimizationSetAsDefaultRequestModel(BaseModel):
params: list[Param]
class AfterOptimizationBacktestingRequestModel(BacktestingRequestModel):
params: list[list[Param]]
def validate_params(
params: Union[list[Param], list[ParamRange]],
optimization_params: dict[str, list[Any]],
) -> None:
param_names = {param.name for param in params}
if not set(optimization_params.keys()).issubset(param_names):
missing_keys = set(optimization_params.keys()) - param_names
raise InvalidOptimizationParamError(f"Missing params: {', '.join(missing_keys)}")
for param in params:
range_mode = isinstance(param, ParamRange)
param_name = param.name
_validate_param_name(param_name, optimization_params)
_, param_type = optimization_params[param_name]
values = param.values if range_mode else [param.value]
if param_type == ParamType.string:
if not all(isinstance(v, str) for v in values):
raise InvalidOptimizationParamError(
f"'{param_name}' param has type 'str', all values must have typalerte 'str' as well"
)
if range_mode and (len(values) < 1 or len(values) > 10):
raise InvalidOptimizationParamError(
f"'{param_name}' param has type 'str', the values amount must be between 1 and 10"
)
elif param_type == ParamType.integer:
if not all(isinstance(v, int) for v in values):
raise InvalidOptimizationParamError(
f"'{param_name}' param has type 'int', all values must have type 'int' as well"
)
if range_mode and len(values) != 2:
raise InvalidOptimizationParamError(f"'{param_name}' param has type 'int', the values amount must be 2")
elif param_type == ParamType.float:
if not all(isinstance(v, (int, float)) for v in values):
raise InvalidOptimizationParamError(
f"'{param_name}' param has type 'float', all values must have type 'float' as well"
)
if range_mode and len(values) != 2:
raise InvalidOptimizationParamError(f"'{param_name}' param has type 'float', the values amount must be 2")
def _validate_param_name(param_name, optimization_params: dict[str, list[Any]]):
if param_name not in optimization_params:
raise InvalidOptimizationParamError(f"'{param_name}' param wasn't met in the code")
class OptimizationResultRequestModel(BaseModel):
studies_names: list[str] = Field(Query([]))
class OptimizationCalendarResponseModel(BaseModel):
cutoff_date: str
start_date: str
end_date: str