mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support TLUs with signed input
- this is done with an absolute ugly hack closes #923 refs #1001
This commit is contained in:
@@ -270,7 +270,7 @@ class IntermediateNodeConverter:
|
||||
value = self.node.inputs[variable_input_index]
|
||||
assert_true(value.is_encrypted)
|
||||
|
||||
if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover
|
||||
if not isinstance(value.dtype, Integer): # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how compilation works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
|
||||
|
||||
@@ -89,17 +89,6 @@ def check_node_compatibility_with_mlir(
|
||||
)
|
||||
== 1
|
||||
)
|
||||
|
||||
if not value_is_unsigned_integer(inputs[0]):
|
||||
# this branch is not reachable because compilation fails during inputset evaluation
|
||||
if node.op_name == "TLU": # pragma: no cover
|
||||
return "only unsigned integer lookup tables are supported"
|
||||
|
||||
if node.op_name == "MultiTLU": # pragma: no cover
|
||||
return "only unsigned integer multi lookup tables are supported"
|
||||
|
||||
# e.g., `np.absolute is not supported for the time being`
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
else:
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
|
||||
|
||||
@@ -370,8 +370,8 @@ class GenericFunction(IntermediateNode):
|
||||
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
|
||||
"""Get the table for the current input value of this GenericFunction.
|
||||
|
||||
This function only works if the GenericFunction variable input value is an unsigned Integer.
|
||||
It only works if there is a single variable input node among ordered_preds.
|
||||
This function only works if the GenericFunction variable input value is an Integer.
|
||||
This function only works if there is a single variable input node among ordered_preds.
|
||||
|
||||
Args:
|
||||
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
|
||||
@@ -393,16 +393,12 @@ class GenericFunction(IntermediateNode):
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_dtype = self.inputs[variable_input_idx].dtype
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
# Check the input is an integer to be able to build a table
|
||||
assert_true(
|
||||
isinstance(variable_input_dtype, Integer),
|
||||
f"{self.get_table.__name__} only works for an unsigned Integer input",
|
||||
)
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
assert_true(
|
||||
not variable_input_dtype.is_signed,
|
||||
f"{self.get_table.__name__} only works for an unsigned Integer input",
|
||||
)
|
||||
|
||||
input_value_constructor = self.inputs[0].underlying_constructor
|
||||
if input_value_constructor is None:
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy
|
||||
from zamalang import CompilerEngine
|
||||
@@ -21,7 +22,8 @@ from ..common.mlir.utils import (
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.values import BaseValue
|
||||
from ..common.representation.intermediate import Add, Constant, GenericFunction
|
||||
from ..common.values import BaseValue, ClearScalar
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
@@ -311,7 +313,66 @@ def compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
raise
|
||||
|
||||
|
||||
def prepare_op_graph_for_mlir(op_graph):
|
||||
# HACK
|
||||
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001 is
|
||||
# done
|
||||
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015
|
||||
def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
|
||||
"""Hack the op_graph to add offsets to signed inputs to TLUs.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to hack.
|
||||
"""
|
||||
# Ugly hack to add an offset before entering a TLU if its variable input node has a signed
|
||||
# output.
|
||||
# This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR.
|
||||
# This does not update the TLU input values to allow for proper table generation.
|
||||
# Thankfully we are not supposed to touch the op_graph beyond that point
|
||||
for node in list((nx_graph := op_graph.graph).nodes):
|
||||
if isinstance(node, GenericFunction):
|
||||
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, (pred, _) in enumerate(ordered_preds_and_inputs)
|
||||
if not isinstance(pred, Constant)
|
||||
]
|
||||
assert_true(len(variable_input_indices) == 1)
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_node = ordered_preds_and_inputs[variable_input_idx][0]
|
||||
variable_input_value = variable_input_node.outputs[0]
|
||||
variable_input_dtype = variable_input_value.dtype
|
||||
assert_true(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
if not variable_input_dtype.is_signed:
|
||||
continue
|
||||
|
||||
# input_bit_width + 1 to be MLIR compliant
|
||||
input_bit_width = variable_input_dtype.bit_width
|
||||
mlir_compliant_int_type = Integer(input_bit_width + 1, True)
|
||||
|
||||
# Manually fix the output values to be MLIR compliant
|
||||
# offset_constant is set to abs(min_value) for the variable input so that the values
|
||||
# [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed
|
||||
# TLU to an actual unsigned TLU. The get_table function creates the table from the min
|
||||
# value to the max value. As we keep the input value as a signed value, it will be from
|
||||
# - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding
|
||||
# values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been
|
||||
# shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in
|
||||
# the lambda function of the GenericFunction.
|
||||
offset_constant = Constant(abs(variable_input_dtype.min_value()))
|
||||
offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type)
|
||||
add_offset = Add(
|
||||
[deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))]
|
||||
)
|
||||
add_offset.outputs[0] = deepcopy(variable_input_value)
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0)
|
||||
nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0)
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0)
|
||||
|
||||
|
||||
def prepare_op_graph_for_mlir(op_graph: OPGraph):
|
||||
"""Prepare OPGraph for MLIR lowering.
|
||||
|
||||
This includes checking compatibility, changing bit-widths, and modifying lookup tables.
|
||||
@@ -337,6 +398,12 @@ def prepare_op_graph_for_mlir(op_graph):
|
||||
# TODO: workaround extend LUT #359
|
||||
extend_direct_lookup_tables(op_graph)
|
||||
|
||||
# HACK
|
||||
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001
|
||||
# is done
|
||||
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015
|
||||
hack_offset_negative_inputs_to_lookup_tables(op_graph)
|
||||
|
||||
|
||||
def _compile_numpy_function_internal(
|
||||
function_to_compile: Callable,
|
||||
|
||||
@@ -267,8 +267,8 @@ def check_is_good_execution(compiler_engine, function, args, verbose=True):
|
||||
# Bad computation after nb_tries
|
||||
raise AssertionError(
|
||||
f"bad computation after {nb_tries} tries, which was supposed to happen with a "
|
||||
f"probability of {expected_bad_luck}.\nLast engine result: {last_engine_result} "
|
||||
f"last function result: {last_function_result}"
|
||||
f"probability of {expected_bad_luck}.\nLast engine result:\n{last_engine_result}\n"
|
||||
f"Last function result:\n{last_function_result}"
|
||||
)
|
||||
|
||||
|
||||
@@ -725,7 +725,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
"function,parameters,inputset,test_input,use_check_good_exec",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
@@ -740,11 +740,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 8],
|
||||
[7, 2],
|
||||
[3, 6],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -759,11 +755,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 7],
|
||||
[8, 1],
|
||||
[5, 6],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
@@ -786,11 +778,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
],
|
||||
2,
|
||||
),
|
||||
[
|
||||
[2, 9],
|
||||
[8, 3],
|
||||
[4, 7],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
@@ -817,11 +805,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[3, 4],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 13],
|
||||
[8, 6],
|
||||
[5, 9],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 100 - x,
|
||||
@@ -836,11 +820,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[100, 93],
|
||||
[94, 99],
|
||||
[98, 95],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
|
||||
@@ -855,11 +835,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[10, 8],
|
||||
[14, 14],
|
||||
[8, 25],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 2,
|
||||
@@ -874,11 +850,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[0, 14],
|
||||
[12, 2],
|
||||
[4, 10],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -893,11 +865,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[4, 14],
|
||||
[12, 1],
|
||||
[6, 5],
|
||||
],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: LookupTable([2, 1, 3, 0])[x],
|
||||
@@ -912,11 +880,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[3, 0],
|
||||
],
|
||||
),
|
||||
[
|
||||
[2, 1],
|
||||
[3, 1],
|
||||
[0, 2],
|
||||
],
|
||||
True,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(x, 2),
|
||||
@@ -925,7 +889,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(2, x),
|
||||
@@ -934,7 +898,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.clip(x, 1, 5),
|
||||
@@ -949,11 +913,22 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 5],
|
||||
[5, 1],
|
||||
[2, 5],
|
||||
],
|
||||
True,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.clip(x + (-4), -3, 5) + 3,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
True,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.clip(1, 5),
|
||||
@@ -968,16 +943,32 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 5],
|
||||
[5, 1],
|
||||
[2, 5],
|
||||
],
|
||||
True,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (x + (-4)).clip(-3, 5) + 3,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_tensor_correctness(
|
||||
function, parameters, inputset, test_input, expected_output, default_compilation_configuration
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
use_check_good_exec,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
@@ -987,14 +978,18 @@ def test_compile_and_run_tensor_correctness(
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = (
|
||||
numpy_test_input = tuple(
|
||||
item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8)
|
||||
for item in test_input
|
||||
)
|
||||
assert numpy.array_equal(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(expected_output, dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
if use_check_good_exec:
|
||||
check_is_good_execution(circuit, function, numpy_test_input)
|
||||
else:
|
||||
assert numpy.array_equal(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(function(*numpy_test_input), dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1544,47 +1539,6 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
|
||||
def test_fail_with_intermediate_signed_values(default_compilation_configuration):
|
||||
"""Test function with failing compilation due to intermediate signed integers."""
|
||||
|
||||
def function(x, y):
|
||||
z = numpy.abs(10 * numpy.negative(x))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
try:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(2, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(2, is_signed=False)),
|
||||
},
|
||||
[(i, j) for i in range(2 ** 2) for j in range(2 ** 2)],
|
||||
default_compilation_configuration,
|
||||
show_mlir=True,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
match = """
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = y # EncryptedScalar<uint2>
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
%2 = x # EncryptedScalar<uint2>
|
||||
%3 = negative(%2) # EncryptedScalar<int3>
|
||||
%4 = mul(%3, %1) # EncryptedScalar<int6>
|
||||
%5 = absolute(%4) # EncryptedScalar<uint5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being
|
||||
%6 = astype(%5, dtype=int32) # EncryptedScalar<uint5>
|
||||
%7 = add(%6, %0) # EncryptedScalar<uint6>
|
||||
return %7
|
||||
|
||||
""".strip() # noqa: E501 # pylint: disable=line-too-long
|
||||
assert str(error) == match
|
||||
raise
|
||||
|
||||
|
||||
def test_small_inputset_no_fail():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
|
||||
Reference in New Issue
Block a user