mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): TLU conversion
This commit is contained in:
@@ -6,11 +6,12 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
|
||||
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
|
||||
- `ctx`: MLIR context
|
||||
"""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from typing import cast
|
||||
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import numpy as np
|
||||
from mlir.dialects import std as std_dialect
|
||||
from mlir.ir import IntegerAttr, IntegerType
|
||||
from mlir.ir import DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from ...common.data_types.integers import Integer
|
||||
@@ -129,11 +130,38 @@ def constant(node, _, __, ctx):
|
||||
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converted function for the arbitrary function intermediate node."""
|
||||
assert len(node.inputs) == 1, "LUT should have a single input"
|
||||
assert len(node.outputs) == 1, "LUT should have a single output"
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
raise TypeError("Only support LUT with encrypted unsigned integers inputs")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
|
||||
raise TypeError("Only support LUT with encrypted unsigned integers outputs")
|
||||
|
||||
x_node = preds[0]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
table = node.get_table()
|
||||
out_dtype = cast(Integer, node.outputs[0].data_type)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
|
||||
tensor_lut = std_dialect.ConstantOp(
|
||||
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
|
||||
dense_elem,
|
||||
).result
|
||||
return hlfhe.ApplyLookupTableEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width),
|
||||
x,
|
||||
tensor_lut,
|
||||
).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
ir.Add: add,
|
||||
ir.Sub: sub,
|
||||
ir.Mul: mul,
|
||||
ir.Constant: constant,
|
||||
ir.ArbitraryFunction: apply_lut,
|
||||
}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.mlir.converters import add, constant, mul, sub
|
||||
from hdk.common.values import ClearValue
|
||||
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
|
||||
from hdk.common.values import ClearValue, EncryptedValue
|
||||
|
||||
|
||||
class MockNode:
|
||||
@@ -38,3 +38,45 @@ def test_fail_signed_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
|
||||
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_node",
|
||||
[
|
||||
ClearValue(Integer(8, True)),
|
||||
ClearValue(Integer(8, False)),
|
||||
EncryptedValue(Integer(8, True)),
|
||||
],
|
||||
)
|
||||
def test_fail_tlu_input(input_node):
|
||||
"""Test failing LUT converter with invalid input"""
|
||||
with pytest.raises(
|
||||
TypeError, match=r"Only support LUT with encrypted unsigned integers inputs"
|
||||
):
|
||||
apply_lut(
|
||||
MockNode(inputs=[input_node], outputs=[EncryptedValue(Integer(8, False))]),
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_node",
|
||||
[
|
||||
ClearValue(Integer(8, True)),
|
||||
ClearValue(Integer(8, False)),
|
||||
EncryptedValue(Integer(8, True)),
|
||||
],
|
||||
)
|
||||
def test_fail_tlu_output(input_node):
|
||||
"""Test failing LUT converter with invalid output"""
|
||||
with pytest.raises(
|
||||
TypeError, match=r"Only support LUT with encrypted unsigned integers outputs"
|
||||
):
|
||||
apply_lut(
|
||||
MockNode(inputs=[EncryptedValue(Integer(8, False))], outputs=[input_node]),
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from zamalang import compiler
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from hdk.common.values import ClearValue, EncryptedValue
|
||||
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
|
||||
@@ -58,6 +59,12 @@ def ret_multiple_different_order(x, y, z):
|
||||
return y, z, x
|
||||
|
||||
|
||||
def lut(x):
|
||||
"""Test lookup table"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x]
|
||||
|
||||
|
||||
def datagen(*args):
|
||||
"""Generate data from ranges"""
|
||||
for prod in itertools.product(*args):
|
||||
@@ -163,6 +170,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(1, 5), range(1, 5), range(1, 5)),
|
||||
),
|
||||
(
|
||||
lut,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
|
||||
@@ -23,6 +23,12 @@ def no_fuse_unhandled(x, y):
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def lut(x):
|
||||
"""Test lookup table"""
|
||||
table = LookupTable(list(range(128)))
|
||||
return table[x]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
@@ -75,6 +81,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(lut, ((0, 127),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
|
||||
Reference in New Issue
Block a user