refactor(tracing): add generic tracing functions for generic functions

- remove nearly duplicated code when tracing GenericFunction for unary and
binary operators

refs #965
This commit is contained in:
Arthur Meyre
2021-11-22 18:14:23 +01:00
parent ba6207e71e
commit f1ed07d580
2 changed files with 32 additions and 62 deletions

View File

@@ -153,79 +153,46 @@ class NPTracer(BaseTracer):
return self.__class__([], NPConstant(constant_data), 0)
@classmethod
def _unary_operator(
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
def _np_operator(
cls,
numpy_operator,
numpy_operator_string,
numpy_operator_nin,
*input_tracers: "NPTracer",
**kwargs,
) -> "NPTracer":
"""Trace an unary operator.
"""Trace a numpy operator.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true(len(input_tracers) == 1)
assert_true(len(input_tracers) == numpy_operator_nin)
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(
unary_operator,
numpy_operator,
*input_tracers,
)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
generic_function_output_value = TensorValue(
output_dtype, input_tracers[0].output.is_encrypted, output_shape
variable_input_indices = [
idx
for idx, pred in enumerate(input_tracers)
if not isinstance(pred.traced_computation, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, "
f"got {non_constant_pred_count} for operator {numpy_operator}",
)
traced_computation = GenericFunction(
inputs=[input_tracers[0].output],
arbitrary_func=unary_operator,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=deepcopy(kwargs),
op_name=unary_operator_string,
)
output_tracer = cls(
input_tracers,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
@classmethod
def _binary_operator(
cls, binary_operator, binary_operator_string, *input_tracers: "NPTracer", **kwargs
) -> "NPTracer":
"""Trace a binary operator, supposing one of the input is a constant.
If no input is a constant, raises an error.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true(len(input_tracers) == 2)
# One of the inputs has to be constant
if isinstance(input_tracers[0].traced_computation, Constant):
in_which_input_is_constant = 0
elif isinstance(input_tracers[1].traced_computation, Constant):
in_which_input_is_constant = 1
else:
raise NotImplementedError(f"Can't manage binary operator {binary_operator}")
in_which_input_is_variable = 1 - in_which_input_is_constant
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(
binary_operator,
*input_tracers,
)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
variable_input_idx = variable_input_indices[0]
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
generic_function_output_value = TensorValue(
output_dtype,
input_tracers[in_which_input_is_variable].output.is_encrypted,
input_tracers[variable_input_idx].output.is_encrypted,
output_shape,
)
@@ -233,11 +200,11 @@ class NPTracer(BaseTracer):
traced_computation = GenericFunction(
inputs=[input_tracer.output for input_tracer in input_tracers],
arbitrary_func=binary_operator,
arbitrary_func=numpy_operator,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=op_kwargs,
op_name=binary_operator_string,
op_name=numpy_operator_string,
)
output_tracer = cls(
input_tracers,
@@ -607,8 +574,8 @@ def _get_unary_fun(function: numpy.ufunc):
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
function, f"{function.__name__}", *input_tracers, **kwargs
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
function, f"{function.__name__}", 1, *input_tracers, **kwargs
)
# pylint: enable=protected-access
@@ -619,8 +586,8 @@ def _get_binary_fun(function: numpy.ufunc):
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._binary_operator(
function, f"{function.__name__}", *input_tracers, **kwargs
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
function, f"{function.__name__}", 2, *input_tracers, **kwargs
)
# pylint: enable=protected-access

View File

@@ -687,7 +687,10 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_
params_names = signature(function_to_trace).parameters.keys()
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
with pytest.raises(
AssertionError,
match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator",
):
trace_numpy_function(
function_to_trace,
{