tfx/tools/cli/handler/beam_handler.py (77 lines of code) (raw):

# Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Handler for Beam.""" import json import sys from typing import Any, Dict import click from tfx.dsl.io import fileio from tfx.tools.cli import labels from tfx.tools.cli.handler import base_handler from tfx.tools.cli.handler import beam_dag_runner_patcher from tfx.tools.cli.handler import dag_runner_patcher from tfx.utils import io_utils class BeamHandler(base_handler.BaseHandler): """Helper methods for Beam Handler.""" def _get_dag_runner_patcher(self) -> dag_runner_patcher.DagRunnerPatcher: return beam_dag_runner_patcher.BeamDagRunnerPatcher() def create_pipeline(self, overwrite: bool = False) -> None: """Creates pipeline in Beam. Args: overwrite: set as true to update pipeline. """ patcher = self._get_dag_runner_patcher() context = self.execute_dsl(patcher) pipeline_name = context[patcher.PIPELINE_NAME] self._check_pipeline_existence(pipeline_name, required=overwrite) self._save_pipeline({ labels.PIPELINE_NAME: pipeline_name, labels.PIPELINE_ROOT: context[patcher.PIPELINE_ROOT] }) if overwrite: click.echo('Pipeline "{}" updated successfully.'.format(pipeline_name)) else: click.echo('Pipeline "{}" created successfully.'.format(pipeline_name)) def update_pipeline(self) -> None: """Updates pipeline in Beam.""" # Set overwrite as True to update the pipeline. self.create_pipeline(overwrite=True) def list_pipelines(self) -> None: """List all the pipelines in the environment.""" if not fileio.exists(self._handler_home_dir): click.echo('No pipelines to display.') return pipelines_list = fileio.listdir(self._handler_home_dir) # Print every pipeline name in a new line. click.echo('-' * 30) click.echo('\n'.join(pipelines_list)) click.echo('-' * 30) def delete_pipeline(self) -> None: """Deletes pipeline in Beam.""" pipeline_name = self.flags_dict[labels.PIPELINE_NAME] handler_pipeline_path = self._get_pipeline_info_path(pipeline_name) # Check if pipeline exists. self._check_pipeline_existence(pipeline_name) # Delete pipeline folder. io_utils.delete_dir(handler_pipeline_path) click.echo('Pipeline "{}" deleted successfully.'.format(pipeline_name)) def compile_pipeline(self) -> None: """Compiles pipeline in Beam.""" patcher = self._get_dag_runner_patcher() self.execute_dsl(patcher) click.echo('Pipeline compiled successfully.') def create_run(self) -> None: """Runs a pipeline in Beam.""" pipeline_name = self.flags_dict[labels.PIPELINE_NAME] # Check if pipeline exists. self._check_pipeline_existence(pipeline_name) with open(self._get_pipeline_args_path(pipeline_name), 'r') as f: pipeline_args = json.load(f) # Run pipeline dsl. self._subprocess_call( [sys.executable, str(pipeline_args[labels.PIPELINE_DSL_PATH])]) def delete_run(self) -> None: """Deletes a run.""" click.echo('Not supported for {} orchestrator.'.format( self.flags_dict[labels.ENGINE_FLAG])) def terminate_run(self) -> None: """Stops a run.""" click.echo('Not supported for {} orchestrator.'.format( self.flags_dict[labels.ENGINE_FLAG])) def list_runs(self) -> None: """Lists all runs of a pipeline.""" click.echo('Not supported for {} orchestrator.'.format( self.flags_dict[labels.ENGINE_FLAG])) def get_run(self) -> None: """Checks run status.""" click.echo('Not supported for {} orchestrator.'.format( self.flags_dict[labels.ENGINE_FLAG])) def _save_pipeline(self, pipeline_args: Dict[str, Any]) -> None: """Creates/updates pipeline folder in the handler directory.""" # Add pipeline dsl path to pipeline args. pipeline_args[labels.PIPELINE_DSL_PATH] = self.flags_dict[ labels.PIPELINE_DSL_PATH] pipeline_name = pipeline_args[labels.PIPELINE_NAME] handler_pipeline_path = self._get_pipeline_info_path(pipeline_name) # If updating pipeline, first delete pipeline directory. if fileio.exists(handler_pipeline_path): io_utils.delete_dir(handler_pipeline_path) # Dump pipeline_args to handler pipeline folder as json. fileio.makedirs(handler_pipeline_path) with open(self._get_pipeline_args_path(pipeline_name), 'w') as f: json.dump(pipeline_args, f)