def train_agent()

in run_bc.py [0:0]


def train_agent(env, state_dim, action_dim, max_action, device, output_dir, args):
    # Load buffer
    dataset = d4rl.qlearning_dataset(env)
    data_sampler = Data_Sampler(dataset, device, args.reward_tune)
    utils.print_banner('Loaded buffer')

    if args.algo == 'bc':
        from agents.bc_diffusion import BC as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      beta_schedule=args.beta_schedule,
                      n_timesteps=args.T,
                      model_type=args.model,
                      lr=args.lr)
    elif args.algo == 'bc_mle':
        from agents.bc_mle import BC_MLE as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      lr=args.lr)
    elif args.algo == 'bc_cvae':
        from agents.bc_cvae import BC_CVAE as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      lr=args.lr)
    elif args.algo == 'bc_kl':
        from agents.bc_kl import BC_KL as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      num_samples_match=10,
                      lr=args.lr)
    elif args.algo == 'bc_mmd':
        from agents.bc_mmd import BC_MMD as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      num_samples_match=10,
                      mmd_sigma=20.0,
                      lr=args.lr)
    elif args.algo == 'bc_w':
        from agents.bc_w import BC_W as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      w_gamma=5.0,
                      lr=args.lr)
    elif args.algo == 'bc_gan':
        from agents.bc_gan import BC_GAN as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      lr=args.lr)
    elif args.algo == 'bc_gan2':
        from agents.bc_gan2 import BC_GAN as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      lr=args.lr)

    evaluations = []
    training_iters = 0
    max_timesteps = args.num_epochs * args.num_steps_per_epoch
    best_score = -100.
    while training_iters < max_timesteps:
        iterations = int(args.eval_freq * args.num_steps_per_epoch)
        utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90)
        agent.train(data_sampler,
                    iterations=iterations,
                    batch_size=args.batch_size)
        training_iters += iterations
        curr_epoch = int(training_iters // int(args.num_steps_per_epoch))
        logger.record_tabular('Trained Epochs', curr_epoch)

        eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(agent, args.env_name, args.seed,
                                                                               eval_episodes=args.eval_episodes)
        evaluations.append([eval_res, eval_res_std, eval_norm_res, eval_norm_res_std])
        np.save(os.path.join(output_dir, "eval"), evaluations)
        logger.record_tabular('Average Episodic Reward', eval_res)
        logger.record_tabular('Average Episodic N-Reward', eval_norm_res)
        logger.dump_tabular()

        # record and save the best model
        if eval_norm_res >= best_score:
            if args.save_best_model: agent.save_model(output_dir)
            best_score = eval_norm_res
            best_res = {'epoch': curr_epoch, 'best normalized score avg': eval_norm_res,
                        'best normalized score std': eval_norm_res_std,
                        'best raw score avg': eval_res, 'best raw score std': eval_res_std}
            with open(os.path.join(output_dir, "best_score.txt"), 'w') as f:
                f.write(json.dumps(best_res))