refactor(frontend-python): re-write bit width assignment

This commit is contained in:
Umut
2023-08-02 17:34:47 +02:00
parent 9a5b08938e
commit 46f3de63cc
6 changed files with 539 additions and 222 deletions

View File

@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers, pytestconfig):
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
%c2_i8 = arith.constant 2 : i8
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
return %1 : !FHE.eint<11>

View File

@@ -8,6 +8,8 @@ import pytest
from concrete import fhe
from concrete.fhe.mlir import GraphConverter
from ..conftest import USE_MULTI_PRECISION
def assign(x, y):
"""
@@ -607,6 +609,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 18-bits during compilation because of its relation with other operations)
%1 = y # EncryptedScalar<uint5> ∈ [20, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 18-bits during compilation because of its relation with other operations)
%2 = multiply(%0, %1) # EncryptedScalar<uint21> ∈ [2000000, 2000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 21-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 21-bits during compilation because of its relation with other operations)
@@ -633,6 +651,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%1 = y # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%2 = dot(%0, %1) # EncryptedScalar<uint36> ∈ [40000000000, 40000000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
(note that it's assigned 36-bits during compilation because of its relation with other operations)
@@ -665,6 +699,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%1 = y # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%2 = matmul(%0, %1) # EncryptedTensor<uint36, shape=(2, 2)> ∈ [40000000000, 50000000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 36-bits during compilation because of its relation with other operations)
@@ -758,6 +808,184 @@ def test_converter_bad_convert(
helpers.check_str(expected_message, str(excinfo.value))
@pytest.mark.parametrize(
"function,parameters,expected_mlir",
[
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
},
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
}
""", # noqa: E501
),
],
)
def test_converter_convert_multi_precision(function, parameters, expected_mlir, helpers):
"""
Test `convert` method of `Converter` with multi precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=False)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
@pytest.mark.parametrize(
"function,parameters,expected_mlir",
[
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
},
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
}
""", # noqa: E501
),
],
)
def test_converter_convert_single_precision(function, parameters, expected_mlir, helpers):
"""
Test `convert` method of `Converter` with multi precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=True)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
@pytest.mark.parametrize(
"function,parameters,expected_graph",
[
@@ -769,7 +997,7 @@ def test_converter_bad_convert(
"""
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]