def _generate_tasks_for_node()

in tfx/orchestration/experimental/core/sync_pipeline_task_gen.py [0:0]


  def _generate_tasks_for_node(
      self, node: node_proto_view.NodeProtoView) -> List[task_lib.Task]:
    """Generates list of tasks for the given node."""
    node_uid = task_lib.NodeUid.from_node(self._pipeline, node)
    node_id = node.node_info.id
    result = []

    node_state = self._node_states_dict[node_uid]
    if node_state.state in (pstate.NodeState.STOPPING, pstate.NodeState.STOPPED,
                            pstate.NodeState.PAUSING, pstate.NodeState.PAUSED):
      logging.info('Ignoring node in state \'%s\' for task generation: %s',
                   node_state.state, node_uid)
      return result

    # If this is a pure service node, there is no ExecNodeTask to generate
    # but we ensure node services and check service status.
    service_status = self._ensure_node_services_if_pure(node_id)
    if service_status is not None:
      if service_status == service_jobs.ServiceStatus.FAILED:
        # TODO(b/205642811): Mark all pending executions as either failed (if
        # active) or canceled (if new), and delete the the executions temporary
        # and output directories.
        error_msg = f'service job failed; node uid: {node_uid}'
        result.append(
            self._update_node_state_to_failed_task(
                node_uid,
                error_code=status_lib.Code.UNKNOWN,
                error_msg=error_msg,
            )
        )
      elif service_status == service_jobs.ServiceStatus.SUCCESS:
        logging.info('Service node successful: %s', node_uid)
        result.append(
            task_lib.UpdateNodeStateTask(
                node_uid=node_uid, state=pstate.NodeState.COMPLETE))
      elif (service_status == service_jobs.ServiceStatus.RUNNING and
            node_state.state != pstate.NodeState.RUNNING):
        result.append(
            task_lib.UpdateNodeStateTask(
                node_uid=node_uid, state=pstate.NodeState.RUNNING))
      return result

    # For mixed service nodes, we ensure node services and check service
    # status; pipeline is aborted if the service jobs have failed.
    service_status = self._ensure_node_services_if_mixed(node.node_info.id)
    if service_status == service_jobs.ServiceStatus.FAILED:
      error_msg = f'associated service job failed; node uid: {node_uid}'
      result.append(
          self._update_node_state_to_failed_task(
              node_uid, error_code=status_lib.Code.UNKNOWN, error_msg=error_msg
          )
      )
      return result

    # If a task for the node is already tracked by the task queue, it need
    # not be considered for generation again.
    if self._is_task_id_tracked_fn(
        task_lib.exec_node_task_id_from_node(self._pipeline, node)):
      return result

    node_executions = task_gen_utils.get_executions(self._mlmd_handle, node)
    latest_executions_set = task_gen_utils.get_latest_executions_set(
        node_executions)

    # Generates tasks from resolved inputs if the node doesn't have any
    # execution.
    if not latest_executions_set:
      result.extend(self._generate_tasks_from_resolved_inputs(node))
      return result

    # If all the executions are successful, the node is COMPLETE.
    if all(
        execution_lib.is_execution_successful(e) for e in latest_executions_set
    ):
      logging.info('Node successful: %s', node_uid)
      result.append(
          task_lib.UpdateNodeStateTask(
              node_uid=node_uid, state=pstate.NodeState.COMPLETE))
      return result

    # If the node has a failed execution, try to retry the failed execution.
    failed_executions = [
        e for e in latest_executions_set if execution_lib.is_execution_failed(e)
    ]
    if failed_executions:
      if len(failed_executions) > 1:
        error_msg = (f'node {node_uid} failed; error: More than one failed '
                     'executions found in the latest execution set.')
        result.append(
            self._update_node_state_to_failed_task(
                node_uid,
                error_code=status_lib.Code.INTERNAL,
                error_msg=error_msg,
            )
        )
      elif (node.execution_options.max_execution_retries >=
            task_gen_utils.get_num_of_failures_from_failed_execution(
                node_executions, failed_executions[0])):
        [retry_execution] = (
            task_gen_utils.register_executions_from_existing_executions(
                self._mlmd_handle, node, failed_executions
            )
        )
        result.extend(
            self._generate_tasks_from_existing_execution(retry_execution, node))
      else:
        result.append(
            task_lib.UpdateNodeStateTask(
                node_uid=node_uid,
                state=pstate.NodeState.FAILED,
                status=task_gen_utils.interpret_status_from_failed_execution(
                    failed_executions[0]
                ),
            )
        )
      return result

    # Restarts canceled node, if the node state is STARTING.
    canceled_executions = [
        e for e in latest_executions_set
        if execution_lib.is_execution_canceled(e)
    ]
    if canceled_executions and node_state.state == pstate.NodeState.STARTING:
      new_executions = (
          task_gen_utils.register_executions_from_existing_executions(
              self._mlmd_handle, node, canceled_executions
          )
      )
      with mlmd_state.mlmd_execution_atomic_op(
          mlmd_handle=self._mlmd_handle, execution_id=new_executions[0].id
      ) as execution:
        execution.last_known_state = metadata_store_pb2.Execution.RUNNING

      result.extend(
          self._generate_tasks_from_existing_execution(new_executions[0], node)
      )
      return result

    # If the node has active executions, creates tasks from the oldest active
    # execution.
    oldest_active_execution = next((e for e in latest_executions_set
                                    if execution_lib.is_execution_active(e)),
                                   None)
    if oldest_active_execution:
      result.extend(
          self._generate_tasks_from_existing_execution(oldest_active_execution,
                                                       node))
      return result

    raise RuntimeError('Task generation process should not reach this point.')