def apply()

in tfx/dsl/input_resolution/ops/latest_policy_model_op.py [0:0]


  def apply(self, input_dict: typing_utils.ArtifactMultiMap):
    """Finds the latest created model via a certain policy.

    The input_dict is expected to have the following format:

    {
        "model": [Model 1, Model 2, ...],
        "model_blessing": [ModelBlessing 1, ModelBlessing 2, ...],
        "model_infra_blessing": [ModelInfraBlessing 1, ...]
    }

    "model" is a required key. "model_blessing" and "model_infra_blessing" are
    optional keys. If "model_blessing" and/or "model_infra_blessing" are
    provided, then only their lineage w.r.t. the Model artifacts will be
    considered.

    Example usecases for specifying "model_blessing"/"model_infra_blessing"
    include: 1) Resolving inputs to a Pusher 2) Specifying ModelBlessing
    artifacts from a specific Evaluator, in cases where the pipeline has
    multiple Evaluators.

    Note that only the standard TFleX Model, ModelBlessing, ModelInfraBlessing,
    and ModelPush artifacts are supported.

    Args:
      input_dict: An input dict containing "model", "model_blessing",
        "model_infra_blessing" as keys and lists of Model, ModelBlessing, and
        ModelInfraBlessing artifacts as values, respectively.

    Returns:
      A dictionary containing the latest Model artifact, as well as the
      ModelBlessing, ModelInfraBlessing, and/or ModelPush based on the Policy.

      For example, for a LATEST_BLESSED policy, the following dict will be
      returned:
      {
        "model": [Model],
        "model_blessing": [ModelBlessing],
        "model_infra_blessing": [ModelInfraBlessing]
      }

      For a LATEST_PUSHED policy, the following dict will be returned:
      {
        "model": [Model],
        "model_push": [ModelPush]
      }

    Raises:
      InvalidArgument: If the models are not Model artifacts.
      SkipSignal: If raise_skip_signal is True and one of the following:
        1. The input_dict is empty.
        2. If no models are passed in.
        3. If input_dict contains "model_blessing" and/or "model_infra_blessing"
           as keys but have empty lists as values for both of them.
        4. No latest model was found that matches the policy.
    """
    if not input_dict:
      return self._raise_skip_signal_or_return_empty_dict(
          'The input dictionary is empty.'
      )

    _validate_input_dict(input_dict)

    if not input_dict[ops_utils.MODEL_KEY]:
      return self._raise_skip_signal_or_return_empty_dict(
          'The "model" key in the input dict contained no Model artifacts.'
      )

    # Sort the models from from latest created to oldest.
    models = input_dict.get(ops_utils.MODEL_KEY)
    models.sort(  # pytype: disable=attribute-error
        key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id),
        reverse=True,
    )

    # Return the latest trained model if the policy is LATEST_EXPORTED.
    if self.policy == Policy.LATEST_EXPORTED:
      return {ops_utils.MODEL_KEY: [models[0]]}

    # If ModelBlessing and/or ModelInfraBlessing artifacts were included in
    # input_dict, then we will only consider those child artifacts.
    specifies_child_artifacts = (
        ops_utils.MODEL_BLESSSING_KEY in input_dict.keys()
        or ops_utils.MODEL_INFRA_BLESSING_KEY in input_dict.keys()
    )
    input_child_artifacts = input_dict.get(
        ops_utils.MODEL_BLESSSING_KEY, []
    ) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
    input_child_artifact_ids = set([a.id for a in input_child_artifacts])

    # If the ModelBlessing and ModelInfraBlessing lists are empty, then no
    # child artifacts can be considered and we raise a SkipSignal. This can
    # occur when a Model has been trained but not blessed yet, for example.
    if specifies_child_artifacts and not input_child_artifact_ids:
      return self._raise_skip_signal_or_return_empty_dict(
          '"model_blessing" and/or "model_infra_blessing" were specified as '
          'keys in the input dictionary, but contained no '
          'ModelBlessing/ModelInfraBlessing artifacts.'
      )

    # In MLMD, two artifacts are related by:
    #
    #       Event 1           Event 2
    # Model ------> Execution ------> Artifact B
    #
    # Artifact B can be:
    # 1. ModelBlessing output artifact from an Evaluator.
    # 2. ModelInfraBlessing output artifact from an InfraValidator.
    # 3. ModelPush output artifact from a Pusher.
    #
    # We query MLMD to get a list of candidate model artifact ids that have
    # a child artifact of type child_artifact_type. Note we perform batch
    # queries to reduce the number round trips to the database.

    # There could be multiple events with the same execution ID but different
    # artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
    # keep the values of model_artifact_ids_by_execution_id as sets.
    model_artifact_ids = sorted(set(m.id for m in models))
    model_artifact_ids_by_execution_id = collections.defaultdict(set)

    # Pusher takes uses the key "model_export" to take in the Model artifact,
    # but all other components use the key "model".
    if self.policy == Policy.LATEST_PUSHED:
      event_input_key = ops_utils.MODEL_EXPORT_KEY
    else:
      event_input_key = ops_utils.MODEL_KEY

    # Get all Executions in MLMD associated with the Model artifacts.
    execution_ids = set()
    for event in self.context.store.get_events_by_artifact_ids(
        model_artifact_ids
    ):
      if event_lib.is_valid_input_event(event, event_input_key):
        model_artifact_ids_by_execution_id[event.execution_id].add(
            event.artifact_id
        )
        execution_ids.add(event.execution_id)

    # Get all artifact ids associated with an OUTPUT Event in each Execution.
    # These ids correspond to descendant artifacts 1 hop distance away from the
    # Model.
    child_artifact_ids = set()
    child_artifact_ids_by_model_artifact_id = collections.defaultdict(set)
    model_artifact_ids_by_child_artifact_id = collections.defaultdict(set)
    for event in self.context.store.get_events_by_execution_ids(execution_ids):
      if not event_lib.is_valid_output_event(event):
        continue

      child_artifact_id = event.artifact_id
      # Only consider child artifacts present in input_dict, if specified.
      if (
          specifies_child_artifacts
          and child_artifact_id not in input_child_artifact_ids
      ):
        continue

      child_artifact_ids.add(child_artifact_id)
      model_artifact_ids = model_artifact_ids_by_execution_id[
          event.execution_id
      ]
      model_artifact_ids_by_child_artifact_id[child_artifact_id] = (
          model_artifact_ids
      )

      for model_artifact_id in model_artifact_ids:
        child_artifact_ids_by_model_artifact_id[model_artifact_id].add(
            child_artifact_id
        )

    # Get Model, ModelBlessing, ModelInfraBlessing, and ModelPush ArtifactTypes.
    artifact_type_by_type_id = {}
    artifact_type_by_name = {}
    for artifact_type in self.context.store.get_artifact_types():
      artifact_type_by_type_id[artifact_type.id] = artifact_type
      artifact_type_by_name[artifact_type.name] = artifact_type

    # Populate the ModelRelations associated with each Model artifact and its
    # children.
    child_artifact_by_artifact_id = {}
    model_relations_by_model_artifact_id = collections.defaultdict(
        ModelRelations
    )
    for artifact in self.context.store.get_artifacts_by_id(child_artifact_ids):
      child_artifact_by_artifact_id[artifact.id] = artifact
      for model_artifact_id in model_artifact_ids_by_child_artifact_id[
          artifact.id
      ]:
        model_relations = model_relations_by_model_artifact_id[
            model_artifact_id
        ]

        artifact_type_name = artifact_type_by_type_id[artifact.type_id].name
        if _is_eval_blessed(artifact_type_name, artifact):
          model_relations.model_blessing_by_artifact_id[artifact.id] = artifact

        elif _is_infra_blessed(artifact_type_name, artifact):
          model_relations.infra_blessing_by_artifact_id[artifact.id] = artifact

        elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME:
          model_relations.model_push_by_artifact_id[artifact.id] = artifact

    # Find the latest model and ModelRelations that meets the Policy.
    result = {}
    for model in models:
      model_relations = model_relations_by_model_artifact_id[model.id]
      if model_relations.meets_policy(self.policy):
        result[ops_utils.MODEL_KEY] = [model]
        break
    else:
      return self._raise_skip_signal_or_return_empty_dict(
          f'No model found that meets the Policy {Policy(self.policy).name}'
      )

    return _build_result_dictionary(
        result, model_relations, self.policy, artifact_type_by_name
    )