def call()

in projects/home/recap/data/preprocessors.py [0:0]


  def call(self, inputs, training=None, mask=None):
    labels = self.config.engagements_list
    # union of engagements
    mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
    n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
    batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
    negative_weights = tf.math.divide_no_nan(
      tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
      tf.cast(batch_size - n_positives, tf.float32),
    )
    new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights

    def _split_by_label_concatenate_and_truncate(input_tensor):
      # takes positive examples and concatenate with negative examples and truncate
      # DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
      return tf.concat(
        [
          input_tensor[mask],
          input_tensor[tf.math.logical_not(mask)],
        ],
        0,
      )[:batch_size]

    if "weights" not in inputs:
      # add placeholder so logic below applies even if weights aren't present in inputs
      inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])

    for tensor in inputs:
      if tensor == "weights":
        inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])

      inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])

    return inputs