in agents/diffusion.py [0:0]
def sample_t_last(self, state):
batch_size = state.shape[0]
shape = (batch_size, self.action_dim)
device = self.betas.device
x = torch.randn(shape, device=device)
cur_T = np.random.randint(int(self.n_timesteps * 0.8), self.n_timesteps)
for i in reversed(range(0, cur_T)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
if i != 0:
with torch.no_grad():
x = self.p_sample(x, timesteps, state)
else:
x = self.p_sample(x, timesteps, state)
action = x
return action.clamp_(-self.max_action, self.max_action)