def to_batch()

in projects/home/recap/data/dataset.py [0:0]


def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
  """Converts a torch data loader output into `RecapBatch`."""

  x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
  try:
    features_in, labels = x
  except ValueError:
    # For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
    features_in, labels = x, None

  sparse_features = keyed_jagged_tensor_from_tensors_dict({})
  if sparse_feature_names:
    sparse_features = keyed_jagged_tensor_from_tensors_dict(
      {embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
    )

  user_embedding, user_eng_embedding, author_embedding = None, None, None
  if "user_embedding" in features_in:
    if sparse_feature_names and "meta__user_id" in sparse_feature_names:
      raise ValueError("Only one source of embedding for user is supported")
    else:
      user_embedding = features_in["user_embedding"]

  if "user_eng_embedding" in features_in:
    if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
      raise ValueError("Only one source of embedding for user is supported")
    else:
      user_eng_embedding = features_in["user_eng_embedding"]

  if "author_embedding" in features_in:
    if sparse_feature_names and "meta__author_id" in sparse_feature_names:
      raise ValueError("Only one source of embedding for user is supported")
    else:
      author_embedding = features_in["author_embedding"]

  return RecapBatch(
    continuous_features=features_in["continuous"],
    binary_features=features_in["binary"],
    discrete_features=features_in["discrete"],
    sparse_features=sparse_features,
    user_embedding=user_embedding,
    user_eng_embedding=user_eng_embedding,
    author_embedding=author_embedding,
    labels=labels,
    weights=features_in.get("weights", None),  # Defaults to torch.ones_like(labels)
  )