feat: add encrypted zeros and ones functions

This commit is contained in:
Umut
2022-06-10 15:23:28 +02:00
parent 59cacc35df
commit 57aae5afdb
9 changed files with 294 additions and 3 deletions

View File

@@ -15,6 +15,6 @@ from .compilation import (
Server,
compiler,
)
from .extensions import LookupTable, univariate
from .extensions import LookupTable, one, ones, univariate, zero, zeros
from .mlir.utils import MAXIMUM_BIT_WIDTH
from .representation import Graph

View File

@@ -2,5 +2,7 @@
Provide additional features that are not present in numpy.
"""
from .ones import one, ones
from .table import LookupTable
from .univariate import univariate
from .zeros import zero, zeros

View File

@@ -0,0 +1,56 @@
"""
Declaration of `ones` and `one` functions, to simplify creation of encrypted ones.
"""
from typing import Tuple, Union
import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array of ones.
Args:
shape (Tuple[int, ...]):
shape of the array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray filled with ones otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
numpy_ones = np.ones(shape, dtype=np.int64)
if is_tracing:
computation = Node.generic(
"ones",
[],
Value.of(numpy_ones, is_encrypted=True),
lambda: np.ones(shape, dtype=np.int64),
)
return Tracer(computation, [])
return numpy_ones
def one() -> Union[np.ndarray, Tracer]:
"""
Create an encrypted scalar with the value of one.
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with one otherwise
"""
return ones(())

View File

@@ -0,0 +1,56 @@
"""
Declaration of `zeros` and `zero` functions, to simplify creation of encrypted zeros.
"""
from typing import Tuple, Union
import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array of zeros.
Args:
shape (Tuple[int, ...]):
shape of the array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray filled with zeros otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
numpy_zeros = np.zeros(shape, dtype=np.int64)
if is_tracing:
computation = Node.generic(
"zeros",
[],
Value.of(numpy_zeros, is_encrypted=True),
lambda: np.zeros(shape, dtype=np.int64),
)
return Tracer(computation, [])
return numpy_zeros
def zero() -> Union[np.ndarray, Tracer]:
"""
Create an encrypted scalar with the value of zero.
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with zero otherwise
"""
return zeros(())

View File

@@ -106,6 +106,9 @@ class GraphConverter:
if not inputs[0].is_encrypted:
return "only encrypted negation is supported"
elif name == "ones":
assert_that(len(inputs) == 0)
elif name == "reshape":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
@@ -126,7 +129,11 @@ class GraphConverter:
if not inputs[0].is_encrypted:
return "only encrypted transpose is supported"
elif name == "zeros":
assert_that(len(inputs) == 0)
else:
assert_that(node.converted_to_table_lookup)
variable_input_indices = [
idx
for idx, pred in enumerate(graph.ordered_preds_of(node))
@@ -135,7 +142,7 @@ class GraphConverter:
if len(variable_input_indices) != 1:
return "only single input table lookups are supported"
if all(input.is_clear for input in inputs):
if len(inputs) > 0 and all(input.is_clear for input in inputs):
return "one of the operands must be encrypted"
return None

View File

@@ -124,7 +124,7 @@ class NodeConverter:
in-memory MLIR representation corresponding to `self.node`
"""
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches,too-many-statements
if self.node.operation == Operation.Constant:
result = self.convert_constant()
@@ -163,6 +163,9 @@ class NodeConverter:
elif name == "negative":
result = self.convert_neg()
elif name == "ones":
result = self.convert_ones()
elif name == "reshape":
result = self.convert_reshape()
@@ -175,7 +178,11 @@ class NodeConverter:
elif name == "transpose":
result = self.convert_transpose()
elif name == "zeros":
result = self.convert_zeros()
else:
assert_that(self.node.converted_to_table_lookup)
result = self.convert_tlu()
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
@@ -458,6 +465,49 @@ class NodeConverter:
return result
def convert_ones(self) -> OpResult:
"""
Convert "ones" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
assert isinstance(self.node.output.dtype, Integer)
bit_width = self.node.output.dtype.bit_width
if self.node.output.is_scalar:
constant_value = Value(
Integer(is_signed=False, bit_width=bit_width + 1),
shape=(),
is_encrypted=False,
)
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
constant_attr = IntegerAttr.get(constant_type, 1)
zero = fhe.ZeroEintOp(resulting_type).result
one = arith.ConstantOp(constant_type, constant_attr).result
result = fhe.AddEintIntOp(resulting_type, zero, one).result
else:
constant_value = Value(
Integer(is_signed=False, bit_width=bit_width + 1),
shape=(1,),
is_encrypted=False,
)
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
constant_attr = Attribute.parse(f"dense<[1]> : {constant_type}")
zeros = fhe.ZeroTensorOp(resulting_type).result
ones = arith.ConstantOp(constant_type, constant_attr).result
result = fhelinalg.AddEintIntOp(resulting_type, zeros, ones).result
return result
def convert_reshape(self) -> OpResult:
"""
Convert "reshape" node to its corresponding MLIR representation.
@@ -843,3 +893,21 @@ class NodeConverter:
preds = self.preds
return fhelinalg.TransposeOp(resulting_type, *preds).result
def convert_zeros(self) -> OpResult:
"""
Convert "zeros" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
if self.node.output.is_scalar:
result = fhe.ZeroEintOp(resulting_type).result
else:
result = fhe.ZeroTensorOp(resulting_type).result
return result

View File

@@ -313,10 +313,12 @@ class Node:
"matmul",
"multiply",
"negative",
"ones",
"reshape",
"subtract",
"sum",
"transpose",
"zeros",
]
@property
@@ -336,5 +338,7 @@ class Node:
"add",
"multiply",
"negative",
"ones",
"subtract",
"zeros",
]

View File

@@ -0,0 +1,49 @@
"""
Tests of execution of ones operation.
"""
import random
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x: cnp.one() + x,
id="cnp.one() + x",
),
pytest.param(
lambda x: cnp.ones(()) + x,
id="cnp.ones(()) + x",
),
pytest.param(
lambda x: cnp.ones(10) + x,
id="cnp.ones(10) + x",
),
pytest.param(
lambda x: cnp.ones((10,)) + x,
id="cnp.ones((10,)) + x",
),
pytest.param(
lambda x: cnp.ones((3, 2)) + x,
id="cnp.ones((3, 2)) + x",
),
],
)
def test_ones(function, helpers):
"""
Test ones.
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = range(10)
circuit = compiler.compile(inputset, configuration)
sample = random.randint(0, 11)
helpers.check_execution(circuit, function, sample)

View File

@@ -0,0 +1,49 @@
"""
Tests of execution of zeros operation.
"""
import random
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x: cnp.zero() + x,
id="cnp.zero() + x",
),
pytest.param(
lambda x: cnp.zeros(()) + x,
id="cnp.zeros(()) + x",
),
pytest.param(
lambda x: cnp.zeros(10) + x,
id="cnp.zeros(10) + x",
),
pytest.param(
lambda x: cnp.zeros((10,)) + x,
id="cnp.zeros((10,)) + x",
),
pytest.param(
lambda x: cnp.zeros((3, 2)) + x,
id="cnp.zeros((3, 2)) + x",
),
],
)
def test_zeros(function, helpers):
"""
Test zeros.
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = range(10)
circuit = compiler.compile(inputset, configuration)
sample = random.randint(0, 11)
helpers.check_execution(circuit, function, sample)