mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add support for __truediv__
- cannot use the standard binary op workflow as we don't have an op for div closes #866 closes #867
This commit is contained in:
@@ -3,10 +3,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..data_types import Float
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import (
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME,
|
||||
Add,
|
||||
Constant,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
@@ -169,3 +172,45 @@ class BaseTracer(ABC):
|
||||
def __getitem__(self, item):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
def _truediv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
if isinstance(lhs, BaseTracer):
|
||||
if not self._supports_other_operand(rhs):
|
||||
return NotImplemented
|
||||
elif isinstance(rhs, BaseTracer):
|
||||
if not self._supports_other_operand(lhs):
|
||||
return NotImplemented
|
||||
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if not (
|
||||
isinstance(sanitized_inputs[0].traced_computation, Constant)
|
||||
or isinstance(sanitized_inputs[1].traced_computation, Constant)
|
||||
):
|
||||
raise NotImplementedError("Can't manage binary operator truediv")
|
||||
|
||||
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
|
||||
output_value = self._get_mix_values_func()(*sanitized_input_values)
|
||||
# The true division in python is always float64
|
||||
output_value.dtype = Float(64)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=sanitized_input_values,
|
||||
arbitrary_func=lambda x, y: x / y,
|
||||
output_value=output_value,
|
||||
op_kind="TLU",
|
||||
op_name="truediv",
|
||||
)
|
||||
|
||||
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
|
||||
|
||||
return result_tracer
|
||||
|
||||
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(self, other)
|
||||
|
||||
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(other, self)
|
||||
|
||||
@@ -585,8 +585,10 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
tensor_shape
|
||||
)
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
else numpy.int64(1)
|
||||
)
|
||||
# Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail
|
||||
# as python int and float don't have the astype method
|
||||
input_ = input_ * tensor_diversifier
|
||||
|
||||
num_params = len(params_names)
|
||||
|
||||
@@ -159,6 +159,7 @@ def complicated_topology(x):
|
||||
def mix_x_and_y_and_call_f(func, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(10 * func(x))
|
||||
z = z / 2
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
@@ -200,6 +201,8 @@ def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y):
|
||||
def mix_x_and_y_and_call_f_avoid_0_input(func, x, y):
|
||||
"""Create an upper function to test `func`, which makes that inputs are not 0"""
|
||||
a = numpy.abs(7 * numpy.sin(x)) + 1
|
||||
b = 100 / a
|
||||
a = a + b
|
||||
z = numpy.abs(5 * func(a))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
|
||||
import networkx as nx
|
||||
@@ -608,26 +609,64 @@ def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tracer",
|
||||
"operation,exception_type,match",
|
||||
[
|
||||
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), "x", 0), 0),
|
||||
pytest.param(
|
||||
lambda x: x + "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" + x,
|
||||
TypeError,
|
||||
'can only concatenate str (not "NPTracer") to str',
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" - x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * "fail",
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" * x,
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x / "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" / x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"operation",
|
||||
[
|
||||
lambda x: x + "fail",
|
||||
lambda x: "fail" + x,
|
||||
lambda x: x - "fail",
|
||||
lambda x: "fail" - x,
|
||||
lambda x: x * "fail",
|
||||
lambda x: "fail" * x,
|
||||
],
|
||||
)
|
||||
def test_nptracer_unsupported_operands(operation, tracer):
|
||||
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
"""Test cases where NPTracer cannot be used with other operands."""
|
||||
with pytest.raises(TypeError):
|
||||
tracer = operation(tracer)
|
||||
tracers = [
|
||||
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
|
||||
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
|
||||
]
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
_ = operation(*tracers)
|
||||
|
||||
assert match in str(exc_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user