mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(mlir): implement mlir conversion of basic tensor operations
This commit is contained in:
@@ -71,18 +71,6 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_scalar(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a scalar.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a scalar
|
||||
"""
|
||||
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer.
|
||||
|
||||
@@ -112,6 +100,23 @@ def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is encrypted and is of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is encrypted and is of type unsigned Integer
|
||||
"""
|
||||
|
||||
return (
|
||||
value_to_check.is_encrypted
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""This file contains a wrapper class for direct table lookups."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Iterable, Tuple, Union
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
|
||||
from ..common_helpers import is_a_power_of_2
|
||||
from ..data_types.base import BaseDataType
|
||||
@@ -30,7 +30,7 @@ class LookupTable:
|
||||
self.table = table
|
||||
self.output_dtype = make_integer_to_hold(table, force_signed=False)
|
||||
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
def __getitem__(self, key: Union[int, Iterable, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
# we need to create an `GenericFunction` node
|
||||
# because the result will be determined during the runtime
|
||||
@@ -58,11 +58,58 @@ class LookupTable:
|
||||
return LookupTable._checked_indexing(key, self.table)
|
||||
|
||||
@staticmethod
|
||||
def _checked_indexing(x, table):
|
||||
def _check_index_out_of_range(x, table):
|
||||
if x < 0 or x >= len(table):
|
||||
raise ValueError(
|
||||
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
|
||||
f"(you should check your inputset)",
|
||||
)
|
||||
|
||||
return table[x]
|
||||
@staticmethod
|
||||
def _checked_indexing(x, table):
|
||||
"""Index `table` using `x`.
|
||||
|
||||
There is a single table and the indexing works with the following semantics:
|
||||
- when x == c
|
||||
- table[x] == table[c]
|
||||
- when x == [c1, c2]
|
||||
- table[x] == [table[c1], table[c2]]
|
||||
- when x == [[c1, c2], [c3, c4], [c5, c6]]
|
||||
- table[x] == [[table[c1], table[c2]], [table[c3], table[c4]], [table[c5], table[c6]]]
|
||||
|
||||
Args:
|
||||
x (Union[int, Iterable]): index to use
|
||||
table (Tuple[int, ...]): table to index
|
||||
|
||||
Returns:
|
||||
Union[int, List[int]]: result of indexing
|
||||
"""
|
||||
|
||||
if not isinstance(x, Iterable):
|
||||
LookupTable._check_index_out_of_range(x, table)
|
||||
return table[x]
|
||||
|
||||
def fill_result(partial_result: List[Any], partial_x: Iterable[Any]):
|
||||
"""Fill partial result with partial x.
|
||||
|
||||
This function implements the recursive indexing of nested iterables.
|
||||
|
||||
Args:
|
||||
partial_result (List[Any]): currently accumulated result
|
||||
partial_x (Iterable[Any]): current index to use
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
for item in partial_x:
|
||||
if isinstance(item, Iterable):
|
||||
partial_result.append([])
|
||||
fill_result(partial_result[-1], item)
|
||||
else:
|
||||
LookupTable._check_index_out_of_range(item, table)
|
||||
partial_result.append(table[item])
|
||||
|
||||
result = []
|
||||
fill_result(result, x)
|
||||
return result
|
||||
|
||||
@@ -12,14 +12,15 @@ from typing import cast
|
||||
import numpy
|
||||
from mlir.dialects import arith as arith_dialect
|
||||
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
from zamalang.dialects import hlfhe, hlfhelinalg
|
||||
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
value_is_scalar_integer,
|
||||
value_is_tensor_integer,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
@@ -27,26 +28,71 @@ from ..representation.intermediate import Add, Constant, Dot, GenericFunction, M
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def _convert_scalar_constant_op_to_single_element_tensor_constant_op(operation):
|
||||
"""Convert a scalar constant operation result to a dense tensor constant operation result.
|
||||
|
||||
see https://github.com/zama-ai/concretefhe-internal/issues/837.
|
||||
|
||||
This is a temporary workaround before the compiler natively supports
|
||||
`tensor + scalar`, `tensor - scalar`, `tensor * scalar` operations.
|
||||
|
||||
Example input = `%c3_i4 = arith.constant 3 : i4`
|
||||
Example output = `%cst = arith.constant dense<3> : tensor<1xi4>`
|
||||
|
||||
Args:
|
||||
operation: operation to convert
|
||||
|
||||
Returns:
|
||||
the converted operation
|
||||
"""
|
||||
|
||||
operation_str = str(operation)
|
||||
|
||||
constant_start_location = operation_str.find("arith.constant") + len("arith.constant") + 1
|
||||
constant_end_location = operation_str.find(f": {str(operation.type)}") - 1
|
||||
constant_value = operation_str[constant_start_location:constant_end_location]
|
||||
|
||||
resulting_type = RankedTensorType.get((1,), operation.type)
|
||||
value_attr = Attribute.parse(f"dense<{constant_value}> : tensor<1x{str(operation.type)}>")
|
||||
|
||||
return arith_dialect.ConstantOp(resulting_type, value_attr).result
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert an addition intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "addition should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "addition should have a single output")
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
both_of_the_inputs_are_encrypted = True
|
||||
ordered_preds = preds
|
||||
|
||||
for input_ in node.inputs:
|
||||
if value_is_tensor_integer(input_):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(input_):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if node.inputs[1].is_clear:
|
||||
both_of_the_inputs_are_encrypted = False
|
||||
if node.inputs[0].is_clear:
|
||||
both_of_the_inputs_are_encrypted = False
|
||||
ordered_preds = preds[::-1]
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
if both_of_the_inputs_are_encrypted:
|
||||
return _linalg_add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _linalg_add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
if both_of_the_inputs_are_encrypted:
|
||||
return _add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
@@ -61,7 +107,7 @@ def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate node with (eint, int)."""
|
||||
"""Convert an addition intermediate node with (eint, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintOp(
|
||||
@@ -71,17 +117,57 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate tensor node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(rhs.type).startswith("tensor"):
|
||||
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.AddEintIntOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def _linalg_add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate tensor node with (eint, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.AddEintOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
|
||||
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support subtraction between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
if value_is_clear_tensor_integer(node.inputs[0]):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_clear_scalar_integer(node.inputs[0]):
|
||||
is_convertible = False
|
||||
|
||||
if value_is_tensor_integer(node.inputs[1]):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(node.inputs[1]):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support subtraction between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
return _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
@@ -95,22 +181,46 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(lhs.type).startswith("tensor"):
|
||||
lhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(lhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.SubIntEintOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _mul_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
return _mul_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support multiplication between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
ordered_preds = preds
|
||||
|
||||
for input_ in node.inputs:
|
||||
if value_is_tensor_integer(input_):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(input_):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support multiplication between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if node.inputs[0].is_clear:
|
||||
ordered_preds = preds[::-1]
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
return _linalg_mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
@@ -124,6 +234,20 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(rhs.type).startswith("tensor"):
|
||||
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.MulEintIntOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a constant input."""
|
||||
value = node.outputs[0]
|
||||
@@ -177,12 +301,12 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
variable_input_value = node.inputs[variable_input_idx]
|
||||
|
||||
assert_true(len(node.outputs) == 1, "LUT should have a single output")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(variable_input_value):
|
||||
if not value_is_encrypted_unsigned_integer(variable_input_value):
|
||||
raise TypeError(
|
||||
f"Only support LUT with encrypted unsigned integers inputs "
|
||||
f"(but {variable_input_value} is provided)"
|
||||
)
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
|
||||
if not value_is_encrypted_unsigned_integer(node.outputs[0]):
|
||||
raise TypeError(
|
||||
f"Only support LUT with encrypted unsigned integers outputs "
|
||||
f"(but {node.outputs[0]} is provided)"
|
||||
@@ -209,11 +333,13 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
|
||||
dense_elem,
|
||||
).result
|
||||
return hlfhe.ApplyLookupTableEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width),
|
||||
x,
|
||||
tensor_lut,
|
||||
).result
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width)
|
||||
|
||||
if value_is_encrypted_tensor_integer(node.inputs[0]):
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
return hlfhelinalg.ApplyLookupTableEintOp(vec_type, x, tensor_lut).result
|
||||
return hlfhe.ApplyLookupTableEintOp(int_type, x, tensor_lut).result
|
||||
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
|
||||
@@ -10,7 +10,6 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_integer,
|
||||
value_is_scalar,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
from ..debugging import format_operation_graph
|
||||
@@ -46,18 +45,18 @@ def check_node_compatibility_with_mlir(
|
||||
|
||||
if isinstance(node, intermediate.Add): # constraints for addition
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar addition is supported"
|
||||
if not value_is_integer(inp):
|
||||
return "only integer addition is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Sub): # constraints for subtraction
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar subtraction is supported"
|
||||
if not value_is_integer(inp):
|
||||
return "only integer subtraction is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Mul): # constraints for multiplication
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar multiplication is supported"
|
||||
if not value_is_integer(inp):
|
||||
return "only integer multiplication is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Input): # constraints for inputs
|
||||
assert_true(len(outputs) == 1)
|
||||
@@ -85,8 +84,14 @@ def check_node_compatibility_with_mlir(
|
||||
)
|
||||
if node.op_name == "MultiTLU":
|
||||
return "direct multi table lookup is not supported for the time being"
|
||||
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
|
||||
return "only unsigned integer scalar lookup tables are supported"
|
||||
|
||||
if not value_is_unsigned_integer(inputs[0]):
|
||||
# this branch is not reachable because compilation fails during inputset evaluation
|
||||
if node.op_name == "TLU": # pragma: no cover
|
||||
return "only unsigned integer lookup tables are supported"
|
||||
|
||||
# e.g., `np.absolute is not supported for the time being`
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
else:
|
||||
return (
|
||||
f"{node.op_name} of kind {node.op_kind.value} is not supported for the time being"
|
||||
@@ -109,8 +114,8 @@ def check_node_compatibility_with_mlir(
|
||||
|
||||
if is_output:
|
||||
for out in outputs:
|
||||
if not value_is_scalar(out) or not value_is_unsigned_integer(out):
|
||||
return "only scalar unsigned integer outputs are supported"
|
||||
if not value_is_unsigned_integer(out):
|
||||
return "only unsigned integer outputs are supported"
|
||||
else:
|
||||
for out in outputs:
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
|
||||
@@ -22,8 +22,8 @@ def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_intermediate = x + 2.8
|
||||
y_intermediate = y + 9.3
|
||||
intermediate = x_intermediate + y_intermediate
|
||||
return intermediate.astype(numpy.int32)
|
||||
intermediate = x_intermediate - y_intermediate
|
||||
return (intermediate * 1.5).astype(numpy.int32)
|
||||
|
||||
|
||||
def identity_lut_generator(n):
|
||||
@@ -540,6 +540,228 @@ def test_compile_and_run_correctness(
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 8],
|
||||
[7, 2],
|
||||
[3, 6],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 7],
|
||||
[8, 1],
|
||||
[5, 6],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
|
||||
# this is easy with known constants but weird with variable things such as another input
|
||||
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
random.randint(0, (2 ** 3) - 1),
|
||||
)
|
||||
for _ in range(10)
|
||||
],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
2,
|
||||
),
|
||||
[
|
||||
[2, 9],
|
||||
[8, 3],
|
||||
[4, 7],
|
||||
],
|
||||
marks=pytest.mark.xfail(),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
)
|
||||
for _ in range(10)
|
||||
],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
[
|
||||
[1, 6],
|
||||
[2, 5],
|
||||
[3, 4],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 13],
|
||||
[8, 6],
|
||||
[5, 9],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 100 - x,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[100, 93],
|
||||
[94, 99],
|
||||
[98, 95],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[10, 8],
|
||||
[14, 14],
|
||||
[8, 25],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 2,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[0, 14],
|
||||
[12, 2],
|
||||
[4, 10],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[4, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[4, 14],
|
||||
[12, 1],
|
||||
[6, 5],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: LookupTable([2, 1, 3, 0])[x],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 1],
|
||||
[2, 1],
|
||||
[3, 0],
|
||||
],
|
||||
),
|
||||
[
|
||||
[2, 1],
|
||||
[3, 1],
|
||||
[0, 2],
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_tensor_correctness(
|
||||
function, parameters, inputset, test_input, expected_output, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input)
|
||||
assert numpy.array_equal(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(expected_output, dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size, input_range",
|
||||
[
|
||||
@@ -769,71 +991,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>
|
||||
%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 127 - x,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
@@ -892,22 +1050,27 @@ return(%1)
|
||||
[(numpy.array(i), numpy.array(i)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering\n
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%1 = Constant(2.8) # ClearScalar<Float<64 bits>>
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(1.5) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%3 = Constant(9.3) # ClearScalar<Float<64 bits>>
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%2 = Constant(2.8) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
|
||||
return(%7)
|
||||
%3 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%4 = Constant(9.3) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%5 = Add(%1, %2) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%6 = Add(%3, %4) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%7 = Sub(%5, %6) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
|
||||
%8 = Mul(%7, %0) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
|
||||
%9 = astype(int32)(%8) # EncryptedScalar<Integer<signed, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported for the time being
|
||||
return(%9)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
@@ -1057,7 +1220,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>
|
||||
%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being
|
||||
%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%7)
|
||||
@@ -1255,7 +1418,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user