mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(frontend-python): add missing error messages on high bit-width encrypted dot and matmul
This commit is contained in:
@@ -757,13 +757,39 @@ class Context:
|
||||
|
||||
def dot(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
|
||||
if x.is_clear and y.is_clear:
|
||||
highlights = {
|
||||
highlights: Dict[Node, Union[str, List[str]]] = {
|
||||
x.origin: "lhs is clear",
|
||||
y.origin: "rhs is clear" if x.origin is not y.origin else "operand is clear",
|
||||
self.converting: "but clear-clear dot products are not supported",
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
if (x.is_encrypted and y.is_encrypted) and (
|
||||
x.bit_width > MAXIMUM_TLU_BIT_WIDTH or y.bit_width > MAXIMUM_TLU_BIT_WIDTH
|
||||
):
|
||||
highlights = {
|
||||
self.converting: [
|
||||
f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit "
|
||||
f"encrypted dot products are supported"
|
||||
],
|
||||
}
|
||||
|
||||
for operand in [x, y]:
|
||||
if operand.bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
highlights[operand.origin] = [
|
||||
f"this {operand.bit_width}-bit value "
|
||||
f"is used as an operand to an encrypted dot products"
|
||||
]
|
||||
if operand.bit_width != operand.original_bit_width:
|
||||
highlights[operand.origin].append( # type: ignore
|
||||
"("
|
||||
f"note that it's assigned {operand.bit_width}-bits "
|
||||
f"during compilation because of its relation with other operations"
|
||||
")"
|
||||
)
|
||||
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
if x.is_scalar or y.is_scalar:
|
||||
@@ -1298,19 +1324,44 @@ class Context:
|
||||
|
||||
def matmul(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
|
||||
if x.is_clear and y.is_clear:
|
||||
highlights = {
|
||||
highlights: Dict[Node, Union[str, List[str]]] = {
|
||||
x.origin: "lhs is clear",
|
||||
y.origin: "rhs is clear" if x.origin is not y.origin else "operand is clear",
|
||||
self.converting: "but clear-clear matrix multiplications are not supported",
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
if (x.is_encrypted and y.is_encrypted) and (
|
||||
x.bit_width > MAXIMUM_TLU_BIT_WIDTH or y.bit_width > MAXIMUM_TLU_BIT_WIDTH
|
||||
):
|
||||
highlights = {
|
||||
self.converting: [
|
||||
f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit "
|
||||
f"encrypted matrix multiplications are supported"
|
||||
],
|
||||
}
|
||||
|
||||
for operand in [x, y]:
|
||||
if operand.bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
highlights[operand.origin] = [
|
||||
f"this {operand.bit_width}-bit value "
|
||||
f"is used as an operand to an encrypted matrix multiplication"
|
||||
]
|
||||
if operand.bit_width != operand.original_bit_width:
|
||||
highlights[operand.origin].append( # type: ignore
|
||||
"("
|
||||
f"note that it's assigned {operand.bit_width}-bits "
|
||||
f"during compilation because of its relation with other operations"
|
||||
")"
|
||||
)
|
||||
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
if resulting_type.shape == ():
|
||||
if x.is_clear:
|
||||
x, y = y, x
|
||||
|
||||
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
|
||||
elif x.is_encrypted and y.is_encrypted:
|
||||
operation = fhelinalg.MatMulEintEintOp
|
||||
|
||||
@@ -615,6 +615,64 @@ Function you are trying to compile cannot be compiled
|
||||
(note that it's assigned 21-bits during compilation because of its relation with other operations)
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint21> ∈ [2000000, 2000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
[100_000, 200_000],
|
||||
[200_000, 100_000],
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint36> ∈ [40000000000, 40000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
[
|
||||
[100_000, 200_000],
|
||||
[200_000, 100_000],
|
||||
],
|
||||
[
|
||||
[100_000, 200_000],
|
||||
[200_000, 100_000],
|
||||
],
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint36, shape=(2, 2)> ∈ [40000000000, 50000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user