def __call__()

in core/train_pipeline.py [0:0]


  def __call__(self, *input, **kwargs) -> Awaitable:
    assert self._name in self._context.input_dist_requests
    request = self._context.input_dist_requests[self._name]
    assert isinstance(request, Awaitable)
    with record_function("## wait_sparse_data_dist ##"):
      # Finish waiting on the dist_stream,
      # in case some delayed stream scheduling happens during the wait() call.
      with torch.cuda.stream(self._dist_stream):
        data = request.wait()

    # Make sure that both result of input_dist and context
    # are properly transferred to the current stream.
    if self._dist_stream is not None:
      torch.cuda.current_stream().wait_stream(self._dist_stream)
      cur_stream = torch.cuda.current_stream()

      assert isinstance(
        data, (torch.Tensor, Multistreamable)
      ), f"{type(data)} must implement Multistreamable interface"
      # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
      data.record_stream(cur_stream)

      ctx = self._context.module_contexts[self._name]
      ctx.record_stream(cur_stream)

    if len(self._context.feature_processor_forwards) > 0:
      with record_function("## feature_processor ##"):
        for sparse_feature in data:
          if sparse_feature.id_score_list_features is not None:
            for fp_forward in self._context.feature_processor_forwards:
              sparse_feature.id_score_list_features = fp_forward(
                sparse_feature.id_score_list_features
              )

    return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)