in agents/bc_kl.py [0:0]
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
lr=3e-4,
num_samples_match=10,
kl_type='backward'
):
latent_dim = action_dim * 2
self.vae = VAE(state_dim, action_dim, latent_dim, max_action, device).to(device)
self.vae_optimizer = torch.optim.Adam(self.vae.parameters(), lr=lr)
self.actor = RegularActor(state_dim, action_dim, max_action, device).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.device = device
self.num_samples_match = num_samples_match
self.kl_type = kl_type