Files
concrete/hdk/common/tracing/base_tracer.py
Arthur Meyre ff260b2cd2 feat(nptracer): add dot tracing abilities
- remove no cover from Dot.label
- small refactor of BaseTracer to make _sanitize a class method
- small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs
- add function routing to NPTracer
- add dot tracing to NPTracer
- small refactor to get tracing functions for numpy funcs and ufuncs
2021-08-25 11:00:36 +02:00

156 lines
4.9 KiB
Python

"""This file holds the code that can be shared between tracers."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
from ..representation import intermediate as ir
from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME
from ..values import BaseValue
class BaseTracer(ABC):
"""Base class for implementing tracers."""
inputs: List["BaseTracer"]
traced_computation: ir.IntermediateNode
output: BaseValue
_mix_values_func: Callable[..., BaseValue]
def __init__(
self,
inputs: Iterable["BaseTracer"],
traced_computation: ir.IntermediateNode,
output_index: int,
) -> None:
self.inputs = list(inputs)
self.traced_computation = traced_computation
self.output = traced_computation.outputs[output_index]
@abstractmethod
def _supports_other_operand(self, other: Any) -> bool:
"""Function to check if the current class supports tracing with the other operand.
Args:
other (Any): the operand to check compatibility with.
Returns:
bool: True if the tracer can manage operations with the other operand.
"""
return isinstance(other, self.__class__)
@abstractmethod
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
"""Helper function to create a tracer for a constant input.
Args:
constant_data (Any): The constant to store.
Returns:
BaseTracer: The BaseTracer for that constant.
"""
@classmethod
def _get_mix_values_func(cls):
return cls._mix_values_func
def _sanitize(self, inp) -> "BaseTracer":
if not isinstance(inp, BaseTracer):
return self._make_const_input_tracer(inp)
return inp
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Any]],
computation_to_trace: Type[ir.IntermediateNode],
) -> Tuple["BaseTracer", ...]:
"""Helper functions to instantiate all output BaseTracer for a given computation.
Args:
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
for a new node.
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
to instantiate for the computation being traced
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
"""
# For inputs which are actually constant, first convert into a tracer
sanitized_inputs = [self._sanitize(inp) for inp in inputs]
additional_parameters = (
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
if computation_to_trace.requires_mix_values_func()
else {}
)
traced_computation = computation_to_trace(
(x.output for x in sanitized_inputs),
**additional_parameters,
)
output_tracers = tuple(
self.__class__(sanitized_inputs, traced_computation, output_index)
for output_index in range(len(traced_computation.outputs))
)
return output_tracers
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Add,
)
assert len(result_tracer) == 1
return result_tracer[0]
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
# the order, we need to do as in __rsub__, ie mostly a copy of __sub__ +
# some changes
__radd__ = __add__
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Sub,
)
assert len(result_tracer) == 1
return result_tracer[0]
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[other, self],
ir.Sub,
)
assert len(result_tracer) == 1
return result_tracer[0]
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Mul,
)
assert len(result_tracer) == 1
return result_tracer[0]
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
# some changes
__rmul__ = __mul__