From 624143106f9e5ac300434f07019daef29378d57b Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 25 Oct 2021 12:06:24 +0200 Subject: [PATCH] refactor(compilation): remove unnecessary check in compile.py refs #645 --- concrete/numpy/compile.py | 13 +- .../common/compilation/test_configuration.py | 2 - tests/numpy/test_compile.py | 137 ++++++++++-------- 3 files changed, 79 insertions(+), 73 deletions(-) diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 59c6fc1a5..39ebe40f6 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -2,7 +2,7 @@ import sys import traceback -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple import numpy from zamalang import CompilerEngine @@ -21,7 +21,6 @@ from ..common.mlir.utils import ( ) from ..common.operator_graph import OPGraph from ..common.optimization.topological import fuse_float_operations -from ..common.representation.intermediate import IntermediateNode from ..common.values import BaseValue from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( @@ -102,16 +101,6 @@ def _compile_numpy_function_into_op_graph_internal( if not check_op_graph_is_integer_program(op_graph): fuse_float_operations(op_graph, compilation_artifacts) - # TODO: To be removed once we support more than integers - offending_non_integer_nodes: List[IntermediateNode] = [] - op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes) - if not op_grap_is_int_prog: - raise ValueError( - f"{function_to_compile.__name__} cannot be compiled as it has nodes with either float" - f" inputs or outputs.\nOffending nodes : " - f"{', '.join(str(node) for node in offending_non_integer_nodes)}" - ) - # Find bounds with the inputset inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( op_graph, diff --git a/tests/common/compilation/test_configuration.py b/tests/common/compilation/test_configuration.py index daadbd307..454807a56 100644 --- a/tests/common/compilation/test_configuration.py +++ b/tests/common/compilation/test_configuration.py @@ -35,8 +35,6 @@ def simple_fuse_not_output(x): simple_fuse_not_output, True, id="simple_fuse_not_output", - marks=pytest.mark.xfail(strict=True), - # fails because it connot be compiled without topological optimizations ), ], ) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 95a76b6a7..95a21786e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -448,12 +448,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration): ((4, 8), (3, 4), (0, 4)), ["x", "y", "z"], ), - pytest.param( - no_fuse_unhandled, - ((-2, 2), (-2, 2)), - ["x", "y"], - marks=pytest.mark.xfail(strict=True, raises=ValueError), - ), pytest.param(complicated_topology, ((0, 10),), ["x"]), ], ) @@ -735,6 +729,7 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ) +# pylint: disable=line-too-long,unnecessary-lambda @pytest.mark.parametrize( "function,parameters,inputset,match", [ @@ -745,10 +740,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = Constant(1) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = x # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Sub(%0, %1) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%1 = x # EncryptedScalar>\n" # noqa: E501 + "%2 = Sub(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501 "return(%2)\n" ), ), @@ -759,10 +754,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = x # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = Constant(-1) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = x # EncryptedScalar>\n" # noqa: E501 + "%1 = Constant(-1) # ClearScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 "return(%2)\n" ), ), @@ -773,10 +768,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Add(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 "return(%2)\n" ), ), @@ -787,10 +782,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Add(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 "return(%2)\n" ), ), @@ -801,10 +796,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Mul(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%2 = Mul(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported\n" # noqa: E501 "return(%2)\n" ), ), @@ -815,15 +810,15 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = Constant(127) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Sub(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = Constant(127) # ClearScalar>\n" # noqa: E501 + "%1 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "%2 = Sub(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported\n" # noqa: E501 "return(%2)\n" ), ), pytest.param( - lambda x, y: numpy.dot(x, y), # pylint: disable=unnecessary-lambda + lambda x, y: numpy.dot(x, y), { "x": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)), "y": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)), @@ -843,12 +838,12 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = x # EncryptedTensor, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = y # EncryptedTensor, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = Dot(%0, %1) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = x # EncryptedTensor, shape=(1,)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 + "%1 = y # EncryptedTensor, shape=(1,)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 + "%2 = Dot(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported\n" # noqa: E501 "return(%2)\n" ), ), @@ -866,22 +861,44 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura "return(%1)\n" ), ), + pytest.param( + no_fuse_unhandled, + {"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))}, + [(i, i) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n\n" + "%0 = x # EncryptedScalar>\n" # noqa: E501 + "%1 = Constant(2.8) # ClearScalar>\n" + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 + "%2 = y # EncryptedScalar>\n" # noqa: E501 + "%3 = Constant(9.3) # ClearScalar>\n" + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 + "%4 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "%5 = Add(%2, %3) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "%6 = Add(%4, %5) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "%7 = astype(int32)(%6) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 + "return(%7)\n" + ), + ), ], ) +# pylint: enable=line-too-long,unnecessary-lambda def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration): """Test function compile_numpy_function_into_op_graph for a program with signed values""" - with pytest.raises(RuntimeError): - try: - compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - except RuntimeError as error: - assert str(error) == match - raise + with pytest.raises(RuntimeError) as excinfo: + compile_numpy_function( + function, + parameters, + inputset, + default_compilation_configuration, + ) + + assert str(excinfo.value) == match def test_fail_with_intermediate_signed_values(default_compilation_configuration): @@ -905,22 +922,24 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration) show_mlir=True, ) except RuntimeError as error: + # pylint: disable=line-too-long match = ( "function you are trying to compile isn't supported for MLIR lowering\n" "\n" - "%0 = y # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%1 = Constant(10) # ClearScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%2 = x # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%3 = np.negative(%2) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%4 = Mul(%3, %1) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%5 = np.absolute(%4) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 # pylint: disable=line-too-long - "%6 = astype(int32)(%5) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long - "%7 = Add(%6, %0) # EncryptedScalar>\n" # noqa: E501 # pylint: disable=line-too-long + "%0 = y # EncryptedScalar>\n" # noqa: E501 + "%1 = Constant(10) # ClearScalar>\n" # noqa: E501 + "%2 = x # EncryptedScalar>\n" # noqa: E501 + "%3 = np.negative(%2) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "%4 = Mul(%3, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "%5 = np.absolute(%4) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 + "%6 = astype(int32)(%5) # EncryptedScalar>\n" # noqa: E501 + "%7 = Add(%6, %0) # EncryptedScalar>\n" # noqa: E501 "return(%7)\n" ) + # pylint: enable=line-too-long assert str(error) == match raise