mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(mlir): convert TensorValue inputs to MLIR
factorized the input type conversion of scalar values
This commit is contained in:
@@ -84,6 +84,68 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and value_to_check.is_encrypted
|
||||
and isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted TensorValue of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and
|
||||
unsigned
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_tensor_integer(value_to_check)
|
||||
and not cast(Integer, value_to_check.data_type).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a clear TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear TensorValue of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and value_to_check.is_clear
|
||||
and isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a TensorValue of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, TensorValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
def find_type_to_hold_both_lossy(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
|
||||
@@ -1,19 +1,29 @@
|
||||
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from typing import cast
|
||||
from typing import Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module
|
||||
from mlir.ir import (
|
||||
Context,
|
||||
InsertionPoint,
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
RankedTensorType,
|
||||
)
|
||||
from mlir.ir import Type as MLIRType
|
||||
from mlir.ir import UnrankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from .. import values
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_unsigned_integer,
|
||||
)
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
@@ -41,6 +51,52 @@ class MLIRConverter:
|
||||
self.context = Context()
|
||||
zamalang.register_dialects(self.context)
|
||||
|
||||
def _get_tensor_element_type(
|
||||
self,
|
||||
bit_width: int,
|
||||
is_encrypted: bool,
|
||||
is_signed: bool,
|
||||
shape: Tuple[int, ...],
|
||||
) -> MLIRType:
|
||||
"""Get the MLIRType for a tensor element given its properties.
|
||||
|
||||
Args:
|
||||
bit_width (int): number of bits used for the scalar
|
||||
is_encrypted (bool): is the scalar encrypted or not
|
||||
is_signed (bool): is the scalar signed or not
|
||||
shape (Tuple[int, ...]): shape of the tensor
|
||||
|
||||
Returns:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed)
|
||||
if len(shape): # randked tensor
|
||||
return RankedTensorType.get(shape, element_type)
|
||||
# unranked tensor
|
||||
return UnrankedTensorType.get(element_type)
|
||||
|
||||
def _get_scalar_element_type(
|
||||
self, bit_width: int, is_encrypted: bool, is_signed: bool
|
||||
) -> MLIRType:
|
||||
"""Get the MLIRType for a scalar element given its properties.
|
||||
|
||||
Args:
|
||||
bit_width (int): number of bits used for the scalar
|
||||
is_encrypted (bool): is the scalar encrypted or not
|
||||
is_signed (bool): is the scalar signed or not
|
||||
|
||||
Returns:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
if is_encrypted and not is_signed:
|
||||
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
|
||||
if is_signed and not is_encrypted: # clear signed
|
||||
return IntegerType.get_signed(bit_width)
|
||||
# shoulld be clear unsigned at this point
|
||||
assert not is_signed and not is_encrypted
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(bit_width)
|
||||
|
||||
def hdk_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
|
||||
"""Convert an HDK value to its corresponding MLIR Type.
|
||||
|
||||
@@ -51,15 +107,25 @@ class MLIRConverter:
|
||||
corresponding MLIR type
|
||||
"""
|
||||
if value_is_encrypted_scalar_unsigned_integer(value):
|
||||
return hlfhe.EncryptedIntegerType.get(
|
||||
self.context, cast(Integer, value.data_type).bit_width
|
||||
return self._get_scalar_element_type(
|
||||
cast(Integer, value.data_type).bit_width, True, False
|
||||
)
|
||||
if value_is_clear_scalar_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
if dtype.is_signed:
|
||||
return IntegerType.get_signed(dtype.bit_width, context=self.context)
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(dtype.bit_width, context=self.context)
|
||||
return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed)
|
||||
if value_is_encrypted_tensor_unsigned_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
return self._get_tensor_element_type(
|
||||
dtype.bit_width, True, False, cast(values.TensorValue, value).shape
|
||||
)
|
||||
if value_is_clear_tensor_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
return self._get_tensor_element_type(
|
||||
dtype.bit_width,
|
||||
False,
|
||||
dtype.is_signed,
|
||||
cast(values.TensorValue, value).shape,
|
||||
)
|
||||
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
|
||||
|
||||
def convert(self, op_graph: OPGraph) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
from mlir.ir import IntegerType
|
||||
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
|
||||
from zamalang import compiler
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
@@ -11,6 +11,7 @@ from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from hdk.common.values import ClearScalar, EncryptedScalar
|
||||
from hdk.common.values.tensors import ClearTensor, EncryptedTensor
|
||||
from hdk.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
@@ -202,14 +203,64 @@ def test_hdk_clear_integer_to_mlir_type(is_signed):
|
||||
"""Test conversion of ClearScalar into MLIR"""
|
||||
value = ClearScalar(Integer(5, is_signed=is_signed))
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
int_mlir = converter.hdk_value_to_mlir_type(value)
|
||||
with converter.context:
|
||||
int_mlir = converter.hdk_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
assert int_mlir == IntegerType.get_signed(5)
|
||||
else:
|
||||
assert int_mlir == IntegerType.get_signless(5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_signed", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
None,
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
],
|
||||
)
|
||||
def test_hdk_clear_tensor_integer_to_mlir_type(is_signed, shape):
|
||||
"""Test conversion of ClearTensor into MLIR"""
|
||||
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.hdk_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
element_type = IntegerType.get_signed(5)
|
||||
else:
|
||||
element_type = IntegerType.get_signless(5)
|
||||
if shape is None:
|
||||
expected_type = UnrankedTensorType.get(element_type)
|
||||
else:
|
||||
expected_type = RankedTensorType.get(shape, element_type)
|
||||
assert tensor_mlir == expected_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
None,
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
],
|
||||
)
|
||||
def test_hdk_encrypted_tensor_integer_to_mlir_type(shape):
|
||||
"""Test conversion of EncryptedTensor into MLIR"""
|
||||
value = EncryptedTensor(Integer(6, is_signed=False), shape)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.hdk_value_to_mlir_type(value)
|
||||
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
|
||||
if shape is None:
|
||||
expected_type = UnrankedTensorType.get(element_type)
|
||||
else:
|
||||
expected_type = RankedTensorType.get(shape, element_type)
|
||||
assert tensor_mlir == expected_type
|
||||
|
||||
|
||||
def test_failing_hdk_to_mlir_type():
|
||||
"""Test failing conversion of an unsupported type into MLIR"""
|
||||
value = "random"
|
||||
|
||||
Reference in New Issue
Block a user