refactor(frontend-python): re-write bit width assignment

This commit is contained in:
Umut
2023-08-02 17:34:47 +02:00
parent 9a5b08938e
commit 46f3de63cc
6 changed files with 539 additions and 222 deletions

View File

@@ -16,6 +16,7 @@ ignore = [
"**/__init__.py" = ["F401"]
"concrete/fhe/compilation/configuration.py" = ["ARG002"]
"concrete/fhe/mlir/processors/all.py" = ["F401"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
"examples/**" = ["PLR2004"]
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311"]

View File

@@ -802,7 +802,10 @@ class Context:
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_encrypted and y.is_encrypted:
assert self.is_bit_width_compatible(x, y)
else:
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_scalar or y.is_scalar:
return self.mul(resulting_type, x, y)
@@ -1369,7 +1372,10 @@ class Context:
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_encrypted and y.is_encrypted:
assert self.is_bit_width_compatible(x, y)
else:
assert self.is_bit_width_compatible(resulting_type, x, y)
if resulting_type.shape == ():
if x.is_clear:
@@ -1489,7 +1495,10 @@ class Context:
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_encrypted and y.is_encrypted:
assert self.is_bit_width_compatible(x, y)
else:
assert self.is_bit_width_compatible(resulting_type, x, y)
use_linalg = x.is_tensor or y.is_tensor

View File

@@ -1,10 +1,10 @@
"""
Declaration of `AssignBitWidths` graph processor.
"""
from __future__ import annotations
import typing
from collections.abc import Iterable
from typing import Dict, List
import z3
from ...dtypes import Integer
from ...representation import Graph, Node, Operation
@@ -13,266 +13,344 @@ from . import GraphProcessor
class AssignBitWidths(GraphProcessor):
"""
Assign a precision to all nodes inputs/output.
AssignBitWidths graph processor, to assign proper bit-widths to be compatible with FHE.
The precisions are compatible graph constraints and MLIR.
There are two modes:
- single precision: where all encrypted values have the same precision.
- multi precision: where encrypted values can have different precisions.
- Single Precision, where all encrypted values have the same precision.
- Multi Precision, where encrypted values can have different precisions.
"""
def __init__(self, single_precision=False):
self.single_precision = single_precision
def apply(self, graph: Graph):
nodes = graph.query_nodes()
for node in nodes:
optimizer = z3.Optimize()
max_bit_width: z3.Int = z3.Int("max")
bit_widths: Dict[Node, z3.Int] = {}
additional_constraints = AdditionalConstraints(optimizer, graph, bit_widths)
nodes = graph.query_nodes(ordered=True)
for i, node in enumerate(nodes):
bit_width = z3.Int(f"%{i}")
bit_widths[node] = bit_width
optimizer.add(max_bit_width >= bit_width)
assert isinstance(node.output.dtype, Integer)
node.properties["original_bit_width"] = node.output.dtype.bit_width
optimizer.add(bit_width >= node.output.dtype.bit_width)
additional_constraints.generate_for(node, bit_width)
if self.single_precision:
assign_single_precision(nodes)
else:
assign_multi_precision(graph, nodes)
for bit_width in bit_widths.values():
optimizer.add(bit_width == max_bit_width)
optimizer.minimize(sum(bit_width**2 for bit_width in bit_widths.values()))
assert optimizer.check() == z3.sat
model = optimizer.model()
for node, bit_width in bit_widths.items():
assert isinstance(node.output.dtype, Integer)
new_bit_width = model[bit_width].as_long()
if node.output.is_clear:
new_bit_width += 1
node.properties["original_bit_width"] = node.output.dtype.bit_width
node.output.dtype.bit_width = new_bit_width
def assign_single_precision(nodes: list[Node]):
"""Assign one single encryption precision to all nodes."""
p = required_encrypted_bitwidth(nodes)
for node in nodes:
assign_precisions_1_node(node, p, p)
def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
"""Assign input/output precision to a single node.
Precision are adjusted to match different use, e.g. encrypted and constant case.
class AdditionalConstraints:
"""
assert isinstance(node.output.dtype, Integer)
if node.output.is_encrypted:
node.output.dtype.bit_width = output_p
else:
node.output.dtype.bit_width = output_p + 1
for value in node.inputs:
assert isinstance(value.dtype, Integer)
if value.is_encrypted:
value.dtype.bit_width = inputs_p
else:
value.dtype.bit_width = inputs_p + 1
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
ROUNDING = {"round_bit_pattern"}
MULTIPLY = {"multiply", "matmul", "dot"}
def max_encrypted_bitwidth_node(node: Node):
"""Give the minimal precision to implement the node.
This applies to both input and output precisions.
"""
assert isinstance(node.output.dtype, Integer)
if node.output.is_encrypted or node.operation == Operation.Constant:
normal_p = node.output.dtype.bit_width
else:
normal_p = -1
name = node.properties.get("name")
if name in CHUNKED_COMPARISON:
return max(normal_p, CHUNKED_COMPARISON_MIN_BITWIDTH)
if name in MAX_POOLS:
return normal_p + 1
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
# For operations that use multiply, an additional bit
# needs to be added to the bitwidths of the inputs.
# For single precision circuits the max of the input / output
# precisions will be taken in required_encrypted_bitwidth. For
# multi-precision, the circuit partitions will handle the
# input and output precisions separately.
all_inp_bitwidths = []
# Need a loop here to allow typechecking and make mypy happy
for inp in node.inputs:
dtype_inp = inp.dtype
assert isinstance(dtype_inp, Integer)
all_inp_bitwidths.append(dtype_inp.bit_width)
normal_p = max(all_inp_bitwidths)
# FIXME: This probably does not work well with multi-precision!
return max(normal_p + 1, node.output.dtype.bit_width)
return normal_p
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
"""Give the minimal precision to implement all the nodes.
This function is called for both single-precision (for the whole circuit)
and for multi-precision circuits (for circuit partitions).
Ops for which the compiler introduces TLUs need to be handled explicitly
in `max_encrypted_bitwidth_node`. The maximum
of all precisions of the various operations is returned.
AdditionalConstraints class to customize bit-width assignment step easily.
"""
bitwidths = map(max_encrypted_bitwidth_node, nodes)
return max(bitwidths, default=-1)
optimizer: z3.Optimize
graph: Graph
bit_widths: Dict[Node, z3.Int]
node: Node
bit_width: z3.Int
def required_inputs_encrypted_bitwidth(graph, node, nodes_output_p: list[tuple[Node, int]]) -> int:
"""Give the minimal precision to supports the inputs."""
preds = graph.ordered_preds_of(node)
get_prec = lambda node: nodes_output_p[node.properties[NODE_ID]][1]
# by definition all inputs have the same block precision
# see uniform_precision_per_blocks
return get_prec(node) if len(preds) == 0 else get_prec(preds[0])
# pylint: disable=missing-function-docstring,unused-argument
def __init__(self, optimizer: z3.Optimize, graph: Graph, bit_widths: Dict[Node, z3.Int]):
self.optimizer = optimizer
self.graph = graph
self.bit_widths = bit_widths
def assign_multi_precision(graph, nodes):
"""Assign a specific encryption precision to each nodes."""
add_nodes_id(nodes)
nodes_output_p = uniform_precision_per_blocks(graph, nodes)
for node, _ in nodes_output_p:
node.properties["original_bit_width"] = node.output.dtype.bit_width
nodes_inputs_p = [
required_inputs_encrypted_bitwidth(graph, node, nodes_output_p)
if can_change_precision(node)
else output_p
for node, output_p in nodes_output_p
]
for (node, output_p), inputs_p in zip(nodes_output_p, nodes_inputs_p):
assign_precisions_1_node(node, output_p, inputs_p)
clear_nodes_id(nodes)
def generate_for(self, node: Node, bit_width: z3.Int):
"""
Generate additional constraints for a node.
Args:
node (Node):
node to generate constraints for
TLU_WITHOUT_PRECISION_CHANGE = CHUNKED_COMPARISON | MAX_POOLS | MULTIPLY
bit_width (z3.Int):
symbolic bit-width which will be assigned to node once constraints are solved
"""
assert node.operation in {Operation.Generic, Operation.Constant, Operation.Input}
operation_name = (
node.properties["name"]
if node.operation == Operation.Generic
else ("constant" if node.operation == Operation.Constant else "input")
)
def can_change_precision(node):
"""Detect if a node completely ties inputs/output precisions together."""
if (
node.properties.get("name") in ROUNDING
and node.properties["attributes"]["overflow_protection"]
):
return False # protection can change precision
if hasattr(self, operation_name):
constraints = getattr(self, operation_name)
preds = self.graph.ordered_preds_of(node)
return (
node.converted_to_table_lookup
and node.properties.get("name") not in TLU_WITHOUT_PRECISION_CHANGE
)
if isinstance(constraints, set):
for add_constraint in constraints:
add_constraint(self, node, preds)
elif isinstance(constraints, dict):
for condition, conditional_constraints in constraints.items():
if condition(self, node, preds):
for add_constraint in conditional_constraints:
add_constraint(self, node, preds)
def convert_union_to_blocks(node_union: UnionFind) -> Iterable[list[int]]:
"""Convert a `UnionFind` to blocks.
else: # pragma: no cover
message = (
f"Expected a set or a dict "
f"for additional constraints of '{operation_name}' operation"
f"but got {type(constraints).__name__} instead"
)
raise ValueError(message)
The result is an iterable of blocks.A block being a list of node id.
"""
blocks = {}
for node_id in range(node_union.size):
node_canon = node_union.find_canonical(node_id)
if node_canon == node_id:
assert node_canon not in blocks
blocks[node_canon] = [node_id]
else:
blocks[node_canon].append(node_id)
return blocks.values()
# ==========
# Conditions
# ==========
def all_inputs_are_encrypted(self, node: Node, preds: List[Node]) -> bool:
return all(pred.output.is_encrypted for pred in preds)
NODE_ID = "node_id"
def some_inputs_are_clear(self, node: Node, preds: List[Node]) -> bool:
return any(pred.output.is_clear for pred in preds)
def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool:
return node.properties["attributes"]["overflow_protection"] is True
def add_nodes_id(nodes):
"""Temporarily add a NODE_ID property to all nodes."""
for node_id, node in enumerate(nodes):
assert NODE_ID not in node.properties
node.properties[NODE_ID] = node_id
# ===========
# Constraints
# ===========
def inputs_share_precision(self, node: Node, preds: List[Node]):
for i in range(len(preds) - 1):
self.optimizer.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])
def clear_nodes_id(nodes):
"""Remove the NODE_ID property from all nodes."""
for node in nodes:
del node.properties[NODE_ID]
def inputs_and_output_share_precision(self, node: Node, preds: List[Node]):
self.inputs_share_precision(node, preds)
if len(preds) != 0:
self.optimizer.add(self.bit_widths[preds[-1]] == self.bit_widths[node])
def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
for pred in preds:
assert isinstance(pred.output.dtype, Integer)
def uniform_precision_per_blocks(graph: Graph, nodes: list[Node]) -> list[tuple[Node, int]]:
"""Find the required precision of blocks and associate it corresponding nodes."""
size = len(nodes)
node_union = UnionFind(size)
for node_id, node in enumerate(nodes):
preds = graph.ordered_preds_of(node)
if not preds:
continue
# we always unify all inputs
first_input_id = preds[0].properties[NODE_ID]
for pred in preds[1:]:
pred_id = pred.properties[NODE_ID]
node_union.union(first_input_id, pred_id)
# we unify with outputs only if no precision change can occur
if not can_change_precision(node):
node_union.union(first_input_id, node_id)
actual_bit_width = pred.output.dtype.bit_width
required_bit_width = actual_bit_width + 1
blocks = convert_union_to_blocks(node_union)
result: list[None | tuple[Node, int]]
result = [None] * len(nodes)
for nodes_id in blocks:
output_p = required_encrypted_bitwidth(nodes[node_id] for node_id in nodes_id)
for node_id in nodes_id:
result[node_id] = (nodes[node_id], output_p)
assert None not in result
return typing.cast("list[tuple[Node, int]]", result)
self.optimizer.add(self.bit_widths[pred] >= required_bit_width)
def inputs_require_at_least_four_bits(self, node: Node, preds: List[Node]):
for pred in preds:
self.optimizer.add(self.bit_widths[pred] >= 4)
class UnionFind:
"""
Utility class joins the nodes in equivalent precision classes.
# ==========
# Operations
# ==========
Nodes are just integers id.
"""
add = {
inputs_and_output_share_precision,
}
parent: list[int]
array = {
inputs_and_output_share_precision,
}
def __init__(self, size: int):
"""Create a union find suitable for `size` nodes."""
self.parent = list(range(size))
assign_static = {
inputs_and_output_share_precision,
}
@property
def size(self):
"""Size in number of nodes."""
return len(self.parent)
bitwise_and = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
def find_canonical(self, a: int) -> int:
"""Find the current canonical node for a given input node."""
parent = self.parent[a]
if a == parent:
return a
canonical = self.find_canonical(parent)
self.parent[a] = canonical
return canonical
bitwise_or = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
def union(self, a: int, b: int):
"""Union both nodes."""
self.united_common_ancestor(a, b)
bitwise_xor = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
def united_common_ancestor(self, a: int, b: int) -> int:
"""Deduce the common ancestor of both nodes after unification."""
parent_a = self.parent[a]
parent_b = self.parent[b]
broadcast_to = {
inputs_and_output_share_precision,
}
if parent_a == parent_b:
return parent_a
concatenate = {
inputs_and_output_share_precision,
}
if a == parent_a and parent_b < parent_a:
common_ancestor = parent_b
elif b == parent_b and parent_a < parent_b:
common_ancestor = parent_a
else:
common_ancestor = self.united_common_ancestor(parent_a, parent_b)
conv1d = {
inputs_and_output_share_precision,
}
self.parent[a] = common_ancestor
self.parent[b] = common_ancestor
return common_ancestor
conv2d = {
inputs_and_output_share_precision,
}
conv3d = {
inputs_and_output_share_precision,
}
copy = {
inputs_and_output_share_precision,
}
dot = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
},
some_inputs_are_clear: {
inputs_and_output_share_precision,
},
}
equal = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
expand_dims = {
inputs_and_output_share_precision,
}
greater = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
inputs_require_at_least_four_bits,
},
}
greater_equal = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
inputs_require_at_least_four_bits,
},
}
index_static = {
inputs_and_output_share_precision,
}
left_shift = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
less = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
inputs_require_at_least_four_bits,
},
}
less_equal = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
inputs_require_at_least_four_bits,
},
}
matmul = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
},
some_inputs_are_clear: {
inputs_and_output_share_precision,
},
}
maxpool1d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
maxpool2d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
maxpool3d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
multiply = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
},
some_inputs_are_clear: {
inputs_and_output_share_precision,
},
}
negative = {
inputs_and_output_share_precision,
}
not_equal = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
reshape = {
inputs_and_output_share_precision,
}
right_shift = {
all_inputs_are_encrypted: {
inputs_and_output_share_precision,
},
}
round_bit_pattern = {
has_overflow_protection: {
inputs_and_output_share_precision,
},
}
subtract = {
inputs_and_output_share_precision,
}
sum = {
inputs_and_output_share_precision,
}
squeeze = {
inputs_and_output_share_precision,
}
transpose = {
inputs_and_output_share_precision,
}

View File

@@ -2,3 +2,4 @@ networkx>=2.6
numpy>=1.23
scipy>=1.10
torch>=1.13
z3-solver>=4.12

View File

@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers, pytestconfig):
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
%c2_i8 = arith.constant 2 : i8
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
return %1 : !FHE.eint<11>

View File

@@ -8,6 +8,8 @@ import pytest
from concrete import fhe
from concrete.fhe.mlir import GraphConverter
from ..conftest import USE_MULTI_PRECISION
def assign(x, y):
"""
@@ -607,6 +609,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 18-bits during compilation because of its relation with other operations)
%1 = y # EncryptedScalar<uint5> ∈ [20, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 18-bits during compilation because of its relation with other operations)
%2 = multiply(%0, %1) # EncryptedScalar<uint21> ∈ [2000000, 2000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 21-bit value is used as an operand to an encrypted multiplication
(note that it's assigned 21-bits during compilation because of its relation with other operations)
@@ -633,6 +651,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%1 = y # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%2 = dot(%0, %1) # EncryptedScalar<uint36> ∈ [40000000000, 40000000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
(note that it's assigned 36-bits during compilation because of its relation with other operations)
@@ -665,6 +699,22 @@ return %2
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%1 = y # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 19-bits during compilation because of its relation with other operations)
%2 = matmul(%0, %1) # EncryptedTensor<uint36, shape=(2, 2)> ∈ [40000000000, 50000000000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
return %2
""" # noqa: E501
if USE_MULTI_PRECISION
else """
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
(note that it's assigned 36-bits during compilation because of its relation with other operations)
@@ -758,6 +808,184 @@ def test_converter_bad_convert(
helpers.check_str(expected_message, str(excinfo.value))
@pytest.mark.parametrize(
"function,parameters,expected_mlir",
[
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
},
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
}
""", # noqa: E501
),
],
)
def test_converter_convert_multi_precision(function, parameters, expected_mlir, helpers):
"""
Test `convert` method of `Converter` with multi precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=False)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
@pytest.mark.parametrize(
"function,parameters,expected_mlir",
[
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
},
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
}
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
},
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
}
""", # noqa: E501
),
],
)
def test_converter_convert_single_precision(function, parameters, expected_mlir, helpers):
"""
Test `convert` method of `Converter` with multi precision.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(single_precision=True)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
@pytest.mark.parametrize(
"function,parameters,expected_graph",
[
@@ -769,7 +997,7 @@ def test_converter_bad_convert(
"""
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]