mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(frontend/python): re-write MLIR conversion
This commit is contained in:
@@ -190,6 +190,9 @@ good-names=i,
|
||||
k,
|
||||
ex,
|
||||
Run,
|
||||
xs,
|
||||
on,
|
||||
of,
|
||||
_
|
||||
|
||||
# Good variable names regexes, separated by a comma. If names match any regex,
|
||||
@@ -438,7 +441,10 @@ disable=raw-checker-failed,
|
||||
too-many-instance-attributes,
|
||||
too-many-lines,
|
||||
too-many-locals,
|
||||
too-many-public-methods,
|
||||
too-many-statements,
|
||||
unnecessary-lambda-assignment,
|
||||
use-implicit-booleaness-not-comparison,
|
||||
wrong-import-order
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
|
||||
@@ -8,10 +8,13 @@ select = [
|
||||
ignore = [
|
||||
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105",
|
||||
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901",
|
||||
"PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901"
|
||||
"PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901",
|
||||
"E731"
|
||||
]
|
||||
|
||||
[per-file-ignores]
|
||||
"**/__init__.py" = ["F401"]
|
||||
"concrete/fhe/mlir/processors/all.py" = ["F401"]
|
||||
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
|
||||
"examples/**" = ["PLR2004"]
|
||||
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311"]
|
||||
|
||||
@@ -29,7 +29,15 @@ licenses:
|
||||
pytest:
|
||||
export LD_PRELOAD=$(RUNTIME_LIBRARY)
|
||||
export PYTHONPATH=$(BINDINGS_DIRECTORY)
|
||||
|
||||
# test single precision
|
||||
pytest tests -svv -n auto \
|
||||
--key-cache "${KEY_CACHE_DIRECTORY}" \
|
||||
-m "${PYTEST_MARKERS}"
|
||||
|
||||
# test multi precision
|
||||
pytest tests -svv -n auto \
|
||||
--precision=multi \
|
||||
--cov=concrete \
|
||||
--cov-fail-under=100 \
|
||||
--cov-report=term-missing:skip-covered \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Setup concrete module to be enlarged with numpy module.
|
||||
Setup concrete namespace.
|
||||
"""
|
||||
|
||||
# Do not modify, this is to have a compatible namespace package
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Export everything that users might need.
|
||||
Concrete.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
|
||||
|
||||
from .compilation import (
|
||||
@@ -24,6 +23,8 @@ from .extensions import (
|
||||
AutoRounder,
|
||||
LookupTable,
|
||||
array,
|
||||
conv,
|
||||
maxpool,
|
||||
one,
|
||||
ones,
|
||||
round_bit_pattern,
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -165,13 +165,7 @@ class Client:
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
is_signed = self.specs.input_signs[index]
|
||||
sanitizer = 0 if not is_signed else 2 ** (width - 1)
|
||||
|
||||
if isinstance(arg, int):
|
||||
sanitized_args[index] = arg + sanitizer
|
||||
else:
|
||||
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
|
||||
sanitized_args[index] = arg
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
@@ -205,61 +199,7 @@ class Client:
|
||||
|
||||
self.keygen(force=False)
|
||||
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, self._keyset, result)
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
sanitized_outputs: List[Union[int, np.ndarray]] = []
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("outputs" in client_parameters_json)
|
||||
output_specs = client_parameters_json["outputs"]
|
||||
|
||||
for index, output in enumerate(outputs):
|
||||
is_signed = self.specs.output_signs[index]
|
||||
crt_decomposition = (
|
||||
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
|
||||
)
|
||||
|
||||
if is_signed:
|
||||
if crt_decomposition:
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = (
|
||||
output
|
||||
if output < (int(np.prod(crt_decomposition)) // 2)
|
||||
else -int(np.prod(crt_decomposition)) + output
|
||||
)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (np.prod(crt_decomposition) // 2),
|
||||
output,
|
||||
-np.prod(crt_decomposition) + output,
|
||||
).astype(
|
||||
np.int64
|
||||
) # type: ignore
|
||||
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
|
||||
else:
|
||||
n = output_specs[index]["shape"]["width"]
|
||||
output %= 2**n
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (2 ** (n - 1)), output, output - (2**n)
|
||||
).astype(
|
||||
np.int64
|
||||
) # type: ignore
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
sanitized_outputs.append(
|
||||
output if isinstance(output, int) else output.astype(np.uint64)
|
||||
)
|
||||
|
||||
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def evaluation_keys(self) -> EvaluationKeys:
|
||||
|
||||
@@ -434,7 +434,7 @@ class Compiler:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph)
|
||||
mlir = GraphConverter().convert(self.graph, self.configuration)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class Configuration:
|
||||
global_p_error: Optional[float]
|
||||
insecure_key_cache_location: Optional[str]
|
||||
auto_adjust_rounders: bool
|
||||
single_precision: bool
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
@@ -64,6 +65,7 @@ class Configuration:
|
||||
p_error: Optional[float] = None,
|
||||
global_p_error: Optional[float] = None,
|
||||
auto_adjust_rounders: bool = False,
|
||||
single_precision: bool = True,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.show_graph = show_graph
|
||||
@@ -82,6 +84,7 @@ class Configuration:
|
||||
self.p_error = p_error
|
||||
self.global_p_error = global_p_error
|
||||
self.auto_adjust_rounders = auto_adjust_rounders
|
||||
self.single_precision = single_precision
|
||||
|
||||
self._validate()
|
||||
|
||||
|
||||
@@ -608,7 +608,7 @@ def convert_subgraph_to_subgraph_node(
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
||||
subgraph_node = Node.generic(
|
||||
"subgraph",
|
||||
subgraph_variable_input_node.inputs,
|
||||
deepcopy(subgraph_variable_input_node.inputs),
|
||||
terminal_node.output,
|
||||
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
|
||||
kwargs={
|
||||
|
||||
@@ -45,7 +45,7 @@ class Integer(BaseDataType):
|
||||
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
value = np.array(value)
|
||||
value = np.array(value, dtype=np.int64)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception at the else branch below
|
||||
|
||||
@@ -3,6 +3,8 @@ Provide additional features that are not present in numpy.
|
||||
"""
|
||||
|
||||
from .array import array
|
||||
from .convolution import conv
|
||||
from .maxpool import maxpool
|
||||
from .ones import one, ones
|
||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||
from .table import LookupTable
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Declaration of `array` function, to simplify creation of encrypted arrays.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -52,7 +53,7 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
|
||||
|
||||
computation = Node.generic(
|
||||
"array",
|
||||
[value.output for value in values],
|
||||
[deepcopy(value.output) for value in values],
|
||||
Value(dtype, shape, is_encrypted),
|
||||
lambda *args: np.array(args).reshape(shape),
|
||||
)
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
"""
|
||||
Convolution operations' tracing and evaluation.
|
||||
Tracing and evaluation of convolution.
|
||||
"""
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..fhe.internal.utils import assert_that
|
||||
from ..fhe.representation import Node
|
||||
from ..fhe.tracing import Tracer
|
||||
from ..fhe.values import EncryptedTensor
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import EncryptedTensor
|
||||
|
||||
SUPPORTED_AUTO_PAD = {
|
||||
"NOTSET",
|
||||
@@ -23,8 +24,8 @@ SUPPORTED_AUTO_PAD = {
|
||||
|
||||
def conv(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]] = None,
|
||||
weight: Union[np.ndarray, List, Tracer],
|
||||
bias: Optional[Union[np.ndarray, List, Tracer]] = None,
|
||||
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
@@ -63,11 +64,18 @@ def conv(
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
if kernel_shape is not None and (
|
||||
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
|
||||
):
|
||||
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
|
||||
raise ValueError(message)
|
||||
|
||||
if isinstance(weight, list): # pragma: no cover
|
||||
try:
|
||||
weight = np.array(weight)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if bias is not None and isinstance(bias, list): # pragma: no cover
|
||||
try:
|
||||
bias = np.array(bias)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not isinstance(weight, np.ndarray):
|
||||
@@ -84,6 +92,12 @@ def conv(
|
||||
message = "expected bias to be of type Tracer or ndarray"
|
||||
raise TypeError(message)
|
||||
|
||||
if kernel_shape is not None and (
|
||||
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
|
||||
):
|
||||
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
|
||||
raise ValueError(message)
|
||||
|
||||
if x.ndim <= 2:
|
||||
message = (
|
||||
f"expected input x to have at least 3 dimensions (N, C, D1, ...), but got {x.ndim}"
|
||||
@@ -511,7 +525,7 @@ def _trace_conv(
|
||||
|
||||
computation = Node.generic(
|
||||
conv_func, # "conv1d" or "conv2d" or "conv3d"
|
||||
input_values,
|
||||
deepcopy(input_values),
|
||||
output_value,
|
||||
eval_func,
|
||||
args=() if bias is not None else (np.zeros(n_filters, dtype=np.int64),),
|
||||
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
Tracing and evaluation of maxpool function.
|
||||
Tracing and evaluation of maxpool.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..fhe.internal.utils import assert_that
|
||||
from ..fhe.representation import Node
|
||||
from ..fhe.tracing import Tracer
|
||||
from ..fhe.values import Value
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
@@ -288,9 +289,12 @@ def _trace_or_evaluate(
|
||||
resulting_value.is_encrypted = x.output.is_encrypted
|
||||
resulting_value.dtype = x.output.dtype
|
||||
|
||||
dims = x.ndim - 2
|
||||
assert_that(dims in {1, 2, 3})
|
||||
|
||||
computation = Node.generic(
|
||||
"maxpool",
|
||||
[x.output],
|
||||
f"maxpool{dims}d",
|
||||
[deepcopy(x.output)],
|
||||
resulting_value,
|
||||
_evaluate,
|
||||
kwargs={
|
||||
@@ -218,7 +218,7 @@ def round_bit_pattern(
|
||||
if isinstance(x, Tracer):
|
||||
computation = Node.generic(
|
||||
"round_bit_pattern",
|
||||
[x.output],
|
||||
[deepcopy(x.output)],
|
||||
deepcopy(x.output),
|
||||
evaluator,
|
||||
kwargs={"lsbs_to_remove": lsbs_to_remove},
|
||||
|
||||
@@ -83,7 +83,7 @@ class LookupTable:
|
||||
|
||||
computation = Node.generic(
|
||||
"tlu",
|
||||
[key.output],
|
||||
[deepcopy(key.output)],
|
||||
output,
|
||||
LookupTable.apply,
|
||||
kwargs={"table": table},
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Declaration of `univariate` function.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -75,7 +76,7 @@ def univariate(
|
||||
|
||||
computation = Node.generic(
|
||||
function.__name__,
|
||||
[x.output],
|
||||
[deepcopy(x.output)],
|
||||
output_value,
|
||||
lambda x: function(x), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
|
||||
@@ -2,5 +2,4 @@
|
||||
Provide `computation graph` to `mlir` functionality.
|
||||
"""
|
||||
|
||||
from .graph_converter import GraphConverter
|
||||
from .node_converter import NodeConverter
|
||||
from .converter import Converter as GraphConverter
|
||||
|
||||
1907
frontends/concrete-python/concrete/fhe/mlir/context.py
Normal file
1907
frontends/concrete-python/concrete/fhe/mlir/context.py
Normal file
File diff suppressed because it is too large
Load Diff
183
frontends/concrete-python/concrete/fhe/mlir/conversion.py
Normal file
183
frontends/concrete-python/concrete/fhe/mlir/conversion.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Declaration of `ConversionType` and `Conversion` classes.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from mlir.ir import OpResult as MlirOperation
|
||||
from mlir.ir import Type as MlirType
|
||||
|
||||
from ..representation import Node
|
||||
|
||||
# pylint: enable=import-error
|
||||
|
||||
|
||||
SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
|
||||
SCALAR_EINT_SEARCH_REGEX = re.compile(r"^!FHE\.e(s)?int<([0-9]+)>$")
|
||||
|
||||
TENSOR_INT_SEARCH_REGEX = re.compile(r"^tensor<(([0-9]+x)+)i([0-9]+)>$")
|
||||
TENSOR_EINT_SEARCH_REGEX = re.compile(r"^tensor<(([0-9]+x)+)!FHE\.e(s)?int<([0-9]+)>>$")
|
||||
|
||||
|
||||
class ConversionType:
|
||||
"""
|
||||
ConversionType class, to make it easier to work with MLIR types.
|
||||
"""
|
||||
|
||||
mlir: MlirType
|
||||
|
||||
bit_width: int
|
||||
is_encrypted: bool
|
||||
is_signed: bool
|
||||
shape: Tuple[int, ...]
|
||||
|
||||
def __init__(self, mlir: MlirType):
|
||||
self.mlir = mlir
|
||||
mlir_type_str = str(mlir)
|
||||
|
||||
search = SCALAR_INT_SEARCH_REGEX.search(mlir_type_str)
|
||||
if search:
|
||||
(matched_bit_width,) = search.groups()
|
||||
|
||||
self.bit_width = int(matched_bit_width)
|
||||
self.is_encrypted = False
|
||||
self.is_signed = True
|
||||
self.shape = ()
|
||||
|
||||
return
|
||||
|
||||
search = SCALAR_EINT_SEARCH_REGEX.search(mlir_type_str)
|
||||
if search:
|
||||
matched_is_signed, matched_bit_width = search.groups()
|
||||
|
||||
self.bit_width = int(matched_bit_width)
|
||||
self.is_encrypted = True
|
||||
self.is_signed = matched_is_signed is not None
|
||||
self.shape = ()
|
||||
|
||||
return
|
||||
|
||||
search = TENSOR_INT_SEARCH_REGEX.search(mlir_type_str)
|
||||
if search:
|
||||
matched_shape, _, matched_bit_width = search.groups()
|
||||
|
||||
self.bit_width = int(matched_bit_width)
|
||||
self.is_encrypted = False
|
||||
self.is_signed = True
|
||||
self.shape = tuple(int(size) for size in matched_shape.rstrip("x").split("x"))
|
||||
|
||||
return
|
||||
|
||||
search = TENSOR_EINT_SEARCH_REGEX.search(mlir_type_str)
|
||||
if search:
|
||||
matched_shape, _, matched_is_signed, matched_bit_width = search.groups()
|
||||
|
||||
self.bit_width = int(matched_bit_width)
|
||||
self.is_encrypted = True
|
||||
self.is_signed = matched_is_signed is not None
|
||||
self.shape = tuple(int(size) for size in matched_shape.rstrip("x").split("x"))
|
||||
|
||||
return
|
||||
|
||||
self.is_encrypted = False
|
||||
self.bit_width = 64
|
||||
self.is_signed = False
|
||||
self.shape = ()
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
return not self.is_encrypted
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
return self.shape == ()
|
||||
|
||||
@property
|
||||
def is_tensor(self) -> bool:
|
||||
return self.shape != ()
|
||||
|
||||
@property
|
||||
def is_unsigned(self) -> bool:
|
||||
return not self.is_signed
|
||||
|
||||
# pylint: enable=missing-function-docstring
|
||||
|
||||
|
||||
class Conversion:
|
||||
"""
|
||||
Conversion class, to store MLIR operations with additional information.
|
||||
"""
|
||||
|
||||
origin: Node
|
||||
|
||||
type: ConversionType
|
||||
result: MlirOperation
|
||||
|
||||
_original_bit_width: Optional[int]
|
||||
|
||||
def __init__(self, origin: Node, result: MlirOperation):
|
||||
self.origin = origin
|
||||
|
||||
self.type = ConversionType(result.type)
|
||||
self.result = result
|
||||
|
||||
self._original_bit_width = None
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.result)
|
||||
|
||||
def set_original_bit_width(self, original_bit_width: int):
|
||||
"""
|
||||
Set the original bit-width of the conversion.
|
||||
"""
|
||||
self._original_bit_width = original_bit_width
|
||||
|
||||
@property
|
||||
def original_bit_width(self) -> int:
|
||||
"""
|
||||
Get the original bit-width of the conversion.
|
||||
|
||||
If not explicitly set, defaults to the actual bit width.
|
||||
"""
|
||||
return self._original_bit_width if self._original_bit_width is not None else self.bit_width
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
@property
|
||||
def bit_width(self) -> int:
|
||||
return self.type.bit_width
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
return self.type.is_clear
|
||||
|
||||
@property
|
||||
def is_encrypted(self) -> bool:
|
||||
return self.type.is_encrypted
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
return self.type.is_scalar
|
||||
|
||||
@property
|
||||
def is_signed(self) -> bool:
|
||||
return self.type.is_signed
|
||||
|
||||
@property
|
||||
def is_tensor(self) -> bool:
|
||||
return self.type.is_tensor
|
||||
|
||||
@property
|
||||
def is_unsigned(self) -> bool:
|
||||
return self.type.is_unsigned
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return self.type.shape
|
||||
|
||||
# pylint: enable=missing-function-docstring
|
||||
501
frontends/concrete-python/concrete/fhe/mlir/converter.py
Normal file
501
frontends/concrete-python/concrete/fhe/mlir/converter.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
Declaration of `Converter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import concrete.lang
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from mlir.dialects import func
|
||||
from mlir.ir import BlockArgument as MlirBlockArgument
|
||||
from mlir.ir import Context as MlirContext
|
||||
from mlir.ir import InsertionPoint as MlirInsertionPoint
|
||||
from mlir.ir import Location as MlirLocation
|
||||
from mlir.ir import Module as MlirModule
|
||||
from mlir.ir import OpResult as MlirOperation
|
||||
|
||||
from concrete.fhe.compilation.configuration import Configuration
|
||||
|
||||
from ..representation import Graph, Node, Operation
|
||||
from .context import Context
|
||||
from .conversion import Conversion
|
||||
from .processors.all import * # pylint: disable=wildcard-import
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH, construct_deduplicated_tables
|
||||
|
||||
# pylint: enable=import-error,no-name-in-module
|
||||
|
||||
|
||||
class Converter:
|
||||
"""
|
||||
Converter class, to convert a computation graph to MLIR.
|
||||
"""
|
||||
|
||||
def convert(self, graph: Graph, configuration: Configuration) -> str:
|
||||
"""
|
||||
Convert a computation graph to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to convert
|
||||
|
||||
configuration (Configuration):
|
||||
configuration to use
|
||||
|
||||
Return:
|
||||
str:
|
||||
MLIR corresponding to graph
|
||||
"""
|
||||
|
||||
graph = self.process(graph, configuration)
|
||||
|
||||
with MlirContext() as context, MlirLocation.unknown():
|
||||
concrete.lang.register_dialects(context) # pylint: disable=no-member
|
||||
|
||||
module = MlirModule.create()
|
||||
with MlirInsertionPoint(module.body):
|
||||
ctx = Context(context, graph)
|
||||
|
||||
input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()]
|
||||
|
||||
@func.FuncOp.from_py_func(*input_types)
|
||||
def main(*args):
|
||||
for index, node in enumerate(graph.ordered_inputs()):
|
||||
conversion = Conversion(node, args[index])
|
||||
if "original_bit_width" in node.properties:
|
||||
conversion.set_original_bit_width(node.properties["original_bit_width"])
|
||||
ctx.conversions[node] = conversion
|
||||
|
||||
for node in nx.lexicographical_topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
|
||||
preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)]
|
||||
self.node(ctx, node, preds)
|
||||
|
||||
outputs = []
|
||||
for node in graph.ordered_outputs():
|
||||
assert node in ctx.conversions
|
||||
outputs.append(ctx.conversions[node].result)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
def extract_mlir_name(result: MlirOperation) -> str:
|
||||
return (
|
||||
f"%arg{result.arg_number}"
|
||||
if isinstance(result, MlirBlockArgument)
|
||||
else str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
)
|
||||
|
||||
direct_replacements = {}
|
||||
for placeholder, elements in ctx.from_elements_operations.items():
|
||||
element_names = [extract_mlir_name(element) for element in elements]
|
||||
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
|
||||
direct_replacements[extract_mlir_name(placeholder)] = actual_value
|
||||
|
||||
module_lines_after_direct_replacements_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in direct_replacements:
|
||||
module_lines_after_direct_replacements_are_applied.append(line)
|
||||
continue
|
||||
|
||||
new_value = direct_replacements[mlir_name]
|
||||
new_line = f" {mlir_name} = {new_value}"
|
||||
|
||||
module_lines_after_direct_replacements_are_applied.append(new_line)
|
||||
|
||||
return "\n".join(module_lines_after_direct_replacements_are_applied).strip()
|
||||
|
||||
def process(self, graph: Graph, configuration: Configuration) -> Graph:
|
||||
"""
|
||||
Process a computation graph for MLIR conversion.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to convert
|
||||
|
||||
configuration (Configuration):
|
||||
configuration to use
|
||||
|
||||
Return:
|
||||
str:
|
||||
MLIR corresponding to graph
|
||||
"""
|
||||
|
||||
pipeline = [
|
||||
CheckIntegerOnly(),
|
||||
AssignBitWidths(single_precision=configuration.single_precision),
|
||||
]
|
||||
|
||||
graph = deepcopy(graph)
|
||||
for processor in pipeline:
|
||||
processor.apply(graph)
|
||||
|
||||
return graph
|
||||
|
||||
def node(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
"""
|
||||
Convert a computation graph node into MLIR.
|
||||
|
||||
Args:
|
||||
ctx (Context):
|
||||
conversion context
|
||||
|
||||
node (Node):
|
||||
node to convert
|
||||
|
||||
preds (List[Conversion]):
|
||||
conversions of ordered predecessors of the node
|
||||
|
||||
Return:
|
||||
Conversion:
|
||||
conversion object corresponding to node
|
||||
"""
|
||||
|
||||
ctx.converting = node
|
||||
|
||||
assert node.operation != Operation.Input
|
||||
operation = "constant" if node.operation == Operation.Constant else node.properties["name"]
|
||||
assert operation not in ["convert", "node"]
|
||||
|
||||
converter = getattr(self, operation) if hasattr(self, operation) else self.tlu
|
||||
conversion = converter(ctx, node, preds)
|
||||
conversion.set_original_bit_width(node.properties["original_bit_width"])
|
||||
|
||||
ctx.conversions[node] = conversion
|
||||
return conversion
|
||||
|
||||
# The name of the remaining methods all correspond to node names.
|
||||
# And they have the same signature so that they can be called in a generic way.
|
||||
|
||||
# pylint: disable=missing-function-docstring,unused-argument
|
||||
|
||||
def add(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.add(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) > 0
|
||||
return ctx.array(ctx.typeof(node), elements=preds)
|
||||
|
||||
def assign_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.assign_static(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
preds[1],
|
||||
index=node.properties["kwargs"]["index"],
|
||||
)
|
||||
|
||||
def bitwise_and(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.bitwise_and(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def bitwise_or(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.bitwise_or(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def bitwise_xor(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.bitwise_xor(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def broadcast_to(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.broadcast_to(preds[0], shape=node.output.shape)
|
||||
|
||||
def concatenate(self, ctx: Context, node: Node, preds: List[Conversion]):
|
||||
return ctx.concatenate(
|
||||
ctx.typeof(node),
|
||||
preds,
|
||||
axis=node.properties["kwargs"].get("axis", 0),
|
||||
)
|
||||
|
||||
def constant(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 0
|
||||
return ctx.constant(ctx.typeof(node), data=node())
|
||||
|
||||
def conv1d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
ctx.error({node: "1-dimensional convolutions are not supported at the moment"})
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
def conv2d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) in [2, 3]
|
||||
return ctx.conv2d(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
preds[1],
|
||||
preds[2] if len(preds) == 3 else None,
|
||||
strides=node.properties["kwargs"]["strides"],
|
||||
dilations=node.properties["kwargs"]["dilations"],
|
||||
pads=node.properties["kwargs"]["pads"],
|
||||
group=node.properties["kwargs"]["group"],
|
||||
)
|
||||
|
||||
def conv3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
ctx.error({node: "3-dimensional convolutions are not supported at the moment"})
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
def dot(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.dot(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.equality(ctx.typeof(node), preds[0], preds[1], equals=True)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def expand_dims(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.reshape(preds[0], shape=node.output.shape)
|
||||
|
||||
def greater(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.greater(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def greater_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.greater_equal(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def index_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.index_static(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
index=node.properties["kwargs"]["index"],
|
||||
)
|
||||
|
||||
def left_shift(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.shift(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
preds[1],
|
||||
orientation="left",
|
||||
original_resulting_bit_width=node.properties["original_bit_width"],
|
||||
)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def less(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.less(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def less_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.less_equal(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def matmul(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def maxpool1d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
ctx.error({node: "1-dimensional maxpooling is not supported at the moment"})
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
def maxpool2d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.maxpool2d(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
kernel_shape=node.properties["kwargs"]["kernel_shape"],
|
||||
strides=node.properties["kwargs"]["strides"],
|
||||
dilations=node.properties["kwargs"]["dilations"],
|
||||
)
|
||||
|
||||
def maxpool3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
def multiply(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.mul(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def negative(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.neg(ctx.typeof(node), preds[0])
|
||||
|
||||
def not_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.equality(ctx.typeof(node), preds[0], preds[1], equals=False)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def ones(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 0
|
||||
return ctx.ones(ctx.typeof(node))
|
||||
|
||||
def reshape(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.reshape(preds[0], shape=node.output.shape)
|
||||
|
||||
def right_shift(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.shift(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
preds[1],
|
||||
orientation="right",
|
||||
original_resulting_bit_width=node.properties["original_bit_width"],
|
||||
)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def subtract(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
return ctx.sub(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def sum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.sum(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
axes=node.properties["kwargs"].get("axis", []),
|
||||
keep_dims=node.properties["kwargs"].get("keepdims", False),
|
||||
)
|
||||
|
||||
def squeeze(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
# because of the tracing logic, we have the correct output shape
|
||||
|
||||
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
|
||||
# and the result is a scalar, so we need to do indexing, not reshape
|
||||
if node.output.shape == ():
|
||||
assert all(size == 1 for size in preds[0].shape)
|
||||
index = (0,) * len(preds[0].shape)
|
||||
return ctx.index_static(ctx.typeof(node), preds[0], index)
|
||||
|
||||
# otherwise, a simple reshape would work as we already have the correct shape
|
||||
return ctx.reshape(preds[0], shape=node.output.shape)
|
||||
|
||||
def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert node.converted_to_table_lookup
|
||||
|
||||
variable_input_index = -1
|
||||
|
||||
pred_nodes = ctx.graph.ordered_preds_of(node)
|
||||
for i, pred_node in enumerate(pred_nodes):
|
||||
if pred_node.operation != Operation.Constant:
|
||||
if variable_input_index == -1:
|
||||
variable_input_index = i
|
||||
else:
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
assert variable_input_index != -1
|
||||
|
||||
variable_input = preds[variable_input_index]
|
||||
if variable_input.bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
variable_input_messages = [
|
||||
f"this {variable_input.bit_width}-bit value "
|
||||
f"is used as an input to a table lookup"
|
||||
]
|
||||
if variable_input.bit_width != variable_input.original_bit_width:
|
||||
variable_input_messages.append(
|
||||
"("
|
||||
f"note that it's assigned {variable_input.bit_width}-bits "
|
||||
f"during compilation because of its relation with other operations"
|
||||
")"
|
||||
)
|
||||
|
||||
highlights = {
|
||||
variable_input.origin: variable_input_messages,
|
||||
node: f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit table lookups are supported",
|
||||
}
|
||||
ctx.error(highlights) # type: ignore
|
||||
|
||||
tables = construct_deduplicated_tables(node, pred_nodes)
|
||||
assert len(tables) > 0
|
||||
|
||||
lut_shape: Tuple[int, ...] = ()
|
||||
map_shape: Tuple[int, ...] = ()
|
||||
|
||||
if len(tables) == 1:
|
||||
table = tables[0][0]
|
||||
|
||||
# The reduction on 63b is to avoid problems like doing a TLU of
|
||||
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
|
||||
# constraint of the compiler, while in practice, it is a small
|
||||
# value. Reducing on 64b was not ok for some reason
|
||||
lut_shape = (len(table),)
|
||||
lut_values = np.array(table % (2 << 63), dtype=np.uint64)
|
||||
|
||||
map_shape = ()
|
||||
map_values = None
|
||||
else:
|
||||
individual_table_size = len(tables[0][0])
|
||||
|
||||
lut_shape = (len(tables), individual_table_size)
|
||||
map_shape = node.output.shape
|
||||
|
||||
lut_values = np.zeros(lut_shape, dtype=np.uint64)
|
||||
map_values = np.zeros(map_shape, dtype=np.intp)
|
||||
|
||||
for i, (table, indices) in enumerate(tables):
|
||||
assert len(table) == individual_table_size
|
||||
lut_values[i, :] = table
|
||||
for index in indices:
|
||||
map_values[index] = i
|
||||
|
||||
if len(tables) == 1:
|
||||
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())
|
||||
|
||||
assert map_values is not None
|
||||
return ctx.multi_tlu(
|
||||
ctx.typeof(node),
|
||||
on=variable_input,
|
||||
tables=lut_values.tolist(),
|
||||
mapping=map_values.tolist(),
|
||||
)
|
||||
|
||||
def transpose(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.transpose(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
axes=node.properties["kwargs"].get("axes", []),
|
||||
)
|
||||
|
||||
def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 0
|
||||
return ctx.zeros(ctx.typeof(node))
|
||||
|
||||
# pylint: enable=missing-function-docstring,unused-argument
|
||||
@@ -1,739 +0,0 @@
|
||||
"""
|
||||
Declaration of `GraphConverter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from mlir.dialects import arith, func
|
||||
from mlir.ir import (
|
||||
Attribute,
|
||||
Context,
|
||||
InsertionPoint,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer, SignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
|
||||
|
||||
class GraphConverter:
|
||||
"""
|
||||
GraphConverter class, to convert computation graphs to their MLIR equivalent.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
|
||||
"""
|
||||
Check node convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph of the node
|
||||
|
||||
node (Node):
|
||||
node to be checked
|
||||
|
||||
Returns:
|
||||
Optional[str]:
|
||||
None if node is convertible to MLIR, the reason for inconvertibility otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
inputs = node.inputs
|
||||
output = node.output
|
||||
|
||||
if node.operation == Operation.Constant:
|
||||
assert_that(len(inputs) == 0)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer constants are supported"
|
||||
|
||||
elif node.operation == Operation.Input:
|
||||
assert_that(len(inputs) == 1)
|
||||
assert_that(inputs[0] == output)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer inputs are supported"
|
||||
if output.dtype.is_signed and output.is_clear:
|
||||
return "only encrypted signed integer inputs are supported"
|
||||
|
||||
else:
|
||||
assert_that(node.operation == Operation.Generic)
|
||||
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer operations are supported"
|
||||
|
||||
name = node.properties["name"]
|
||||
|
||||
if name == "add":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "array":
|
||||
assert_that(len(inputs) > 0)
|
||||
assert_that(all(input.is_scalar for input in inputs))
|
||||
|
||||
elif name == "assign.static":
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only assignment to encrypted tensors are supported"
|
||||
|
||||
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
if all(value.is_encrypted for value in node.inputs):
|
||||
pred_nodes = graph.ordered_preds_of(node)
|
||||
if (
|
||||
name in ["left_shift", "right_shift"]
|
||||
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
|
||||
):
|
||||
return "only up to 4-bit shifts are supported"
|
||||
|
||||
for pred_node in pred_nodes:
|
||||
assert isinstance(pred_node.output.dtype, Integer)
|
||||
if pred_node.output.dtype.is_signed:
|
||||
return "only unsigned bitwise operations are supported"
|
||||
|
||||
elif name == "broadcast_to":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted broadcasting is supported"
|
||||
|
||||
elif name == "concatenate":
|
||||
if not all(input.is_encrypted for input in inputs):
|
||||
return "only all encrypted concatenate is supported"
|
||||
|
||||
elif name in ["conv1d", "conv2d", "conv3d"]:
|
||||
assert_that(len(inputs) == 2 or len(inputs) == 3)
|
||||
if not (inputs[0].is_encrypted and inputs[1].is_clear):
|
||||
return f"only {name} with encrypted input and clear weight is supported"
|
||||
|
||||
elif name == "dot":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only dot product between encrypted and clear is supported"
|
||||
|
||||
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "expand_dims":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "index.static":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted indexing supported"
|
||||
|
||||
elif name == "matmul":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only matrix multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "maxpool":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted maxpool is supported"
|
||||
|
||||
elif name == "multiply":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "negative":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted negation is supported"
|
||||
|
||||
elif name == "ones":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
elif name == "reshape":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted reshape is supported"
|
||||
|
||||
elif name == "squeeze":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "subtract":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "sum":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted sum is supported"
|
||||
|
||||
elif name == "transpose":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted transpose is supported"
|
||||
|
||||
elif name == "zeros":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
else:
|
||||
assert_that(node.converted_to_table_lookup)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(graph.ordered_preds_of(node))
|
||||
if pred.operation != Operation.Constant
|
||||
]
|
||||
assert_that(len(variable_input_indices) == 1)
|
||||
|
||||
if len(inputs) > 0 and all(input.is_clear for input in inputs):
|
||||
return "one of the operands must be encrypted"
|
||||
|
||||
return None
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_convertibility(graph: Graph):
|
||||
"""
|
||||
Check graph convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be checked
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if `graph` is not convertible to MLIR
|
||||
"""
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
if len(graph.output_nodes) > 1:
|
||||
offending_nodes.update(
|
||||
{
|
||||
node: ["only a single output is supported", node.location]
|
||||
for node in graph.output_nodes.values()
|
||||
}
|
||||
)
|
||||
|
||||
if len(offending_nodes) == 0:
|
||||
for node in graph.graph.nodes:
|
||||
reason = GraphConverter._check_node_convertibility(graph, node)
|
||||
if reason is not None:
|
||||
offending_nodes[node] = [reason, node.location]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
message = (
|
||||
"Function you are trying to compile cannot be converted to MLIR\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
@staticmethod
|
||||
def _update_bit_widths(graph: Graph):
|
||||
"""
|
||||
Update bit-widths in a computation graph to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be updated
|
||||
"""
|
||||
|
||||
offending_nodes: Dict[Node, List[str]] = {}
|
||||
|
||||
max_bit_width = 0
|
||||
max_bit_width_node = None
|
||||
|
||||
first_tlu_node = None
|
||||
first_signed_node = None
|
||||
|
||||
for node in nx.lexicographical_topological_sort(graph.graph):
|
||||
dtype = node.output.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
|
||||
current_node_bit_width = (
|
||||
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
|
||||
)
|
||||
if (
|
||||
all(value.is_encrypted for value in node.inputs)
|
||||
and node.operation == Operation.Generic
|
||||
and node.properties["name"]
|
||||
in [
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
]
|
||||
):
|
||||
# implementation of these operators require at least 4 bits
|
||||
current_node_bit_width = max(current_node_bit_width, 4)
|
||||
|
||||
if max_bit_width < current_node_bit_width:
|
||||
max_bit_width = current_node_bit_width
|
||||
max_bit_width_node = node
|
||||
|
||||
if node.converted_to_table_lookup and first_tlu_node is None:
|
||||
first_tlu_node = node
|
||||
|
||||
if dtype.is_signed and first_signed_node is None:
|
||||
first_signed_node = node
|
||||
|
||||
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
assert max_bit_width_node is not None
|
||||
offending_nodes[max_bit_width_node] = [
|
||||
(
|
||||
{
|
||||
Operation.Input: f"this input is {max_bit_width}-bits",
|
||||
Operation.Constant: f"this constant is {max_bit_width}-bits",
|
||||
Operation.Generic: f"this operation results in {max_bit_width}-bits",
|
||||
}[max_bit_width_node.operation]
|
||||
),
|
||||
max_bit_width_node.location,
|
||||
]
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bits",
|
||||
first_tlu_node.location,
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
"Function you are trying to compile cannot be converted to MLIR:\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
|
||||
for value in node.inputs + [node.output]:
|
||||
dtype = value.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
|
||||
|
||||
@staticmethod
|
||||
def _offset_negative_lookup_table_inputs(graph: Graph):
|
||||
"""
|
||||
Offset negative table lookup inputs to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to apply offset
|
||||
"""
|
||||
|
||||
# ugly hack to add an offset before entering a TLU
|
||||
# if its variable input node has a signed output.
|
||||
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
|
||||
# this does not update the TLU input values to allow for proper table generation.
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic:
|
||||
if not node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
variable_input_index = -1
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = index
|
||||
break
|
||||
|
||||
variable_input_node = preds[variable_input_index]
|
||||
|
||||
variable_input_value = variable_input_node.output
|
||||
variable_input_dtype = variable_input_value.dtype
|
||||
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
if not variable_input_dtype.is_signed:
|
||||
continue
|
||||
|
||||
variable_input_bit_width = variable_input_dtype.bit_width
|
||||
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
|
||||
|
||||
offset_constant_value = abs(variable_input_dtype.min())
|
||||
|
||||
offset_constant = Node.constant(offset_constant_value)
|
||||
offset_constant.output.dtype = offset_constant_dtype
|
||||
|
||||
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
|
||||
offset_constant.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
add_offset = Node.generic(
|
||||
"add",
|
||||
[variable_input_value, ClearScalar(offset_constant_dtype)],
|
||||
variable_input_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = variable_input_node.properties["original_bit_width"]
|
||||
add_offset.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
|
||||
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_assignments(graph: Graph):
|
||||
"""
|
||||
Broadcast assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
shape = node.inputs[0].shape
|
||||
index = node.properties["kwargs"]["index"]
|
||||
|
||||
assert_that(isinstance(index, tuple))
|
||||
while len(index) < len(shape):
|
||||
index = (*index, slice(None, None, None))
|
||||
|
||||
required_value_shape_list = []
|
||||
|
||||
for i, indexing_element in enumerate(index):
|
||||
if isinstance(indexing_element, slice):
|
||||
n = len(np.zeros(shape[i])[indexing_element])
|
||||
required_value_shape_list.append(n)
|
||||
else:
|
||||
required_value_shape_list.append(1)
|
||||
|
||||
required_value_shape = tuple(required_value_shape_list)
|
||||
actual_value_shape = node.inputs[1].shape
|
||||
|
||||
if required_value_shape != actual_value_shape:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
pred_to_modify = preds[1]
|
||||
|
||||
modified_value = deepcopy(pred_to_modify.output)
|
||||
modified_value.shape = required_value_shape
|
||||
|
||||
try:
|
||||
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_value.is_encrypted = True
|
||||
modified_value.dtype = node.output.dtype
|
||||
modified_pred = Node.generic(
|
||||
"broadcast_to",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.broadcast_to,
|
||||
kwargs={"shape": required_value_shape},
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
np.reshape(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_pred = Node.generic(
|
||||
"reshape",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.reshape,
|
||||
kwargs={"newshape": required_value_shape},
|
||||
)
|
||||
|
||||
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
|
||||
"original_bit_width"
|
||||
]
|
||||
|
||||
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
|
||||
|
||||
nx_graph.remove_edge(pred_to_modify, node)
|
||||
nx_graph.add_edge(modified_pred, node, input_idx=1)
|
||||
|
||||
node.inputs[1] = modified_value
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_clear_assignments(graph: Graph):
|
||||
"""
|
||||
Encrypt clear assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
assigned_value = node.inputs[1]
|
||||
if assigned_value.is_clear:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
assigned_pred = preds[1]
|
||||
|
||||
new_assigned_pred_value = deepcopy(assigned_value)
|
||||
new_assigned_pred_value.is_encrypted = True
|
||||
new_assigned_pred_value.dtype = preds[0].output.dtype
|
||||
|
||||
zero = Node.generic(
|
||||
"zeros",
|
||||
[],
|
||||
EncryptedScalar(new_assigned_pred_value.dtype),
|
||||
lambda: np.zeros((), dtype=np.int64),
|
||||
)
|
||||
|
||||
original_bit_width = 1
|
||||
zero.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
new_assigned_pred = Node.generic(
|
||||
"add",
|
||||
[assigned_pred.output, zero.output],
|
||||
new_assigned_pred_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = assigned_pred.properties["original_bit_width"]
|
||||
new_assigned_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(preds[1], node)
|
||||
|
||||
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
|
||||
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
|
||||
|
||||
@staticmethod
|
||||
def _tensorize_scalars_for_fhelinalg(graph: Graph):
|
||||
"""
|
||||
Tensorize scalars if they are used within fhelinalg operations.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to update
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
OPS_TO_TENSORIZE = [
|
||||
"add",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"broadcast_to",
|
||||
"dot",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"left_shift",
|
||||
"less",
|
||||
"less_equal",
|
||||
"multiply",
|
||||
"not_equal",
|
||||
"right_shift",
|
||||
"subtract",
|
||||
]
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
tensorized_scalars: Dict[Node, Node] = {}
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
|
||||
assert len(node.inputs) in {1, 2}
|
||||
|
||||
if len(node.inputs) == 2:
|
||||
if {inp.is_scalar for inp in node.inputs} != {True, False}:
|
||||
continue
|
||||
else:
|
||||
if not node.inputs[0].is_scalar: # noqa: PLR5501
|
||||
continue
|
||||
|
||||
# for bitwise and comparison operators that can have constants
|
||||
# we don't need broadcasting here
|
||||
if node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
pred_to_tensorize: Optional[Node] = None
|
||||
pred_to_tensorize_index = 0
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.output.is_scalar:
|
||||
pred_to_tensorize = pred
|
||||
pred_to_tensorize_index = index
|
||||
break
|
||||
|
||||
assert pred_to_tensorize is not None
|
||||
|
||||
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
|
||||
if tensorized_pred is None:
|
||||
tensorized_value = deepcopy(pred_to_tensorize.output)
|
||||
tensorized_value.shape = (1,)
|
||||
|
||||
tensorized_pred = Node.generic(
|
||||
"array",
|
||||
[pred_to_tensorize.output],
|
||||
tensorized_value,
|
||||
lambda *args: np.array(args),
|
||||
)
|
||||
|
||||
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
|
||||
tensorized_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
original_shape = ()
|
||||
tensorized_pred.properties["original_shape"] = original_shape
|
||||
|
||||
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
|
||||
tensorized_scalars[pred_to_tensorize] = tensorized_pred
|
||||
|
||||
assert tensorized_pred is not None
|
||||
|
||||
nx_graph.remove_edge(pred_to_tensorize, node)
|
||||
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
|
||||
|
||||
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
|
||||
new_input_value.shape = (1,)
|
||||
node.inputs[pred_to_tensorize_index] = new_input_value
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
|
||||
"""
|
||||
Use subtraction to sanitize signed inputs.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph being converted
|
||||
|
||||
args (List[Any]):
|
||||
list of arguments from mlir main
|
||||
|
||||
ctx (Context):
|
||||
mlir context where the conversion is being performed
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[Any]]:
|
||||
sanitized args and name of the sanitized variables in MLIR
|
||||
"""
|
||||
|
||||
sanitized_args = []
|
||||
for i, arg in enumerate(args):
|
||||
input_node = graph.input_nodes[i]
|
||||
input_value = input_node.output
|
||||
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
|
||||
if input_dtype.is_signed:
|
||||
assert_that(input_value.is_encrypted)
|
||||
n = input_dtype.bit_width
|
||||
|
||||
sanitizer_type = IntegerType.get_signless(n + 1)
|
||||
sanitizer = 2 ** (n - 1)
|
||||
|
||||
if input_value.is_scalar:
|
||||
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
|
||||
else:
|
||||
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
|
||||
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
|
||||
# pylint: enable=too-many-function-args
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
|
||||
if input_value.is_scalar:
|
||||
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
else:
|
||||
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
|
||||
sanitized_args.append(sanitized)
|
||||
else:
|
||||
sanitized_args.append(arg)
|
||||
|
||||
return sanitized_args
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph) -> str:
|
||||
"""
|
||||
Convert a computation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be converted
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual MLIR representation corresponding to `graph`
|
||||
"""
|
||||
|
||||
graph = deepcopy(graph)
|
||||
|
||||
GraphConverter._check_graph_convertibility(graph)
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
GraphConverter._broadcast_assignments(graph)
|
||||
GraphConverter._encrypt_clear_assignments(graph)
|
||||
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
|
||||
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
NodeConverter.value_to_mlir_type(ctx, input_node.output)
|
||||
for input_node in graph.ordered_inputs()
|
||||
]
|
||||
|
||||
@func.FuncOp.from_py_func(*parameters)
|
||||
def main(*args):
|
||||
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
|
||||
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = sanitized_args[arg_num]
|
||||
|
||||
constant_cache = {}
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
|
||||
node_converter = NodeConverter(
|
||||
ctx,
|
||||
graph,
|
||||
node,
|
||||
preds,
|
||||
constant_cache,
|
||||
from_elements_operations,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
|
||||
return results
|
||||
|
||||
direct_replacements = {}
|
||||
for placeholder, elements in from_elements_operations.items():
|
||||
element_names = [NodeConverter.mlir_name(element) for element in elements]
|
||||
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
|
||||
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
|
||||
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in direct_replacements:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
new_value = direct_replacements[mlir_name]
|
||||
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied).strip()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Declaration of `GraphProcessor` class.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Mapping, Union
|
||||
|
||||
from ...representation import Graph, Node
|
||||
|
||||
|
||||
class GraphProcessor(ABC):
|
||||
"""
|
||||
GraphProcessor base class, to define the API for a graph processing pipeline.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, graph: Graph):
|
||||
"""
|
||||
Process the graph.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def error(graph: Graph, highlights: Mapping[Node, Union[str, List[str]]]):
|
||||
"""
|
||||
Fail processing with an error.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph being processed
|
||||
|
||||
highlights (Mapping[Node, Union[str, List[str]]]):
|
||||
nodes to highlight along with messages
|
||||
"""
|
||||
|
||||
highlights_with_location = {}
|
||||
for node, messages in highlights.items():
|
||||
messages_with_location = messages if isinstance(messages, list) else [messages]
|
||||
messages_with_location.append(node.location)
|
||||
highlights_with_location[node] = messages_with_location
|
||||
|
||||
message = "Function you are trying to compile cannot be compiled\n\n" + graph.format(
|
||||
highlighted_nodes=highlights_with_location
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
All graph processors.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
|
||||
from .assign_bit_widths import AssignBitWidths
|
||||
from .check_integer_only import CheckIntegerOnly
|
||||
|
||||
# pylint: enable=unused-import
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Declaration of `AssignBitWidths` graph processor.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
|
||||
from ...dtypes import Integer
|
||||
from ...representation import Graph, Node, Operation
|
||||
from . import GraphProcessor
|
||||
|
||||
|
||||
class AssignBitWidths(GraphProcessor):
|
||||
"""
|
||||
Assign a precision to all nodes inputs/output.
|
||||
|
||||
The precisions are compatible graph constraints and MLIR.
|
||||
There are two modes:
|
||||
- single precision: where all encrypted values have the same precision.
|
||||
- multi precision: where encrypted values can have different precisions.
|
||||
"""
|
||||
|
||||
def __init__(self, single_precision=False):
|
||||
self.single_precision = single_precision
|
||||
|
||||
def apply(self, graph: Graph):
|
||||
nodes = graph.query_nodes()
|
||||
for node in nodes:
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
|
||||
if self.single_precision:
|
||||
assign_single_precision(nodes)
|
||||
else:
|
||||
assign_multi_precision(graph, nodes)
|
||||
|
||||
|
||||
def assign_single_precision(nodes: list[Node]):
|
||||
"""Assign one single encryption precision to all nodes."""
|
||||
p = required_encrypted_bitwidth(nodes)
|
||||
for node in nodes:
|
||||
assign_precisions_1_node(node, p, p)
|
||||
|
||||
|
||||
def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
|
||||
"""Assign input/output precision to a single node.
|
||||
|
||||
Precision are adjusted to match different use, e.g. encrypted and constant case.
|
||||
"""
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
if node.output.is_encrypted:
|
||||
node.output.dtype.bit_width = output_p
|
||||
else:
|
||||
node.output.dtype.bit_width = output_p + 1
|
||||
|
||||
for value in node.inputs:
|
||||
assert isinstance(value.dtype, Integer)
|
||||
if value.is_encrypted:
|
||||
value.dtype.bit_width = inputs_p
|
||||
else:
|
||||
value.dtype.bit_width = inputs_p + 1
|
||||
|
||||
|
||||
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
|
||||
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
|
||||
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
|
||||
MULTIPLY = {"multiply"}
|
||||
|
||||
|
||||
def max_encrypted_bitwidth_node(node: Node):
|
||||
"""Give the minimal precision to implement the node.
|
||||
|
||||
This applies to both input and output precisions.
|
||||
"""
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
if node.output.is_encrypted or node.operation == Operation.Constant:
|
||||
normal_p = node.output.dtype.bit_width
|
||||
else:
|
||||
normal_p = -1
|
||||
name = node.properties.get("name")
|
||||
|
||||
if name in CHUNKED_COMPARISON:
|
||||
return max(normal_p, CHUNKED_COMPARISON_MIN_BITWIDTH)
|
||||
|
||||
if name in MAX_POOLS:
|
||||
return normal_p + 1
|
||||
|
||||
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
|
||||
return normal_p + 1
|
||||
|
||||
return normal_p
|
||||
|
||||
|
||||
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
|
||||
"""Give the minimal precision to implement all the nodes."""
|
||||
bitwidths = map(max_encrypted_bitwidth_node, nodes)
|
||||
return max(bitwidths, default=-1)
|
||||
|
||||
|
||||
def required_inputs_encrypted_bitwidth(graph, node, nodes_output_p: list[tuple[Node, int]]) -> int:
|
||||
"""Give the minimal precision to supports the inputs."""
|
||||
preds = graph.ordered_preds_of(node)
|
||||
get_prec = lambda node: nodes_output_p[node.properties[NODE_ID]][1]
|
||||
# by definition all inputs have the same block precision
|
||||
# see uniform_precision_per_blocks
|
||||
return get_prec(node) if len(preds) == 0 else get_prec(preds[0])
|
||||
|
||||
|
||||
def assign_multi_precision(graph, nodes):
|
||||
"""Assign a specific encryption precision to each nodes."""
|
||||
add_nodes_id(nodes)
|
||||
nodes_output_p = uniform_precision_per_blocks(graph, nodes)
|
||||
for node, _ in nodes_output_p:
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
nodes_inputs_p = [
|
||||
required_inputs_encrypted_bitwidth(graph, node, nodes_output_p)
|
||||
if can_change_precision(node)
|
||||
else output_p
|
||||
for node, output_p in nodes_output_p
|
||||
]
|
||||
for (node, output_p), inputs_p in zip(nodes_output_p, nodes_inputs_p):
|
||||
assign_precisions_1_node(node, output_p, inputs_p)
|
||||
clear_nodes_id(nodes)
|
||||
|
||||
|
||||
TLU_WITHOUT_PRECISION_CHANGE = CHUNKED_COMPARISON | MAX_POOLS | MULTIPLY
|
||||
|
||||
|
||||
def can_change_precision(node):
|
||||
"""Detect if a node completely ties inputs/output precisions together."""
|
||||
return (
|
||||
node.converted_to_table_lookup
|
||||
and node.properties.get("name") not in TLU_WITHOUT_PRECISION_CHANGE
|
||||
)
|
||||
|
||||
|
||||
def convert_union_to_blocks(node_union: UnionFind) -> Iterable[list[int]]:
|
||||
"""Convert a `UnionFind` to blocks.
|
||||
|
||||
The result is an iterable of blocks.A block being a list of node id.
|
||||
"""
|
||||
blocks = {}
|
||||
for node_id in range(node_union.size):
|
||||
node_canon = node_union.find_canonical(node_id)
|
||||
if node_canon == node_id:
|
||||
assert node_canon not in blocks
|
||||
blocks[node_canon] = [node_id]
|
||||
else:
|
||||
blocks[node_canon].append(node_id)
|
||||
return blocks.values()
|
||||
|
||||
|
||||
NODE_ID = "node_id"
|
||||
|
||||
|
||||
def add_nodes_id(nodes):
|
||||
"""Temporarily add a NODE_ID property to all nodes."""
|
||||
for node_id, node in enumerate(nodes):
|
||||
assert NODE_ID not in node.properties
|
||||
node.properties[NODE_ID] = node_id
|
||||
|
||||
|
||||
def clear_nodes_id(nodes):
|
||||
"""Remove the NODE_ID property from all nodes."""
|
||||
for node in nodes:
|
||||
del node.properties[NODE_ID]
|
||||
|
||||
|
||||
def uniform_precision_per_blocks(graph: Graph, nodes: list[Node]) -> list[tuple[Node, int]]:
|
||||
"""Find the required precision of blocks and associate it corresponding nodes."""
|
||||
size = len(nodes)
|
||||
node_union = UnionFind(size)
|
||||
for node_id, node in enumerate(nodes):
|
||||
preds = graph.ordered_preds_of(node)
|
||||
if not preds:
|
||||
continue
|
||||
# we always unify all inputs
|
||||
first_input_id = preds[0].properties[NODE_ID]
|
||||
for pred in preds[1:]:
|
||||
pred_id = pred.properties[NODE_ID]
|
||||
node_union.union(first_input_id, pred_id)
|
||||
# we unify with outputs only if no precision change can occur
|
||||
if not can_change_precision(node):
|
||||
node_union.union(first_input_id, node_id)
|
||||
|
||||
blocks = convert_union_to_blocks(node_union)
|
||||
result: list[None | tuple[Node, int]]
|
||||
result = [None] * len(nodes)
|
||||
for nodes_id in blocks:
|
||||
output_p = required_encrypted_bitwidth(nodes[node_id] for node_id in nodes_id)
|
||||
for node_id in nodes_id:
|
||||
result[node_id] = (nodes[node_id], output_p)
|
||||
assert None not in result
|
||||
return typing.cast("list[tuple[Node, int]]", result)
|
||||
|
||||
|
||||
class UnionFind:
|
||||
"""
|
||||
Utility class joins the nodes in equivalent precision classes.
|
||||
|
||||
Nodes are just integers id.
|
||||
"""
|
||||
|
||||
parent: list[int]
|
||||
|
||||
def __init__(self, size: int):
|
||||
"""Create a union find suitable for `size` nodes."""
|
||||
self.parent = list(range(size))
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
"""Size in number of nodes."""
|
||||
return len(self.parent)
|
||||
|
||||
def find_canonical(self, a: int) -> int:
|
||||
"""Find the current canonical node for a given input node."""
|
||||
parent = self.parent[a]
|
||||
if a == parent:
|
||||
return a
|
||||
canonical = self.find_canonical(parent)
|
||||
self.parent[a] = canonical
|
||||
return canonical
|
||||
|
||||
def union(self, a: int, b: int):
|
||||
"""Union both nodes."""
|
||||
self.united_common_ancestor(a, b)
|
||||
|
||||
def united_common_ancestor(self, a: int, b: int) -> int:
|
||||
"""Deduce the common ancestor of both nodes after unification."""
|
||||
parent_a = self.parent[a]
|
||||
parent_b = self.parent[b]
|
||||
|
||||
if parent_a == parent_b:
|
||||
return parent_a
|
||||
|
||||
if a == parent_a and parent_b < parent_a:
|
||||
common_ancestor = parent_b
|
||||
elif b == parent_b and parent_a < parent_b:
|
||||
common_ancestor = parent_a
|
||||
else:
|
||||
common_ancestor = self.united_common_ancestor(parent_a, parent_b)
|
||||
|
||||
self.parent[a] = common_ancestor
|
||||
self.parent[b] = common_ancestor
|
||||
return common_ancestor
|
||||
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Declaration of `CheckIntegerOnly` graph processor.
|
||||
"""
|
||||
|
||||
from ...dtypes import Integer
|
||||
from ...representation import Graph
|
||||
from . import GraphProcessor
|
||||
|
||||
|
||||
class CheckIntegerOnly(GraphProcessor):
|
||||
"""
|
||||
CheckIntegerOnly graph processor, to make sure the graph only contains integer nodes.
|
||||
"""
|
||||
|
||||
def apply(self, graph: Graph):
|
||||
non_integer_nodes = graph.query_nodes(
|
||||
custom_filter=(lambda node: not isinstance(node.output.dtype, Integer))
|
||||
)
|
||||
if non_integer_nodes:
|
||||
self.error(graph, {node: "only integers are supported" for node in non_integer_nodes})
|
||||
@@ -4,7 +4,7 @@ Declaration of various functions and constants related to MLIR conversion.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from itertools import chain, product
|
||||
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -53,11 +53,11 @@ def flood_replace_none_values(table: list):
|
||||
previous_idx = current_idx - 1
|
||||
next_idx = current_idx + 1
|
||||
|
||||
if previous_idx >= 0 and table[previous_idx] is None:
|
||||
if previous_idx >= 0 and table[previous_idx] is None: # pragma: no cover
|
||||
table[previous_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(previous_idx)
|
||||
|
||||
if next_idx < len(table) and table[next_idx] is None:
|
||||
if next_idx < len(table) and table[next_idx] is None: # pragma: no cover
|
||||
table[next_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(next_idx)
|
||||
|
||||
@@ -93,12 +93,13 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
|
||||
values = chain(range(0, variable_input_dtype.max() + 1), range(variable_input_dtype.min(), 0))
|
||||
|
||||
np.seterr(divide="ignore")
|
||||
|
||||
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
|
||||
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
|
||||
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
|
||||
for value in values:
|
||||
try:
|
||||
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
|
||||
table.append(node(*inputs))
|
||||
|
||||
@@ -5,7 +5,7 @@ Declaration of `Graph` class.
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -296,10 +296,15 @@ class Graph:
|
||||
|
||||
bounds += "]"
|
||||
|
||||
output_value = node.output
|
||||
if isinstance(output_value.dtype, Integer) and "original_bit_width" in node.properties:
|
||||
output_value = deepcopy(output_value)
|
||||
output_value.dtype.bit_width = node.properties["original_bit_width"]
|
||||
|
||||
# remember metadata of the node
|
||||
line_metadata.append(
|
||||
{
|
||||
"type": f"# {node.output}",
|
||||
"type": f"# {output_value}",
|
||||
"bounds": bounds,
|
||||
"tag": (f"@ {node.tag}" if node.tag != "" else ""),
|
||||
"location": node.location,
|
||||
@@ -559,6 +564,9 @@ class Graph:
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
is_encrypted_filter: Optional[bool] = None,
|
||||
custom_filter: Optional[Callable[[Node], bool]] = None,
|
||||
ordered: bool = False,
|
||||
) -> List[Node]:
|
||||
"""
|
||||
Query nodes within the graph.
|
||||
@@ -575,6 +583,15 @@ class Graph:
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
is_encrypted_filter (Optional[bool], default = None)
|
||||
filter for encryption status
|
||||
|
||||
custom_filter (Optional[Callable[[Node], bool]], default = None):
|
||||
flexible filter
|
||||
|
||||
ordered (bool)
|
||||
whether to apply topological sorting before filtering nodes
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
filtered nodes
|
||||
@@ -592,6 +609,12 @@ class Graph:
|
||||
|
||||
return any(text == alternative for alternative in text_filter)
|
||||
|
||||
def match_boolean_filter(boolean_filter, boolean):
|
||||
if boolean_filter is None:
|
||||
return True
|
||||
|
||||
return boolean == boolean_filter
|
||||
|
||||
def get_operation_name(node):
|
||||
result: str
|
||||
|
||||
@@ -604,12 +627,15 @@ class Graph:
|
||||
|
||||
return result
|
||||
|
||||
nodes = nx.lexicographical_topological_sort(self.graph) if ordered else self.graph.nodes()
|
||||
return [
|
||||
node
|
||||
for node in self.graph.nodes()
|
||||
for node in nodes
|
||||
if (
|
||||
match_text_filter(tag_filter, node.tag)
|
||||
and match_text_filter(operation_filter, get_operation_name(node))
|
||||
and match_boolean_filter(is_encrypted_filter, node.output.is_encrypted)
|
||||
and (custom_filter is None or custom_filter(node))
|
||||
)
|
||||
]
|
||||
|
||||
@@ -617,6 +643,8 @@ class Graph:
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
is_encrypted_filter: Optional[bool] = None,
|
||||
custom_filter: Optional[Callable[[Node], bool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get maximum integer bit-width within the graph.
|
||||
@@ -630,16 +658,21 @@ class Graph:
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
is_encrypted_filter (Optional[bool], default = None)
|
||||
filter for encryption status
|
||||
|
||||
custom_filter (Optional[Callable[[Node], bool]], default = None):
|
||||
flexible filter
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum integer bit-width within the graph
|
||||
if there are no integer nodes matching the query, result is -1
|
||||
"""
|
||||
|
||||
query = self.query_nodes(tag_filter, operation_filter, is_encrypted_filter, custom_filter)
|
||||
filtered_bit_widths = (
|
||||
node.output.dtype.bit_width
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer)
|
||||
node.output.dtype.bit_width for node in query if isinstance(node.output.dtype, Integer)
|
||||
)
|
||||
return max(filtered_bit_widths, default=-1)
|
||||
|
||||
@@ -647,6 +680,8 @@ class Graph:
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
is_encrypted_filter: Optional[bool] = None,
|
||||
custom_filter: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Get integer range of the graph.
|
||||
@@ -660,30 +695,39 @@ class Graph:
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
is_encrypted_filter (Optional[bool], default = None)
|
||||
filter for encryption status
|
||||
|
||||
custom_filter (Optional[Callable[[Node], bool]], default = None):
|
||||
flexible filter
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]:
|
||||
minimum and maximum integer value observed during inputset evaluation
|
||||
if there are no integer nodes matching the query, result is None
|
||||
"""
|
||||
|
||||
if self.is_direct:
|
||||
return None
|
||||
|
||||
result: Optional[Tuple[int, int]] = None
|
||||
|
||||
if not self.is_direct:
|
||||
filtered_bounds = (
|
||||
node.bounds
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer) and node.bounds is not None
|
||||
)
|
||||
for min_bound, max_bound in filtered_bounds:
|
||||
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
|
||||
query = self.query_nodes(tag_filter, operation_filter, is_encrypted_filter, custom_filter)
|
||||
filtered_bounds = (
|
||||
node.bounds
|
||||
for node in query
|
||||
if isinstance(node.output.dtype, Integer) and node.bounds is not None
|
||||
)
|
||||
for min_bound, max_bound in filtered_bounds:
|
||||
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
|
||||
|
||||
if result is None:
|
||||
result = (int(min_bound), int(max_bound))
|
||||
else:
|
||||
old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence
|
||||
result = (
|
||||
min(old_min_bound, int(min_bound)),
|
||||
max(old_max_bound, int(max_bound)),
|
||||
)
|
||||
if result is None:
|
||||
result = (int(min_bound), int(max_bound))
|
||||
else:
|
||||
old_min_bound, old_max_bound = result
|
||||
result = (
|
||||
min(old_min_bound, int(min_bound)),
|
||||
max(old_max_bound, int(max_bound)),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -163,10 +163,6 @@ class Node:
|
||||
|
||||
fhe_directory = os.path.dirname(fhe.__file__)
|
||||
|
||||
import concrete.onnx as coonx
|
||||
|
||||
coonx_directory = os.path.dirname(coonx.__file__)
|
||||
|
||||
# pylint: enable=cyclic-import,import-outside-toplevel
|
||||
|
||||
for frame in reversed(traceback.extract_stack()):
|
||||
@@ -176,9 +172,6 @@ class Node:
|
||||
if frame.filename.startswith(fhe_directory):
|
||||
continue
|
||||
|
||||
if frame.filename.startswith(coonx_directory):
|
||||
continue
|
||||
|
||||
self.location = f"{frame.filename}:{frame.lineno}"
|
||||
break
|
||||
|
||||
@@ -294,12 +287,12 @@ class Node:
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
if name == "index_static":
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"{predecessors[0]}[{', '.join(elements)}]"
|
||||
|
||||
if name == "assign.static":
|
||||
if name == "assign_static":
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
|
||||
@@ -345,10 +338,10 @@ class Node:
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
if name == "index_static":
|
||||
name = self.format(["□"])
|
||||
|
||||
if name == "assign.static":
|
||||
if name == "assign_static":
|
||||
name = self.format(["□", "□"])[1:-1]
|
||||
|
||||
return name
|
||||
@@ -386,7 +379,7 @@ class Node:
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
"assign.static",
|
||||
"assign_static",
|
||||
"broadcast_to",
|
||||
"concatenate",
|
||||
"conv1d",
|
||||
@@ -394,7 +387,7 @@ class Node:
|
||||
"conv3d",
|
||||
"dot",
|
||||
"expand_dims",
|
||||
"index.static",
|
||||
"index_static",
|
||||
"matmul",
|
||||
"maxpool",
|
||||
"multiply",
|
||||
|
||||
@@ -426,7 +426,7 @@ class Tracer:
|
||||
|
||||
computation = Node.generic(
|
||||
operation.__name__,
|
||||
[tracer.output for tracer in tracers],
|
||||
[deepcopy(tracer.output) for tracer in tracers],
|
||||
output_value,
|
||||
operation,
|
||||
kwargs=kwargs,
|
||||
@@ -618,7 +618,7 @@ class Tracer:
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
[deepcopy(self.output)],
|
||||
output_value,
|
||||
lambda x: x, # unused for direct definition
|
||||
)
|
||||
@@ -662,7 +662,7 @@ class Tracer:
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
[deepcopy(self.output)],
|
||||
output_value,
|
||||
evaluator,
|
||||
kwargs={"dtype": dtype},
|
||||
@@ -753,8 +753,8 @@ class Tracer:
|
||||
output_value.shape = np.zeros(output_value.shape)[index].shape
|
||||
|
||||
computation = Node.generic(
|
||||
"index.static",
|
||||
[self.output],
|
||||
"index_static",
|
||||
[deepcopy(self.output)],
|
||||
output_value,
|
||||
lambda x, index: x[index],
|
||||
kwargs={"index": index},
|
||||
@@ -803,8 +803,8 @@ class Tracer:
|
||||
|
||||
sanitized_value = self.sanitize(value)
|
||||
computation = Node.generic(
|
||||
"assign.static",
|
||||
[self.output, sanitized_value.output],
|
||||
"assign_static",
|
||||
[deepcopy(self.output), deepcopy(sanitized_value.output)],
|
||||
self.output,
|
||||
assign,
|
||||
kwargs={"index": index},
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
Implement machine learning operations as specified by ONNX.
|
||||
"""
|
||||
|
||||
from .convolution import conv
|
||||
from .maxpool import maxpool
|
||||
@@ -78,7 +78,6 @@ setuptools.setup(
|
||||
|
||||
package_dir={
|
||||
"concrete.fhe": "./concrete/fhe",
|
||||
"concrete.onnx": "./concrete/onnx",
|
||||
"": bindings_directory(),
|
||||
},
|
||||
packages=setuptools.find_namespace_packages(
|
||||
@@ -87,9 +86,6 @@ setuptools.setup(
|
||||
) + setuptools.find_namespace_packages(
|
||||
where=".",
|
||||
include=["concrete.fhe", "concrete.fhe.*"],
|
||||
) + setuptools.find_namespace_packages(
|
||||
where=".",
|
||||
include=["concrete.onnx", "concrete.onnx.*"],
|
||||
) + setuptools.find_namespace_packages(
|
||||
where=bindings_directory(),
|
||||
include=["concrete.compiler", "concrete.compiler.*"],
|
||||
|
||||
@@ -18,6 +18,7 @@ tests_directory = os.path.dirname(tests.__file__)
|
||||
|
||||
|
||||
INSECURE_KEY_CACHE_LOCATION = None
|
||||
USE_MULTI_PRECISION = False
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
@@ -39,6 +40,13 @@ def pytest_addoption(parser):
|
||||
action="store",
|
||||
help="Specify the location of the key cache",
|
||||
)
|
||||
parser.addoption(
|
||||
"--precision",
|
||||
type=str,
|
||||
default=None,
|
||||
action="store",
|
||||
help="Which precision strategy to use in execution tests (single or multi)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
@@ -47,6 +55,7 @@ def pytest_sessionstart(session):
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global INSECURE_KEY_CACHE_LOCATION
|
||||
global USE_MULTI_PRECISION
|
||||
# pylint: enable=global-statement
|
||||
|
||||
key_cache_location = session.config.getoption("--key-cache", default=None)
|
||||
@@ -64,6 +73,9 @@ def pytest_sessionstart(session):
|
||||
|
||||
INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
|
||||
|
||||
precision = session.config.getoption("--precision", default="single")
|
||||
USE_MULTI_PRECISION = precision == "multi"
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
|
||||
"""
|
||||
@@ -117,6 +129,7 @@ class Helpers:
|
||||
jit=True,
|
||||
insecure_key_cache_location=INSECURE_KEY_CACHE_LOCATION,
|
||||
global_p_error=(1 / 10_000),
|
||||
single_precision=(not USE_MULTI_PRECISION),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -21,6 +21,22 @@ from concrete import fhe
|
||||
lambda x, y: fhe.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
},
|
||||
id="fhe.array([x, y])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: fhe.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
},
|
||||
id="fhe.array([x, y])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: fhe.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
},
|
||||
id="fhe.array([x, y])",
|
||||
@@ -42,6 +58,14 @@ from concrete import fhe
|
||||
},
|
||||
id="fhe.array([[x, 1], [y, 2], [z, 3]])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: fhe.array([x, y]) + fhe.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
},
|
||||
id="fhe.array([x, y]) + fhe.array([x, y])",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_array(function, parameters, helpers):
|
||||
|
||||
@@ -60,3 +60,58 @@ def test_bitwise(function, parameters, helpers):
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: (x & y) + (2**6),
|
||||
id="x & y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x | y) + (2**6),
|
||||
id="x | y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x ^ y) + (2**6),
|
||||
id="x ^ y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_bitwise_optimized(function, parameters, helpers):
|
||||
"""
|
||||
Test optimized bitwise operations between encrypted integers.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -18,6 +18,20 @@ from concrete import fhe
|
||||
"y": {"shape": (3, 2)},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.concatenate((x, y)),
|
||||
{
|
||||
"x": {"shape": (4, 2), "status": "clear"},
|
||||
"y": {"shape": (3, 2)},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.concatenate((x, y)),
|
||||
{
|
||||
"x": {"shape": (4, 2)},
|
||||
"y": {"shape": (3, 2), "status": "clear"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.concatenate((x, y), axis=0),
|
||||
{
|
||||
|
||||
@@ -5,7 +5,6 @@ Tests of execution of convolution operation.
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.onnx as connx
|
||||
from concrete import fhe
|
||||
from concrete.fhe.representation.node import Node
|
||||
from concrete.fhe.tracing.tracer import Tracer
|
||||
@@ -62,7 +61,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
|
||||
return fhe.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
circuit = function.compile(inputset, configuration)
|
||||
@@ -307,32 +306,6 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
|
||||
ValueError,
|
||||
"expected number of channel in weight to be 1.0 (C / group), but got 2",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4),
|
||||
(1, 1, 2),
|
||||
(1,),
|
||||
(0, 0),
|
||||
(1,),
|
||||
(1,),
|
||||
None,
|
||||
1,
|
||||
"NOTSET",
|
||||
NotImplementedError,
|
||||
"conv1d conversion to MLIR is not yet implemented",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4, 4, 4),
|
||||
(1, 1, 2, 2, 2),
|
||||
(1,),
|
||||
(0, 0, 0, 0, 0, 0),
|
||||
(1, 1, 1),
|
||||
(1, 1, 1),
|
||||
None,
|
||||
1,
|
||||
"NOTSET",
|
||||
NotImplementedError,
|
||||
"conv3d conversion to MLIR is not yet implemented",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4, 4, 4, 4),
|
||||
(1, 1, 2, 2, 2, 2),
|
||||
@@ -388,7 +361,7 @@ def test_bad_conv_compilation(
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(
|
||||
return fhe.conv(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
@@ -426,8 +399,8 @@ def test_bad_conv_compilation(
|
||||
"func",
|
||||
[
|
||||
# pylint: disable=protected-access
|
||||
connx.convolution._evaluate_conv,
|
||||
connx.convolution._trace_conv,
|
||||
fhe.extensions.convolution._evaluate_conv,
|
||||
fhe.extensions.convolution._trace_conv,
|
||||
# pylint: enable=protected-access
|
||||
],
|
||||
)
|
||||
@@ -487,7 +460,7 @@ def test_inconsistent_input_types(
|
||||
Test conv with inconsistent input types.
|
||||
"""
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
connx.conv(
|
||||
fhe.conv(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
|
||||
@@ -5,7 +5,6 @@ Tests of execution of maxpool operation.
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.onnx as connx
|
||||
from concrete import fhe
|
||||
|
||||
|
||||
@@ -69,16 +68,61 @@ def test_maxpool(
|
||||
sample_input = np.expand_dims(np.array(sample_input), axis=(0, 1))
|
||||
expected_output = np.expand_dims(np.array(expected_output), axis=(0, 1))
|
||||
|
||||
assert np.array_equal(connx.maxpool(sample_input, **operation), expected_output)
|
||||
assert np.array_equal(fhe.maxpool(sample_input, **operation), expected_output)
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.maxpool(x, **operation)
|
||||
return fhe.maxpool(x, **operation)
|
||||
|
||||
graph = function.trace([sample_input], helpers.configuration())
|
||||
assert np.array_equal(graph(sample_input), expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,parameters",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"kernel_shape": (3, 2),
|
||||
},
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 20], "shape": (1, 1, 6, 7)},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"kernel_shape": (3, 2),
|
||||
},
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 6, 7)},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_maxpool2d(
|
||||
operation,
|
||||
parameters,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test maxpool2d.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def function(x):
|
||||
return fhe.maxpool(x, **operation)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,operation,expected_error,expected_message",
|
||||
[
|
||||
@@ -308,7 +352,7 @@ def test_bad_maxpool(
|
||||
"""
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
connx.maxpool(np.random.randint(0, 10, size=input_shape), **operation)
|
||||
fhe.maxpool(np.random.randint(0, 10, size=input_shape), **operation)
|
||||
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
@@ -318,51 +362,11 @@ def test_bad_maxpool_special(helpers):
|
||||
Test maxpool with bad parameters for special cases.
|
||||
"""
|
||||
|
||||
# compile
|
||||
# -------
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def not_compilable(x):
|
||||
return connx.maxpool(x, kernel_shape=(4, 3))
|
||||
|
||||
inputset = [np.random.randint(0, 10, size=(1, 1, 10, 10)) for i in range(100)]
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
not_compilable.compile(inputset, helpers.configuration())
|
||||
|
||||
helpers.check_str("MaxPool operation cannot be compiled yet", str(excinfo.value))
|
||||
|
||||
# clear input
|
||||
# -----------
|
||||
|
||||
@fhe.compiler({"x": "clear"})
|
||||
def clear_input(x):
|
||||
return connx.maxpool(x, kernel_shape=(4, 3, 2))
|
||||
|
||||
inputset = [np.zeros((1, 1, 10, 10, 10), dtype=np.int64)]
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
clear_input.compile(inputset, helpers.configuration())
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 10, 10)> ∈ [0, 0]
|
||||
%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor<uint1, shape=(1, 1, 7, 8, 9)> ∈ [0, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported
|
||||
return %1
|
||||
|
||||
""".strip(), # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
# badly typed ndarray input
|
||||
# -------------------------
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
connx.maxpool(np.array([{}, None]), ())
|
||||
fhe.maxpool(np.array([{}, None]), ())
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
@@ -379,7 +383,7 @@ Expected input elements to be of type np.integer, np.floating, or np.bool_ but i
|
||||
# -----------------
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
connx.maxpool("", ())
|
||||
fhe.maxpool("", ())
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
@@ -74,3 +74,98 @@ def test_constant_mul(function, parameters, helpers):
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
id="x * y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear"},
|
||||
"y": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
"y": {"range": [0, 10], "status": "clear"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
"y": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "clear"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear"},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "clear", "shape": (2, 1)},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (2, 1)},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": (2, 1)},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_mul(function, parameters, helpers):
|
||||
"""
|
||||
Test mul where both of the operators are dynamic.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -273,7 +273,7 @@ def deterministic_unary_function(x):
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 84]},
|
||||
},
|
||||
id="abs(64 - x)",
|
||||
id="abs(42 - x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: ~x,
|
||||
|
||||
650
frontends/concrete-python/tests/mlir/test_converter.py
Normal file
650
frontends/concrete-python/tests/mlir/test_converter.py
Normal file
@@ -0,0 +1,650 @@
|
||||
"""
|
||||
Tests of `Converter` class.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete import fhe
|
||||
from concrete.fhe.mlir import GraphConverter
|
||||
|
||||
|
||||
def assign(x, y):
|
||||
"""
|
||||
Assign scalar `y` into vector `x`.
|
||||
"""
|
||||
|
||||
x[0] = y
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,encryption_statuses,inputset,expected_error,expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(0.0, 0), (7.0, 7), (0.0, 7), (7.0, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<float64> ∈ [0.0, 7.0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integers are supported
|
||||
%1 = y # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%2 = add(%0, %1) # EncryptedScalar<float64> ∈ [0.0, 14.0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integers are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.conv(x, [[[3, 1, 0, 2]]]),
|
||||
{"x": "encrypted"},
|
||||
[np.ones(shape=(1, 1, 10), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10)> ∈ [1, 1]
|
||||
%1 = [[[3 1 0 2]]] # ClearTensor<uint2, shape=(1, 1, 4)> ∈ [0, 3]
|
||||
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint3, shape=(1, 1, 7)> ∈ [6, 6]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1-dimensional convolutions are not supported at the moment
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.conv(x, [[[[[1, 3], [4, 2]]]]]),
|
||||
{"x": "encrypted"},
|
||||
[np.ones(shape=(1, 1, 3, 4, 5), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 3, 4, 5)> ∈ [1, 1]
|
||||
%1 = [[[[[1 3] [4 2]]]]] # ClearTensor<uint3, shape=(1, 1, 1, 2, 2)> ∈ [1, 4]
|
||||
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint4, shape=(1, 1, 3, 3, 4)> ∈ [10, 10]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3-dimensional convolutions are not supported at the moment
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.maxpool(x, kernel_shape=(3,)),
|
||||
{"x": "encrypted"},
|
||||
[np.ones(shape=(1, 1, 10), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10)> ∈ [1, 1]
|
||||
%1 = maxpool1d(%0, kernel_shape=(3,), strides=(1,), pads=(0, 0), dilations=(1,), ceil_mode=False) # EncryptedTensor<uint1, shape=(1, 1, 8)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1-dimensional maxpooling is not supported at the moment
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.maxpool(x, kernel_shape=(3, 1, 2)),
|
||||
{"x": "encrypted"},
|
||||
[np.ones(shape=(1, 1, 3, 4, 5), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 3, 4, 5)> ∈ [1, 1]
|
||||
%1 = maxpool3d(%0, kernel_shape=(3, 1, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # EncryptedTensor<uint1, shape=(1, 1, 1, 4, 4)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3-dimensional maxpooling is not supported at the moment
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{"x": "clear", "y": "clear"},
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
|
||||
%1 = y # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = add(%0, %1) # ClearScalar<uint4> ∈ [0, 14]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear additions are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
{"x": "clear", "y": "clear"},
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
|
||||
%1 = y # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = subtract(%0, %1) # ClearScalar<int4> ∈ [-7, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear subtractions are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{"x": "clear", "y": "clear"},
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
|
||||
%1 = y # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = multiply(%0, %1) # ClearScalar<uint6> ∈ [0, 49]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear multiplications are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "clear", "y": "clear"},
|
||||
[([1, 2], [3, 4])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
|
||||
%1 = y # ClearTensor<uint3, shape=(2,)> ∈ [3, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = dot(%0, %1) # ClearScalar<uint4> ∈ [11, 11]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear dot products are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.broadcast_to(x, shape=(2, 2)),
|
||||
{"x": "clear"},
|
||||
[[1, 2], [3, 4]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(2,)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
|
||||
%1 = broadcast_to(%0, shape=(2, 2)) # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear values cannot be broadcasted
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
assign,
|
||||
{"x": "clear", "y": "encrypted"},
|
||||
[([1, 2, 3], 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint2, shape=(3,)> ∈ [0, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tensor is clear
|
||||
%1 = y # EncryptedScalar<uint1> ∈ [0, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ assigned value is encrypted
|
||||
%2 = (%0[0] = %1) # ClearTensor<uint2, shape=(3,)> ∈ [0, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted values cannot be assigned to clear tensors
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x**2 + (x + 1_000_000),
|
||||
{"x": "encrypted"},
|
||||
[100_000],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 34-bit value is used as an input to a table lookup
|
||||
(note that it's assigned 34-bits during compilation because of its relation with other operations)
|
||||
%1 = 2 # ClearScalar<uint2> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # EncryptedScalar<uint34> ∈ [10000000000, 10000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit table lookups are supported
|
||||
%3 = 1000000 # ClearScalar<uint20> ∈ [1000000, 1000000]
|
||||
%4 = add(%0, %3) # EncryptedScalar<uint21> ∈ [1100000, 1100000]
|
||||
%5 = add(%2, %4) # EncryptedScalar<uint34> ∈ [10001100000, 10001100000]
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x & y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(-2, 4)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<int2> ∈ [-2, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is signed
|
||||
%1 = y # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
%2 = bitwise_and(%0, %1) # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x & y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(4, -2)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
%1 = y # EncryptedScalar<int2> ∈ [-2, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is signed
|
||||
%2 = bitwise_and(%0, %1) # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.concatenate((x, y)),
|
||||
{"x": "clear", "y": "clear"},
|
||||
[([1, 2], [3, 4])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
|
||||
%1 = y # ClearTensor<uint3, shape=(2,)> ∈ [3, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
|
||||
%2 = concatenate((%0, %1)) # ClearTensor<uint3, shape=(4,)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear concatenation is not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.conv(x, [[[[2, 1], [0, 3]]]]),
|
||||
{"x": "clear"},
|
||||
[np.ones(shape=(1, 1, 10, 10), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
|
||||
%1 = [[[[2 1] [0 3]]]] # ClearTensor<uint2, shape=(1, 1, 2, 2)> ∈ [0, 3]
|
||||
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [6, 6]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear convolutions are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: fhe.conv(x, weight=y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.ones(shape=(1, 1, 10, 10), dtype=np.int64),
|
||||
np.ones(shape=(1, 1, 2, 2), dtype=np.int64),
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
|
||||
%1 = y # EncryptedTensor<uint1, shape=(1, 1, 2, 2)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ weight is encrypted
|
||||
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [4, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but convolutions with encrypted weights are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: fhe.conv(x, weight=[[[[2, 1], [0, 3]]]], bias=y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.ones(shape=(1, 1, 10, 10), dtype=np.int64),
|
||||
np.ones(shape=(1,), dtype=np.int64),
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
|
||||
%1 = y # EncryptedTensor<uint1, shape=(1,)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bias is encrypted
|
||||
%2 = [[[[2 1] [0 3]]]] # ClearTensor<uint2, shape=(1, 1, 2, 2)> ∈ [0, 3]
|
||||
%3 = conv2d(%0, %2, %1, pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [7, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but convolutions with encrypted biases are not supported
|
||||
return %3
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.ones(shape=(3,), dtype=np.int64),
|
||||
np.ones(shape=(3,), dtype=np.int64),
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
|
||||
%1 = y # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint2> ∈ [3, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted dot products are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "clear", "y": "clear"},
|
||||
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
|
||||
%1 = y # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = matmul(%0, %1) # ClearTensor<uint5, shape=(2, 2)> ∈ [5, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear matrix multiplications are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
|
||||
%1 = y # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint5, shape=(2, 2)> ∈ [5, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted matrix multiplications are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.maxpool(x, kernel_shape=(3, 2)),
|
||||
{"x": "clear"},
|
||||
[np.ones(shape=(1, 1, 10, 5), dtype=np.int64)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 5)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
|
||||
%1 = maxpool2d(%0, kernel_shape=(3, 2), strides=(1, 1), pads=(0, 0, 0, 0), dilations=(1, 1), ceil_mode=False) # ClearTensor<uint1, shape=(1, 1, 8, 4)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear maxpooling is not supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x**2,
|
||||
{"x": "clear"},
|
||||
[3, 4, 5],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearScalar<uint3> ∈ [3, 5]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this clear value is used as an input to a table lookup
|
||||
%1 = 2 # ClearScalar<uint2> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # ClearScalar<uint5> ∈ [9, 25]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only encrypted table lookups are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.sum(x),
|
||||
{"x": "clear"},
|
||||
[[1, 2]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
|
||||
%1 = sum(%0) # ClearScalar<uint2> ∈ [3, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear summation is not supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(-2, 4)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<int2> ∈ [-2, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is signed
|
||||
%1 = y # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
%2 = left_shift(%0, %1) # EncryptedScalar<int6> ∈ [-32, -32]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise shifts are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x >> y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(4, -2)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [4, 4]
|
||||
%1 = y # EncryptedScalar<int2> ∈ [-2, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is signed
|
||||
%2 = right_shift(%0, %1) # EncryptedScalar<uint1> ∈ [0, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise shifts are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: -x,
|
||||
{"x": "clear"},
|
||||
[10],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearScalar<uint4> ∈ [10, 10]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
|
||||
%1 = negative(%0) # ClearScalar<int5> ∈ [-10, -10]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear negations are not supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: fhe.LookupTable([fhe.LookupTable([0, 1]), fhe.LookupTable([1, 0])])[x],
|
||||
{"x": "clear"},
|
||||
[[1, 1], [1, 0], [0, 1], [0, 0]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this clear value is used as an input to a table lookup
|
||||
%1 = tlu(%0, table=[[0, 1] [1, 0]]) # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only encrypted table lookups are supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_converter_bad_convert(
|
||||
function,
|
||||
encryption_statuses,
|
||||
inputset,
|
||||
expected_error,
|
||||
expected_message,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test unsupported graph conversion.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = fhe.Compiler(function, encryption_statuses)
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_graph",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: (x**2) + 100,
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
|
||||
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
|
||||
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
|
||||
return %4
|
||||
|
||||
""",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_converter_process_multi_precision(function, parameters, expected_graph, helpers):
|
||||
"""
|
||||
Test `process` method of `Converter` with multi precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=False)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
processed_graph = GraphConverter().process(graph, configuration)
|
||||
for node in processed_graph.query_nodes():
|
||||
if "original_bit_width" in node.properties:
|
||||
del node.properties["original_bit_width"]
|
||||
|
||||
helpers.check_str(expected_graph, processed_graph.format())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_graph",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: (x**2) + 100,
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint8> ∈ [0, 10]
|
||||
%1 = 2 # ClearScalar<uint9> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
|
||||
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
|
||||
return %4
|
||||
|
||||
""",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_converter_process_single_precision(function, parameters, expected_graph, helpers):
|
||||
"""
|
||||
Test `process` method of `Converter` with single precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=True)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
processed_graph = GraphConverter().process(graph, configuration)
|
||||
for node in processed_graph.query_nodes():
|
||||
if "original_bit_width" in node.properties:
|
||||
del node.properties["original_bit_width"]
|
||||
|
||||
helpers.check_str(expected_graph, processed_graph.format())
|
||||
@@ -1,499 +0,0 @@
|
||||
"""
|
||||
Tests of `GraphConverter` class.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.onnx as connx
|
||||
from concrete import fhe
|
||||
|
||||
|
||||
def assign(x):
|
||||
"""
|
||||
Simple assignment to a vector.
|
||||
"""
|
||||
|
||||
x[0] = 0
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,encryption_statuses,inputset,expected_error,expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: (x - y, x + y),
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%1 = y # ClearScalar<uint3> ∈ [0, 7]
|
||||
%2 = subtract(%0, %1) # EncryptedScalar<int4> ∈ [-7, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
%3 = add(%0, %1) # EncryptedScalar<uint4> ∈ [0, 14]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
return (%2, %3)
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x,
|
||||
{"x": "clear"},
|
||||
range(-10, 10),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearScalar<int5> ∈ [-10, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
|
||||
return %0
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 1.5,
|
||||
{"x": "encrypted"},
|
||||
[2.5 * x for x in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<float64> ∈ [0.0, 247.5]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
|
||||
%1 = 1.5 # ClearScalar<float64> ∈ [1.5, 1.5]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<float64> ∈ [0.0, 371.25]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.sin(x),
|
||||
{"x": "encrypted"},
|
||||
range(100),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint7> ∈ [0, 99]
|
||||
%1 = sin(%0) # EncryptedScalar<float64> ∈ [-0.99999, 0.999912]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.concatenate((x, y)),
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2**3, size=(3, 2)),
|
||||
np.random.randint(0, 2**3, size=(3, 2)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
%1 = y # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
%2 = concatenate((%0, %1)) # EncryptedTensor<uint3, shape=(6, 2)> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, w: connx.conv(x, w),
|
||||
{"x": "encrypted", "w": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2, size=(1, 1, 4)),
|
||||
np.random.randint(0, 2, size=(1, 1, 1)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, w: connx.conv(x, w),
|
||||
{"x": "encrypted", "w": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2, size=(1, 1, 4, 4)),
|
||||
np.random.randint(0, 2, size=(1, 1, 1, 1)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, w: connx.conv(x, w),
|
||||
{"x": "encrypted", "w": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2, size=(1, 1, 4, 4, 4)),
|
||||
np.random.randint(0, 2, size=(1, 1, 1, 1, 1)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[([0], [0]), ([3], [3]), ([3], [0]), ([0], [3]), ([1], [1])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
|
||||
%1 = y # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint4> ∈ [0, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{"x": "clear"},
|
||||
[[0, 1, 2, 3], [7, 6, 5, 4]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(4,)> ∈ [0, 7]
|
||||
%1 = %0[0] # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2**1, size=(1, 1)),
|
||||
np.random.randint(0, 2**1, size=(1, 1)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
%1 = y # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%1 = y # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint6> ∈ [0, 49]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: -x,
|
||||
{"x": "clear"},
|
||||
[0, 7],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearScalar<uint3> ∈ [0, 7]
|
||||
%1 = negative(%0) # ClearScalar<int4> ∈ [-7, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2**3, size=(2, 3)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(2, 3)> ∈ [0, 7]
|
||||
%1 = reshape(%0, newshape=(3, 2)) # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.sum(x),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(1,)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
|
||||
%1 = sum(%0) # ClearScalar<uint1> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.maximum(x, np.array([3])),
|
||||
{"x": "clear"},
|
||||
[[0], [1]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
|
||||
%1 = [3] # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
|
||||
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3, 2)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
|
||||
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.broadcast_to(x, shape=(3, 2)),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(2,)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
|
||||
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
assign,
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3,)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
|
||||
%1 = 0 # ClearScalar<uint1> ∈ [0, 0]
|
||||
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
|
||||
{"x": "encrypted"},
|
||||
[200000],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint18> ∈ [200000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this input is 18-bits
|
||||
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint18> ∈ [200300, 200300]
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint4> ∈ [9, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bits
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
%2 = 10 # ClearScalar<uint4>
|
||||
%3 = multiply(%2, %1) # EncryptedScalar<float64>
|
||||
%4 = absolute(%3) # EncryptedScalar<float64>
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(-1, 1), (-2, 3)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<int2> ∈ [-2, -1]
|
||||
%1 = y # EncryptedScalar<uint2> ∈ [1, 3]
|
||||
%2 = left_shift(%0, %1) # EncryptedScalar<int5> ∈ [-16, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned bitwise operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(1, 20), (2, 10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint2> ∈ [1, 2]
|
||||
%1 = y # EncryptedScalar<uint5> ∈ [10, 20]
|
||||
%2 = left_shift(%0, %1) # EncryptedScalar<uint21> ∈ [2048, 1048576]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 4-bit shifts are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_graph_converter_bad_convert(
|
||||
function,
|
||||
encryption_statuses,
|
||||
inputset,
|
||||
expected_error,
|
||||
expected_message,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test unsupported graph conversion.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = fhe.Compiler(function, encryption_statuses)
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: 1 + fhe.LookupTable([4, 1, 2, 3])[x] + fhe.LookupTable([4, 1, 2, 3])[x + 1],
|
||||
range(3),
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%c1_i4 = arith.constant 1 : i4
|
||||
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
%1 = "FHE.add_eint_int"(%arg0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
%2 = "FHE.add_eint_int"(%0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
%3 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
%4 = "FHE.add_eint"(%2, %3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
return %4 : !FHE.eint<3>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
# Notice that there is only a single 1 and a single table cst above
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_cache(function, inputset, expected_mlir, helpers):
|
||||
"""
|
||||
Test caching MLIR constants.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(function, {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir, circuit.mlir)
|
||||
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
@@ -39,6 +39,64 @@ def f(x):
|
||||
return g(z + 3) * 2
|
||||
|
||||
|
||||
def test_graph_format_show_lines(helpers):
|
||||
"""
|
||||
Test `format` method of `Graph` class with show_lines=True.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(f, {"x": "encrypted"})
|
||||
graph = compiler.trace(range(10), configuration)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
expected = f"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 9] {tests_directory}/representation/test_graph.py:50
|
||||
%1 = 2 # ClearScalar<uint2> ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint5> ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%3 = 42 # ClearScalar<uint6> ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint6> ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%5 = subgraph(%4) # EncryptedScalar<uint3> ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%6 = 3 # ClearScalar<uint2> ∈ [3, 3] {tests_directory}/representation/test_graph.py:39
|
||||
%7 = add(%5, %6) # EncryptedScalar<uint4> ∈ [9, 10] {tests_directory}/representation/test_graph.py:39
|
||||
%8 = 120 # ClearScalar<uint7> ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%9 = subtract(%8, %7) # EncryptedScalar<uint7> ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%10 = 4 # ClearScalar<uint3> ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%12 = 2 # ClearScalar<uint2> ∈ [2, 2] {tests_directory}/representation/test_graph.py:39
|
||||
%13 = multiply(%11, %12) # EncryptedScalar<uint6> ∈ [54, 54] {tests_directory}/representation/test_graph.py:39
|
||||
return %13
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%5 = subgraph(%4):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2> @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
actual = graph.format(show_locations=True)
|
||||
|
||||
assert (
|
||||
actual.strip() == expected.strip()
|
||||
), f"""
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{actual}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,tag_filter,operation_filter,expected_result",
|
||||
[
|
||||
@@ -184,13 +242,14 @@ def test_graph_maximum_integer_bit_width(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,tag_filter,operation_filter,expected_result",
|
||||
"function,inputset,tag_filter,operation_filter,is_encrypted_filter,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 42,
|
||||
range(-10, 10),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
(-10, 51),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -199,12 +258,14 @@ def test_graph_maximum_integer_bit_width(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
(0, 120),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -212,6 +273,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
(0, 54),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -219,6 +281,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
"abc",
|
||||
None,
|
||||
None,
|
||||
(0, 18),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -226,6 +289,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
["abc", "def"],
|
||||
None,
|
||||
None,
|
||||
(0, 120),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -233,6 +297,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
re.compile(".*b.*"),
|
||||
None,
|
||||
None,
|
||||
(0, 60),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -240,6 +305,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
"input",
|
||||
None,
|
||||
(0, 9),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -247,6 +313,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
"constant",
|
||||
None,
|
||||
(2, 120),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -254,6 +321,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
"subgraph",
|
||||
None,
|
||||
(6, 7),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -261,6 +329,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
"add",
|
||||
None,
|
||||
(9, 60),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -268,6 +337,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
["subgraph", "add"],
|
||||
None,
|
||||
(6, 60),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -275,6 +345,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
None,
|
||||
re.compile("sub.*"),
|
||||
None,
|
||||
(6, 111),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -282,6 +353,7 @@ def test_graph_maximum_integer_bit_width(
|
||||
range(10),
|
||||
"abc.foo",
|
||||
"add",
|
||||
None,
|
||||
(42, 60),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -290,6 +362,23 @@ def test_graph_maximum_integer_bit_width(
|
||||
"abc",
|
||||
"floor_divide",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - 2,
|
||||
range(5, 10),
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
(3, 9),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - 2,
|
||||
range(5, 10),
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
(2, 2),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -298,6 +387,7 @@ def test_graph_integer_range(
|
||||
inputset,
|
||||
tag_filter,
|
||||
operation_filter,
|
||||
is_encrypted_filter,
|
||||
expected_result,
|
||||
helpers,
|
||||
):
|
||||
@@ -310,62 +400,23 @@ def test_graph_integer_range(
|
||||
compiler = fhe.Compiler(function, {"x": "encrypted"})
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
assert graph.integer_range(tag_filter, operation_filter) == expected_result
|
||||
assert graph.integer_range(tag_filter, operation_filter, is_encrypted_filter) == expected_result
|
||||
|
||||
|
||||
def test_graph_format_show_lines(helpers):
|
||||
def test_direct_graph_integer_range(helpers):
|
||||
"""
|
||||
Test `format` method of `Graph` class with show_lines=True.
|
||||
Test `integer_range` method of `Graph` class where `graph.is_direct` is `True`.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from concrete.fhe.dtypes import Integer
|
||||
from concrete.fhe.values import Value
|
||||
|
||||
compiler = fhe.Compiler(f, {"x": "encrypted"})
|
||||
graph = compiler.trace(range(10), configuration)
|
||||
# pylint: enable=import-outside-toplevel
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
expected = f"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 9] {tests_directory}/representation/test_graph.py:324
|
||||
%1 = 2 # ClearScalar<uint2> ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint5> ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%3 = 42 # ClearScalar<uint6> ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint6> ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%5 = subgraph(%4) # EncryptedScalar<uint3> ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%6 = 3 # ClearScalar<uint2> ∈ [3, 3] {tests_directory}/representation/test_graph.py:39
|
||||
%7 = add(%5, %6) # EncryptedScalar<uint4> ∈ [9, 10] {tests_directory}/representation/test_graph.py:39
|
||||
%8 = 120 # ClearScalar<uint7> ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%9 = subtract(%8, %7) # EncryptedScalar<uint7> ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%10 = 4 # ClearScalar<uint3> ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%12 = 2 # ClearScalar<uint2> ∈ [2, 2] {tests_directory}/representation/test_graph.py:39
|
||||
%13 = multiply(%11, %12) # EncryptedScalar<uint6> ∈ [54, 54] {tests_directory}/representation/test_graph.py:39
|
||||
return %13
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%5 = subgraph(%4):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2> @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
actual = graph.format(show_locations=True)
|
||||
|
||||
assert (
|
||||
actual.strip() == expected.strip()
|
||||
), f"""
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{actual}
|
||||
|
||||
"""
|
||||
circuit = fhe.Compiler.assemble(
|
||||
lambda x: x,
|
||||
{"x": Value(dtype=Integer(is_signed=False, bit_width=8), shape=(), is_encrypted=True)},
|
||||
configuration=helpers.configuration(),
|
||||
)
|
||||
assert circuit.graph.integer_range() is None
|
||||
|
||||
@@ -167,7 +167,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="index.static",
|
||||
name="index_static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
operation=lambda x: x[slice(None, None, -1)],
|
||||
@@ -208,7 +208,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="assign.static",
|
||||
name="assign_static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
||||
operation=lambda *args: args,
|
||||
@@ -266,7 +266,7 @@ def test_node_format(node, predecessors, expected_result):
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="index.static",
|
||||
name="index_static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=()),
|
||||
operation=lambda *args: args,
|
||||
@@ -276,7 +276,7 @@ def test_node_format(node, predecessors, expected_result):
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="assign.static",
|
||||
name="assign_static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
||||
operation=lambda *args: args,
|
||||
|
||||
Reference in New Issue
Block a user