refactor(mlir): re-write mlir conversion

This commit is contained in:
Umut
2021-11-11 17:32:38 +03:00
parent 6fec590e65
commit 239f66eb46
15 changed files with 736 additions and 1114 deletions

View File

@@ -0,0 +1,115 @@
"""Test file for MLIR conversion helpers."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
import pytest
import zamalang
from mlir.ir import Context, Location
from concrete.common.data_types import Float, SignedInteger, UnsignedInteger
from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
# pylint: enable=no-name-in-module
@pytest.mark.parametrize(
"integer,is_encrypted,expected_mlir_type_str",
[
pytest.param(SignedInteger(5), False, "i5"),
pytest.param(UnsignedInteger(5), False, "i5"),
pytest.param(SignedInteger(32), False, "i32"),
pytest.param(UnsignedInteger(32), False, "i32"),
pytest.param(SignedInteger(5), True, "!HLFHE.eint<5>"),
pytest.param(UnsignedInteger(5), True, "!HLFHE.eint<5>"),
],
)
def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str):
"""Test function for integer to MLIR type conversion."""
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str
@pytest.mark.parametrize(
"integer,is_encrypted,expected_error_message",
[
pytest.param(SignedInteger(32), True, "can't create eint with the given width"),
pytest.param(UnsignedInteger(32), True, "can't create eint with the given width"),
],
)
def test_fail_integer_to_mlir_type(integer, is_encrypted, expected_error_message):
"""Test function for failed integer to MLIR type conversion."""
with pytest.raises(ValueError) as excinfo:
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
integer_to_mlir_type(ctx, integer, is_encrypted)
assert str(excinfo.value) == expected_error_message
@pytest.mark.parametrize(
"value,expected_mlir_type_str",
[
pytest.param(ClearScalar(SignedInteger(5)), "i5"),
pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(SignedInteger(5)), "!HLFHE.eint<5>"),
pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"),
pytest.param(ClearScalar(UnsignedInteger(5)), "i5"),
pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(UnsignedInteger(5)), "!HLFHE.eint<5>"),
pytest.param(
EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"
),
],
)
def test_value_to_mlir_type(value, expected_mlir_type_str):
"""Test function for value to MLIR type conversion."""
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str
@pytest.mark.parametrize(
"value,expected_error_message",
[
pytest.param(
ClearScalar(Float(32)),
"ClearScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
ClearTensor(Float(32), shape=(2, 3)),
"ClearTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
pytest.param(
EncryptedScalar(Float(32)),
"EncryptedScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
EncryptedTensor(Float(32), shape=(2, 3)),
"EncryptedTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
pytest.param(
EncryptedScalar(UnsignedInteger(32)),
"EncryptedScalar<uint32> is not supported for MLIR conversion",
),
pytest.param(
EncryptedTensor(UnsignedInteger(32), shape=(2, 3)),
"EncryptedTensor<uint32, shape=(2, 3)> is not supported for MLIR conversion",
),
],
)
def test_fail_value_to_mlir_type(value, expected_error_message):
"""Test function for failed value to MLIR type conversion."""
with pytest.raises(TypeError) as excinfo:
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
value_to_mlir_type(ctx, value)
assert str(excinfo.value) == expected_error_message

View File

@@ -1,81 +0,0 @@
"""Test converter functions"""
import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar
class MockNode:
"""Mocking an intermediate node"""
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
if inputs is None:
self.inputs = [None for i in range(inputs_n)]
else:
self.inputs = inputs
if outputs is None:
self.outputs = [None for i in range(outputs_n)]
else:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
converter(MockNode(2, 1), None, None, None)
def test_fail_non_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None)
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
@pytest.mark.parametrize(
"input_node",
[
ClearScalar(Integer(8, True)),
ClearScalar(Integer(8, False)),
EncryptedScalar(Integer(8, True)),
],
)
def test_fail_tlu_input(input_node):
"""Test failing LUT converter with invalid input"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers inputs"
):
apply_lut(
MockNode(inputs=[input_node], outputs=[EncryptedScalar(Integer(8, False))]),
[None],
None,
None,
None,
)
@pytest.mark.parametrize(
"input_node",
[
ClearScalar(Integer(8, True)),
ClearScalar(Integer(8, False)),
EncryptedScalar(Integer(8, True)),
],
)
def test_fail_tlu_output(input_node):
"""Test failing LUT converter with invalid output"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers outputs"
):
apply_lut(
MockNode(inputs=[EncryptedScalar(Integer(8, False))], outputs=[input_node]),
[None],
None,
None,
None,
)

View File

@@ -1,392 +0,0 @@
"""Test file for conversion to MLIR"""
# pylint: disable=no-name-in-module,no-member
import itertools
import numpy
import pytest
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
from zamalang import compiler
from zamalang.dialects import hlfhe
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.table import LookupTable
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from concrete.common.values import ClearScalar, EncryptedScalar
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
from concrete.numpy.compile import compile_numpy_function_into_op_graph, prepare_op_graph_for_mlir
from concrete.numpy.np_mlir_converter import NPMLIRConverter
def add(x, y):
"""Test simple add"""
return x + y
def constant_add(x):
"""Test constant add"""
return x + 5
def sub(x, y):
"""Test simple sub"""
return x - y
def constant_sub(x):
"""Test constant sub"""
return 12 - x
def mul(x, y):
"""Test simple mul"""
return x * y
def constant_mul(x):
"""Test constant mul"""
return x * 2
def sub_add_mul(x, y, z):
"""Test combination of ops"""
return z - y + x * z
def ret_multiple(x, y, z):
"""Test return of multiple values"""
return x, y, z
def ret_multiple_different_order(x, y, z):
"""Test return of multiple values in a different order from input"""
return y, z, x
def lut(x):
"""Test lookup table"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
# TODO: remove workaround #359
def lut_more_bits_than_table_length(x, y):
"""Test lookup table when bit_width support longer LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x] + y
# TODO: remove workaround #359
def lut_less_bits_than_table_length(x):
"""Test lookup table when bit_width support smaller LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
def dot(x, y):
"""Test dot"""
return numpy.dot(x, y)
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
yield prod
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
add,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"y": ClearScalar(Integer(32, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
constant_add,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
add,
{
"x": ClearScalar(Integer(32, is_signed=False)),
"y": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
add,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(7, 15), range(1, 5)),
),
(
sub,
{
"x": ClearScalar(Integer(8, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(5, 10), range(2, 6)),
),
(
constant_sub,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
mul,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": ClearScalar(Integer(8, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
constant_mul,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
mul,
{
"x": ClearScalar(Integer(8, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
sub_add_mul,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(0, 8), range(1, 5), range(5, 12)),
),
(
ret_multiple,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
ret_multiple_different_order,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
lut,
{
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),
(
lut_more_bits_than_table_length,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"y": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8), range(0, 16)),
),
(
lut_less_bits_than_table_length,
{
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
inputset,
default_compilation_configuration,
)
prepare_op_graph_for_mlir(result_graph)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
dot,
{
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 4), range(0, 4)),
),
(
dot,
{
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 4), range(0, 4)),
),
],
)
def test_mlir_converter_dot_between_vectors(
func, args_dict, args_ranges, default_compilation_configuration
):
"""Test the conversion to MLIR by calling the parser from the compiler"""
assert len(args_dict["x"].shape) == 1
assert len(args_dict["y"].shape) == 1
n = args_dict["x"].shape[0]
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
(
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
for data in datagen(*args_ranges)
),
default_compilation_configuration,
)
prepare_op_graph_for_mlir(result_graph)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
def left_dot_with_constant(x):
return numpy.dot(x, numpy.array([1, 2]))
def right_dot_with_constant(x):
return numpy.dot(numpy.array([1, 2]), x)
left_graph = compile_numpy_function_into_op_graph(
left_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
prepare_op_graph_for_mlir(left_graph)
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_mlir = left_converter.convert(left_graph)
right_graph = compile_numpy_function_into_op_graph(
right_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
prepare_op_graph_for_mlir(right_graph)
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_mlir = right_converter.convert(right_graph)
# testing that this doesn't raise an error
compiler.round_trip(left_mlir)
compiler.round_trip(right_mlir)
def test_concrete_encrypted_integer_to_mlir_type():
"""Test conversion of EncryptedScalar into MLIR"""
value = EncryptedScalar(Integer(7, is_signed=False))
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
eint = converter.common_value_to_mlir_type(value)
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
@pytest.mark.parametrize("is_signed", [True, False])
def test_concrete_clear_integer_to_mlir_type(is_signed):
"""Test conversion of ClearScalar into MLIR"""
value = ClearScalar(Integer(5, is_signed=is_signed))
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context:
int_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
assert int_mlir == IntegerType.get_signed(5)
else:
assert int_mlir == IntegerType.get_signless(5)
@pytest.mark.parametrize("is_signed", [True, False])
@pytest.mark.parametrize(
"shape",
[
(5,),
(5, 8),
(-1, 5),
],
)
def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
"""Test conversion of ClearTensor into MLIR"""
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
element_type = IntegerType.get_signed(5)
else:
element_type = IntegerType.get_signless(5)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
@pytest.mark.parametrize(
"shape",
[
(5,),
(5, 8),
(-1, 5),
],
)
def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
"""Test conversion of EncryptedTensor into MLIR"""
value = EncryptedTensor(Integer(6, is_signed=False), shape)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
def test_failing_concrete_to_mlir_type():
"""Test failing conversion of an unsupported type into MLIR"""
value = "random"
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
converter.common_value_to_mlir_type(value)
# pylint: enable=no-name-in-module,no-member

View File

@@ -0,0 +1,85 @@
"""Test file for intermediate node to MLIR converter."""
import random
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function_to_compile,parameters,inputset,expected_error_type,expected_error_message",
[
pytest.param(
lambda x, y: x * y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)],
NotImplementedError,
"Multiplication "
"between "
"EncryptedScalar<uint6> "
"and "
"EncryptedScalar<uint6> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: x - y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)],
NotImplementedError,
"Subtraction "
"of "
"EncryptedScalar<uint3> "
"from "
"EncryptedScalar<uint3> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
"y": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(2,)),
numpy.random.randint(0, 2 ** 3, size=(2,)),
)
for _ in range(10)
]
+ [(numpy.array([7, 7]), numpy.array([7, 7]))],
NotImplementedError,
"Dot product "
"between "
"EncryptedTensor<uint7, shape=(2,)> "
"and "
"EncryptedTensor<uint7, shape=(2,)> "
"cannot be converted to MLIR yet",
),
],
)
def test_fail_node_conversion(
function_to_compile,
parameters,
inputset,
expected_error_type,
expected_error_message,
default_compilation_configuration,
):
"""Test function for failed intermediate node conversion."""
with pytest.raises(expected_error_type) as excinfo:
compile_numpy_function(
function_to_compile, parameters, inputset, default_compilation_configuration
)
assert str(excinfo.value) == expected_error_message

View File

@@ -519,6 +519,8 @@ def test_compile_function_multiple_outputs(
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]),
pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]),
pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]),
],
)
def test_compile_and_run_correctness(
@@ -548,6 +550,11 @@ def test_compile_and_run_correctness(
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x + 1,
{
@@ -566,6 +573,7 @@ def test_compile_and_run_correctness(
[7, 2],
[3, 6],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
@@ -590,9 +598,7 @@ def test_compile_and_run_correctness(
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
# this is easy with known constants but weird with variable things such as another input
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x, y: x + y,
{
@@ -619,7 +625,7 @@ def test_compile_and_run_correctness(
[8, 3],
[4, 7],
],
marks=pytest.mark.xfail(),
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x, y: x + y,
@@ -652,6 +658,11 @@ def test_compile_and_run_correctness(
[5, 9],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: 100 - x,
{
@@ -670,6 +681,7 @@ def test_compile_and_run_correctness(
[94, 99],
[98, 95],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
@@ -690,6 +702,11 @@ def test_compile_and_run_correctness(
[8, 25],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x * 2,
{
@@ -708,6 +725,7 @@ def test_compile_and_run_correctness(
[12, 2],
[4, 10],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
@@ -747,6 +765,36 @@ def test_compile_and_run_correctness(
[0, 2],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(x, 2),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(2, x),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_compile_and_run_tensor_correctness(
@@ -874,7 +922,7 @@ def test_compile_and_run_constant_dot_correctness(
default_compilation_configuration,
)
right_circuit = compile_numpy_function(
left,
right,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
default_compilation_configuration,