mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: allow broadcast_to operation on scalars
This commit is contained in:
@@ -314,7 +314,7 @@ class GraphConverter:
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"]
|
||||
OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"]
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
tensorized_scalars: Dict[Node, Node] = {}
|
||||
@@ -322,10 +322,14 @@ class GraphConverter:
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
|
||||
assert_that(len(node.inputs) == 2)
|
||||
assert len(node.inputs) in {1, 2}
|
||||
|
||||
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
|
||||
continue
|
||||
if len(node.inputs) == 2:
|
||||
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
|
||||
continue
|
||||
else:
|
||||
if not node.inputs[0].is_scalar:
|
||||
continue
|
||||
|
||||
pred_to_tensorize: Optional[Node] = None
|
||||
pred_to_tensorize_index = 0
|
||||
|
||||
@@ -256,7 +256,10 @@ class NodeConverter:
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
zeros = fhe.ZeroTensorOp(resulting_type).result
|
||||
result = fhelinalg.AddEintOp(resulting_type, self.preds[0], zeros).result
|
||||
if self.node.inputs[0].is_encrypted:
|
||||
result = fhelinalg.AddEintOp(resulting_type, zeros, self.preds[0]).result
|
||||
else:
|
||||
result = fhelinalg.AddEintIntOp(resulting_type, zeros, self.preds[0]).result
|
||||
|
||||
# TODO: convert this to a single operation once it can be done
|
||||
# (https://github.com/zama-ai/concrete-numpy-internal/issues/1610)
|
||||
|
||||
@@ -11,6 +11,8 @@ import concrete.numpy as cnp
|
||||
@pytest.mark.parametrize(
|
||||
"from_shape,to_shape",
|
||||
[
|
||||
pytest.param((), (2,)),
|
||||
pytest.param((), (2, 3)),
|
||||
pytest.param((3,), (2, 3)),
|
||||
pytest.param((3,), (4, 2, 3)),
|
||||
pytest.param((1, 2), (4, 3, 2)),
|
||||
|
||||
Reference in New Issue
Block a user