def Do()

in tfx/components/example_diff/executor.py [0:0]


  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:
    """Computes example diffs for each split pair.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
      output_dict: Output dict from output key to a list of Artifacts.
      exec_properties: A dict of execution properties.

    Raises:
      ValueError: If examples are in a non- record-based format.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    # Load and deserialize included split pairs from execution properties.
    included_split_pairs = _IncludedSplitPairs(
        json_utils.loads(
            exec_properties.get(
                standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY, 'null')) or
        None)

    test_examples = artifact_utils.get_single_instance(
        input_dict[standard_component_specs.EXAMPLES_KEY])
    base_examples = artifact_utils.get_single_instance(
        input_dict[standard_component_specs.BASELINE_EXAMPLES_KEY])

    example_diff_artifact = artifact_utils.get_single_instance(
        output_dict[standard_component_specs.EXAMPLE_DIFF_RESULT_KEY])
    diff_config = exec_properties.get(
        standard_component_specs.EXAMPLE_DIFF_CONFIG_KEY)

    logging.info('Running examplediff with config %s', diff_config)

    test_tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
        examples=[test_examples], telemetry_descriptors=_TELEMETRY_DESCRIPTORS)
    base_tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
        examples=[base_examples], telemetry_descriptors=_TELEMETRY_DESCRIPTORS)

    # First set up the pairs we'll operate on.
    split_pairs = []
    for test_split in artifact_utils.decode_split_names(
        test_examples.split_names):
      for base_split in artifact_utils.decode_split_names(
          base_examples.split_names):
        if included_split_pairs.included(test_split, base_split):
          split_pairs.append((test_split, base_split))
    if not split_pairs:
      raise ValueError(
          'No split pairs from test and baseline examples: %s, %s' %
          (test_examples, base_examples))
    if included_split_pairs.get_included_split_pairs():
      missing_split_pairs = included_split_pairs.get_included_split_pairs(
      ) - set(split_pairs)
      if missing_split_pairs:
        raise ValueError(
            'Missing split pairs identified in include_split_pairs: %s' %
            ', '.join([
                '%s_%s' % (test, baseline)
                for test, baseline in missing_split_pairs
            ]))
    with self._make_beam_pipeline() as p:
      for test_split, base_split in split_pairs:
        test_uri = artifact_utils.get_split_uri([test_examples], test_split)
        base_uri = artifact_utils.get_split_uri([base_examples], base_split)
        test_tfxio = test_tfxio_factory(io_utils.all_files_pattern(test_uri))
        base_tfxio = base_tfxio_factory(io_utils.all_files_pattern(base_uri))
        if not isinstance(
            test_tfxio, record_based_tfxio.RecordBasedTFXIO) or not isinstance(
                base_tfxio, record_based_tfxio.RecordBasedTFXIO):
          # TODO(b/227361696): Support more general sources.
          raise ValueError('Only RecordBasedTFXIO supported, got %s, %s' %
                           (test_tfxio, base_tfxio))

        split_pair = '%s_%s' % (test_split, base_split)
        logging.info('Processing split pair %s', split_pair)
        # pylint: disable=cell-var-from-loop
        @beam.ptransform_fn
        def _iteration(p):
          base_examples = (
              p | 'TFXIORead[base]' >> test_tfxio.RawRecordBeamSource()
              | 'Parse[base]' >> beam.Map(_parse_example))
          test_examples = (
              p | 'TFXIORead[test]' >> base_tfxio.RawRecordBeamSource()
              | 'Parse[test]' >> beam.Map(_parse_example))
          results = ((base_examples, test_examples)
                     | feature_skew_detector.DetectFeatureSkewImpl(
                         **_config_to_kwargs(diff_config)))

          output_uri = os.path.join(example_diff_artifact.uri,
                                    'SplitPair-%s' % split_pair)
          _ = (
              results[feature_skew_detector.SKEW_RESULTS_KEY]
              | 'WriteStats' >> feature_skew_detector.skew_results_sink(
                  os.path.join(output_uri, STATS_FILE_NAME)))
          _ = (
              results[feature_skew_detector.SKEW_PAIRS_KEY]
              | 'WriteSample' >> feature_skew_detector.skew_pair_sink(
                  os.path.join(output_uri, _SAMPLE_FILE_NAME)))
          _ = (
              results[feature_skew_detector.MATCH_STATS_KEY]
              | 'WriteMatchStats' >> feature_skew_detector.match_stats_sink(
                  os.path.join(output_uri, MATCH_STATS_FILE_NAME)))
          if feature_skew_detector.CONFUSION_KEY in results:
            _ = (
                results[feature_skew_detector.CONFUSION_KEY]
                |
                'WriteConfusion' >> feature_skew_detector.confusion_count_sink(
                    os.path.join(output_uri, CONFUSION_FILE_NAME)))

          # pylint: enable=cell-var-from-loop

        _ = p | 'ProcessSplits[%s]' % split_pair >> _iteration()  # pylint: disable=no-value-for-parameter