diff --git a/concrete/numpy/representation/evaluator.py b/concrete/numpy/representation/evaluator.py new file mode 100644 index 000000000..56470e3fa --- /dev/null +++ b/concrete/numpy/representation/evaluator.py @@ -0,0 +1,50 @@ +""" +Declaration of various `Evaluator` classes, to make graphs picklable. +""" + + +class ConstantEvaluator: + """ + ConstantEvaluator class, to evaluate Operation.Constant nodes. + """ + + def __init__(self, properties): + self.properties = properties + + def __call__(self, *args, **kwargs): + return self.properties["constant"] + + +class InputEvaluator: + """ + InputEvaluator class, to evaluate Operation.Input nodes. + """ + + def __call__(self, *args, **kwargs): + return args[0] + + +class GenericEvaluator: + """ + GenericEvaluator class, to evaluate Operation.Generic nodes. + """ + + def __init__(self, operation, properties): + self.operation = operation + self.properties = properties + + def __call__(self, *args, **kwargs): + return self.operation(*args, *self.properties["args"], **self.properties["kwargs"]) + + +class GenericTupleEvaluator: + """ + GenericEvaluator class, to evaluate Operation.Generic nodes where args are packed in a tuple. + """ + + def __init__(self, operation, properties): + self.operation = operation + self.properties = properties + + def __call__(self, *args, **kwargs): + return self.operation(tuple(args), *self.properties["args"], **self.properties["kwargs"]) diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 00c4172a0..5d796bfef 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -9,6 +9,7 @@ import numpy as np from ..internal.utils import assert_that from ..values import Value +from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator from .operation import Operation from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element @@ -50,7 +51,7 @@ class Node: raise ValueError(f"Constant {repr(constant)} is not supported") from error properties = {"constant": np.array(constant)} - return Node([], value, Operation.Constant, lambda: properties["constant"], properties) + return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties) @staticmethod def generic( @@ -99,17 +100,17 @@ class Node: "attributes": attributes if attributes is not None else {}, } - if name == "concatenate": - - def evaluator(*args): - return operation(tuple(args), *properties["args"], **properties["kwargs"]) - - else: - - def evaluator(*args): - return operation(*args, *properties["args"], **properties["kwargs"]) - - return Node(inputs, output, Operation.Generic, evaluator, properties) + return Node( + inputs, + output, + Operation.Generic, + ( + GenericTupleEvaluator(operation, properties) # type: ignore + if name in ["concatenate"] + else GenericEvaluator(operation, properties) # type: ignore + ), + properties, + ) @staticmethod def input(name: str, value: Value) -> "Node": @@ -128,7 +129,7 @@ class Node: node representing input """ - return Node([value], value, Operation.Input, lambda x: x, {"name": name}) + return Node([value], value, Operation.Input, InputEvaluator(), {"name": name}) def __init__( self,