feat: allow broadcast_to operation on scalars

This commit is contained in:
Umut
2022-07-27 09:08:07 +02:00
parent 710ee3408d
commit 48014ed60a
3 changed files with 14 additions and 5 deletions

View File

@@ -314,7 +314,7 @@ class GraphConverter:
"""
# pylint: disable=invalid-name
OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"]
OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"]
# pylint: enable=invalid-name
tensorized_scalars: Dict[Node, Node] = {}
@@ -322,10 +322,14 @@ class GraphConverter:
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
assert_that(len(node.inputs) == 2)
assert len(node.inputs) in {1, 2}
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
continue
if len(node.inputs) == 2:
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
continue
else:
if not node.inputs[0].is_scalar:
continue
pred_to_tensorize: Optional[Node] = None
pred_to_tensorize_index = 0

View File

@@ -256,7 +256,10 @@ class NodeConverter:
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
zeros = fhe.ZeroTensorOp(resulting_type).result
result = fhelinalg.AddEintOp(resulting_type, self.preds[0], zeros).result
if self.node.inputs[0].is_encrypted:
result = fhelinalg.AddEintOp(resulting_type, zeros, self.preds[0]).result
else:
result = fhelinalg.AddEintIntOp(resulting_type, zeros, self.preds[0]).result
# TODO: convert this to a single operation once it can be done
# (https://github.com/zama-ai/concrete-numpy-internal/issues/1610)

View File

@@ -11,6 +11,8 @@ import concrete.numpy as cnp
@pytest.mark.parametrize(
"from_shape,to_shape",
[
pytest.param((), (2,)),
pytest.param((), (2, 3)),
pytest.param((3,), (2, 3)),
pytest.param((3,), (4, 2, 3)),
pytest.param((1, 2), (4, 3, 2)),