feat(mlir): support constant inputs in mlir conversion

This commit is contained in:
youben11
2021-08-16 10:41:48 +01:00
committed by Ayoub Benaissa
parent 8fbe5dab4d
commit 3922bfe9b4
3 changed files with 84 additions and 5 deletions

View File

@@ -1,15 +1,24 @@
"""Test converter functions"""
import pytest
from hdk.common.mlir.converters import add, mul, sub
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue
from hdk.common.mlir.converters import add, constant, mul, sub
class MockNode:
"""Mocking an intermediate node"""
def __init__(self, inputs=5, outputs=5):
self.inputs = [None for i in range(inputs)]
self.outputs = [None for i in range(outputs)]
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
if inputs is None:
self.inputs = [None for i in range(inputs_n)]
else:
self.inputs = inputs
if outputs is None:
self.outputs = [None for i in range(outputs_n)]
else:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul])
@@ -17,3 +26,15 @@ def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
converter(MockNode(2, 1), None, None, None)
def test_fail_non_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
constant(MockNode(outputs=[ClearValue(Float(32))]), None, None, None)
def test_fail_signed_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)

View File

@@ -18,16 +18,31 @@ def add(x, y):
return x + y
def constant_add(x):
"""Test constant add"""
return x + 5
def sub(x, y):
"""Test simple sub"""
return x - y
def constant_sub(x):
"""Test constant sub"""
return 8 - x
def mul(x, y):
"""Test simple mul"""
return x * y
def constant_mul(x):
"""Test constant mul"""
return x * 2
def sub_add_mul(x, y, z):
"""Test combination of ops"""
return z - y + x * z
@@ -60,6 +75,13 @@ def datagen(*args):
},
(range(0, 8), range(1, 4)),
),
(
constant_add,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
(
add,
{
@@ -84,6 +106,13 @@ def datagen(*args):
},
(range(5, 10), range(2, 6)),
),
(
constant_sub,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 5),),
),
(
mul,
{
@@ -92,6 +121,13 @@ def datagen(*args):
},
(range(1, 5), range(2, 8)),
),
(
constant_mul,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
(
mul,
{