diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 077ffccdf..4d98ca709 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -477,9 +477,10 @@ class Tracer: """ normalized_dtype = np.dtype(dtype) - if normalized_dtype != np.int64: + if np.issubdtype(normalized_dtype, np.integer) and normalized_dtype != np.int64: print( - "Warning: when using `value.astype(newtype)`, " + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " "only use `np.int64` as the newtype " "to avoid unexpected overflows " "during inputset evaluation" diff --git a/docs/dev/explanation/float-fusing.md b/docs/dev/explanation/float-fusing.md index eb19536bf..c413aa637 100644 --- a/docs/dev/explanation/float-fusing.md +++ b/docs/dev/explanation/float-fusing.md @@ -14,7 +14,7 @@ def quantized_sin(x): # from a 7 bit unsigned integer x, compute z in the [0; 2 * pi] range z = 2 * numpy.pi * x * (1 / 127) # quantize over 6 bits and offset to be >= 0, round and convert to integers in range [0; 63] - quantized_sin = numpy.rint(31 * numpy.sin(z) + 31).astype(numpy.int32) + quantized_sin = numpy.rint(31 * numpy.sin(z) + 31).astype(numpy.int64) # output quantized_sin and a further offset result return quantized_sin, quantized_sin + 32 ``` @@ -48,12 +48,12 @@ Here is an example benefiting from the expanded search: def fusable_with_bigger_search(x, y): """fusable with bigger search""" x = x + 1 - x_1 = x.astype(numpy.int32) + x_1 = x.astype(numpy.int64) x_1 = x_1 + 1.5 - x_2 = x.astype(numpy.int32) + x_2 = x.astype(numpy.int64) x_2 = x_2 + 3.4 add = x_1 + x_2 - add_int = add.astype(numpy.int32) + add_int = add.astype(numpy.int64) return add_int + y ``` @@ -79,7 +79,7 @@ def non_fusable(x, y): x_1 = x + 1.5 # x_1 is now float y_1 = y + 3.4 # y_1 is now float add = x_1 + y_1 - add_int = add.astype(numpy.int32) + add_int = add.astype(numpy.int64) return add_int ``` diff --git a/docs/user/howto/numpy_support.md b/docs/user/howto/numpy_support.md index fa8e59dd3..6a731c73a 100644 --- a/docs/user/howto/numpy_support.md +++ b/docs/user/howto/numpy_support.md @@ -139,4 +139,4 @@ The framework also gives support for: - true div, i.e., `x / y`: if one of `x` or `y` is a constant - floor div, i.e., `x // y`: if one of `x` or `y` is a constant -There is support for astype as well, e.g. `x.astype(numpy.int32)`. This allows to control which data type to use for computations. In the context of FHE going back to integers may allow to fuse floating point operations together, see [this tutorial](../tutorial/working_with_floating_points.md) to see how to work with floating point values. +There is support for astype as well, e.g. `x.astype(numpy.int64)`. This allows to control which data type to use for computations. In the context of FHE going back to integers may allow to fuse floating point operations together, see [this tutorial](../tutorial/working_with_floating_points.md) to see how to work with floating point values. diff --git a/docs/user/tutorial/compilation_artifacts.md b/docs/user/tutorial/compilation_artifacts.md index 616f3e91f..d5be1b412 100644 --- a/docs/user/tutorial/compilation_artifacts.md +++ b/docs/user/tutorial/compilation_artifacts.md @@ -123,7 +123,7 @@ artifacts = cnp.CompilationArtifacts("/tmp/custom/export/path") @cnp.compiler({"x": "encrypted"}, artifacts=artifacts) def f(x): - return 127 - (50 * (np.sin(x) + 1)).astype(np.uint32) + return 127 - (50 * (np.sin(x) + 1)).astype(np.int64) f.compile(range(2 ** 3)) diff --git a/docs/user/tutorial/table_lookup.md b/docs/user/tutorial/table_lookup.md index b62b29ea9..8d7528aec 100644 --- a/docs/user/tutorial/table_lookup.md +++ b/docs/user/tutorial/table_lookup.md @@ -85,7 +85,7 @@ Here is an example function that results in fused table lookup: ```python def f(x): - return 127 - (50 * (np.sin(x) + 1)).astype(np.uint32) # astype is to go back to integer world + return 127 - (50 * (np.sin(x) + 1)).astype(np.int64) # astype is to go back to integer world ``` where @@ -125,5 +125,5 @@ which is calculated by: ```python -[(50 * (np.sin(x) + 1)).astype(np.uint32) for x in range(2 ** 3)] +[(50 * (np.sin(x) + 1)).astype(np.int64) for x in range(2 ** 3)] ``` diff --git a/docs/user/tutorial/working_with_floating_points.md b/docs/user/tutorial/working_with_floating_points.md index 11102bdf7..e9ae346f2 100644 --- a/docs/user/tutorial/working_with_floating_points.md +++ b/docs/user/tutorial/working_with_floating_points.md @@ -9,7 +9,7 @@ import numpy as np # Function using floating points values converted back to integers at the end @cnp.compiler({"x": "encrypted"}) def f(x): - return np.fabs(50 * (2 * np.sin(x) * np.cos(x))).astype(np.uint32) + return np.fabs(50 * (2 * np.sin(x) * np.cos(x))).astype(np.int64) # astype is to go back to the integer world circuit = f.compile(range(64)) diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 07e5b98e2..726525d4c 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -15,14 +15,14 @@ def fusable_with_bigger_search(x, y): x = x + 1 - x_1 = x.astype(np.int32) + x_1 = x.astype(np.int64) x_1 = x_1 + 1.5 - x_2 = x.astype(np.int32) + x_2 = x.astype(np.int64) x_2 = x_2 + 3.4 add = x_1 + x_2 - add_int = add.astype(np.int32) + add_int = add.astype(np.int64) return add_int + y @@ -36,17 +36,17 @@ def fusable_with_bigger_search_needs_second_iteration(x, y): x = x + 0.5 x = np.cos(x) - x_1 = x.astype(np.int32) + x_1 = x.astype(np.int64) x_1 = x_1 + 1.5 x_p = x + 1 x_p2 = x_p + 1 - x_2 = (x_p + x_p2).astype(np.int32) + x_2 = (x_p + x_p2).astype(np.int64) x_2 = x_2 + 3.4 add = x_1 + x_2 - add_int = add.astype(np.int32) + add_int = add.astype(np.int64) return add_int + y @@ -70,7 +70,7 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): t8 = np.add(t7, t2) t9 = np.rint(t8) t10 = np.clip(t9, t0, t1) - t11 = t10.astype(np.int32) + t11 = t10.astype(np.int64) return t11 def subgraph_24(x): @@ -133,7 +133,7 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): t38 = np.add(t37, t2) t39 = np.rint(t38) t40 = np.clip(t39, t0, t1) - t41 = t40.astype(np.int32) + t41 = t40.astype(np.int64) return t41 return function @@ -159,18 +159,18 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): id="127 // x", ), pytest.param( - lambda x: (x / 3).astype(np.uint8), + lambda x: (x / 3).astype(np.int64), { "x": {"status": "encrypted", "range": [0, 127]}, }, - id="(x / 3).astype(np.uint8)", + id="(x / 3).astype(np.int64)", ), pytest.param( - lambda x: (127 / x).astype(np.uint8), + lambda x: (127 / x).astype(np.int64), { "x": {"status": "encrypted", "range": [1, 127]}, }, - id="(127 / x).astype(np.uint8)", + id="(127 / x).astype(np.int64)", ), pytest.param( lambda x: x ** 2, @@ -383,18 +383,18 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): id="x.clip(5, 10)", ), pytest.param( - lambda x: (60 * np.sin(x)).astype(np.int8) + 60, + lambda x: (60 * np.sin(x)).astype(np.int64) + 60, { "x": {"status": "encrypted", "range": [0, 127]}, }, - id="(60 * np.sin(x)).astype(np.int8) + 60", + id="(60 * np.sin(x)).astype(np.int64) + 60", ), pytest.param( - lambda x: ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.uint8), + lambda x: ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64), { "x": {"status": "encrypted", "range": [0, 127]}, }, - id="((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.uint8)", + id="((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64)", ), pytest.param( lambda x: np.maximum(x, [[10, 20], [30, 40], [50, 60]]), @@ -485,7 +485,7 @@ def test_others_bad_fusing(helpers): @cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration) def function1(x, y): - return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.uint8) + return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = [(i, i) for i in range(100)] @@ -497,27 +497,27 @@ def test_others_bad_fusing(helpers): Function you are trying to compile cannot be converted to MLIR - %0 = 10 # ClearScalar - %1 = 10 # ClearScalar - %2 = 2 # ClearScalar - %3 = 2 # ClearScalar - %4 = x # EncryptedScalar - %5 = y # ClearScalar - %6 = sin(%4) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported - %7 = cos(%5) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported - %8 = power(%6, %2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported - %9 = power(%7, %3) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%10 = multiply(%0, %8) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%11 = multiply(%1, %9) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%12 = add(%10, %11) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%13 = astype(%12, dtype=ubyte) # EncryptedScalar + %0 = 10 # ClearScalar + %1 = 10 # ClearScalar + %2 = 2 # ClearScalar + %3 = 2 # ClearScalar + %4 = x # EncryptedScalar + %5 = y # ClearScalar + %6 = sin(%4) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported + %7 = cos(%5) # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported + %8 = power(%6, %2) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported + %9 = power(%7, %3) # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%10 = multiply(%0, %8) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%11 = multiply(%1, %9) # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%12 = add(%10, %11) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%13 = astype(%12, dtype=int_) # EncryptedScalar return %13 """, # noqa: E501 @@ -530,7 +530,7 @@ return %13 @cnp.compiler({"x": "encrypted"}, configuration=configuration) def function2(x): - return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int8) + return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = range(100) @@ -548,7 +548,7 @@ Function you are trying to compile cannot be converted to MLIR ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported %3 = multiply(%2, %0) # EncryptedTensor ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%4 = astype(%3, dtype=byte) # EncryptedTensor +%4 = astype(%3, dtype=int_) # EncryptedTensor return %4 """, # noqa: E501 @@ -561,7 +561,7 @@ return %4 @cnp.compiler({"x": "encrypted"}, configuration=configuration) def function3(x): - return np.abs(np.sin(x)).reshape((2, 3)).astype(np.uint8) + return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)] @@ -580,7 +580,7 @@ Function you are trying to compile cannot be converted to MLIR ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported %3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%4 = astype(%3, dtype=ubyte) # EncryptedTensor +%4 = astype(%3, dtype=int_) # EncryptedTensor return %4 """, # noqa: E501 diff --git a/tests/tracing/test_tracer.py b/tests/tracing/test_tracer.py index 7e06d65c3..82a3298c8 100644 --- a/tests/tracing/test_tracer.py +++ b/tests/tracing/test_tracer.py @@ -54,3 +54,96 @@ def test_tracer_bad_trace(function, parameters, expected_error, expected_message Tracer.trace(function, parameters) assert str(excinfo.value) == expected_message + + +@pytest.mark.parametrize( + "function,parameters,expected_message", + [ + pytest.param( + lambda x: x.astype(np.int8), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.int16), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.int32), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.uint8), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.uint16), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.uint32), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + pytest.param( + lambda x: x.astype(np.uint64), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))}, + ( + "Warning: When using `value.astype(newtype)` " + "with an integer newtype, " + "only use `np.int64` as the newtype " + "to avoid unexpected overflows " + "during inputset evaluation" + ), + ), + ], +) +def test_tracer_warning_trace(function, parameters, expected_message, capsys): + """ + Test `trace` function of `Tracer` class with parameters that result in a warning. + """ + + Tracer.trace(function, parameters) + + captured = capsys.readouterr() + assert captured.out.strip() == expected_message