feat: add support for __floordiv__

closes #871
closes #872
This commit is contained in:
Arthur Meyre
2021-11-10 15:18:56 +01:00
parent c5952cd09f
commit 93820e1588
3 changed files with 47 additions and 9 deletions

View File

@@ -1,9 +1,11 @@
"""This file holds the code that can be shared between tracers."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
from ..data_types import Float
from ..data_types.base import BaseDataType
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import (
IR_MIX_VALUES_FUNC_ARG_NAME,
@@ -173,8 +175,13 @@ class BaseTracer(ABC):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)
def _truediv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
def _div_common(
self,
lhs: Union["BaseTracer", Any],
rhs: Union["BaseTracer", Any],
div_op: Callable,
op_name: str,
output_dtype: Optional[BaseDataType] = None,
) -> "BaseTracer":
if isinstance(lhs, BaseTracer):
if not self._supports_other_operand(rhs):
@@ -190,27 +197,43 @@ class BaseTracer(ABC):
isinstance(sanitized_inputs[0].traced_computation, Constant)
or isinstance(sanitized_inputs[1].traced_computation, Constant)
):
raise NotImplementedError("Can't manage binary operator truediv")
raise NotImplementedError(f"Can't manage binary operator {op_name}")
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
output_value = self._get_mix_values_func()(*sanitized_input_values)
# The true division in python is always float64
output_value.dtype = Float(64)
if output_dtype is not None:
output_value.dtype = deepcopy(output_dtype)
traced_computation = GenericFunction(
inputs=sanitized_input_values,
arbitrary_func=lambda x, y: x / y,
arbitrary_func=div_op,
output_value=output_value,
op_kind="TLU",
op_name="truediv",
op_name=op_name,
)
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
return result_tracer
def _truediv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64))
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(self, other)
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(other, self)
def _floordiv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv")
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._floordiv(self, other)
def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._floordiv(other, self)

View File

@@ -174,6 +174,7 @@ def mix_x_and_y_and_call_f_with_float_inputs(func, x, y):
def mix_x_and_y_and_call_f_with_integer_inputs(func, x, y):
"""Create an upper function to test `func`, with inputs which are forced to be integers but
in a way which is fusable into a TLU"""
x = x // 2
a = x + 0.1
a = numpy.rint(a).astype(numpy.int32)
z = numpy.abs(10 * func(a))
@@ -201,8 +202,9 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y):
def mix_x_and_y_and_call_f_avoid_0_input(func, x, y):
"""Create an upper function to test `func`, which makes that inputs are not 0"""
a = numpy.abs(7 * numpy.sin(x)) + 1
c = 100 // a
b = 100 / a
a = a + b
a = a + b + c
z = numpy.abs(5 * func(a))
z = z.astype(numpy.int32) + y
return z

View File

@@ -651,9 +651,22 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
TypeError,
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x // "fail",
TypeError,
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" // x,
TypeError,
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
),
pytest.param(
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
),
pytest.param(
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
),
],
)
def test_nptracer_unsupported_operands(operation, exception_type, match):