refactor: change IR and BaseTracer to support any input mixing function

- prepares support for tensor types
- introduce _n_in and n_in() and requires_mix_values_func() for
IntermediateNode to know if they require a function to mix input values and
determine output value
- update BaseTracer to pass the class _mix_values_func to IntermediateNodes
that need it
- update NPTracer to stay consistent with current behavior regarding values
mixing
This commit is contained in:
Arthur Meyre
2021-08-19 17:31:15 +02:00
parent 78480e5da7
commit 7a0f11b1b0
3 changed files with 51 additions and 2 deletions

View File

@@ -11,16 +11,20 @@ from ..data_types.dtypes_helpers import (
mix_scalar_values_determine_holding_dtype,
)
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
class IntermediateNode(ABC):
"""Abstract Base Class to derive from to represent source program operations."""
inputs: List[BaseValue]
outputs: List[BaseValue]
_n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode
def __init__(
self,
inputs: Iterable[BaseValue],
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
) -> None:
self.inputs = list(inputs)
assert all(isinstance(x, BaseValue) for x in self.inputs)
@@ -28,13 +32,15 @@ class IntermediateNode(ABC):
def _init_binary(
self,
inputs: Iterable[BaseValue],
mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype,
**_kwargs, # Required to conform to __init__ typing
) -> None:
"""__init__ for a binary operation, ie two inputs."""
IntermediateNode.__init__(self, inputs)
assert len(self.inputs) == 2
self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""
@@ -82,10 +88,30 @@ class IntermediateNode(ABC):
Any: the result of the computation
"""
@classmethod
def n_in(cls) -> int:
"""Returns how many inputs the node has.
Returns:
int: The number of inputs of the node.
"""
return cls._n_in
@classmethod
def requires_mix_values_func(cls) -> bool:
"""Function to determine whether the Class requires a mix_values_func to be built.
Returns:
bool: True if __init__ expects a mix_values_func argument.
"""
return cls.n_in() > 1
class Add(IntermediateNode):
"""Addition between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
@@ -96,6 +122,8 @@ class Add(IntermediateNode):
class Sub(IntermediateNode):
"""Subtraction between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
@@ -106,6 +134,8 @@ class Sub(IntermediateNode):
class Mul(IntermediateNode):
"""Multiplication between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
@@ -118,6 +148,7 @@ class Input(IntermediateNode):
input_name: str
program_input_idx: int
_n_in: int = 1
def __init__(
self,
@@ -147,6 +178,7 @@ class ConstantInput(IntermediateNode):
"""Node representing a constant of the program."""
_constant_data: Any
_n_in: int = 0
def __init__(
self,
@@ -191,6 +223,7 @@ class ArbitraryFunction(IntermediateNode):
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_name: str
_n_in: int = 1
def __init__(
self,

View File

@@ -1,10 +1,11 @@
"""This file holds the code that can be shared between tracers."""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
from ..data_types import BaseValue
from ..representation import intermediate as ir
from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME
class BaseTracer(ABC):
@@ -13,6 +14,7 @@ class BaseTracer(ABC):
inputs: List["BaseTracer"]
traced_computation: ir.IntermediateNode
output: BaseValue
_mix_values_func: Callable[..., BaseValue]
def __init__(
self,
@@ -47,6 +49,10 @@ class BaseTracer(ABC):
BaseTracer: The BaseTracer for that constant.
"""
@classmethod
def _get_mix_values_func(cls):
return cls._mix_values_func
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Any]],
@@ -71,8 +77,15 @@ class BaseTracer(ABC):
sanitized_inputs = [sanitize(inp) for inp in inputs]
additional_parameters = (
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
if computation_to_trace.requires_mix_values_func()
else {}
)
traced_computation = computation_to_trace(
(x.output for x in sanitized_inputs),
**additional_parameters,
)
output_tracers = tuple(

View File

@@ -7,6 +7,7 @@ import numpy
from numpy.typing import DTypeLike
from ..common.data_types import BaseValue
from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import ArbitraryFunction, ConstantInput
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
@@ -30,6 +31,8 @@ NPConstantInput = partial(
class NPTracer(BaseTracer):
"""Tracer class for numpy operations."""
_mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.