diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index dbdd9c457..aab61e13d 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -314,7 +314,7 @@ class GraphConverter: """ # pylint: disable=invalid-name - OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"] + OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"] # pylint: enable=invalid-name tensorized_scalars: Dict[Node, Node] = {} @@ -322,10 +322,14 @@ class GraphConverter: nx_graph = graph.graph for node in list(nx_graph.nodes): if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE: - assert_that(len(node.inputs) == 2) + assert len(node.inputs) in {1, 2} - if set(inp.is_scalar for inp in node.inputs) != {True, False}: - continue + if len(node.inputs) == 2: + if set(inp.is_scalar for inp in node.inputs) != {True, False}: + continue + else: + if not node.inputs[0].is_scalar: + continue pred_to_tensorize: Optional[Node] = None pred_to_tensorize_index = 0 diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 667f6184f..bcadb5448 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -256,7 +256,10 @@ class NodeConverter: resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) zeros = fhe.ZeroTensorOp(resulting_type).result - result = fhelinalg.AddEintOp(resulting_type, self.preds[0], zeros).result + if self.node.inputs[0].is_encrypted: + result = fhelinalg.AddEintOp(resulting_type, zeros, self.preds[0]).result + else: + result = fhelinalg.AddEintIntOp(resulting_type, zeros, self.preds[0]).result # TODO: convert this to a single operation once it can be done # (https://github.com/zama-ai/concrete-numpy-internal/issues/1610) diff --git a/tests/execution/test_broadcast_to.py b/tests/execution/test_broadcast_to.py index aeba83092..175d0a6f4 100644 --- a/tests/execution/test_broadcast_to.py +++ b/tests/execution/test_broadcast_to.py @@ -11,6 +11,8 @@ import concrete.numpy as cnp @pytest.mark.parametrize( "from_shape,to_shape", [ + pytest.param((), (2,)), + pytest.param((), (2, 3)), pytest.param((3,), (2, 3)), pytest.param((3,), (4, 2, 3)), pytest.param((1, 2), (4, 3, 2)),