def infer_return_type_func()

in sdks/python/apache_beam/typehints/trivial_inference.py [0:0]


def infer_return_type_func(f, input_types, debug=False, depth=0):
  """Analyses a function to deduce its return type.

  Args:
    f: A Python function object to infer the return type of.
    input_types: A sequence of inputs corresponding to the input types.
    debug: Whether to print verbose debugging information.
    depth: Maximum inspection depth during type inference.

  Returns:
    A TypeConstraint that that the return value of this function will (likely)
    satisfy given the specified inputs.

  Raises:
    TypeInferenceError: if no type can be inferred.
  """
  if debug:
    print()
    print(f, id(f), input_types)
    dis.dis(f)
  from . import opcodes
  simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items())

  co = f.__code__
  code = co.co_code
  end = len(code)
  pc = 0
  free = None

  yields = set()
  returns = set()
  # TODO(robertwb): Default args via inspect module.
  local_vars = list(input_types) + [typehints.Union[()]] * (
      len(co.co_varnames) - len(input_types))
  state = FrameState(f, local_vars)
  states = collections.defaultdict(lambda: None)
  jumps = collections.defaultdict(int)

  # In Python 3, use dis library functions to disassemble bytecode and handle
  # EXTENDED_ARGs.
  ofs_table = {}  # offset -> instruction
  for instruction in dis.get_instructions(f):
    ofs_table[instruction.offset] = instruction

  # Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored).
  inst_size = 2
  opt_arg_size = 0

  last_pc = -1
  while pc < end:  # pylint: disable=too-many-nested-blocks
    start = pc
    instruction = ofs_table[pc]
    op = instruction.opcode
    if debug:
      print('-->' if pc == last_pc else '    ', end=' ')
      print(repr(pc).rjust(4), end=' ')
      print(dis.opname[op].ljust(20), end=' ')

    pc += inst_size
    if op >= dis.HAVE_ARGUMENT:
      arg = instruction.arg
      pc += opt_arg_size
      if debug:
        print(str(arg).rjust(5), end=' ')
        if op in dis.hasconst:
          print('(' + repr(co.co_consts[arg]) + ')', end=' ')
        elif op in dis.hasname:
          print('(' + co.co_names[arg] + ')', end=' ')
        elif op in dis.hasjrel:
          print('(to ' + repr(pc + arg) + ')', end=' ')
        elif op in dis.haslocal:
          print('(' + co.co_varnames[arg] + ')', end=' ')
        elif op in dis.hascompare:
          print('(' + dis.cmp_op[arg] + ')', end=' ')
        elif op in dis.hasfree:
          if free is None:
            free = co.co_cellvars + co.co_freevars
          print('(' + free[arg] + ')', end=' ')

    # Actually emulate the op.
    if state is None and states[start] is None:
      # No control reaches here (yet).
      if debug:
        print()
      continue
    state |= states[start]

    opname = dis.opname[op]
    jmp = jmp_state = None
    if opname.startswith('CALL_FUNCTION'):
      if opname == 'CALL_FUNCTION':
        pop_count = arg + 1
        if depth <= 0:
          return_type = Any
        elif isinstance(state.stack[-pop_count], Const):
          return_type = infer_return_type(
              state.stack[-pop_count].value,
              state.stack[1 - pop_count:],
              debug=debug,
              depth=depth - 1)
        else:
          return_type = Any
      elif opname == 'CALL_FUNCTION_KW':
        # TODO(udim): Handle keyword arguments. Requires passing them by name
        #   to infer_return_type.
        pop_count = arg + 2
        if isinstance(state.stack[-pop_count], Const):
          from apache_beam.pvalue import Row
          if state.stack[-pop_count].value == Row:
            fields = state.stack[-1].value
            return_type = row_type.RowTypeConstraint(
                zip(fields, Const.unwrap_all(state.stack[-pop_count + 1:-1])))
          else:
            return_type = Any
        else:
          return_type = Any
      elif opname == 'CALL_FUNCTION_EX':
        # stack[-has_kwargs]: Map of keyword args.
        # stack[-1 - has_kwargs]: Iterable of positional args.
        # stack[-2 - has_kwargs]: Function to call.
        has_kwargs = arg & 1  # type: int
        pop_count = has_kwargs + 2
        if has_kwargs:
          # TODO(udim): Unimplemented. Requires same functionality as a
          #   CALL_FUNCTION_KW implementation.
          return_type = Any
        else:
          args = state.stack[-1]
          _callable = state.stack[-2]
          if isinstance(args, typehints.ListConstraint):
            # Case where there's a single var_arg argument.
            args = [args]
          elif isinstance(args, typehints.TupleConstraint):
            args = list(args._inner_types())
          return_type = infer_return_type(
              _callable.value, args, debug=debug, depth=depth - 1)
      else:
        raise TypeInferenceError('unable to handle %s' % opname)
      state.stack[-pop_count:] = [return_type]
    elif opname == 'CALL_METHOD':
      pop_count = 1 + arg
      # LOAD_METHOD will return a non-Const (Any) if loading from an Any.
      if isinstance(state.stack[-pop_count], Const) and depth > 0:
        return_type = infer_return_type(
            state.stack[-pop_count].value,
            state.stack[1 - pop_count:],
            debug=debug,
            depth=depth - 1)
      else:
        return_type = typehints.Any
      state.stack[-pop_count:] = [return_type]
    elif opname in simple_ops:
      if debug:
        print("Executing simple op " + opname)
      simple_ops[opname](state, arg)
    elif opname == 'RETURN_VALUE':
      returns.add(state.stack[-1])
      state = None
    elif opname == 'YIELD_VALUE':
      yields.add(state.stack[-1])
    elif opname == 'JUMP_FORWARD':
      jmp = pc + arg
      jmp_state = state
      state = None
    elif opname == 'JUMP_ABSOLUTE':
      jmp = arg
      jmp_state = state
      state = None
    elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'):
      state.stack.pop()
      jmp = arg
      jmp_state = state.copy()
    elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'):
      jmp = arg
      jmp_state = state.copy()
      state.stack.pop()
    elif opname == 'FOR_ITER':
      jmp = pc + arg
      jmp_state = state.copy()
      jmp_state.stack.pop()
      state.stack.append(element_type(state.stack[-1]))
    else:
      raise TypeInferenceError('unable to handle %s' % opname)

    if jmp is not None:
      # TODO(robertwb): Is this guaranteed to converge?
      new_state = states[jmp] | jmp_state
      if jmp < pc and new_state != states[jmp] and jumps[pc] < 5:
        jumps[pc] += 1
        pc = jmp
      states[jmp] = new_state

    if debug:
      print()
      print(state)
      pprint.pprint(dict(item for item in states.items() if item[1]))

  if yields:
    result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))]
  else:
    result = reduce(union, Const.unwrap_all(returns))
  finalize_hints(result)

  if debug:
    print(f, id(f), input_types, '->', result)
  return result