in agents/bc_gan2.py [0:0]
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
lr=3e-4):
self.actor = self.generator = Generator(state_dim,
action_dim,
max_action,
device,
z_dim=min(action_dim, 10)).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.gen_optim = torch.optim.Adam(self.generator.parameters(), lr=2e-4)
self.discriminator = Discriminator(state_dim, action_dim).to(device)
self.disc_optim = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4)
self.adversarial_loss = torch.nn.BCEWithLogitsLoss()
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=2e-4)
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.device = device