def training_range()

in tfx/dsl/input_resolution/ops/training_range_op.py [0:0]


def training_range(store: Any, model: types.Artifact) -> List[types.Artifact]:
  """ContainsTrainingRange implementation, for shared use across ResolverOps.

  Returns the Examples artifact the Model was trained on.

  Note that only the standard TFleX Model and Examples artifacts are supported.

  Args:
   store: The MetadataStore.
   model: The Model artifact whose trained Examples to return.

  Returns:
    List of Examples artifacts if found, else empty list. We intentionally don't
    raise SkipSignal, such that the caller can decide to raise it or not.
  """
  # In MLMD, an Examples and Model are related by:
  #
  #          Event 1           Event 2
  # Examples ------> Execution ------> Model
  #
  # For a single Model, there may be many parent Examples it was trained on.

  # TODO(kshivvy): Support querying multiple Model ids at once, to reduce the
  # number of round trip MLMD queries. This will be useful for resolving inputs
  # of a span driven evaluator.

  # Get all Executions associated with creating the Model.
  execution_ids = set()
  for event in store.get_events_by_artifact_ids([model.id]):
    if event_lib.is_valid_output_event(event):
      execution_ids.add(event.execution_id)

  # Get all artifact ids associated with an INPUT Event in each Execution.
  # These ids correspond to parent artifacts of the Model.
  parent_artifact_ids = set()
  for event in store.get_events_by_execution_ids(execution_ids):
    if event_lib.is_valid_input_event(event):
      parent_artifact_ids.add(event.artifact_id)

  # Get the type ids of the parent artifacts.
  type_ids = set()
  artifact_by_artifact_id = {}
  for artifact in store.get_artifacts_by_id(parent_artifact_ids):
    type_ids.add(artifact.type_id)
    artifact_by_artifact_id[artifact.id] = artifact

  # Find the ArtifactType associated with Examples.
  for artifact_type in store.get_artifact_types_by_id(type_ids):
    if artifact_type.name == ops_utils.EXAMPLES_TYPE_NAME:
      examples_type = artifact_type
      break
  else:
    return []

  mlmd_examples = []
  for artifact_id in parent_artifact_ids:
    artifact = artifact_by_artifact_id[artifact_id]
    if artifact.type_id == examples_type.id:
      mlmd_examples.append(artifact)

  if not mlmd_examples:
    return []

  # Return the sorted Examples.
  artifacts = artifact_utils.deserialize_artifacts(examples_type, mlmd_examples)
  return sorted(
      artifacts, key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id)
  )