def distribute_trials_per_task()

in src/backend/entrypoints/llm_backend/api/optimization.py [0:0]


def distribute_trials_per_task(n_trials: int, studies: list[str], tasks_per_study: int) -> list[tuple[str, int]]:
    if tasks_per_study > n_trials:
        raise RuntimeError("tasks_per_study was greater than n_trials")

    num_studies = len(studies)

    quotient, remainder = divmod(n_trials, num_studies)

    trials_per_study = [quotient + 1 if i < remainder else quotient for i in range(num_studies)]

    results = []

    for study_index, study in enumerate(studies):
        trials_in_study = trials_per_study[study_index]

        quotient, remainder = divmod(trials_in_study, tasks_per_study)

        trials_per_task_in_study = [quotient + 1 if i < remainder else quotient for i in range(tasks_per_study)]

        results.extend((study, t) for t in trials_per_task_in_study if t > 0)

    return results