in tfx/benchmarks/datasets/chicago_taxi/dataset.py [0:0]
def generate_models(self, args, force_tf_compat_v1=True):
# Modified version of Chicago Taxi Example pipeline
# tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py
root = tempfile.mkdtemp()
pipeline_root = os.path.join(root, "pipeline")
metadata_path = os.path.join(root, "metadata/metadata.db")
module_file = os.path.join(
os.path.dirname(__file__),
"../../../examples/chicago_taxi_pipeline/taxi_utils.py")
example_gen = components.ImportExampleGen(
input_base=os.path.dirname(self.dataset_path()))
statistics_gen = components.StatisticsGen(
examples=example_gen.outputs["examples"])
schema_gen = components.SchemaGen(
statistics=statistics_gen.outputs["statistics"],
infer_feature_shape=False)
transform = components.Transform(
examples=example_gen.outputs["examples"],
schema=schema_gen.outputs["schema"],
module_file=module_file,
force_tf_compat_v1=force_tf_compat_v1)
trainer = components.Trainer(
module_file=module_file,
transformed_examples=transform.outputs["transformed_examples"],
schema=schema_gen.outputs["schema"],
transform_graph=transform.outputs["transform_graph"],
train_args=trainer_pb2.TrainArgs(num_steps=100),
eval_args=trainer_pb2.EvalArgs(num_steps=50))
p = pipeline.Pipeline(
pipeline_name="chicago_taxi_beam",
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, schema_gen, transform, trainer
],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path))
BeamDagRunner().run(p)
def join_unique_subdir(path):
dirs = os.listdir(path)
if len(dirs) != 1:
raise ValueError(
"expecting there to be only one subdirectory in %s, but "
"subdirectories were: %s" % (path, dirs))
return os.path.join(path, dirs[0])
trainer_output_dir = join_unique_subdir(
os.path.join(pipeline_root, "Trainer/model"))
eval_model_dir = join_unique_subdir(
os.path.join(trainer_output_dir, "eval_model_dir"))
serving_model_dir = join_unique_subdir(
os.path.join(trainer_output_dir,
"serving_model_dir/export/chicago-taxi"))
transform_output_dir = join_unique_subdir(
os.path.join(pipeline_root, "Transform/transform_graph"))
transform_model_dir = os.path.join(transform_output_dir, "transform_fn")
tft_saved_model_path = self.tft_saved_model_path(force_tf_compat_v1)
shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True)
shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True)
shutil.rmtree(tft_saved_model_path, ignore_errors=True)
shutil.copytree(serving_model_dir, self.trained_saved_model_path())
shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
shutil.copytree(transform_model_dir, tft_saved_model_path)