in spotify_tensorflow/tfx/tft.py [0:0]
def tftransform(pipeline_args, # type: List[str]
temp_location, # type: str
schema_file, # type: str
output_dir, # type: str
preprocessing_fn, # type: Any
training_data=None, # type: Union[None, str]
evaluation_data=None, # type: Union[None, str]
transform_fn_dir=None, # type: Union[None, str]
compression_type=None # type: str
): # type: (...) -> PipelineState
"""
Generic tf.transform pipeline that takes tf.{example, record} training and evaluation
datasets and outputs transformed data together with transform function Saved Model.
:param pipeline_args: un-parsed Dataflow arguments
:param temp_location: temporary location for dataflow job working dir
:param schema_file: path to the raw feature schema text file
:param output_dir: output dir for transformed data and function
:param preprocessing_fn: tf.transform preprocessing function
:param training_data: path to the training data
:param evaluation_data: path to the evaluation data
:param transform_fn_dir: dir to previously saved transformation function to apply
:param compression_type: compression type for writing of tf.records
:return final state of the Beam pipeline
"""
assert_not_empty_string(temp_location)
assert_not_empty_string(schema_file)
assert_not_empty_string(output_dir)
assert_not_none(preprocessing_fn)
if compression_type is None:
compression_type = CompressionTypes.AUTO
raw_feature_spec = schema_txt_file_to_feature_spec(schema_file)
raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)
raw_data_coder = ExampleProtoCoder(raw_data_metadata.schema)
transformed_train_output_dir = os.path.join(output_dir, "training")
transformed_eval_output_dir = os.path.join(output_dir, "evaluation")
if not any(i.startswith("--job_name") for i in pipeline_args):
pipeline_args.append("--job_name=tf-transform-{}-{}".format(getpass.getuser(),
int(time.time())))
pipeline = beam.Pipeline(argv=pipeline_args)
with beam_impl.Context(temp_dir=temp_location):
if training_data is not None:
# if training data is provided, transform_fn_dir will be ignored
if transform_fn_dir is not None:
warnings.warn("Transform_fn_dir is ignored because training_data is provided")
transform_fn_output = os.path.join(output_dir, "transform_fn", "saved_model.pb")
if FileSystems.exists(transform_fn_output):
raise ValueError("Transform fn already exists at %s!" % transform_fn_output)
# compute the transform_fn and apply to the training data
raw_train_data = (
pipeline
| "ReadTrainData" >> tfrecordio.ReadFromTFRecord(training_data,
coder=raw_data_coder))
((transformed_train_data, transformed_train_metadata), transform_fn) = (
(raw_train_data, raw_data_metadata)
| ("AnalyzeAndTransformTrainData" >> beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))) # noqa: E501
_ = ( # noqa: F841
transform_fn
| "WriteTransformFn" >>
transform_fn_io.WriteTransformFn(output_dir))
transformed_train_coder = ExampleProtoCoder(transformed_train_metadata.schema)
_ = ( # noqa: F841
transformed_train_data
| "WriteTransformedTrainData" >> tfrecordio.WriteToTFRecord(os.path.join(transformed_train_output_dir, "part"), # noqa: E501
coder=transformed_train_coder, # noqa: E501
compression_type=compression_type, # noqa: E501
file_name_suffix=".tfrecords")) # noqa: E501
else:
if transform_fn_dir is None:
raise ValueError("Either training_data or transformed_fn needs to be provided")
# load the transform_fn
transform_fn = pipeline | transform_fn_io.ReadTransformFn(transform_fn_dir)
if evaluation_data is not None:
# if evaluation_data exists, apply the transform_fn to the evaluation data
raw_eval_data = (
pipeline
| "ReadEvalData" >> tfrecordio.ReadFromTFRecord(evaluation_data,
coder=raw_data_coder))
(transformed_eval_data, transformed_eval_metadata) = (
((raw_eval_data, raw_data_metadata), transform_fn)
| "TransformEvalData" >> beam_impl.TransformDataset())
transformed_eval_coder = ExampleProtoCoder(transformed_eval_metadata.schema)
_ = ( # noqa: F841
transformed_eval_data
| "WriteTransformedEvalData" >> tfrecordio.WriteToTFRecord(os.path.join(transformed_eval_output_dir, "part"), # noqa: E501
coder=transformed_eval_coder, # noqa: E501
compression_type=compression_type, # noqa: E501
file_name_suffix=".tfrecords")) # noqa: E501
result = pipeline.run().wait_until_finish()
return result