in tfx/orchestration/pipeline.py [0:0]
def __init__(self,
pipeline_name: str,
pipeline_root: Optional[Union[str, ph.Placeholder]] = '',
metadata_connection_config: Optional[
metadata.ConnectionConfigType] = None,
components: Optional[List[base_node.BaseNode]] = None,
enable_cache: Optional[bool] = False,
beam_pipeline_args: Optional[List[Union[str,
ph.Placeholder]]] = None,
platform_config: Optional[message.Message] = None,
execution_mode: Optional[ExecutionMode] = ExecutionMode.SYNC,
inputs: Optional[PipelineInputs] = None,
outputs: Optional[Dict[str, channel.OutputChannel]] = None,
**kwargs):
"""Initialize pipeline.
Args:
pipeline_name: Name of the pipeline;
pipeline_root: Path to root directory of the pipeline. This will most
often be just a string. Some orchestrators may have limited support for
constructing this from a Placeholder, e.g. a RuntimeInfoPlaceholder that
refers to fields from the platform config. pipeline_root is optional
only if the pipeline is composed within another parent pipeline, in
which case it will inherit its parent pipeline's root.
metadata_connection_config: The config to connect to ML metadata.
components: Optional list of components to construct the pipeline.
enable_cache: Whether or not cache is enabled for this run.
beam_pipeline_args: Pipeline arguments for Beam powered Components.
platform_config: Pipeline level platform config, in proto form.
execution_mode: The execution mode of the pipeline, can be SYNC or ASYNC.
inputs: Optional inputs of a pipeline.
outputs: Optional outputs of a pipeline.
**kwargs: Additional kwargs forwarded as pipeline args.
"""
if len(pipeline_name) > _MAX_PIPELINE_NAME_LENGTH:
raise ValueError(
f'pipeline {pipeline_name} exceeds maximum allowed length: {_MAX_PIPELINE_NAME_LENGTH}.'
)
self.pipeline_name = pipeline_name
# Initialize pipeline as a node.
super().__init__()
if inputs:
inputs.pipeline = self
self._inputs = inputs
if outputs:
self._outputs = {
k: channel.PipelineOutputChannel(v, pipeline=self, output_key=k)
for k, v in outputs.items()
}
else:
self._outputs = {}
self._id = pipeline_name
# Once pipeline is finalized, this instance is regarded as immutable and
# any detectable mutation will raise an error.
self._finalized = False
# TODO(b/183621450): deprecate PipelineInfo.
self.pipeline_info = data_types.PipelineInfo( # pylint: disable=g-missing-from-attributes
pipeline_name=pipeline_name,
pipeline_root=pipeline_root)
self.enable_cache = enable_cache
self.metadata_connection_config = metadata_connection_config
self.execution_mode = execution_mode
self._beam_pipeline_args = beam_pipeline_args or []
self.platform_config = platform_config
self.additional_pipeline_args = kwargs.pop( # pylint: disable=g-missing-from-attributes
'additional_pipeline_args', {})
reg = kwargs.pop('dsl_context_registry', None)
if reg:
if not isinstance(reg, dsl_context_registry.DslContextRegistry):
raise ValueError('dsl_context_registry must be DslContextRegistry type '
f'but got {reg}')
self._dsl_context_registry = reg
else:
self._dsl_context_registry = dsl_context_registry.get()
if self._dsl_context_registry.get_contexts(self):
self._dsl_context_registry = (
self._dsl_context_registry.extract_for_pipeline(self))
# TODO(b/216581002): Use self._dsl_context_registry to obtain components.
self._components = []
if components:
self._set_components(components)