def benchmarkAggregateCombineManualActuation()

in tfx/benchmarks/tfma_benchmark_base.py [0:0]


  def benchmarkAggregateCombineManualActuation(self):
    """Benchmark the aggregate combine stage "manually".

    Runs _AggregateCombineFn "manually" outside a Beam pipeline. Records the
    wall time taken.
    """

    # Run InputsToExtracts manually.
    records = []
    for x in self._dataset.read_raw_dataset(
        deserialize=False, limit=self._max_num_examples()):
      records.append({tfma.constants.INPUT_KEY: x})

    fn = tfma.extractors.legacy_predict_extractor._TFMAPredictionDoFn(  # pylint: disable=protected-access
        eval_shared_models={"": tfma.default_eval_shared_model(
            eval_saved_model_path=self._dataset.tfma_saved_model_path())},
        eval_config=None)
    fn.setup()

    # Predict
    predict_batch_size = 1000
    predict_result = []
    for batch in benchmark_utils.batched_iterator(records, predict_batch_size):
      predict_result.extend(fn.process(batch))

    # AggregateCombineFn
    #
    # We simulate accumulating records into multiple different accumulators,
    # each with inputs_per_accumulator records, and then merging the resulting
    # accumulators together at one go.

    # Number of elements to feed into a single accumulator.
    # (This means we will have len(records) / inputs_per_accumulator
    # accumulators to merge).
    inputs_per_accumulator = 1000

    combiner = tfma.evaluators.legacy_aggregate._AggregateCombineFn(  # pylint: disable=protected-access
        eval_shared_model=tfma.default_eval_shared_model(
            eval_saved_model_path=self._dataset.tfma_saved_model_path()))
    combiner.setup()
    accumulators = []

    start = time.time()
    for batch in benchmark_utils.batched_iterator(predict_result,
                                                  inputs_per_accumulator):
      accumulator = combiner.create_accumulator()
      for elem in batch:
        combiner.add_input(accumulator, elem)
      accumulators.append(accumulator)
    final_accumulator = combiner.merge_accumulators(accumulators)
    final_output = combiner.extract_output(final_accumulator)
    end = time.time()
    delta = end - start

    # Extract output to sanity check example count. This is not timed.
    extract_fn = tfma.evaluators.legacy_aggregate._ExtractOutputDoFn(  # pylint: disable=protected-access
        eval_shared_model=tfma.default_eval_shared_model(
            eval_saved_model_path=self._dataset.tfma_saved_model_path()))
    extract_fn.setup()
    interpreted_output = list(extract_fn.process(((), final_output)))
    if len(interpreted_output) != 1:
      raise ValueError("expecting exactly 1 interpreted output, got %d" %
                       (len(interpreted_output)))
    got_example_count = interpreted_output[0][1].get(
        "post_export_metrics/example_count")
    if got_example_count != self._dataset.num_examples(
        limit=self._max_num_examples()):
      raise ValueError(
          "example count mismatch: expecting %d got %d" %
          (self._dataset.num_examples(limit=self._max_num_examples()),
           got_example_count))

    self.report_benchmark(
        iters=1,
        wall_time=delta,
        extras={
            "inputs_per_accumulator":
                inputs_per_accumulator,
            "num_examples":
                self._dataset.num_examples(limit=self._max_num_examples())
        })