in sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py [0:0]
def expand_sdf(stages, context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Transforms splitable DoFns into pair+split+read."""
for stage in stages:
transform = only_transform(stage.transforms)
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.restriction_coder_id:
def copy_like(protos, original, suffix='_copy', **kwargs):
if isinstance(original, str):
key = original
original = protos[original]
else:
key = 'component'
new_id = unique_name(protos, key + suffix)
protos[new_id].CopyFrom(original)
proto = protos[new_id]
for name, value in kwargs.items():
if isinstance(value, dict):
getattr(proto, name).clear()
getattr(proto, name).update(value)
elif isinstance(value, list):
del getattr(proto, name)[:]
getattr(proto, name).extend(value)
elif name == 'urn':
proto.spec.urn = value
elif name == 'payload':
proto.spec.payload = value
else:
setattr(proto, name, value)
if 'unique_name' not in kwargs and hasattr(proto, 'unique_name'):
proto.unique_name = unique_name(
{p.unique_name
for p in protos.values()},
original.unique_name + suffix)
return new_id
def make_stage(base_stage, transform_id, extra_must_follow=()):
# type: (Stage, str, Iterable[Stage]) -> Stage
transform = context.components.transforms[transform_id]
return Stage(
transform.unique_name, [transform],
base_stage.downstream_side_inputs,
union(base_stage.must_follow, frozenset(extra_must_follow)),
parent=base_stage.name,
environment=base_stage.environment)
main_input_tag = only_element(
tag for tag in transform.inputs.keys()
if tag not in pardo_payload.side_inputs)
main_input_id = transform.inputs[main_input_tag]
element_coder_id = context.components.pcollections[
main_input_id].coder_id
# Tuple[element, restriction]
paired_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.KV.urn),
component_coder_ids=[
element_coder_id, pardo_payload.restriction_coder_id
]))
# Tuple[Tuple[element, restriction], double]
sized_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.KV.urn),
component_coder_ids=[
paired_coder_id,
context.add_or_get_coder_id(
# context can be None here only because FloatCoder does
# not have components
coders.FloatCoder().to_runner_api(None), # type: ignore
'doubles_coder')
]))
paired_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_paired',
coder_id=paired_coder_id)
pair_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/PairWithRestriction',
urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
outputs={'out': paired_pcoll_id})
split_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_split',
coder_id=sized_coder_id)
split_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/SplitAndSizeRestriction',
urn=common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
outputs={'out': split_pcoll_id})
reshuffle_stage = None
if common_urns.composites.RESHUFFLE.urn in context.known_runner_urns:
reshuffle_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_reshuffle',
coder_id=sized_coder_id)
reshuffle_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Reshuffle',
urn=common_urns.composites.RESHUFFLE.urn,
payload=b'',
inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}),
outputs={'out': reshuffle_pcoll_id})
reshuffle_stage = make_stage(stage, reshuffle_transform_id)
else:
reshuffle_pcoll_id = split_pcoll_id
reshuffle_transform_id = None
if context.is_drain:
truncate_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_truncate_restriction',
coder_id=sized_coder_id)
# Lengthprefix the truncate output.
context.length_prefix_pcoll_coders(truncate_pcoll_id)
truncate_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/TruncateAndSizeRestriction',
urn=common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
inputs=dict(
transform.inputs, **{main_input_tag: reshuffle_pcoll_id}),
outputs={'out': truncate_pcoll_id})
process_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Process',
urn=common_urns.sdf_components.
PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
inputs=dict(
transform.inputs, **{main_input_tag: truncate_pcoll_id}))
else:
process_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Process',
urn=common_urns.sdf_components.
PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
inputs=dict(
transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
yield make_stage(stage, pair_transform_id)
split_stage = make_stage(stage, split_transform_id)
yield split_stage
if reshuffle_stage:
yield reshuffle_stage
if context.is_drain:
yield make_stage(
stage, truncate_transform_id, extra_must_follow=[split_stage])
yield make_stage(stage, process_transform_id)
else:
yield make_stage(
stage, process_transform_id, extra_must_follow=[split_stage])
else:
yield stage
else:
yield stage