in sdks/python/apache_beam/typehints/trivial_inference.py [0:0]
def infer_return_type_func(f, input_types, debug=False, depth=0):
"""Analyses a function to deduce its return type.
Args:
f: A Python function object to infer the return type of.
input_types: A sequence of inputs corresponding to the input types.
debug: Whether to print verbose debugging information.
depth: Maximum inspection depth during type inference.
Returns:
A TypeConstraint that that the return value of this function will (likely)
satisfy given the specified inputs.
Raises:
TypeInferenceError: if no type can be inferred.
"""
if debug:
print()
print(f, id(f), input_types)
dis.dis(f)
from . import opcodes
simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items())
co = f.__code__
code = co.co_code
end = len(code)
pc = 0
free = None
yields = set()
returns = set()
# TODO(robertwb): Default args via inspect module.
local_vars = list(input_types) + [typehints.Union[()]] * (
len(co.co_varnames) - len(input_types))
state = FrameState(f, local_vars)
states = collections.defaultdict(lambda: None)
jumps = collections.defaultdict(int)
# In Python 3, use dis library functions to disassemble bytecode and handle
# EXTENDED_ARGs.
ofs_table = {} # offset -> instruction
for instruction in dis.get_instructions(f):
ofs_table[instruction.offset] = instruction
# Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored).
inst_size = 2
opt_arg_size = 0
last_pc = -1
while pc < end: # pylint: disable=too-many-nested-blocks
start = pc
instruction = ofs_table[pc]
op = instruction.opcode
if debug:
print('-->' if pc == last_pc else ' ', end=' ')
print(repr(pc).rjust(4), end=' ')
print(dis.opname[op].ljust(20), end=' ')
pc += inst_size
if op >= dis.HAVE_ARGUMENT:
arg = instruction.arg
pc += opt_arg_size
if debug:
print(str(arg).rjust(5), end=' ')
if op in dis.hasconst:
print('(' + repr(co.co_consts[arg]) + ')', end=' ')
elif op in dis.hasname:
print('(' + co.co_names[arg] + ')', end=' ')
elif op in dis.hasjrel:
print('(to ' + repr(pc + arg) + ')', end=' ')
elif op in dis.haslocal:
print('(' + co.co_varnames[arg] + ')', end=' ')
elif op in dis.hascompare:
print('(' + dis.cmp_op[arg] + ')', end=' ')
elif op in dis.hasfree:
if free is None:
free = co.co_cellvars + co.co_freevars
print('(' + free[arg] + ')', end=' ')
# Actually emulate the op.
if state is None and states[start] is None:
# No control reaches here (yet).
if debug:
print()
continue
state |= states[start]
opname = dis.opname[op]
jmp = jmp_state = None
if opname.startswith('CALL_FUNCTION'):
if opname == 'CALL_FUNCTION':
pop_count = arg + 1
if depth <= 0:
return_type = Any
elif isinstance(state.stack[-pop_count], Const):
return_type = infer_return_type(
state.stack[-pop_count].value,
state.stack[1 - pop_count:],
debug=debug,
depth=depth - 1)
else:
return_type = Any
elif opname == 'CALL_FUNCTION_KW':
# TODO(udim): Handle keyword arguments. Requires passing them by name
# to infer_return_type.
pop_count = arg + 2
if isinstance(state.stack[-pop_count], Const):
from apache_beam.pvalue import Row
if state.stack[-pop_count].value == Row:
fields = state.stack[-1].value
return_type = row_type.RowTypeConstraint(
zip(fields, Const.unwrap_all(state.stack[-pop_count + 1:-1])))
else:
return_type = Any
else:
return_type = Any
elif opname == 'CALL_FUNCTION_EX':
# stack[-has_kwargs]: Map of keyword args.
# stack[-1 - has_kwargs]: Iterable of positional args.
# stack[-2 - has_kwargs]: Function to call.
has_kwargs = arg & 1 # type: int
pop_count = has_kwargs + 2
if has_kwargs:
# TODO(udim): Unimplemented. Requires same functionality as a
# CALL_FUNCTION_KW implementation.
return_type = Any
else:
args = state.stack[-1]
_callable = state.stack[-2]
if isinstance(args, typehints.ListConstraint):
# Case where there's a single var_arg argument.
args = [args]
elif isinstance(args, typehints.TupleConstraint):
args = list(args._inner_types())
return_type = infer_return_type(
_callable.value, args, debug=debug, depth=depth - 1)
else:
raise TypeInferenceError('unable to handle %s' % opname)
state.stack[-pop_count:] = [return_type]
elif opname == 'CALL_METHOD':
pop_count = 1 + arg
# LOAD_METHOD will return a non-Const (Any) if loading from an Any.
if isinstance(state.stack[-pop_count], Const) and depth > 0:
return_type = infer_return_type(
state.stack[-pop_count].value,
state.stack[1 - pop_count:],
debug=debug,
depth=depth - 1)
else:
return_type = typehints.Any
state.stack[-pop_count:] = [return_type]
elif opname in simple_ops:
if debug:
print("Executing simple op " + opname)
simple_ops[opname](state, arg)
elif opname == 'RETURN_VALUE':
returns.add(state.stack[-1])
state = None
elif opname == 'YIELD_VALUE':
yields.add(state.stack[-1])
elif opname == 'JUMP_FORWARD':
jmp = pc + arg
jmp_state = state
state = None
elif opname == 'JUMP_ABSOLUTE':
jmp = arg
jmp_state = state
state = None
elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'):
state.stack.pop()
jmp = arg
jmp_state = state.copy()
elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'):
jmp = arg
jmp_state = state.copy()
state.stack.pop()
elif opname == 'FOR_ITER':
jmp = pc + arg
jmp_state = state.copy()
jmp_state.stack.pop()
state.stack.append(element_type(state.stack[-1]))
else:
raise TypeInferenceError('unable to handle %s' % opname)
if jmp is not None:
# TODO(robertwb): Is this guaranteed to converge?
new_state = states[jmp] | jmp_state
if jmp < pc and new_state != states[jmp] and jumps[pc] < 5:
jumps[pc] += 1
pc = jmp
states[jmp] = new_state
if debug:
print()
print(state)
pprint.pprint(dict(item for item in states.items() if item[1]))
if yields:
result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))]
else:
result = reduce(union, Const.unwrap_all(returns))
finalize_hints(result)
if debug:
print(f, id(f), input_types, '->', result)
return result