diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index e8412de8a..56e7694fe 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -17,8 +17,8 @@ from ..operator_graph import OPGraph from ..representation import intermediate from ..representation.intermediate import Conv2D, IntermediateNode -# TODO: should come from compiler, through an API, #402 -ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 +# TODO: should be removed as the supported bit-width is now dynamic +ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 8 def check_node_compatibility_with_mlir( diff --git a/docs/user/howto/reduce_needed_precision.md b/docs/user/howto/reduce_needed_precision.md index d494212fa..018686112 100644 --- a/docs/user/howto/reduce_needed_precision.md +++ b/docs/user/howto/reduce_needed_precision.md @@ -38,7 +38,7 @@ Traceback (most recent call last): update_bit_width_for_mlir(op_graph) File "/home/default/Documents/Projects/Zama/hdk/concrete/common/mlir/utils.py", line 204, in update_bit_width_for_mlir raise RuntimeError( -RuntimeError: max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with: +RuntimeError: max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 8) which is not compatible with: %0 = x # EncryptedScalar %1 = 42 # ClearScalar diff --git a/tests/common/mlir/test_conversion_helpers.py b/tests/common/mlir/test_conversion_helpers.py index a2bd480a4..2f672206b 100644 --- a/tests/common/mlir/test_conversion_helpers.py +++ b/tests/common/mlir/test_conversion_helpers.py @@ -34,24 +34,6 @@ def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str): assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str -@pytest.mark.parametrize( - "integer,is_encrypted,expected_error_message", - [ - pytest.param(SignedInteger(32), True, "can't create eint with the given width"), - pytest.param(UnsignedInteger(32), True, "can't create eint with the given width"), - ], -) -def test_fail_integer_to_mlir_type(integer, is_encrypted, expected_error_message): - """Test function for failed integer to MLIR type conversion.""" - - with pytest.raises(ValueError) as excinfo: - with Context() as ctx, Location.unknown(): - concretelang.register_dialects(ctx) - integer_to_mlir_type(ctx, integer, is_encrypted) - - assert str(excinfo.value) == expected_error_message - - @pytest.mark.parametrize( "value,expected_mlir_type_str", [ @@ -92,14 +74,6 @@ def test_value_to_mlir_type(value, expected_mlir_type_str): EncryptedTensor(Float(32), shape=(2, 3)), "EncryptedTensor is not supported for MLIR conversion", ), - pytest.param( - EncryptedScalar(UnsignedInteger(32)), - "EncryptedScalar is not supported for MLIR conversion", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(32), shape=(2, 3)), - "EncryptedTensor is not supported for MLIR conversion", - ), ], ) def test_fail_value_to_mlir_type(value, expected_error_message): diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 4fe476ed0..43839abe3 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -181,6 +181,39 @@ def random_lut_7b(x): return table[x] +def random_lut_8b(x): + """8-bit random table lookup""" + + # fmt: off + table = LookupTable( + [ + 212, 38, 101, 238, 40, 147, 221, 83, 237, 177, 130, 149, 173, + 122, 217, 118, 77, 109, 180, 223, 84, 39, 83, 161, 44, 177, + 95, 28, 9, 171, 163, 114, 178, 50, 57, 142, 6, 78, 153, + 184, 170, 40, 94, 45, 237, 50, 220, 45, 97, 212, 174, 42, + 190, 93, 153, 212, 71, 153, 253, 207, 24, 171, 179, 86, 5, + 99, 92, 217, 134, 151, 154, 17, 95, 238, 194, 21, 73, 243, + 158, 242, 216, 176, 251, 46, 23, 20, 18, 107, 49, 87, 149, + 150, 96, 96, 245, 164, 56, 67, 63, 209, 70, 145, 95, 55, + 52, 34, 13, 71, 149, 86, 254, 50, 85, 95, 124, 179, 190, + 108, 137, 116, 245, 207, 83, 131, 128, 39, 232, 138, 198, 164, + 218, 19, 170, 134, 237, 165, 231, 178, 170, 65, 7, 225, 160, + 152, 33, 114, 0, 82, 232, 68, 3, 212, 107, 16, 140, 66, + 177, 102, 48, 234, 204, 164, 54, 23, 120, 164, 172, 5, 107, + 50, 210, 139, 228, 94, 188, 169, 90, 113, 77, 252, 235, 104, + 124, 70, 65, 186, 29, 59, 113, 147, 71, 17, 65, 255, 15, + 88, 24, 229, 74, 16, 90, 122, 146, 176, 33, 250, 175, 170, + 75, 81, 62, 120, 127, 109, 89, 37, 98, 152, 157, 18, 166, + 121, 248, 19, 243, 159, 161, 133, 166, 154, 126, 175, 65, 8, + 203, 252, 82, 145, 189, 143, 123, 245, 104, 197, 16, 222, 236, + 24, 139, 132, 32, 242, 94, 117, 231, 98 + ] + ) + # fmt: on + + return table[x] + + def multi_lut(x): """2-bit multi table lookup""" @@ -771,6 +804,7 @@ def test_compile_function_multiple_outputs( pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]), pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]), pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]), + pytest.param(lambda x: x + 128, ((0, 127),), ["x"]), ], ) def test_compile_and_run_correctness( @@ -1867,6 +1901,7 @@ def test_compile_and_run_matmul_correctness( pytest.param(identity_lut_generator(5), (5,), ["x"], id="identity function (5-bit)"), pytest.param(identity_lut_generator(6), (6,), ["x"], id="identity function (6-bit)"), pytest.param(identity_lut_generator(7), (7,), ["x"], id="identity function (7-bit)"), + pytest.param(identity_lut_generator(8), (8,), ["x"], id="identity function (8-bit)"), pytest.param(random_lut_1b, (1,), ["x"], id="random function (1-bit)"), pytest.param(random_lut_2b, (2,), ["x"], id="random function (2-bit)"), pytest.param(random_lut_3b, (3,), ["x"], id="random function (3-bit)"), @@ -1874,6 +1909,7 @@ def test_compile_and_run_matmul_correctness( pytest.param(random_lut_5b, (5,), ["x"], id="random function (5-bit)"), pytest.param(random_lut_6b, (6,), ["x"], id="random function (6-bit)"), pytest.param(random_lut_7b, (7,), ["x"], id="random function (7-bit)"), + pytest.param(random_lut_8b, (8,), ["x"], id="random function (8-bit)"), pytest.param(small_fused_table, (5,), ["x"], id="small fused table (5-bits)"), ], ) @@ -1918,15 +1954,15 @@ def test_compile_and_run_lut_correctness( "function,table,bit_width", [ pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)") - for n in range(1, 8) + for n in range(1, 9) ] + [ pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)") - for n in range(1, 8) + for n in range(1, 9) ] + [ pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)") - for n in range(1, 7) + for n in range(1, 8) ] + [ pytest.param(*weird_lut(3), 3, id="weird"), @@ -2391,7 +2427,7 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): } # A bit too much - input_ranges = [(0, 100), (0, 28)] + input_ranges = [(100, 200), (100, 200)] with pytest.raises(RuntimeError) as excinfo: compile_numpy_function( @@ -2405,12 +2441,12 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): str(excinfo.value) == """ -max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with: +max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 8) which is not compatible with: -%0 = x # EncryptedScalar -%1 = y # EncryptedScalar -%2 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being +%0 = x # EncryptedScalar +%1 = y # EncryptedScalar +%2 = add(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 9 bits is not supported for the time being return %2 """.strip() # noqa: E501 # pylint: disable=line-too-long @@ -2518,6 +2554,7 @@ def test_wrong_inputs(default_compilation_configuration): "function,input_ranges,list_of_arg_names", [ pytest.param(lambda x: (x + (-27)) + 32, ((0, 10),), ["x"]), + pytest.param(lambda x: (x + (-128)) + 373, ((0, 10),), ["x"]), pytest.param(lambda x: ((-3) * x) + (100 - (x + 1)), ((0, 10),), ["x"]), pytest.param( lambda x, y: (-1) * x + (-2) * y + 40,