diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 47321b6b5..d6e866bfe 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -249,9 +249,17 @@ class NPTracer(BaseTracer): # Supported functions are either univariate or bivariate for which one of the two # sources is a constant + # + # numpy.add, numpy.multiply and numpy.subtract are not there since already managed + # by leveled operations + # + # numpy.conjugate is not there since working on complex numbers + # + # numpy.isnat is not there since it is about timings + # + # numpy.divmod, numpy.modf and numpy.frexp are not there since output two values LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [ numpy.absolute, - # numpy.add, numpy.arccos, numpy.arccosh, numpy.arcsin, @@ -264,14 +272,12 @@ class NPTracer(BaseTracer): numpy.bitwise_xor, numpy.cbrt, numpy.ceil, - # numpy.conjugate, numpy.copysign, numpy.cos, numpy.cosh, numpy.deg2rad, numpy.degrees, - # numpy.divmod, - # numpy.equal, + numpy.equal, numpy.exp, numpy.exp2, numpy.expm1, @@ -282,22 +288,20 @@ class NPTracer(BaseTracer): numpy.fmax, numpy.fmin, numpy.fmod, - # numpy.frexp, numpy.gcd, - # numpy.greater, - # numpy.greater_equal, + numpy.greater, + numpy.greater_equal, numpy.heaviside, numpy.hypot, - # numpy.invert, + numpy.invert, numpy.isfinite, numpy.isinf, numpy.isnan, - # numpy.isnat, numpy.lcm, numpy.ldexp, numpy.left_shift, - # numpy.less, - # numpy.less_equal, + numpy.less, + numpy.less_equal, numpy.log, numpy.log10, numpy.log1p, @@ -311,11 +315,9 @@ class NPTracer(BaseTracer): # numpy.matmul, numpy.maximum, numpy.minimum, - # numpy.modf, - # numpy.multiply, numpy.negative, numpy.nextafter, - # numpy.not_equal, + numpy.not_equal, numpy.positive, numpy.power, numpy.rad2deg, @@ -331,7 +333,6 @@ class NPTracer(BaseTracer): numpy.spacing, numpy.sqrt, numpy.square, - # numpy.subtract, numpy.tan, numpy.tanh, numpy.true_divide, diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index bbadfc9d3..a44234d2f 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -38,8 +38,10 @@ def simple_fuse_output(x): return x.astype(numpy.float64).astype(numpy.int32) -def complex_fuse_indirect_input(function, x, y): - """Complex fuse""" +def mix_x_and_y_intricately_and_call_f(function, x, y): + """Mix x and y in an intricated way, that can't be simplified by + an optimizer eg, and then call function + """ intermediate = x + y intermediate = intermediate + 2 intermediate = intermediate.astype(numpy.float32) @@ -57,8 +59,8 @@ def complex_fuse_indirect_input(function, x, y): ) -def complex_fuse_direct_input(function, x, y): - """Complex fuse""" +def mix_x_and_y_and_call_f(function, x, y): + """Mix x and y and then call function""" x_p_1 = x + 0.1 x_p_2 = x + 0.2 x_p_3 = function(x_p_1 + x_p_2) @@ -72,6 +74,21 @@ def complex_fuse_direct_input(function, x, y): ) +def mix_x_and_y_into_integer_and_call_f(function, x, y): + """Mix x and y but keep the entry to function as an integer""" + x_p_1 = x + 1 + x_p_2 = x + 2 + x_p_3 = function(x_p_1 + x_p_2) + return ( + x_p_3.astype(numpy.int32), + x_p_2.astype(numpy.int32), + (x_p_2 + 3).astype(numpy.int32), + x_p_3.astype(numpy.int32) + 67, + y, + (y + 4.7).astype(numpy.int32) + 3, + ) + + @pytest.mark.parametrize( "function_to_trace,fused", [ @@ -80,14 +97,14 @@ def complex_fuse_direct_input(function, x, y): pytest.param(simple_fuse_not_output, True, id="no_fuse"), pytest.param(simple_fuse_output, True, id="no_fuse"), pytest.param( - lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y), + lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y), True, - id="complex_fuse_indirect_input_with_rint", + id="mix_x_and_y_intricately_and_call_f_with_rint", ), pytest.param( - lambda x, y: complex_fuse_direct_input(numpy.rint, x, y), + lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y), True, - id="complex_fuse_direct_input_with_rint", + id="mix_x_and_y_and_call_f_with_rint", ), ], ) @@ -152,13 +169,16 @@ def subtest_fuse_float_unary_operations_correctness(fun): # Some manipulation to avoid issues with domain of definitions of functions if fun == numpy.arccosh: input_list = [1, 2, 42, 44] - super_fun_list = [complex_fuse_direct_input] + super_fun_list = [mix_x_and_y_and_call_f] elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: input_list = [0, 0.1, 0.2] - super_fun_list = [complex_fuse_direct_input] + super_fun_list = [mix_x_and_y_and_call_f] + elif fun == numpy.invert: + input_list = [1, 2, 42, 44] + super_fun_list = [mix_x_and_y_into_integer_and_call_f] else: input_list = [0, 2, 42, 44] - super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input] + super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] for super_fun in super_fun_list: @@ -194,6 +214,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { numpy.bitwise_or, numpy.bitwise_xor, numpy.gcd, + numpy.invert, numpy.lcm, numpy.ldexp, numpy.left_shift, @@ -220,6 +241,10 @@ def subtest_fuse_float_binary_operations_correctness(fun): # a float output even for functions which return an integer (eg, XOR), such # that our frontend always try to fuse them + # The .astype(numpy.float64) that we have in cases 1 and 3 is here to force + # a float output even for functions which return a bool (eg, EQUAL), such + # that our frontend always try to fuse them + # For bivariate functions: fix one of the inputs if i == 0: # With an integer in first position @@ -229,7 +254,7 @@ def subtest_fuse_float_binary_operations_correctness(fun): elif i == 1: # With a float in first position def get_function_to_trace(): - return lambda x, y: fun(2.3, x + y).astype(numpy.int32) + return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32) elif i == 2: # With an integer in second position @@ -239,7 +264,7 @@ def subtest_fuse_float_binary_operations_correctness(fun): else: # With a float in second position def get_function_to_trace(): - return lambda x, y: fun(x + y, 5.7).astype(numpy.int32) + return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32) input_list = [0, 2, 42, 44] diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 7b8b9472a..01b188878 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -62,6 +62,7 @@ LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set( numpy.isinf, numpy.isnan, numpy.signbit, + numpy.logical_not, ] ) @@ -406,15 +407,17 @@ def test_tracing_astype( ), ], ) -# numpy.logical_not is removed from the following test since it is expecting inputs which are -# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC @pytest.mark.parametrize( "function_to_trace_def", - [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not], + [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1], ) def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def): """Function to trace supported numpy ufuncs""" + # numpy.invert is expecting inputs which are integer only + if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer): + return + # We really need a lambda (because numpy functions are not playing # nice with inspect.signature), but pylint and flake8 are not happy # with it