mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(frontend-python): re-write bit width assignment
This commit is contained in:
@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers, pytestconfig):
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
|
||||
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
|
||||
%c2_i8 = arith.constant 2 : i8
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
|
||||
return %1 : !FHE.eint<11>
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
from concrete import fhe
|
||||
from concrete.fhe.mlir import GraphConverter
|
||||
|
||||
from ..conftest import USE_MULTI_PRECISION
|
||||
|
||||
|
||||
def assign(x, y):
|
||||
"""
|
||||
@@ -607,6 +609,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 18-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedScalar<uint5> ∈ [20, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 18-bits during compilation because of its relation with other operations)
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint21> ∈ [2000000, 2000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 21-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 21-bits during compilation because of its relation with other operations)
|
||||
@@ -633,6 +651,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint36> ∈ [40000000000, 40000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
@@ -665,6 +699,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint36, shape=(2, 2)> ∈ [40000000000, 50000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
@@ -758,6 +808,184 @@ def test_converter_bad_convert(
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_converter_convert_multi_precision(function, parameters, expected_mlir, helpers):
|
||||
"""
|
||||
Test `convert` method of `Converter` with multi precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=False)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_converter_convert_single_precision(function, parameters, expected_mlir, helpers):
|
||||
"""
|
||||
Test `convert` method of `Converter` with multi precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=True)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_graph",
|
||||
[
|
||||
@@ -769,7 +997,7 @@ def test_converter_bad_convert(
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
|
||||
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
|
||||
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
|
||||
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
|
||||
|
||||
Reference in New Issue
Block a user