mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-12 21:54:55 -05:00
@@ -1,9 +1,11 @@
|
||||
"""This file holds the code that can be shared between tracers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..data_types import Float
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import (
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME,
|
||||
@@ -173,8 +175,13 @@ class BaseTracer(ABC):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
def _truediv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
def _div_common(
|
||||
self,
|
||||
lhs: Union["BaseTracer", Any],
|
||||
rhs: Union["BaseTracer", Any],
|
||||
div_op: Callable,
|
||||
op_name: str,
|
||||
output_dtype: Optional[BaseDataType] = None,
|
||||
) -> "BaseTracer":
|
||||
if isinstance(lhs, BaseTracer):
|
||||
if not self._supports_other_operand(rhs):
|
||||
@@ -190,27 +197,43 @@ class BaseTracer(ABC):
|
||||
isinstance(sanitized_inputs[0].traced_computation, Constant)
|
||||
or isinstance(sanitized_inputs[1].traced_computation, Constant)
|
||||
):
|
||||
raise NotImplementedError("Can't manage binary operator truediv")
|
||||
raise NotImplementedError(f"Can't manage binary operator {op_name}")
|
||||
|
||||
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
|
||||
output_value = self._get_mix_values_func()(*sanitized_input_values)
|
||||
# The true division in python is always float64
|
||||
output_value.dtype = Float(64)
|
||||
if output_dtype is not None:
|
||||
output_value.dtype = deepcopy(output_dtype)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=sanitized_input_values,
|
||||
arbitrary_func=lambda x, y: x / y,
|
||||
arbitrary_func=div_op,
|
||||
output_value=output_value,
|
||||
op_kind="TLU",
|
||||
op_name="truediv",
|
||||
op_name=op_name,
|
||||
)
|
||||
|
||||
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
|
||||
|
||||
return result_tracer
|
||||
|
||||
def _truediv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64))
|
||||
|
||||
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(self, other)
|
||||
|
||||
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(other, self)
|
||||
|
||||
def _floordiv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv")
|
||||
|
||||
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._floordiv(self, other)
|
||||
|
||||
def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._floordiv(other, self)
|
||||
|
||||
@@ -174,6 +174,7 @@ def mix_x_and_y_and_call_f_with_float_inputs(func, x, y):
|
||||
def mix_x_and_y_and_call_f_with_integer_inputs(func, x, y):
|
||||
"""Create an upper function to test `func`, with inputs which are forced to be integers but
|
||||
in a way which is fusable into a TLU"""
|
||||
x = x // 2
|
||||
a = x + 0.1
|
||||
a = numpy.rint(a).astype(numpy.int32)
|
||||
z = numpy.abs(10 * func(a))
|
||||
@@ -201,8 +202,9 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y):
|
||||
def mix_x_and_y_and_call_f_avoid_0_input(func, x, y):
|
||||
"""Create an upper function to test `func`, which makes that inputs are not 0"""
|
||||
a = numpy.abs(7 * numpy.sin(x)) + 1
|
||||
c = 100 // a
|
||||
b = 100 / a
|
||||
a = a + b
|
||||
a = a + b + c
|
||||
z = numpy.abs(5 * func(a))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
@@ -651,9 +651,22 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x // "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" // x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
|
||||
Reference in New Issue
Block a user