mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
83ea485fe1
commit
56e0ed4a11
@@ -45,7 +45,9 @@ class NPTracer(BaseTracer):
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}",
|
||||
)
|
||||
return tracing_func(*input_tracers, **kwargs)
|
||||
# Create constant tracers when needed
|
||||
sanitized_input_tracers = [self._sanitize(inp) for inp in input_tracers]
|
||||
return tracing_func(*sanitized_input_tracers, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
@@ -163,6 +165,61 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@classmethod
|
||||
def _binary_operator(
|
||||
cls, binary_operator, binary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
) -> "NPTracer":
|
||||
"""Trace a binary operator, supposing one of the input is a constant.
|
||||
|
||||
If no input is a constant, raises an error.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
custom_assert(len(input_tracers) == 2)
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if isinstance(input_tracers[0].traced_computation, Constant):
|
||||
in_which_input_is_constant = 0
|
||||
baked_constant = deepcopy(input_tracers[0].traced_computation.constant_data)
|
||||
elif isinstance(input_tracers[1].traced_computation, Constant):
|
||||
in_which_input_is_constant = 1
|
||||
baked_constant = deepcopy(input_tracers[1].traced_computation.constant_data)
|
||||
else:
|
||||
raise NotImplementedError(f"Can't manage binary operator {binary_operator}")
|
||||
|
||||
in_which_input_is_variable = 1 - in_which_input_is_constant
|
||||
|
||||
if in_which_input_is_constant == 0:
|
||||
|
||||
def arbitrary_func(x, baked_constant, **kwargs):
|
||||
return binary_operator(baked_constant, x, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
def arbitrary_func(x, baked_constant, **kwargs):
|
||||
return binary_operator(x, baked_constant, **kwargs)
|
||||
|
||||
common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers)
|
||||
custom_assert(len(common_output_dtypes) == 1)
|
||||
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
op_kwargs["baked_constant"] = baked_constant
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[in_which_input_is_variable].output,
|
||||
arbitrary_func=arbitrary_func,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=op_kwargs,
|
||||
op_name=binary_operator_string,
|
||||
)
|
||||
output_tracer = cls(
|
||||
(input_tracers[in_which_input_is_variable],),
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Trace numpy.dot.
|
||||
|
||||
@@ -188,8 +245,9 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
|
||||
# The commented functions are functions require more than a single argument
|
||||
numpy.absolute,
|
||||
# numpy.add,
|
||||
numpy.arccos,
|
||||
@@ -197,7 +255,7 @@ class NPTracer(BaseTracer):
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
# numpy.arctan2,
|
||||
numpy.arctan2,
|
||||
numpy.arctanh,
|
||||
# numpy.bitwise_and,
|
||||
# numpy.bitwise_or,
|
||||
@@ -216,7 +274,7 @@ class NPTracer(BaseTracer):
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
# numpy.float_power,
|
||||
numpy.float_power,
|
||||
numpy.floor,
|
||||
# numpy.floor_divide,
|
||||
# numpy.fmax,
|
||||
@@ -289,7 +347,7 @@ class NPTracer(BaseTracer):
|
||||
}
|
||||
|
||||
|
||||
def _get_fun(function: numpy.ufunc):
|
||||
def _get_unary_fun(function: numpy.ufunc):
|
||||
"""Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
@@ -301,32 +359,41 @@ def _get_fun(function: numpy.ufunc):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _get_binary_fun(function: numpy.ufunc):
|
||||
"""Wrap _binary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._binary_operator(
|
||||
function, f"np.{function.__name__}", *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
# We are populating NPTracer.UFUNC_ROUTING dynamically
|
||||
NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC}
|
||||
NPTracer.UFUNC_ROUTING = {
|
||||
fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1
|
||||
}
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.arctan2] = _get_binary_fun(numpy.arctan2)
|
||||
NPTracer.UFUNC_ROUTING[numpy.float_power] = _get_binary_fun(numpy.float_power)
|
||||
|
||||
|
||||
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
|
||||
# (note that this is not the proper complete handling of these functions)
|
||||
|
||||
|
||||
def _on_numpy_add(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__add__(rhs)
|
||||
|
||||
return rhs.__radd__(lhs)
|
||||
return lhs.__add__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_subtract(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__sub__(rhs)
|
||||
|
||||
return rhs.__rsub__(lhs)
|
||||
return lhs.__sub__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_multiply(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
return rhs.__rmul__(lhs)
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
|
||||
|
||||
@@ -146,45 +146,127 @@ def test_tensor_no_fuse():
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def test_fuse_float_operations_correctness():
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
with fuse_float_operations."""
|
||||
def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
"""Test a unary function with fuse_float_operations."""
|
||||
|
||||
for fun in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
|
||||
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
# With a float in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
|
||||
"""Test a binary function with fuse_float_operations, with no constant as
|
||||
a source."""
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x, y).astype(numpy.int32)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
|
||||
trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_ufunc_operations(fun):
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
if fun.nin == 1:
|
||||
subtest_fuse_float_unary_operations_correctness(fun)
|
||||
elif fun.nin == 2:
|
||||
subtest_fuse_float_binary_operations_correctness(fun)
|
||||
subtest_fuse_float_binary_operations_dont_support_two_variables(fun)
|
||||
else:
|
||||
raise NotImplementedError("Only unary and binary functions are tested for now")
|
||||
|
||||
@@ -195,11 +195,11 @@ def test_numpy_tracing_tensors():
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(5, 6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(7, 4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(4, 7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(3, 8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(9, 2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(10, 1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(11, 0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(0, 11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
|
||||
@@ -234,11 +234,11 @@ def test_numpy_explicit_tracing_tensors():
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(5, 6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(7, 4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(4, 7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(3, 8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(9, 2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(10, 1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(11, 0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(0, 11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
|
||||
@@ -406,44 +406,43 @@ def test_tracing_astype(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_trace_numpy_supported_ufuncs(inputs, expected_output_node):
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda,cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda,cell-var-from-loop
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda,cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda,cell-var-from-loop
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
|
||||
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
|
||||
# Boolean function
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False))
|
||||
else:
|
||||
|
||||
# Boolean function
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(
|
||||
Integer(8, is_signed=False)
|
||||
)
|
||||
else:
|
||||
# Function keeping more or less input type
|
||||
input_node_type = inputs["x"]
|
||||
|
||||
# Function keeping more or less input type
|
||||
input_node_type = inputs["x"]
|
||||
expected_output_node_type = deepcopy(input_node_type)
|
||||
|
||||
expected_output_node_type = deepcopy(input_node_type)
|
||||
expected_output_node_type.dtype.bit_width = max(
|
||||
expected_output_node_type.dtype.bit_width, 32
|
||||
)
|
||||
|
||||
expected_output_node_type.dtype.bit_width = max(
|
||||
expected_output_node_type.dtype.bit_width, 32
|
||||
)
|
||||
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_not_supported():
|
||||
@@ -516,15 +515,13 @@ def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expect
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions():
|
||||
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function):
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
|
||||
for np_function in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
|
||||
assert (
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
)
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
|
||||
Reference in New Issue
Block a user