in tfx/orchestration/pipeline.py [0:0]
def _set_components(self, components: List[base_node.BaseNode]) -> None:
"""Set a full list of components of the pipeline."""
self._check_mutable()
deduped_components = set(components)
node_by_id = {}
# Fills in producer map.
for component in deduped_components:
# Checks every node has an unique id.
if component.id in node_by_id:
raise RuntimeError(
f'Duplicated node_id {component.id} for component type'
f'{component.type}. Try setting a different node_id using '
'`.with_id()`.'
)
node_by_id[component.id] = component
# Connects nodes based on producer map.
for component in deduped_components:
channels = list(component.inputs.values())
for exec_property in component.exec_properties.values():
if isinstance(exec_property, ph.ChannelWrappedPlaceholder):
channels.append(exec_property.channel)
for predicate in conditional.get_predicates(component,
self.dsl_context_registry):
channels.extend(channel_utils.get_dependent_channels(predicate))
for input_channel in channels:
for node_id in input_channel.get_data_dependent_node_ids():
if node_id == self.id:
# If a component's input channel depends on the (self) pipeline,
# it means that component consumes pipeline-level inputs. No need to
# add upstream node here. Pipeline-level inputs will be handled
# during compilation.
continue
upstream_node = node_by_id.get(node_id)
if upstream_node:
component.add_upstream_node(upstream_node)
upstream_node.add_downstream_node(component)
else:
warnings.warn(
f'Node {component.id} depends on the output of node {node_id}'
f', but {node_id} is not included in the components of '
'pipeline. Did you forget to add it?')
layers = topsort.topsorted_layers(
list(deduped_components),
get_node_id_fn=lambda c: c.id,
get_parent_nodes=lambda c: c.upstream_nodes,
get_child_nodes=lambda c: c.downstream_nodes)
self._components = []
for layer in layers:
for component in layer:
self._components.append(component)
if self.beam_pipeline_args:
for component in self._components:
add_beam_pipeline_args_to_component(component, self.beam_pipeline_args)