mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add support for np.broadcast_to
This commit is contained in:
@@ -87,6 +87,11 @@ class GraphConverter:
|
||||
assert_that(len(inputs) > 0)
|
||||
assert_that(all(input.is_scalar for input in inputs))
|
||||
|
||||
elif name == "broadcast_to":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted broadcasting is supported"
|
||||
|
||||
elif name == "concatenate":
|
||||
if not all(input.is_encrypted for input in inputs):
|
||||
return "only all encrypted concatenate is supported"
|
||||
|
||||
@@ -161,6 +161,9 @@ class NodeConverter:
|
||||
elif name == "array":
|
||||
result = self._convert_array()
|
||||
|
||||
elif name == "broadcast_to":
|
||||
result = self._convert_broadcast_to()
|
||||
|
||||
elif name == "concatenate":
|
||||
result = self._convert_concat()
|
||||
|
||||
@@ -242,6 +245,25 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def _convert_broadcast_to(self) -> OpResult:
|
||||
"""
|
||||
Convert "broadcast_to" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
zeros = fhe.ZeroTensorOp(resulting_type).result
|
||||
result = fhelinalg.AddEintOp(resulting_type, self.preds[0], zeros).result
|
||||
|
||||
# TODO: convert this to a single operation once it can be done
|
||||
# (https://github.com/zama-ai/concrete-numpy-internal/issues/1610)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_array(self) -> OpResult:
|
||||
"""
|
||||
Convert "array" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -307,6 +307,7 @@ class Node:
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
"broadcast_to",
|
||||
"concatenate",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
|
||||
@@ -186,6 +186,7 @@ class Tracer:
|
||||
np.bitwise_and,
|
||||
np.bitwise_or,
|
||||
np.bitwise_xor,
|
||||
np.broadcast_to,
|
||||
np.cbrt,
|
||||
np.ceil,
|
||||
np.clip,
|
||||
@@ -273,6 +274,9 @@ class Tracer:
|
||||
np.around: {
|
||||
"decimals",
|
||||
},
|
||||
np.broadcast_to: {
|
||||
"shape",
|
||||
},
|
||||
np.concatenate: {
|
||||
"axis",
|
||||
},
|
||||
@@ -396,7 +400,11 @@ class Tracer:
|
||||
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
|
||||
"""
|
||||
|
||||
if func is np.reshape:
|
||||
if func is np.broadcast_to:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["shape"] = args[1]
|
||||
elif func is np.reshape:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["newshape"] = args[1]
|
||||
|
||||
@@ -66,6 +66,7 @@ Some of these operations are not supported between two encrypted values. A detai
|
||||
* [np.bitwise\_and](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_and.html)
|
||||
* [np.bitwise\_or](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_or.html)
|
||||
* [np.bitwise\_xor](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_xor.html)
|
||||
* [np.broadcast\_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast\_to.html)
|
||||
* [np.cbrt](https://numpy.org/doc/stable/reference/generated/numpy.cbrt.html)
|
||||
* [np.ceil](https://numpy.org/doc/stable/reference/generated/numpy.ceil.html)
|
||||
* [np.clip](https://numpy.org/doc/stable/reference/generated/numpy.clip.html)
|
||||
|
||||
38
tests/execution/test_broadcast_to.py
Normal file
38
tests/execution/test_broadcast_to.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Tests of execution of broadcast to operation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"from_shape,to_shape",
|
||||
[
|
||||
pytest.param((3,), (2, 3)),
|
||||
pytest.param((3,), (4, 2, 3)),
|
||||
pytest.param((1, 2), (4, 3, 2)),
|
||||
pytest.param((3, 2), (4, 3, 2)),
|
||||
pytest.param((3, 1), (4, 3, 5)),
|
||||
pytest.param((3, 1, 4), (3, 2, 4)),
|
||||
pytest.param((3, 1, 1), (5, 3, 1, 3)),
|
||||
],
|
||||
)
|
||||
def test_broadcast_to(from_shape, to_shape, helpers):
|
||||
"""
|
||||
Test broadcast to.
|
||||
"""
|
||||
|
||||
def function(x):
|
||||
return np.broadcast_to(x, to_shape)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**2, size=from_shape) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2**2, size=from_shape)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -385,6 +385,22 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
%0 = x # ClearTensor<uint1, shape=(3, 2)>
|
||||
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.broadcast_to(x, shape=(3, 2)),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(2,)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(2,)>
|
||||
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user