feat: add support for np.broadcast_to

This commit is contained in:
Umut
2022-07-21 11:50:35 +02:00
parent 3c52b89747
commit e398a4fbd0
7 changed files with 92 additions and 1 deletions

View File

@@ -87,6 +87,11 @@ class GraphConverter:
assert_that(len(inputs) > 0)
assert_that(all(input.is_scalar for input in inputs))
elif name == "broadcast_to":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted broadcasting is supported"
elif name == "concatenate":
if not all(input.is_encrypted for input in inputs):
return "only all encrypted concatenate is supported"

View File

@@ -161,6 +161,9 @@ class NodeConverter:
elif name == "array":
result = self._convert_array()
elif name == "broadcast_to":
result = self._convert_broadcast_to()
elif name == "concatenate":
result = self._convert_concat()
@@ -242,6 +245,25 @@ class NodeConverter:
return result
def _convert_broadcast_to(self) -> OpResult:
"""
Convert "broadcast_to" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
zeros = fhe.ZeroTensorOp(resulting_type).result
result = fhelinalg.AddEintOp(resulting_type, self.preds[0], zeros).result
# TODO: convert this to a single operation once it can be done
# (https://github.com/zama-ai/concrete-numpy-internal/issues/1610)
return result
def _convert_array(self) -> OpResult:
"""
Convert "array" node to its corresponding MLIR representation.

View File

@@ -307,6 +307,7 @@ class Node:
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",
"broadcast_to",
"concatenate",
"conv1d",
"conv2d",

View File

@@ -186,6 +186,7 @@ class Tracer:
np.bitwise_and,
np.bitwise_or,
np.bitwise_xor,
np.broadcast_to,
np.cbrt,
np.ceil,
np.clip,
@@ -273,6 +274,9 @@ class Tracer:
np.around: {
"decimals",
},
np.broadcast_to: {
"shape",
},
np.concatenate: {
"axis",
},
@@ -396,7 +400,11 @@ class Tracer:
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
"""
if func is np.reshape:
if func is np.broadcast_to:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["shape"] = args[1]
elif func is np.reshape:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["newshape"] = args[1]

View File

@@ -66,6 +66,7 @@ Some of these operations are not supported between two encrypted values. A detai
* [np.bitwise\_and](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_and.html)
* [np.bitwise\_or](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_or.html)
* [np.bitwise\_xor](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_xor.html)
* [np.broadcast\_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast\_to.html)
* [np.cbrt](https://numpy.org/doc/stable/reference/generated/numpy.cbrt.html)
* [np.ceil](https://numpy.org/doc/stable/reference/generated/numpy.ceil.html)
* [np.clip](https://numpy.org/doc/stable/reference/generated/numpy.clip.html)

View File

@@ -0,0 +1,38 @@
"""
Tests of execution of broadcast to operation.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"from_shape,to_shape",
[
pytest.param((3,), (2, 3)),
pytest.param((3,), (4, 2, 3)),
pytest.param((1, 2), (4, 3, 2)),
pytest.param((3, 2), (4, 3, 2)),
pytest.param((3, 1), (4, 3, 5)),
pytest.param((3, 1, 4), (3, 2, 4)),
pytest.param((3, 1, 1), (5, 3, 1, 3)),
],
)
def test_broadcast_to(from_shape, to_shape, helpers):
"""
Test broadcast to.
"""
def function(x):
return np.broadcast_to(x, to_shape)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**2, size=from_shape) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**2, size=from_shape)
helpers.check_execution(circuit, function, sample)

View File

@@ -385,6 +385,22 @@ Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(3, 2)>
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: np.broadcast_to(x, shape=(3, 2)),
{"x": "clear"},
[np.random.randint(0, 2, size=(2,)) for _ in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(2,)>
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
return %1
""", # noqa: E501