def generate_urn_files()

in sdks/python/gen_protos.py [0:0]


def generate_urn_files(log, out_dir):
  """
  Create python files with statically defined URN constants.

  Creates a <proto>_pb2_urn.py file for each <proto>_pb2.py file that contains
  an enum type.

  This works by importing each api.<proto>_pb2 module created by `protoc`,
  inspecting the module's contents, and generating a new side-car urn module.
  This is executed at build time rather than dynamically on import to ensure
  that it is compatible with static type checkers like mypy.
  """
  import google.protobuf.message as message
  import google.protobuf.pyext._message as pyext_message

  class Context(object):
    INDENT = '  '
    CAP_SPLIT = re.compile('([A-Z][^A-Z]*|^[a-z]+)')

    def __init__(self, indent=0):
      self.lines = []
      self.imports = set()
      self.empty_types = set()
      self._indent = indent

    @contextlib.contextmanager
    def indent(self):
      self._indent += 1
      yield
      self._indent -= 1

    def prepend(self, s):
      if s:
        self.lines.insert(0, (self.INDENT * self._indent) + s + '\n')
      else:
        self.lines.insert(0, '\n')

    def line(self, s):
      if s:
        self.lines.append((self.INDENT * self._indent) + s + '\n')
      else:
        self.lines.append('\n')

    def import_type(self, typ):
      modname = typ.__module__
      if modname in ('__builtin__', 'builtin'):
        return typ.__name__
      else:
        self.imports.add(modname)
        return modname + '.' + typ.__name__

    @staticmethod
    def is_message_type(obj):
      return isinstance(obj, type) and \
             issubclass(obj, message.Message)

    @staticmethod
    def is_enum_type(obj):
      return type(obj).__name__ == 'EnumTypeWrapper'

    def python_repr(self, obj):
      if isinstance(obj, message.Message):
        return self.message_repr(obj)
      elif isinstance(obj, (list,
                            pyext_message.RepeatedCompositeContainer,  # pylint: disable=c-extension-no-member
                            pyext_message.RepeatedScalarContainer)):  # pylint: disable=c-extension-no-member
        return '[%s]' % ', '.join(self.python_repr(x) for x in obj)
      else:
        return repr(obj)

    def empty_type(self, typ):
      name = ('EMPTY_' +
              '_'.join(x.upper()
                       for x in self.CAP_SPLIT.findall(typ.__name__)))
      self.empty_types.add('%s = %s()' % (name, self.import_type(typ)))
      return name

    def message_repr(self, msg):
      parts = []
      for field, value in msg.ListFields():
        parts.append('%s=%s' % (field.name, self.python_repr(value)))
      if parts:
        return '%s(%s)' % (self.import_type(type(msg)), ', '.join(parts))
      else:
        return self.empty_type(type(msg))

    def write_enum(self, enum_name, enum, indent):
      ctx = Context(indent=indent)

      with ctx.indent():
        for v in enum.DESCRIPTOR.values:
          extensions = v.GetOptions().Extensions

          prop = (
              extensions[beam_runner_api_pb2.beam_urn],
              extensions[beam_runner_api_pb2.beam_constant],
              extensions[metrics_pb2.monitoring_info_spec],
              extensions[metrics_pb2.label_props],
          )
          reprs = [self.python_repr(x) for x in prop]
          if all(x == "''" or x.startswith('EMPTY_') for x in reprs):
            continue
          ctx.line('%s = PropertiesFromEnumValue(%s)' %
                   (v.name, ', '.join(self.python_repr(x) for x in prop)))

      if ctx.lines:
        ctx.prepend('class %s(object):' % enum_name)
        ctx.prepend('')
        ctx.line('')
      return ctx.lines

    def write_message(self, message_name, message, indent=0):
      ctx = Context(indent=indent)

      with ctx.indent():
        for obj_name, obj in inspect.getmembers(message):
          if self.is_message_type(obj):
            ctx.lines += self.write_message(obj_name, obj, ctx._indent)
          elif self.is_enum_type(obj):
            ctx.lines += self.write_enum(obj_name, obj, ctx._indent)

      if ctx.lines:
        ctx.prepend('class %s(object):' % message_name)
        ctx.prepend('')
      return ctx.lines

  pb2_files = [path for path in glob.glob(os.path.join(out_dir, '*_pb2.py'))]
  api_path = os.path.dirname(pb2_files[0])
  sys.path.insert(0, os.path.dirname(api_path))

  def _import(m):
    # TODO: replace with importlib when we drop support for python2.
    return __import__('api.%s' % m, fromlist=[None])

  try:
    beam_runner_api_pb2 = _import('beam_runner_api_pb2')
    metrics_pb2 = _import('metrics_pb2')

    for pb2_file in pb2_files:
      modname = os.path.splitext(pb2_file)[0]
      out_file = modname + '_urns.py'
      modname = os.path.basename(modname)
      mod = _import(modname)

      ctx = Context()
      for obj_name, obj in inspect.getmembers(mod):
        if ctx.is_message_type(obj):
          ctx.lines += ctx.write_message(obj_name, obj)

      if ctx.lines:
        for line in reversed(sorted(ctx.empty_types)):
          ctx.prepend(line)

        for modname in reversed(sorted(ctx.imports)):
          ctx.prepend('from . import %s' % modname)

        ctx.prepend('from ..utils import PropertiesFromEnumValue')

        log.info("Writing urn stubs: %s" % out_file)
        with open(out_file, 'w') as f:
          f.writelines(ctx.lines)

  finally:
    sys.path.pop(0)