mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): support constant inputs in mlir conversion
This commit is contained in:
@@ -7,8 +7,14 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
|
||||
- `ctx`: MLIR context
|
||||
"""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from typing import cast
|
||||
|
||||
from mlir.dialects import std as std_dialect
|
||||
from mlir.ir import IntegerAttr, IntegerType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
@@ -113,6 +119,22 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {ir.Add: add, ir.Sub: sub, ir.Mul: mul}
|
||||
def constant(node, _, __, ctx):
|
||||
"""Converter function for constant inputs."""
|
||||
if not value_is_clear_integer(node.outputs[0]):
|
||||
raise TypeError("Don't support non-integer constants")
|
||||
dtype = cast(Integer, node.outputs[0].data_type)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer")
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
ir.Add: add,
|
||||
ir.Sub: sub,
|
||||
ir.Mul: mul,
|
||||
ir.ConstantInput: constant,
|
||||
}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
"""Test converter functions"""
|
||||
import pytest
|
||||
|
||||
from hdk.common.mlir.converters import add, mul, sub
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue
|
||||
from hdk.common.mlir.converters import add, constant, mul, sub
|
||||
|
||||
|
||||
class MockNode:
|
||||
"""Mocking an intermediate node"""
|
||||
|
||||
def __init__(self, inputs=5, outputs=5):
|
||||
self.inputs = [None for i in range(inputs)]
|
||||
self.outputs = [None for i in range(outputs)]
|
||||
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
|
||||
if inputs is None:
|
||||
self.inputs = [None for i in range(inputs_n)]
|
||||
else:
|
||||
self.inputs = inputs
|
||||
if outputs is None:
|
||||
self.outputs = [None for i in range(outputs_n)]
|
||||
else:
|
||||
self.outputs = outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul])
|
||||
@@ -17,3 +26,15 @@ def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
converter(MockNode(2, 1), None, None, None)
|
||||
|
||||
|
||||
def test_fail_non_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
|
||||
constant(MockNode(outputs=[ClearValue(Float(32))]), None, None, None)
|
||||
|
||||
|
||||
def test_fail_signed_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
|
||||
constant(MockNode(outputs=[ClearValue(Integer(8, True))]), None, None, None)
|
||||
|
||||
@@ -18,16 +18,31 @@ def add(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
def constant_add(x):
|
||||
"""Test constant add"""
|
||||
return x + 5
|
||||
|
||||
|
||||
def sub(x, y):
|
||||
"""Test simple sub"""
|
||||
return x - y
|
||||
|
||||
|
||||
def constant_sub(x):
|
||||
"""Test constant sub"""
|
||||
return 8 - x
|
||||
|
||||
|
||||
def mul(x, y):
|
||||
"""Test simple mul"""
|
||||
return x * y
|
||||
|
||||
|
||||
def constant_mul(x):
|
||||
"""Test constant mul"""
|
||||
return x * 2
|
||||
|
||||
|
||||
def sub_add_mul(x, y, z):
|
||||
"""Test combination of ops"""
|
||||
return z - y + x * z
|
||||
@@ -60,6 +75,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
constant_add,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
@@ -84,6 +106,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(5, 10), range(2, 6)),
|
||||
),
|
||||
(
|
||||
constant_sub,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 5),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
@@ -92,6 +121,13 @@ def datagen(*args):
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
constant_mul,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user