mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add encrypted zeros and ones functions
This commit is contained in:
@@ -15,6 +15,6 @@ from .compilation import (
|
||||
Server,
|
||||
compiler,
|
||||
)
|
||||
from .extensions import LookupTable, univariate
|
||||
from .extensions import LookupTable, one, ones, univariate, zero, zeros
|
||||
from .mlir.utils import MAXIMUM_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
|
||||
@@ -2,5 +2,7 @@
|
||||
Provide additional features that are not present in numpy.
|
||||
"""
|
||||
|
||||
from .ones import one, ones
|
||||
from .table import LookupTable
|
||||
from .univariate import univariate
|
||||
from .zeros import zero, zeros
|
||||
|
||||
56
concrete/numpy/extensions/ones.py
Normal file
56
concrete/numpy/extensions/ones.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Declaration of `ones` and `one` functions, to simplify creation of encrypted ones.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array of ones.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray filled with ones otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
numpy_ones = np.ones(shape, dtype=np.int64)
|
||||
|
||||
if is_tracing:
|
||||
computation = Node.generic(
|
||||
"ones",
|
||||
[],
|
||||
Value.of(numpy_ones, is_encrypted=True),
|
||||
lambda: np.ones(shape, dtype=np.int64),
|
||||
)
|
||||
return Tracer(computation, [])
|
||||
|
||||
return numpy_ones
|
||||
|
||||
|
||||
def one() -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted scalar with the value of one.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with one otherwise
|
||||
"""
|
||||
|
||||
return ones(())
|
||||
56
concrete/numpy/extensions/zeros.py
Normal file
56
concrete/numpy/extensions/zeros.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Declaration of `zeros` and `zero` functions, to simplify creation of encrypted zeros.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array of zeros.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray filled with zeros otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
numpy_zeros = np.zeros(shape, dtype=np.int64)
|
||||
|
||||
if is_tracing:
|
||||
computation = Node.generic(
|
||||
"zeros",
|
||||
[],
|
||||
Value.of(numpy_zeros, is_encrypted=True),
|
||||
lambda: np.zeros(shape, dtype=np.int64),
|
||||
)
|
||||
return Tracer(computation, [])
|
||||
|
||||
return numpy_zeros
|
||||
|
||||
|
||||
def zero() -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted scalar with the value of zero.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with zero otherwise
|
||||
"""
|
||||
|
||||
return zeros(())
|
||||
@@ -106,6 +106,9 @@ class GraphConverter:
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted negation is supported"
|
||||
|
||||
elif name == "ones":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
elif name == "reshape":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
@@ -126,7 +129,11 @@ class GraphConverter:
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted transpose is supported"
|
||||
|
||||
elif name == "zeros":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
else:
|
||||
assert_that(node.converted_to_table_lookup)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(graph.ordered_preds_of(node))
|
||||
@@ -135,7 +142,7 @@ class GraphConverter:
|
||||
if len(variable_input_indices) != 1:
|
||||
return "only single input table lookups are supported"
|
||||
|
||||
if all(input.is_clear for input in inputs):
|
||||
if len(inputs) > 0 and all(input.is_clear for input in inputs):
|
||||
return "one of the operands must be encrypted"
|
||||
|
||||
return None
|
||||
|
||||
@@ -124,7 +124,7 @@ class NodeConverter:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
if self.node.operation == Operation.Constant:
|
||||
result = self.convert_constant()
|
||||
@@ -163,6 +163,9 @@ class NodeConverter:
|
||||
elif name == "negative":
|
||||
result = self.convert_neg()
|
||||
|
||||
elif name == "ones":
|
||||
result = self.convert_ones()
|
||||
|
||||
elif name == "reshape":
|
||||
result = self.convert_reshape()
|
||||
|
||||
@@ -175,7 +178,11 @@ class NodeConverter:
|
||||
elif name == "transpose":
|
||||
result = self.convert_transpose()
|
||||
|
||||
elif name == "zeros":
|
||||
result = self.convert_zeros()
|
||||
|
||||
else:
|
||||
assert_that(self.node.converted_to_table_lookup)
|
||||
result = self.convert_tlu()
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
@@ -458,6 +465,49 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_ones(self) -> OpResult:
|
||||
"""
|
||||
Convert "ones" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
assert isinstance(self.node.output.dtype, Integer)
|
||||
bit_width = self.node.output.dtype.bit_width
|
||||
|
||||
if self.node.output.is_scalar:
|
||||
constant_value = Value(
|
||||
Integer(is_signed=False, bit_width=bit_width + 1),
|
||||
shape=(),
|
||||
is_encrypted=False,
|
||||
)
|
||||
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
|
||||
constant_attr = IntegerAttr.get(constant_type, 1)
|
||||
|
||||
zero = fhe.ZeroEintOp(resulting_type).result
|
||||
one = arith.ConstantOp(constant_type, constant_attr).result
|
||||
|
||||
result = fhe.AddEintIntOp(resulting_type, zero, one).result
|
||||
else:
|
||||
constant_value = Value(
|
||||
Integer(is_signed=False, bit_width=bit_width + 1),
|
||||
shape=(1,),
|
||||
is_encrypted=False,
|
||||
)
|
||||
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
|
||||
constant_attr = Attribute.parse(f"dense<[1]> : {constant_type}")
|
||||
|
||||
zeros = fhe.ZeroTensorOp(resulting_type).result
|
||||
ones = arith.ConstantOp(constant_type, constant_attr).result
|
||||
|
||||
result = fhelinalg.AddEintIntOp(resulting_type, zeros, ones).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_reshape(self) -> OpResult:
|
||||
"""
|
||||
Convert "reshape" node to its corresponding MLIR representation.
|
||||
@@ -843,3 +893,21 @@ class NodeConverter:
|
||||
preds = self.preds
|
||||
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
|
||||
def convert_zeros(self) -> OpResult:
|
||||
"""
|
||||
Convert "zeros" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
if self.node.output.is_scalar:
|
||||
result = fhe.ZeroEintOp(resulting_type).result
|
||||
else:
|
||||
result = fhe.ZeroTensorOp(resulting_type).result
|
||||
|
||||
return result
|
||||
|
||||
@@ -313,10 +313,12 @@ class Node:
|
||||
"matmul",
|
||||
"multiply",
|
||||
"negative",
|
||||
"ones",
|
||||
"reshape",
|
||||
"subtract",
|
||||
"sum",
|
||||
"transpose",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -336,5 +338,7 @@ class Node:
|
||||
"add",
|
||||
"multiply",
|
||||
"negative",
|
||||
"ones",
|
||||
"subtract",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
49
tests/execution/test_ones.py
Normal file
49
tests/execution/test_ones.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Tests of execution of ones operation.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: cnp.one() + x,
|
||||
id="cnp.one() + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.ones(()) + x,
|
||||
id="cnp.ones(()) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.ones(10) + x,
|
||||
id="cnp.ones(10) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.ones((10,)) + x,
|
||||
id="cnp.ones((10,)) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.ones((3, 2)) + x,
|
||||
id="cnp.ones((3, 2)) + x",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ones(function, helpers):
|
||||
"""
|
||||
Test ones.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = range(10)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = random.randint(0, 11)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
49
tests/execution/test_zeros.py
Normal file
49
tests/execution/test_zeros.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Tests of execution of zeros operation.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: cnp.zero() + x,
|
||||
id="cnp.zero() + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.zeros(()) + x,
|
||||
id="cnp.zeros(()) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.zeros(10) + x,
|
||||
id="cnp.zeros(10) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.zeros((10,)) + x,
|
||||
id="cnp.zeros((10,)) + x",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.zeros((3, 2)) + x,
|
||||
id="cnp.zeros((3, 2)) + x",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_zeros(function, helpers):
|
||||
"""
|
||||
Test zeros.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = range(10)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = random.randint(0, 11)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
Reference in New Issue
Block a user