mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: make graph picklable
This commit is contained in:
50
concrete/numpy/representation/evaluator.py
Normal file
50
concrete/numpy/representation/evaluator.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Declaration of various `Evaluator` classes, to make graphs picklable.
|
||||
"""
|
||||
|
||||
|
||||
class ConstantEvaluator:
|
||||
"""
|
||||
ConstantEvaluator class, to evaluate Operation.Constant nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, properties):
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.properties["constant"]
|
||||
|
||||
|
||||
class InputEvaluator:
|
||||
"""
|
||||
InputEvaluator class, to evaluate Operation.Input nodes.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
|
||||
class GenericEvaluator:
|
||||
"""
|
||||
GenericEvaluator class, to evaluate Operation.Generic nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, operation, properties):
|
||||
self.operation = operation
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.operation(*args, *self.properties["args"], **self.properties["kwargs"])
|
||||
|
||||
|
||||
class GenericTupleEvaluator:
|
||||
"""
|
||||
GenericEvaluator class, to evaluate Operation.Generic nodes where args are packed in a tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, operation, properties):
|
||||
self.operation = operation
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.operation(tuple(args), *self.properties["args"], **self.properties["kwargs"])
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..values import Value
|
||||
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
|
||||
from .operation import Operation
|
||||
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
|
||||
|
||||
@@ -50,7 +51,7 @@ class Node:
|
||||
raise ValueError(f"Constant {repr(constant)} is not supported") from error
|
||||
|
||||
properties = {"constant": np.array(constant)}
|
||||
return Node([], value, Operation.Constant, lambda: properties["constant"], properties)
|
||||
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
|
||||
|
||||
@staticmethod
|
||||
def generic(
|
||||
@@ -99,17 +100,17 @@ class Node:
|
||||
"attributes": attributes if attributes is not None else {},
|
||||
}
|
||||
|
||||
if name == "concatenate":
|
||||
|
||||
def evaluator(*args):
|
||||
return operation(tuple(args), *properties["args"], **properties["kwargs"])
|
||||
|
||||
else:
|
||||
|
||||
def evaluator(*args):
|
||||
return operation(*args, *properties["args"], **properties["kwargs"])
|
||||
|
||||
return Node(inputs, output, Operation.Generic, evaluator, properties)
|
||||
return Node(
|
||||
inputs,
|
||||
output,
|
||||
Operation.Generic,
|
||||
(
|
||||
GenericTupleEvaluator(operation, properties) # type: ignore
|
||||
if name in ["concatenate"]
|
||||
else GenericEvaluator(operation, properties) # type: ignore
|
||||
),
|
||||
properties,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def input(name: str, value: Value) -> "Node":
|
||||
@@ -128,7 +129,7 @@ class Node:
|
||||
node representing input
|
||||
"""
|
||||
|
||||
return Node([value], value, Operation.Input, lambda x: x, {"name": name})
|
||||
return Node([value], value, Operation.Input, InputEvaluator(), {"name": name})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user