diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index b305467b5..cdc377657 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -757,13 +757,39 @@ class Context: def dot(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: if x.is_clear and y.is_clear: - highlights = { + highlights: Dict[Node, Union[str, List[str]]] = { x.origin: "lhs is clear", y.origin: "rhs is clear" if x.origin is not y.origin else "operand is clear", self.converting: "but clear-clear dot products are not supported", } self.error(highlights) + if (x.is_encrypted and y.is_encrypted) and ( + x.bit_width > MAXIMUM_TLU_BIT_WIDTH or y.bit_width > MAXIMUM_TLU_BIT_WIDTH + ): + highlights = { + self.converting: [ + f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit " + f"encrypted dot products are supported" + ], + } + + for operand in [x, y]: + if operand.bit_width > MAXIMUM_TLU_BIT_WIDTH: + highlights[operand.origin] = [ + f"this {operand.bit_width}-bit value " + f"is used as an operand to an encrypted dot products" + ] + if operand.bit_width != operand.original_bit_width: + highlights[operand.origin].append( # type: ignore + "(" + f"note that it's assigned {operand.bit_width}-bits " + f"during compilation because of its relation with other operations" + ")" + ) + + self.error(highlights) + assert self.is_bit_width_compatible(resulting_type, x, y) if x.is_scalar or y.is_scalar: @@ -1298,19 +1324,44 @@ class Context: def matmul(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: if x.is_clear and y.is_clear: - highlights = { + highlights: Dict[Node, Union[str, List[str]]] = { x.origin: "lhs is clear", y.origin: "rhs is clear" if x.origin is not y.origin else "operand is clear", self.converting: "but clear-clear matrix multiplications are not supported", } self.error(highlights) + if (x.is_encrypted and y.is_encrypted) and ( + x.bit_width > MAXIMUM_TLU_BIT_WIDTH or y.bit_width > MAXIMUM_TLU_BIT_WIDTH + ): + highlights = { + self.converting: [ + f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit " + f"encrypted matrix multiplications are supported" + ], + } + + for operand in [x, y]: + if operand.bit_width > MAXIMUM_TLU_BIT_WIDTH: + highlights[operand.origin] = [ + f"this {operand.bit_width}-bit value " + f"is used as an operand to an encrypted matrix multiplication" + ] + if operand.bit_width != operand.original_bit_width: + highlights[operand.origin].append( # type: ignore + "(" + f"note that it's assigned {operand.bit_width}-bits " + f"during compilation because of its relation with other operations" + ")" + ) + + self.error(highlights) + assert self.is_bit_width_compatible(resulting_type, x, y) if resulting_type.shape == (): if x.is_clear: x, y = y, x - operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot elif x.is_encrypted and y.is_encrypted: operation = fhelinalg.MatMulEintEintOp diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 578110f00..b8ac31f0d 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -615,6 +615,64 @@ Function you are trying to compile cannot be compiled (note that it's assigned 21-bits during compilation because of its relation with other operations) %2 = multiply(%0, %1) # EncryptedScalar ∈ [2000000, 2000000] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported +return %2 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y: np.dot(x, y), + {"x": "encrypted", "y": "encrypted"}, + [ + ( + [100_000, 200_000], + [200_000, 100_000], + ) + ], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedTensor ∈ [100000, 200000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products + (note that it's assigned 36-bits during compilation because of its relation with other operations) +%1 = y # EncryptedTensor ∈ [100000, 200000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products + (note that it's assigned 36-bits during compilation because of its relation with other operations) +%2 = dot(%0, %1) # EncryptedScalar ∈ [40000000000, 40000000000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported +return %2 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y: x @ y, + {"x": "encrypted", "y": "encrypted"}, + [ + ( + [ + [100_000, 200_000], + [200_000, 100_000], + ], + [ + [100_000, 200_000], + [200_000, 100_000], + ], + ) + ], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedTensor ∈ [100000, 200000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication + (note that it's assigned 36-bits during compilation because of its relation with other operations) +%1 = y # EncryptedTensor ∈ [100000, 200000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication + (note that it's assigned 36-bits during compilation because of its relation with other operations) +%2 = matmul(%0, %1) # EncryptedTensor ∈ [40000000000, 50000000000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported return %2 """, # noqa: E501