in tfx/orchestration/experimental/core/sync_pipeline_task_gen.py [0:0]
def _generate_tasks_for_node(
self, node: node_proto_view.NodeProtoView) -> List[task_lib.Task]:
"""Generates list of tasks for the given node."""
node_uid = task_lib.NodeUid.from_node(self._pipeline, node)
node_id = node.node_info.id
result = []
node_state = self._node_states_dict[node_uid]
if node_state.state in (pstate.NodeState.STOPPING, pstate.NodeState.STOPPED,
pstate.NodeState.PAUSING, pstate.NodeState.PAUSED):
logging.info('Ignoring node in state \'%s\' for task generation: %s',
node_state.state, node_uid)
return result
# If this is a pure service node, there is no ExecNodeTask to generate
# but we ensure node services and check service status.
service_status = self._ensure_node_services_if_pure(node_id)
if service_status is not None:
if service_status == service_jobs.ServiceStatus.FAILED:
# TODO(b/205642811): Mark all pending executions as either failed (if
# active) or canceled (if new), and delete the the executions temporary
# and output directories.
error_msg = f'service job failed; node uid: {node_uid}'
result.append(
self._update_node_state_to_failed_task(
node_uid,
error_code=status_lib.Code.UNKNOWN,
error_msg=error_msg,
)
)
elif service_status == service_jobs.ServiceStatus.SUCCESS:
logging.info('Service node successful: %s', node_uid)
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.COMPLETE))
elif (service_status == service_jobs.ServiceStatus.RUNNING and
node_state.state != pstate.NodeState.RUNNING):
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.RUNNING))
return result
# For mixed service nodes, we ensure node services and check service
# status; pipeline is aborted if the service jobs have failed.
service_status = self._ensure_node_services_if_mixed(node.node_info.id)
if service_status == service_jobs.ServiceStatus.FAILED:
error_msg = f'associated service job failed; node uid: {node_uid}'
result.append(
self._update_node_state_to_failed_task(
node_uid, error_code=status_lib.Code.UNKNOWN, error_msg=error_msg
)
)
return result
# If a task for the node is already tracked by the task queue, it need
# not be considered for generation again.
if self._is_task_id_tracked_fn(
task_lib.exec_node_task_id_from_node(self._pipeline, node)):
return result
node_executions = task_gen_utils.get_executions(self._mlmd_handle, node)
latest_executions_set = task_gen_utils.get_latest_executions_set(
node_executions)
# Generates tasks from resolved inputs if the node doesn't have any
# execution.
if not latest_executions_set:
result.extend(self._generate_tasks_from_resolved_inputs(node))
return result
# If all the executions are successful, the node is COMPLETE.
if all(
execution_lib.is_execution_successful(e) for e in latest_executions_set
):
logging.info('Node successful: %s', node_uid)
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.COMPLETE))
return result
# If the node has a failed execution, try to retry the failed execution.
failed_executions = [
e for e in latest_executions_set if execution_lib.is_execution_failed(e)
]
if failed_executions:
if len(failed_executions) > 1:
error_msg = (f'node {node_uid} failed; error: More than one failed '
'executions found in the latest execution set.')
result.append(
self._update_node_state_to_failed_task(
node_uid,
error_code=status_lib.Code.INTERNAL,
error_msg=error_msg,
)
)
elif (node.execution_options.max_execution_retries >=
task_gen_utils.get_num_of_failures_from_failed_execution(
node_executions, failed_executions[0])):
[retry_execution] = (
task_gen_utils.register_executions_from_existing_executions(
self._mlmd_handle, node, failed_executions
)
)
result.extend(
self._generate_tasks_from_existing_execution(retry_execution, node))
else:
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid,
state=pstate.NodeState.FAILED,
status=task_gen_utils.interpret_status_from_failed_execution(
failed_executions[0]
),
)
)
return result
# Restarts canceled node, if the node state is STARTING.
canceled_executions = [
e for e in latest_executions_set
if execution_lib.is_execution_canceled(e)
]
if canceled_executions and node_state.state == pstate.NodeState.STARTING:
new_executions = (
task_gen_utils.register_executions_from_existing_executions(
self._mlmd_handle, node, canceled_executions
)
)
with mlmd_state.mlmd_execution_atomic_op(
mlmd_handle=self._mlmd_handle, execution_id=new_executions[0].id
) as execution:
execution.last_known_state = metadata_store_pb2.Execution.RUNNING
result.extend(
self._generate_tasks_from_existing_execution(new_executions[0], node)
)
return result
# If the node has active executions, creates tasks from the oldest active
# execution.
oldest_active_execution = next((e for e in latest_executions_set
if execution_lib.is_execution_active(e)),
None)
if oldest_active_execution:
result.extend(
self._generate_tasks_from_existing_execution(oldest_active_execution,
node))
return result
raise RuntimeError('Task generation process should not reach this point.')