def batch_size()

in common/batch.py [0:0]


  def batch_size(self) -> int:
    for tensor in self.as_dict().values():
      if tensor is None:
        continue
      if not isinstance(tensor, torch.Tensor):
        continue
      return tensor.shape[0]
    raise Exception("Could not determine batch size from tensors.")