def __call__()

in tfx/orchestration/experimental/core/async_pipeline_task_gen.py [0:0]


  def __call__(self) -> List[task_lib.Task]:
    result = []
    for node in [node_proto_view.get_view(n) for n in self._pipeline.nodes]:
      node_uid = task_lib.NodeUid.from_node(self._pipeline, node)
      node_id = node.node_info.id

      with self._pipeline_state:
        node_state = self._pipeline_state.get_node_state(node_uid)
        if node_state.state in (pstate.NodeState.STOPPING,
                                pstate.NodeState.STOPPED,
                                pstate.NodeState.PAUSING,
                                pstate.NodeState.PAUSED,
                                pstate.NodeState.FAILED):
          logging.info('Ignoring node in state \'%s\' for task generation: %s',
                       node_state.state, node_uid)
          continue

      # If this is a pure service node, there is no ExecNodeTask to generate
      # but we ensure node services and check service status.
      service_status = self._ensure_node_services_if_pure(node_id)
      if service_status is not None:
        if service_status != service_jobs.ServiceStatus.RUNNING:
          error_msg = f'associated service job failed; node uid: {node_uid}'
          result.append(
              task_lib.UpdateNodeStateTask(
                  node_uid=node_uid,
                  state=pstate.NodeState.FAILED,
                  status=status_lib.Status(
                      code=status_lib.Code.UNKNOWN, message=error_msg)))
        elif node_state.state != pstate.NodeState.RUNNING:
          result.append(
              task_lib.UpdateNodeStateTask(
                  node_uid=node_uid, state=pstate.NodeState.RUNNING
              )
          )
        continue

      # For mixed service nodes, we ensure node services and check service
      # status; the node is aborted if its service jobs have failed.
      service_status = self._ensure_node_services_if_mixed(node.node_info.id)
      if service_status is not None:
        if service_status != service_jobs.ServiceStatus.RUNNING:
          error_msg = f'associated service job failed; node uid: {node_uid}'
          result.append(
              task_lib.UpdateNodeStateTask(
                  node_uid=node_uid,
                  state=pstate.NodeState.FAILED,
                  status=status_lib.Status(
                      code=status_lib.Code.UNKNOWN, message=error_msg)))
          continue

      # If a task for the node is already tracked by the task queue, it need
      # not be considered for generation again.
      if self._is_task_id_tracked_fn(
          task_lib.exec_node_task_id_from_node(self._pipeline, node)):
        continue

      result.extend(
          self._generate_tasks_for_node(
              self._mlmd_handle, node, node_state.backfill_token
          )
      )
    return result