def generate_models()

in tfx/benchmarks/datasets/chicago_taxi/dataset.py [0:0]


  def generate_models(self, args, force_tf_compat_v1=True):
    # Modified version of Chicago Taxi Example pipeline
    # tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py

    root = tempfile.mkdtemp()
    pipeline_root = os.path.join(root, "pipeline")
    metadata_path = os.path.join(root, "metadata/metadata.db")
    module_file = os.path.join(
        os.path.dirname(__file__),
        "../../../examples/chicago_taxi_pipeline/taxi_utils.py")

    example_gen = components.ImportExampleGen(
        input_base=os.path.dirname(self.dataset_path()))
    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs["examples"])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs["statistics"],
        infer_feature_shape=False)
    transform = components.Transform(
        examples=example_gen.outputs["examples"],
        schema=schema_gen.outputs["schema"],
        module_file=module_file,
        force_tf_compat_v1=force_tf_compat_v1)
    trainer = components.Trainer(
        module_file=module_file,
        transformed_examples=transform.outputs["transformed_examples"],
        schema=schema_gen.outputs["schema"],
        transform_graph=transform.outputs["transform_graph"],
        train_args=trainer_pb2.TrainArgs(num_steps=100),
        eval_args=trainer_pb2.EvalArgs(num_steps=50))
    p = pipeline.Pipeline(
        pipeline_name="chicago_taxi_beam",
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, transform, trainer
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))
    BeamDagRunner().run(p)

    def join_unique_subdir(path):
      dirs = os.listdir(path)
      if len(dirs) != 1:
        raise ValueError(
            "expecting there to be only one subdirectory in %s, but "
            "subdirectories were: %s" % (path, dirs))
      return os.path.join(path, dirs[0])

    trainer_output_dir = join_unique_subdir(
        os.path.join(pipeline_root, "Trainer/model"))
    eval_model_dir = join_unique_subdir(
        os.path.join(trainer_output_dir, "eval_model_dir"))
    serving_model_dir = join_unique_subdir(
        os.path.join(trainer_output_dir,
                     "serving_model_dir/export/chicago-taxi"))
    transform_output_dir = join_unique_subdir(
        os.path.join(pipeline_root, "Transform/transform_graph"))
    transform_model_dir = os.path.join(transform_output_dir, "transform_fn")
    tft_saved_model_path = self.tft_saved_model_path(force_tf_compat_v1)

    shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True)
    shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True)
    shutil.rmtree(tft_saved_model_path, ignore_errors=True)
    shutil.copytree(serving_model_dir, self.trained_saved_model_path())
    shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
    shutil.copytree(transform_model_dir, tft_saved_model_path)