def sample_t_middle()

in agents/diffusion.py [0:0]


    def sample_t_middle(self, state):
        batch_size = state.shape[0]
        shape = (batch_size, self.action_dim)
        device = self.betas.device

        batch_size = shape[0]
        x = torch.randn(shape, device=device)

        t = np.random.randint(0, int(self.n_timesteps*0.2))
        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, timesteps, state, grad=(i == t))
        action = x
        return action.clamp_(-self.max_action, self.max_action)