mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: introduce tag extension, create integer range getter for graphs, allow filtering in integer bit width getter
This commit is contained in:
@@ -22,6 +22,7 @@ from .extensions import (
|
||||
one,
|
||||
ones,
|
||||
round_bit_pattern,
|
||||
tag,
|
||||
univariate,
|
||||
zero,
|
||||
zeros,
|
||||
|
||||
@@ -7,11 +7,9 @@ import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..representation import Graph, Node
|
||||
from ..representation import Graph
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
@@ -25,14 +23,9 @@ class DebugArtifacts:
|
||||
|
||||
source_code: Optional[str]
|
||||
parameter_encryption_statuses: Dict[str, str]
|
||||
|
||||
textual_representations_of_graphs: Dict[str, List[str]]
|
||||
|
||||
final_graph: Optional[Graph]
|
||||
bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]]
|
||||
|
||||
mlir_to_compile: Optional[str]
|
||||
|
||||
client_parameters: Optional[bytes]
|
||||
|
||||
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
@@ -40,14 +33,9 @@ class DebugArtifacts:
|
||||
|
||||
self.source_code = None
|
||||
self.parameter_encryption_statuses = {}
|
||||
|
||||
self.textual_representations_of_graphs = {}
|
||||
|
||||
self.final_graph = None
|
||||
self.bounds_of_the_final_graph = None
|
||||
|
||||
self.mlir_to_compile = None
|
||||
|
||||
self.client_parameters = None
|
||||
|
||||
def add_source_code(self, function: Union[str, Callable]):
|
||||
@@ -100,18 +88,6 @@ class DebugArtifacts:
|
||||
|
||||
self.final_graph = graph
|
||||
|
||||
def add_final_graph_bounds(self, bounds: Dict[Node, Dict[str, Any]]):
|
||||
"""
|
||||
Add bounds of the latest computation graph.
|
||||
|
||||
Args:
|
||||
bounds (Dict[Node, Dict[str, Any]]):
|
||||
bounds of the latest computation graph
|
||||
"""
|
||||
|
||||
assert self.final_graph is not None
|
||||
self.bounds_of_the_final_graph = bounds
|
||||
|
||||
def add_mlir_to_compile(self, mlir: str):
|
||||
"""
|
||||
Add textual representation of the resulting MLIR.
|
||||
@@ -201,13 +177,6 @@ class DebugArtifacts:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}\n")
|
||||
|
||||
if self.bounds_of_the_final_graph is not None:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
|
||||
for index, node in enumerate(nx.topological_sort(self.final_graph.graph)):
|
||||
bounds = self.bounds_of_the_final_graph.get(node)
|
||||
f.write(f"%{index} :: [{bounds['min']}, {bounds['max']}]\n")
|
||||
|
||||
if self.mlir_to_compile is not None:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -278,10 +278,8 @@ class Compiler:
|
||||
assert self.graph is not None
|
||||
|
||||
bounds = self.graph.measure_bounds(self.inputset)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_final_graph_bounds(bounds)
|
||||
|
||||
self.graph.update_with_bounds(bounds)
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
|
||||
|
||||
@@ -557,14 +557,19 @@ def convert_subgraph_to_subgraph_node(
|
||||
|
||||
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
|
||||
if len(variable_input_nodes) != 1:
|
||||
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
|
||||
base_highlighted_nodes = {
|
||||
node: ["within this subgraph", node.location] for node in all_nodes
|
||||
}
|
||||
for variable_input_node in variable_input_nodes:
|
||||
base_highlighted_nodes[variable_input_node] = ["this is one of the input nodes"]
|
||||
base_highlighted_nodes[variable_input_node] = [
|
||||
"this is one of the input nodes",
|
||||
variable_input_node.location,
|
||||
]
|
||||
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because it has multiple input nodes\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
@@ -577,6 +582,10 @@ def convert_subgraph_to_subgraph_node(
|
||||
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
|
||||
nx_subgraph.add_node(subgraph_variable_input_node)
|
||||
|
||||
subgraph_variable_input_node.location = variable_input_node.location
|
||||
subgraph_variable_input_node.tag = variable_input_node.tag
|
||||
subgraph_variable_input_node.created_at = variable_input_node.created_at
|
||||
|
||||
variable_input_node_successors = {
|
||||
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
|
||||
}
|
||||
@@ -592,6 +601,10 @@ def convert_subgraph_to_subgraph_node(
|
||||
**new_edge_data,
|
||||
)
|
||||
|
||||
original_location = terminal_node.location
|
||||
original_tag = terminal_node.tag
|
||||
original_created_at = terminal_node.created_at
|
||||
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
||||
subgraph_node = Node.generic(
|
||||
"subgraph",
|
||||
@@ -604,6 +617,10 @@ def convert_subgraph_to_subgraph_node(
|
||||
},
|
||||
)
|
||||
|
||||
subgraph_node.location = original_location
|
||||
subgraph_node.tag = original_tag
|
||||
subgraph_node.created_at = original_created_at
|
||||
|
||||
return subgraph_node, variable_input_node
|
||||
|
||||
|
||||
@@ -635,8 +652,11 @@ def check_subgraph_fusability(
|
||||
if subgraph is not fusable
|
||||
"""
|
||||
|
||||
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
|
||||
base_highlighted_nodes[variable_input_node] = ["with this input node"]
|
||||
base_highlighted_nodes = {node: ["within this subgraph", node.location] for node in all_nodes}
|
||||
base_highlighted_nodes[variable_input_node] = [
|
||||
"with this input node",
|
||||
variable_input_node.location,
|
||||
]
|
||||
|
||||
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
|
||||
for node in non_constant_nodes:
|
||||
@@ -644,19 +664,22 @@ def check_subgraph_fusability(
|
||||
continue
|
||||
|
||||
if not node.is_fusable:
|
||||
base_highlighted_nodes[node] = ["this node is not fusable"]
|
||||
base_highlighted_nodes[node] = ["this node is not fusable", node.location]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is marked explicitly as non-fusable\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
if node.output.shape != variable_input_node.output.shape:
|
||||
base_highlighted_nodes[node] = ["this node has a different shape than the input node"]
|
||||
base_highlighted_nodes[node] = [
|
||||
"this node has a different shape than the input node",
|
||||
node.location,
|
||||
]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is has a different shape than the input node\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -6,5 +6,6 @@ from .array import array
|
||||
from .ones import one, ones
|
||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||
from .table import LookupTable
|
||||
from .tag import tag
|
||||
from .univariate import univariate
|
||||
from .zeros import zero, zeros
|
||||
|
||||
24
concrete/numpy/extensions/tag.py
Normal file
24
concrete/numpy/extensions/tag.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Declaration of `tag` context manager, to allow tagging certain nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
tag_context = threading.local()
|
||||
tag_context.stack = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tag(name: str):
|
||||
"""
|
||||
Introduce a new tag to the tag stack.
|
||||
|
||||
Can be nested, and the resulting tag will be `tag1.tag2`.
|
||||
"""
|
||||
|
||||
tag_context.stack.append(name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tag_context.stack.pop()
|
||||
@@ -204,7 +204,7 @@ class GraphConverter:
|
||||
if len(graph.output_nodes) > 1:
|
||||
offending_nodes.update(
|
||||
{
|
||||
node: ["only a single output is supported"]
|
||||
node: ["only a single output is supported", node.location]
|
||||
for node in graph.output_nodes.values()
|
||||
}
|
||||
)
|
||||
@@ -213,7 +213,7 @@ class GraphConverter:
|
||||
for node in graph.graph.nodes:
|
||||
reason = GraphConverter._check_node_convertibility(graph, node, virtual)
|
||||
if reason is not None:
|
||||
offending_nodes[node] = [reason]
|
||||
offending_nodes[node] = [reason, node.location]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
@@ -257,14 +257,16 @@ class GraphConverter:
|
||||
if max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers",
|
||||
first_tlu_node.location,
|
||||
]
|
||||
|
||||
if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS:
|
||||
offending_nodes[first_signed_node] = [
|
||||
f"signed integers are only supported "
|
||||
f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits "
|
||||
f"on circuits with table lookups"
|
||||
f"on circuits with table lookups",
|
||||
first_signed_node.location,
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Declaration of `Graph` class.
|
||||
"""
|
||||
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -25,11 +26,14 @@ class Graph:
|
||||
|
||||
input_indices: Dict[Node, int]
|
||||
|
||||
is_direct: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, Node],
|
||||
output_nodes: Dict[int, Node],
|
||||
is_direct: bool = False,
|
||||
):
|
||||
self.graph = graph
|
||||
|
||||
@@ -38,6 +42,8 @@ class Graph:
|
||||
|
||||
self.input_indices = {node: index for index, node in input_nodes.items()}
|
||||
|
||||
self.is_direct = is_direct
|
||||
|
||||
self.prune_useless_nodes()
|
||||
|
||||
def __call__(
|
||||
@@ -82,7 +88,10 @@ class Graph:
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
"Evaluation of the graph failed\n\n"
|
||||
+ self.format(highlighted_nodes={node: ["evaluation of this node failed"]})
|
||||
+ self.format(
|
||||
highlighted_nodes={node: ["evaluation of this node failed"]},
|
||||
show_bounds=False,
|
||||
)
|
||||
) from error
|
||||
|
||||
return node_results
|
||||
@@ -91,6 +100,10 @@ class Graph:
|
||||
self,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[Node, List[str]]] = None,
|
||||
show_types: bool = True,
|
||||
show_bounds: bool = True,
|
||||
show_tags: bool = True,
|
||||
show_locations: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Graph`.
|
||||
@@ -102,11 +115,28 @@ class Graph:
|
||||
highlighted_nodes (Optional[Dict[Node, List[str]]], default = None):
|
||||
nodes to be highlighted and their corresponding messages
|
||||
|
||||
show_types (bool, default = True):
|
||||
whether to show types of nodes
|
||||
|
||||
show_bounds (bool, default = True):
|
||||
whether to show bounds of nodes
|
||||
|
||||
show_tags (bool, default = True):
|
||||
whether to show tags of nodes
|
||||
|
||||
show_locations (bool, default = False):
|
||||
whether to show line information of nodes
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Graph`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
|
||||
if self.is_direct:
|
||||
show_bounds = False
|
||||
|
||||
# node -> identifier
|
||||
# e.g., id_map[node1] = 2
|
||||
# means line for node1 is in this form %2 = node1.format(...)
|
||||
@@ -115,9 +145,9 @@ class Graph:
|
||||
# lines that will be merged at the end
|
||||
lines: List[str] = []
|
||||
|
||||
# type information to add to each line
|
||||
# (for alingment, this is done after lines are determined)
|
||||
type_informations: List[str] = []
|
||||
# metadata to add to each line
|
||||
# (for alignment, this is done after lines are determined)
|
||||
line_metadata: List[Dict[str, str]] = []
|
||||
|
||||
# default highlighted nodes is empty
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
@@ -130,7 +160,7 @@ class Graph:
|
||||
subgraphs: Dict[str, Graph] = {}
|
||||
|
||||
# format nodes
|
||||
for node in nx.topological_sort(self.graph):
|
||||
for node in nx.lexicographical_topological_sort(self.graph):
|
||||
# assign a unique id to outputs of node
|
||||
id_map[node] = len(id_map)
|
||||
|
||||
@@ -160,8 +190,17 @@ class Graph:
|
||||
if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]:
|
||||
subgraphs[line] = node.properties["kwargs"]["subgraph"]
|
||||
|
||||
# remember type information of the node
|
||||
type_informations.append(str(node.output))
|
||||
# remember metadata of the node
|
||||
line_metadata.append(
|
||||
{
|
||||
"type": f"# {node.output}",
|
||||
"bounds": (
|
||||
f"∈ [{node.bounds[0]}, {node.bounds[1]}]" if node.bounds is not None else ""
|
||||
),
|
||||
"tag": (f"@ {node.tag}" if node.tag != "" else ""),
|
||||
"location": node.location,
|
||||
},
|
||||
)
|
||||
|
||||
# align = signs
|
||||
#
|
||||
@@ -182,11 +221,28 @@ class Graph:
|
||||
" " * (longest_length_before_equals_sign - length_before_equals_sign)
|
||||
) + line
|
||||
|
||||
# add type information
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] += " " * (longest_line_length - len(line))
|
||||
lines[i] += f" # {type_informations[i]}"
|
||||
# determine which metadata to show
|
||||
shown_metadata_keys = []
|
||||
if show_types:
|
||||
shown_metadata_keys.append("type")
|
||||
if show_bounds:
|
||||
shown_metadata_keys.append("bounds")
|
||||
if show_tags:
|
||||
shown_metadata_keys.append("tag")
|
||||
if show_locations:
|
||||
shown_metadata_keys.append("location")
|
||||
|
||||
# show requested metadata
|
||||
indent = 8
|
||||
for metadata_key in shown_metadata_keys:
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
lines = [
|
||||
line + (" " * ((longest_line_length - len(line)) + indent)) + metadata[metadata_key]
|
||||
for line, metadata in zip(lines, line_metadata)
|
||||
]
|
||||
|
||||
# strip whitespaces
|
||||
lines = [line.rstrip() for line in lines]
|
||||
|
||||
# add highlights (this is done in reverse to keep indices consistent)
|
||||
for i in reversed(range(len(lines))):
|
||||
@@ -209,13 +265,23 @@ class Graph:
|
||||
result += "\n\n"
|
||||
result += "Subgraphs:"
|
||||
for line, subgraph in subgraphs.items():
|
||||
subgraph_lines = subgraph.format(maximum_constant_length).split("\n")
|
||||
subgraph_lines = subgraph.format(
|
||||
maximum_constant_length=maximum_constant_length,
|
||||
highlighted_nodes={},
|
||||
show_types=show_types,
|
||||
show_bounds=False, # doesn't make sense as we don't measure bounds in subgraphs
|
||||
show_tags=show_tags,
|
||||
show_locations=show_locations,
|
||||
).split("\n")
|
||||
|
||||
result += "\n\n"
|
||||
result += f" {line}:\n\n"
|
||||
result += "\n".join(f" {line}" for line in subgraph_lines)
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-locals,too-many-statements
|
||||
|
||||
def measure_bounds(
|
||||
self,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
@@ -300,6 +366,8 @@ class Graph:
|
||||
min_bound = bounds[node]["min"]
|
||||
max_bound = bounds[node]["max"]
|
||||
|
||||
node.bounds = (min_bound, max_bound)
|
||||
|
||||
new_value = deepcopy(node.output)
|
||||
|
||||
if isinstance(min_bound, np.integer):
|
||||
@@ -384,17 +452,135 @@ class Graph:
|
||||
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
|
||||
def maximum_integer_bit_width(self) -> int:
|
||||
def query_nodes(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> List[Node]:
|
||||
"""
|
||||
Query nodes within the graph.
|
||||
|
||||
Filters work like so:
|
||||
str -> nodes without exact match is skipped
|
||||
List[str] -> nodes without exact match with one of the strings in the list is skipped
|
||||
re.Pattern -> nodes without pattern match is skipped
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
filtered nodes
|
||||
"""
|
||||
|
||||
def match_text_filter(text_filter, text):
|
||||
if text_filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(text_filter, str):
|
||||
return text == text_filter
|
||||
|
||||
if isinstance(text_filter, re.Pattern):
|
||||
return text_filter.match(text)
|
||||
|
||||
return any(text == alternative for alternative in text_filter)
|
||||
|
||||
def get_operation_name(node):
|
||||
result: str
|
||||
|
||||
if node.operation == Operation.Input:
|
||||
result = "input"
|
||||
elif node.operation == Operation.Constant:
|
||||
result = "constant"
|
||||
else:
|
||||
result = node.properties["name"]
|
||||
|
||||
return result
|
||||
|
||||
return [
|
||||
node
|
||||
for node in self.graph.nodes()
|
||||
if (
|
||||
match_text_filter(tag_filter, node.tag)
|
||||
and match_text_filter(operation_filter, get_operation_name(node))
|
||||
)
|
||||
]
|
||||
|
||||
def maximum_integer_bit_width(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get maximum integer bit-width within the graph.
|
||||
|
||||
Only nodes after filtering will be used to calculate the result.
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum integer bit-width within the graph (-1 is there are no integer nodes)
|
||||
maximum integer bit-width within the graph
|
||||
if there are no integer nodes matching the query, result is -1
|
||||
"""
|
||||
|
||||
result = -1
|
||||
for node in self.graph.nodes():
|
||||
if isinstance(node.output.dtype, Integer):
|
||||
result = max(result, node.output.dtype.bit_width)
|
||||
filtered_bit_widths = (
|
||||
node.output.dtype.bit_width
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer)
|
||||
)
|
||||
return max(filtered_bit_widths, default=-1)
|
||||
|
||||
def integer_range(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Get integer range of the graph.
|
||||
|
||||
Only nodes after filtering will be used to calculate the result.
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]:
|
||||
minimum and maximum integer value observed during inputset evaluation
|
||||
if there are no integer nodes matching the query, result is None
|
||||
"""
|
||||
|
||||
result: Optional[Tuple[int, int]] = None
|
||||
|
||||
if not self.is_direct:
|
||||
filtered_bounds = (
|
||||
node.bounds
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer) and node.bounds is not None
|
||||
)
|
||||
for min_bound, max_bound in filtered_bounds:
|
||||
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
|
||||
|
||||
if result is None:
|
||||
result = (int(min_bound), int(max_bound))
|
||||
else:
|
||||
old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence
|
||||
result = (
|
||||
min(old_min_bound, int(min_bound)),
|
||||
max(old_max_bound, int(max_bound)),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
Declaration of `Node` class.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -25,8 +28,13 @@ class Node:
|
||||
operation: Operation
|
||||
evaluator: Callable
|
||||
|
||||
bounds: Optional[Tuple[Union[int, float], Union[int, float]]]
|
||||
properties: Dict[str, Any]
|
||||
|
||||
location: str
|
||||
tag: str
|
||||
created_at: float
|
||||
|
||||
@staticmethod
|
||||
def constant(constant: Any) -> "Node":
|
||||
"""
|
||||
@@ -145,8 +153,44 @@ class Node:
|
||||
self.operation = operation
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
self.bounds = None
|
||||
self.properties = properties if properties is not None else {}
|
||||
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
cnp_directory = os.path.dirname(cnp.__file__)
|
||||
|
||||
import concrete.onnx as coonx
|
||||
|
||||
coonx_directory = os.path.dirname(coonx.__file__)
|
||||
|
||||
# pylint: enable=cyclic-import,import-outside-toplevel
|
||||
|
||||
for frame in reversed(traceback.extract_stack()):
|
||||
if frame.filename == "<__array_function__ internals>":
|
||||
continue
|
||||
|
||||
if frame.filename.startswith(cnp_directory):
|
||||
continue
|
||||
|
||||
if frame.filename.startswith(coonx_directory):
|
||||
continue
|
||||
|
||||
self.location = f"{frame.filename}:{frame.lineno}"
|
||||
break
|
||||
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
|
||||
from ..extensions.tag import tag_context
|
||||
|
||||
self.tag = ".".join(tag_context.stack)
|
||||
|
||||
# pylint: enable=cyclic-import,import-outside-toplevel
|
||||
|
||||
self.created_at = time.time()
|
||||
|
||||
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
|
||||
def generic_error_message() -> str:
|
||||
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
|
||||
@@ -361,3 +405,6 @@ class Node:
|
||||
"subtract",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
return self.created_at < other.created_at
|
||||
|
||||
@@ -151,7 +151,7 @@ class Tracer:
|
||||
output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
|
||||
return Graph(graph, input_nodes, output_nodes)
|
||||
return Graph(graph, input_nodes, output_nodes, is_direct)
|
||||
|
||||
def __init__(self, computation: Node, input_tracers: List["Tracer"]):
|
||||
self.computation = computation
|
||||
|
||||
@@ -43,7 +43,6 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "4.final.graph.txt").exists()
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
|
||||
@@ -60,6 +59,5 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "4.final.graph.txt").exists()
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
|
||||
@@ -25,16 +25,7 @@ def test_circuit_str(helpers):
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset, configuration.fork(p_error=6e-5))
|
||||
|
||||
assert str(circuit) == (
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = y # EncryptedScalar<uint5>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
assert str(circuit) == circuit.graph.format()
|
||||
|
||||
|
||||
def test_circuit_feedback(helpers):
|
||||
|
||||
@@ -262,9 +262,8 @@ def test_compiler_compile_bad_inputset(helpers):
|
||||
|
||||
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__).strip()
|
||||
== """
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
@@ -277,29 +276,30 @@ Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = inf # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint1>
|
||||
%1 = inf # ClearScalar<float64>
|
||||
%2 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
""".strip(),
|
||||
str(excinfo.value.__cause__).strip(),
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__).strip()
|
||||
== """
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = inf # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint1>
|
||||
%1 = inf # ClearScalar<float64>
|
||||
%2 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
""".strip(),
|
||||
str(excinfo.value.__cause__.__cause__).strip(),
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -319,9 +319,8 @@ return %3
|
||||
|
||||
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__).strip()
|
||||
== """
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
@@ -334,29 +333,30 @@ Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = nan # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint1>
|
||||
%1 = nan # ClearScalar<float64>
|
||||
%2 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
""".strip(),
|
||||
str(excinfo.value.__cause__).strip(),
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__).strip()
|
||||
== """
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = nan # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint1>
|
||||
%1 = nan # ClearScalar<float64>
|
||||
%2 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
""".strip(),
|
||||
str(excinfo.value.__cause__.__cause__).strip(),
|
||||
)
|
||||
|
||||
assert (
|
||||
|
||||
@@ -48,9 +48,9 @@ def test_compiler_verbose_trace(helpers, capsys):
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
@@ -112,19 +112,19 @@ def test_compiler_verbose_virtual_compile(helpers, capsys):
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
|
||||
MLIR
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
Virtual circuits don't have MLIR.
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
Virtual circuits don't have optimizer output.
|
||||
------------------------------------------------
|
||||
------------------------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
@@ -140,7 +140,6 @@ def test_circuit(helpers):
|
||||
return x + 42
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit1),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint2>
|
||||
@@ -149,6 +148,7 @@ def test_circuit(helpers):
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
str(circuit1),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
@@ -158,7 +158,6 @@ return %2
|
||||
return x + 42
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit2),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
@@ -167,6 +166,7 @@ return %2
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
str(circuit2),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
@@ -179,7 +179,6 @@ return %2
|
||||
return cnp.univariate(square, outputs=cnp.uint7)(x)
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit3),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
@@ -187,6 +186,7 @@ return %2
|
||||
return %1
|
||||
|
||||
""".strip(),
|
||||
str(circuit3),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
@@ -196,7 +196,6 @@ return %1
|
||||
return ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(cnp.uint3)
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit4),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
@@ -207,18 +206,19 @@ Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = 2 # ClearScalar<uint2>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = input # EncryptedScalar<uint3>
|
||||
%3 = sin(%2) # EncryptedScalar<float64>
|
||||
%4 = cos(%2) # EncryptedScalar<float64>
|
||||
%5 = power(%3, %0) # EncryptedScalar<float64>
|
||||
%6 = power(%4, %1) # EncryptedScalar<float64>
|
||||
%7 = add(%5, %6) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint3>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
%2 = 2 # ClearScalar<uint2>
|
||||
%3 = power(%1, %2) # EncryptedScalar<float64>
|
||||
%4 = cos(%0) # EncryptedScalar<float64>
|
||||
%5 = 2 # ClearScalar<uint2>
|
||||
%6 = power(%4, %5) # EncryptedScalar<float64>
|
||||
%7 = add(%3, %6) # EncryptedScalar<float64>
|
||||
%8 = astype(%7) # EncryptedScalar<uint3>
|
||||
return %8
|
||||
|
||||
""".strip(),
|
||||
str(circuit4),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
@@ -228,7 +228,6 @@ Subgraphs:
|
||||
return x + 42
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit5),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<int2>
|
||||
@@ -237,6 +236,7 @@ Subgraphs:
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
str(circuit5),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ Configuration of `pytest`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
@@ -11,6 +12,10 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
import tests
|
||||
|
||||
tests_directory = os.path.dirname(tests.__file__)
|
||||
|
||||
|
||||
INSECURE_KEY_CACHE_LOCATION = None
|
||||
|
||||
@@ -279,6 +284,14 @@ Actual Output
|
||||
actual str
|
||||
"""
|
||||
|
||||
# remove error line information
|
||||
# there are explicit tests to make sure the line information is correct
|
||||
# however, it would have been very hard to keep the other tests up to date
|
||||
|
||||
actual = "\n".join(
|
||||
line for line in actual.splitlines() if not line.strip().startswith(tests_directory)
|
||||
)
|
||||
|
||||
assert (
|
||||
actual.strip() == expected.strip()
|
||||
), f"""
|
||||
|
||||
@@ -338,7 +338,7 @@ def test_bad_maxpool_special(helpers):
|
||||
def clear_input(x):
|
||||
return connx.maxpool(x, kernel_shape=(4, 3, 2))
|
||||
|
||||
inputset = [np.random.randint(0, 10, size=(1, 1, 10, 10, 10)) for i in range(100)]
|
||||
inputset = [np.zeros((1, 1, 10, 10, 10), dtype=np.int64)]
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
clear_input.compile(inputset, helpers.configuration())
|
||||
|
||||
@@ -348,9 +348,9 @@ def test_bad_maxpool_special(helpers):
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint4, shape=(1, 1, 10, 10, 10)>
|
||||
%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor<uint4, shape=(1, 1, 7, 8, 9)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported
|
||||
%0 = x # ClearTensor<uint1, shape=(1, 1, 10, 10, 10)> ∈ [0, 0]
|
||||
%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor<uint1, shape=(1, 1, 7, 8, 9)> ∈ [0, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported
|
||||
return %1
|
||||
|
||||
""".strip(), # noqa: E501
|
||||
|
||||
@@ -702,31 +702,31 @@ def test_others_bad_fusing(helpers):
|
||||
|
||||
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%2 = 2 # ClearScalar<uint2>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%0 = x # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%1 = y # ClearScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%2 = sin(%0) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%3 = 2 # ClearScalar<uint2>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%4 = x # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%5 = y # ClearScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%6 = sin(%4) # EncryptedScalar<float64>
|
||||
%4 = power(%2, %3) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%7 = cos(%5) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%8 = power(%6, %2) # EncryptedScalar<float64>
|
||||
%5 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%6 = multiply(%5, %4) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%9 = power(%7, %3) # ClearScalar<float64>
|
||||
%7 = cos(%1) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%10 = multiply(%0, %8) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%11 = multiply(%1, %9) # ClearScalar<float64>
|
||||
%8 = 2 # ClearScalar<uint2>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%9 = power(%7, %8) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%12 = add(%10, %11) # EncryptedScalar<float64>
|
||||
%10 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%11 = multiply(%10, %9) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%12 = add(%6, %11) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
|
||||
@@ -201,7 +201,7 @@ def test_auto_rounding(helpers):
|
||||
return %4
|
||||
|
||||
""",
|
||||
str(circuit3),
|
||||
str(circuit3.graph.format(show_bounds=False)),
|
||||
)
|
||||
|
||||
|
||||
|
||||
64
tests/extensions/test_tag.py
Normal file
64
tests/extensions/test_tag.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Tests of 'tag' extension.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def test_tag(helpers):
|
||||
"""
|
||||
Test tag extension.
|
||||
"""
|
||||
|
||||
def g(z):
|
||||
with cnp.tag("def"):
|
||||
a = 120 - z
|
||||
b = a // 4
|
||||
return b
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
with cnp.tag("abc"):
|
||||
x = x * 2
|
||||
with cnp.tag("foo"):
|
||||
y = x + 42
|
||||
z = np.sqrt(y).astype(np.int64)
|
||||
|
||||
return g(z + 3) * 2
|
||||
|
||||
inputset = range(10)
|
||||
circuit = f.trace(inputset, configuration=helpers.configuration())
|
||||
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = 2 # ClearScalar<uint2> @ abc
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint5> @ abc
|
||||
%3 = 42 # ClearScalar<uint6> @ abc.foo
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint6> @ abc.foo
|
||||
%5 = subgraph(%4) # EncryptedScalar<uint3> @ abc
|
||||
%6 = 3 # ClearScalar<uint2>
|
||||
%7 = add(%5, %6) # EncryptedScalar<uint4>
|
||||
%8 = 120 # ClearScalar<uint7> @ def
|
||||
%9 = subtract(%8, %7) # EncryptedScalar<uint7> @ def
|
||||
%10 = 4 # ClearScalar<uint3> @ def
|
||||
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> @ def
|
||||
%12 = 2 # ClearScalar<uint2>
|
||||
%13 = multiply(%11, %12) # EncryptedScalar<uint6>
|
||||
return %13
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%5 = subgraph(%4):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2> @ abc.foo
|
||||
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc
|
||||
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
circuit.format(show_bounds=False),
|
||||
)
|
||||
@@ -26,18 +26,18 @@ def assign(x):
|
||||
pytest.param(
|
||||
lambda x, y: (x - y, x + y),
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
[(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)],
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = y # ClearScalar<uint3>
|
||||
%2 = subtract(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
%3 = add(%0, %1) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%1 = y # ClearScalar<uint3> ∈ [0, 7]
|
||||
%2 = subtract(%0, %1) # EncryptedScalar<int4> ∈ [-7, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
%3 = add(%0, %1) # EncryptedScalar<uint4> ∈ [0, 14]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported
|
||||
return (%2, %3)
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -51,8 +51,8 @@ return (%2, %3)
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearScalar<int5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
|
||||
%0 = x # ClearScalar<int5> ∈ [-10, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
|
||||
return %0
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -66,12 +66,12 @@ return %0
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
|
||||
%1 = 1.5 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%0 = x # EncryptedScalar<float64> ∈ [0.0, 247.5]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
|
||||
%1 = 1.5 # ClearScalar<float64> ∈ [1.5, 1.5]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<float64> ∈ [0.0, 371.25]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -85,9 +85,9 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint7>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%0 = x # EncryptedScalar<uint7> ∈ [0, 99]
|
||||
%1 = sin(%0) # EncryptedScalar<float64> ∈ [-0.9999902065507035, 0.9999118601072672]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -107,10 +107,10 @@ return %1
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = y # ClearTensor<uint3, shape=(3, 2)>
|
||||
%2 = concatenate((%0, %1)) # EncryptedTensor<uint3, shape=(6, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
%1 = y # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
%2 = concatenate((%0, %1)) # EncryptedTensor<uint3, shape=(6, 2)> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -130,10 +130,10 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4)>
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1)>
|
||||
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -153,10 +153,10 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4)>
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1)>
|
||||
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -176,10 +176,10 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)>
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1, 1)>
|
||||
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
|
||||
%1 = w # EncryptedTensor<uint1, shape=(1, 1, 1, 1, 1)> ∈ [0, 1]
|
||||
%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor<uint1, shape=(1, 1, 4, 4, 4)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -187,22 +187,16 @@ return %2
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2**2, size=(1,)),
|
||||
np.random.randint(0, 2**2, size=(1,)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
[([0], [0]), ([3], [3]), ([3], [0]), ([0], [3]), ([1], [1])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(1,)>
|
||||
%1 = y # EncryptedTensor<uint2, shape=(1,)>
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported
|
||||
%0 = x # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
|
||||
%1 = y # EncryptedTensor<uint2, shape=(1,)> ∈ [0, 3]
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint4> ∈ [0, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -210,15 +204,15 @@ return %2
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2**3, size=(4,)) for _ in range(100)],
|
||||
[[0, 1, 2, 3], [7, 6, 5, 4]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(4,)>
|
||||
%1 = %0[0] # ClearScalar<uint3>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported
|
||||
%0 = x # ClearTensor<uint3, shape=(4,)> ∈ [0, 7]
|
||||
%1 = %0[0] # ClearScalar<uint3> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -228,8 +222,8 @@ return %1
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.random.randint(0, 2**2, size=(1, 1)),
|
||||
np.random.randint(0, 2**2, size=(1, 1)),
|
||||
np.random.randint(0, 2**1, size=(1, 1)),
|
||||
np.random.randint(0, 2**1, size=(1, 1)),
|
||||
)
|
||||
for _ in range(100)
|
||||
],
|
||||
@@ -238,10 +232,10 @@ return %1
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(1, 1)>
|
||||
%1 = y # EncryptedTensor<uint2, shape=(1, 1)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(1, 1)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
%1 = y # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -249,16 +243,16 @@ return %2
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)],
|
||||
[(0, 0), (7, 7), (0, 7), (7, 0)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = y # EncryptedScalar<uint3>
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint6>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported
|
||||
%0 = x # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%1 = y # EncryptedScalar<uint3> ∈ [0, 7]
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint6> ∈ [0, 49]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -266,15 +260,15 @@ return %2
|
||||
pytest.param(
|
||||
lambda x: -x,
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2**3) for _ in range(100)],
|
||||
[0, 7],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearScalar<uint3>
|
||||
%1 = negative(%0) # ClearScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported
|
||||
%0 = x # ClearScalar<uint3> ∈ [0, 7]
|
||||
%1 = negative(%0) # ClearScalar<int4> ∈ [-7, 0]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -288,9 +282,9 @@ return %1
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint3, shape=(2, 3)>
|
||||
%1 = reshape(%0, newshape=(3, 2)) # ClearTensor<uint3, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported
|
||||
%0 = x # ClearTensor<uint3, shape=(2, 3)> ∈ [0, 7]
|
||||
%1 = reshape(%0, newshape=(3, 2)) # ClearTensor<uint3, shape=(3, 2)> ∈ [0, 7]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -304,9 +298,9 @@ return %1
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)>
|
||||
%1 = sum(%0) # ClearScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
|
||||
%1 = sum(%0) # ClearScalar<uint1> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -314,16 +308,16 @@ return %1
|
||||
pytest.param(
|
||||
lambda x: np.maximum(x, np.array([3])),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(1,)) for _ in range(100)],
|
||||
[[0], [1]],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)>
|
||||
%1 = [3] # ClearTensor<uint2, shape=(1,)>
|
||||
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
|
||||
%0 = x # ClearTensor<uint1, shape=(1,)> ∈ [0, 1]
|
||||
%1 = [3] # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
|
||||
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)> ∈ [3, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -331,15 +325,15 @@ return %2
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3, 2)) for _ in range(100)],
|
||||
[np.random.randint(0, 2, size=(3, 2)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3, 2)>
|
||||
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
|
||||
%0 = x # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
|
||||
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -347,15 +341,15 @@ return %1
|
||||
pytest.param(
|
||||
lambda x: np.broadcast_to(x, shape=(3, 2)),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(2,)) for _ in range(100)],
|
||||
[np.random.randint(0, 2, size=(2,)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(2,)>
|
||||
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
|
||||
%0 = x # ClearTensor<uint1, shape=(2,)> ∈ [0, 1]
|
||||
%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor<uint1, shape=(3, 2)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -363,19 +357,18 @@ return %1
|
||||
pytest.param(
|
||||
assign,
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3,)) for _ in range(100)],
|
||||
[np.random.randint(0, 2, size=(3,)) for _ in range(10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3,)>
|
||||
%1 = 0 # ClearScalar<uint1>
|
||||
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
|
||||
%0 = x # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
|
||||
%1 = 0 # ClearScalar<uint1> ∈ [0, 0]
|
||||
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)> ∈ [0, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
|
||||
return %2
|
||||
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
@@ -387,21 +380,21 @@ return %2
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint18>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint18>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers
|
||||
%0 = x # EncryptedScalar<uint18> ∈ [200000, 200000]
|
||||
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint18> ∈ [200300, 200300]
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint4> ∈ [9, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = input # EncryptedScalar<uint2>
|
||||
%2 = sin(%1) # EncryptedScalar<float64>
|
||||
%3 = multiply(%0, %2) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint2>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
%2 = 10 # ClearScalar<uint4>
|
||||
%3 = multiply(%2, %1) # EncryptedScalar<float64>
|
||||
%4 = absolute(%3) # EncryptedScalar<float64>
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
@@ -417,21 +410,21 @@ Subgraphs:
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint11>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint12>
|
||||
%3 = subgraph(%2) # EncryptedScalar<int5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
|
||||
%0 = x # EncryptedScalar<uint11> ∈ [1024, 2047]
|
||||
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint12> ∈ [1324, 2347]
|
||||
%3 = subgraph(%2) # EncryptedScalar<int5> ∈ [-9, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = input # EncryptedScalar<uint2>
|
||||
%2 = sin(%1) # EncryptedScalar<float64>
|
||||
%3 = multiply(%0, %2) # EncryptedScalar<float64>
|
||||
%0 = input # EncryptedScalar<uint2>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
%2 = 10 # ClearScalar<uint4>
|
||||
%3 = multiply(%2, %1) # EncryptedScalar<float64>
|
||||
%4 = astype(%3, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %4
|
||||
|
||||
|
||||
@@ -2,37 +2,175 @@
|
||||
Tests of `Graph` class.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
import tests
|
||||
|
||||
tests_directory = os.path.dirname(tests.__file__)
|
||||
|
||||
|
||||
def g(z):
|
||||
"""
|
||||
Example function with a tag.
|
||||
"""
|
||||
|
||||
with cnp.tag("def"):
|
||||
a = 120 - z
|
||||
b = a // 4
|
||||
return b
|
||||
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
Example function with nested tags.
|
||||
"""
|
||||
|
||||
with cnp.tag("abc"):
|
||||
x = x * 2
|
||||
with cnp.tag("foo"):
|
||||
y = x + 42
|
||||
z = np.sqrt(y).astype(np.int64)
|
||||
|
||||
return g(z + 3) * 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,expected_result",
|
||||
"function,inputset,tag_filter,operation_filter,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
range(5),
|
||||
None,
|
||||
None,
|
||||
3,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 42,
|
||||
range(10),
|
||||
None,
|
||||
None,
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 42,
|
||||
range(50),
|
||||
None,
|
||||
None,
|
||||
7,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1.2,
|
||||
[1.5, 4.2],
|
||||
None,
|
||||
None,
|
||||
-1,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
None,
|
||||
7,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"",
|
||||
None,
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc",
|
||||
None,
|
||||
5,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
["abc", "def"],
|
||||
None,
|
||||
7,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
re.compile(".*b.*"),
|
||||
None,
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"input",
|
||||
4,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"constant",
|
||||
7,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"subgraph",
|
||||
3,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"add",
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
["subgraph", "add"],
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
re.compile("sub.*"),
|
||||
7,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc.foo",
|
||||
"add",
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc",
|
||||
"floor_divide",
|
||||
-1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_graph_maximum_integer_bit_width(function, inputset, expected_result, helpers):
|
||||
def test_graph_maximum_integer_bit_width(
|
||||
function,
|
||||
inputset,
|
||||
tag_filter,
|
||||
operation_filter,
|
||||
expected_result,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test `maximum_integer_bit_width` method of `Graph` class.
|
||||
"""
|
||||
@@ -42,6 +180,192 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
print(graph.format())
|
||||
assert graph.maximum_integer_bit_width(tag_filter, operation_filter) == expected_result
|
||||
|
||||
assert graph.maximum_integer_bit_width() == expected_result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,tag_filter,operation_filter,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 42,
|
||||
range(-10, 10),
|
||||
None,
|
||||
None,
|
||||
(-10, 51),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1.2,
|
||||
[1.5, 4.2],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
None,
|
||||
(0, 120),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"",
|
||||
None,
|
||||
(0, 54),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc",
|
||||
None,
|
||||
(0, 18),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
["abc", "def"],
|
||||
None,
|
||||
(0, 120),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
re.compile(".*b.*"),
|
||||
None,
|
||||
(0, 60),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"input",
|
||||
(0, 9),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"constant",
|
||||
(2, 120),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"subgraph",
|
||||
(6, 7),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
"add",
|
||||
(9, 60),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
["subgraph", "add"],
|
||||
(6, 60),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
None,
|
||||
re.compile("sub.*"),
|
||||
(6, 111),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc.foo",
|
||||
"add",
|
||||
(42, 60),
|
||||
),
|
||||
pytest.param(
|
||||
f,
|
||||
range(10),
|
||||
"abc",
|
||||
"floor_divide",
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_graph_integer_range(
|
||||
function,
|
||||
inputset,
|
||||
tag_filter,
|
||||
operation_filter,
|
||||
expected_result,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test `integer_range` method of `Graph` class.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
assert graph.integer_range(tag_filter, operation_filter) == expected_result
|
||||
|
||||
|
||||
def test_graph_format_show_lines(helpers):
|
||||
"""
|
||||
Test `format` method of `Graph` class with show_lines=True.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(f, {"x": "encrypted"})
|
||||
graph = compiler.trace(range(10), configuration)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
expected = f"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4> ∈ [0, 9] {tests_directory}/representation/test_graph.py:324
|
||||
%1 = 2 # ClearScalar<uint2> ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<uint5> ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34
|
||||
%3 = 42 # ClearScalar<uint6> ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%4 = add(%2, %3) # EncryptedScalar<uint6> ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%5 = subgraph(%4) # EncryptedScalar<uint3> ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%6 = 3 # ClearScalar<uint2> ∈ [3, 3] {tests_directory}/representation/test_graph.py:39
|
||||
%7 = add(%5, %6) # EncryptedScalar<uint4> ∈ [9, 10] {tests_directory}/representation/test_graph.py:39
|
||||
%8 = 120 # ClearScalar<uint7> ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%9 = subtract(%8, %7) # EncryptedScalar<uint7> ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23
|
||||
%10 = 4 # ClearScalar<uint3> ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%11 = floor_divide(%9, %10) # EncryptedScalar<uint5> ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24
|
||||
%12 = 2 # ClearScalar<uint2> ∈ [2, 2] {tests_directory}/representation/test_graph.py:39
|
||||
%13 = multiply(%11, %12) # EncryptedScalar<uint6> ∈ [54, 54] {tests_directory}/representation/test_graph.py:39
|
||||
return %13
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%5 = subgraph(%4):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2> @ abc.foo {tests_directory}/representation/test_graph.py:36
|
||||
%1 = sqrt(%0) # EncryptedScalar<float64> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
%2 = astype(%1, dtype=int_) # EncryptedScalar<uint1> @ abc {tests_directory}/representation/test_graph.py:37
|
||||
return %2
|
||||
|
||||
""" # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
actual = graph.format(show_locations=True)
|
||||
|
||||
assert (
|
||||
actual.strip() == expected.strip()
|
||||
), f"""
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{actual}
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user