mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement representation module
This commit is contained in:
7
concrete/numpy/representation/__init__.py
Normal file
7
concrete/numpy/representation/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.representation` namespace.
|
||||
"""
|
||||
|
||||
from .graph import Graph
|
||||
from .node import Node
|
||||
from .operation import Operation
|
||||
498
concrete/numpy/representation/graph.py
Normal file
498
concrete/numpy/representation/graph.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
Declaration of `Graph` class.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..dtypes import Float, Integer, UnsignedInteger
|
||||
from .node import Node
|
||||
from .operation import OPERATION_COLOR_MAPPING, Operation
|
||||
|
||||
|
||||
class Graph:
|
||||
"""
|
||||
Graph class, to represent computation graphs.
|
||||
"""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
|
||||
input_nodes: Dict[int, Node]
|
||||
output_nodes: Dict[int, Node]
|
||||
|
||||
input_indices: Dict[Node, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, Node],
|
||||
output_nodes: Dict[int, Node],
|
||||
):
|
||||
self.graph = graph
|
||||
|
||||
self.input_nodes = input_nodes
|
||||
self.output_nodes = output_nodes
|
||||
|
||||
self.input_indices = {node: index for index, node in input_nodes.items()}
|
||||
|
||||
self.prune_useless_nodes()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
evaluation = self.evaluate(*args)
|
||||
result = tuple(evaluation[node] for node in self.ordered_outputs())
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
*args: Any,
|
||||
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
|
||||
r"""
|
||||
Perform the computation `Graph` represents and get resulting values for all nodes.
|
||||
|
||||
Args:
|
||||
*args (List[Any]):
|
||||
inputs to the computation
|
||||
|
||||
Returns:
|
||||
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
|
||||
nodes and their values during computation
|
||||
"""
|
||||
|
||||
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if node.operation == Operation.Input:
|
||||
node_results[node] = node(args[self.input_indices[node]])
|
||||
continue
|
||||
|
||||
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
|
||||
node_results[node] = node(*pred_results)
|
||||
return node_results
|
||||
|
||||
def draw(
|
||||
self,
|
||||
show: bool = False,
|
||||
horizontal: bool = False,
|
||||
save_to: Optional[Union[Path, str]] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Draw the `Graph` and optionally save/show the drawing.
|
||||
|
||||
note that this function requires the python `pygraphviz` package
|
||||
which itself requires the installation of `graphviz` packages
|
||||
see https://pygraphviz.github.io/documentation/stable/install.html
|
||||
|
||||
Args:
|
||||
show (bool, default = False):
|
||||
whether to show the drawing using matplotlib or not
|
||||
|
||||
horizontal (bool, default = False):
|
||||
whether to draw horizontally or not
|
||||
|
||||
save_to (Optional[Path], default = None):
|
||||
path to save the drawing
|
||||
a temporary file will be used if it's None
|
||||
|
||||
Returns:
|
||||
Path:
|
||||
path to the saved drawing
|
||||
"""
|
||||
|
||||
def get_color(node, output_nodes):
|
||||
if node in output_nodes:
|
||||
return OPERATION_COLOR_MAPPING["output"]
|
||||
return OPERATION_COLOR_MAPPING[node.operation]
|
||||
|
||||
graph = self.graph
|
||||
output_nodes = set(self.output_nodes.values())
|
||||
|
||||
attributes = {
|
||||
node: {
|
||||
"label": node.label(),
|
||||
"color": get_color(node, output_nodes),
|
||||
"penwidth": 2, # double thickness for circles
|
||||
"peripheries": 2 if node in output_nodes else 1, # two circles for output nodes
|
||||
}
|
||||
for node in graph.nodes
|
||||
}
|
||||
nx.set_node_attributes(graph, attributes)
|
||||
|
||||
for edge in graph.edges(keys=True):
|
||||
idx = graph.edges[edge]["input_idx"]
|
||||
graph.edges[edge]["label"] = f" {idx} "
|
||||
|
||||
try:
|
||||
agraph = nx.nx_agraph.to_agraph(graph)
|
||||
except ImportError as error: # pragma: no cover
|
||||
if "pygraphviz" in str(error):
|
||||
raise ImportError(
|
||||
"Graph.draw requires pygraphviz. Install graphviz distribution to your OS "
|
||||
"following https://pygraphviz.github.io/documentation/stable/install.html "
|
||||
"and reinstall concrete-numpy with extras: `pip install --force-reinstall "
|
||||
"concrete-numpy[full]`"
|
||||
) from error
|
||||
|
||||
raise
|
||||
|
||||
agraph.graph_attr["rankdir"] = "LR" if horizontal else "TB"
|
||||
agraph.layout("dot")
|
||||
|
||||
if save_to is None:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
# we need to change the permissions of the temporary file
|
||||
# so that it can be read by all users
|
||||
|
||||
# (https://stackoverflow.com/a/44130605)
|
||||
|
||||
# get the old umask and replace it with 0o666
|
||||
old_umask = os.umask(0o666)
|
||||
|
||||
# restore the old umask back
|
||||
os.umask(old_umask)
|
||||
|
||||
# combine the old umask with the wanted permissions
|
||||
permissions = 0o666 & ~old_umask
|
||||
|
||||
# set new permissions
|
||||
os.chmod(tmp.name, permissions)
|
||||
|
||||
save_to_str = str(tmp.name)
|
||||
else:
|
||||
save_to_str = str(save_to)
|
||||
|
||||
agraph.draw(save_to_str)
|
||||
|
||||
if show: # pragma: no cover
|
||||
plt.close("all")
|
||||
plt.figure()
|
||||
|
||||
img = Image.open(save_to_str)
|
||||
plt.imshow(img)
|
||||
img.close()
|
||||
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
return Path(save_to_str)
|
||||
|
||||
def format(
|
||||
self,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[Node, List[str]]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Graph`.
|
||||
|
||||
Args:
|
||||
maximum_constant_length (int, default = 25):
|
||||
maximum length of formatted constants
|
||||
|
||||
highlighted_nodes (Optional[Dict[Node, List[str]]], default = None):
|
||||
nodes to be highlighted and their corresponding messages
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Graph`
|
||||
"""
|
||||
|
||||
# node -> identifier
|
||||
# e.g., id_map[node1] = 2
|
||||
# means line for node1 is in this form %2 = node1.format(...)
|
||||
id_map: Dict[Node, int] = {}
|
||||
|
||||
# lines that will be merged at the end
|
||||
lines: List[str] = []
|
||||
|
||||
# type information to add to each line
|
||||
# (for alingment, this is done after lines are determined)
|
||||
type_informations: List[str] = []
|
||||
|
||||
# default highlighted nodes is empty
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
# highlight information for lines, this is required because highlights are added to lines
|
||||
# after their type information is added, and we only have line numbers, not nodes
|
||||
highlighted_lines: Dict[int, List[str]] = {}
|
||||
|
||||
# subgraphs to format after the main graph is formatted
|
||||
subgraphs: Dict[str, Graph] = {}
|
||||
|
||||
# format nodes
|
||||
for node in nx.topological_sort(self.graph):
|
||||
# assign a unique id to outputs of node
|
||||
id_map[node] = len(id_map)
|
||||
|
||||
# remember highlights of the node
|
||||
if node in highlighted_nodes:
|
||||
highlighted_lines[len(lines)] = highlighted_nodes[node]
|
||||
|
||||
# extract predecessors and their ids
|
||||
predecessors = []
|
||||
for predecessor in self.ordered_preds_of(node):
|
||||
predecessors.append(f"%{id_map[predecessor]}")
|
||||
|
||||
# start the build the line for the node
|
||||
line = ""
|
||||
|
||||
# add output information to the line
|
||||
line += f"%{id_map[node]}"
|
||||
|
||||
# add node information to the line
|
||||
line += " = "
|
||||
line += node.format(predecessors, maximum_constant_length)
|
||||
|
||||
# append line to list of lines
|
||||
lines.append(line)
|
||||
|
||||
# if exists, save the subgraph
|
||||
if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]:
|
||||
subgraphs[line] = node.properties["kwargs"]["subgraph"]
|
||||
|
||||
# remember type information of the node
|
||||
type_informations.append(str(node.output))
|
||||
|
||||
# align = signs
|
||||
#
|
||||
# e.g.,
|
||||
#
|
||||
# %1 = ...
|
||||
# %2 = ...
|
||||
# ...
|
||||
# %8 = ...
|
||||
# %9 = ...
|
||||
# %10 = ...
|
||||
# %11 = ...
|
||||
# ...
|
||||
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
length_before_equals_sign = len(line.split("=")[0])
|
||||
lines[i] = (
|
||||
" " * (longest_length_before_equals_sign - length_before_equals_sign)
|
||||
) + line
|
||||
|
||||
# add type information
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] += " " * (longest_line_length - len(line))
|
||||
lines[i] += f" # {type_informations[i]}"
|
||||
|
||||
# add highlights (this is done in reverse to keep indices consistent)
|
||||
for i in reversed(range(len(lines))):
|
||||
if i in highlighted_lines:
|
||||
for j, message in enumerate(highlighted_lines[i]):
|
||||
highlight = "^" if j == 0 else " "
|
||||
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
|
||||
|
||||
# add return information
|
||||
# (if there is a single return, it's in the form `return %id`
|
||||
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
|
||||
returns: List[str] = []
|
||||
for node in self.output_nodes.values():
|
||||
returns.append(f"%{id_map[node]}")
|
||||
lines.append("return " + (returns[0] if len(returns) == 1 else f"({', '.join(returns)})"))
|
||||
|
||||
# format subgraphs after the actual graph
|
||||
result = "\n".join(lines)
|
||||
if len(subgraphs) > 0:
|
||||
result += "\n\n"
|
||||
result += "Subgraphs:"
|
||||
for line, subgraph in subgraphs.items():
|
||||
subgraph_lines = subgraph.format(maximum_constant_length).split("\n")
|
||||
result += "\n\n"
|
||||
result += f" {line}:\n\n"
|
||||
result += "\n".join(f" {line}" for line in subgraph_lines)
|
||||
|
||||
return result
|
||||
|
||||
def measure_bounds(
|
||||
self,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
) -> Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
|
||||
"""
|
||||
Evaluate the `Graph` using an inputset and measure bounds.
|
||||
|
||||
inputset is either an iterable of anything
|
||||
for a single parameter
|
||||
|
||||
or
|
||||
|
||||
an iterable of tuples of anything (of rank number of parameters)
|
||||
for multiple parameters
|
||||
|
||||
e.g.,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
inputset = [1, 3, 5, 2, 4]
|
||||
def f(x):
|
||||
...
|
||||
|
||||
inputset = [(1, 2), (2, 4), (3, 1), (2, 2)]
|
||||
def g(x, y):
|
||||
...
|
||||
|
||||
Args:
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
inputset to use
|
||||
|
||||
Returns:
|
||||
Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
|
||||
bounds of each node in the `Graph`
|
||||
"""
|
||||
|
||||
bounds = {}
|
||||
|
||||
inputset_iterator = iter(inputset)
|
||||
|
||||
sample = next(inputset_iterator)
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": value.min(),
|
||||
"max": value.max(),
|
||||
}
|
||||
|
||||
for sample in inputset_iterator:
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": np.minimum(bounds[node]["min"], value.min()),
|
||||
"max": np.maximum(bounds[node]["max"], value.max()),
|
||||
}
|
||||
|
||||
return bounds
|
||||
|
||||
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
|
||||
"""
|
||||
Update `Value`s within the `Graph` according to measured bounds.
|
||||
|
||||
Args:
|
||||
bounds (Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
|
||||
bounds of each node in the `Graph`
|
||||
"""
|
||||
|
||||
for node in self.graph.nodes():
|
||||
if node in bounds:
|
||||
min_bound = bounds[node]["min"]
|
||||
max_bound = bounds[node]["max"]
|
||||
|
||||
new_value = deepcopy(node.output)
|
||||
|
||||
if isinstance(min_bound, np.integer):
|
||||
new_value.dtype = Integer.that_can_represent(np.array([min_bound, max_bound]))
|
||||
else:
|
||||
new_value.dtype = {
|
||||
np.bool_: UnsignedInteger(1),
|
||||
np.float64: Float(64),
|
||||
np.float32: Float(32),
|
||||
np.float16: Float(16),
|
||||
}[type(min_bound)]
|
||||
|
||||
node.output = new_value
|
||||
|
||||
if node.operation == Operation.Input:
|
||||
node.inputs[0] = new_value
|
||||
|
||||
for successor in self.graph.successors(node):
|
||||
edge_data = self.graph.get_edge_data(node, successor)
|
||||
for edge in edge_data.values():
|
||||
input_idx = edge["input_idx"]
|
||||
successor.inputs[input_idx] = node.output
|
||||
|
||||
def ordered_inputs(self) -> List[Node]:
|
||||
"""
|
||||
Get the input nodes of the `Graph`, ordered by their indices.
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered input nodes
|
||||
"""
|
||||
|
||||
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
|
||||
|
||||
def ordered_outputs(self) -> List[Node]:
|
||||
"""
|
||||
Get the output nodes of the `Graph`, ordered by their indices.
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered output nodes
|
||||
"""
|
||||
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def ordered_preds_of(self, node: Node) -> List[Node]:
|
||||
"""
|
||||
Get predecessors of `node`, ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
node whose predecessors are requested
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered predecessors of `node`.
|
||||
"""
|
||||
|
||||
idx_to_pred: Dict[int, Node] = {}
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
|
||||
def prune_useless_nodes(self):
|
||||
"""
|
||||
Remove unreachable nodes from the graph.
|
||||
"""
|
||||
|
||||
useful_nodes: Dict[Node, None] = {}
|
||||
|
||||
current_nodes = {node: None for node in self.ordered_outputs()}
|
||||
while current_nodes:
|
||||
useful_nodes.update(current_nodes)
|
||||
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
next_nodes.update({node: None for node in self.graph.predecessors(node)})
|
||||
|
||||
current_nodes = next_nodes
|
||||
|
||||
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
|
||||
def maximum_integer_bit_width(self) -> int:
|
||||
"""
|
||||
Get maximum integer bit-width within the graph.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum integer bit-width within the graph (-1 is there are no integer nodes)
|
||||
"""
|
||||
|
||||
result = -1
|
||||
for node in self.graph.nodes():
|
||||
if isinstance(node.output.dtype, Integer):
|
||||
result = max(result, node.output.dtype.bit_width)
|
||||
return result
|
||||
316
concrete/numpy/representation/node.py
Normal file
316
concrete/numpy/representation/node.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Declaration of `Node` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..values import Value
|
||||
from .operation import Operation
|
||||
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Node class, to represent computation in a computation graph.
|
||||
"""
|
||||
|
||||
inputs: List[Value]
|
||||
output: Value
|
||||
|
||||
operation: Operation
|
||||
evaluator: Callable
|
||||
|
||||
properties: Dict[str, Any]
|
||||
|
||||
@staticmethod
|
||||
def constant(constant: Any) -> "Node":
|
||||
"""
|
||||
Create an Operation.Constant node.
|
||||
|
||||
Args:
|
||||
constant (Any):
|
||||
constant to represent
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing constant
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if the constant is not representable
|
||||
"""
|
||||
|
||||
try:
|
||||
value = Value.of(constant)
|
||||
except Exception as error:
|
||||
raise ValueError(f"Constant {repr(constant)} is not supported") from error
|
||||
|
||||
properties = {"constant": np.array(constant)}
|
||||
return Node([], value, Operation.Constant, lambda: properties["constant"], properties)
|
||||
|
||||
@staticmethod
|
||||
def generic(
|
||||
name: str,
|
||||
inputs: List[Value],
|
||||
output: Value,
|
||||
operation: Callable,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Create an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the operation
|
||||
|
||||
inputs (List[Value]):
|
||||
inputs to the operation
|
||||
|
||||
output (Value):
|
||||
output of the operation
|
||||
|
||||
operation (Callable):
|
||||
operation itself
|
||||
|
||||
args (Optional[Tuple[Any, ...]]):
|
||||
args to pass to operation during evaluation
|
||||
|
||||
kwargs (Optional[Dict[str, Any]]):
|
||||
kwargs to pass to operation during evaluation
|
||||
|
||||
attributes (Optional[Dict[str, Any]]):
|
||||
attributes of the operation
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing operation
|
||||
"""
|
||||
|
||||
properties = {
|
||||
"name": name,
|
||||
"args": args if args is not None else (),
|
||||
"kwargs": kwargs if kwargs is not None else {},
|
||||
"attributes": attributes if attributes is not None else {},
|
||||
}
|
||||
|
||||
if name == "concatenate":
|
||||
|
||||
def evaluator(*args):
|
||||
return operation(tuple(args), *properties["args"], **properties["kwargs"])
|
||||
|
||||
else:
|
||||
|
||||
def evaluator(*args):
|
||||
return operation(*args, *properties["args"], **properties["kwargs"])
|
||||
|
||||
return Node(inputs, output, Operation.Generic, evaluator, properties)
|
||||
|
||||
@staticmethod
|
||||
def input(name: str, value: Value) -> "Node":
|
||||
"""
|
||||
Create an Operation.Input node.
|
||||
|
||||
Args:
|
||||
name (Any):
|
||||
name of the input
|
||||
|
||||
value (Any):
|
||||
value of the input
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing input
|
||||
"""
|
||||
|
||||
return Node([value], value, Operation.Input, lambda x: x, {"name": name})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: List[Value],
|
||||
output: Value,
|
||||
operation: Operation,
|
||||
evaluator: Callable,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.inputs = inputs
|
||||
self.output = output
|
||||
|
||||
self.operation = operation
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
self.properties = properties if properties is not None else {}
|
||||
|
||||
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
|
||||
def generic_error_message() -> str:
|
||||
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
|
||||
if len(args) != 0:
|
||||
result += f" using {', '.join(repr(arg) for arg in args)}"
|
||||
return result
|
||||
|
||||
if len(args) != len(self.inputs):
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} failed because of invalid number of arguments"
|
||||
)
|
||||
|
||||
for arg, input_ in zip(args, self.inputs):
|
||||
try:
|
||||
arg_value = Value.of(arg)
|
||||
except Exception as error:
|
||||
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} failed because {arg_str} is not valid"
|
||||
) from error
|
||||
|
||||
if input_.shape != arg_value.shape:
|
||||
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} failed because "
|
||||
f"{arg_str} does not have the expected "
|
||||
f"shape of {input_.shape}"
|
||||
)
|
||||
|
||||
result = self.evaluator(*args)
|
||||
|
||||
if isinstance(result, int) and -(2 ** 63) < result < (2 ** 63) - 1:
|
||||
result = np.int64(result)
|
||||
if isinstance(result, float):
|
||||
result = np.float64(result)
|
||||
|
||||
if isinstance(result, list):
|
||||
try:
|
||||
np_result = np.array(result)
|
||||
result = np_result
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception below
|
||||
pass
|
||||
|
||||
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"of type {result.__class__.__name__} "
|
||||
f"which is not acceptable either because of the type or because of overflow"
|
||||
)
|
||||
|
||||
if isinstance(result, np.ndarray):
|
||||
dtype = result.dtype
|
||||
if (
|
||||
not np.issubdtype(dtype, np.integer)
|
||||
and not np.issubdtype(dtype, np.floating)
|
||||
and not np.issubdtype(dtype, np.bool_)
|
||||
):
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
|
||||
f"which is not acceptable because of the underlying type"
|
||||
)
|
||||
|
||||
if result.shape != self.output.shape:
|
||||
raise ValueError(
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"which does not have the expected "
|
||||
f"shape of {self.output.shape}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Node` (for printing).
|
||||
|
||||
Args:
|
||||
predecessors (List[str]):
|
||||
predecessor names to this node
|
||||
|
||||
maximum_constant_length (int, default = 45):
|
||||
maximum length of formatted constants
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Node` (for printing)
|
||||
"""
|
||||
|
||||
if self.operation == Operation.Constant:
|
||||
return format_constant(self(), maximum_constant_length)
|
||||
|
||||
if self.operation == Operation.Input:
|
||||
return self.properties["name"]
|
||||
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
index = self.properties["attributes"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"{predecessors[0]}[{', '.join(elements)}]"
|
||||
|
||||
if name == "concatenate":
|
||||
args = [f"({', '.join(predecessors)})"]
|
||||
else:
|
||||
args = deepcopy(predecessors)
|
||||
|
||||
args.extend(
|
||||
format_constant(value, maximum_constant_length) for value in self.properties["args"]
|
||||
)
|
||||
args.extend(
|
||||
f"{name}={format_constant(value, maximum_constant_length)}"
|
||||
for name, value in self.properties["kwargs"].items()
|
||||
if name not in KWARGS_IGNORED_IN_FORMATTING
|
||||
)
|
||||
|
||||
return f"{name}({', '.join(args)})"
|
||||
|
||||
def label(self) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Node` (for drawing).
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Node` (for drawing).
|
||||
"""
|
||||
|
||||
if self.operation == Operation.Constant:
|
||||
return format_constant(self(), maximum_length=45, keep_newlines=True)
|
||||
|
||||
if self.operation == Operation.Input:
|
||||
return self.properties["name"]
|
||||
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
name = self.format(["index"])
|
||||
|
||||
return name
|
||||
|
||||
@property
|
||||
def converted_to_table_lookup(self) -> bool:
|
||||
"""
|
||||
Get whether the node is converted to a table lookup during MLIR conversion.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if the node is converted to a table lookup, False otherwise
|
||||
"""
|
||||
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"concatenate",
|
||||
"conv2d",
|
||||
"dot",
|
||||
"index.static",
|
||||
"matmul",
|
||||
"multiply",
|
||||
"negative",
|
||||
"reshape",
|
||||
"subtract",
|
||||
"sum",
|
||||
]
|
||||
29
concrete/numpy/representation/operation.py
Normal file
29
concrete/numpy/representation/operation.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Declaration of `Operation` enum.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Operation(Enum):
|
||||
"""
|
||||
Operation enum, to distinguish nodes within a computation graph.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
Constant = "constant"
|
||||
Generic = "generic"
|
||||
Input = "input"
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
# https://graphviz.org/doc/info/colors.html#svg
|
||||
|
||||
OPERATION_COLOR_MAPPING = {
|
||||
Operation.Constant: "grey",
|
||||
Operation.Generic: "black",
|
||||
Operation.Input: "crimson",
|
||||
"output": "gold",
|
||||
}
|
||||
114
concrete/numpy/representation/utils.py
Normal file
114
concrete/numpy/representation/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to representation of computation.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Hashable, Set, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
|
||||
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
|
||||
"subgraph",
|
||||
"terminal_node",
|
||||
}
|
||||
|
||||
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
|
||||
np.float16: "float16",
|
||||
np.float32: "float32",
|
||||
np.float64: "float64",
|
||||
np.int8: "int8",
|
||||
np.int16: "int16",
|
||||
np.int32: "int32",
|
||||
np.int64: "int64",
|
||||
np.uint8: "uint8",
|
||||
np.uint16: "uint16",
|
||||
np.uint32: "uint32",
|
||||
np.uint64: "uint64",
|
||||
np.byte: "byte",
|
||||
np.short: "short",
|
||||
np.intc: "intc",
|
||||
np.int_: "int_",
|
||||
np.longlong: "longlong",
|
||||
np.ubyte: "ubyte",
|
||||
np.ushort: "ushort",
|
||||
np.uintc: "uintc",
|
||||
np.uint: "uint",
|
||||
np.ulonglong: "ulonglong",
|
||||
}
|
||||
|
||||
|
||||
def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool = False) -> str:
|
||||
"""
|
||||
Get the textual representation of a constant.
|
||||
|
||||
Args:
|
||||
constant (Any):
|
||||
constant to format
|
||||
|
||||
maximum_length (int, default = 45):
|
||||
maximum length of the resulting string
|
||||
|
||||
keep_newlines (bool, default = False):
|
||||
whether to keep newlines or not
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of `constant`
|
||||
"""
|
||||
|
||||
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
|
||||
return SPECIAL_OBJECT_MAPPING[constant]
|
||||
|
||||
# maximum_length should not be smaller than 7 characters because
|
||||
# the constant will be formatted to `x ... y`
|
||||
# where x and y are part of the constant, and they are at least 1 character
|
||||
assert_that(maximum_length >= 7)
|
||||
|
||||
result = str(constant)
|
||||
if not keep_newlines:
|
||||
result = result.replace("\n", "")
|
||||
|
||||
if len(result) > maximum_length:
|
||||
from_start = (maximum_length - 5) // 2
|
||||
from_end = (maximum_length - 5) - from_start
|
||||
|
||||
if keep_newlines and "\n" in result:
|
||||
result = f"{result[:from_start]}\n...\n{result[-from_end:]}"
|
||||
else:
|
||||
result = f"{result[:from_start]} ... {result[-from_end:]}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_indexing_element(indexing_element: Union[int, np.integer, slice]):
|
||||
"""
|
||||
Format an indexing element.
|
||||
|
||||
This is required mainly for slices. The reason is that string representation of slices
|
||||
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
|
||||
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
|
||||
it will be formatted as `[:, 2:]`.
|
||||
|
||||
Args:
|
||||
indexing_element (Union[int, np.integer, slice]):
|
||||
indexing element to format
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of `indexing_element`
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if isinstance(indexing_element, slice):
|
||||
if indexing_element.start is not None:
|
||||
result += str(indexing_element.start)
|
||||
result += ":"
|
||||
if indexing_element.stop is not None:
|
||||
result += str(indexing_element.stop)
|
||||
if indexing_element.step is not None:
|
||||
result += ":"
|
||||
result += str(indexing_element.step)
|
||||
else:
|
||||
result += str(indexing_element)
|
||||
return result.replace("\n", " ")
|
||||
Reference in New Issue
Block a user