feat: manage sanitization properly for numpy functions

- allows to use dot with constant inputs
This commit is contained in:
Arthur Meyre
2021-10-07 12:39:03 +02:00
parent 4a77d0515a
commit 0317dd49ea

View File

@@ -7,7 +7,7 @@ import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import custom_assert
from ..common.debugging.custom_assert import assert_true, custom_assert
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
@@ -53,7 +53,7 @@ class NPTracer(BaseTracer):
raise NotImplementedError("Only __call__ method is supported currently")
def __array_function__(self, func, _types, args, kwargs):
"""Catch calls to numpy function in routes them to hnp functions if supported.
"""Catch calls to numpy function in routes them to tracing functions if supported.
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
@@ -62,7 +62,8 @@ class NPTracer(BaseTracer):
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy functions, func: {func}",
)
return tracing_func(*args, **kwargs)
sanitized_args = [self._sanitize(arg) for arg in args]
return tracing_func(self, *sanitized_args, **kwargs)
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
r"""Support numpy astype feature.
@@ -222,26 +223,25 @@ class NPTracer(BaseTracer):
)
return output_tracer
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
def dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
"""Trace numpy.dot.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
# input_tracers contains the other tracer of the dot product
dot_inputs = (self, self._sanitize(other_tracer))
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs)
common_output_dtypes = self._manage_dtypes(numpy.dot, *args)
custom_assert(len(common_output_dtypes) == 1)
traced_computation = Dot(
[input_tracer.output for input_tracer in dot_inputs],
[input_tracer.output for input_tracer in args],
common_output_dtypes[0],
delegate_evaluation_function=numpy.dot,
)
output_tracer = self.__class__(
dot_inputs,
args,
traced_computation=traced_computation,
output_index=0,
)