From f686ca535af53cac3c6b02bb03e992943c7f4139 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 1 Sep 2021 14:18:32 +0100 Subject: [PATCH] feat(mlir): convert TensorValue inputs to MLIR factorized the input type conversion of scalar values --- hdk/common/data_types/dtypes_helpers.py | 62 ++++++++++++++++++ hdk/common/mlir/mlir_converter.py | 82 +++++++++++++++++++++--- tests/common/mlir/test_mlir_converter.py | 55 +++++++++++++++- 3 files changed, 189 insertions(+), 10 deletions(-) diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index 7502c3a29..7c2c027b3 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -84,6 +84,68 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool: ) +def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: + """Helper function to check that a value is an encrypted TensorValue of type Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is an encrypted TensorValue of type Integer + """ + return ( + isinstance(value_to_check, TensorValue) + and value_to_check.is_encrypted + and isinstance(value_to_check.data_type, INTEGER_TYPES) + ) + + +def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool: + """Helper function to check that a value is an encrypted TensorValue of type unsigned Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and + unsigned + """ + return ( + value_is_encrypted_tensor_integer(value_to_check) + and not cast(Integer, value_to_check.data_type).is_signed + ) + + +def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool: + """Helper function to check that a value is a clear TensorValue of type Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is a clear TensorValue of type Integer + """ + return ( + isinstance(value_to_check, TensorValue) + and value_to_check.is_clear + and isinstance(value_to_check.data_type, INTEGER_TYPES) + ) + + +def value_is_tensor_integer(value_to_check: BaseValue) -> bool: + """Helper function to check that a value is a TensorValue of type Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is a TensorValue of type Integer + """ + return isinstance(value_to_check, TensorValue) and isinstance( + value_to_check.data_type, INTEGER_TYPES + ) + + def find_type_to_hold_both_lossy( dtype1: BaseDataType, dtype2: BaseDataType, diff --git a/hdk/common/mlir/mlir_converter.py b/hdk/common/mlir/mlir_converter.py index 1e0124251..b3657cd9f 100644 --- a/hdk/common/mlir/mlir_converter.py +++ b/hdk/common/mlir/mlir_converter.py @@ -1,19 +1,29 @@ """File containing code to convert a DAG containing ir nodes to the compiler opset.""" # pylint: disable=no-name-in-module,no-member -from typing import cast +from typing import Tuple, cast import networkx as nx import zamalang from mlir.dialects import builtin -from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module +from mlir.ir import ( + Context, + InsertionPoint, + IntegerType, + Location, + Module, + RankedTensorType, +) from mlir.ir import Type as MLIRType +from mlir.ir import UnrankedTensorType from zamalang.dialects import hlfhe from .. import values from ..data_types import Integer from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, + value_is_clear_tensor_integer, value_is_encrypted_scalar_unsigned_integer, + value_is_encrypted_tensor_unsigned_integer, ) from ..operator_graph import OPGraph from ..representation import intermediate as ir @@ -41,6 +51,52 @@ class MLIRConverter: self.context = Context() zamalang.register_dialects(self.context) + def _get_tensor_element_type( + self, + bit_width: int, + is_encrypted: bool, + is_signed: bool, + shape: Tuple[int, ...], + ) -> MLIRType: + """Get the MLIRType for a tensor element given its properties. + + Args: + bit_width (int): number of bits used for the scalar + is_encrypted (bool): is the scalar encrypted or not + is_signed (bool): is the scalar signed or not + shape (Tuple[int, ...]): shape of the tensor + + Returns: + MLIRType: corresponding MLIR type + """ + element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed) + if len(shape): # randked tensor + return RankedTensorType.get(shape, element_type) + # unranked tensor + return UnrankedTensorType.get(element_type) + + def _get_scalar_element_type( + self, bit_width: int, is_encrypted: bool, is_signed: bool + ) -> MLIRType: + """Get the MLIRType for a scalar element given its properties. + + Args: + bit_width (int): number of bits used for the scalar + is_encrypted (bool): is the scalar encrypted or not + is_signed (bool): is the scalar signed or not + + Returns: + MLIRType: corresponding MLIR type + """ + if is_encrypted and not is_signed: + return hlfhe.EncryptedIntegerType.get(self.context, bit_width) + if is_signed and not is_encrypted: # clear signed + return IntegerType.get_signed(bit_width) + # shoulld be clear unsigned at this point + assert not is_signed and not is_encrypted + # unsigned integer are considered signless in the compiler + return IntegerType.get_signless(bit_width) + def hdk_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType: """Convert an HDK value to its corresponding MLIR Type. @@ -51,15 +107,25 @@ class MLIRConverter: corresponding MLIR type """ if value_is_encrypted_scalar_unsigned_integer(value): - return hlfhe.EncryptedIntegerType.get( - self.context, cast(Integer, value.data_type).bit_width + return self._get_scalar_element_type( + cast(Integer, value.data_type).bit_width, True, False ) if value_is_clear_scalar_integer(value): dtype = cast(Integer, value.data_type) - if dtype.is_signed: - return IntegerType.get_signed(dtype.bit_width, context=self.context) - # unsigned integer are considered signless in the compiler - return IntegerType.get_signless(dtype.bit_width, context=self.context) + return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed) + if value_is_encrypted_tensor_unsigned_integer(value): + dtype = cast(Integer, value.data_type) + return self._get_tensor_element_type( + dtype.bit_width, True, False, cast(values.TensorValue, value).shape + ) + if value_is_clear_tensor_integer(value): + dtype = cast(Integer, value.data_type) + return self._get_tensor_element_type( + dtype.bit_width, + False, + dtype.is_signed, + cast(values.TensorValue, value).shape, + ) raise TypeError(f"can't convert value of type {type(value)} to MLIR type") def convert(self, op_graph: OPGraph) -> str: diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index c0981992b..de975ea22 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -3,7 +3,7 @@ import itertools import pytest -from mlir.ir import IntegerType +from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType from zamalang import compiler from zamalang.dialects import hlfhe @@ -11,6 +11,7 @@ from hdk.common.data_types.integers import Integer from hdk.common.extensions.table import LookupTable from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from hdk.common.values import ClearScalar, EncryptedScalar +from hdk.common.values.tensors import ClearTensor, EncryptedTensor from hdk.numpy.compile import compile_numpy_function_into_op_graph @@ -202,14 +203,64 @@ def test_hdk_clear_integer_to_mlir_type(is_signed): """Test conversion of ClearScalar into MLIR""" value = ClearScalar(Integer(5, is_signed=is_signed)) converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - int_mlir = converter.hdk_value_to_mlir_type(value) with converter.context: + int_mlir = converter.hdk_value_to_mlir_type(value) if is_signed: assert int_mlir == IntegerType.get_signed(5) else: assert int_mlir == IntegerType.get_signless(5) +@pytest.mark.parametrize("is_signed", [True, False]) +@pytest.mark.parametrize( + "shape", + [ + None, + (5,), + (5, 8), + (-1, 5), + ], +) +def test_hdk_clear_tensor_integer_to_mlir_type(is_signed, shape): + """Test conversion of ClearTensor into MLIR""" + value = ClearTensor(Integer(5, is_signed=is_signed), shape) + converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + with converter.context, Location.unknown(): + tensor_mlir = converter.hdk_value_to_mlir_type(value) + if is_signed: + element_type = IntegerType.get_signed(5) + else: + element_type = IntegerType.get_signless(5) + if shape is None: + expected_type = UnrankedTensorType.get(element_type) + else: + expected_type = RankedTensorType.get(shape, element_type) + assert tensor_mlir == expected_type + + +@pytest.mark.parametrize( + "shape", + [ + None, + (5,), + (5, 8), + (-1, 5), + ], +) +def test_hdk_encrypted_tensor_integer_to_mlir_type(shape): + """Test conversion of EncryptedTensor into MLIR""" + value = EncryptedTensor(Integer(6, is_signed=False), shape) + converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + with converter.context, Location.unknown(): + tensor_mlir = converter.hdk_value_to_mlir_type(value) + element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6) + if shape is None: + expected_type = UnrankedTensorType.get(element_type) + else: + expected_type = RankedTensorType.get(shape, element_type) + assert tensor_mlir == expected_type + + def test_failing_hdk_to_mlir_type(): """Test failing conversion of an unsupported type into MLIR""" value = "random"