feat(mlir): support constant inputs in mlir conversion

This commit is contained in:
youben11
2021-08-16 10:41:48 +01:00
committed by Ayoub Benaissa
parent 8fbe5dab4d
commit 3922bfe9b4
3 changed files with 84 additions and 5 deletions

View File

@@ -7,8 +7,14 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
- `ctx`: MLIR context
"""
# pylint: disable=no-name-in-module,no-member
from typing import cast
from mlir.dialects import std as std_dialect
from mlir.ir import IntegerAttr, IntegerType
from zamalang.dialects import hlfhe
from hdk.common.data_types.integers import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_integer,
value_is_encrypted_unsigned_integer,
@@ -113,6 +119,22 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
).result
V0_OPSET_CONVERSION_FUNCTIONS = {ir.Add: add, ir.Sub: sub, ir.Mul: mul}
def constant(node, _, __, ctx):
"""Converter function for constant inputs."""
if not value_is_clear_integer(node.outputs[0]):
raise TypeError("Don't support non-integer constants")
dtype = cast(Integer, node.outputs[0].data_type)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer")
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
V0_OPSET_CONVERSION_FUNCTIONS = {
ir.Add: add,
ir.Sub: sub,
ir.Mul: mul,
ir.ConstantInput: constant,
}
# pylint: enable=no-name-in-module,no-member

View File

@@ -1,15 +1,24 @@
"""Test converter functions"""
import pytest
from hdk.common.mlir.converters import add, mul, sub
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue
from hdk.common.mlir.converters import add, constant, mul, sub
class MockNode:
"""Mocking an intermediate node"""
def __init__(self, inputs=5, outputs=5):
self.inputs = [None for i in range(inputs)]
self.outputs = [None for i in range(outputs)]
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
if inputs is None:
self.inputs = [None for i in range(inputs_n)]
else:
self.inputs = inputs
if outputs is None:
self.outputs = [None for i in range(outputs_n)]
else:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul])
@@ -17,3 +26,15 @@ def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
converter(MockNode(2, 1), None, None, None)
def test_fail_non_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
constant(MockNode(outputs=[ClearValue(Float(32))]), None, None, None)
def test_fail_signed_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)

View File

@@ -18,16 +18,31 @@ def add(x, y):
return x + y
def constant_add(x):
"""Test constant add"""
return x + 5
def sub(x, y):
"""Test simple sub"""
return x - y
def constant_sub(x):
"""Test constant sub"""
return 8 - x
def mul(x, y):
"""Test simple mul"""
return x * y
def constant_mul(x):
"""Test constant mul"""
return x * 2
def sub_add_mul(x, y, z):
"""Test combination of ops"""
return z - y + x * z
@@ -60,6 +75,13 @@ def datagen(*args):
},
(range(0, 8), range(1, 4)),
),
(
constant_add,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
(
add,
{
@@ -84,6 +106,13 @@ def datagen(*args):
},
(range(5, 10), range(2, 6)),
),
(
constant_sub,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 5),),
),
(
mul,
{
@@ -92,6 +121,13 @@ def datagen(*args):
},
(range(1, 5), range(2, 8)),
),
(
constant_mul,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
(
mul,
{