mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compilation): implement MLIR conversion of constant arrays
This commit is contained in:
@@ -11,18 +11,19 @@ from typing import cast
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import numpy as np
|
||||
from mlir.dialects import std as std_dialect
|
||||
from mlir.ir import DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from ...common.data_types.integers import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
@@ -123,14 +124,44 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def constant(node, _, __, ctx):
|
||||
"""Convert a constant inputs."""
|
||||
if not value_is_clear_scalar_integer(node.outputs[0]):
|
||||
raise TypeError("Don't support non-integer constants")
|
||||
dtype = cast(Integer, node.outputs[0].dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer")
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
|
||||
"""Convert a constant input."""
|
||||
value = node.outputs[0]
|
||||
|
||||
if value_is_clear_scalar_integer(value):
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer")
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
|
||||
|
||||
if value_is_clear_tensor_integer(value):
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer tensor")
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
vec_type = RankedTensorType.get(value.shape, int_type)
|
||||
|
||||
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
|
||||
# provided by LLVM
|
||||
|
||||
# `DenseElementsAttr` should have been used instead but it's impossible to assign
|
||||
# custom bit-widths using it (e.g., uint5)
|
||||
|
||||
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
|
||||
# we use `Attribute.parse` to let the underlying library do it by itself
|
||||
|
||||
value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}")
|
||||
return std_dialect.ConstantOp(vec_type, value_attr).result
|
||||
|
||||
raise TypeError(f"Don't support {value} constants")
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar
|
||||
|
||||
|
||||
class MockNode:
|
||||
@@ -30,14 +30,19 @@ def test_failing_converter(converter):
|
||||
|
||||
def test_fail_non_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
|
||||
with pytest.raises(TypeError, match=r"Don't support .* constants"):
|
||||
constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None)
|
||||
|
||||
with pytest.raises(TypeError, match=r"Don't support .* constants"):
|
||||
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
|
||||
|
||||
|
||||
def test_fail_signed_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
|
||||
constant(MockNode(outputs=[ClearScalar(Integer(8, True))]), None, None, None)
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer tensor"):
|
||||
constant(MockNode(outputs=[ClearTensor(Integer(8, True), shape=(2,))]), None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -267,6 +267,36 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
def test_mlir_converter_dot_vector_and_constant():
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
|
||||
def left_dot_with_constant(x):
|
||||
return numpy.dot(x, numpy.array([1, 2]))
|
||||
|
||||
def right_dot_with_constant(x):
|
||||
return numpy.dot(numpy.array([1, 2]), x)
|
||||
|
||||
left_graph = compile_numpy_function_into_op_graph(
|
||||
left_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
)
|
||||
left_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_mlir = left_converter.convert(left_graph)
|
||||
|
||||
right_graph = compile_numpy_function_into_op_graph(
|
||||
right_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
)
|
||||
right_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_mlir = right_converter.convert(right_graph)
|
||||
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(left_mlir)
|
||||
compiler.round_trip(right_mlir)
|
||||
|
||||
|
||||
def test_concrete_encrypted_integer_to_mlir_type():
|
||||
"""Test conversion of EncryptedScalar into MLIR"""
|
||||
value = EncryptedScalar(Integer(7, is_signed=False))
|
||||
|
||||
@@ -161,11 +161,11 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
8,
|
||||
6,
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
16,
|
||||
10,
|
||||
(0, 3),
|
||||
),
|
||||
],
|
||||
@@ -173,18 +173,22 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_and_run_dot_correctness(size, input_range):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def data_gen(input_range, size):
|
||||
for _ in range(1000):
|
||||
low, high = input_range
|
||||
args = [
|
||||
numpy.array([random.randint(low, high) for _ in range(size)]) for __ in range(2)
|
||||
]
|
||||
low, high = input_range
|
||||
shape = (size,)
|
||||
|
||||
yield args
|
||||
inputset = [
|
||||
(numpy.zeros(shape, dtype=numpy.uint32), numpy.zeros(shape, dtype=numpy.uint32)),
|
||||
(
|
||||
numpy.ones(shape, dtype=numpy.uint32) * high,
|
||||
numpy.ones(shape, dtype=numpy.uint32) * high,
|
||||
),
|
||||
]
|
||||
for _ in range(8):
|
||||
inputset.append((numpy.random.randint(low, high + 1), numpy.random.randint(low, high + 1)))
|
||||
|
||||
function_parameters = {
|
||||
"x": EncryptedTensor(Integer(64, False), (size,)),
|
||||
"y": ClearTensor(Integer(64, False), (size,)),
|
||||
"x": EncryptedTensor(Integer(64, False), shape),
|
||||
"y": ClearTensor(Integer(64, False), shape),
|
||||
}
|
||||
|
||||
def function(x, y):
|
||||
@@ -193,14 +197,71 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(input_range, size),
|
||||
inputset,
|
||||
)
|
||||
|
||||
low, high = input_range
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size,input_range",
|
||||
[
|
||||
pytest.param(
|
||||
1,
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
4,
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
6,
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
10,
|
||||
(0, 3),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
shape = (size,)
|
||||
|
||||
inputset = [
|
||||
(numpy.zeros(shape, dtype=numpy.uint32),),
|
||||
(numpy.ones(shape, dtype=numpy.uint32) * high,),
|
||||
]
|
||||
for _ in range(8):
|
||||
inputset.append((numpy.random.randint(low, high + 1),))
|
||||
|
||||
constant = numpy.random.randint(low, high + 1, size=shape)
|
||||
|
||||
def left(x):
|
||||
return numpy.dot(x, constant)
|
||||
|
||||
def right(x):
|
||||
return numpy.dot(constant, x)
|
||||
|
||||
left_circuit = compile_numpy_function(
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
)
|
||||
right_circuit = compile_numpy_function(
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
|
||||
assert left_circuit.run(*args) == left(*args)
|
||||
assert right_circuit.run(*args) == right(*args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user