feat: support TLUs with signed input

- this is done with an absolute ugly hack

closes #923
refs #1001
This commit is contained in:
Arthur Meyre
2021-11-24 15:13:27 +01:00
parent 7909a4899f
commit f53d374d1f
5 changed files with 135 additions and 129 deletions

View File

@@ -270,7 +270,7 @@ class IntermediateNodeConverter:
value = self.node.inputs[variable_input_index]
assert_true(value.is_encrypted)
if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover
if not isinstance(value.dtype, Integer): # pragma: no cover
# this branch is not covered as it's impossible to get into due to how compilation works
# however, it doesn't hurt to keep it as an extra measure
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")

View File

@@ -89,17 +89,6 @@ def check_node_compatibility_with_mlir(
)
== 1
)
if not value_is_unsigned_integer(inputs[0]):
# this branch is not reachable because compilation fails during inputset evaluation
if node.op_name == "TLU": # pragma: no cover
return "only unsigned integer lookup tables are supported"
if node.op_name == "MultiTLU": # pragma: no cover
return "only unsigned integer multi lookup tables are supported"
# e.g., `np.absolute is not supported for the time being`
return f"{node.op_name} is not supported for the time being"
else:
return f"{node.op_name} is not supported for the time being"

View File

@@ -370,8 +370,8 @@ class GenericFunction(IntermediateNode):
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
"""Get the table for the current input value of this GenericFunction.
This function only works if the GenericFunction variable input value is an unsigned Integer.
It only works if there is a single variable input node among ordered_preds.
This function only works if the GenericFunction variable input value is an Integer.
This function only works if there is a single variable input node among ordered_preds.
Args:
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
@@ -393,16 +393,12 @@ class GenericFunction(IntermediateNode):
variable_input_idx = variable_input_indices[0]
variable_input_dtype = self.inputs[variable_input_idx].dtype
# Check the input is an unsigned integer to be able to build a table
# Check the input is an integer to be able to build a table
assert_true(
isinstance(variable_input_dtype, Integer),
f"{self.get_table.__name__} only works for an unsigned Integer input",
)
variable_input_dtype = cast(Integer, variable_input_dtype)
assert_true(
not variable_input_dtype.is_signed,
f"{self.get_table.__name__} only works for an unsigned Integer input",
)
input_value_constructor = self.inputs[0].underlying_constructor
if input_value_constructor is None:

View File

@@ -2,7 +2,8 @@
import sys
import traceback
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast
import numpy
from zamalang import CompilerEngine
@@ -21,7 +22,8 @@ from ..common.mlir.utils import (
)
from ..common.operator_graph import OPGraph
from ..common.optimization.topological import fuse_float_operations
from ..common.values import BaseValue
from ..common.representation.intermediate import Add, Constant, GenericFunction
from ..common.values import BaseValue, ClearScalar
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
get_base_data_type_for_numpy_or_python_constant_data,
@@ -311,7 +313,66 @@ def compile_numpy_function_into_op_graph_and_measure_bounds(
raise
def prepare_op_graph_for_mlir(op_graph):
# HACK
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001 is
# done
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015
def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
"""Hack the op_graph to add offsets to signed inputs to TLUs.
Args:
op_graph (OPGraph): the OPGraph to hack.
"""
# Ugly hack to add an offset before entering a TLU if its variable input node has a signed
# output.
# This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR.
# This does not update the TLU input values to allow for proper table generation.
# Thankfully we are not supposed to touch the op_graph beyond that point
for node in list((nx_graph := op_graph.graph).nodes):
if isinstance(node, GenericFunction):
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
variable_input_indices = [
idx
for idx, (pred, _) in enumerate(ordered_preds_and_inputs)
if not isinstance(pred, Constant)
]
assert_true(len(variable_input_indices) == 1)
variable_input_idx = variable_input_indices[0]
variable_input_node = ordered_preds_and_inputs[variable_input_idx][0]
variable_input_value = variable_input_node.outputs[0]
variable_input_dtype = variable_input_value.dtype
assert_true(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
if not variable_input_dtype.is_signed:
continue
# input_bit_width + 1 to be MLIR compliant
input_bit_width = variable_input_dtype.bit_width
mlir_compliant_int_type = Integer(input_bit_width + 1, True)
# Manually fix the output values to be MLIR compliant
# offset_constant is set to abs(min_value) for the variable input so that the values
# [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed
# TLU to an actual unsigned TLU. The get_table function creates the table from the min
# value to the max value. As we keep the input value as a signed value, it will be from
# - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding
# values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been
# shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in
# the lambda function of the GenericFunction.
offset_constant = Constant(abs(variable_input_dtype.min_value()))
offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type)
add_offset = Add(
[deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))]
)
add_offset.outputs[0] = deepcopy(variable_input_value)
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0)
nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0)
nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0)
def prepare_op_graph_for_mlir(op_graph: OPGraph):
"""Prepare OPGraph for MLIR lowering.
This includes checking compatibility, changing bit-widths, and modifying lookup tables.
@@ -337,6 +398,12 @@ def prepare_op_graph_for_mlir(op_graph):
# TODO: workaround extend LUT #359
extend_direct_lookup_tables(op_graph)
# HACK
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001
# is done
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015
hack_offset_negative_inputs_to_lookup_tables(op_graph)
def _compile_numpy_function_internal(
function_to_compile: Callable,

View File

@@ -267,8 +267,8 @@ def check_is_good_execution(compiler_engine, function, args, verbose=True):
# Bad computation after nb_tries
raise AssertionError(
f"bad computation after {nb_tries} tries, which was supposed to happen with a "
f"probability of {expected_bad_luck}.\nLast engine result: {last_engine_result} "
f"last function result: {last_function_result}"
f"probability of {expected_bad_luck}.\nLast engine result:\n{last_engine_result}\n"
f"Last function result:\n{last_function_result}"
)
@@ -725,7 +725,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
"function,parameters,inputset,test_input,use_check_good_exec",
[
pytest.param(
lambda x: x + 1,
@@ -740,11 +740,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[1, 8],
[7, 2],
[3, 6],
],
False,
),
pytest.param(
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
@@ -759,11 +755,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[1, 7],
[8, 1],
[5, 6],
],
False,
),
pytest.param(
lambda x, y: x + y,
@@ -786,11 +778,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
],
2,
),
[
[2, 9],
[8, 3],
[4, 7],
],
False,
),
pytest.param(
lambda x, y: x + y,
@@ -817,11 +805,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[3, 4],
],
),
[
[1, 13],
[8, 6],
[5, 9],
],
False,
),
pytest.param(
lambda x: 100 - x,
@@ -836,11 +820,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[100, 93],
[94, 99],
[98, 95],
],
False,
),
pytest.param(
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
@@ -855,11 +835,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[10, 8],
[14, 14],
[8, 25],
],
False,
),
pytest.param(
lambda x: x * 2,
@@ -874,11 +850,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[0, 14],
[12, 2],
[4, 10],
],
False,
),
pytest.param(
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
@@ -893,11 +865,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[4, 14],
[12, 1],
[6, 5],
],
False,
),
pytest.param(
lambda x: LookupTable([2, 1, 3, 0])[x],
@@ -912,11 +880,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[3, 0],
],
),
[
[2, 1],
[3, 1],
[0, 2],
],
True,
),
pytest.param(
lambda x: numpy.dot(x, 2),
@@ -925,7 +889,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
False,
),
pytest.param(
lambda x: numpy.dot(2, x),
@@ -934,7 +898,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
False,
),
pytest.param(
lambda x: numpy.clip(x, 1, 5),
@@ -949,11 +913,22 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[1, 5],
[5, 1],
[2, 5],
],
True,
),
pytest.param(
lambda x: numpy.clip(x + (-4), -3, 5) + 3,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
True,
),
pytest.param(
lambda x: x.clip(1, 5),
@@ -968,16 +943,32 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[2, 5],
],
),
[
[1, 5],
[5, 1],
[2, 5],
],
True,
),
pytest.param(
lambda x: (x + (-4)).clip(-3, 5) + 3,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
True,
),
],
)
def test_compile_and_run_tensor_correctness(
function, parameters, inputset, test_input, expected_output, default_compilation_configuration
function,
parameters,
inputset,
test_input,
use_check_good_exec,
default_compilation_configuration,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
@@ -987,14 +978,18 @@ def test_compile_and_run_tensor_correctness(
default_compilation_configuration,
)
numpy_test_input = (
numpy_test_input = tuple(
item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8)
for item in test_input
)
assert numpy.array_equal(
circuit.run(*numpy_test_input),
numpy.array(expected_output, dtype=numpy.uint8),
)
if use_check_good_exec:
check_is_good_execution(circuit, function, numpy_test_input)
else:
assert numpy.array_equal(
circuit.run(*numpy_test_input),
numpy.array(function(*numpy_test_input), dtype=numpy.uint8),
)
@pytest.mark.parametrize(
@@ -1544,47 +1539,6 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
assert str(excinfo.value) == match, str(excinfo.value)
def test_fail_with_intermediate_signed_values(default_compilation_configuration):
"""Test function with failing compilation due to intermediate signed integers."""
def function(x, y):
z = numpy.abs(10 * numpy.negative(x))
z = z.astype(numpy.int32) + y
return z
with pytest.raises(RuntimeError):
try:
compile_numpy_function(
function,
{
"x": EncryptedScalar(Integer(2, is_signed=False)),
"y": EncryptedScalar(Integer(2, is_signed=False)),
},
[(i, j) for i in range(2 ** 2) for j in range(2 ** 2)],
default_compilation_configuration,
show_mlir=True,
)
except RuntimeError as error:
match = """
function you are trying to compile isn't supported for MLIR lowering
%0 = y # EncryptedScalar<uint2>
%1 = 10 # ClearScalar<uint4>
%2 = x # EncryptedScalar<uint2>
%3 = negative(%2) # EncryptedScalar<int3>
%4 = mul(%3, %1) # EncryptedScalar<int6>
%5 = absolute(%4) # EncryptedScalar<uint5>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being
%6 = astype(%5, dtype=int32) # EncryptedScalar<uint5>
%7 = add(%6, %0) # EncryptedScalar<uint6>
return %7
""".strip() # noqa: E501 # pylint: disable=line-too-long
assert str(error) == match
raise
def test_small_inputset_no_fail():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
compile_numpy_function_into_op_graph_and_measure_bounds(