def tftransform()

in spotify_tensorflow/tfx/tft.py [0:0]


def tftransform(pipeline_args,                          # type: List[str]
                temp_location,                          # type: str
                schema_file,                            # type: str
                output_dir,                             # type: str
                preprocessing_fn,                       # type: Any
                training_data=None,                     # type: Union[None, str]
                evaluation_data=None,                   # type: Union[None, str]
                transform_fn_dir=None,                  # type: Union[None, str]
                compression_type=None                   # type: str
                ):  # type: (...) -> PipelineState
    """
    Generic tf.transform pipeline that takes tf.{example, record} training and evaluation
    datasets and outputs transformed data together with transform function Saved Model.

    :param pipeline_args: un-parsed Dataflow arguments
    :param temp_location: temporary location for dataflow job working dir
    :param schema_file: path to the raw feature schema text file
    :param output_dir: output dir for transformed data and function
    :param preprocessing_fn: tf.transform preprocessing function
    :param training_data: path to the training data
    :param evaluation_data: path to the evaluation data
    :param transform_fn_dir: dir to previously saved transformation function to apply
    :param compression_type: compression type for writing of tf.records
    :return final state of the Beam pipeline
    """
    assert_not_empty_string(temp_location)
    assert_not_empty_string(schema_file)
    assert_not_empty_string(output_dir)
    assert_not_none(preprocessing_fn)

    if compression_type is None:
        compression_type = CompressionTypes.AUTO

    raw_feature_spec = schema_txt_file_to_feature_spec(schema_file)
    raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
    raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)
    raw_data_coder = ExampleProtoCoder(raw_data_metadata.schema)

    transformed_train_output_dir = os.path.join(output_dir, "training")
    transformed_eval_output_dir = os.path.join(output_dir, "evaluation")

    if not any(i.startswith("--job_name") for i in pipeline_args):
        pipeline_args.append("--job_name=tf-transform-{}-{}".format(getpass.getuser(),
                                                                    int(time.time())))

    pipeline = beam.Pipeline(argv=pipeline_args)
    with beam_impl.Context(temp_dir=temp_location):
        if training_data is not None:
            # if training data is provided, transform_fn_dir will be ignored
            if transform_fn_dir is not None:
                warnings.warn("Transform_fn_dir is ignored because training_data is provided")

            transform_fn_output = os.path.join(output_dir, "transform_fn", "saved_model.pb")
            if FileSystems.exists(transform_fn_output):
                raise ValueError("Transform fn already exists at %s!" % transform_fn_output)

            # compute the transform_fn and apply to the training data
            raw_train_data = (
                    pipeline
                    | "ReadTrainData" >> tfrecordio.ReadFromTFRecord(training_data,
                                                                     coder=raw_data_coder))

            ((transformed_train_data, transformed_train_metadata), transform_fn) = (
                    (raw_train_data, raw_data_metadata)
                    | ("AnalyzeAndTransformTrainData" >> beam_impl.AnalyzeAndTransformDataset(preprocessing_fn)))  # noqa: E501

            _ = (   # noqa: F841
                    transform_fn
                    | "WriteTransformFn" >>
                    transform_fn_io.WriteTransformFn(output_dir))

            transformed_train_coder = ExampleProtoCoder(transformed_train_metadata.schema)
            _ = (   # noqa: F841
                    transformed_train_data
                    | "WriteTransformedTrainData" >> tfrecordio.WriteToTFRecord(os.path.join(transformed_train_output_dir, "part"),  # noqa: E501
                                                                                coder=transformed_train_coder,  # noqa: E501
                                                                                compression_type=compression_type,  # noqa: E501
                                                                                file_name_suffix=".tfrecords"))  # noqa: E501
        else:
            if transform_fn_dir is None:
                raise ValueError("Either training_data or transformed_fn needs to be provided")
            # load the transform_fn
            transform_fn = pipeline | transform_fn_io.ReadTransformFn(transform_fn_dir)

        if evaluation_data is not None:
            # if evaluation_data exists, apply the transform_fn to the evaluation data
            raw_eval_data = (
                    pipeline
                    | "ReadEvalData" >> tfrecordio.ReadFromTFRecord(evaluation_data,
                                                                    coder=raw_data_coder))

            (transformed_eval_data, transformed_eval_metadata) = (
                    ((raw_eval_data, raw_data_metadata), transform_fn)
                    | "TransformEvalData" >> beam_impl.TransformDataset())

            transformed_eval_coder = ExampleProtoCoder(transformed_eval_metadata.schema)
            _ = (   # noqa: F841
                    transformed_eval_data
                    | "WriteTransformedEvalData" >> tfrecordio.WriteToTFRecord(os.path.join(transformed_eval_output_dir, "part"),  # noqa: E501
                                                                               coder=transformed_eval_coder,  # noqa: E501
                                                                               compression_type=compression_type,  # noqa: E501
                                                                               file_name_suffix=".tfrecords"))  # noqa: E501
    result = pipeline.run().wait_until_finish()

    return result