def Do()

in tfx/components/distribution_validator/executor.py [0:0]


  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:
    """DistributionValidator executor entrypoint.

    This checks for changes in data distributions from one dataset to another,
    based on the summary statitics for those datasets.

    Args:
      input_dict: Input dict from input key to a list of artifacts.
      output_dict: Output dict from key to a list of artifacts.
      exec_properties: A dict of execution properties.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    # Load and deserialize include splits from execution properties.
    include_splits_list = json_utils.loads(
        exec_properties.get(standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY,
                            'null')) or []
    include_splits = set((test, base) for test, base in include_splits_list)

    test_statistics = artifact_utils.get_single_instance(
        input_dict[standard_component_specs.STATISTICS_KEY])
    baseline_statistics = artifact_utils.get_single_instance(
        input_dict[standard_component_specs.BASELINE_STATISTICS_KEY])

    config = exec_properties.get(
        standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY)
    custom_validation_config = exec_properties.get(
        standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY)

    logging.info('Running distribution_validator with config %s', config)

    # Set up pairs of splits to validate.
    split_pairs = []
    for test_split in artifact_utils.decode_split_names(
        test_statistics.split_names):
      for baseline_split in artifact_utils.decode_split_names(
          baseline_statistics.split_names):
        if (test_split, baseline_split) in include_splits:
          split_pairs.append((test_split, baseline_split))
        elif not include_splits and test_split == baseline_split:
          split_pairs.append((test_split, baseline_split))
    if not split_pairs:
      raise ValueError(
          'No split pairs from test and baseline statistics: %s, %s' %
          (test_statistics, baseline_statistics))
    if include_splits:
      missing_split_pairs = include_splits - set(split_pairs)
      if missing_split_pairs:
        raise ValueError(
            'Missing split pairs identified in include_split_pairs: %s' %
            ', '.join([
                '%s_%s' % (test, baseline)
                for test, baseline in missing_split_pairs
            ]))

    anomalies_artifact = artifact_utils.get_single_instance(
        output_dict[standard_component_specs.ANOMALIES_KEY])
    anomalies_artifact.split_names = artifact_utils.encode_split_names(
        ['%s_%s' % (test, baseline) for test, baseline in split_pairs])

    for test_split, baseline_split in split_pairs:
      split_pair = '%s_%s' % (test_split, baseline_split)
      logging.info('Processing split pair %s', split_pair)
      test_stats_split = stats_artifact_utils.load_statistics(
          test_statistics, test_split).proto()
      baseline_stats_split = stats_artifact_utils.load_statistics(
          baseline_statistics, baseline_split).proto()

      schema = _make_schema_from_config(config, baseline_stats_split)
      full_anomalies = tfdv.validate_statistics(
          test_stats_split,
          schema,
          previous_statistics=baseline_stats_split,
          custom_validation_config=custom_validation_config)
      anomalies = _get_comparison_only_anomalies(full_anomalies)
      anomalies = _add_anomalies_for_missing_comparisons(anomalies, config)
      writer_utils.write_anomalies(
          os.path.join(
              anomalies_artifact.uri,
              'SplitPair-%s' % split_pair,
              DEFAULT_FILE_NAME,
          ),
          anomalies,
      )