mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): support constant inputs in mlir conversion
This commit is contained in:
@@ -1,15 +1,24 @@
|
||||
"""Test converter functions"""
|
||||
import pytest
|
||||
|
||||
from hdk.common.mlir.converters import add, mul, sub
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue
|
||||
from hdk.common.mlir.converters import add, constant, mul, sub
|
||||
|
||||
|
||||
class MockNode:
|
||||
"""Mocking an intermediate node"""
|
||||
|
||||
def __init__(self, inputs=5, outputs=5):
|
||||
self.inputs = [None for i in range(inputs)]
|
||||
self.outputs = [None for i in range(outputs)]
|
||||
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
|
||||
if inputs is None:
|
||||
self.inputs = [None for i in range(inputs_n)]
|
||||
else:
|
||||
self.inputs = inputs
|
||||
if outputs is None:
|
||||
self.outputs = [None for i in range(outputs_n)]
|
||||
else:
|
||||
self.outputs = outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul])
|
||||
@@ -17,3 +26,15 @@ def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
converter(MockNode(2, 1), None, None, None)
|
||||
|
||||
|
||||
def test_fail_non_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
|
||||
constant(MockNode(outputs=[ClearValue(Float(32))]), None, None, None)
|
||||
|
||||
|
||||
def test_fail_signed_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
|
||||
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)
|
||||
|
||||
@@ -18,16 +18,31 @@ def add(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
def constant_add(x):
|
||||
"""Test constant add"""
|
||||
return x + 5
|
||||
|
||||
|
||||
def sub(x, y):
|
||||
"""Test simple sub"""
|
||||
return x - y
|
||||
|
||||
|
||||
def constant_sub(x):
|
||||
"""Test constant sub"""
|
||||
return 8 - x
|
||||
|
||||
|
||||
def mul(x, y):
|
||||
"""Test simple mul"""
|
||||
return x * y
|
||||
|
||||
|
||||
def constant_mul(x):
|
||||
"""Test constant mul"""
|
||||
return x * 2
|
||||
|
||||
|
||||
def sub_add_mul(x, y, z):
|
||||
"""Test combination of ops"""
|
||||
return z - y + x * z
|
||||
@@ -60,6 +75,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
constant_add,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
@@ -84,6 +106,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(5, 10), range(2, 6)),
|
||||
),
|
||||
(
|
||||
constant_sub,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 5),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
@@ -92,6 +121,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
constant_mul,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user