feat: add support for __truediv__

- cannot use the standard binary op workflow as we don't have an op for div

closes #866
closes #867
This commit is contained in:
Arthur Meyre
2021-11-10 11:29:30 +01:00
parent e316c1b3ba
commit 955470fb89
4 changed files with 106 additions and 17 deletions

View File

@@ -3,10 +3,13 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
from ..data_types import Float
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import (
IR_MIX_VALUES_FUNC_ARG_NAME,
Add,
Constant,
GenericFunction,
IndexConstant,
IntermediateNode,
Mul,
@@ -169,3 +172,45 @@ class BaseTracer(ABC):
def __getitem__(self, item):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)
def _truediv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
if isinstance(lhs, BaseTracer):
if not self._supports_other_operand(rhs):
return NotImplemented
elif isinstance(rhs, BaseTracer):
if not self._supports_other_operand(lhs):
return NotImplemented
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
# One of the inputs has to be constant
if not (
isinstance(sanitized_inputs[0].traced_computation, Constant)
or isinstance(sanitized_inputs[1].traced_computation, Constant)
):
raise NotImplementedError("Can't manage binary operator truediv")
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
output_value = self._get_mix_values_func()(*sanitized_input_values)
# The true division in python is always float64
output_value.dtype = Float(64)
traced_computation = GenericFunction(
inputs=sanitized_input_values,
arbitrary_func=lambda x, y: x / y,
output_value=output_value,
op_kind="TLU",
op_name="truediv",
)
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
return result_tracer
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(self, other)
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(other, self)

View File

@@ -585,8 +585,10 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
tensor_shape
)
if tensor_shape != ()
else 1
else numpy.int64(1)
)
# Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail
# as python int and float don't have the astype method
input_ = input_ * tensor_diversifier
num_params = len(params_names)

View File

@@ -159,6 +159,7 @@ def complicated_topology(x):
def mix_x_and_y_and_call_f(func, x, y):
"""Create an upper function to test `func`"""
z = numpy.abs(10 * func(x))
z = z / 2
z = z.astype(numpy.int32) + y
return z
@@ -200,6 +201,8 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y):
def mix_x_and_y_and_call_f_avoid_0_input(func, x, y):
"""Create an upper function to test `func`, which makes that inputs are not 0"""
a = numpy.abs(7 * numpy.sin(x)) + 1
b = 100 / a
a = a + b
z = numpy.abs(5 * func(a))
z = z.astype(numpy.int32) + y
return z

View File

@@ -1,5 +1,6 @@
"""Test file for numpy tracing"""
import inspect
from copy import deepcopy
import networkx as nx
@@ -608,26 +609,64 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
@pytest.mark.parametrize(
"tracer",
"operation,exception_type,match",
[
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), "x", 0), 0),
pytest.param(
lambda x: x + "fail",
TypeError,
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" + x,
TypeError,
'can only concatenate str (not "NPTracer") to str',
),
pytest.param(
lambda x: x - "fail",
TypeError,
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" - x,
TypeError,
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x * "fail",
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: "fail" * x,
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: x / "fail",
TypeError,
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" / x,
TypeError,
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
),
pytest.param(
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
),
],
)
@pytest.mark.parametrize(
"operation",
[
lambda x: x + "fail",
lambda x: "fail" + x,
lambda x: x - "fail",
lambda x: "fail" - x,
lambda x: x * "fail",
lambda x: "fail" * x,
],
)
def test_nptracer_unsupported_operands(operation, tracer):
def test_nptracer_unsupported_operands(operation, exception_type, match):
"""Test cases where NPTracer cannot be used with other operands."""
with pytest.raises(TypeError):
tracer = operation(tracer)
tracers = [
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
]
with pytest.raises(exception_type) as exc_info:
_ = operation(*tracers)
assert match in str(exc_info)
@pytest.mark.parametrize(