in sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py [0:0]
def lift_combiners(stages, context):
# type: (List[Stage], TransformContext) -> Iterator[Stage]
"""Expands CombinePerKey into pre- and post-grouping stages.
... -> CombinePerKey -> ...
becomes
... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
"""
def is_compatible_with_combiner_lifting(trigger):
'''Returns whether this trigger is compatible with combiner lifting.
Certain triggers, such as those that fire after a certain number of
elements, need to observe every element, and as such are incompatible
with combiner lifting (which may aggregate several elements into one
before they reach the triggering code after shuffle).
'''
if trigger is None:
return True
elif trigger.WhichOneof('trigger') in (
'default',
'always',
'never',
'after_processing_time',
'after_synchronized_processing_time'):
return True
elif trigger.HasField('element_count'):
return trigger.element_count.element_count == 1
elif trigger.HasField('after_end_of_window'):
return is_compatible_with_combiner_lifting(
trigger.after_end_of_window.early_firings
) and is_compatible_with_combiner_lifting(
trigger.after_end_of_window.late_firings)
elif trigger.HasField('after_any'):
return all(
is_compatible_with_combiner_lifting(t)
for t in trigger.after_any.subtriggers)
elif trigger.HasField('repeat'):
return is_compatible_with_combiner_lifting(trigger.repeat.subtrigger)
else:
return False
def can_lift(combine_per_key_transform):
windowing = context.components.windowing_strategies[
context.components.pcollections[only_element(
list(combine_per_key_transform.inputs.values())
)].windowing_strategy_id]
return is_compatible_with_combiner_lifting(windowing.trigger)
def make_stage(base_stage, transform):
# type: (Stage, beam_runner_api_pb2.PTransform) -> Stage
return Stage(
transform.unique_name, [transform],
downstream_side_inputs=base_stage.downstream_side_inputs,
must_follow=base_stage.must_follow,
parent=base_stage.name,
environment=base_stage.environment)
def lifted_stages(stage):
transform = stage.transforms[0]
combine_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
input_pcoll = context.components.pcollections[only_element(
list(transform.inputs.values()))]
output_pcoll = context.components.pcollections[only_element(
list(transform.outputs.values()))]
element_coder_id = input_pcoll.coder_id
element_coder = context.components.coders[element_coder_id]
key_coder_id, _ = element_coder.component_coder_ids
accumulator_coder_id = combine_payload.accumulator_coder_id
key_accumulator_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, accumulator_coder_id])
key_accumulator_coder_id = context.add_or_get_coder_id(
key_accumulator_coder)
accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.ITERABLE.urn),
component_coder_ids=[accumulator_coder_id])
accumulator_iter_coder_id = context.add_or_get_coder_id(
accumulator_iter_coder)
key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
key_accumulator_iter_coder_id = context.add_or_get_coder_id(
key_accumulator_iter_coder)
precombined_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[precombined_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Precombine.out',
coder_id=key_accumulator_coder_id,
windowing_strategy_id=input_pcoll.windowing_strategy_id,
is_bounded=input_pcoll.is_bounded))
grouped_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[grouped_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Group.out',
coder_id=key_accumulator_iter_coder_id,
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))
merged_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[merged_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Merge.out',
coder_id=key_accumulator_coder_id,
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Precombine',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.
urn,
payload=transform.spec.payload),
inputs=transform.inputs,
outputs={'out': precombined_pcoll_id},
environment_id=transform.environment_id))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Group',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.GROUP_BY_KEY.urn),
inputs={'in': precombined_pcoll_id},
outputs={'out': grouped_pcoll_id}))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Merge',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.
COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
payload=transform.spec.payload),
inputs={'in': grouped_pcoll_id},
outputs={'out': merged_pcoll_id},
environment_id=transform.environment_id))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/ExtractOutputs',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.
COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
payload=transform.spec.payload),
inputs={'in': merged_pcoll_id},
outputs=transform.outputs,
environment_id=transform.environment_id))
def unlifted_stages(stage):
transform = stage.transforms[0]
for sub in transform.subtransforms:
yield make_stage(stage, context.components.transforms[sub])
for stage in stages:
transform = only_transform(stage.transforms)
if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
expansion = lifted_stages if can_lift(transform) else unlifted_stages
for substage in expansion(stage):
yield substage
else:
yield stage