def topsorted_layers()

in tfx/utils/topsort.py [0:0]


def topsorted_layers(
    nodes: Sequence[NodeT], get_node_id_fn: Callable[[NodeT], str],
    get_parent_nodes: Callable[[NodeT], List[NodeT]],
    get_child_nodes: Callable[[NodeT], List[NodeT]]) -> List[List[NodeT]]:
  """Sorts the DAG of nodes in topological order.

  Args:
    nodes: A sequence of nodes.
    get_node_id_fn: Callable that returns a unique text identifier for a node.
    get_parent_nodes: Callable that returns a list of parent nodes for a node.
      If a parent node's id is not found in the list of node ids, that parent
      node will be omitted.
    get_child_nodes: Callable that returns a list of child nodes for a node.
      If a child node's id is not found in the list of node ids, that child
      node will be omitted.

  Returns:
    A list of topologically ordered node layers. Each layer of nodes is sorted
    by its node id given by `get_node_id_fn`.

  Raises:
    InvalidDAGError: If the input nodes don't form a DAG.
    ValueError: If the nodes are not unique.
  """
  # Make sure the nodes are unique.
  node_ids = set(get_node_id_fn(n) for n in nodes)
  if len(node_ids) != len(nodes):
    raise ValueError('Nodes must have unique ids.')

  # The outputs of get_(parent|child)_nodes should always be deduplicated,
  # and references to unknown nodes should be removed.
  def _apply_and_clean(func: Callable[[NodeT], List[NodeT]], func_name: str,
                       node: NodeT) -> List[NodeT]:
    seen_inner_node_ids = set()
    result = []
    for inner_node in func(node):
      inner_node_id = get_node_id_fn(inner_node)
      if inner_node_id in seen_inner_node_ids:
        logging.warning(
            'Duplicate node_id %s found when calling %s on node %s. '
            'This entry will be ignored.', inner_node_id, func_name, node)
      elif inner_node_id not in node_ids:
        logging.warning(
            'node_id %s found when calling %s on node %s, but this node_id is '
            'not found in the set of input nodes %s. This entry will be '
            'ignored.', inner_node_id, func_name, node, node_ids)
      else:
        seen_inner_node_ids.add(inner_node_id)
        result.append(inner_node)

    return result

  get_clean_parent_nodes = (
      lambda node: _apply_and_clean(get_parent_nodes, 'get_parent_nodes', node))
  get_clean_child_nodes = (
      lambda node: _apply_and_clean(get_child_nodes, 'get_child_nodes', node))

  # The first layer contains nodes with no incoming edges.
  layer = [node for node in nodes if not get_clean_parent_nodes(node)]

  visited_node_ids = set()
  layers = []
  while layer:
    layer = sorted(layer, key=get_node_id_fn)
    layers.append(layer)

    next_layer = []
    for node in layer:
      visited_node_ids.add(get_node_id_fn(node))
      for child_node in get_clean_child_nodes(node):
        # Include the child node if all its parents are visited. If the child
        # node is part of a cycle, it will never be included since it will have
        # at least one unvisited parent node which is also part of the cycle.
        parent_node_ids = set(
            get_node_id_fn(p) for p in get_clean_parent_nodes(child_node))
        if parent_node_ids.issubset(visited_node_ids):
          next_layer.append(child_node)
    layer = next_layer

  num_output_nodes = sum(len(layer) for layer in layers)
  # Nodes in cycles are not included in layers; raise an error if this happens.
  if num_output_nodes < len(nodes):
    raise InvalidDAGError('Cycle detected.')
  # This should never happen; raise an error if this occurs.
  if num_output_nodes > len(nodes):
    raise InvalidDAGError('Unknown DAG error.')

  return layers