mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): conversion of dot node into MLIR
This commit is contained in:
@@ -17,7 +17,9 @@ from zamalang.dialects import hlfhe
|
||||
from ...common.data_types.integers import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
)
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
@@ -131,7 +133,7 @@ def constant(node, _, __, ctx):
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converted function for the arbitrary function intermediate node."""
|
||||
"""Converter function for the arbitrary function intermediate node."""
|
||||
assert len(node.inputs) == 1, "LUT should have a single input"
|
||||
assert len(node.outputs) == 1, "LUT should have a single output"
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
@@ -156,12 +158,42 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the dot intermediate node."""
|
||||
assert len(node.inputs) == 2, "Dot should have two inputs"
|
||||
assert len(node.outputs) == 1, "Dot should have a single output"
|
||||
if not (
|
||||
(
|
||||
value_is_encrypted_tensor_integer(node.inputs[0])
|
||||
and value_is_clear_tensor_integer(node.inputs[1])
|
||||
)
|
||||
or (
|
||||
value_is_encrypted_tensor_integer(node.inputs[1])
|
||||
and value_is_clear_tensor_integer(node.inputs[0])
|
||||
)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Don't support subtraction between {type(node.inputs[0])} and {type(node.inputs[1])}"
|
||||
)
|
||||
lhs_node, rhs_node = preds
|
||||
# need to flip as underlying operation need encrypted first
|
||||
if value_is_clear_tensor_integer(node.inputs[0]):
|
||||
lhs_node, rhs_node = rhs_node, lhs_node
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.Dot(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
ir.Add: add,
|
||||
ir.Sub: sub,
|
||||
ir.Mul: mul,
|
||||
ir.Constant: constant,
|
||||
ir.ArbitraryFunction: apply_lut,
|
||||
ir.Dot: dot,
|
||||
}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
|
||||
@@ -51,7 +51,7 @@ class MLIRConverter:
|
||||
self.context = Context()
|
||||
zamalang.register_dialects(self.context)
|
||||
|
||||
def _get_tensor_element_type(
|
||||
def _get_tensor_type(
|
||||
self,
|
||||
bit_width: int,
|
||||
is_encrypted: bool,
|
||||
@@ -69,13 +69,13 @@ class MLIRConverter:
|
||||
Returns:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
element_type = self._get_scalar_element_type(bit_width, is_encrypted, is_signed)
|
||||
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
|
||||
if len(shape): # randked tensor
|
||||
return RankedTensorType.get(shape, element_type)
|
||||
# unranked tensor
|
||||
return UnrankedTensorType.get(element_type)
|
||||
|
||||
def _get_scalar_element_type(
|
||||
def _get_scalar_integer_type(
|
||||
self, bit_width: int, is_encrypted: bool, is_signed: bool
|
||||
) -> MLIRType:
|
||||
"""Get the MLIRType for a scalar element given its properties.
|
||||
@@ -92,7 +92,7 @@ class MLIRConverter:
|
||||
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
|
||||
if is_signed and not is_encrypted: # clear signed
|
||||
return IntegerType.get_signed(bit_width)
|
||||
# shoulld be clear unsigned at this point
|
||||
# should be clear unsigned at this point
|
||||
assert not is_signed and not is_encrypted
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(bit_width)
|
||||
@@ -107,24 +107,29 @@ class MLIRConverter:
|
||||
corresponding MLIR type
|
||||
"""
|
||||
if value_is_encrypted_scalar_unsigned_integer(value):
|
||||
return self._get_scalar_element_type(
|
||||
return self._get_scalar_integer_type(
|
||||
cast(Integer, value.data_type).bit_width, True, False
|
||||
)
|
||||
if value_is_clear_scalar_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
return self._get_scalar_element_type(dtype.bit_width, False, dtype.is_signed)
|
||||
return self._get_scalar_integer_type(
|
||||
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
|
||||
)
|
||||
if value_is_encrypted_tensor_unsigned_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
return self._get_tensor_element_type(
|
||||
dtype.bit_width, True, False, cast(values.TensorValue, value).shape
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
is_encrypted=True,
|
||||
is_signed=False,
|
||||
shape=cast(values.TensorValue, value).shape,
|
||||
)
|
||||
if value_is_clear_tensor_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
return self._get_tensor_element_type(
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
False,
|
||||
dtype.is_signed,
|
||||
cast(values.TensorValue, value).shape,
|
||||
is_encrypted=False,
|
||||
is_signed=dtype.is_signed,
|
||||
shape=cast(values.TensorValue, value).shape,
|
||||
)
|
||||
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@ from typing import cast
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_scalar_integer,
|
||||
)
|
||||
from ..operator_graph import OPGraph
|
||||
@@ -37,9 +39,11 @@ def _set_all_bit_width(op_graph: OPGraph, p: int):
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
for value in node.outputs + node.inputs:
|
||||
if value_is_clear_scalar_integer(value):
|
||||
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
|
||||
value.data_type.bit_width = p + 1
|
||||
elif value_is_encrypted_scalar_integer(value):
|
||||
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
|
||||
value
|
||||
):
|
||||
value.data_type.bit_width = p
|
||||
|
||||
|
||||
@@ -52,8 +56,10 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
max_bit_width = 0
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_scalar_integer(value_out):
|
||||
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
|
||||
elif value_is_encrypted_scalar_integer(value_out):
|
||||
elif value_is_encrypted_scalar_integer(value_out) or value_is_encrypted_tensor_integer(
|
||||
value_out
|
||||
):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
@@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
|
||||
from hdk.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
|
||||
from hdk.common.values import ClearScalar, EncryptedScalar
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class MockNode:
|
||||
self.outputs = outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul])
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
|
||||
def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import itertools
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
|
||||
from zamalang import compiler
|
||||
@@ -66,6 +67,11 @@ def lut(x):
|
||||
return table[x]
|
||||
|
||||
|
||||
def dot(x, y):
|
||||
"""Test dot"""
|
||||
return numpy.dot(x, y)
|
||||
|
||||
|
||||
def datagen(*args):
|
||||
"""Generate data from ranges"""
|
||||
for prod in itertools.product(*args):
|
||||
@@ -178,6 +184,22 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
|
||||
Reference in New Issue
Block a user