diff --git a/hdk/common/mlir/converters.py b/hdk/common/mlir/converters.py index f9a680ac0..09fa646d3 100644 --- a/hdk/common/mlir/converters.py +++ b/hdk/common/mlir/converters.py @@ -7,8 +7,14 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml - `ctx`: MLIR context """ # pylint: disable=no-name-in-module,no-member +from typing import cast + +from mlir.dialects import std as std_dialect +from mlir.ir import IntegerAttr, IntegerType from zamalang.dialects import hlfhe +from hdk.common.data_types.integers import Integer + from ..data_types.dtypes_helpers import ( value_is_clear_integer, value_is_encrypted_unsigned_integer, @@ -113,6 +119,22 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx): ).result -V0_OPSET_CONVERSION_FUNCTIONS = {ir.Add: add, ir.Sub: sub, ir.Mul: mul} +def constant(node, _, __, ctx): + """Converter function for constant inputs.""" + if not value_is_clear_integer(node.outputs[0]): + raise TypeError("Don't support non-integer constants") + dtype = cast(Integer, node.outputs[0].data_type) + if dtype.is_signed: + raise TypeError("Don't support signed constant integer") + int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) + return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result + + +V0_OPSET_CONVERSION_FUNCTIONS = { + ir.Add: add, + ir.Sub: sub, + ir.Mul: mul, + ir.ConstantInput: constant, +} # pylint: enable=no-name-in-module,no-member diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py index 26d100925..894d42a77 100644 --- a/tests/common/mlir/test_converters.py +++ b/tests/common/mlir/test_converters.py @@ -1,15 +1,24 @@ """Test converter functions""" import pytest -from hdk.common.mlir.converters import add, mul, sub +from hdk.common.data_types.floats import Float +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearValue +from hdk.common.mlir.converters import add, constant, mul, sub class MockNode: """Mocking an intermediate node""" - def __init__(self, inputs=5, outputs=5): - self.inputs = [None for i in range(inputs)] - self.outputs = [None for i in range(outputs)] + def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None): + if inputs is None: + self.inputs = [None for i in range(inputs_n)] + else: + self.inputs = inputs + if outputs is None: + self.outputs = [None for i in range(outputs_n)] + else: + self.outputs = outputs @pytest.mark.parametrize("converter", [add, sub, mul]) @@ -17,3 +26,15 @@ def test_failing_converter(converter): """Test failing converter""" with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"): converter(MockNode(2, 1), None, None, None) + + +def test_fail_non_integer_const(): + """Test failing constant converter with non-integer""" + with pytest.raises(TypeError, match=r"Don't support non-integer constants"): + constant(MockNode(outputs=[ClearValue(Float(32))]), None, None, None) + + +def test_fail_signed_integer_const(): + """Test failing constant converter with non-integer""" + with pytest.raises(TypeError, match=r"Don't support signed constant integer"): + constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None) diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index 1537d0cf2..22528a808 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -18,16 +18,31 @@ def add(x, y): return x + y +def constant_add(x): + """Test constant add""" + return x + 5 + + def sub(x, y): """Test simple sub""" return x - y +def constant_sub(x): + """Test constant sub""" + return 8 - x + + def mul(x, y): """Test simple mul""" return x * y +def constant_mul(x): + """Test constant mul""" + return x * 2 + + def sub_add_mul(x, y, z): """Test combination of ops""" return z - y + x * z @@ -60,6 +75,13 @@ def datagen(*args): }, (range(0, 8), range(1, 4)), ), + ( + constant_add, + { + "x": EncryptedValue(Integer(64, is_signed=False)), + }, + (range(0, 8),), + ), ( add, { @@ -84,6 +106,13 @@ def datagen(*args): }, (range(5, 10), range(2, 6)), ), + ( + constant_sub, + { + "x": EncryptedValue(Integer(64, is_signed=False)), + }, + (range(0, 5),), + ), ( mul, { @@ -92,6 +121,13 @@ def datagen(*args): }, (range(1, 5), range(2, 8)), ), + ( + constant_mul, + { + "x": EncryptedValue(Integer(64, is_signed=False)), + }, + (range(0, 8),), + ), ( mul, {