mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
364 lines
11 KiB
Python
364 lines
11 KiB
Python
"""
|
|
Declaration of `Node` class.
|
|
"""
|
|
|
|
from copy import deepcopy
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
from ..internal.utils import assert_that
|
|
from ..values import Value
|
|
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
|
|
from .operation import Operation
|
|
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
|
|
|
|
|
|
class Node:
|
|
"""
|
|
Node class, to represent computation in a computation graph.
|
|
"""
|
|
|
|
inputs: List[Value]
|
|
output: Value
|
|
|
|
operation: Operation
|
|
evaluator: Callable
|
|
|
|
properties: Dict[str, Any]
|
|
|
|
@staticmethod
|
|
def constant(constant: Any) -> "Node":
|
|
"""
|
|
Create an Operation.Constant node.
|
|
|
|
Args:
|
|
constant (Any):
|
|
constant to represent
|
|
|
|
Returns:
|
|
Node:
|
|
node representing constant
|
|
|
|
Raises:
|
|
ValueError:
|
|
if the constant is not representable
|
|
"""
|
|
|
|
try:
|
|
value = Value.of(constant)
|
|
except Exception as error:
|
|
raise ValueError(f"Constant {repr(constant)} is not supported") from error
|
|
|
|
properties = {"constant": np.array(constant)}
|
|
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
|
|
|
|
@staticmethod
|
|
def generic(
|
|
name: str,
|
|
inputs: List[Value],
|
|
output: Value,
|
|
operation: Callable,
|
|
args: Optional[Tuple[Any, ...]] = None,
|
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
attributes: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""
|
|
Create an Operation.Generic node.
|
|
|
|
Args:
|
|
name (str):
|
|
name of the operation
|
|
|
|
inputs (List[Value]):
|
|
inputs to the operation
|
|
|
|
output (Value):
|
|
output of the operation
|
|
|
|
operation (Callable):
|
|
operation itself
|
|
|
|
args (Optional[Tuple[Any, ...]]):
|
|
args to pass to operation during evaluation
|
|
|
|
kwargs (Optional[Dict[str, Any]]):
|
|
kwargs to pass to operation during evaluation
|
|
|
|
attributes (Optional[Dict[str, Any]]):
|
|
attributes of the operation
|
|
|
|
Returns:
|
|
Node:
|
|
node representing operation
|
|
"""
|
|
|
|
properties = {
|
|
"name": name,
|
|
"args": args if args is not None else (),
|
|
"kwargs": kwargs if kwargs is not None else {},
|
|
"attributes": attributes if attributes is not None else {},
|
|
}
|
|
|
|
return Node(
|
|
inputs,
|
|
output,
|
|
Operation.Generic,
|
|
(
|
|
GenericTupleEvaluator(operation, properties) # type: ignore
|
|
if name in ["concatenate"]
|
|
else GenericEvaluator(operation, properties) # type: ignore
|
|
),
|
|
properties,
|
|
)
|
|
|
|
@staticmethod
|
|
def input(name: str, value: Value) -> "Node":
|
|
"""
|
|
Create an Operation.Input node.
|
|
|
|
Args:
|
|
name (Any):
|
|
name of the input
|
|
|
|
value (Any):
|
|
value of the input
|
|
|
|
Returns:
|
|
Node:
|
|
node representing input
|
|
"""
|
|
|
|
return Node([value], value, Operation.Input, InputEvaluator(), {"name": name})
|
|
|
|
def __init__(
|
|
self,
|
|
inputs: List[Value],
|
|
output: Value,
|
|
operation: Operation,
|
|
evaluator: Callable,
|
|
properties: Optional[Dict[str, Any]] = None,
|
|
):
|
|
self.inputs = inputs
|
|
self.output = output
|
|
|
|
self.operation = operation
|
|
self.evaluator = evaluator # type: ignore
|
|
|
|
self.properties = properties if properties is not None else {}
|
|
|
|
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
|
|
def generic_error_message() -> str:
|
|
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
|
|
if len(args) != 0:
|
|
result += f" using {', '.join(repr(arg) for arg in args)}"
|
|
return result
|
|
|
|
if len(args) != len(self.inputs):
|
|
raise ValueError(
|
|
f"{generic_error_message()} failed because of invalid number of arguments"
|
|
)
|
|
|
|
for arg, input_ in zip(args, self.inputs):
|
|
try:
|
|
arg_value = Value.of(arg)
|
|
except Exception as error:
|
|
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
|
raise ValueError(
|
|
f"{generic_error_message()} failed because {arg_str} is not valid"
|
|
) from error
|
|
|
|
if input_.shape != arg_value.shape:
|
|
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
|
raise ValueError(
|
|
f"{generic_error_message()} failed because "
|
|
f"{arg_str} does not have the expected "
|
|
f"shape of {input_.shape}"
|
|
)
|
|
|
|
result = self.evaluator(*args)
|
|
|
|
if isinstance(result, int) and -(2**63) < result < (2**63) - 1:
|
|
result = np.int64(result)
|
|
if isinstance(result, float):
|
|
result = np.float64(result)
|
|
|
|
if isinstance(result, list):
|
|
try:
|
|
np_result = np.array(result)
|
|
result = np_result
|
|
except Exception: # pylint: disable=broad-except
|
|
# here we try our best to convert the list to np.ndarray
|
|
# if it fails we raise the exception below
|
|
pass
|
|
|
|
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
|
|
raise ValueError(
|
|
f"{generic_error_message()} resulted in {repr(result)} "
|
|
f"of type {result.__class__.__name__} "
|
|
f"which is not acceptable either because of the type or because of overflow"
|
|
)
|
|
|
|
if isinstance(result, np.ndarray):
|
|
dtype = result.dtype
|
|
if (
|
|
not np.issubdtype(dtype, np.integer)
|
|
and not np.issubdtype(dtype, np.floating)
|
|
and not np.issubdtype(dtype, np.bool_)
|
|
):
|
|
raise ValueError(
|
|
f"{generic_error_message()} resulted in {repr(result)} "
|
|
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
|
|
f"which is not acceptable because of the underlying type"
|
|
)
|
|
|
|
if result.shape != self.output.shape:
|
|
raise ValueError(
|
|
f"{generic_error_message()} resulted in {repr(result)} "
|
|
f"which does not have the expected "
|
|
f"shape of {self.output.shape}"
|
|
)
|
|
|
|
return result
|
|
|
|
def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> str:
|
|
"""
|
|
Get the textual representation of the `Node` (dependent to preds).
|
|
|
|
Args:
|
|
predecessors (List[str]):
|
|
predecessor names to this node
|
|
|
|
maximum_constant_length (int, default = 45):
|
|
maximum length of formatted constants
|
|
|
|
Returns:
|
|
str:
|
|
textual representation of the `Node` (dependent to preds)
|
|
"""
|
|
|
|
if self.operation == Operation.Constant:
|
|
return format_constant(self(), maximum_constant_length)
|
|
|
|
if self.operation == Operation.Input:
|
|
return self.properties["name"]
|
|
|
|
assert_that(self.operation == Operation.Generic)
|
|
|
|
name = self.properties["name"]
|
|
|
|
if name == "index.static":
|
|
index = self.properties["kwargs"]["index"]
|
|
elements = [format_indexing_element(element) for element in index]
|
|
return f"{predecessors[0]}[{', '.join(elements)}]"
|
|
|
|
if name == "assign.static":
|
|
index = self.properties["kwargs"]["index"]
|
|
elements = [format_indexing_element(element) for element in index]
|
|
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
|
|
|
|
if name == "concatenate":
|
|
args = [f"({', '.join(predecessors)})"]
|
|
else:
|
|
args = deepcopy(predecessors)
|
|
|
|
if name == "array":
|
|
values = str(np.array(predecessors).reshape(self.output.shape).tolist()).replace(
|
|
"'", ""
|
|
)
|
|
return f"array({format_constant(values, maximum_constant_length)})"
|
|
|
|
args.extend(
|
|
format_constant(value, maximum_constant_length) for value in self.properties["args"]
|
|
)
|
|
args.extend(
|
|
f"{name}={format_constant(value, maximum_constant_length)}"
|
|
for name, value in self.properties["kwargs"].items()
|
|
if name not in KWARGS_IGNORED_IN_FORMATTING
|
|
)
|
|
|
|
return f"{name}({', '.join(args)})"
|
|
|
|
def label(self) -> str:
|
|
"""
|
|
Get the textual representation of the `Node` (independent of preds).
|
|
|
|
Returns:
|
|
str:
|
|
textual representation of the `Node` (independent of preds).
|
|
"""
|
|
|
|
if self.operation == Operation.Constant:
|
|
return format_constant(self(), maximum_length=45, keep_newlines=True)
|
|
|
|
if self.operation == Operation.Input:
|
|
return self.properties["name"]
|
|
|
|
assert_that(self.operation == Operation.Generic)
|
|
|
|
name = self.properties["name"]
|
|
|
|
if name == "index.static":
|
|
name = self.format(["□"])
|
|
|
|
if name == "assign.static":
|
|
name = self.format(["□", "□"])[1:-1]
|
|
|
|
return name
|
|
|
|
@property
|
|
def converted_to_table_lookup(self) -> bool:
|
|
"""
|
|
Get whether the node is converted to a table lookup during MLIR conversion.
|
|
|
|
Returns:
|
|
bool:
|
|
True if the node is converted to a table lookup, False otherwise
|
|
"""
|
|
|
|
return self.operation == Operation.Generic and self.properties["name"] not in [
|
|
"add",
|
|
"array",
|
|
"assign.static",
|
|
"broadcast_to",
|
|
"concatenate",
|
|
"conv1d",
|
|
"conv2d",
|
|
"conv3d",
|
|
"dot",
|
|
"expand_dims",
|
|
"index.static",
|
|
"matmul",
|
|
"maxpool",
|
|
"multiply",
|
|
"negative",
|
|
"ones",
|
|
"reshape",
|
|
"subtract",
|
|
"sum",
|
|
"transpose",
|
|
"zeros",
|
|
]
|
|
|
|
@property
|
|
def is_fusable(self) -> bool:
|
|
"""
|
|
Get whether the node is can be fused into a table lookup.
|
|
|
|
Returns:
|
|
bool:
|
|
True if the node can be fused into a table lookup, False otherwise
|
|
"""
|
|
|
|
if self.converted_to_table_lookup:
|
|
return True
|
|
|
|
return self.operation != Operation.Generic or self.properties["name"] in [
|
|
"add",
|
|
"multiply",
|
|
"negative",
|
|
"ones",
|
|
"subtract",
|
|
"zeros",
|
|
]
|