in core/train_pipeline.py [0:0]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.cuda.stream(self._dist_stream):
data = request.wait()
# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._dist_stream is not None:
torch.cuda.current_stream().wait_stream(self._dist_stream)
cur_stream = torch.cuda.current_stream()
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)
ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)
if len(self._context.feature_processor_forwards) > 0:
with record_function("## feature_processor ##"):
for sparse_feature in data:
if sparse_feature.id_score_list_features is not None:
for fp_forward in self._context.feature_processor_forwards:
sparse_feature.id_score_list_features = fp_forward(
sparse_feature.id_score_list_features
)
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)