diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index ba60f8c23..fe3610e63 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -11,16 +11,20 @@ from ..data_types.dtypes_helpers import ( mix_scalar_values_determine_holding_dtype, ) +IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" + class IntermediateNode(ABC): """Abstract Base Class to derive from to represent source program operations.""" inputs: List[BaseValue] outputs: List[BaseValue] + _n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode def __init__( self, inputs: Iterable[BaseValue], + **_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes ) -> None: self.inputs = list(inputs) assert all(isinstance(x, BaseValue) for x in self.inputs) @@ -28,13 +32,15 @@ class IntermediateNode(ABC): def _init_binary( self, inputs: Iterable[BaseValue], + mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype, + **_kwargs, # Required to conform to __init__ typing ) -> None: """__init__ for a binary operation, ie two inputs.""" IntermediateNode.__init__(self, inputs) assert len(self.inputs) == 2 - self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])] + self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] def _is_equivalent_to_binary_commutative(self, other: object) -> bool: """is_equivalent_to for a binary and commutative operation.""" @@ -82,10 +88,30 @@ class IntermediateNode(ABC): Any: the result of the computation """ + @classmethod + def n_in(cls) -> int: + """Returns how many inputs the node has. + + Returns: + int: The number of inputs of the node. + """ + return cls._n_in + + @classmethod + def requires_mix_values_func(cls) -> bool: + """Function to determine whether the Class requires a mix_values_func to be built. + + Returns: + bool: True if __init__ expects a mix_values_func argument. + """ + return cls.n_in() > 1 + class Add(IntermediateNode): """Addition between two values.""" + _n_in: int = 2 + __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative @@ -96,6 +122,8 @@ class Add(IntermediateNode): class Sub(IntermediateNode): """Subtraction between two values.""" + _n_in: int = 2 + __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative @@ -106,6 +134,8 @@ class Sub(IntermediateNode): class Mul(IntermediateNode): """Multiplication between two values.""" + _n_in: int = 2 + __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative @@ -118,6 +148,7 @@ class Input(IntermediateNode): input_name: str program_input_idx: int + _n_in: int = 1 def __init__( self, @@ -147,6 +178,7 @@ class ConstantInput(IntermediateNode): """Node representing a constant of the program.""" _constant_data: Any + _n_in: int = 0 def __init__( self, @@ -191,6 +223,7 @@ class ArbitraryFunction(IntermediateNode): op_args: Tuple[Any, ...] op_kwargs: Dict[str, Any] op_name: str + _n_in: int = 1 def __init__( self, diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index e6774147a..2c9018e54 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -1,10 +1,11 @@ """This file holds the code that can be shared between tracers.""" from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Tuple, Type, Union from ..data_types import BaseValue from ..representation import intermediate as ir +from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME class BaseTracer(ABC): @@ -13,6 +14,7 @@ class BaseTracer(ABC): inputs: List["BaseTracer"] traced_computation: ir.IntermediateNode output: BaseValue + _mix_values_func: Callable[..., BaseValue] def __init__( self, @@ -47,6 +49,10 @@ class BaseTracer(ABC): BaseTracer: The BaseTracer for that constant. """ + @classmethod + def _get_mix_values_func(cls): + return cls._mix_values_func + def instantiate_output_tracers( self, inputs: Iterable[Union["BaseTracer", Any]], @@ -71,8 +77,15 @@ class BaseTracer(ABC): sanitized_inputs = [sanitize(inp) for inp in inputs] + additional_parameters = ( + {IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()} + if computation_to_trace.requires_mix_values_func() + else {} + ) + traced_computation = computation_to_trace( (x.output for x in sanitized_inputs), + **additional_parameters, ) output_tracers = tuple( diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 075409c38..91a4ac401 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -7,6 +7,7 @@ import numpy from numpy.typing import DTypeLike from ..common.data_types import BaseValue +from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype from ..common.operator_graph import OPGraph from ..common.representation.intermediate import ArbitraryFunction, ConstantInput from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters @@ -30,6 +31,8 @@ NPConstantInput = partial( class NPTracer(BaseTracer): """Tracer class for numpy operations.""" + _mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype + def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs): """Catch calls to numpy ufunc and routes them to tracing functions if supported.