mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(frontend-python): re-write bit width assignment
This commit is contained in:
@@ -16,6 +16,7 @@ ignore = [
|
||||
"**/__init__.py" = ["F401"]
|
||||
"concrete/fhe/compilation/configuration.py" = ["ARG002"]
|
||||
"concrete/fhe/mlir/processors/all.py" = ["F401"]
|
||||
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
|
||||
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
|
||||
"examples/**" = ["PLR2004"]
|
||||
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311"]
|
||||
|
||||
@@ -802,7 +802,10 @@ class Context:
|
||||
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
if x.is_encrypted and y.is_encrypted:
|
||||
assert self.is_bit_width_compatible(x, y)
|
||||
else:
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
if x.is_scalar or y.is_scalar:
|
||||
return self.mul(resulting_type, x, y)
|
||||
@@ -1369,7 +1372,10 @@ class Context:
|
||||
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
if x.is_encrypted and y.is_encrypted:
|
||||
assert self.is_bit_width_compatible(x, y)
|
||||
else:
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
if resulting_type.shape == ():
|
||||
if x.is_clear:
|
||||
@@ -1489,7 +1495,10 @@ class Context:
|
||||
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
if x.is_encrypted and y.is_encrypted:
|
||||
assert self.is_bit_width_compatible(x, y)
|
||||
else:
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
use_linalg = x.is_tensor or y.is_tensor
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Declaration of `AssignBitWidths` graph processor.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from typing import Dict, List
|
||||
|
||||
import z3
|
||||
|
||||
from ...dtypes import Integer
|
||||
from ...representation import Graph, Node, Operation
|
||||
@@ -13,266 +13,344 @@ from . import GraphProcessor
|
||||
|
||||
class AssignBitWidths(GraphProcessor):
|
||||
"""
|
||||
Assign a precision to all nodes inputs/output.
|
||||
AssignBitWidths graph processor, to assign proper bit-widths to be compatible with FHE.
|
||||
|
||||
The precisions are compatible graph constraints and MLIR.
|
||||
There are two modes:
|
||||
- single precision: where all encrypted values have the same precision.
|
||||
- multi precision: where encrypted values can have different precisions.
|
||||
- Single Precision, where all encrypted values have the same precision.
|
||||
- Multi Precision, where encrypted values can have different precisions.
|
||||
"""
|
||||
|
||||
def __init__(self, single_precision=False):
|
||||
self.single_precision = single_precision
|
||||
|
||||
def apply(self, graph: Graph):
|
||||
nodes = graph.query_nodes()
|
||||
for node in nodes:
|
||||
optimizer = z3.Optimize()
|
||||
|
||||
max_bit_width: z3.Int = z3.Int("max")
|
||||
bit_widths: Dict[Node, z3.Int] = {}
|
||||
|
||||
additional_constraints = AdditionalConstraints(optimizer, graph, bit_widths)
|
||||
|
||||
nodes = graph.query_nodes(ordered=True)
|
||||
for i, node in enumerate(nodes):
|
||||
bit_width = z3.Int(f"%{i}")
|
||||
bit_widths[node] = bit_width
|
||||
|
||||
optimizer.add(max_bit_width >= bit_width)
|
||||
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
optimizer.add(bit_width >= node.output.dtype.bit_width)
|
||||
|
||||
additional_constraints.generate_for(node, bit_width)
|
||||
|
||||
if self.single_precision:
|
||||
assign_single_precision(nodes)
|
||||
else:
|
||||
assign_multi_precision(graph, nodes)
|
||||
for bit_width in bit_widths.values():
|
||||
optimizer.add(bit_width == max_bit_width)
|
||||
|
||||
optimizer.minimize(sum(bit_width**2 for bit_width in bit_widths.values()))
|
||||
|
||||
assert optimizer.check() == z3.sat
|
||||
model = optimizer.model()
|
||||
|
||||
for node, bit_width in bit_widths.items():
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
new_bit_width = model[bit_width].as_long()
|
||||
|
||||
if node.output.is_clear:
|
||||
new_bit_width += 1
|
||||
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
node.output.dtype.bit_width = new_bit_width
|
||||
|
||||
|
||||
def assign_single_precision(nodes: list[Node]):
|
||||
"""Assign one single encryption precision to all nodes."""
|
||||
p = required_encrypted_bitwidth(nodes)
|
||||
for node in nodes:
|
||||
assign_precisions_1_node(node, p, p)
|
||||
|
||||
|
||||
def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
|
||||
"""Assign input/output precision to a single node.
|
||||
|
||||
Precision are adjusted to match different use, e.g. encrypted and constant case.
|
||||
class AdditionalConstraints:
|
||||
"""
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
if node.output.is_encrypted:
|
||||
node.output.dtype.bit_width = output_p
|
||||
else:
|
||||
node.output.dtype.bit_width = output_p + 1
|
||||
|
||||
for value in node.inputs:
|
||||
assert isinstance(value.dtype, Integer)
|
||||
if value.is_encrypted:
|
||||
value.dtype.bit_width = inputs_p
|
||||
else:
|
||||
value.dtype.bit_width = inputs_p + 1
|
||||
|
||||
|
||||
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
|
||||
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
|
||||
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
|
||||
ROUNDING = {"round_bit_pattern"}
|
||||
MULTIPLY = {"multiply", "matmul", "dot"}
|
||||
|
||||
|
||||
def max_encrypted_bitwidth_node(node: Node):
|
||||
"""Give the minimal precision to implement the node.
|
||||
|
||||
This applies to both input and output precisions.
|
||||
"""
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
if node.output.is_encrypted or node.operation == Operation.Constant:
|
||||
normal_p = node.output.dtype.bit_width
|
||||
else:
|
||||
normal_p = -1
|
||||
name = node.properties.get("name")
|
||||
|
||||
if name in CHUNKED_COMPARISON:
|
||||
return max(normal_p, CHUNKED_COMPARISON_MIN_BITWIDTH)
|
||||
|
||||
if name in MAX_POOLS:
|
||||
return normal_p + 1
|
||||
|
||||
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
|
||||
# For operations that use multiply, an additional bit
|
||||
# needs to be added to the bitwidths of the inputs.
|
||||
# For single precision circuits the max of the input / output
|
||||
# precisions will be taken in required_encrypted_bitwidth. For
|
||||
# multi-precision, the circuit partitions will handle the
|
||||
# input and output precisions separately.
|
||||
all_inp_bitwidths = []
|
||||
# Need a loop here to allow typechecking and make mypy happy
|
||||
for inp in node.inputs:
|
||||
dtype_inp = inp.dtype
|
||||
assert isinstance(dtype_inp, Integer)
|
||||
all_inp_bitwidths.append(dtype_inp.bit_width)
|
||||
|
||||
normal_p = max(all_inp_bitwidths)
|
||||
|
||||
# FIXME: This probably does not work well with multi-precision!
|
||||
return max(normal_p + 1, node.output.dtype.bit_width)
|
||||
|
||||
return normal_p
|
||||
|
||||
|
||||
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
|
||||
"""Give the minimal precision to implement all the nodes.
|
||||
|
||||
This function is called for both single-precision (for the whole circuit)
|
||||
and for multi-precision circuits (for circuit partitions).
|
||||
|
||||
Ops for which the compiler introduces TLUs need to be handled explicitly
|
||||
in `max_encrypted_bitwidth_node`. The maximum
|
||||
of all precisions of the various operations is returned.
|
||||
AdditionalConstraints class to customize bit-width assignment step easily.
|
||||
"""
|
||||
|
||||
bitwidths = map(max_encrypted_bitwidth_node, nodes)
|
||||
return max(bitwidths, default=-1)
|
||||
optimizer: z3.Optimize
|
||||
graph: Graph
|
||||
bit_widths: Dict[Node, z3.Int]
|
||||
|
||||
node: Node
|
||||
bit_width: z3.Int
|
||||
|
||||
def required_inputs_encrypted_bitwidth(graph, node, nodes_output_p: list[tuple[Node, int]]) -> int:
|
||||
"""Give the minimal precision to supports the inputs."""
|
||||
preds = graph.ordered_preds_of(node)
|
||||
get_prec = lambda node: nodes_output_p[node.properties[NODE_ID]][1]
|
||||
# by definition all inputs have the same block precision
|
||||
# see uniform_precision_per_blocks
|
||||
return get_prec(node) if len(preds) == 0 else get_prec(preds[0])
|
||||
# pylint: disable=missing-function-docstring,unused-argument
|
||||
|
||||
def __init__(self, optimizer: z3.Optimize, graph: Graph, bit_widths: Dict[Node, z3.Int]):
|
||||
self.optimizer = optimizer
|
||||
self.graph = graph
|
||||
self.bit_widths = bit_widths
|
||||
|
||||
def assign_multi_precision(graph, nodes):
|
||||
"""Assign a specific encryption precision to each nodes."""
|
||||
add_nodes_id(nodes)
|
||||
nodes_output_p = uniform_precision_per_blocks(graph, nodes)
|
||||
for node, _ in nodes_output_p:
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
nodes_inputs_p = [
|
||||
required_inputs_encrypted_bitwidth(graph, node, nodes_output_p)
|
||||
if can_change_precision(node)
|
||||
else output_p
|
||||
for node, output_p in nodes_output_p
|
||||
]
|
||||
for (node, output_p), inputs_p in zip(nodes_output_p, nodes_inputs_p):
|
||||
assign_precisions_1_node(node, output_p, inputs_p)
|
||||
clear_nodes_id(nodes)
|
||||
def generate_for(self, node: Node, bit_width: z3.Int):
|
||||
"""
|
||||
Generate additional constraints for a node.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
node to generate constraints for
|
||||
|
||||
TLU_WITHOUT_PRECISION_CHANGE = CHUNKED_COMPARISON | MAX_POOLS | MULTIPLY
|
||||
bit_width (z3.Int):
|
||||
symbolic bit-width which will be assigned to node once constraints are solved
|
||||
"""
|
||||
|
||||
assert node.operation in {Operation.Generic, Operation.Constant, Operation.Input}
|
||||
operation_name = (
|
||||
node.properties["name"]
|
||||
if node.operation == Operation.Generic
|
||||
else ("constant" if node.operation == Operation.Constant else "input")
|
||||
)
|
||||
|
||||
def can_change_precision(node):
|
||||
"""Detect if a node completely ties inputs/output precisions together."""
|
||||
if (
|
||||
node.properties.get("name") in ROUNDING
|
||||
and node.properties["attributes"]["overflow_protection"]
|
||||
):
|
||||
return False # protection can change precision
|
||||
if hasattr(self, operation_name):
|
||||
constraints = getattr(self, operation_name)
|
||||
preds = self.graph.ordered_preds_of(node)
|
||||
|
||||
return (
|
||||
node.converted_to_table_lookup
|
||||
and node.properties.get("name") not in TLU_WITHOUT_PRECISION_CHANGE
|
||||
)
|
||||
if isinstance(constraints, set):
|
||||
for add_constraint in constraints:
|
||||
add_constraint(self, node, preds)
|
||||
|
||||
elif isinstance(constraints, dict):
|
||||
for condition, conditional_constraints in constraints.items():
|
||||
if condition(self, node, preds):
|
||||
for add_constraint in conditional_constraints:
|
||||
add_constraint(self, node, preds)
|
||||
|
||||
def convert_union_to_blocks(node_union: UnionFind) -> Iterable[list[int]]:
|
||||
"""Convert a `UnionFind` to blocks.
|
||||
else: # pragma: no cover
|
||||
message = (
|
||||
f"Expected a set or a dict "
|
||||
f"for additional constraints of '{operation_name}' operation"
|
||||
f"but got {type(constraints).__name__} instead"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
The result is an iterable of blocks.A block being a list of node id.
|
||||
"""
|
||||
blocks = {}
|
||||
for node_id in range(node_union.size):
|
||||
node_canon = node_union.find_canonical(node_id)
|
||||
if node_canon == node_id:
|
||||
assert node_canon not in blocks
|
||||
blocks[node_canon] = [node_id]
|
||||
else:
|
||||
blocks[node_canon].append(node_id)
|
||||
return blocks.values()
|
||||
# ==========
|
||||
# Conditions
|
||||
# ==========
|
||||
|
||||
def all_inputs_are_encrypted(self, node: Node, preds: List[Node]) -> bool:
|
||||
return all(pred.output.is_encrypted for pred in preds)
|
||||
|
||||
NODE_ID = "node_id"
|
||||
def some_inputs_are_clear(self, node: Node, preds: List[Node]) -> bool:
|
||||
return any(pred.output.is_clear for pred in preds)
|
||||
|
||||
def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool:
|
||||
return node.properties["attributes"]["overflow_protection"] is True
|
||||
|
||||
def add_nodes_id(nodes):
|
||||
"""Temporarily add a NODE_ID property to all nodes."""
|
||||
for node_id, node in enumerate(nodes):
|
||||
assert NODE_ID not in node.properties
|
||||
node.properties[NODE_ID] = node_id
|
||||
# ===========
|
||||
# Constraints
|
||||
# ===========
|
||||
|
||||
def inputs_share_precision(self, node: Node, preds: List[Node]):
|
||||
for i in range(len(preds) - 1):
|
||||
self.optimizer.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])
|
||||
|
||||
def clear_nodes_id(nodes):
|
||||
"""Remove the NODE_ID property from all nodes."""
|
||||
for node in nodes:
|
||||
del node.properties[NODE_ID]
|
||||
def inputs_and_output_share_precision(self, node: Node, preds: List[Node]):
|
||||
self.inputs_share_precision(node, preds)
|
||||
if len(preds) != 0:
|
||||
self.optimizer.add(self.bit_widths[preds[-1]] == self.bit_widths[node])
|
||||
|
||||
def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
|
||||
for pred in preds:
|
||||
assert isinstance(pred.output.dtype, Integer)
|
||||
|
||||
def uniform_precision_per_blocks(graph: Graph, nodes: list[Node]) -> list[tuple[Node, int]]:
|
||||
"""Find the required precision of blocks and associate it corresponding nodes."""
|
||||
size = len(nodes)
|
||||
node_union = UnionFind(size)
|
||||
for node_id, node in enumerate(nodes):
|
||||
preds = graph.ordered_preds_of(node)
|
||||
if not preds:
|
||||
continue
|
||||
# we always unify all inputs
|
||||
first_input_id = preds[0].properties[NODE_ID]
|
||||
for pred in preds[1:]:
|
||||
pred_id = pred.properties[NODE_ID]
|
||||
node_union.union(first_input_id, pred_id)
|
||||
# we unify with outputs only if no precision change can occur
|
||||
if not can_change_precision(node):
|
||||
node_union.union(first_input_id, node_id)
|
||||
actual_bit_width = pred.output.dtype.bit_width
|
||||
required_bit_width = actual_bit_width + 1
|
||||
|
||||
blocks = convert_union_to_blocks(node_union)
|
||||
result: list[None | tuple[Node, int]]
|
||||
result = [None] * len(nodes)
|
||||
for nodes_id in blocks:
|
||||
output_p = required_encrypted_bitwidth(nodes[node_id] for node_id in nodes_id)
|
||||
for node_id in nodes_id:
|
||||
result[node_id] = (nodes[node_id], output_p)
|
||||
assert None not in result
|
||||
return typing.cast("list[tuple[Node, int]]", result)
|
||||
self.optimizer.add(self.bit_widths[pred] >= required_bit_width)
|
||||
|
||||
def inputs_require_at_least_four_bits(self, node: Node, preds: List[Node]):
|
||||
for pred in preds:
|
||||
self.optimizer.add(self.bit_widths[pred] >= 4)
|
||||
|
||||
class UnionFind:
|
||||
"""
|
||||
Utility class joins the nodes in equivalent precision classes.
|
||||
# ==========
|
||||
# Operations
|
||||
# ==========
|
||||
|
||||
Nodes are just integers id.
|
||||
"""
|
||||
add = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
parent: list[int]
|
||||
array = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
def __init__(self, size: int):
|
||||
"""Create a union find suitable for `size` nodes."""
|
||||
self.parent = list(range(size))
|
||||
assign_static = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
"""Size in number of nodes."""
|
||||
return len(self.parent)
|
||||
bitwise_and = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
def find_canonical(self, a: int) -> int:
|
||||
"""Find the current canonical node for a given input node."""
|
||||
parent = self.parent[a]
|
||||
if a == parent:
|
||||
return a
|
||||
canonical = self.find_canonical(parent)
|
||||
self.parent[a] = canonical
|
||||
return canonical
|
||||
bitwise_or = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
def union(self, a: int, b: int):
|
||||
"""Union both nodes."""
|
||||
self.united_common_ancestor(a, b)
|
||||
bitwise_xor = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
def united_common_ancestor(self, a: int, b: int) -> int:
|
||||
"""Deduce the common ancestor of both nodes after unification."""
|
||||
parent_a = self.parent[a]
|
||||
parent_b = self.parent[b]
|
||||
broadcast_to = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
if parent_a == parent_b:
|
||||
return parent_a
|
||||
concatenate = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
if a == parent_a and parent_b < parent_a:
|
||||
common_ancestor = parent_b
|
||||
elif b == parent_b and parent_a < parent_b:
|
||||
common_ancestor = parent_a
|
||||
else:
|
||||
common_ancestor = self.united_common_ancestor(parent_a, parent_b)
|
||||
conv1d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
self.parent[a] = common_ancestor
|
||||
self.parent[b] = common_ancestor
|
||||
return common_ancestor
|
||||
conv2d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
conv3d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
copy = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
dot = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
},
|
||||
some_inputs_are_clear: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
expand_dims = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
greater = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_at_least_four_bits,
|
||||
},
|
||||
}
|
||||
|
||||
greater_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_at_least_four_bits,
|
||||
},
|
||||
}
|
||||
|
||||
index_static = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
left_shift = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
less = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_at_least_four_bits,
|
||||
},
|
||||
}
|
||||
|
||||
less_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_at_least_four_bits,
|
||||
},
|
||||
}
|
||||
|
||||
matmul = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
},
|
||||
some_inputs_are_clear: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
maxpool1d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
maxpool2d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
maxpool3d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
multiply = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
},
|
||||
some_inputs_are_clear: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
negative = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
not_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
reshape = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
right_shift = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
round_bit_pattern = {
|
||||
has_overflow_protection: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
subtract = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
sum = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
squeeze = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
transpose = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
@@ -2,3 +2,4 @@ networkx>=2.6
|
||||
numpy>=1.23
|
||||
scipy>=1.10
|
||||
torch>=1.13
|
||||
z3-solver>=4.12
|
||||
|
||||
@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers, pytestconfig):
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
|
||||
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
|
||||
%c2_i8 = arith.constant 2 : i8
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
|
||||
return %1 : !FHE.eint<11>
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
from concrete import fhe
|
||||
from concrete.fhe.mlir import GraphConverter
|
||||
|
||||
from ..conftest import USE_MULTI_PRECISION
|
||||
|
||||
|
||||
def assign(x, y):
|
||||
"""
|
||||
@@ -607,6 +609,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 18-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedScalar<uint5> ∈ [20, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 18-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 18-bits during compilation because of its relation with other operations)
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint21> ∈ [2000000, 2000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted multiplications are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedScalar<uint17> ∈ [100000, 100000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 21-bit value is used as an operand to an encrypted multiplication
|
||||
(note that it's assigned 21-bits during compilation because of its relation with other operations)
|
||||
@@ -633,6 +651,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint36> ∈ [40000000000, 40000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted dot products are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2,)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted dot products
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
@@ -665,6 +699,22 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%1 = y # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 19-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 19-bits during compilation because of its relation with other operations)
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint36, shape=(2, 2)> ∈ [40000000000, 50000000000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint18, shape=(2, 2)> ∈ [100000, 200000]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this 36-bit value is used as an operand to an encrypted matrix multiplication
|
||||
(note that it's assigned 36-bits during compilation because of its relation with other operations)
|
||||
@@ -758,6 +808,184 @@ def test_converter_bad_convert(
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_converter_convert_multi_precision(function, parameters, expected_mlir, helpers):
|
||||
"""
|
||||
Test `convert` method of `Converter` with multi precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=False)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (2, 4)},
|
||||
},
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_converter_convert_single_precision(function, parameters, expected_mlir, helpers):
|
||||
"""
|
||||
Test `convert` method of `Converter` with multi precision.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration().fork(single_precision=True)
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_graph",
|
||||
[
|
||||
@@ -769,7 +997,7 @@ def test_converter_bad_convert(
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
|
||||
%1 = 2 # ClearScalar<uint5> ∈ [2, 2]
|
||||
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
|
||||
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
|
||||
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
|
||||
|
||||
Reference in New Issue
Block a user