in tfx/benchmarks/tfma_benchmark_base.py [0:0]
def benchmarkAggregateCombineManualActuation(self):
"""Benchmark the aggregate combine stage "manually".
Runs _AggregateCombineFn "manually" outside a Beam pipeline. Records the
wall time taken.
"""
# Run InputsToExtracts manually.
records = []
for x in self._dataset.read_raw_dataset(
deserialize=False, limit=self._max_num_examples()):
records.append({tfma.constants.INPUT_KEY: x})
fn = tfma.extractors.legacy_predict_extractor._TFMAPredictionDoFn( # pylint: disable=protected-access
eval_shared_models={"": tfma.default_eval_shared_model(
eval_saved_model_path=self._dataset.tfma_saved_model_path())},
eval_config=None)
fn.setup()
# Predict
predict_batch_size = 1000
predict_result = []
for batch in benchmark_utils.batched_iterator(records, predict_batch_size):
predict_result.extend(fn.process(batch))
# AggregateCombineFn
#
# We simulate accumulating records into multiple different accumulators,
# each with inputs_per_accumulator records, and then merging the resulting
# accumulators together at one go.
# Number of elements to feed into a single accumulator.
# (This means we will have len(records) / inputs_per_accumulator
# accumulators to merge).
inputs_per_accumulator = 1000
combiner = tfma.evaluators.legacy_aggregate._AggregateCombineFn( # pylint: disable=protected-access
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=self._dataset.tfma_saved_model_path()))
combiner.setup()
accumulators = []
start = time.time()
for batch in benchmark_utils.batched_iterator(predict_result,
inputs_per_accumulator):
accumulator = combiner.create_accumulator()
for elem in batch:
combiner.add_input(accumulator, elem)
accumulators.append(accumulator)
final_accumulator = combiner.merge_accumulators(accumulators)
final_output = combiner.extract_output(final_accumulator)
end = time.time()
delta = end - start
# Extract output to sanity check example count. This is not timed.
extract_fn = tfma.evaluators.legacy_aggregate._ExtractOutputDoFn( # pylint: disable=protected-access
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=self._dataset.tfma_saved_model_path()))
extract_fn.setup()
interpreted_output = list(extract_fn.process(((), final_output)))
if len(interpreted_output) != 1:
raise ValueError("expecting exactly 1 interpreted output, got %d" %
(len(interpreted_output)))
got_example_count = interpreted_output[0][1].get(
"post_export_metrics/example_count")
if got_example_count != self._dataset.num_examples(
limit=self._max_num_examples()):
raise ValueError(
"example count mismatch: expecting %d got %d" %
(self._dataset.num_examples(limit=self._max_num_examples()),
got_example_count))
self.report_benchmark(
iters=1,
wall_time=delta,
extras={
"inputs_per_accumulator":
inputs_per_accumulator,
"num_examples":
self._dataset.num_examples(limit=self._max_num_examples())
})