feat(mlir): convert TensorValue inputs to MLIR

factorized the input type conversion of scalar values
This commit is contained in:
youben11
2021-09-01 14:18:32 +01:00
committed by Ayoub Benaissa
parent cfe48cca15
commit f686ca535a
3 changed files with 189 additions and 10 deletions

View File

@@ -84,6 +84,68 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
)
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is an encrypted TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer
"""
return (
isinstance(value_to_check, TensorValue)
and value_to_check.is_encrypted
and isinstance(value_to_check.data_type, INTEGER_TYPES)
)
def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is an encrypted TensorValue of type unsigned Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and
unsigned
"""
return (
value_is_encrypted_tensor_integer(value_to_check)
and not cast(Integer, value_to_check.data_type).is_signed
)
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is a clear TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a clear TensorValue of type Integer
"""
return (
isinstance(value_to_check, TensorValue)
and value_to_check.is_clear
and isinstance(value_to_check.data_type, INTEGER_TYPES)
)
def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is a TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a TensorValue of type Integer
"""
return isinstance(value_to_check, TensorValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
)
def find_type_to_hold_both_lossy(
dtype1: BaseDataType,
dtype2: BaseDataType,

View File

@@ -1,19 +1,29 @@
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
# pylint: disable=no-name-in-module,no-member
from typing import cast
from typing import Tuple, cast
import networkx as nx
import zamalang
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module
from mlir.ir import (
Context,
InsertionPoint,
IntegerType,
Location,
Module,
RankedTensorType,
)
from mlir.ir import Type as MLIRType
from mlir.ir import UnrankedTensorType
from zamalang.dialects import hlfhe
from .. import values
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_unsigned_integer,
)
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
@@ -41,6 +51,52 @@ class MLIRConverter:
self.context = Context()
zamalang.register_dialects(self.context)
def _get_tensor_element_type(
self,
bit_width: int,
is_encrypted: bool,
is_signed: bool,
shape: Tuple[int, ...],
) -> MLIRType:
"""Get the MLIRType for a tensor element given its properties.
Args:
bit_width (int): number of bits used for the scalar
is_encrypted (bool): is the scalar encrypted or not
is_signed (bool): is the scalar signed or not
shape (Tuple[int, ...]): shape of the tensor
Returns:
MLIRType: corresponding MLIR type
"""
element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed)
if len(shape): # randked tensor
return RankedTensorType.get(shape, element_type)
# unranked tensor
return UnrankedTensorType.get(element_type)
def _get_scalar_element_type(
self, bit_width: int, is_encrypted: bool, is_signed: bool
) -> MLIRType:
"""Get the MLIRType for a scalar element given its properties.
Args:
bit_width (int): number of bits used for the scalar
is_encrypted (bool): is the scalar encrypted or not
is_signed (bool): is the scalar signed or not
Returns:
MLIRType: corresponding MLIR type
"""
if is_encrypted and not is_signed:
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
if is_signed and not is_encrypted: # clear signed
return IntegerType.get_signed(bit_width)
# shoulld be clear unsigned at this point
assert not is_signed and not is_encrypted
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(bit_width)
def hdk_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
"""Convert an HDK value to its corresponding MLIR Type.
@@ -51,15 +107,25 @@ class MLIRConverter:
corresponding MLIR type
"""
if value_is_encrypted_scalar_unsigned_integer(value):
return hlfhe.EncryptedIntegerType.get(
self.context, cast(Integer, value.data_type).bit_width
return self._get_scalar_element_type(
cast(Integer, value.data_type).bit_width, True, False
)
if value_is_clear_scalar_integer(value):
dtype = cast(Integer, value.data_type)
if dtype.is_signed:
return IntegerType.get_signed(dtype.bit_width, context=self.context)
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(dtype.bit_width, context=self.context)
return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed)
if value_is_encrypted_tensor_unsigned_integer(value):
dtype = cast(Integer, value.data_type)
return self._get_tensor_element_type(
dtype.bit_width, True, False, cast(values.TensorValue, value).shape
)
if value_is_clear_tensor_integer(value):
dtype = cast(Integer, value.data_type)
return self._get_tensor_element_type(
dtype.bit_width,
False,
dtype.is_signed,
cast(values.TensorValue, value).shape,
)
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
def convert(self, op_graph: OPGraph) -> str:

View File

@@ -3,7 +3,7 @@
import itertools
import pytest
from mlir.ir import IntegerType
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
from zamalang import compiler
from zamalang.dialects import hlfhe
@@ -11,6 +11,7 @@ from hdk.common.data_types.integers import Integer
from hdk.common.extensions.table import LookupTable
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from hdk.common.values import ClearScalar, EncryptedScalar
from hdk.common.values.tensors import ClearTensor, EncryptedTensor
from hdk.numpy.compile import compile_numpy_function_into_op_graph
@@ -202,14 +203,64 @@ def test_hdk_clear_integer_to_mlir_type(is_signed):
"""Test conversion of ClearScalar into MLIR"""
value = ClearScalar(Integer(5, is_signed=is_signed))
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
int_mlir = converter.hdk_value_to_mlir_type(value)
with converter.context:
int_mlir = converter.hdk_value_to_mlir_type(value)
if is_signed:
assert int_mlir == IntegerType.get_signed(5)
else:
assert int_mlir == IntegerType.get_signless(5)
@pytest.mark.parametrize("is_signed", [True, False])
@pytest.mark.parametrize(
"shape",
[
None,
(5,),
(5, 8),
(-1, 5),
],
)
def test_hdk_clear_tensor_integer_to_mlir_type(is_signed, shape):
"""Test conversion of ClearTensor into MLIR"""
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.hdk_value_to_mlir_type(value)
if is_signed:
element_type = IntegerType.get_signed(5)
else:
element_type = IntegerType.get_signless(5)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
@pytest.mark.parametrize(
"shape",
[
None,
(5,),
(5, 8),
(-1, 5),
],
)
def test_hdk_encrypted_tensor_integer_to_mlir_type(shape):
"""Test conversion of EncryptedTensor into MLIR"""
value = EncryptedTensor(Integer(6, is_signed=False), shape)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.hdk_value_to_mlir_type(value)
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
def test_failing_hdk_to_mlir_type():
"""Test failing conversion of an unsupported type into MLIR"""
value = "random"