def build_preprocess()

in projects/home/recap/data/preprocessors.py [0:0]


def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
  """Builds a preprocess model to apply all preprocessing stages."""
  if mode == config_mod.JobMode.INFERENCE:
    logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
    return None

  preprocess_models = []
  if preprocess_config.downsample_negatives:
    preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
  if preprocess_config.truncate_and_slice:
    preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
  if preprocess_config.downcast:
    preprocess_models.append(DownCast(preprocess_config.downcast))
  if preprocess_config.rectify_labels:
    preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
  if preprocess_config.extract_features:
    preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))

  if len(preprocess_models) == 0:
    raise ValueError("No known preprocessor.")

  class PreprocessModel(tf.keras.Model):
    def __init__(self, preprocess_models):
      super().__init__()
      self.preprocess_models = preprocess_models

    def call(self, inputs, training=None, mask=None):
      outputs = inputs
      for model in self.preprocess_models:
        outputs = model(outputs, training, mask)
      return outputs

  if len(preprocess_models) > 1:
    logging.warning(
      "With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
    )
  return PreprocessModel(preprocess_models)