From 93820e15885846456d654a8d3ef67702f422a061 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 10 Nov 2021 15:18:56 +0100 Subject: [PATCH] feat: add support for __floordiv__ closes #871 closes #872 --- concrete/common/tracing/base_tracer.py | 39 ++++++++++++++++++++------ tests/numpy/test_compile.py | 4 ++- tests/numpy/test_tracing.py | 13 +++++++++ 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index acd4163e0..f132466b6 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -1,9 +1,11 @@ """This file holds the code that can be shared between tracers.""" from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Tuple, Type, Union +from copy import deepcopy +from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union from ..data_types import Float +from ..data_types.base import BaseDataType from ..debugging.custom_assert import assert_true from ..representation.intermediate import ( IR_MIX_VALUES_FUNC_ARG_NAME, @@ -173,8 +175,13 @@ class BaseTracer(ABC): traced_computation = IndexConstant(self.output, item) return self.__class__([self], traced_computation, 0) - def _truediv( - self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] + def _div_common( + self, + lhs: Union["BaseTracer", Any], + rhs: Union["BaseTracer", Any], + div_op: Callable, + op_name: str, + output_dtype: Optional[BaseDataType] = None, ) -> "BaseTracer": if isinstance(lhs, BaseTracer): if not self._supports_other_operand(rhs): @@ -190,27 +197,43 @@ class BaseTracer(ABC): isinstance(sanitized_inputs[0].traced_computation, Constant) or isinstance(sanitized_inputs[1].traced_computation, Constant) ): - raise NotImplementedError("Can't manage binary operator truediv") + raise NotImplementedError(f"Can't manage binary operator {op_name}") sanitized_input_values = [san_input.output for san_input in sanitized_inputs] output_value = self._get_mix_values_func()(*sanitized_input_values) - # The true division in python is always float64 - output_value.dtype = Float(64) + if output_dtype is not None: + output_value.dtype = deepcopy(output_dtype) traced_computation = GenericFunction( inputs=sanitized_input_values, - arbitrary_func=lambda x, y: x / y, + arbitrary_func=div_op, output_value=output_value, op_kind="TLU", - op_name="truediv", + op_name=op_name, ) result_tracer = self.__class__(sanitized_inputs, traced_computation, 0) return result_tracer + def _truediv( + self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] + ) -> "BaseTracer": + return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64)) + def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": return self._truediv(self, other) def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": return self._truediv(other, self) + + def _floordiv( + self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] + ) -> "BaseTracer": + return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv") + + def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + return self._floordiv(self, other) + + def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + return self._floordiv(other, self) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 9c998b63e..ac59eca02 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -174,6 +174,7 @@ def mix_x_and_y_and_call_f_with_float_inputs(func, x, y): def mix_x_and_y_and_call_f_with_integer_inputs(func, x, y): """Create an upper function to test `func`, with inputs which are forced to be integers but in a way which is fusable into a TLU""" + x = x // 2 a = x + 0.1 a = numpy.rint(a).astype(numpy.int32) z = numpy.abs(10 * func(a)) @@ -201,8 +202,9 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y): def mix_x_and_y_and_call_f_avoid_0_input(func, x, y): """Create an upper function to test `func`, which makes that inputs are not 0""" a = numpy.abs(7 * numpy.sin(x)) + 1 + c = 100 // a b = 100 / a - a = a + b + a = a + b + c z = numpy.abs(5 * func(a)) z = z.astype(numpy.int32) + y return z diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 24873b07c..fd0f7ce11 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -651,9 +651,22 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented(): TypeError, "unsupported operand type(s) for /: 'str' and 'NPTracer'", ), + pytest.param( + lambda x: x // "fail", + TypeError, + "unsupported operand type(s) for //: 'NPTracer' and 'str'", + ), + pytest.param( + lambda x: "fail" // x, + TypeError, + "unsupported operand type(s) for //: 'str' and 'NPTracer'", + ), pytest.param( lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv" ), + pytest.param( + lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv" + ), ], ) def test_nptracer_unsupported_operands(operation, exception_type, match):