mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor(tracing): add generic tracing functions for generic functions
- remove nearly duplicated code when tracing GenericFunction for unary and binary operators refs #965
This commit is contained in:
@@ -153,79 +153,46 @@ class NPTracer(BaseTracer):
|
||||
return self.__class__([], NPConstant(constant_data), 0)
|
||||
|
||||
@classmethod
|
||||
def _unary_operator(
|
||||
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
def _np_operator(
|
||||
cls,
|
||||
numpy_operator,
|
||||
numpy_operator_string,
|
||||
numpy_operator_nin,
|
||||
*input_tracers: "NPTracer",
|
||||
**kwargs,
|
||||
) -> "NPTracer":
|
||||
"""Trace an unary operator.
|
||||
"""Trace a numpy operator.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true(len(input_tracers) == 1)
|
||||
assert_true(len(input_tracers) == numpy_operator_nin)
|
||||
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
unary_operator,
|
||||
numpy_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
output_dtype, input_tracers[0].output.is_encrypted, output_shape
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(input_tracers)
|
||||
if not isinstance(pred.traced_computation, Constant)
|
||||
]
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, "
|
||||
f"got {non_constant_pred_count} for operator {numpy_operator}",
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[input_tracers[0].output],
|
||||
arbitrary_func=unary_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name=unary_operator_string,
|
||||
)
|
||||
output_tracer = cls(
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@classmethod
|
||||
def _binary_operator(
|
||||
cls, binary_operator, binary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
) -> "NPTracer":
|
||||
"""Trace a binary operator, supposing one of the input is a constant.
|
||||
|
||||
If no input is a constant, raises an error.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true(len(input_tracers) == 2)
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if isinstance(input_tracers[0].traced_computation, Constant):
|
||||
in_which_input_is_constant = 0
|
||||
elif isinstance(input_tracers[1].traced_computation, Constant):
|
||||
in_which_input_is_constant = 1
|
||||
else:
|
||||
raise NotImplementedError(f"Can't manage binary operator {binary_operator}")
|
||||
|
||||
in_which_input_is_variable = 1 - in_which_input_is_constant
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
binary_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
output_dtype,
|
||||
input_tracers[in_which_input_is_variable].output.is_encrypted,
|
||||
input_tracers[variable_input_idx].output.is_encrypted,
|
||||
output_shape,
|
||||
)
|
||||
|
||||
@@ -233,11 +200,11 @@ class NPTracer(BaseTracer):
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[input_tracer.output for input_tracer in input_tracers],
|
||||
arbitrary_func=binary_operator,
|
||||
arbitrary_func=numpy_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs=op_kwargs,
|
||||
op_name=binary_operator_string,
|
||||
op_name=numpy_operator_string,
|
||||
)
|
||||
output_tracer = cls(
|
||||
input_tracers,
|
||||
@@ -607,8 +574,8 @@ def _get_unary_fun(function: numpy.ufunc):
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
|
||||
function, f"{function.__name__}", *input_tracers, **kwargs
|
||||
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
|
||||
function, f"{function.__name__}", 1, *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@@ -619,8 +586,8 @@ def _get_binary_fun(function: numpy.ufunc):
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._binary_operator(
|
||||
function, f"{function.__name__}", *input_tracers, **kwargs
|
||||
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
|
||||
function, f"{function.__name__}", 2, *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@@ -687,7 +687,10 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator",
|
||||
):
|
||||
trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user