diff --git a/hdk/common/mlir/converters.py b/hdk/common/mlir/converters.py index df615689f..a457253da 100644 --- a/hdk/common/mlir/converters.py +++ b/hdk/common/mlir/converters.py @@ -17,7 +17,9 @@ from zamalang.dialects import hlfhe from ...common.data_types.integers import Integer from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, + value_is_clear_tensor_integer, value_is_encrypted_scalar_unsigned_integer, + value_is_encrypted_tensor_integer, ) from ..representation import intermediate as ir @@ -131,7 +133,7 @@ def constant(node, _, __, ctx): def apply_lut(node, preds, ir_to_mlir_node, ctx): - """Converted function for the arbitrary function intermediate node.""" + """Converter function for the arbitrary function intermediate node.""" assert len(node.inputs) == 1, "LUT should have a single input" assert len(node.outputs) == 1, "LUT should have a single output" if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): @@ -156,12 +158,42 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): ).result +def dot(node, preds, ir_to_mlir_node, ctx): + """Converter function for the dot intermediate node.""" + assert len(node.inputs) == 2, "Dot should have two inputs" + assert len(node.outputs) == 1, "Dot should have a single output" + if not ( + ( + value_is_encrypted_tensor_integer(node.inputs[0]) + and value_is_clear_tensor_integer(node.inputs[1]) + ) + or ( + value_is_encrypted_tensor_integer(node.inputs[1]) + and value_is_clear_tensor_integer(node.inputs[0]) + ) + ): + raise TypeError( + f"Don't support subtraction between {type(node.inputs[0])} and {type(node.inputs[1])}" + ) + lhs_node, rhs_node = preds + # need to flip as underlying operation need encrypted first + if value_is_clear_tensor_integer(node.inputs[0]): + lhs_node, rhs_node = rhs_node, lhs_node + lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] + return hlfhe.Dot( + hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width), + lhs, + rhs, + ).result + + V0_OPSET_CONVERSION_FUNCTIONS = { ir.Add: add, ir.Sub: sub, ir.Mul: mul, ir.Constant: constant, ir.ArbitraryFunction: apply_lut, + ir.Dot: dot, } # pylint: enable=no-name-in-module,no-member diff --git a/hdk/common/mlir/mlir_converter.py b/hdk/common/mlir/mlir_converter.py index b3657cd9f..7c50c0988 100644 --- a/hdk/common/mlir/mlir_converter.py +++ b/hdk/common/mlir/mlir_converter.py @@ -51,7 +51,7 @@ class MLIRConverter: self.context = Context() zamalang.register_dialects(self.context) - def _get_tensor_element_type( + def _get_tensor_type( self, bit_width: int, is_encrypted: bool, @@ -69,13 +69,13 @@ class MLIRConverter: Returns: MLIRType: corresponding MLIR type """ - element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed) + element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed) if len(shape): # randked tensor return RankedTensorType.get(shape, element_type) # unranked tensor return UnrankedTensorType.get(element_type) - def _get_scalar_element_type( + def _get_scalar_integer_type( self, bit_width: int, is_encrypted: bool, is_signed: bool ) -> MLIRType: """Get the MLIRType for a scalar element given its properties. @@ -92,7 +92,7 @@ class MLIRConverter: return hlfhe.EncryptedIntegerType.get(self.context, bit_width) if is_signed and not is_encrypted: # clear signed return IntegerType.get_signed(bit_width) - # shoulld be clear unsigned at this point + # should be clear unsigned at this point assert not is_signed and not is_encrypted # unsigned integer are considered signless in the compiler return IntegerType.get_signless(bit_width) @@ -107,24 +107,29 @@ class MLIRConverter: corresponding MLIR type """ if value_is_encrypted_scalar_unsigned_integer(value): - return self._get_scalar_element_type( + return self._get_scalar_integer_type( cast(Integer, value.data_type).bit_width, True, False ) if value_is_clear_scalar_integer(value): dtype = cast(Integer, value.data_type) - return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed) + return self._get_scalar_integer_type( + dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed + ) if value_is_encrypted_tensor_unsigned_integer(value): dtype = cast(Integer, value.data_type) - return self._get_tensor_element_type( - dtype.bit_width, True, False, cast(values.TensorValue, value).shape + return self._get_tensor_type( + dtype.bit_width, + is_encrypted=True, + is_signed=False, + shape=cast(values.TensorValue, value).shape, ) if value_is_clear_tensor_integer(value): dtype = cast(Integer, value.data_type) - return self._get_tensor_element_type( + return self._get_tensor_type( dtype.bit_width, - False, - dtype.is_signed, - cast(values.TensorValue, value).shape, + is_encrypted=False, + is_signed=dtype.is_signed, + shape=cast(values.TensorValue, value).shape, ) raise TypeError(f"can't convert value of type {type(value)} to MLIR type") diff --git a/hdk/common/mlir/utils.py b/hdk/common/mlir/utils.py index 77b374708..c6dec7c8d 100644 --- a/hdk/common/mlir/utils.py +++ b/hdk/common/mlir/utils.py @@ -4,7 +4,9 @@ from typing import cast from ..data_types import Integer from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, + value_is_clear_tensor_integer, value_is_encrypted_scalar_integer, + value_is_encrypted_tensor_integer, value_is_scalar_integer, ) from ..operator_graph import OPGraph @@ -37,9 +39,11 @@ def _set_all_bit_width(op_graph: OPGraph, p: int): """ for node in op_graph.graph.nodes: for value in node.outputs + node.inputs: - if value_is_clear_scalar_integer(value): + if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value): value.data_type.bit_width = p + 1 - elif value_is_encrypted_scalar_integer(value): + elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer( + value + ): value.data_type.bit_width = p @@ -52,8 +56,10 @@ def update_bit_width_for_mlir(op_graph: OPGraph): max_bit_width = 0 for node in op_graph.graph.nodes: for value_out in node.outputs: - if value_is_clear_scalar_integer(value_out): + if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out): max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1) - elif value_is_encrypted_scalar_integer(value_out): + elif value_is_encrypted_scalar_integer(value_out) or value_is_encrypted_tensor_integer( + value_out + ): max_bit_width = max(max_bit_width, value_out.data_type.bit_width) _set_all_bit_width(op_graph, max_bit_width) diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py index ffcf3aff5..04ce27223 100644 --- a/tests/common/mlir/test_converters.py +++ b/tests/common/mlir/test_converters.py @@ -3,7 +3,7 @@ import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer -from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub +from hdk.common.mlir.converters import add, apply_lut, constant, dot, mul, sub from hdk.common.values import ClearScalar, EncryptedScalar @@ -21,7 +21,7 @@ class MockNode: self.outputs = outputs -@pytest.mark.parametrize("converter", [add, sub, mul]) +@pytest.mark.parametrize("converter", [add, sub, mul, dot]) def test_failing_converter(converter): """Test failing converter""" with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"): diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index de975ea22..1b45ce0d7 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -2,6 +2,7 @@ # pylint: disable=no-name-in-module,no-member import itertools +import numpy import pytest from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType from zamalang import compiler @@ -66,6 +67,11 @@ def lut(x): return table[x] +def dot(x, y): + """Test dot""" + return numpy.dot(x, y) + + def datagen(*args): """Generate data from ranges""" for prod in itertools.product(*args): @@ -178,6 +184,22 @@ def datagen(*args): }, (range(0, 8),), ), + ( + dot, + { + "x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), + "y": ClearTensor(Integer(64, is_signed=False), shape=(4,)), + }, + (range(0, 8), range(0, 8)), + ), + ( + dot, + { + "x": ClearTensor(Integer(64, is_signed=False), shape=(4,)), + "y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), + }, + (range(0, 8), range(0, 8)), + ), ], ) def test_mlir_converter(func, args_dict, args_ranges):