feat(compiler): add matmul eint eint op

This commit is contained in:
Andrei Stoian
2023-04-12 17:33:14 +02:00
committed by Andrei Stoian
parent a5c679f0dc
commit 817ee6b637
23 changed files with 1522 additions and 235 deletions

View File

@@ -1,10 +1,21 @@
# Python Frontend
## Installation for end-users
End-users should install `concrete-python` using `pip`:
```shell
pip install concrete-python
```
## Setup for development
Developers that want to contribute to the Concrete-Python project can use the following
approach to setup their environment.
```shell
# clone the repository
git clone https://github.com/zama-ai/concrete.git
git clone https://github.com/zama-ai/concrete.git --recursive
cd concrete
# create virtual environment
@@ -19,6 +30,8 @@ cd ../../compilers/concrete-compiler/compiler
make python-bindings
# set bindings build directory as an environment variable
# *** NOTE ***: You must use the Release build of the compiler!
# For now, the Debug build is not compatible with concrete-python
export COMPILER_BUILD_DIRECTORY=$(pwd)/build
echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
@@ -26,3 +39,17 @@ echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
cd ../../../frontends/concrete-python
make pytest
```
### VSCode setup
Alternatively you can use VSCode to develop Concrete-Python:
Suppose the compiler bindings were built in `/home/zama/concrete/compilers/concrete-compiler/compiler/build`:
- Create a `.env` file in the concrete-python root directory
- Determine the absolute path of the local compiler repository, e.g. `/home/zama/concrete`. Replace this with your
path in the following two lines
- Add to it `PYTHONPATH=$(PYTHON_PATH):/home/zama/concrete/compilers/concrete-compiler/compiler/build/tools/concretelang/python_packages/concretelang_core/`
- Add to it `LD_PRELOAD=/home/zama/concrete/compilers/concrete-compiler/compiler/build/lib/libConcretelangRuntime.so`
You can now configure `pytest` in VScode and run the tests using the graphical interface.

View File

@@ -764,28 +764,40 @@ class Context:
}
self.error(highlights)
if x.is_encrypted and y.is_encrypted:
highlights = {
x.origin: "lhs is encrypted",
y.origin: (
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
),
self.converting: "but encrypted-encrypted dot products are not supported",
}
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_scalar or y.is_scalar:
return self.mul(resulting_type, x, y)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
if x.is_clear:
x, y = y, x
return self.operation(fhelinalg.Dot, resulting_type, x.result, y.result)
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
x = self.to_signed(x)
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
)
)
intermediate_result = self.operation(
operation,
signed_resulting_type,
x.result,
y.result,
)
return self.to_unsigned(intermediate_result)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
return self.operation(operation, resulting_type, x.result, y.result)
def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
assert self.is_bit_width_compatible(resulting_type, x)
@@ -1293,29 +1305,41 @@ class Context:
}
self.error(highlights)
if x.is_encrypted and y.is_encrypted:
highlights = {
x.origin: "lhs is encrypted",
y.origin: (
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
),
self.converting: "but encrypted-encrypted matrix multiplications are not supported",
}
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
if resulting_type.shape == ():
if x.is_clear:
x, y = y, x
operation = fhelinalg.Dot
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
elif x.is_encrypted and y.is_encrypted:
operation = fhelinalg.MatMulEintEintOp
else:
operation = fhelinalg.MatMulEintIntOp if x.is_encrypted else fhelinalg.MatMulIntEintOp
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
x = self.to_signed(x)
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
)
)
intermediate_result = self.operation(
operation,
signed_resulting_type,
x.result,
y.result,
)
return self.to_unsigned(intermediate_result)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
return self.operation(operation, resulting_type, x.result, y.result)
def maxpool2d(

View File

@@ -65,8 +65,8 @@ def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
MULTIPLY = {"multiply"}
ROUNDING = {"round_bit_pattern"}
MULTIPLY = {"multiply", "matmul"}
def max_encrypted_bitwidth_node(node: Node):
@@ -88,13 +88,38 @@ def max_encrypted_bitwidth_node(node: Node):
return normal_p + 1
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
return normal_p + 1
# For operations that use multiply, an additional bit
# needs to be added to the bitwidths of the inputs.
# For single precision circuits the max of the input / output
# precisions will be taken in required_encrypted_bitwidth. For
# multi-precision, the circuit partitions will handle the
# input and output precisions separately.
all_inp_bitwidths = []
# Need a loop here to allow typechecking and make mypy happy
for inp in node.inputs:
dtype_inp = inp.dtype
assert isinstance(dtype_inp, Integer)
all_inp_bitwidths.append(dtype_inp.bit_width)
normal_p = max(all_inp_bitwidths)
# FIXME: This probably does not work well with multi-precision!
return max(normal_p + 1, node.output.dtype.bit_width)
return normal_p
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
"""Give the minimal precision to implement all the nodes."""
"""Give the minimal precision to implement all the nodes.
This function is called for both single-precision (for the whole circuit)
and for multi-precision circuits (for circuit partitions).
Ops for which the compiler introduces TLUs need to be handled explicitly
in `max_encrypted_bitwidth_node`. The maximum
of all precisions of the various operations is returned.
"""
bitwidths = map(max_encrypted_bitwidth_node, nodes)
return max(bitwidths, default=-1)

View File

@@ -23,7 +23,7 @@ def test_dot(size, helpers):
cst = np.random.randint(0, bound, size=(size,))
@fhe.compiler({"x": "encrypted"})
def left_function(x):
def dot_enc_enc_function(x):
return np.dot(x, cst)
@fhe.compiler({"x": "encrypted"})
@@ -36,12 +36,55 @@ def test_dot(size, helpers):
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
left_function_circuit = left_function.compile(inputset, configuration)
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
right_function_circuit = right_function.compile(inputset, configuration)
method_circuit = method.compile(inputset, configuration)
sample = np.random.randint(0, bound, size=(size,))
helpers.check_execution(left_function_circuit, left_function, sample)
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
helpers.check_execution(right_function_circuit, right_function, sample)
helpers.check_execution(method_circuit, method, sample)
@pytest.mark.parametrize(
"size",
[1, 10],
)
@pytest.mark.parametrize(
"bitwidth",
[2, 6],
)
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("negative_only", [True, False])
def test_dot_enc_enc(size, bitwidth, negative_only, signed, helpers):
"""
Test dot.
"""
configuration = helpers.configuration()
minv = 0 if not signed else -(2 ** (bitwidth - 1))
# +1 since randint max is not inclusive
maxv = 2**bitwidth if not signed else 2 ** (bitwidth - 1)
if negative_only:
maxv = 1
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def dot_enc_enc_function(x, y):
return np.dot(x, y)
inputset = [
(np.random.randint(minv, maxv, size=(size,)), np.random.randint(minv, maxv, size=(size,)))
for i in range(100)
]
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
sample = [
np.random.randint(minv, maxv, size=(size,)),
np.random.randint(minv, maxv, size=(size,)),
]
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)

View File

@@ -16,6 +16,11 @@ from concrete import fhe
(2, 3),
(0, 3),
),
pytest.param(
(3, 2),
(2, 3),
(0, 127),
),
pytest.param(
(1, 2),
(2, 1),
@@ -46,6 +51,11 @@ from concrete import fhe
(5, 5),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(-127, 127),
),
pytest.param(
(5,),
(5, 3),
@@ -59,7 +69,7 @@ from concrete import fhe
pytest.param(
(5,),
(4, 5, 3),
(0, 5),
(-5, 5),
),
pytest.param(
(4, 5, 3),
@@ -74,7 +84,7 @@ from concrete import fhe
pytest.param(
(2, 4, 5, 3),
(3,),
(0, 5),
(-1, 5),
),
pytest.param(
(5, 4, 3),
@@ -156,3 +166,200 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
helpers.check_execution(rhs_operator_circuit, rhs_operator, rhs_sample)
helpers.check_execution(lhs_function_circuit, lhs_function, lhs_sample)
helpers.check_execution(rhs_function_circuit, rhs_function, rhs_sample)
@pytest.mark.parametrize(
"lhs_shape,rhs_shape,bounds",
[
pytest.param(
(3, 2),
(2, 3),
(0, 3),
),
pytest.param(
(3, 2),
(2, 3),
(0, 127),
),
pytest.param(
(1, 2),
(2, 1),
(0, 3),
),
pytest.param(
(3, 3),
(3, 3),
(0, 3),
),
pytest.param(
(2, 1),
(1, 2),
(0, 7),
),
pytest.param(
(2,),
(2,),
(0, 7),
),
pytest.param(
(5, 5),
(5,),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(-63, 63),
),
pytest.param(
(2,),
(2, 7),
(-63, 0),
),
pytest.param(
(5,),
(5, 3),
(0, 3),
),
pytest.param(
(5, 3),
(3,),
(0, 3),
),
pytest.param(
(5,),
(4, 5, 3),
(-5, 5),
),
pytest.param(
(4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5,),
(2, 4, 5, 3),
(0, 5),
),
pytest.param(
(2, 4, 5, 3),
(3,),
(-1, 5),
),
pytest.param(
(5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
pytest.param(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
pytest.param(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
],
)
def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
"""
Test matmul.
"""
configuration = helpers.configuration()
minimum, maximum = bounds
# Matmul of clear values and encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "clear"})
def lhs_operator_clear(x, y):
return x @ y
# Matmul of two encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
return np.matmul(x, y)
# Put all the dual operand functions in a list
# FIXME: add lhs_operator_clear to this list to
# re-enable the testing with clear values
dual_operand_functions = [enc_function_xy]
# Compile each dual operand function and test it on random data
for func in dual_operand_functions:
dual_op_inputset = [
(
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
)
for i in range(100)
]
dual_op_circuit = func.compile(dual_op_inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
minimum, maximum, size=lhs_shape
), np.random.randint(minimum, maximum, size=rhs_shape)
helpers.check_execution(dual_op_circuit, func, [lhs_sample, rhs_sample])
@pytest.mark.parametrize("bitwidth", [4, 10])
@pytest.mark.parametrize("signed", [True, False])
def test_matmul_zero(bitwidth, signed, helpers):
"""
Test matmul.
"""
lhs_shape = (2, 1)
rhs_shape = (1, 2)
range_lhs = (-(2 ** (bitwidth - 1)), 2 ** (bitwidth - 1) - 1) if signed else (0, 2**bitwidth)
configuration = helpers.configuration()
# Matmul of two encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
return x * y
dual_op_inputset = [
(
np.random.randint(range_lhs[0], range_lhs[1], size=lhs_shape),
np.zeros(rhs_shape, dtype=np.int64),
)
for i in range(100)
]
dual_op_circuit = enc_function_xy.compile(dual_op_inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
range_lhs[0], range_lhs[1], size=lhs_shape
), np.zeros(rhs_shape, dtype=np.int64)
helpers.check_execution(dual_op_circuit, enc_function_xy, [lhs_sample, rhs_sample])

View File

@@ -152,6 +152,26 @@ def test_constant_mul(function, parameters, helpers):
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
},
{
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (3, 1)},
},
{
"x": {"range": [10, 20], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 2 * 3 - 1], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (1, 3)},
},
],
)
def test_mul(function, parameters, helpers):

View File

@@ -359,30 +359,6 @@ return %3
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{"x": "encrypted", "y": "encrypted"},
[
(
np.ones(shape=(3,), dtype=np.int64),
np.ones(shape=(3,), dtype=np.int64),
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = dot(%0, %1) # EncryptedScalar<uint2> ∈ [3, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted dot products are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "clear", "y": "clear"},
@@ -398,25 +374,6 @@ Function you are trying to compile cannot be compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = matmul(%0, %1) # ClearTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear matrix multiplications are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "encrypted", "y": "encrypted"},
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = matmul(%0, %1) # EncryptedTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted matrix multiplications are not supported
return %2
""", # noqa: E501