mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: update only int64 astype warning to ignore floating point types
This commit is contained in:
@@ -477,9 +477,10 @@ class Tracer:
|
||||
"""
|
||||
|
||||
normalized_dtype = np.dtype(dtype)
|
||||
if normalized_dtype != np.int64:
|
||||
if np.issubdtype(normalized_dtype, np.integer) and normalized_dtype != np.int64:
|
||||
print(
|
||||
"Warning: when using `value.astype(newtype)`, "
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
|
||||
@@ -14,7 +14,7 @@ def quantized_sin(x):
|
||||
# from a 7 bit unsigned integer x, compute z in the [0; 2 * pi] range
|
||||
z = 2 * numpy.pi * x * (1 / 127)
|
||||
# quantize over 6 bits and offset to be >= 0, round and convert to integers in range [0; 63]
|
||||
quantized_sin = numpy.rint(31 * numpy.sin(z) + 31).astype(numpy.int32)
|
||||
quantized_sin = numpy.rint(31 * numpy.sin(z) + 31).astype(numpy.int64)
|
||||
# output quantized_sin and a further offset result
|
||||
return quantized_sin, quantized_sin + 32
|
||||
```
|
||||
@@ -48,12 +48,12 @@ Here is an example benefiting from the expanded search:
|
||||
def fusable_with_bigger_search(x, y):
|
||||
"""fusable with bigger search"""
|
||||
x = x + 1
|
||||
x_1 = x.astype(numpy.int32)
|
||||
x_1 = x.astype(numpy.int64)
|
||||
x_1 = x_1 + 1.5
|
||||
x_2 = x.astype(numpy.int32)
|
||||
x_2 = x.astype(numpy.int64)
|
||||
x_2 = x_2 + 3.4
|
||||
add = x_1 + x_2
|
||||
add_int = add.astype(numpy.int32)
|
||||
add_int = add.astype(numpy.int64)
|
||||
return add_int + y
|
||||
```
|
||||
|
||||
@@ -79,7 +79,7 @@ def non_fusable(x, y):
|
||||
x_1 = x + 1.5 # x_1 is now float
|
||||
y_1 = y + 3.4 # y_1 is now float
|
||||
add = x_1 + y_1
|
||||
add_int = add.astype(numpy.int32)
|
||||
add_int = add.astype(numpy.int64)
|
||||
return add_int
|
||||
```
|
||||
|
||||
|
||||
@@ -139,4 +139,4 @@ The framework also gives support for:
|
||||
- true div, i.e., `x / y`: if one of `x` or `y` is a constant
|
||||
- floor div, i.e., `x // y`: if one of `x` or `y` is a constant
|
||||
|
||||
There is support for astype as well, e.g. `x.astype(numpy.int32)`. This allows to control which data type to use for computations. In the context of FHE going back to integers may allow to fuse floating point operations together, see [this tutorial](../tutorial/working_with_floating_points.md) to see how to work with floating point values.
|
||||
There is support for astype as well, e.g. `x.astype(numpy.int64)`. This allows to control which data type to use for computations. In the context of FHE going back to integers may allow to fuse floating point operations together, see [this tutorial](../tutorial/working_with_floating_points.md) to see how to work with floating point values.
|
||||
|
||||
@@ -123,7 +123,7 @@ artifacts = cnp.CompilationArtifacts("/tmp/custom/export/path")
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, artifacts=artifacts)
|
||||
def f(x):
|
||||
return 127 - (50 * (np.sin(x) + 1)).astype(np.uint32)
|
||||
return 127 - (50 * (np.sin(x) + 1)).astype(np.int64)
|
||||
|
||||
f.compile(range(2 ** 3))
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ Here is an example function that results in fused table lookup:
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
def f(x):
|
||||
return 127 - (50 * (np.sin(x) + 1)).astype(np.uint32) # astype is to go back to integer world
|
||||
return 127 - (50 * (np.sin(x) + 1)).astype(np.int64) # astype is to go back to integer world
|
||||
```
|
||||
|
||||
where
|
||||
@@ -125,5 +125,5 @@ which is calculated by:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
[(50 * (np.sin(x) + 1)).astype(np.uint32) for x in range(2 ** 3)]
|
||||
[(50 * (np.sin(x) + 1)).astype(np.int64) for x in range(2 ** 3)]
|
||||
```
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
# Function using floating points values converted back to integers at the end
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return np.fabs(50 * (2 * np.sin(x) * np.cos(x))).astype(np.uint32)
|
||||
return np.fabs(50 * (2 * np.sin(x) * np.cos(x))).astype(np.int64)
|
||||
# astype is to go back to the integer world
|
||||
|
||||
circuit = f.compile(range(64))
|
||||
|
||||
@@ -15,14 +15,14 @@ def fusable_with_bigger_search(x, y):
|
||||
|
||||
x = x + 1
|
||||
|
||||
x_1 = x.astype(np.int32)
|
||||
x_1 = x.astype(np.int64)
|
||||
x_1 = x_1 + 1.5
|
||||
|
||||
x_2 = x.astype(np.int32)
|
||||
x_2 = x.astype(np.int64)
|
||||
x_2 = x_2 + 3.4
|
||||
|
||||
add = x_1 + x_2
|
||||
add_int = add.astype(np.int32)
|
||||
add_int = add.astype(np.int64)
|
||||
|
||||
return add_int + y
|
||||
|
||||
@@ -36,17 +36,17 @@ def fusable_with_bigger_search_needs_second_iteration(x, y):
|
||||
x = x + 0.5
|
||||
x = np.cos(x)
|
||||
|
||||
x_1 = x.astype(np.int32)
|
||||
x_1 = x.astype(np.int64)
|
||||
x_1 = x_1 + 1.5
|
||||
|
||||
x_p = x + 1
|
||||
x_p2 = x_p + 1
|
||||
|
||||
x_2 = (x_p + x_p2).astype(np.int32)
|
||||
x_2 = (x_p + x_p2).astype(np.int64)
|
||||
x_2 = x_2 + 3.4
|
||||
|
||||
add = x_1 + x_2
|
||||
add_int = add.astype(np.int32)
|
||||
add_int = add.astype(np.int64)
|
||||
|
||||
return add_int + y
|
||||
|
||||
@@ -70,7 +70,7 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
t8 = np.add(t7, t2)
|
||||
t9 = np.rint(t8)
|
||||
t10 = np.clip(t9, t0, t1)
|
||||
t11 = t10.astype(np.int32)
|
||||
t11 = t10.astype(np.int64)
|
||||
return t11
|
||||
|
||||
def subgraph_24(x):
|
||||
@@ -133,7 +133,7 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
t38 = np.add(t37, t2)
|
||||
t39 = np.rint(t38)
|
||||
t40 = np.clip(t39, t0, t1)
|
||||
t41 = t40.astype(np.int32)
|
||||
t41 = t40.astype(np.int64)
|
||||
return t41
|
||||
|
||||
return function
|
||||
@@ -159,18 +159,18 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
id="127 // x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (x / 3).astype(np.uint8),
|
||||
lambda x: (x / 3).astype(np.int64),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 127]},
|
||||
},
|
||||
id="(x / 3).astype(np.uint8)",
|
||||
id="(x / 3).astype(np.int64)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (127 / x).astype(np.uint8),
|
||||
lambda x: (127 / x).astype(np.int64),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [1, 127]},
|
||||
},
|
||||
id="(127 / x).astype(np.uint8)",
|
||||
id="(127 / x).astype(np.int64)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x ** 2,
|
||||
@@ -383,18 +383,18 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
id="x.clip(5, 10)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (60 * np.sin(x)).astype(np.int8) + 60,
|
||||
lambda x: (60 * np.sin(x)).astype(np.int64) + 60,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 127]},
|
||||
},
|
||||
id="(60 * np.sin(x)).astype(np.int8) + 60",
|
||||
id="(60 * np.sin(x)).astype(np.int64) + 60",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.uint8),
|
||||
lambda x: ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 127]},
|
||||
},
|
||||
id="((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.uint8)",
|
||||
id="((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.maximum(x, [[10, 20], [30, 40], [50, 60]]),
|
||||
@@ -485,7 +485,7 @@ def test_others_bad_fusing(helpers):
|
||||
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration)
|
||||
def function1(x, y):
|
||||
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.uint8)
|
||||
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [(i, i) for i in range(100)]
|
||||
@@ -497,27 +497,27 @@ def test_others_bad_fusing(helpers):
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
%2 = 2 # ClearScalar<uint2>
|
||||
%3 = 2 # ClearScalar<uint2>
|
||||
%4 = x # EncryptedScalar<uint7>
|
||||
%5 = y # ClearScalar<uint7>
|
||||
%6 = sin(%4) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%7 = cos(%5) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%8 = power(%6, %2) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%9 = power(%7, %3) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%10 = multiply(%0, %8) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%11 = multiply(%1, %9) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%12 = add(%10, %11) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%13 = astype(%12, dtype=ubyte) # EncryptedScalar<uint4>
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
%2 = 2 # ClearScalar<uint2>
|
||||
%3 = 2 # ClearScalar<uint2>
|
||||
%4 = x # EncryptedScalar<uint7>
|
||||
%5 = y # ClearScalar<uint7>
|
||||
%6 = sin(%4) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%7 = cos(%5) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%8 = power(%6, %2) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%9 = power(%7, %3) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%10 = multiply(%0, %8) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%11 = multiply(%1, %9) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%12 = add(%10, %11) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint4>
|
||||
return %13
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -530,7 +530,7 @@ return %13
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
def function2(x):
|
||||
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int8)
|
||||
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = range(100)
|
||||
@@ -548,7 +548,7 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%3 = multiply(%2, %0) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%4 = astype(%3, dtype=byte) # EncryptedTensor<int3, shape=(2, 2)>
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<int3, shape=(2, 2)>
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -561,7 +561,7 @@ return %4
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
def function3(x):
|
||||
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.uint8)
|
||||
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)]
|
||||
@@ -580,7 +580,7 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%4 = astype(%3, dtype=ubyte) # EncryptedTensor<uint1, shape=(2, 3)>
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
|
||||
@@ -54,3 +54,96 @@ def test_tracer_bad_trace(function, parameters, expected_error, expected_message
|
||||
Tracer.trace(function, parameters)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.int8),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.int16),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.int32),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.uint8),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.uint16),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.uint32),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(np.uint64),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(3, 2))},
|
||||
(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracer_warning_trace(function, parameters, expected_message, capsys):
|
||||
"""
|
||||
Test `trace` function of `Tracer` class with parameters that result in a warning.
|
||||
"""
|
||||
|
||||
Tracer.trace(function, parameters)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == expected_message
|
||||
|
||||
Reference in New Issue
Block a user