From 2da3895f1aa4ea836613316a3fea68ee0989b732 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Wed, 6 Oct 2021 18:02:35 +0200 Subject: [PATCH] feat: add more bivariate operations which are supported if one of the operands is a constant refs #126 --- concrete/numpy/tracing.py | 67 +++++++++++-------- .../common/optimization/test_float_fusing.py | 34 +++++++++- tests/numpy/test_tracing.py | 5 +- 3 files changed, 74 insertions(+), 32 deletions(-) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 5cd78bc4f..47321b6b5 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -259,13 +259,13 @@ class NPTracer(BaseTracer): numpy.arctan, numpy.arctan2, numpy.arctanh, - # numpy.bitwise_and, - # numpy.bitwise_or, - # numpy.bitwise_xor, + numpy.bitwise_and, + numpy.bitwise_or, + numpy.bitwise_xor, numpy.cbrt, numpy.ceil, # numpy.conjugate, - # numpy.copysign, + numpy.copysign, numpy.cos, numpy.cosh, numpy.deg2rad, @@ -278,51 +278,51 @@ class NPTracer(BaseTracer): numpy.fabs, numpy.float_power, numpy.floor, - # numpy.floor_divide, - # numpy.fmax, - # numpy.fmin, - # numpy.fmod, + numpy.floor_divide, + numpy.fmax, + numpy.fmin, + numpy.fmod, # numpy.frexp, - # numpy.gcd, + numpy.gcd, # numpy.greater, # numpy.greater_equal, - # numpy.heaviside, - # numpy.hypot, + numpy.heaviside, + numpy.hypot, # numpy.invert, numpy.isfinite, numpy.isinf, numpy.isnan, # numpy.isnat, - # numpy.lcm, - # numpy.ldexp, - # numpy.left_shift, + numpy.lcm, + numpy.ldexp, + numpy.left_shift, # numpy.less, # numpy.less_equal, numpy.log, numpy.log10, numpy.log1p, numpy.log2, - # numpy.logaddexp, - # numpy.logaddexp2, - # numpy.logical_and, - # numpy.logical_not, - # numpy.logical_or, - # numpy.logical_xor, + numpy.logaddexp, + numpy.logaddexp2, + numpy.logical_and, + numpy.logical_not, + numpy.logical_or, + numpy.logical_xor, # numpy.matmul, - # numpy.maximum, - # numpy.minimum, + numpy.maximum, + numpy.minimum, # numpy.modf, # numpy.multiply, numpy.negative, - # numpy.nextafter, + numpy.nextafter, # numpy.not_equal, numpy.positive, - # numpy.power, + numpy.power, numpy.rad2deg, numpy.radians, numpy.reciprocal, - # numpy.remainder, - # numpy.right_shift, + numpy.remainder, + numpy.right_shift, numpy.rint, numpy.sign, numpy.signbit, @@ -334,7 +334,7 @@ class NPTracer(BaseTracer): # numpy.subtract, numpy.tan, numpy.tanh, - # numpy.true_divide, + numpy.true_divide, numpy.trunc, ] @@ -378,9 +378,18 @@ NPTracer.UFUNC_ROUTING = { fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1 } -NPTracer.UFUNC_ROUTING[numpy.arctan2] = _get_binary_fun(numpy.arctan2) -NPTracer.UFUNC_ROUTING[numpy.float_power] = _get_binary_fun(numpy.float_power) +NPTracer.UFUNC_ROUTING.update( + {fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2} +) +list_of_not_supported = [ + (ufunc.__name__, ufunc.nin) + for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC + if ufunc.nin not in [1, 2] +] + +custom_assert(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}") +del list_of_not_supported # We are adding initial support for `np.array(...)` +,-,* `BaseTracer` # (note that this is not the proper complete handling of these functions) diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 7f017ed7a..bbadfc9d3 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -189,16 +189,42 @@ def subtest_fuse_float_unary_operations_correctness(fun): assert function_to_trace(*inputs) == op_graph(*inputs) +LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { + numpy.bitwise_and, + numpy.bitwise_or, + numpy.bitwise_xor, + numpy.gcd, + numpy.lcm, + numpy.ldexp, + numpy.left_shift, + numpy.logical_and, + numpy.logical_not, + numpy.logical_or, + numpy.logical_xor, + numpy.remainder, + numpy.right_shift, +} + + def subtest_fuse_float_binary_operations_correctness(fun): """Test a binary functions with fuse_float_operations, with a constant as a source.""" for i in range(4): + # Know if the function is defined for integer inputs + if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES: + if i not in [0, 2]: + continue + + # The .astype(numpy.float64) that we have in cases 0 and 2 is here to force + # a float output even for functions which return an integer (eg, XOR), such + # that our frontend always try to fuse them + # For bivariate functions: fix one of the inputs if i == 0: # With an integer in first position def get_function_to_trace(): - return lambda x, y: fun(3, x + y).astype(numpy.int32) + return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32) elif i == 1: # With a float in first position @@ -208,7 +234,7 @@ def subtest_fuse_float_binary_operations_correctness(fun): elif i == 2: # With an integer in second position def get_function_to_trace(): - return lambda x, y: fun(x + y, 4).astype(numpy.int32) + return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32) else: # With a float in second position @@ -217,6 +243,10 @@ def subtest_fuse_float_binary_operations_correctness(fun): input_list = [0, 2, 42, 44] + # Domain of definition + if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]: + input_list = [2, 42, 44] + for input_ in input_list: function_to_trace = get_function_to_trace() diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index dffa405fc..7b8b9472a 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -406,8 +406,11 @@ def test_tracing_astype( ), ], ) +# numpy.logical_not is removed from the following test since it is expecting inputs which are +# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC @pytest.mark.parametrize( - "function_to_trace_def", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1] + "function_to_trace_def", + [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not], ) def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def): """Function to trace supported numpy ufuncs"""