feat: make graph picklable

This commit is contained in:
Umut
2022-05-05 16:44:26 +02:00
parent a8d929f1f1
commit 6739e2e8ab
2 changed files with 64 additions and 13 deletions

View File

@@ -0,0 +1,50 @@
"""
Declaration of various `Evaluator` classes, to make graphs picklable.
"""
class ConstantEvaluator:
"""
ConstantEvaluator class, to evaluate Operation.Constant nodes.
"""
def __init__(self, properties):
self.properties = properties
def __call__(self, *args, **kwargs):
return self.properties["constant"]
class InputEvaluator:
"""
InputEvaluator class, to evaluate Operation.Input nodes.
"""
def __call__(self, *args, **kwargs):
return args[0]
class GenericEvaluator:
"""
GenericEvaluator class, to evaluate Operation.Generic nodes.
"""
def __init__(self, operation, properties):
self.operation = operation
self.properties = properties
def __call__(self, *args, **kwargs):
return self.operation(*args, *self.properties["args"], **self.properties["kwargs"])
class GenericTupleEvaluator:
"""
GenericEvaluator class, to evaluate Operation.Generic nodes where args are packed in a tuple.
"""
def __init__(self, operation, properties):
self.operation = operation
self.properties = properties
def __call__(self, *args, **kwargs):
return self.operation(tuple(args), *self.properties["args"], **self.properties["kwargs"])

View File

@@ -9,6 +9,7 @@ import numpy as np
from ..internal.utils import assert_that
from ..values import Value
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
from .operation import Operation
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
@@ -50,7 +51,7 @@ class Node:
raise ValueError(f"Constant {repr(constant)} is not supported") from error
properties = {"constant": np.array(constant)}
return Node([], value, Operation.Constant, lambda: properties["constant"], properties)
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
@staticmethod
def generic(
@@ -99,17 +100,17 @@ class Node:
"attributes": attributes if attributes is not None else {},
}
if name == "concatenate":
def evaluator(*args):
return operation(tuple(args), *properties["args"], **properties["kwargs"])
else:
def evaluator(*args):
return operation(*args, *properties["args"], **properties["kwargs"])
return Node(inputs, output, Operation.Generic, evaluator, properties)
return Node(
inputs,
output,
Operation.Generic,
(
GenericTupleEvaluator(operation, properties) # type: ignore
if name in ["concatenate"]
else GenericEvaluator(operation, properties) # type: ignore
),
properties,
)
@staticmethod
def input(name: str, value: Value) -> "Node":
@@ -128,7 +129,7 @@ class Node:
node representing input
"""
return Node([value], value, Operation.Input, lambda x: x, {"name": name})
return Node([value], value, Operation.Input, InputEvaluator(), {"name": name})
def __init__(
self,