feat(mlir): TLU conversion

This commit is contained in:
youben11
2021-08-17 10:30:10 +01:00
committed by Ayoub Benaissa
parent dbda93639b
commit 784158741e
4 changed files with 95 additions and 4 deletions

View File

@@ -6,11 +6,12 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
- `ctx`: MLIR context
"""
# pylint: disable=no-name-in-module,no-member
from typing import cast
# pylint: disable=no-name-in-module,no-member
import numpy as np
from mlir.dialects import std as std_dialect
from mlir.ir import IntegerAttr, IntegerType
from mlir.ir import DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from zamalang.dialects import hlfhe
from ...common.data_types.integers import Integer
@@ -129,11 +130,38 @@ def constant(node, _, __, ctx):
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
def apply_lut(node, preds, ir_to_mlir_node, ctx):
"""Converted function for the arbitrary function intermediate node."""
assert len(node.inputs) == 1, "LUT should have a single input"
assert len(node.outputs) == 1, "LUT should have a single output"
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
raise TypeError("Only support LUT with encrypted unsigned integers inputs")
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
raise TypeError("Only support LUT with encrypted unsigned integers outputs")
x_node = preds[0]
x = ir_to_mlir_node[x_node]
table = node.get_table()
out_dtype = cast(Integer, node.outputs[0].data_type)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
tensor_lut = std_dialect.ConstantOp(
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
dense_elem,
).result
return hlfhe.ApplyLookupTableEintOp(
hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width),
x,
tensor_lut,
).result
V0_OPSET_CONVERSION_FUNCTIONS = {
ir.Add: add,
ir.Sub: sub,
ir.Mul: mul,
ir.Constant: constant,
ir.ArbitraryFunction: apply_lut,
}
# pylint: enable=no-name-in-module,no-member

View File

@@ -3,8 +3,8 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.mlir.converters import add, constant, mul, sub
from hdk.common.values import ClearValue
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
from hdk.common.values import ClearValue, EncryptedValue
class MockNode:
@@ -38,3 +38,45 @@ def test_fail_signed_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)
@pytest.mark.parametrize(
"input_node",
[
ClearValue(Integer(8, True)),
ClearValue(Integer(8, False)),
EncryptedValue(Integer(8, True)),
],
)
def test_fail_tlu_input(input_node):
"""Test failing LUT converter with invalid input"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers inputs"
):
apply_lut(
MockNode(inputs=[input_node], outputs=[EncryptedValue(Integer(8, False))]),
[None],
None,
None,
)
@pytest.mark.parametrize(
"input_node",
[
ClearValue(Integer(8, True)),
ClearValue(Integer(8, False)),
EncryptedValue(Integer(8, True)),
],
)
def test_fail_tlu_output(input_node):
"""Test failing LUT converter with invalid output"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers outputs"
):
apply_lut(
MockNode(inputs=[EncryptedValue(Integer(8, False))], outputs=[input_node]),
[None],
None,
None,
)

View File

@@ -8,6 +8,7 @@ from zamalang import compiler
from zamalang.dialects import hlfhe
from hdk.common.data_types.integers import Integer
from hdk.common.extensions.table import LookupTable
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from hdk.common.values import ClearValue, EncryptedValue
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
@@ -58,6 +59,12 @@ def ret_multiple_different_order(x, y, z):
return y, z, x
def lut(x):
"""Test lookup table"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
@@ -163,6 +170,13 @@ def datagen(*args):
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
lut,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):

View File

@@ -23,6 +23,12 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def lut(x):
"""Test lookup table"""
table = LookupTable(list(range(128)))
return table[x]
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
@@ -75,6 +81,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lut, ((0, 127),), ["x"]),
],
)
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):