From 955470fb8944784a79c284bc04379c2043520d3e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 10 Nov 2021 11:29:30 +0100 Subject: [PATCH] feat: add support for __truediv__ - cannot use the standard binary op workflow as we don't have an op for div closes #866 closes #867 --- concrete/common/tracing/base_tracer.py | 45 ++++++++++++ .../common/optimization/test_float_fusing.py | 4 +- tests/numpy/test_compile.py | 3 + tests/numpy/test_tracing.py | 71 ++++++++++++++----- 4 files changed, 106 insertions(+), 17 deletions(-) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 4a729359a..acd4163e0 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -3,10 +3,13 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, List, Tuple, Type, Union +from ..data_types import Float from ..debugging.custom_assert import assert_true from ..representation.intermediate import ( IR_MIX_VALUES_FUNC_ARG_NAME, Add, + Constant, + GenericFunction, IndexConstant, IntermediateNode, Mul, @@ -169,3 +172,45 @@ class BaseTracer(ABC): def __getitem__(self, item): traced_computation = IndexConstant(self.output, item) return self.__class__([self], traced_computation, 0) + + def _truediv( + self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] + ) -> "BaseTracer": + if isinstance(lhs, BaseTracer): + if not self._supports_other_operand(rhs): + return NotImplemented + elif isinstance(rhs, BaseTracer): + if not self._supports_other_operand(lhs): + return NotImplemented + + sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]] + + # One of the inputs has to be constant + if not ( + isinstance(sanitized_inputs[0].traced_computation, Constant) + or isinstance(sanitized_inputs[1].traced_computation, Constant) + ): + raise NotImplementedError("Can't manage binary operator truediv") + + sanitized_input_values = [san_input.output for san_input in sanitized_inputs] + output_value = self._get_mix_values_func()(*sanitized_input_values) + # The true division in python is always float64 + output_value.dtype = Float(64) + + traced_computation = GenericFunction( + inputs=sanitized_input_values, + arbitrary_func=lambda x, y: x / y, + output_value=output_value, + op_kind="TLU", + op_name="truediv", + ) + + result_tracer = self.__class__(sanitized_inputs, traced_computation, 0) + + return result_tracer + + def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + return self._truediv(self, other) + + def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + return self._truediv(other, self) diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 6e05474f5..1d7175dc6 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -585,8 +585,10 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): tensor_shape ) if tensor_shape != () - else 1 + else numpy.int64(1) ) + # Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail + # as python int and float don't have the astype method input_ = input_ * tensor_diversifier num_params = len(params_names) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b104ae9b0..7fff439ff 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -159,6 +159,7 @@ def complicated_topology(x): def mix_x_and_y_and_call_f(func, x, y): """Create an upper function to test `func`""" z = numpy.abs(10 * func(x)) + z = z / 2 z = z.astype(numpy.int32) + y return z @@ -200,6 +201,8 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y): def mix_x_and_y_and_call_f_avoid_0_input(func, x, y): """Create an upper function to test `func`, which makes that inputs are not 0""" a = numpy.abs(7 * numpy.sin(x)) + 1 + b = 100 / a + a = a + b z = numpy.abs(5 * func(a)) z = z.astype(numpy.int32) + y return z diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index f84df3fc3..24873b07c 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -1,5 +1,6 @@ """Test file for numpy tracing""" +import inspect from copy import deepcopy import networkx as nx @@ -608,26 +609,64 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented(): @pytest.mark.parametrize( - "tracer", + "operation,exception_type,match", [ - tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), "x", 0), 0), + pytest.param( + lambda x: x + "fail", + TypeError, + "unsupported operand type(s) for +: 'NPTracer' and 'str'", + ), + pytest.param( + lambda x: "fail" + x, + TypeError, + 'can only concatenate str (not "NPTracer") to str', + ), + pytest.param( + lambda x: x - "fail", + TypeError, + "unsupported operand type(s) for -: 'NPTracer' and 'str'", + ), + pytest.param( + lambda x: "fail" - x, + TypeError, + "unsupported operand type(s) for -: 'str' and 'NPTracer'", + ), + pytest.param( + lambda x: x * "fail", + TypeError, + "can't multiply sequence by non-int of type 'NPTracer'", + ), + pytest.param( + lambda x: "fail" * x, + TypeError, + "can't multiply sequence by non-int of type 'NPTracer'", + ), + pytest.param( + lambda x: x / "fail", + TypeError, + "unsupported operand type(s) for /: 'NPTracer' and 'str'", + ), + pytest.param( + lambda x: "fail" / x, + TypeError, + "unsupported operand type(s) for /: 'str' and 'NPTracer'", + ), + pytest.param( + lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv" + ), ], ) -@pytest.mark.parametrize( - "operation", - [ - lambda x: x + "fail", - lambda x: "fail" + x, - lambda x: x - "fail", - lambda x: "fail" - x, - lambda x: x * "fail", - lambda x: "fail" * x, - ], -) -def test_nptracer_unsupported_operands(operation, tracer): +def test_nptracer_unsupported_operands(operation, exception_type, match): """Test cases where NPTracer cannot be used with other operands.""" - with pytest.raises(TypeError): - tracer = operation(tracer) + tracers = [ + tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0) + for idx, param_name in enumerate(inspect.signature(operation).parameters.keys()) + ] + + with pytest.raises(exception_type) as exc_info: + _ = operation(*tracers) + + assert match in str(exc_info) @pytest.mark.parametrize(