in tfx/components/distribution_validator/executor.py [0:0]
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]) -> None:
"""DistributionValidator executor entrypoint.
This checks for changes in data distributions from one dataset to another,
based on the summary statitics for those datasets.
Args:
input_dict: Input dict from input key to a list of artifacts.
output_dict: Output dict from key to a list of artifacts.
exec_properties: A dict of execution properties.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
# Load and deserialize include splits from execution properties.
include_splits_list = json_utils.loads(
exec_properties.get(standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY,
'null')) or []
include_splits = set((test, base) for test, base in include_splits_list)
test_statistics = artifact_utils.get_single_instance(
input_dict[standard_component_specs.STATISTICS_KEY])
baseline_statistics = artifact_utils.get_single_instance(
input_dict[standard_component_specs.BASELINE_STATISTICS_KEY])
config = exec_properties.get(
standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY)
custom_validation_config = exec_properties.get(
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY)
logging.info('Running distribution_validator with config %s', config)
# Set up pairs of splits to validate.
split_pairs = []
for test_split in artifact_utils.decode_split_names(
test_statistics.split_names):
for baseline_split in artifact_utils.decode_split_names(
baseline_statistics.split_names):
if (test_split, baseline_split) in include_splits:
split_pairs.append((test_split, baseline_split))
elif not include_splits and test_split == baseline_split:
split_pairs.append((test_split, baseline_split))
if not split_pairs:
raise ValueError(
'No split pairs from test and baseline statistics: %s, %s' %
(test_statistics, baseline_statistics))
if include_splits:
missing_split_pairs = include_splits - set(split_pairs)
if missing_split_pairs:
raise ValueError(
'Missing split pairs identified in include_split_pairs: %s' %
', '.join([
'%s_%s' % (test, baseline)
for test, baseline in missing_split_pairs
]))
anomalies_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.ANOMALIES_KEY])
anomalies_artifact.split_names = artifact_utils.encode_split_names(
['%s_%s' % (test, baseline) for test, baseline in split_pairs])
for test_split, baseline_split in split_pairs:
split_pair = '%s_%s' % (test_split, baseline_split)
logging.info('Processing split pair %s', split_pair)
test_stats_split = stats_artifact_utils.load_statistics(
test_statistics, test_split).proto()
baseline_stats_split = stats_artifact_utils.load_statistics(
baseline_statistics, baseline_split).proto()
schema = _make_schema_from_config(config, baseline_stats_split)
full_anomalies = tfdv.validate_statistics(
test_stats_split,
schema,
previous_statistics=baseline_stats_split,
custom_validation_config=custom_validation_config)
anomalies = _get_comparison_only_anomalies(full_anomalies)
anomalies = _add_anomalies_for_missing_comparisons(anomalies, config)
writer_utils.write_anomalies(
os.path.join(
anomalies_artifact.uri,
'SplitPair-%s' % split_pair,
DEFAULT_FILE_NAME,
),
anomalies,
)