mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): implement mlir conversion of reshape and flatten
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
import numpy
|
||||
from mlir.dialects import arith, tensor
|
||||
from mlir.dialects import arith, linalg, tensor
|
||||
from mlir.ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
@@ -117,7 +117,11 @@ class IntermediateNodeConverter:
|
||||
result = self.convert_dot()
|
||||
|
||||
elif isinstance(self.node, GenericFunction):
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
if self.node.op_name in ["flatten", "reshape"]:
|
||||
# notice flatten() == reshape(-1) and convert_reshape can handle that
|
||||
result = self.convert_reshape()
|
||||
else:
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
elif isinstance(self.node, IndexConstant):
|
||||
result = self.convert_index_constant()
|
||||
@@ -481,6 +485,129 @@ class IntermediateNodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_reshape(self) -> OpResult:
|
||||
"""Convert a "reshape" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
assert_true(isinstance(self.node.inputs[0], TensorValue))
|
||||
input_shape = cast(TensorValue, self.node.inputs[0]).shape
|
||||
|
||||
assert_true(isinstance(self.node.outputs[0], TensorValue))
|
||||
output_shape = cast(TensorValue, self.node.outputs[0]).shape
|
||||
|
||||
pred = self.preds[0]
|
||||
if input_shape == output_shape:
|
||||
return pred
|
||||
|
||||
# we can either collapse or expand, which changes the number of dimensions
|
||||
# this is a limitation of the current compiler and it will be improved in the future (#1060)
|
||||
can_be_converted_directly = len(input_shape) != len(output_shape)
|
||||
|
||||
reassociation: List[List[int]] = []
|
||||
if can_be_converted_directly:
|
||||
if len(output_shape) == 1:
|
||||
# output is 1 dimensional so collapse every dimension into the same dimension
|
||||
reassociation.append(list(range(len(input_shape))))
|
||||
else:
|
||||
# input is m dimensional
|
||||
# output is n dimensional
|
||||
# and m is different than n
|
||||
|
||||
# we don't want to duplicate code so we forget about input and output
|
||||
# and we focus on smaller shape and bigger shape
|
||||
|
||||
smaller_shape, bigger_shape = (
|
||||
(output_shape, input_shape)
|
||||
if len(output_shape) < len(input_shape)
|
||||
else (input_shape, output_shape)
|
||||
)
|
||||
s_index, b_index = 0, 0
|
||||
|
||||
# now we will figure out how to group the bigger shape to get the smaller shape
|
||||
# think of the algorithm below as
|
||||
# keep merging the dimensions of the bigger shape
|
||||
# until we have a match on the smaller shape
|
||||
# then try to match the next dimension of the smaller shape
|
||||
# if all dimensions of the smaller shape is matched
|
||||
# we can convert it
|
||||
|
||||
group = []
|
||||
size = 1
|
||||
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
|
||||
# dimension `b_index` of `bigger_shape` belongs to current group
|
||||
group.append(b_index)
|
||||
|
||||
# and current group has `size * bigger_shape[b_index]` elements now
|
||||
size *= bigger_shape[b_index]
|
||||
|
||||
# if current group size matches the dimension `s_index` of `smaller_shape`
|
||||
if size == smaller_shape[s_index]:
|
||||
# we finalize this group and reset everything
|
||||
size = 1
|
||||
reassociation.append(group)
|
||||
group = []
|
||||
|
||||
# now try to match the next dimension of `smaller_shape`
|
||||
s_index += 1
|
||||
|
||||
# now process the next dimension of `bigger_shape`
|
||||
b_index += 1
|
||||
|
||||
# handle the case where bigger shape has proceeding 1s
|
||||
# e.g., (5,) -> (5, 1)
|
||||
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
|
||||
reassociation[-1].append(b_index)
|
||||
b_index += 1
|
||||
|
||||
# if not all dimensions of both shapes are processed exactly
|
||||
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
|
||||
# we cannot convert
|
||||
can_be_converted_directly = False
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
|
||||
if can_be_converted_directly:
|
||||
reassociation_attr = ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
|
||||
for group in reassociation
|
||||
]
|
||||
)
|
||||
if len(output_shape) < len(input_shape):
|
||||
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
|
||||
flattened_type = value_to_mlir_type(
|
||||
self.ctx,
|
||||
TensorValue(
|
||||
self.node.inputs[0].dtype,
|
||||
self.node.inputs[0].is_encrypted,
|
||||
(numpy.prod(input_shape),),
|
||||
),
|
||||
)
|
||||
flattened_result = linalg.TensorCollapseShapeOp(
|
||||
flattened_type,
|
||||
pred,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
return linalg.TensorExpandShapeOp(
|
||||
resulting_type,
|
||||
flattened_result,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_sub(self) -> OpResult:
|
||||
"""Convert a Sub node to its corresponding MLIR representation.
|
||||
|
||||
|
||||
@@ -89,7 +89,8 @@ def check_node_compatibility_with_mlir(
|
||||
== 1
|
||||
)
|
||||
else:
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
if node.op_name not in ["flatten", "reshape"]:
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
assert_true(len(inputs) == 2)
|
||||
|
||||
@@ -362,13 +362,13 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def reshape(self, arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
"""Trace x.reshape.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self.numpy_reshape(self, arg1, **kwargs)
|
||||
return self.numpy_reshape(self, newshape, **kwargs)
|
||||
|
||||
def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.reshape.
|
||||
@@ -390,15 +390,13 @@ class NPTracer(BaseTracer):
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work,
|
||||
# while classical form is numpy.reshape(x, (170,))
|
||||
newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,)
|
||||
|
||||
# Check shape compatibility
|
||||
assert_true(
|
||||
numpy.product(newshape) == first_arg_output.size,
|
||||
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
|
||||
)
|
||||
try:
|
||||
# calculate a newshape using numpy to handle edge cases such as `-1`s within new shape
|
||||
newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})"
|
||||
) from error
|
||||
|
||||
reshape_is_fusable = newshape == first_arg_output.shape
|
||||
|
||||
|
||||
@@ -1675,23 +1675,6 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.reshape(x, (2, 6)),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
|
||||
[numpy.random.randint(0, 2 ** 3, size=(3, 4)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
|
||||
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
|
||||
265
tests/numpy/test_compile_memory_operations.py
Normal file
265
tests/numpy/test_compile_memory_operations.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Test module for memory operations."""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5, 9, 1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5], [9], [1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
|
||||
[[0, 1, 1], [2, 2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape(-1),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 2, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((4, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((6, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 3, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((2, 3, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 4, 30)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((6, 4, 30)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 60, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((2, 60, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 6, -1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, 6, -1)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, -1, 12)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, -1, 12)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, 18, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_run_correctness(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""
|
||||
Test correctness of results when running a compiled function with memory operators.
|
||||
|
||||
e.g.,
|
||||
- flatten
|
||||
- reshape
|
||||
"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
actual = circuit.run(numpy.array(test_input, dtype=numpy.uint8))
|
||||
expected = numpy.array(expected_output, dtype=numpy.uint8)
|
||||
|
||||
check_array_equality(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,error,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, -1, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, -1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_failed_compilation(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
error,
|
||||
match,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""
|
||||
Test compilation failures of compiled function with memory operations.
|
||||
|
||||
e.g.,
|
||||
- reshape
|
||||
"""
|
||||
|
||||
with pytest.raises(error) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value) == match
|
||||
), f"""
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{excinfo.value}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{match}
|
||||
|
||||
"""
|
||||
@@ -644,7 +644,7 @@ def test_tracing_numpy_calls(
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
@@ -985,7 +985,7 @@ def test_tracing_ndarray_calls(
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
@@ -294,7 +294,7 @@ def test_tracing_numpy_calls(
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -178,7 +178,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user