in tfx/components/example_diff/executor.py [0:0]
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]) -> None:
"""Computes example diffs for each split pair.
Args:
input_dict: Input dict from input key to a list of Artifacts.
output_dict: Output dict from output key to a list of Artifacts.
exec_properties: A dict of execution properties.
Raises:
ValueError: If examples are in a non- record-based format.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
# Load and deserialize included split pairs from execution properties.
included_split_pairs = _IncludedSplitPairs(
json_utils.loads(
exec_properties.get(
standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY, 'null')) or
None)
test_examples = artifact_utils.get_single_instance(
input_dict[standard_component_specs.EXAMPLES_KEY])
base_examples = artifact_utils.get_single_instance(
input_dict[standard_component_specs.BASELINE_EXAMPLES_KEY])
example_diff_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.EXAMPLE_DIFF_RESULT_KEY])
diff_config = exec_properties.get(
standard_component_specs.EXAMPLE_DIFF_CONFIG_KEY)
logging.info('Running examplediff with config %s', diff_config)
test_tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
examples=[test_examples], telemetry_descriptors=_TELEMETRY_DESCRIPTORS)
base_tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
examples=[base_examples], telemetry_descriptors=_TELEMETRY_DESCRIPTORS)
# First set up the pairs we'll operate on.
split_pairs = []
for test_split in artifact_utils.decode_split_names(
test_examples.split_names):
for base_split in artifact_utils.decode_split_names(
base_examples.split_names):
if included_split_pairs.included(test_split, base_split):
split_pairs.append((test_split, base_split))
if not split_pairs:
raise ValueError(
'No split pairs from test and baseline examples: %s, %s' %
(test_examples, base_examples))
if included_split_pairs.get_included_split_pairs():
missing_split_pairs = included_split_pairs.get_included_split_pairs(
) - set(split_pairs)
if missing_split_pairs:
raise ValueError(
'Missing split pairs identified in include_split_pairs: %s' %
', '.join([
'%s_%s' % (test, baseline)
for test, baseline in missing_split_pairs
]))
with self._make_beam_pipeline() as p:
for test_split, base_split in split_pairs:
test_uri = artifact_utils.get_split_uri([test_examples], test_split)
base_uri = artifact_utils.get_split_uri([base_examples], base_split)
test_tfxio = test_tfxio_factory(io_utils.all_files_pattern(test_uri))
base_tfxio = base_tfxio_factory(io_utils.all_files_pattern(base_uri))
if not isinstance(
test_tfxio, record_based_tfxio.RecordBasedTFXIO) or not isinstance(
base_tfxio, record_based_tfxio.RecordBasedTFXIO):
# TODO(b/227361696): Support more general sources.
raise ValueError('Only RecordBasedTFXIO supported, got %s, %s' %
(test_tfxio, base_tfxio))
split_pair = '%s_%s' % (test_split, base_split)
logging.info('Processing split pair %s', split_pair)
# pylint: disable=cell-var-from-loop
@beam.ptransform_fn
def _iteration(p):
base_examples = (
p | 'TFXIORead[base]' >> test_tfxio.RawRecordBeamSource()
| 'Parse[base]' >> beam.Map(_parse_example))
test_examples = (
p | 'TFXIORead[test]' >> base_tfxio.RawRecordBeamSource()
| 'Parse[test]' >> beam.Map(_parse_example))
results = ((base_examples, test_examples)
| feature_skew_detector.DetectFeatureSkewImpl(
**_config_to_kwargs(diff_config)))
output_uri = os.path.join(example_diff_artifact.uri,
'SplitPair-%s' % split_pair)
_ = (
results[feature_skew_detector.SKEW_RESULTS_KEY]
| 'WriteStats' >> feature_skew_detector.skew_results_sink(
os.path.join(output_uri, STATS_FILE_NAME)))
_ = (
results[feature_skew_detector.SKEW_PAIRS_KEY]
| 'WriteSample' >> feature_skew_detector.skew_pair_sink(
os.path.join(output_uri, _SAMPLE_FILE_NAME)))
_ = (
results[feature_skew_detector.MATCH_STATS_KEY]
| 'WriteMatchStats' >> feature_skew_detector.match_stats_sink(
os.path.join(output_uri, MATCH_STATS_FILE_NAME)))
if feature_skew_detector.CONFUSION_KEY in results:
_ = (
results[feature_skew_detector.CONFUSION_KEY]
|
'WriteConfusion' >> feature_skew_detector.confusion_count_sink(
os.path.join(output_uri, CONFUSION_FILE_NAME)))
# pylint: enable=cell-var-from-loop
_ = p | 'ProcessSplits[%s]' % split_pair >> _iteration() # pylint: disable=no-value-for-parameter