in tfx/orchestration/kubeflow/base_component.py [0:0]
def __init__(self,
component: tfx_base_node.BaseNode,
depends_on: Set[dsl.ContainerOp],
pipeline: tfx_pipeline.Pipeline,
pipeline_root: dsl.PipelineParam,
tfx_image: str,
kubeflow_metadata_config: kubeflow_pb2.KubeflowMetadataConfig,
tfx_ir: pipeline_pb2.Pipeline,
pod_labels_to_attach: Dict[str, str],
runtime_parameters: List[data_types.RuntimeParameter],
metadata_ui_path: str = '/mlpipeline-ui-metadata.json'):
"""Creates a new Kubeflow-based component.
This class essentially wraps a dsl.ContainerOp construct in Kubeflow
Pipelines.
Args:
component: The logical TFX component to wrap.
depends_on: The set of upstream KFP ContainerOp components that this
component will depend on.
pipeline: The logical TFX pipeline to which this component belongs.
pipeline_root: The pipeline root specified, as a dsl.PipelineParam
tfx_image: The container image to use for this component.
kubeflow_metadata_config: Configuration settings for connecting to the
MLMD store in a Kubeflow cluster.
tfx_ir: The TFX intermedia representation of the pipeline.
pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
runtime_parameters: Runtime parameters of the pipeline.
metadata_ui_path: File location for metadata-ui-metadata.json file.
"""
_replace_placeholder(component)
arguments = [
'--pipeline_root',
pipeline_root,
'--kubeflow_metadata_config',
json_format.MessageToJson(
message=kubeflow_metadata_config, preserving_proto_field_name=True),
'--node_id',
component.id,
# TODO(b/182220464): write IR to pipeline_root and let
# container_entrypoint.py read it back to avoid future issue that IR
# exeeds the flag size limit.
'--tfx_ir',
json_format.MessageToJson(tfx_ir),
'--metadata_ui_path',
metadata_ui_path,
]
for param in runtime_parameters:
arguments.append('--runtime_parameter')
arguments.append(_encode_runtime_parameter(param))
self.container_op = dsl.ContainerOp(
name=component.id,
command=_COMMAND,
image=tfx_image,
arguments=arguments,
output_artifact_paths={
'mlpipeline-ui-metadata': metadata_ui_path,
},
)
logging.info('Adding upstream dependencies for component %s',
self.container_op.name)
for op in depends_on:
logging.info(' -> Component: %s', op.name)
self.container_op.after(op)
# TODO(b/140172100): Document the use of additional_pipeline_args.
if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
# Allow overriding pipeline's run_id externally, primarily for testing.
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=_WORKFLOW_ID_KEY,
value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY]))
else:
# Add the Argo workflow ID to the container's environment variable so it
# can be used to uniquely place pipeline outputs under the pipeline_root.
field_path = "metadata.labels['workflows.argoproj.io/workflow']"
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=_WORKFLOW_ID_KEY,
value_from=k8s_client.V1EnvVarSource(
field_ref=k8s_client.V1ObjectFieldSelector(
field_path=field_path))))
if pod_labels_to_attach:
for k, v in pod_labels_to_attach.items():
self.container_op.add_pod_label(k, v)