feat(mlir): conversion of dot node into MLIR

This commit is contained in:
youben11
2021-09-01 15:00:54 +01:00
committed by Ayoub Benaissa
parent f686ca535a
commit 31c1787af2
5 changed files with 84 additions and 19 deletions

View File

@@ -17,7 +17,9 @@ from zamalang.dialects import hlfhe
from ...common.data_types.integers import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_integer,
)
from ..representation import intermediate as ir
@@ -131,7 +133,7 @@ def constant(node, _, __, ctx):
def apply_lut(node, preds, ir_to_mlir_node, ctx):
"""Converted function for the arbitrary function intermediate node."""
"""Converter function for the arbitrary function intermediate node."""
assert len(node.inputs) == 1, "LUT should have a single input"
assert len(node.outputs) == 1, "LUT should have a single output"
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
@@ -156,12 +158,42 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
).result
def dot(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the dot intermediate node."""
assert len(node.inputs) == 2, "Dot should have two inputs"
assert len(node.outputs) == 1, "Dot should have a single output"
if not (
(
value_is_encrypted_tensor_integer(node.inputs[0])
and value_is_clear_tensor_integer(node.inputs[1])
)
or (
value_is_encrypted_tensor_integer(node.inputs[1])
and value_is_clear_tensor_integer(node.inputs[0])
)
):
raise TypeError(
f"Don't support subtraction between {type(node.inputs[0])} and {type(node.inputs[1])}"
)
lhs_node, rhs_node = preds
# need to flip as underlying operation need encrypted first
if value_is_clear_tensor_integer(node.inputs[0]):
lhs_node, rhs_node = rhs_node, lhs_node
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.Dot(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
lhs,
rhs,
).result
V0_OPSET_CONVERSION_FUNCTIONS = {
ir.Add: add,
ir.Sub: sub,
ir.Mul: mul,
ir.Constant: constant,
ir.ArbitraryFunction: apply_lut,
ir.Dot: dot,
}
# pylint: enable=no-name-in-module,no-member

View File

@@ -51,7 +51,7 @@ class MLIRConverter:
self.context = Context()
zamalang.register_dialects(self.context)
def _get_tensor_element_type(
def _get_tensor_type(
self,
bit_width: int,
is_encrypted: bool,
@@ -69,13 +69,13 @@ class MLIRConverter:
Returns:
MLIRType: corresponding MLIR type
"""
element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed)
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
if len(shape): # randked tensor
return RankedTensorType.get(shape, element_type)
# unranked tensor
return UnrankedTensorType.get(element_type)
def _get_scalar_element_type(
def _get_scalar_integer_type(
self, bit_width: int, is_encrypted: bool, is_signed: bool
) -> MLIRType:
"""Get the MLIRType for a scalar element given its properties.
@@ -92,7 +92,7 @@ class MLIRConverter:
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
if is_signed and not is_encrypted: # clear signed
return IntegerType.get_signed(bit_width)
# shoulld be clear unsigned at this point
# should be clear unsigned at this point
assert not is_signed and not is_encrypted
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(bit_width)
@@ -107,24 +107,29 @@ class MLIRConverter:
corresponding MLIR type
"""
if value_is_encrypted_scalar_unsigned_integer(value):
return self._get_scalar_element_type(
return self._get_scalar_integer_type(
cast(Integer, value.data_type).bit_width, True, False
)
if value_is_clear_scalar_integer(value):
dtype = cast(Integer, value.data_type)
return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed)
return self._get_scalar_integer_type(
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
)
if value_is_encrypted_tensor_unsigned_integer(value):
dtype = cast(Integer, value.data_type)
return self._get_tensor_element_type(
dtype.bit_width, True, False, cast(values.TensorValue, value).shape
return self._get_tensor_type(
dtype.bit_width,
is_encrypted=True,
is_signed=False,
shape=cast(values.TensorValue, value).shape,
)
if value_is_clear_tensor_integer(value):
dtype = cast(Integer, value.data_type)
return self._get_tensor_element_type(
return self._get_tensor_type(
dtype.bit_width,
False,
dtype.is_signed,
cast(values.TensorValue, value).shape,
is_encrypted=False,
is_signed=dtype.is_signed,
shape=cast(values.TensorValue, value).shape,
)
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")

View File

@@ -4,7 +4,9 @@ from typing import cast
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_integer,
value_is_encrypted_tensor_integer,
value_is_scalar_integer,
)
from ..operator_graph import OPGraph
@@ -37,9 +39,11 @@ def _set_all_bit_width(op_graph: OPGraph, p: int):
"""
for node in op_graph.graph.nodes:
for value in node.outputs + node.inputs:
if value_is_clear_scalar_integer(value):
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
value.data_type.bit_width = p + 1
elif value_is_encrypted_scalar_integer(value):
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
value
):
value.data_type.bit_width = p
@@ -52,8 +56,10 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
max_bit_width = 0
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_scalar_integer(value_out):
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
elif value_is_encrypted_scalar_integer(value_out):
elif value_is_encrypted_scalar_integer(value_out) or value_is_encrypted_tensor_integer(
value_out
):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
_set_all_bit_width(op_graph, max_bit_width)

View File

@@ -3,7 +3,7 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
from hdk.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
from hdk.common.values import ClearScalar, EncryptedScalar
@@ -21,7 +21,7 @@ class MockNode:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul])
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):

View File

@@ -2,6 +2,7 @@
# pylint: disable=no-name-in-module,no-member
import itertools
import numpy
import pytest
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
from zamalang import compiler
@@ -66,6 +67,11 @@ def lut(x):
return table[x]
def dot(x, y):
"""Test dot"""
return numpy.dot(x, y)
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
@@ -178,6 +184,22 @@ def datagen(*args):
},
(range(0, 8),),
),
(
dot,
{
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
),
(
dot,
{
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):