""" Tests of `Node` class. """ import numpy as np import pytest from concrete.fhe.dtypes import UnsignedInteger from concrete.fhe.representation import Node from concrete.fhe.values import ClearScalar, EncryptedScalar, EncryptedTensor, ValueDescription @pytest.mark.parametrize( "constant,expected_error,expected_message", [ pytest.param( "abc", ValueError, "Constant 'abc' is not supported", ), ], ) def test_node_bad_constant(constant, expected_error, expected_message): """ Test `constant` function of `Node` class with bad parameters. """ with pytest.raises(expected_error) as excinfo: Node.constant(constant) assert str(excinfo.value) == expected_message @pytest.mark.parametrize( "node,args,expected_error,expected_message", [ pytest.param( Node.constant(1), ["abc"], ValueError, "Evaluation of constant '1' node using 'abc' failed " "because of invalid number of arguments", ), pytest.param( Node.generic( name="add", inputs=[ValueDescription.of(4), ValueDescription.of(10, is_encrypted=True)], output=ValueDescription.of(14), operation=lambda x, y: x + y, ), ["abc"], ValueError, "Evaluation of generic 'add' node using 'abc' failed " "because of invalid number of arguments", ), pytest.param( Node.generic( name="add", inputs=[ValueDescription.of(4), ValueDescription.of(10, is_encrypted=True)], output=ValueDescription.of(14), operation=lambda x, y: x + y, ), ["abc", "def"], ValueError, "Evaluation of generic 'add' node using 'abc', 'def' failed " "because argument 'abc' is not valid", ), pytest.param( Node.generic( name="add", inputs=[ValueDescription.of([3, 4]), ValueDescription.of(10, is_encrypted=True)], output=ValueDescription.of([13, 14]), operation=lambda x, y: x + y, ), [[1, 2, 3, 4], 10], ValueError, "Evaluation of generic 'add' node using [1, 2, 3, 4], 10 failed " "because argument [1, 2, 3, 4] does not have the expected shape of (2,)", ), pytest.param( Node.generic( name="unknown", inputs=[], output=ValueDescription.of(10), operation=lambda: "abc", ), [], ValueError, "Evaluation of generic 'unknown' node resulted in 'abc' of type str " "which is not acceptable either because of the type or because of overflow", ), pytest.param( Node.generic( name="unknown", inputs=[], output=ValueDescription.of(10), operation=lambda: np.array(["abc", "def"]), ), [], ValueError, "Evaluation of generic 'unknown' node resulted in array(['abc', 'def'], dtype='