diff --git a/hdk/common/mlir/converters.py b/hdk/common/mlir/converters.py index 520b2a0d4..df615689f 100644 --- a/hdk/common/mlir/converters.py +++ b/hdk/common/mlir/converters.py @@ -6,11 +6,12 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml - `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values - `ctx`: MLIR context """ -# pylint: disable=no-name-in-module,no-member from typing import cast +# pylint: disable=no-name-in-module,no-member +import numpy as np from mlir.dialects import std as std_dialect -from mlir.ir import IntegerAttr, IntegerType +from mlir.ir import DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType from zamalang.dialects import hlfhe from ...common.data_types.integers import Integer @@ -129,11 +130,38 @@ def constant(node, _, __, ctx): return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result +def apply_lut(node, preds, ir_to_mlir_node, ctx): + """Converted function for the arbitrary function intermediate node.""" + assert len(node.inputs) == 1, "LUT should have a single input" + assert len(node.outputs) == 1, "LUT should have a single output" + if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): + raise TypeError("Only support LUT with encrypted unsigned integers inputs") + if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]): + raise TypeError("Only support LUT with encrypted unsigned integers outputs") + + x_node = preds[0] + x = ir_to_mlir_node[x_node] + table = node.get_table() + out_dtype = cast(Integer, node.outputs[0].data_type) + # Create table + dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx) + tensor_lut = std_dialect.ConstantOp( + RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)), + dense_elem, + ).result + return hlfhe.ApplyLookupTableEintOp( + hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width), + x, + tensor_lut, + ).result + + V0_OPSET_CONVERSION_FUNCTIONS = { ir.Add: add, ir.Sub: sub, ir.Mul: mul, ir.Constant: constant, + ir.ArbitraryFunction: apply_lut, } # pylint: enable=no-name-in-module,no-member diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py index 6c2ef33ca..c84c34f53 100644 --- a/tests/common/mlir/test_converters.py +++ b/tests/common/mlir/test_converters.py @@ -3,8 +3,8 @@ import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer -from hdk.common.mlir.converters import add, constant, mul, sub -from hdk.common.values import ClearValue +from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub +from hdk.common.values import ClearValue, EncryptedValue class MockNode: @@ -38,3 +38,45 @@ def test_fail_signed_integer_const(): """Test failing constant converter with non-integer""" with pytest.raises(TypeError, match=r"Don't support signed constant integer"): constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None) + + +@pytest.mark.parametrize( + "input_node", + [ + ClearValue(Integer(8, True)), + ClearValue(Integer(8, False)), + EncryptedValue(Integer(8, True)), + ], +) +def test_fail_tlu_input(input_node): + """Test failing LUT converter with invalid input""" + with pytest.raises( + TypeError, match=r"Only support LUT with encrypted unsigned integers inputs" + ): + apply_lut( + MockNode(inputs=[input_node], outputs=[EncryptedValue(Integer(8, False))]), + [None], + None, + None, + ) + + +@pytest.mark.parametrize( + "input_node", + [ + ClearValue(Integer(8, True)), + ClearValue(Integer(8, False)), + EncryptedValue(Integer(8, True)), + ], +) +def test_fail_tlu_output(input_node): + """Test failing LUT converter with invalid output""" + with pytest.raises( + TypeError, match=r"Only support LUT with encrypted unsigned integers outputs" + ): + apply_lut( + MockNode(inputs=[EncryptedValue(Integer(8, False))], outputs=[input_node]), + [None], + None, + None, + ) diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index 5a370d016..b4fadaa42 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -8,6 +8,7 @@ from zamalang import compiler from zamalang.dialects import hlfhe from hdk.common.data_types.integers import Integer +from hdk.common.extensions.table import LookupTable from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from hdk.common.values import ClearValue, EncryptedValue from hdk.hnumpy.compile import compile_numpy_function_into_op_graph @@ -58,6 +59,12 @@ def ret_multiple_different_order(x, y, z): return y, z, x +def lut(x): + """Test lookup table""" + table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7]) + return table[x] + + def datagen(*args): """Generate data from ranges""" for prod in itertools.product(*args): @@ -163,6 +170,13 @@ def datagen(*args): }, (range(1, 5), range(1, 5), range(1, 5)), ), + ( + lut, + { + "x": EncryptedValue(Integer(64, is_signed=False)), + }, + (range(0, 8),), + ), ], ) def test_mlir_converter(func, args_dict, args_ranges): diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 5f04714d9..436abadb0 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -23,6 +23,12 @@ def no_fuse_unhandled(x, y): return intermediate.astype(numpy.int32) +def lut(x): + """Test lookup table""" + table = LookupTable(list(range(128))) + return table[x] + + @pytest.mark.parametrize( "function,input_ranges,list_of_arg_names", [ @@ -75,6 +81,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n pytest.param(lambda x: x * 2, ((0, 2),), ["x"]), pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]), pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]), + pytest.param(lut, ((0, 127),), ["x"]), ], ) def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):