refactor(frontend/python): re-write MLIR conversion

This commit is contained in:
Umut
2023-04-06 17:12:17 +02:00
parent b73d465f1d
commit 0edfe59e18
45 changed files with 4089 additions and 3379 deletions

View File

@@ -190,6 +190,9 @@ good-names=i,
k,
ex,
Run,
xs,
on,
of,
_
# Good variable names regexes, separated by a comma. If names match any regex,
@@ -438,7 +441,10 @@ disable=raw-checker-failed,
too-many-instance-attributes,
too-many-lines,
too-many-locals,
too-many-public-methods,
too-many-statements,
unnecessary-lambda-assignment,
use-implicit-booleaness-not-comparison,
wrong-import-order
# Enable the message, report, category or checker with the given id(s). You can

View File

@@ -8,10 +8,13 @@ select = [
ignore = [
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105",
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901",
"PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901"
"PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901",
"E731"
]
[per-file-ignores]
"**/__init__.py" = ["F401"]
"concrete/fhe/mlir/processors/all.py" = ["F401"]
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
"examples/**" = ["PLR2004"]
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311"]

View File

@@ -29,7 +29,15 @@ licenses:
pytest:
export LD_PRELOAD=$(RUNTIME_LIBRARY)
export PYTHONPATH=$(BINDINGS_DIRECTORY)
# test single precision
pytest tests -svv -n auto \
--key-cache "${KEY_CACHE_DIRECTORY}" \
-m "${PYTEST_MARKERS}"
# test multi precision
pytest tests -svv -n auto \
--precision=multi \
--cov=concrete \
--cov-fail-under=100 \
--cov-report=term-missing:skip-covered \

View File

@@ -1,5 +1,5 @@
"""
Setup concrete module to be enlarged with numpy module.
Setup concrete namespace.
"""
# Do not modify, this is to have a compatible namespace package

View File

@@ -1,10 +1,9 @@
"""
Export everything that users might need.
Concrete.
"""
# pylint: disable=import-error,no-name-in-module
# mypy: disable-error-code=attr-defined
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
from .compilation import (
@@ -24,6 +23,8 @@ from .extensions import (
AutoRounder,
LookupTable,
array,
conv,
maxpool,
one,
ones,
round_bit_pattern,

View File

@@ -8,7 +8,7 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
@@ -165,13 +165,7 @@ class Client:
)
if is_valid:
is_signed = self.specs.input_signs[index]
sanitizer = 0 if not is_signed else 2 ** (width - 1)
if isinstance(arg, int):
sanitized_args[index] = arg + sanitizer
else:
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
sanitized_args[index] = arg
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
@@ -205,61 +199,7 @@ class Client:
self.keygen(force=False)
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, self._keyset, result)
if not isinstance(outputs, tuple):
outputs = (outputs,)
sanitized_outputs: List[Union[int, np.ndarray]] = []
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("outputs" in client_parameters_json)
output_specs = client_parameters_json["outputs"]
for index, output in enumerate(outputs):
is_signed = self.specs.output_signs[index]
crt_decomposition = (
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
)
if is_signed:
if crt_decomposition:
if isinstance(output, int):
sanititzed_output = (
output
if output < (int(np.prod(crt_decomposition)) // 2)
else -int(np.prod(crt_decomposition)) + output
)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (np.prod(crt_decomposition) // 2),
output,
-np.prod(crt_decomposition) + output,
).astype(
np.int64
) # type: ignore
sanitized_outputs.append(sanititzed_output)
else:
n = output_specs[index]["shape"]["width"]
output %= 2**n
if isinstance(output, int):
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
sanitized_outputs.append(sanititzed_output)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (2 ** (n - 1)), output, output - (2**n)
).astype(
np.int64
) # type: ignore
sanitized_outputs.append(sanititzed_output)
else:
sanitized_outputs.append(
output if isinstance(output, int) else output.astype(np.uint64)
)
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
return outputs
@property
def evaluation_keys(self) -> EvaluationKeys:

View File

@@ -434,7 +434,7 @@ class Compiler:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph)
mlir = GraphConverter().convert(self.graph, self.configuration)
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)

View File

@@ -30,6 +30,7 @@ class Configuration:
global_p_error: Optional[float]
insecure_key_cache_location: Optional[str]
auto_adjust_rounders: bool
single_precision: bool
def _validate(self):
"""
@@ -64,6 +65,7 @@ class Configuration:
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
auto_adjust_rounders: bool = False,
single_precision: bool = True,
):
self.verbose = verbose
self.show_graph = show_graph
@@ -82,6 +84,7 @@ class Configuration:
self.p_error = p_error
self.global_p_error = global_p_error
self.auto_adjust_rounders = auto_adjust_rounders
self.single_precision = single_precision
self._validate()

View File

@@ -608,7 +608,7 @@ def convert_subgraph_to_subgraph_node(
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
subgraph_node = Node.generic(
"subgraph",
subgraph_variable_input_node.inputs,
deepcopy(subgraph_variable_input_node.inputs),
terminal_node.output,
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
kwargs={

View File

@@ -45,7 +45,7 @@ class Integer(BaseDataType):
if isinstance(value, list):
try:
value = np.array(value)
value = np.array(value, dtype=np.int64)
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception at the else branch below

View File

@@ -3,6 +3,8 @@ Provide additional features that are not present in numpy.
"""
from .array import array
from .convolution import conv
from .maxpool import maxpool
from .ones import one, ones
from .round_bit_pattern import AutoRounder, round_bit_pattern
from .table import LookupTable

View File

@@ -2,6 +2,7 @@
Declaration of `array` function, to simplify creation of encrypted arrays.
"""
from copy import deepcopy
from typing import Any, Union
import numpy as np
@@ -52,7 +53,7 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
computation = Node.generic(
"array",
[value.output for value in values],
[deepcopy(value.output) for value in values],
Value(dtype, shape, is_encrypted),
lambda *args: np.array(args).reshape(shape),
)

View File

@@ -1,17 +1,18 @@
"""
Convolution operations' tracing and evaluation.
Tracing and evaluation of convolution.
"""
import math
from copy import deepcopy
from typing import Callable, List, Optional, Tuple, Union, cast
import numpy as np
import torch
from ..fhe.internal.utils import assert_that
from ..fhe.representation import Node
from ..fhe.tracing import Tracer
from ..fhe.values import EncryptedTensor
from ..internal.utils import assert_that
from ..representation import Node
from ..tracing import Tracer
from ..values import EncryptedTensor
SUPPORTED_AUTO_PAD = {
"NOTSET",
@@ -23,8 +24,8 @@ SUPPORTED_AUTO_PAD = {
def conv(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]] = None,
weight: Union[np.ndarray, List, Tracer],
bias: Optional[Union[np.ndarray, List, Tracer]] = None,
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
@@ -63,11 +64,18 @@ def conv(
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
if kernel_shape is not None and (
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
):
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
raise ValueError(message)
if isinstance(weight, list): # pragma: no cover
try:
weight = np.array(weight)
except Exception: # pylint: disable=broad-except
pass
if bias is not None and isinstance(bias, list): # pragma: no cover
try:
bias = np.array(bias)
except Exception: # pylint: disable=broad-except
pass
if isinstance(x, np.ndarray):
if not isinstance(weight, np.ndarray):
@@ -84,6 +92,12 @@ def conv(
message = "expected bias to be of type Tracer or ndarray"
raise TypeError(message)
if kernel_shape is not None and (
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
):
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
raise ValueError(message)
if x.ndim <= 2:
message = (
f"expected input x to have at least 3 dimensions (N, C, D1, ...), but got {x.ndim}"
@@ -511,7 +525,7 @@ def _trace_conv(
computation = Node.generic(
conv_func, # "conv1d" or "conv2d" or "conv3d"
input_values,
deepcopy(input_values),
output_value,
eval_func,
args=() if bias is not None else (np.zeros(n_filters, dtype=np.int64),),

View File

@@ -1,16 +1,17 @@
"""
Tracing and evaluation of maxpool function.
Tracing and evaluation of maxpool.
"""
from copy import deepcopy
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..fhe.internal.utils import assert_that
from ..fhe.representation import Node
from ..fhe.tracing import Tracer
from ..fhe.values import Value
from ..internal.utils import assert_that
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
# pylint: disable=too-many-branches,too-many-statements
@@ -288,9 +289,12 @@ def _trace_or_evaluate(
resulting_value.is_encrypted = x.output.is_encrypted
resulting_value.dtype = x.output.dtype
dims = x.ndim - 2
assert_that(dims in {1, 2, 3})
computation = Node.generic(
"maxpool",
[x.output],
f"maxpool{dims}d",
[deepcopy(x.output)],
resulting_value,
_evaluate,
kwargs={

View File

@@ -218,7 +218,7 @@ def round_bit_pattern(
if isinstance(x, Tracer):
computation = Node.generic(
"round_bit_pattern",
[x.output],
[deepcopy(x.output)],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},

View File

@@ -83,7 +83,7 @@ class LookupTable:
computation = Node.generic(
"tlu",
[key.output],
[deepcopy(key.output)],
output,
LookupTable.apply,
kwargs={"table": table},

View File

@@ -2,6 +2,7 @@
Declaration of `univariate` function.
"""
from copy import deepcopy
from typing import Any, Callable, Optional, Type, Union
import numpy as np
@@ -75,7 +76,7 @@ def univariate(
computation = Node.generic(
function.__name__,
[x.output],
[deepcopy(x.output)],
output_value,
lambda x: function(x), # pylint: disable=unnecessary-lambda
)

View File

@@ -2,5 +2,4 @@
Provide `computation graph` to `mlir` functionality.
"""
from .graph_converter import GraphConverter
from .node_converter import NodeConverter
from .converter import Converter as GraphConverter

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,183 @@
"""
Declaration of `ConversionType` and `Conversion` classes.
"""
# pylint: disable=import-error,
import re
from typing import Optional, Tuple
from mlir.ir import OpResult as MlirOperation
from mlir.ir import Type as MlirType
from ..representation import Node
# pylint: enable=import-error
SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
SCALAR_EINT_SEARCH_REGEX = re.compile(r"^!FHE\.e(s)?int<([0-9]+)>$")
TENSOR_INT_SEARCH_REGEX = re.compile(r"^tensor<(([0-9]+x)+)i([0-9]+)>$")
TENSOR_EINT_SEARCH_REGEX = re.compile(r"^tensor<(([0-9]+x)+)!FHE\.e(s)?int<([0-9]+)>>$")
class ConversionType:
"""
ConversionType class, to make it easier to work with MLIR types.
"""
mlir: MlirType
bit_width: int
is_encrypted: bool
is_signed: bool
shape: Tuple[int, ...]
def __init__(self, mlir: MlirType):
self.mlir = mlir
mlir_type_str = str(mlir)
search = SCALAR_INT_SEARCH_REGEX.search(mlir_type_str)
if search:
(matched_bit_width,) = search.groups()
self.bit_width = int(matched_bit_width)
self.is_encrypted = False
self.is_signed = True
self.shape = ()
return
search = SCALAR_EINT_SEARCH_REGEX.search(mlir_type_str)
if search:
matched_is_signed, matched_bit_width = search.groups()
self.bit_width = int(matched_bit_width)
self.is_encrypted = True
self.is_signed = matched_is_signed is not None
self.shape = ()
return
search = TENSOR_INT_SEARCH_REGEX.search(mlir_type_str)
if search:
matched_shape, _, matched_bit_width = search.groups()
self.bit_width = int(matched_bit_width)
self.is_encrypted = False
self.is_signed = True
self.shape = tuple(int(size) for size in matched_shape.rstrip("x").split("x"))
return
search = TENSOR_EINT_SEARCH_REGEX.search(mlir_type_str)
if search:
matched_shape, _, matched_is_signed, matched_bit_width = search.groups()
self.bit_width = int(matched_bit_width)
self.is_encrypted = True
self.is_signed = matched_is_signed is not None
self.shape = tuple(int(size) for size in matched_shape.rstrip("x").split("x"))
return
self.is_encrypted = False
self.bit_width = 64
self.is_signed = False
self.shape = ()
# pylint: disable=missing-function-docstring
@property
def is_clear(self) -> bool:
return not self.is_encrypted
@property
def is_scalar(self) -> bool:
return self.shape == ()
@property
def is_tensor(self) -> bool:
return self.shape != ()
@property
def is_unsigned(self) -> bool:
return not self.is_signed
# pylint: enable=missing-function-docstring
class Conversion:
"""
Conversion class, to store MLIR operations with additional information.
"""
origin: Node
type: ConversionType
result: MlirOperation
_original_bit_width: Optional[int]
def __init__(self, origin: Node, result: MlirOperation):
self.origin = origin
self.type = ConversionType(result.type)
self.result = result
self._original_bit_width = None
def __hash__(self):
return hash(self.result)
def set_original_bit_width(self, original_bit_width: int):
"""
Set the original bit-width of the conversion.
"""
self._original_bit_width = original_bit_width
@property
def original_bit_width(self) -> int:
"""
Get the original bit-width of the conversion.
If not explicitly set, defaults to the actual bit width.
"""
return self._original_bit_width if self._original_bit_width is not None else self.bit_width
# pylint: disable=missing-function-docstring
@property
def bit_width(self) -> int:
return self.type.bit_width
@property
def is_clear(self) -> bool:
return self.type.is_clear
@property
def is_encrypted(self) -> bool:
return self.type.is_encrypted
@property
def is_scalar(self) -> bool:
return self.type.is_scalar
@property
def is_signed(self) -> bool:
return self.type.is_signed
@property
def is_tensor(self) -> bool:
return self.type.is_tensor
@property
def is_unsigned(self) -> bool:
return self.type.is_unsigned
@property
def shape(self) -> Tuple[int, ...]:
return self.type.shape
# pylint: enable=missing-function-docstring

View File

@@ -0,0 +1,501 @@
"""
Declaration of `Converter` class.
"""
# pylint: disable=import-error,no-name-in-module
from copy import deepcopy
from typing import List, Tuple
import concrete.lang
import networkx as nx
import numpy as np
from mlir.dialects import func
from mlir.ir import BlockArgument as MlirBlockArgument
from mlir.ir import Context as MlirContext
from mlir.ir import InsertionPoint as MlirInsertionPoint
from mlir.ir import Location as MlirLocation
from mlir.ir import Module as MlirModule
from mlir.ir import OpResult as MlirOperation
from concrete.fhe.compilation.configuration import Configuration
from ..representation import Graph, Node, Operation
from .context import Context
from .conversion import Conversion
from .processors.all import * # pylint: disable=wildcard-import
from .utils import MAXIMUM_TLU_BIT_WIDTH, construct_deduplicated_tables
# pylint: enable=import-error,no-name-in-module
class Converter:
"""
Converter class, to convert a computation graph to MLIR.
"""
def convert(self, graph: Graph, configuration: Configuration) -> str:
"""
Convert a computation graph to MLIR.
Args:
graph (Graph):
graph to convert
configuration (Configuration):
configuration to use
Return:
str:
MLIR corresponding to graph
"""
graph = self.process(graph, configuration)
with MlirContext() as context, MlirLocation.unknown():
concrete.lang.register_dialects(context) # pylint: disable=no-member
module = MlirModule.create()
with MlirInsertionPoint(module.body):
ctx = Context(context, graph)
input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()]
@func.FuncOp.from_py_func(*input_types)
def main(*args):
for index, node in enumerate(graph.ordered_inputs()):
conversion = Conversion(node, args[index])
if "original_bit_width" in node.properties:
conversion.set_original_bit_width(node.properties["original_bit_width"])
ctx.conversions[node] = conversion
for node in nx.lexicographical_topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)]
self.node(ctx, node, preds)
outputs = []
for node in graph.ordered_outputs():
assert node in ctx.conversions
outputs.append(ctx.conversions[node].result)
return tuple(outputs)
def extract_mlir_name(result: MlirOperation) -> str:
return (
f"%arg{result.arg_number}"
if isinstance(result, MlirBlockArgument)
else str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
)
direct_replacements = {}
for placeholder, elements in ctx.from_elements_operations.items():
element_names = [extract_mlir_name(element) for element in elements]
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
direct_replacements[extract_mlir_name(placeholder)] = actual_value
module_lines_after_direct_replacements_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in direct_replacements:
module_lines_after_direct_replacements_are_applied.append(line)
continue
new_value = direct_replacements[mlir_name]
new_line = f" {mlir_name} = {new_value}"
module_lines_after_direct_replacements_are_applied.append(new_line)
return "\n".join(module_lines_after_direct_replacements_are_applied).strip()
def process(self, graph: Graph, configuration: Configuration) -> Graph:
"""
Process a computation graph for MLIR conversion.
Args:
graph (Graph):
graph to convert
configuration (Configuration):
configuration to use
Return:
str:
MLIR corresponding to graph
"""
pipeline = [
CheckIntegerOnly(),
AssignBitWidths(single_precision=configuration.single_precision),
]
graph = deepcopy(graph)
for processor in pipeline:
processor.apply(graph)
return graph
def node(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
"""
Convert a computation graph node into MLIR.
Args:
ctx (Context):
conversion context
node (Node):
node to convert
preds (List[Conversion]):
conversions of ordered predecessors of the node
Return:
Conversion:
conversion object corresponding to node
"""
ctx.converting = node
assert node.operation != Operation.Input
operation = "constant" if node.operation == Operation.Constant else node.properties["name"]
assert operation not in ["convert", "node"]
converter = getattr(self, operation) if hasattr(self, operation) else self.tlu
conversion = converter(ctx, node, preds)
conversion.set_original_bit_width(node.properties["original_bit_width"])
ctx.conversions[node] = conversion
return conversion
# The name of the remaining methods all correspond to node names.
# And they have the same signature so that they can be called in a generic way.
# pylint: disable=missing-function-docstring,unused-argument
def add(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.add(ctx.typeof(node), preds[0], preds[1])
def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) > 0
return ctx.array(ctx.typeof(node), elements=preds)
def assign_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.assign_static(
ctx.typeof(node),
preds[0],
preds[1],
index=node.properties["kwargs"]["index"],
)
def bitwise_and(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.bitwise_and(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def bitwise_or(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.bitwise_or(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def bitwise_xor(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.bitwise_xor(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def broadcast_to(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.broadcast_to(preds[0], shape=node.output.shape)
def concatenate(self, ctx: Context, node: Node, preds: List[Conversion]):
return ctx.concatenate(
ctx.typeof(node),
preds,
axis=node.properties["kwargs"].get("axis", 0),
)
def constant(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 0
return ctx.constant(ctx.typeof(node), data=node())
def conv1d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
ctx.error({node: "1-dimensional convolutions are not supported at the moment"})
assert False, "unreachable" # pragma: no cover
def conv2d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) in [2, 3]
return ctx.conv2d(
ctx.typeof(node),
preds[0],
preds[1],
preds[2] if len(preds) == 3 else None,
strides=node.properties["kwargs"]["strides"],
dilations=node.properties["kwargs"]["dilations"],
pads=node.properties["kwargs"]["pads"],
group=node.properties["kwargs"]["group"],
)
def conv3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
ctx.error({node: "3-dimensional convolutions are not supported at the moment"})
assert False, "unreachable" # pragma: no cover
def dot(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.dot(ctx.typeof(node), preds[0], preds[1])
def equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.equality(ctx.typeof(node), preds[0], preds[1], equals=True)
return self.tlu(ctx, node, preds)
def expand_dims(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.reshape(preds[0], shape=node.output.shape)
def greater(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.greater(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def greater_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.greater_equal(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def index_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.index_static(
ctx.typeof(node),
preds[0],
index=node.properties["kwargs"]["index"],
)
def left_shift(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.shift(
ctx.typeof(node),
preds[0],
preds[1],
orientation="left",
original_resulting_bit_width=node.properties["original_bit_width"],
)
return self.tlu(ctx, node, preds)
def less(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.less(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def less_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.less_equal(ctx.typeof(node), preds[0], preds[1])
return self.tlu(ctx, node, preds)
def matmul(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])
def maxpool1d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
ctx.error({node: "1-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover
def maxpool2d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.maxpool2d(
ctx.typeof(node),
preds[0],
kernel_shape=node.properties["kwargs"]["kernel_shape"],
strides=node.properties["kwargs"]["strides"],
dilations=node.properties["kwargs"]["dilations"],
)
def maxpool3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover
def multiply(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.mul(ctx.typeof(node), preds[0], preds[1])
def negative(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.neg(ctx.typeof(node), preds[0])
def not_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.equality(ctx.typeof(node), preds[0], preds[1], equals=False)
return self.tlu(ctx, node, preds)
def ones(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 0
return ctx.ones(ctx.typeof(node))
def reshape(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.reshape(preds[0], shape=node.output.shape)
def right_shift(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
if all(pred.is_encrypted for pred in preds):
return ctx.shift(
ctx.typeof(node),
preds[0],
preds[1],
orientation="right",
original_resulting_bit_width=node.properties["original_bit_width"],
)
return self.tlu(ctx, node, preds)
def subtract(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.sub(ctx.typeof(node), preds[0], preds[1])
def sum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.sum(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", []),
keep_dims=node.properties["kwargs"].get("keepdims", False),
)
def squeeze(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
# because of the tracing logic, we have the correct output shape
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
# and the result is a scalar, so we need to do indexing, not reshape
if node.output.shape == ():
assert all(size == 1 for size in preds[0].shape)
index = (0,) * len(preds[0].shape)
return ctx.index_static(ctx.typeof(node), preds[0], index)
# otherwise, a simple reshape would work as we already have the correct shape
return ctx.reshape(preds[0], shape=node.output.shape)
def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert node.converted_to_table_lookup
variable_input_index = -1
pred_nodes = ctx.graph.ordered_preds_of(node)
for i, pred_node in enumerate(pred_nodes):
if pred_node.operation != Operation.Constant:
if variable_input_index == -1:
variable_input_index = i
else:
assert False, "unreachable" # pragma: no cover
assert variable_input_index != -1
variable_input = preds[variable_input_index]
if variable_input.bit_width > MAXIMUM_TLU_BIT_WIDTH:
variable_input_messages = [
f"this {variable_input.bit_width}-bit value "
f"is used as an input to a table lookup"
]
if variable_input.bit_width != variable_input.original_bit_width:
variable_input_messages.append(
"("
f"note that it's assigned {variable_input.bit_width}-bits "
f"during compilation because of its relation with other operations"
")"
)
highlights = {
variable_input.origin: variable_input_messages,
node: f"but only up to {MAXIMUM_TLU_BIT_WIDTH}-bit table lookups are supported",
}
ctx.error(highlights) # type: ignore
tables = construct_deduplicated_tables(node, pred_nodes)
assert len(tables) > 0
lut_shape: Tuple[int, ...] = ()
map_shape: Tuple[int, ...] = ()
if len(tables) == 1:
table = tables[0][0]
# The reduction on 63b is to avoid problems like doing a TLU of
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
# constraint of the compiler, while in practice, it is a small
# value. Reducing on 64b was not ok for some reason
lut_shape = (len(table),)
lut_values = np.array(table % (2 << 63), dtype=np.uint64)
map_shape = ()
map_values = None
else:
individual_table_size = len(tables[0][0])
lut_shape = (len(tables), individual_table_size)
map_shape = node.output.shape
lut_values = np.zeros(lut_shape, dtype=np.uint64)
map_values = np.zeros(map_shape, dtype=np.intp)
for i, (table, indices) in enumerate(tables):
assert len(table) == individual_table_size
lut_values[i, :] = table
for index in indices:
map_values[index] = i
if len(tables) == 1:
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())
assert map_values is not None
return ctx.multi_tlu(
ctx.typeof(node),
on=variable_input,
tables=lut_values.tolist(),
mapping=map_values.tolist(),
)
def transpose(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.transpose(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axes", []),
)
def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 0
return ctx.zeros(ctx.typeof(node))
# pylint: enable=missing-function-docstring,unused-argument

View File

@@ -1,739 +0,0 @@
"""
Declaration of `GraphConverter` class.
"""
# pylint: disable=import-error,no-member,no-name-in-module
from copy import deepcopy
from typing import Any, Dict, List, Optional, cast
# mypy: disable-error-code=attr-defined
import concrete.lang as concretelang
import networkx as nx
import numpy as np
from concrete.lang.dialects import fhe, fhelinalg
from mlir.dialects import arith, func
from mlir.ir import (
Attribute,
Context,
InsertionPoint,
IntegerAttr,
IntegerType,
Location,
Module,
OpResult,
RankedTensorType,
)
from ..dtypes import Integer, SignedInteger
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_TLU_BIT_WIDTH
# pylint: enable=import-error,no-member,no-name-in-module
class GraphConverter:
"""
GraphConverter class, to convert computation graphs to their MLIR equivalent.
"""
@staticmethod
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
"""
Check node convertibility to MLIR.
Args:
graph (Graph):
computation graph of the node
node (Node):
node to be checked
Returns:
Optional[str]:
None if node is convertible to MLIR, the reason for inconvertibility otherwise
"""
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
inputs = node.inputs
output = node.output
if node.operation == Operation.Constant:
assert_that(len(inputs) == 0)
if not isinstance(output.dtype, Integer):
return "only integer constants are supported"
elif node.operation == Operation.Input:
assert_that(len(inputs) == 1)
assert_that(inputs[0] == output)
if not isinstance(output.dtype, Integer):
return "only integer inputs are supported"
if output.dtype.is_signed and output.is_clear:
return "only encrypted signed integer inputs are supported"
else:
assert_that(node.operation == Operation.Generic)
if not isinstance(output.dtype, Integer):
return "only integer operations are supported"
name = node.properties["name"]
if name == "add":
assert_that(len(inputs) == 2)
elif name == "array":
assert_that(len(inputs) > 0)
assert_that(all(input.is_scalar for input in inputs))
elif name == "assign.static":
if not inputs[0].is_encrypted:
return "only assignment to encrypted tensors are supported"
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
assert_that(len(inputs) == 2)
if all(value.is_encrypted for value in node.inputs):
pred_nodes = graph.ordered_preds_of(node)
if (
name in ["left_shift", "right_shift"]
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
):
return "only up to 4-bit shifts are supported"
for pred_node in pred_nodes:
assert isinstance(pred_node.output.dtype, Integer)
if pred_node.output.dtype.is_signed:
return "only unsigned bitwise operations are supported"
elif name == "broadcast_to":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted broadcasting is supported"
elif name == "concatenate":
if not all(input.is_encrypted for input in inputs):
return "only all encrypted concatenate is supported"
elif name in ["conv1d", "conv2d", "conv3d"]:
assert_that(len(inputs) == 2 or len(inputs) == 3)
if not (inputs[0].is_encrypted and inputs[1].is_clear):
return f"only {name} with encrypted input and clear weight is supported"
elif name == "dot":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only dot product between encrypted and clear is supported"
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
assert_that(len(inputs) == 2)
elif name == "expand_dims":
assert_that(len(inputs) == 1)
elif name == "index.static":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted indexing supported"
elif name == "matmul":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only matrix multiplication between encrypted and clear is supported"
elif name == "maxpool":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted maxpool is supported"
elif name == "multiply":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only multiplication between encrypted and clear is supported"
elif name == "negative":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted negation is supported"
elif name == "ones":
assert_that(len(inputs) == 0)
elif name == "reshape":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted reshape is supported"
elif name == "squeeze":
assert_that(len(inputs) == 1)
elif name == "subtract":
assert_that(len(inputs) == 2)
elif name == "sum":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted sum is supported"
elif name == "transpose":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted transpose is supported"
elif name == "zeros":
assert_that(len(inputs) == 0)
else:
assert_that(node.converted_to_table_lookup)
variable_input_indices = [
idx
for idx, pred in enumerate(graph.ordered_preds_of(node))
if pred.operation != Operation.Constant
]
assert_that(len(variable_input_indices) == 1)
if len(inputs) > 0 and all(input.is_clear for input in inputs):
return "one of the operands must be encrypted"
return None
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
@staticmethod
def _check_graph_convertibility(graph: Graph):
"""
Check graph convertibility to MLIR.
Args:
graph (Graph):
computation graph to be checked
Raises:
RuntimeError:
if `graph` is not convertible to MLIR
"""
offending_nodes = {}
if len(graph.output_nodes) > 1:
offending_nodes.update(
{
node: ["only a single output is supported", node.location]
for node in graph.output_nodes.values()
}
)
if len(offending_nodes) == 0:
for node in graph.graph.nodes:
reason = GraphConverter._check_node_convertibility(graph, node)
if reason is not None:
offending_nodes[node] = [reason, node.location]
if len(offending_nodes) != 0:
message = (
"Function you are trying to compile cannot be converted to MLIR\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
raise RuntimeError(message)
@staticmethod
def _update_bit_widths(graph: Graph):
"""
Update bit-widths in a computation graph to be convertible to MLIR.
Args:
graph (Graph):
computation graph to be updated
"""
offending_nodes: Dict[Node, List[str]] = {}
max_bit_width = 0
max_bit_width_node = None
first_tlu_node = None
first_signed_node = None
for node in nx.lexicographical_topological_sort(graph.graph):
dtype = node.output.dtype
assert_that(isinstance(dtype, Integer))
current_node_bit_width = (
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
)
if (
all(value.is_encrypted for value in node.inputs)
and node.operation == Operation.Generic
and node.properties["name"]
in [
"greater",
"greater_equal",
"less",
"less_equal",
]
):
# implementation of these operators require at least 4 bits
current_node_bit_width = max(current_node_bit_width, 4)
if max_bit_width < current_node_bit_width:
max_bit_width = current_node_bit_width
max_bit_width_node = node
if node.converted_to_table_lookup and first_tlu_node is None:
first_tlu_node = node
if dtype.is_signed and first_signed_node is None:
first_signed_node = node
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
assert max_bit_width_node is not None
offending_nodes[max_bit_width_node] = [
(
{
Operation.Input: f"this input is {max_bit_width}-bits",
Operation.Constant: f"this constant is {max_bit_width}-bits",
Operation.Generic: f"this operation results in {max_bit_width}-bits",
}[max_bit_width_node.operation]
),
max_bit_width_node.location,
]
offending_nodes[first_tlu_node] = [
f"table lookups are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bits",
first_tlu_node.location,
]
if len(offending_nodes) != 0:
raise RuntimeError(
"Function you are trying to compile cannot be converted to MLIR:\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
for node in nx.topological_sort(graph.graph):
assert isinstance(node.output.dtype, Integer)
node.properties["original_bit_width"] = node.output.dtype.bit_width
for value in node.inputs + [node.output]:
dtype = value.dtype
assert_that(isinstance(dtype, Integer))
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
@staticmethod
def _offset_negative_lookup_table_inputs(graph: Graph):
"""
Offset negative table lookup inputs to be convertible to MLIR.
Args:
graph (Graph):
computation graph to apply offset
"""
# ugly hack to add an offset before entering a TLU
# if its variable input node has a signed output.
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
# this does not update the TLU input values to allow for proper table generation.
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic:
if not node.converted_to_table_lookup:
continue
variable_input_index = -1
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = index
break
variable_input_node = preds[variable_input_index]
variable_input_value = variable_input_node.output
variable_input_dtype = variable_input_value.dtype
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
if not variable_input_dtype.is_signed:
continue
variable_input_bit_width = variable_input_dtype.bit_width
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
offset_constant_value = abs(variable_input_dtype.min())
offset_constant = Node.constant(offset_constant_value)
offset_constant.output.dtype = offset_constant_dtype
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
offset_constant.properties["original_bit_width"] = original_bit_width
add_offset = Node.generic(
"add",
[variable_input_value, ClearScalar(offset_constant_dtype)],
variable_input_value,
np.add,
)
original_bit_width = variable_input_node.properties["original_bit_width"]
add_offset.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def _broadcast_assignments(graph: Graph):
"""
Broadcast assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
shape = node.inputs[0].shape
index = node.properties["kwargs"]["index"]
assert_that(isinstance(index, tuple))
while len(index) < len(shape):
index = (*index, slice(None, None, None))
required_value_shape_list = []
for i, indexing_element in enumerate(index):
if isinstance(indexing_element, slice):
n = len(np.zeros(shape[i])[indexing_element])
required_value_shape_list.append(n)
else:
required_value_shape_list.append(1)
required_value_shape = tuple(required_value_shape_list)
actual_value_shape = node.inputs[1].shape
if required_value_shape != actual_value_shape:
preds = graph.ordered_preds_of(node)
pred_to_modify = preds[1]
modified_value = deepcopy(pred_to_modify.output)
modified_value.shape = required_value_shape
try:
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
modified_value.is_encrypted = True
modified_value.dtype = node.output.dtype
modified_pred = Node.generic(
"broadcast_to",
[pred_to_modify.output],
modified_value,
np.broadcast_to,
kwargs={"shape": required_value_shape},
)
except Exception: # pylint: disable=broad-except
np.reshape(np.zeros(actual_value_shape), required_value_shape)
modified_pred = Node.generic(
"reshape",
[pred_to_modify.output],
modified_value,
np.reshape,
kwargs={"newshape": required_value_shape},
)
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
"original_bit_width"
]
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
nx_graph.remove_edge(pred_to_modify, node)
nx_graph.add_edge(modified_pred, node, input_idx=1)
node.inputs[1] = modified_value
@staticmethod
def _encrypt_clear_assignments(graph: Graph):
"""
Encrypt clear assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
assigned_value = node.inputs[1]
if assigned_value.is_clear:
preds = graph.ordered_preds_of(node)
assigned_pred = preds[1]
new_assigned_pred_value = deepcopy(assigned_value)
new_assigned_pred_value.is_encrypted = True
new_assigned_pred_value.dtype = preds[0].output.dtype
zero = Node.generic(
"zeros",
[],
EncryptedScalar(new_assigned_pred_value.dtype),
lambda: np.zeros((), dtype=np.int64),
)
original_bit_width = 1
zero.properties["original_bit_width"] = original_bit_width
new_assigned_pred = Node.generic(
"add",
[assigned_pred.output, zero.output],
new_assigned_pred_value,
np.add,
)
original_bit_width = assigned_pred.properties["original_bit_width"]
new_assigned_pred.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(preds[1], node)
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
@staticmethod
def _tensorize_scalars_for_fhelinalg(graph: Graph):
"""
Tensorize scalars if they are used within fhelinalg operations.
Args:
graph (Graph):
computation graph to update
"""
# pylint: disable=invalid-name
OPS_TO_TENSORIZE = [
"add",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"broadcast_to",
"dot",
"equal",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"multiply",
"not_equal",
"right_shift",
"subtract",
]
# pylint: enable=invalid-name
tensorized_scalars: Dict[Node, Node] = {}
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
assert len(node.inputs) in {1, 2}
if len(node.inputs) == 2:
if {inp.is_scalar for inp in node.inputs} != {True, False}:
continue
else:
if not node.inputs[0].is_scalar: # noqa: PLR5501
continue
# for bitwise and comparison operators that can have constants
# we don't need broadcasting here
if node.converted_to_table_lookup:
continue
pred_to_tensorize: Optional[Node] = None
pred_to_tensorize_index = 0
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.output.is_scalar:
pred_to_tensorize = pred
pred_to_tensorize_index = index
break
assert pred_to_tensorize is not None
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
if tensorized_pred is None:
tensorized_value = deepcopy(pred_to_tensorize.output)
tensorized_value.shape = (1,)
tensorized_pred = Node.generic(
"array",
[pred_to_tensorize.output],
tensorized_value,
lambda *args: np.array(args),
)
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
tensorized_pred.properties["original_bit_width"] = original_bit_width
original_shape = ()
tensorized_pred.properties["original_shape"] = original_shape
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
tensorized_scalars[pred_to_tensorize] = tensorized_pred
assert tensorized_pred is not None
nx_graph.remove_edge(pred_to_tensorize, node)
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
new_input_value.shape = (1,)
node.inputs[pred_to_tensorize_index] = new_input_value
@staticmethod
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
"""
Use subtraction to sanitize signed inputs.
Args:
graph (Graph):
computation graph being converted
args (List[Any]):
list of arguments from mlir main
ctx (Context):
mlir context where the conversion is being performed
Returns:
Tuple[List[str], List[Any]]:
sanitized args and name of the sanitized variables in MLIR
"""
sanitized_args = []
for i, arg in enumerate(args):
input_node = graph.input_nodes[i]
input_value = input_node.output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
if input_dtype.is_signed:
assert_that(input_value.is_encrypted)
n = input_dtype.bit_width
sanitizer_type = IntegerType.get_signless(n + 1)
sanitizer = 2 ** (n - 1)
if input_value.is_scalar:
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
else:
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
# pylint: disable=too-many-function-args
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
# pylint: enable=too-many-function-args
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
if input_value.is_scalar:
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
else:
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
sanitized_args.append(sanitized)
else:
sanitized_args.append(arg)
return sanitized_args
@staticmethod
def convert(graph: Graph) -> str:
"""
Convert a computation graph to its corresponding MLIR representation.
Args:
graph (Graph):
computation graph to be converted
Returns:
str:
textual MLIR representation corresponding to `graph`
"""
graph = deepcopy(graph)
GraphConverter._check_graph_convertibility(graph)
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
GraphConverter._broadcast_assignments(graph)
GraphConverter._encrypt_clear_assignments(graph)
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
module = Module.create()
with InsertionPoint(module.body):
parameters = [
NodeConverter.value_to_mlir_type(ctx, input_node.output)
for input_node in graph.ordered_inputs()
]
@func.FuncOp.from_py_func(*parameters)
def main(*args):
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
ir_to_mlir = {}
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = sanitized_args[arg_num]
constant_cache = {}
for node in nx.topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
node_converter = NodeConverter(
ctx,
graph,
node,
preds,
constant_cache,
from_elements_operations,
)
ir_to_mlir[node] = node_converter.convert()
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
return results
direct_replacements = {}
for placeholder, elements in from_elements_operations.items():
element_names = [NodeConverter.mlir_name(element) for element in elements]
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in direct_replacements:
module_lines_after_hacks_are_applied.append(line)
continue
new_value = direct_replacements[mlir_name]
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
return "\n".join(module_lines_after_hacks_are_applied).strip()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
"""
Declaration of `GraphProcessor` class.
"""
from abc import ABC, abstractmethod
from typing import List, Mapping, Union
from ...representation import Graph, Node
class GraphProcessor(ABC):
"""
GraphProcessor base class, to define the API for a graph processing pipeline.
"""
@abstractmethod
def apply(self, graph: Graph):
"""
Process the graph.
"""
@staticmethod
def error(graph: Graph, highlights: Mapping[Node, Union[str, List[str]]]):
"""
Fail processing with an error.
Args:
graph (Graph):
graph being processed
highlights (Mapping[Node, Union[str, List[str]]]):
nodes to highlight along with messages
"""
highlights_with_location = {}
for node, messages in highlights.items():
messages_with_location = messages if isinstance(messages, list) else [messages]
messages_with_location.append(node.location)
highlights_with_location[node] = messages_with_location
message = "Function you are trying to compile cannot be compiled\n\n" + graph.format(
highlighted_nodes=highlights_with_location
)
raise RuntimeError(message)

View File

@@ -0,0 +1,10 @@
"""
All graph processors.
"""
# pylint: disable=unused-import
from .assign_bit_widths import AssignBitWidths
from .check_integer_only import CheckIntegerOnly
# pylint: enable=unused-import

View File

@@ -0,0 +1,246 @@
"""
Declaration of `AssignBitWidths` graph processor.
"""
from __future__ import annotations
import typing
from collections.abc import Iterable
from ...dtypes import Integer
from ...representation import Graph, Node, Operation
from . import GraphProcessor
class AssignBitWidths(GraphProcessor):
"""
Assign a precision to all nodes inputs/output.
The precisions are compatible graph constraints and MLIR.
There are two modes:
- single precision: where all encrypted values have the same precision.
- multi precision: where encrypted values can have different precisions.
"""
def __init__(self, single_precision=False):
self.single_precision = single_precision
def apply(self, graph: Graph):
nodes = graph.query_nodes()
for node in nodes:
assert isinstance(node.output.dtype, Integer)
node.properties["original_bit_width"] = node.output.dtype.bit_width
if self.single_precision:
assign_single_precision(nodes)
else:
assign_multi_precision(graph, nodes)
def assign_single_precision(nodes: list[Node]):
"""Assign one single encryption precision to all nodes."""
p = required_encrypted_bitwidth(nodes)
for node in nodes:
assign_precisions_1_node(node, p, p)
def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
"""Assign input/output precision to a single node.
Precision are adjusted to match different use, e.g. encrypted and constant case.
"""
assert isinstance(node.output.dtype, Integer)
if node.output.is_encrypted:
node.output.dtype.bit_width = output_p
else:
node.output.dtype.bit_width = output_p + 1
for value in node.inputs:
assert isinstance(value.dtype, Integer)
if value.is_encrypted:
value.dtype.bit_width = inputs_p
else:
value.dtype.bit_width = inputs_p + 1
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
MULTIPLY = {"multiply"}
def max_encrypted_bitwidth_node(node: Node):
"""Give the minimal precision to implement the node.
This applies to both input and output precisions.
"""
assert isinstance(node.output.dtype, Integer)
if node.output.is_encrypted or node.operation == Operation.Constant:
normal_p = node.output.dtype.bit_width
else:
normal_p = -1
name = node.properties.get("name")
if name in CHUNKED_COMPARISON:
return max(normal_p, CHUNKED_COMPARISON_MIN_BITWIDTH)
if name in MAX_POOLS:
return normal_p + 1
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
return normal_p + 1
return normal_p
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
"""Give the minimal precision to implement all the nodes."""
bitwidths = map(max_encrypted_bitwidth_node, nodes)
return max(bitwidths, default=-1)
def required_inputs_encrypted_bitwidth(graph, node, nodes_output_p: list[tuple[Node, int]]) -> int:
"""Give the minimal precision to supports the inputs."""
preds = graph.ordered_preds_of(node)
get_prec = lambda node: nodes_output_p[node.properties[NODE_ID]][1]
# by definition all inputs have the same block precision
# see uniform_precision_per_blocks
return get_prec(node) if len(preds) == 0 else get_prec(preds[0])
def assign_multi_precision(graph, nodes):
"""Assign a specific encryption precision to each nodes."""
add_nodes_id(nodes)
nodes_output_p = uniform_precision_per_blocks(graph, nodes)
for node, _ in nodes_output_p:
node.properties["original_bit_width"] = node.output.dtype.bit_width
nodes_inputs_p = [
required_inputs_encrypted_bitwidth(graph, node, nodes_output_p)
if can_change_precision(node)
else output_p
for node, output_p in nodes_output_p
]
for (node, output_p), inputs_p in zip(nodes_output_p, nodes_inputs_p):
assign_precisions_1_node(node, output_p, inputs_p)
clear_nodes_id(nodes)
TLU_WITHOUT_PRECISION_CHANGE = CHUNKED_COMPARISON | MAX_POOLS | MULTIPLY
def can_change_precision(node):
"""Detect if a node completely ties inputs/output precisions together."""
return (
node.converted_to_table_lookup
and node.properties.get("name") not in TLU_WITHOUT_PRECISION_CHANGE
)
def convert_union_to_blocks(node_union: UnionFind) -> Iterable[list[int]]:
"""Convert a `UnionFind` to blocks.
The result is an iterable of blocks.A block being a list of node id.
"""
blocks = {}
for node_id in range(node_union.size):
node_canon = node_union.find_canonical(node_id)
if node_canon == node_id:
assert node_canon not in blocks
blocks[node_canon] = [node_id]
else:
blocks[node_canon].append(node_id)
return blocks.values()
NODE_ID = "node_id"
def add_nodes_id(nodes):
"""Temporarily add a NODE_ID property to all nodes."""
for node_id, node in enumerate(nodes):
assert NODE_ID not in node.properties
node.properties[NODE_ID] = node_id
def clear_nodes_id(nodes):
"""Remove the NODE_ID property from all nodes."""
for node in nodes:
del node.properties[NODE_ID]
def uniform_precision_per_blocks(graph: Graph, nodes: list[Node]) -> list[tuple[Node, int]]:
"""Find the required precision of blocks and associate it corresponding nodes."""
size = len(nodes)
node_union = UnionFind(size)
for node_id, node in enumerate(nodes):
preds = graph.ordered_preds_of(node)
if not preds:
continue
# we always unify all inputs
first_input_id = preds[0].properties[NODE_ID]
for pred in preds[1:]:
pred_id = pred.properties[NODE_ID]
node_union.union(first_input_id, pred_id)
# we unify with outputs only if no precision change can occur
if not can_change_precision(node):
node_union.union(first_input_id, node_id)
blocks = convert_union_to_blocks(node_union)
result: list[None | tuple[Node, int]]
result = [None] * len(nodes)
for nodes_id in blocks:
output_p = required_encrypted_bitwidth(nodes[node_id] for node_id in nodes_id)
for node_id in nodes_id:
result[node_id] = (nodes[node_id], output_p)
assert None not in result
return typing.cast("list[tuple[Node, int]]", result)
class UnionFind:
"""
Utility class joins the nodes in equivalent precision classes.
Nodes are just integers id.
"""
parent: list[int]
def __init__(self, size: int):
"""Create a union find suitable for `size` nodes."""
self.parent = list(range(size))
@property
def size(self):
"""Size in number of nodes."""
return len(self.parent)
def find_canonical(self, a: int) -> int:
"""Find the current canonical node for a given input node."""
parent = self.parent[a]
if a == parent:
return a
canonical = self.find_canonical(parent)
self.parent[a] = canonical
return canonical
def union(self, a: int, b: int):
"""Union both nodes."""
self.united_common_ancestor(a, b)
def united_common_ancestor(self, a: int, b: int) -> int:
"""Deduce the common ancestor of both nodes after unification."""
parent_a = self.parent[a]
parent_b = self.parent[b]
if parent_a == parent_b:
return parent_a
if a == parent_a and parent_b < parent_a:
common_ancestor = parent_b
elif b == parent_b and parent_a < parent_b:
common_ancestor = parent_a
else:
common_ancestor = self.united_common_ancestor(parent_a, parent_b)
self.parent[a] = common_ancestor
self.parent[b] = common_ancestor
return common_ancestor

View File

@@ -0,0 +1,20 @@
"""
Declaration of `CheckIntegerOnly` graph processor.
"""
from ...dtypes import Integer
from ...representation import Graph
from . import GraphProcessor
class CheckIntegerOnly(GraphProcessor):
"""
CheckIntegerOnly graph processor, to make sure the graph only contains integer nodes.
"""
def apply(self, graph: Graph):
non_integer_nodes = graph.query_nodes(
custom_filter=(lambda node: not isinstance(node.output.dtype, Integer))
)
if non_integer_nodes:
self.error(graph, {node: "only integers are supported" for node in non_integer_nodes})

View File

@@ -4,7 +4,7 @@ Declaration of various functions and constants related to MLIR conversion.
from collections import defaultdict, deque
from copy import deepcopy
from itertools import product
from itertools import chain, product
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
import numpy as np
@@ -53,11 +53,11 @@ def flood_replace_none_values(table: list):
previous_idx = current_idx - 1
next_idx = current_idx + 1
if previous_idx >= 0 and table[previous_idx] is None:
if previous_idx >= 0 and table[previous_idx] is None: # pragma: no cover
table[previous_idx] = deepcopy(current_value)
not_none_values_idx.append(previous_idx)
if next_idx < len(table) and table[next_idx] is None:
if next_idx < len(table) and table[next_idx] is None: # pragma: no cover
table[next_idx] = deepcopy(current_value)
not_none_values_idx.append(next_idx)
@@ -93,12 +93,13 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
values = chain(range(0, variable_input_dtype.max() + 1), range(variable_input_dtype.min(), 0))
np.seterr(divide="ignore")
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
for value in values:
try:
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
table.append(node(*inputs))

View File

@@ -5,7 +5,7 @@ Declaration of `Graph` class.
import math
import re
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import networkx as nx
import numpy as np
@@ -296,10 +296,15 @@ class Graph:
bounds += "]"
output_value = node.output
if isinstance(output_value.dtype, Integer) and "original_bit_width" in node.properties:
output_value = deepcopy(output_value)
output_value.dtype.bit_width = node.properties["original_bit_width"]
# remember metadata of the node
line_metadata.append(
{
"type": f"# {node.output}",
"type": f"# {output_value}",
"bounds": bounds,
"tag": (f"@ {node.tag}" if node.tag != "" else ""),
"location": node.location,
@@ -559,6 +564,9 @@ class Graph:
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
is_encrypted_filter: Optional[bool] = None,
custom_filter: Optional[Callable[[Node], bool]] = None,
ordered: bool = False,
) -> List[Node]:
"""
Query nodes within the graph.
@@ -575,6 +583,15 @@ class Graph:
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
is_encrypted_filter (Optional[bool], default = None)
filter for encryption status
custom_filter (Optional[Callable[[Node], bool]], default = None):
flexible filter
ordered (bool)
whether to apply topological sorting before filtering nodes
Returns:
List[Node]:
filtered nodes
@@ -592,6 +609,12 @@ class Graph:
return any(text == alternative for alternative in text_filter)
def match_boolean_filter(boolean_filter, boolean):
if boolean_filter is None:
return True
return boolean == boolean_filter
def get_operation_name(node):
result: str
@@ -604,12 +627,15 @@ class Graph:
return result
nodes = nx.lexicographical_topological_sort(self.graph) if ordered else self.graph.nodes()
return [
node
for node in self.graph.nodes()
for node in nodes
if (
match_text_filter(tag_filter, node.tag)
and match_text_filter(operation_filter, get_operation_name(node))
and match_boolean_filter(is_encrypted_filter, node.output.is_encrypted)
and (custom_filter is None or custom_filter(node))
)
]
@@ -617,6 +643,8 @@ class Graph:
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
is_encrypted_filter: Optional[bool] = None,
custom_filter: Optional[Callable[[Node], bool]] = None,
) -> int:
"""
Get maximum integer bit-width within the graph.
@@ -630,16 +658,21 @@ class Graph:
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
is_encrypted_filter (Optional[bool], default = None)
filter for encryption status
custom_filter (Optional[Callable[[Node], bool]], default = None):
flexible filter
Returns:
int:
maximum integer bit-width within the graph
if there are no integer nodes matching the query, result is -1
"""
query = self.query_nodes(tag_filter, operation_filter, is_encrypted_filter, custom_filter)
filtered_bit_widths = (
node.output.dtype.bit_width
for node in self.query_nodes(tag_filter, operation_filter)
if isinstance(node.output.dtype, Integer)
node.output.dtype.bit_width for node in query if isinstance(node.output.dtype, Integer)
)
return max(filtered_bit_widths, default=-1)
@@ -647,6 +680,8 @@ class Graph:
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
is_encrypted_filter: Optional[bool] = None,
custom_filter: Optional[Callable[[Node], bool]] = None,
) -> Optional[Tuple[int, int]]:
"""
Get integer range of the graph.
@@ -660,30 +695,39 @@ class Graph:
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
is_encrypted_filter (Optional[bool], default = None)
filter for encryption status
custom_filter (Optional[Callable[[Node], bool]], default = None):
flexible filter
Returns:
Optional[Tuple[int, int]]:
minimum and maximum integer value observed during inputset evaluation
if there are no integer nodes matching the query, result is None
"""
if self.is_direct:
return None
result: Optional[Tuple[int, int]] = None
if not self.is_direct:
filtered_bounds = (
node.bounds
for node in self.query_nodes(tag_filter, operation_filter)
if isinstance(node.output.dtype, Integer) and node.bounds is not None
)
for min_bound, max_bound in filtered_bounds:
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
query = self.query_nodes(tag_filter, operation_filter, is_encrypted_filter, custom_filter)
filtered_bounds = (
node.bounds
for node in query
if isinstance(node.output.dtype, Integer) and node.bounds is not None
)
for min_bound, max_bound in filtered_bounds:
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
if result is None:
result = (int(min_bound), int(max_bound))
else:
old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence
result = (
min(old_min_bound, int(min_bound)),
max(old_max_bound, int(max_bound)),
)
if result is None:
result = (int(min_bound), int(max_bound))
else:
old_min_bound, old_max_bound = result
result = (
min(old_min_bound, int(min_bound)),
max(old_max_bound, int(max_bound)),
)
return result

View File

@@ -163,10 +163,6 @@ class Node:
fhe_directory = os.path.dirname(fhe.__file__)
import concrete.onnx as coonx
coonx_directory = os.path.dirname(coonx.__file__)
# pylint: enable=cyclic-import,import-outside-toplevel
for frame in reversed(traceback.extract_stack()):
@@ -176,9 +172,6 @@ class Node:
if frame.filename.startswith(fhe_directory):
continue
if frame.filename.startswith(coonx_directory):
continue
self.location = f"{frame.filename}:{frame.lineno}"
break
@@ -294,12 +287,12 @@ class Node:
name = self.properties["name"]
if name == "index.static":
if name == "index_static":
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"{predecessors[0]}[{', '.join(elements)}]"
if name == "assign.static":
if name == "assign_static":
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
@@ -345,10 +338,10 @@ class Node:
name = self.properties["name"]
if name == "index.static":
if name == "index_static":
name = self.format([""])
if name == "assign.static":
if name == "assign_static":
name = self.format(["", ""])[1:-1]
return name
@@ -386,7 +379,7 @@ class Node:
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",
"assign.static",
"assign_static",
"broadcast_to",
"concatenate",
"conv1d",
@@ -394,7 +387,7 @@ class Node:
"conv3d",
"dot",
"expand_dims",
"index.static",
"index_static",
"matmul",
"maxpool",
"multiply",

View File

@@ -426,7 +426,7 @@ class Tracer:
computation = Node.generic(
operation.__name__,
[tracer.output for tracer in tracers],
[deepcopy(tracer.output) for tracer in tracers],
output_value,
operation,
kwargs=kwargs,
@@ -618,7 +618,7 @@ class Tracer:
computation = Node.generic(
"astype",
[self.output],
[deepcopy(self.output)],
output_value,
lambda x: x, # unused for direct definition
)
@@ -662,7 +662,7 @@ class Tracer:
computation = Node.generic(
"astype",
[self.output],
[deepcopy(self.output)],
output_value,
evaluator,
kwargs={"dtype": dtype},
@@ -753,8 +753,8 @@ class Tracer:
output_value.shape = np.zeros(output_value.shape)[index].shape
computation = Node.generic(
"index.static",
[self.output],
"index_static",
[deepcopy(self.output)],
output_value,
lambda x, index: x[index],
kwargs={"index": index},
@@ -803,8 +803,8 @@ class Tracer:
sanitized_value = self.sanitize(value)
computation = Node.generic(
"assign.static",
[self.output, sanitized_value.output],
"assign_static",
[deepcopy(self.output), deepcopy(sanitized_value.output)],
self.output,
assign,
kwargs={"index": index},

View File

@@ -1,6 +0,0 @@
"""
Implement machine learning operations as specified by ONNX.
"""
from .convolution import conv
from .maxpool import maxpool

View File

@@ -78,7 +78,6 @@ setuptools.setup(
package_dir={
"concrete.fhe": "./concrete/fhe",
"concrete.onnx": "./concrete/onnx",
"": bindings_directory(),
},
packages=setuptools.find_namespace_packages(
@@ -87,9 +86,6 @@ setuptools.setup(
) + setuptools.find_namespace_packages(
where=".",
include=["concrete.fhe", "concrete.fhe.*"],
) + setuptools.find_namespace_packages(
where=".",
include=["concrete.onnx", "concrete.onnx.*"],
) + setuptools.find_namespace_packages(
where=bindings_directory(),
include=["concrete.compiler", "concrete.compiler.*"],

View File

@@ -18,6 +18,7 @@ tests_directory = os.path.dirname(tests.__file__)
INSECURE_KEY_CACHE_LOCATION = None
USE_MULTI_PRECISION = False
def pytest_addoption(parser):
@@ -39,6 +40,13 @@ def pytest_addoption(parser):
action="store",
help="Specify the location of the key cache",
)
parser.addoption(
"--precision",
type=str,
default=None,
action="store",
help="Which precision strategy to use in execution tests (single or multi)",
)
def pytest_sessionstart(session):
@@ -47,6 +55,7 @@ def pytest_sessionstart(session):
"""
# pylint: disable=global-statement
global INSECURE_KEY_CACHE_LOCATION
global USE_MULTI_PRECISION
# pylint: enable=global-statement
key_cache_location = session.config.getoption("--key-cache", default=None)
@@ -64,6 +73,9 @@ def pytest_sessionstart(session):
INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
precision = session.config.getoption("--precision", default="single")
USE_MULTI_PRECISION = precision == "multi"
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
"""
@@ -117,6 +129,7 @@ class Helpers:
jit=True,
insecure_key_cache_location=INSECURE_KEY_CACHE_LOCATION,
global_p_error=(1 / 10_000),
single_precision=(not USE_MULTI_PRECISION),
)
@staticmethod

View File

@@ -21,6 +21,22 @@ from concrete import fhe
lambda x, y: fhe.array([x, y]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "encrypted", "shape": ()},
},
id="fhe.array([x, y])",
),
pytest.param(
lambda x, y: fhe.array([x, y]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
},
id="fhe.array([x, y])",
),
pytest.param(
lambda x, y: fhe.array([x, y]),
{
"x": {"range": [0, 10], "status": "clear", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
},
id="fhe.array([x, y])",
@@ -42,6 +58,14 @@ from concrete import fhe
},
id="fhe.array([[x, 1], [y, 2], [z, 3]])",
),
pytest.param(
lambda x, y: fhe.array([x, y]) + fhe.array([x, y]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
},
id="fhe.array([x, y]) + fhe.array([x, y])",
),
],
)
def test_array(function, parameters, helpers):

View File

@@ -60,3 +60,58 @@ def test_bitwise(function, parameters, helpers):
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: (x & y) + (2**6),
id="x & y",
),
pytest.param(
lambda x, y: (x | y) + (2**6),
id="x | y",
),
pytest.param(
lambda x, y: (x ^ y) + (2**6),
id="x ^ y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
},
],
)
def test_bitwise_optimized(function, parameters, helpers):
"""
Test optimized bitwise operations between encrypted integers.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -18,6 +18,20 @@ from concrete import fhe
"y": {"shape": (3, 2)},
},
),
pytest.param(
lambda x, y: np.concatenate((x, y)),
{
"x": {"shape": (4, 2), "status": "clear"},
"y": {"shape": (3, 2)},
},
),
pytest.param(
lambda x, y: np.concatenate((x, y)),
{
"x": {"shape": (4, 2)},
"y": {"shape": (3, 2), "status": "clear"},
},
),
pytest.param(
lambda x, y: np.concatenate((x, y), axis=0),
{

View File

@@ -5,7 +5,6 @@ Tests of execution of convolution operation.
import numpy as np
import pytest
import concrete.onnx as connx
from concrete import fhe
from concrete.fhe.representation.node import Node
from concrete.fhe.tracing.tracer import Tracer
@@ -62,7 +61,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
@fhe.compiler({"x": "encrypted"})
def function(x):
return connx.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
return fhe.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
circuit = function.compile(inputset, configuration)
@@ -307,32 +306,6 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
ValueError,
"expected number of channel in weight to be 1.0 (C / group), but got 2",
),
pytest.param(
(1, 1, 4),
(1, 1, 2),
(1,),
(0, 0),
(1,),
(1,),
None,
1,
"NOTSET",
NotImplementedError,
"conv1d conversion to MLIR is not yet implemented",
),
pytest.param(
(1, 1, 4, 4, 4),
(1, 1, 2, 2, 2),
(1,),
(0, 0, 0, 0, 0, 0),
(1, 1, 1),
(1, 1, 1),
None,
1,
"NOTSET",
NotImplementedError,
"conv3d conversion to MLIR is not yet implemented",
),
pytest.param(
(1, 1, 4, 4, 4, 4),
(1, 1, 2, 2, 2, 2),
@@ -388,7 +361,7 @@ def test_bad_conv_compilation(
@fhe.compiler({"x": "encrypted"})
def function(x):
return connx.conv(
return fhe.conv(
x,
weight,
bias=bias,
@@ -426,8 +399,8 @@ def test_bad_conv_compilation(
"func",
[
# pylint: disable=protected-access
connx.convolution._evaluate_conv,
connx.convolution._trace_conv,
fhe.extensions.convolution._evaluate_conv,
fhe.extensions.convolution._trace_conv,
# pylint: enable=protected-access
],
)
@@ -487,7 +460,7 @@ def test_inconsistent_input_types(
Test conv with inconsistent input types.
"""
with pytest.raises(expected_error) as excinfo:
connx.conv(
fhe.conv(
x,
weight,
bias=bias,

View File

@@ -5,7 +5,6 @@ Tests of execution of maxpool operation.
import numpy as np
import pytest
import concrete.onnx as connx
from concrete import fhe
@@ -69,16 +68,61 @@ def test_maxpool(
sample_input = np.expand_dims(np.array(sample_input), axis=(0, 1))
expected_output = np.expand_dims(np.array(expected_output), axis=(0, 1))
assert np.array_equal(connx.maxpool(sample_input, **operation), expected_output)
assert np.array_equal(fhe.maxpool(sample_input, **operation), expected_output)
@fhe.compiler({"x": "encrypted"})
def function(x):
return connx.maxpool(x, **operation)
return fhe.maxpool(x, **operation)
graph = function.trace([sample_input], helpers.configuration())
assert np.array_equal(graph(sample_input), expected_output)
@pytest.mark.parametrize(
"operation,parameters",
[
pytest.param(
{
"kernel_shape": (3, 2),
},
{
"x": {"status": "encrypted", "range": [0, 20], "shape": (1, 1, 6, 7)},
},
),
pytest.param(
{
"kernel_shape": (3, 2),
},
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 6, 7)},
},
),
],
)
def test_maxpool2d(
operation,
parameters,
helpers,
):
"""
Test maxpool2d.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
def function(x):
return fhe.maxpool(x, **operation)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"input_shape,operation,expected_error,expected_message",
[
@@ -308,7 +352,7 @@ def test_bad_maxpool(
"""
with pytest.raises(expected_error) as excinfo:
connx.maxpool(np.random.randint(0, 10, size=input_shape), **operation)
fhe.maxpool(np.random.randint(0, 10, size=input_shape), **operation)
helpers.check_str(expected_message, str(excinfo.value))
@@ -318,51 +362,11 @@ def test_bad_maxpool_special(helpers):
Test maxpool with bad parameters for special cases.
"""
# compile
# -------
@fhe.compiler({"x": "encrypted"})
def not_compilable(x):
return connx.maxpool(x, kernel_shape=(4, 3))
inputset = [np.random.randint(0, 10, size=(1, 1, 10, 10)) for i in range(100)]
with pytest.raises(NotImplementedError) as excinfo:
not_compilable.compile(inputset, helpers.configuration())
helpers.check_str("MaxPool operation cannot be compiled yet", str(excinfo.value))
# clear input
# -----------
@fhe.compiler({"x": "clear"})
def clear_input(x):
return connx.maxpool(x, kernel_shape=(4, 3, 2))
inputset = [np.zeros((1, 1, 10, 10, 10), dtype=np.int64)]
with pytest.raises(RuntimeError) as excinfo:
clear_input.compile(inputset, helpers.configuration())
helpers.check_str(
# pylint: disable=line-too-long
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 10, 10)> ∈ [0, 0]
%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor<uint1, shape=(1, 1, 7, 8, 9)> ∈ [0, 0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported
return %1
""".strip(), # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# badly typed ndarray input
# -------------------------
with pytest.raises(TypeError) as excinfo:
connx.maxpool(np.array([{}, None]), ())
fhe.maxpool(np.array([{}, None]), ())
helpers.check_str(
# pylint: disable=line-too-long
@@ -379,7 +383,7 @@ Expected input elements to be of type np.integer, np.floating, or np.bool_ but i
# -----------------
with pytest.raises(TypeError) as excinfo:
connx.maxpool("", ())
fhe.maxpool("", ())
helpers.check_str(
# pylint: disable=line-too-long

View File

@@ -74,3 +74,98 @@ def test_constant_mul(function, parameters, helpers):
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x * y,
id="x * y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 10], "status": "clear"},
"y": {"range": [0, 10], "status": "encrypted"},
},
{
"x": {"range": [0, 10], "status": "encrypted"},
"y": {"range": [0, 10], "status": "clear"},
},
{
"x": {"range": [0, 10], "status": "encrypted"},
"y": {"range": [0, 10], "status": "encrypted"},
},
{
"x": {"range": [0, 10], "status": "clear", "shape": (3,)},
"y": {"range": [0, 10], "status": "encrypted"},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 10], "status": "clear"},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 10], "status": "encrypted"},
},
{
"x": {"range": [0, 10], "status": "clear"},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted"},
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted"},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "clear", "shape": (3,)},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "clear", "shape": (2, 1)},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (2, 1)},
"y": {"range": [0, 10], "status": "clear", "shape": (3,)},
},
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (2, 1)},
"y": {"range": [0, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
},
],
)
def test_mul(function, parameters, helpers):
"""
Test mul where both of the operators are dynamic.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -273,7 +273,7 @@ def deterministic_unary_function(x):
{
"x": {"status": "encrypted", "range": [0, 84]},
},
id="abs(64 - x)",
id="abs(42 - x)",
),
pytest.param(
lambda x: ~x,

View File

@@ -0,0 +1,650 @@
"""
Tests of `Converter` class.
"""
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.mlir import GraphConverter
def assign(x, y):
"""
Assign scalar `y` into vector `x`.
"""
x[0] = y
return x
@pytest.mark.parametrize(
"function,encryption_statuses,inputset,expected_error,expected_message",
[
pytest.param(
lambda x, y: x + y,
{"x": "encrypted", "y": "encrypted"},
[(0.0, 0), (7.0, 7), (0.0, 7), (7.0, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<float64> ∈ [0.0, 7.0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integers are supported
%1 = y # EncryptedScalar<uint3> ∈ [0, 7]
%2 = add(%0, %1) # EncryptedScalar<float64> ∈ [0.0, 14.0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integers are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.conv(x, [[[3, 1, 0, 2]]]),
{"x": "encrypted"},
[np.ones(shape=(1, 1, 10), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10)> ∈ [1, 1]
%1 = [[[3 1 0 2]]] # ClearTensor<uint2, shape=(1, 1, 4)> ∈ [0, 3]
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint3, shape=(1, 1, 7)> ∈ [6, 6]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1-dimensional convolutions are not supported at the moment
return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.conv(x, [[[[[1, 3], [4, 2]]]]]),
{"x": "encrypted"},
[np.ones(shape=(1, 1, 3, 4, 5), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 3, 4, 5)> ∈ [1, 1]
%1 = [[[[[1 3] [4 2]]]]] # ClearTensor<uint3, shape=(1, 1, 1, 2, 2)> ∈ [1, 4]
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint4, shape=(1, 1, 3, 3, 4)> ∈ [10, 10]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3-dimensional convolutions are not supported at the moment
return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.maxpool(x, kernel_shape=(3,)),
{"x": "encrypted"},
[np.ones(shape=(1, 1, 10), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10)> ∈ [1, 1]
%1 = maxpool1d(%0, kernel_shape=(3,), strides=(1,), pads=(0, 0), dilations=(1,), ceil_mode=False) # EncryptedTensor<uint1, shape=(1, 1, 8)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1-dimensional maxpooling is not supported at the moment
return %1
""", # noqa: E501
),
pytest.param(
lambda x: fhe.maxpool(x, kernel_shape=(3, 1, 2)),
{"x": "encrypted"},
[np.ones(shape=(1, 1, 3, 4, 5), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 3, 4, 5)> ∈ [1, 1]
%1 = maxpool3d(%0, kernel_shape=(3, 1, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # EncryptedTensor<uint1, shape=(1, 1, 1, 4, 4)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3-dimensional maxpooling is not supported at the moment
return %1
""", # noqa: E501
),
pytest.param(
lambda x, y: x + y,
{"x": "clear", "y": "clear"},
[(0, 0), (7, 7), (0, 7), (7, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
%1 = y # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = add(%0, %1) # ClearScalar<uint4> ∈ [0, 14]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear additions are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x - y,
{"x": "clear", "y": "clear"},
[(0, 0), (7, 7), (0, 7), (7, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
%1 = y # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = subtract(%0, %1) # ClearScalar<int4> ∈ [-7, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear subtractions are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{"x": "clear", "y": "clear"},
[(0, 0), (7, 7), (0, 7), (7, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
%1 = y # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = multiply(%0, %1) # ClearScalar<uint6> ∈ [0, 49]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear multiplications are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{"x": "clear", "y": "clear"},
[([1, 2], [3, 4])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
%1 = y # ClearTensor<uint3, shape=(2,)> ∈ [3, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = dot(%0, %1) # ClearScalar<uint4> ∈ [11, 11]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear dot products are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.broadcast_to(x, shape=(2, 2)),
{"x": "clear"},
[[1, 2], [3, 4]],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint3, shape=(2,)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
%1 = broadcast_to(%0, shape=(2, 2)) # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear values cannot be broadcasted
return %1
""", # noqa: E501
),
pytest.param(
assign,
{"x": "clear", "y": "encrypted"},
[([1, 2, 3], 0)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint2, shape=(3,)> ∈ [0, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tensor is clear
%1 = y # EncryptedScalar<uint1> ∈ [0, 0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ assigned value is encrypted
%2 = (%0[0] = %1) # ClearTensor<uint2, shape=(3,)> ∈ [0, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted values cannot be assigned to clear tensors
return %2
""", # noqa: E501
),
pytest.param(
lambda x: x**2 + (x + 1_000_000),
{"x": "encrypted"},
[100_000],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 34-bit value is used as an input to a table lookup
(note that it's assigned 34-bits during compilation because of its relation with other operations)
%1 = 2 # ClearScalar<uint2> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint34> ∈ [10000000000, 10000000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit table lookups are supported
%3 = 1000000 # ClearScalar<uint20> ∈ [1000000, 1000000]
%4 = add(%0, %3) # EncryptedScalar<uint21> ∈ [1100000, 1100000]
%5 = add(%2, %4) # EncryptedScalar<uint34> ∈ [10001100000, 10001100000]
return %5
""", # noqa: E501
),
pytest.param(
lambda x, y: x & y,
{"x": "encrypted", "y": "encrypted"},
[(-2, 4)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<int2> ∈ [-2, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is signed
%1 = y # EncryptedScalar<uint3> ∈ [4, 4]
%2 = bitwise_and(%0, %1) # EncryptedScalar<uint3> ∈ [4, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise operations are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x & y,
{"x": "encrypted", "y": "encrypted"},
[(4, -2)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint3> ∈ [4, 4]
%1 = y # EncryptedScalar<int2> ∈ [-2, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is signed
%2 = bitwise_and(%0, %1) # EncryptedScalar<uint3> ∈ [4, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise operations are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: np.concatenate((x, y)),
{"x": "clear", "y": "clear"},
[([1, 2], [3, 4])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
%1 = y # ClearTensor<uint3, shape=(2,)> ∈ [3, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ value is clear
%2 = concatenate((%0, %1)) # ClearTensor<uint3, shape=(4,)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear concatenation is not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.conv(x, [[[[2, 1], [0, 3]]]]),
{"x": "clear"},
[np.ones(shape=(1, 1, 10, 10), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = [[[[2 1] [0 3]]]] # ClearTensor<uint2, shape=(1, 1, 2, 2)> ∈ [0, 3]
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [6, 6]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear convolutions are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: fhe.conv(x, weight=y),
{"x": "encrypted", "y": "encrypted"},
[
(
np.ones(shape=(1, 1, 10, 10), dtype=np.int64),
np.ones(shape=(1, 1, 2, 2), dtype=np.int64),
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
%1 = y # EncryptedTensor<uint1, shape=(1, 1, 2, 2)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ weight is encrypted
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [4, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but convolutions with encrypted weights are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: fhe.conv(x, weight=[[[[2, 1], [0, 3]]]], bias=y),
{"x": "encrypted", "y": "encrypted"},
[
(
np.ones(shape=(1, 1, 10, 10), dtype=np.int64),
np.ones(shape=(1,), dtype=np.int64),
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 10, 10)> ∈ [1, 1]
%1 = y # EncryptedTensor<uint1, shape=(1,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bias is encrypted
%2 = [[[[2 1] [0 3]]]] # ClearTensor<uint2, shape=(1, 1, 2, 2)> ∈ [0, 3]
%3 = conv2d(%0, %2, %1, pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint3, shape=(1, 1, 9, 9)> ∈ [7, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but convolutions with encrypted biases are not supported
return %3
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{"x": "encrypted", "y": "encrypted"},
[
(
np.ones(shape=(3,), dtype=np.int64),
np.ones(shape=(3,), dtype=np.int64),
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = dot(%0, %1) # EncryptedScalar<uint2> ∈ [3, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted dot products are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "clear", "y": "clear"},
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is clear
%1 = y # ClearTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = matmul(%0, %1) # ClearTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear matrix multiplications are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "encrypted", "y": "encrypted"},
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = matmul(%0, %1) # EncryptedTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted matrix multiplications are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.maxpool(x, kernel_shape=(3, 2)),
{"x": "clear"},
[np.ones(shape=(1, 1, 10, 5), dtype=np.int64)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 5)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = maxpool2d(%0, kernel_shape=(3, 2), strides=(1, 1), pads=(0, 0, 0, 0), dilations=(1, 1), ceil_mode=False) # ClearTensor<uint1, shape=(1, 1, 8, 4)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear maxpooling is not supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: x**2,
{"x": "clear"},
[3, 4, 5],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint3> ∈ [3, 5]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this clear value is used as an input to a table lookup
%1 = 2 # ClearScalar<uint2> ∈ [2, 2]
%2 = power(%0, %1) # ClearScalar<uint5> ∈ [9, 25]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only encrypted table lookups are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.sum(x),
{"x": "clear"},
[[1, 2]],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint2, shape=(2,)> ∈ [1, 2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = sum(%0) # ClearScalar<uint2> ∈ [3, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear summation is not supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x, y: x << y,
{"x": "encrypted", "y": "encrypted"},
[(-2, 4)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<int2> ∈ [-2, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is signed
%1 = y # EncryptedScalar<uint3> ∈ [4, 4]
%2 = left_shift(%0, %1) # EncryptedScalar<int6> ∈ [-32, -32]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise shifts are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x >> y,
{"x": "encrypted", "y": "encrypted"},
[(4, -2)],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint3> ∈ [4, 4]
%1 = y # EncryptedScalar<int2> ∈ [-2, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is signed
%2 = right_shift(%0, %1) # EncryptedScalar<uint1> ∈ [0, 0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only unsigned-unsigned bitwise shifts are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: -x,
{"x": "clear"},
[10],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint4> ∈ [10, 10]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = negative(%0) # ClearScalar<int5> ∈ [-10, -10]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear negations are not supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: fhe.LookupTable([fhe.LookupTable([0, 1]), fhe.LookupTable([1, 0])])[x],
{"x": "clear"},
[[1, 1], [1, 0], [0, 1], [0, 0]],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this clear value is used as an input to a table lookup
%1 = tlu(%0, table=[[0, 1] [1, 0]]) # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only encrypted table lookups are supported
return %1
""", # noqa: E501
),
],
)
def test_converter_bad_convert(
function,
encryption_statuses,
inputset,
expected_error,
expected_message,
helpers,
):
"""
Test unsupported graph conversion.
"""
configuration = helpers.configuration()
compiler = fhe.Compiler(function, encryption_statuses)
with pytest.raises(expected_error) as excinfo:
compiler.compile(inputset, configuration)
helpers.check_str(expected_message, str(excinfo.value))
@pytest.mark.parametrize(
"function,parameters,expected_graph",
[
pytest.param(
lambda x: (x**2) + 100,
{
"x": {"range": [0, 10], "status": "encrypted"},
},
"""
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
return %4
""",
)
],
)
def test_converter_process_multi_precision(function, parameters, expected_graph, helpers):
"""
Test `process` method of `Converter` with multi precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=False)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
graph = compiler.trace(inputset, configuration)
processed_graph = GraphConverter().process(graph, configuration)
for node in processed_graph.query_nodes():
if "original_bit_width" in node.properties:
del node.properties["original_bit_width"]
helpers.check_str(expected_graph, processed_graph.format())
@pytest.mark.parametrize(
"function,parameters,expected_graph",
[
pytest.param(
lambda x: (x**2) + 100,
{
"x": {"range": [0, 10], "status": "encrypted"},
},
"""
%0 = x # EncryptedScalar<uint8> ∈ [0, 10]
%1 = 2 # ClearScalar<uint9> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
return %4
""",
)
],
)
def test_converter_process_single_precision(function, parameters, expected_graph, helpers):
"""
Test `process` method of `Converter` with single precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=True)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
graph = compiler.trace(inputset, configuration)
processed_graph = GraphConverter().process(graph, configuration)
for node in processed_graph.query_nodes():
if "original_bit_width" in node.properties:
del node.properties["original_bit_width"]
helpers.check_str(expected_graph, processed_graph.format())

View File

@@ -1,499 +0,0 @@
"""
Tests of `GraphConverter` class.
"""
import numpy as np
import pytest
import concrete.onnx as connx
from concrete import fhe
def assign(x):
"""
Simple assignment to a vector.
"""
x[0] = 0
return x
@pytest.mark.parametrize(
"function,encryption_statuses,inputset,expected_error,expected_message",
[
pytest.param(
lambda x, y: (x - y, x + y),
{"x": "encrypted", "y": "clear"},
[(0, 0), (7, 7), (0, 7), (7, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
%1 = y # ClearScalar<uint3> ∈ [0, 7]
%2 = subtract(%0, %1) # EncryptedScalar<int4> ∈ [-7, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
%3 = add(%0, %1) # EncryptedScalar<uint4> ∈ [0, 14]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
return (%2, %3)
""", # noqa: E501
),
pytest.param(
lambda x: x,
{"x": "clear"},
range(-10, 10),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearScalar<int5> ∈ [-10, 9]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
return %0
""", # noqa: E501
),
pytest.param(
lambda x: x * 1.5,
{"x": "encrypted"},
[2.5 * x for x in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<float64> ∈ [0.0, 247.5]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
%1 = 1.5 # ClearScalar<float64> ∈ [1.5, 1.5]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%2 = multiply(%0, %1) # EncryptedScalar<float64> ∈ [0.0, 371.25]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.sin(x),
{"x": "encrypted"},
range(100),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint7> ∈ [0, 99]
%1 = sin(%0) # EncryptedScalar<float64> ∈ [-0.99999, 0.999912]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x, y: np.concatenate((x, y)),
{"x": "encrypted", "y": "clear"},
[
(
np.random.randint(0, 2**3, size=(3, 2)),
np.random.randint(0, 2**3, size=(3, 2)),
)
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint3, shape=(3, 2)> ∈ [0, 7]
%1 = y # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
%2 = concatenate((%0, %1)) # EncryptedTensor<uint3, shape=(6, 2)> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, w: connx.conv(x, w),
{"x": "encrypted", "w": "encrypted"},
[
(
np.random.randint(0, 2, size=(1, 1, 4)),
np.random.randint(0, 2, size=(1, 1, 1)),
)
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1)> ∈ [0, 1]
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, w: connx.conv(x, w),
{"x": "encrypted", "w": "encrypted"},
[
(
np.random.randint(0, 2, size=(1, 1, 4, 4)),
np.random.randint(0, 2, size=(1, 1, 1, 1)),
)
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1)> ∈ [0, 1]
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, w: connx.conv(x, w),
{"x": "encrypted", "w": "encrypted"},
[
(
np.random.randint(0, 2, size=(1, 1, 4, 4, 4)),
np.random.randint(0, 2, size=(1, 1, 1, 1, 1)),
)
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1, 1)> ∈ [0, 1]
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{"x": "encrypted", "y": "encrypted"},
[([0], [0]), ([3], [3]), ([3], [0]), ([0], [3]), ([1], [1])],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
%1 = y # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
%2 = dot(%0, %1) # EncryptedScalar<uint4> ∈ [0, 9]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: x[0],
{"x": "clear"},
[[0, 1, 2, 3], [7, 6, 5, 4]],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint3, shape=(4,)> ∈ [0, 7]
%1 = %0[0] # ClearScalar<uint3> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "encrypted", "y": "encrypted"},
[
(
np.random.randint(0, 2**1, size=(1, 1)),
np.random.randint(0, 2**1, size=(1, 1)),
)
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
%1 = y # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
%2 = matmul(%0, %1) # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{"x": "encrypted", "y": "encrypted"},
[(0, 0), (7, 7), (0, 7), (7, 0)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
%1 = y # EncryptedScalar<uint3> ∈ [0, 7]
%2 = multiply(%0, %1) # EncryptedScalar<uint6> ∈ [0, 49]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: -x,
{"x": "clear"},
[0, 7],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearScalar<uint3> ∈ [0, 7]
%1 = negative(%0) # ClearScalar<int4> ∈ [-7, 0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: x.reshape((3, 2)),
{"x": "clear"},
[np.random.randint(0, 2**3, size=(2, 3)) for _ in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint3, shape=(2, 3)> ∈ [0, 7]
%1 = reshape(%0, newshape=(3, 2)) # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: np.sum(x),
{"x": "clear"},
[np.random.randint(0, 2, size=(1,)) for _ in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
%1 = sum(%0) # ClearScalar<uint1> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: np.maximum(x, np.array([3])),
{"x": "clear"},
[[0], [1]],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
%1 = [3] # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.transpose(x),
{"x": "clear"},
[np.random.randint(0, 2, size=(3, 2)) for _ in range(10)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
return %1
""", # noqa: E501
),
pytest.param(
lambda x: np.broadcast_to(x, shape=(3, 2)),
{"x": "clear"},
[np.random.randint(0, 2, size=(2,)) for _ in range(10)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
return %1
""", # noqa: E501
),
pytest.param(
assign,
{"x": "clear"},
[np.random.randint(0, 2, size=(3,)) for _ in range(10)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
%1 = 0 # ClearScalar<uint1> ∈ [0, 0]
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
{"x": "encrypted"},
[200000],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint18> ∈ [200000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this input is 18-bits
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
%2 = add(%0, %1) # EncryptedScalar<uint18> ∈ [200300, 200300]
%3 = subgraph(%2) # EncryptedScalar<uint4> ∈ [9, 9]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bits
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = input # EncryptedScalar<uint2>
%1 = sin(%0) # EncryptedScalar<float64>
%2 = 10 # ClearScalar<uint4>
%3 = multiply(%2, %1) # EncryptedScalar<float64>
%4 = absolute(%3) # EncryptedScalar<float64>
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
return %5
""", # noqa: E501
),
pytest.param(
lambda x, y: x << y,
{"x": "encrypted", "y": "encrypted"},
[(-1, 1), (-2, 3)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<int2> ∈ [-2, -1]
%1 = y # EncryptedScalar<uint2> ∈ [1, 3]
%2 = left_shift(%0, %1) # EncryptedScalar<int5> ∈ [-16, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned bitwise operations are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x << y,
{"x": "encrypted", "y": "encrypted"},
[(1, 20), (2, 10)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint2> ∈ [1, 2]
%1 = y # EncryptedScalar<uint5> ∈ [10, 20]
%2 = left_shift(%0, %1) # EncryptedScalar<uint21> ∈ [2048, 1048576]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 4-bit shifts are supported
return %2
""", # noqa: E501
),
],
)
def test_graph_converter_bad_convert(
function,
encryption_statuses,
inputset,
expected_error,
expected_message,
helpers,
):
"""
Test unsupported graph conversion.
"""
configuration = helpers.configuration()
compiler = fhe.Compiler(function, encryption_statuses)
with pytest.raises(expected_error) as excinfo:
compiler.compile(inputset, configuration)
helpers.check_str(expected_message, str(excinfo.value))
@pytest.mark.parametrize(
"function,inputset,expected_mlir",
[
pytest.param(
lambda x: 1 + fhe.LookupTable([4, 1, 2, 3])[x] + fhe.LookupTable([4, 1, 2, 3])[x + 1],
range(3),
"""
module {
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%c1_i4 = arith.constant 1 : i4
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
%1 = "FHE.add_eint_int"(%arg0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%2 = "FHE.add_eint_int"(%0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%3 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
%4 = "FHE.add_eint"(%2, %3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
return %4 : !FHE.eint<3>
}
}
""", # noqa: E501
# Notice that there is only a single 1 and a single table cst above
),
],
)
def test_constant_cache(function, inputset, expected_mlir, helpers):
"""
Test caching MLIR constants.
"""
configuration = helpers.configuration()
compiler = fhe.Compiler(function, {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir, circuit.mlir)
# pylint: enable=line-too-long

View File

@@ -39,6 +39,64 @@ def f(x):
return g(z + 3) * 2
def test_graph_format_show_lines(helpers):
"""
Test `format` method of `Graph` class with show_lines=True.
"""
configuration = helpers.configuration()
compiler = fhe.Compiler(f, {"x": "encrypted"})
graph = compiler.trace(range(10), configuration)
# pylint: disable=line-too-long
expected = f"""
%0 = x # EncryptedScalar<uint4> ∈ [0, 9] {tests_directory}/representation/test_graph.py:50
%1 = 2 # ClearScalar<uint2> ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34
%2 = multiply(%0, %1) # EncryptedScalar<uint5> ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34
%3 = 42 # ClearScalar<uint6> ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36
%4 = add(%2, %3) # EncryptedScalar<uint6> ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36
%5 = subgraph(%4) # EncryptedScalar<uint3> ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37
%6 = 3 # ClearScalar<uint2> ∈ [3, 3] {tests_directory}/representation/test_graph.py:39
%7 = add(%5, %6) # EncryptedScalar<uint4> ∈ [9, 10] {tests_directory}/representation/test_graph.py:39
%8 = 120 # ClearScalar<uint7> ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23
%9 = subtract(%8, %7) # EncryptedScalar<uint7> ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23
%10 = 4 # ClearScalar<uint3> ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24
%12 = 2 # ClearScalar<uint2> ∈ [2, 2] {tests_directory}/representation/test_graph.py:39
%13 = multiply(%11, %12) # EncryptedScalar<uint6> ∈ [54, 54] {tests_directory}/representation/test_graph.py:39
return %13
Subgraphs:
%5 = subgraph(%4):
%0 = input # EncryptedScalar<uint2> @ abc.foo {tests_directory}/representation/test_graph.py:36
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc {tests_directory}/representation/test_graph.py:37
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc {tests_directory}/representation/test_graph.py:37
return %2
""" # noqa: E501
# pylint: enable=line-too-long
actual = graph.format(show_locations=True)
assert (
actual.strip() == expected.strip()
), f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
@pytest.mark.parametrize(
"function,inputset,tag_filter,operation_filter,expected_result",
[
@@ -184,13 +242,14 @@ def test_graph_maximum_integer_bit_width(
@pytest.mark.parametrize(
"function,inputset,tag_filter,operation_filter,expected_result",
"function,inputset,tag_filter,operation_filter,is_encrypted_filter,expected_result",
[
pytest.param(
lambda x: x + 42,
range(-10, 10),
None,
None,
None,
(-10, 51),
),
pytest.param(
@@ -199,12 +258,14 @@ def test_graph_maximum_integer_bit_width(
None,
None,
None,
None,
),
pytest.param(
f,
range(10),
None,
None,
None,
(0, 120),
),
pytest.param(
@@ -212,6 +273,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
"",
None,
None,
(0, 54),
),
pytest.param(
@@ -219,6 +281,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
"abc",
None,
None,
(0, 18),
),
pytest.param(
@@ -226,6 +289,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
["abc", "def"],
None,
None,
(0, 120),
),
pytest.param(
@@ -233,6 +297,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
re.compile(".*b.*"),
None,
None,
(0, 60),
),
pytest.param(
@@ -240,6 +305,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
"input",
None,
(0, 9),
),
pytest.param(
@@ -247,6 +313,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
"constant",
None,
(2, 120),
),
pytest.param(
@@ -254,6 +321,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
"subgraph",
None,
(6, 7),
),
pytest.param(
@@ -261,6 +329,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
"add",
None,
(9, 60),
),
pytest.param(
@@ -268,6 +337,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
["subgraph", "add"],
None,
(6, 60),
),
pytest.param(
@@ -275,6 +345,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
None,
re.compile("sub.*"),
None,
(6, 111),
),
pytest.param(
@@ -282,6 +353,7 @@ def test_graph_maximum_integer_bit_width(
range(10),
"abc.foo",
"add",
None,
(42, 60),
),
pytest.param(
@@ -290,6 +362,23 @@ def test_graph_maximum_integer_bit_width(
"abc",
"floor_divide",
None,
None,
),
pytest.param(
lambda x: x - 2,
range(5, 10),
None,
None,
True,
(3, 9),
),
pytest.param(
lambda x: x - 2,
range(5, 10),
None,
None,
False,
(2, 2),
),
],
)
@@ -298,6 +387,7 @@ def test_graph_integer_range(
inputset,
tag_filter,
operation_filter,
is_encrypted_filter,
expected_result,
helpers,
):
@@ -310,62 +400,23 @@ def test_graph_integer_range(
compiler = fhe.Compiler(function, {"x": "encrypted"})
graph = compiler.trace(inputset, configuration)
assert graph.integer_range(tag_filter, operation_filter) == expected_result
assert graph.integer_range(tag_filter, operation_filter, is_encrypted_filter) == expected_result
def test_graph_format_show_lines(helpers):
def test_direct_graph_integer_range(helpers):
"""
Test `format` method of `Graph` class with show_lines=True.
Test `integer_range` method of `Graph` class where `graph.is_direct` is `True`.
"""
configuration = helpers.configuration()
# pylint: disable=import-outside-toplevel
from concrete.fhe.dtypes import Integer
from concrete.fhe.values import Value
compiler = fhe.Compiler(f, {"x": "encrypted"})
graph = compiler.trace(range(10), configuration)
# pylint: enable=import-outside-toplevel
# pylint: disable=line-too-long
expected = f"""
%0 = x # EncryptedScalar<uint4> ∈ [0, 9] {tests_directory}/representation/test_graph.py:324
%1 = 2 # ClearScalar<uint2> ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34
%2 = multiply(%0, %1) # EncryptedScalar<uint5> ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34
%3 = 42 # ClearScalar<uint6> ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36
%4 = add(%2, %3) # EncryptedScalar<uint6> ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36
%5 = subgraph(%4) # EncryptedScalar<uint3> ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37
%6 = 3 # ClearScalar<uint2> ∈ [3, 3] {tests_directory}/representation/test_graph.py:39
%7 = add(%5, %6) # EncryptedScalar<uint4> ∈ [9, 10] {tests_directory}/representation/test_graph.py:39
%8 = 120 # ClearScalar<uint7> ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23
%9 = subtract(%8, %7) # EncryptedScalar<uint7> ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23
%10 = 4 # ClearScalar<uint3> ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24
%12 = 2 # ClearScalar<uint2> ∈ [2, 2] {tests_directory}/representation/test_graph.py:39
%13 = multiply(%11, %12) # EncryptedScalar<uint6> ∈ [54, 54] {tests_directory}/representation/test_graph.py:39
return %13
Subgraphs:
%5 = subgraph(%4):
%0 = input # EncryptedScalar<uint2> @ abc.foo {tests_directory}/representation/test_graph.py:36
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc {tests_directory}/representation/test_graph.py:37
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc {tests_directory}/representation/test_graph.py:37
return %2
""" # noqa: E501
# pylint: enable=line-too-long
actual = graph.format(show_locations=True)
assert (
actual.strip() == expected.strip()
), f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
circuit = fhe.Compiler.assemble(
lambda x: x,
{"x": Value(dtype=Integer(is_signed=False, bit_width=8), shape=(), is_encrypted=True)},
configuration=helpers.configuration(),
)
assert circuit.graph.integer_range() is None

View File

@@ -167,7 +167,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
),
pytest.param(
Node.generic(
name="index.static",
name="index_static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
operation=lambda x: x[slice(None, None, -1)],
@@ -208,7 +208,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
),
pytest.param(
Node.generic(
name="assign.static",
name="assign_static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
operation=lambda *args: args,
@@ -266,7 +266,7 @@ def test_node_format(node, predecessors, expected_result):
),
pytest.param(
Node.generic(
name="index.static",
name="index_static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=()),
operation=lambda *args: args,
@@ -276,7 +276,7 @@ def test_node_format(node, predecessors, expected_result):
),
pytest.param(
Node.generic(
name="assign.static",
name="assign_static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
operation=lambda *args: args,