in sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py [0:0]
def pack_per_key_combiners(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], bool]) -> Iterator[Stage]
"""Packs sibling CombinePerKey stages into a single CombinePerKey.
If CombinePerKey stages have a common input, one input each, and one output
each, pack the stages into a single stage that runs all CombinePerKeys and
outputs resulting tuples to a new PCollection. A subsequent stage unpacks
tuples from this PCollection and sends them to the original output
PCollections.
"""
class _UnpackFn(core.DoFn):
"""A DoFn that unpacks a packed to multiple tagged outputs.
Example:
tags = (T1, T2, ...)
input = (K, (V1, V2, ...))
output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
"""
def __init__(self, tags):
self._tags = tags
def process(self, element):
key, values = element
return [
core.pvalue.TaggedOutput(tag, (key, value)) for tag,
value in zip(self._tags, values)
]
def _get_fallback_coder_id():
return context.add_or_get_coder_id(
# passing None works here because there are no component coders
coders.registry.get_coder(object).to_runner_api(None)) # type: ignore[arg-type]
def _get_component_coder_id_from_kv_coder(coder, index):
assert index < 2
if coder.spec.urn == common_urns.coders.KV.urn and len(
coder.component_coder_ids) == 2:
return coder.component_coder_ids[index]
return _get_fallback_coder_id()
def _get_key_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 0)
def _get_value_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 1)
def _try_fuse_stages(a, b):
if a.can_fuse(b, context):
return a.fuse(b, context)
else:
raise ValueError
# Partition stages by whether they are eligible for CombinePerKey packing
# and group eligible CombinePerKey stages by parent and environment.
def get_stage_key(stage):
if (len(stage.transforms) == 1 and can_pack(stage.name) and
stage.environment is not None and python_urns.PACKED_COMBINE_FN in
context.components.environments[stage.environment].capabilities):
transform = only_transform(stage.transforms)
if (transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and
len(transform.inputs) == 1 and len(transform.outputs) == 1):
combine_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
if combine_payload.combine_fn.urn == python_urns.PICKLED_COMBINE_FN:
return (only_element(transform.inputs.values()), stage.environment)
return None
grouped_eligible_stages, ineligible_stages = _group_stages_by_key(
stages, get_stage_key)
for stage in ineligible_stages:
yield stage
for stage_key, packable_stages in grouped_eligible_stages.items():
input_pcoll_id, _ = stage_key
try:
if not len(packable_stages) > 1:
raise ValueError('Only one stage in this group: Skipping stage packing')
# Fused stage is used as template and is not yielded.
fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
except ValueError:
# Skip packing stages in this group.
# Yield the stages unmodified, and then continue to the next group.
for stage in packable_stages:
yield stage
continue
transforms = [only_transform(stage.transforms) for stage in packable_stages]
combine_payloads = [
proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
for transform in transforms
]
output_pcoll_ids = [
only_element(transform.outputs.values()) for transform in transforms
]
# Build accumulator coder for (acc1, acc2, ...)
accumulator_coder_ids = [
combine_payload.accumulator_coder_id
for combine_payload in combine_payloads
]
tuple_accumulator_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
component_coder_ids=accumulator_coder_ids))
# Build packed output coder for (key, (out1, out2, ...))
input_kv_coder_id = context.components.pcollections[input_pcoll_id].coder_id
key_coder_id = _get_key_coder_id_from_kv_coder(
context.components.coders[input_kv_coder_id])
output_kv_coder_ids = [
context.components.pcollections[output_pcoll_id].coder_id
for output_pcoll_id in output_pcoll_ids
]
output_value_coder_ids = [
_get_value_coder_id_from_kv_coder(
context.components.coders[output_kv_coder_id])
for output_kv_coder_id in output_kv_coder_ids
]
pack_output_value_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
component_coder_ids=output_value_coder_ids)
pack_output_value_coder_id = context.add_or_get_coder_id(
pack_output_value_coder)
pack_output_kv_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, pack_output_value_coder_id])
pack_output_kv_coder_id = context.add_or_get_coder_id(pack_output_kv_coder)
pack_stage_name = _make_pack_name([stage.name for stage in packable_stages])
pack_transform_name = _make_pack_name([
only_transform(stage.transforms).unique_name
for stage in packable_stages
])
pack_pcoll_id = unique_name(context.components.pcollections, 'pcollection')
input_pcoll = context.components.pcollections[input_pcoll_id]
context.components.pcollections[pack_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=pack_transform_name + '/Pack.out',
coder_id=pack_output_kv_coder_id,
windowing_strategy_id=input_pcoll.windowing_strategy_id,
is_bounded=input_pcoll.is_bounded))
# Set up Pack stage.
# TODO(BEAM-7746): classes that inherit from RunnerApiFn are expected to
# accept a PipelineContext for from_runner_api/to_runner_api. Determine
# how to accomodate this.
pack_combine_fn = combiners.SingleInputTupleCombineFn(
*[
core.CombineFn.from_runner_api(combine_payload.combine_fn, context) # type: ignore[arg-type]
for combine_payload in combine_payloads
]).to_runner_api(context) # type: ignore[arg-type]
pack_transform = beam_runner_api_pb2.PTransform(
unique_name=pack_transform_name + '/Pack',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.composites.COMBINE_PER_KEY.urn,
payload=beam_runner_api_pb2.CombinePayload(
combine_fn=pack_combine_fn,
accumulator_coder_id=tuple_accumulator_coder_id).
SerializeToString()),
inputs={'in': input_pcoll_id},
# 'None' single output key follows convention for CombinePerKey.
outputs={'None': pack_pcoll_id},
environment_id=fused_stage.environment)
pack_stage = Stage(
pack_stage_name + '/Pack', [pack_transform],
downstream_side_inputs=fused_stage.downstream_side_inputs,
must_follow=fused_stage.must_follow,
parent=fused_stage.parent,
environment=fused_stage.environment)
yield pack_stage
# Set up Unpack stage
tags = [str(i) for i in range(len(output_pcoll_ids))]
pickled_do_fn_data = pickler.dumps((_UnpackFn(tags), (), {}, [], None))
unpack_transform = beam_runner_api_pb2.PTransform(
unique_name=pack_transform_name + '/Unpack',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.PAR_DO.urn,
payload=beam_runner_api_pb2.ParDoPayload(
do_fn=beam_runner_api_pb2.FunctionSpec(
urn=python_urns.PICKLED_DOFN_INFO,
payload=pickled_do_fn_data)).SerializeToString()),
inputs={'in': pack_pcoll_id},
outputs=dict(zip(tags, output_pcoll_ids)),
environment_id=fused_stage.environment)
unpack_stage = Stage(
pack_stage_name + '/Unpack', [unpack_transform],
downstream_side_inputs=fused_stage.downstream_side_inputs,
must_follow=fused_stage.must_follow,
parent=fused_stage.parent,
environment=fused_stage.environment)
yield unpack_stage