def expand_sdf()

in sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py [0:0]


def expand_sdf(stages, context):
  # type: (Iterable[Stage], TransformContext) -> Iterator[Stage]

  """Transforms splitable DoFns into pair+split+read."""
  for stage in stages:
    transform = only_transform(stage.transforms)
    if transform.spec.urn == common_urns.primitives.PAR_DO.urn:

      pardo_payload = proto_utils.parse_Bytes(
          transform.spec.payload, beam_runner_api_pb2.ParDoPayload)

      if pardo_payload.restriction_coder_id:

        def copy_like(protos, original, suffix='_copy', **kwargs):
          if isinstance(original, str):
            key = original
            original = protos[original]
          else:
            key = 'component'
          new_id = unique_name(protos, key + suffix)
          protos[new_id].CopyFrom(original)
          proto = protos[new_id]
          for name, value in kwargs.items():
            if isinstance(value, dict):
              getattr(proto, name).clear()
              getattr(proto, name).update(value)
            elif isinstance(value, list):
              del getattr(proto, name)[:]
              getattr(proto, name).extend(value)
            elif name == 'urn':
              proto.spec.urn = value
            elif name == 'payload':
              proto.spec.payload = value
            else:
              setattr(proto, name, value)
          if 'unique_name' not in kwargs and hasattr(proto, 'unique_name'):
            proto.unique_name = unique_name(
                {p.unique_name
                 for p in protos.values()},
                original.unique_name + suffix)
          return new_id

        def make_stage(base_stage, transform_id, extra_must_follow=()):
          # type: (Stage, str, Iterable[Stage]) -> Stage
          transform = context.components.transforms[transform_id]
          return Stage(
              transform.unique_name, [transform],
              base_stage.downstream_side_inputs,
              union(base_stage.must_follow, frozenset(extra_must_follow)),
              parent=base_stage.name,
              environment=base_stage.environment)

        main_input_tag = only_element(
            tag for tag in transform.inputs.keys()
            if tag not in pardo_payload.side_inputs)
        main_input_id = transform.inputs[main_input_tag]
        element_coder_id = context.components.pcollections[
            main_input_id].coder_id
        # Tuple[element, restriction]
        paired_coder_id = context.add_or_get_coder_id(
            beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=common_urns.coders.KV.urn),
                component_coder_ids=[
                    element_coder_id, pardo_payload.restriction_coder_id
                ]))
        # Tuple[Tuple[element, restriction], double]
        sized_coder_id = context.add_or_get_coder_id(
            beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=common_urns.coders.KV.urn),
                component_coder_ids=[
                    paired_coder_id,
                    context.add_or_get_coder_id(
                        # context can be None here only because FloatCoder does
                        # not have components
                        coders.FloatCoder().to_runner_api(None),  # type: ignore
                        'doubles_coder')
                ]))

        paired_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_paired',
            coder_id=paired_coder_id)
        pair_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/PairWithRestriction',
            urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
            outputs={'out': paired_pcoll_id})

        split_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_split',
            coder_id=sized_coder_id)
        split_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/SplitAndSizeRestriction',
            urn=common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
            inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
            outputs={'out': split_pcoll_id})

        reshuffle_stage = None
        if common_urns.composites.RESHUFFLE.urn in context.known_runner_urns:
          reshuffle_pcoll_id = copy_like(
              context.components.pcollections,
              main_input_id,
              '_reshuffle',
              coder_id=sized_coder_id)
          reshuffle_transform_id = copy_like(
              context.components.transforms,
              transform,
              unique_name=transform.unique_name + '/Reshuffle',
              urn=common_urns.composites.RESHUFFLE.urn,
              payload=b'',
              inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}),
              outputs={'out': reshuffle_pcoll_id})
          reshuffle_stage = make_stage(stage, reshuffle_transform_id)
        else:
          reshuffle_pcoll_id = split_pcoll_id
          reshuffle_transform_id = None

        if context.is_drain:
          truncate_pcoll_id = copy_like(
              context.components.pcollections,
              main_input_id,
              '_truncate_restriction',
              coder_id=sized_coder_id)
          # Lengthprefix the truncate output.
          context.length_prefix_pcoll_coders(truncate_pcoll_id)
          truncate_transform_id = copy_like(
              context.components.transforms,
              transform,
              unique_name=transform.unique_name + '/TruncateAndSizeRestriction',
              urn=common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
              inputs=dict(
                  transform.inputs, **{main_input_tag: reshuffle_pcoll_id}),
              outputs={'out': truncate_pcoll_id})
          process_transform_id = copy_like(
              context.components.transforms,
              transform,
              unique_name=transform.unique_name + '/Process',
              urn=common_urns.sdf_components.
              PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
              inputs=dict(
                  transform.inputs, **{main_input_tag: truncate_pcoll_id}))
        else:
          process_transform_id = copy_like(
              context.components.transforms,
              transform,
              unique_name=transform.unique_name + '/Process',
              urn=common_urns.sdf_components.
              PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
              inputs=dict(
                  transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))

        yield make_stage(stage, pair_transform_id)
        split_stage = make_stage(stage, split_transform_id)
        yield split_stage
        if reshuffle_stage:
          yield reshuffle_stage
        if context.is_drain:
          yield make_stage(
              stage, truncate_transform_id, extra_must_follow=[split_stage])
          yield make_stage(stage, process_transform_id)
        else:
          yield make_stage(
              stage, process_transform_id, extra_must_follow=[split_stage])

      else:
        yield stage

    else:
      yield stage