in tfx/dsl/input_resolution/ops/training_range_op.py [0:0]
def training_range(store: Any, model: types.Artifact) -> List[types.Artifact]:
"""ContainsTrainingRange implementation, for shared use across ResolverOps.
Returns the Examples artifact the Model was trained on.
Note that only the standard TFleX Model and Examples artifacts are supported.
Args:
store: The MetadataStore.
model: The Model artifact whose trained Examples to return.
Returns:
List of Examples artifacts if found, else empty list. We intentionally don't
raise SkipSignal, such that the caller can decide to raise it or not.
"""
# In MLMD, an Examples and Model are related by:
#
# Event 1 Event 2
# Examples ------> Execution ------> Model
#
# For a single Model, there may be many parent Examples it was trained on.
# TODO(kshivvy): Support querying multiple Model ids at once, to reduce the
# number of round trip MLMD queries. This will be useful for resolving inputs
# of a span driven evaluator.
# Get all Executions associated with creating the Model.
execution_ids = set()
for event in store.get_events_by_artifact_ids([model.id]):
if event_lib.is_valid_output_event(event):
execution_ids.add(event.execution_id)
# Get all artifact ids associated with an INPUT Event in each Execution.
# These ids correspond to parent artifacts of the Model.
parent_artifact_ids = set()
for event in store.get_events_by_execution_ids(execution_ids):
if event_lib.is_valid_input_event(event):
parent_artifact_ids.add(event.artifact_id)
# Get the type ids of the parent artifacts.
type_ids = set()
artifact_by_artifact_id = {}
for artifact in store.get_artifacts_by_id(parent_artifact_ids):
type_ids.add(artifact.type_id)
artifact_by_artifact_id[artifact.id] = artifact
# Find the ArtifactType associated with Examples.
for artifact_type in store.get_artifact_types_by_id(type_ids):
if artifact_type.name == ops_utils.EXAMPLES_TYPE_NAME:
examples_type = artifact_type
break
else:
return []
mlmd_examples = []
for artifact_id in parent_artifact_ids:
artifact = artifact_by_artifact_id[artifact_id]
if artifact.type_id == examples_type.id:
mlmd_examples.append(artifact)
if not mlmd_examples:
return []
# Return the sorted Examples.
artifacts = artifact_utils.deserialize_artifacts(examples_type, mlmd_examples)
return sorted(
artifacts, key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id)
)