mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
1d691f232b
commit
05cacc8744
@@ -110,6 +110,89 @@ class BaseTracer(ABC):
|
||||
|
||||
return output_tracers
|
||||
|
||||
def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer":
|
||||
"""Trace a unary operator which maintains the shape, which will thus be replaced by a TLU.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_name=f"{op_name}",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def _helper_for_binary_functions_with_one_cst_input(
|
||||
self,
|
||||
lhs: Union["BaseTracer", Any],
|
||||
rhs: Union["BaseTracer", Any],
|
||||
op_lambda: Callable,
|
||||
op_name: str,
|
||||
output_dtype: Optional[BaseDataType] = None,
|
||||
) -> "BaseTracer":
|
||||
"""Trace a binary operator which maintains the shape, when one input is a constant.
|
||||
|
||||
This function is helpful to convert an operation with two inputs, one of which being a
|
||||
constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify
|
||||
debugging.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
if isinstance(lhs, BaseTracer):
|
||||
if not self._supports_other_operand(rhs):
|
||||
return NotImplemented
|
||||
elif isinstance(rhs, BaseTracer):
|
||||
if not self._supports_other_operand(lhs):
|
||||
return NotImplemented
|
||||
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if not (
|
||||
isinstance(sanitized_inputs[0].traced_computation, Constant)
|
||||
or isinstance(sanitized_inputs[1].traced_computation, Constant)
|
||||
):
|
||||
raise NotImplementedError(f"Can't manage binary operator {op_name}")
|
||||
|
||||
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
|
||||
output_value = self._get_mix_values_func()(*sanitized_input_values)
|
||||
if output_dtype is not None:
|
||||
output_value.dtype = deepcopy(output_dtype)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=sanitized_input_values,
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=output_value,
|
||||
op_kind="TLU",
|
||||
op_name=op_name,
|
||||
)
|
||||
|
||||
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
|
||||
|
||||
return result_tracer
|
||||
|
||||
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
@@ -135,11 +218,31 @@ class BaseTracer(ABC):
|
||||
# a reference to the same object
|
||||
return 0 + self
|
||||
|
||||
def __lshift__(self, shift) -> "BaseTracer":
|
||||
return 2 ** shift * self
|
||||
def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x << y, "lshift"
|
||||
)
|
||||
|
||||
def __rshift__(self, shift) -> "BaseTracer":
|
||||
return self // 2 ** shift
|
||||
def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x << shift
|
||||
return self._lshift(self, other)
|
||||
|
||||
def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# cst << x
|
||||
return self._lshift(other, self)
|
||||
|
||||
def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x >> y, "rshift"
|
||||
)
|
||||
|
||||
def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x >> shift
|
||||
return self._rshift(self, other)
|
||||
|
||||
def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# cst >> x
|
||||
return self._rshift(other, self)
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
@@ -182,94 +285,22 @@ class BaseTracer(ABC):
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
def unary_ndarray_op(self, op_lambda, op_string: str):
|
||||
"""Trace an operator which maintains the shape, which will thus be replaced by a TLU.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_name=f"{op_string}",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __abs__(self):
|
||||
return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__")
|
||||
return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__")
|
||||
|
||||
def __invert__(self):
|
||||
return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__")
|
||||
return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__")
|
||||
|
||||
def __getitem__(self, item):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
def _div_common(
|
||||
self,
|
||||
lhs: Union["BaseTracer", Any],
|
||||
rhs: Union["BaseTracer", Any],
|
||||
div_op: Callable,
|
||||
op_name: str,
|
||||
output_dtype: Optional[BaseDataType] = None,
|
||||
) -> "BaseTracer":
|
||||
if isinstance(lhs, BaseTracer):
|
||||
if not self._supports_other_operand(rhs):
|
||||
return NotImplemented
|
||||
elif isinstance(rhs, BaseTracer):
|
||||
if not self._supports_other_operand(lhs):
|
||||
return NotImplemented
|
||||
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if not (
|
||||
isinstance(sanitized_inputs[0].traced_computation, Constant)
|
||||
or isinstance(sanitized_inputs[1].traced_computation, Constant)
|
||||
):
|
||||
raise NotImplementedError(f"Can't manage binary operator {op_name}")
|
||||
|
||||
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
|
||||
output_value = self._get_mix_values_func()(*sanitized_input_values)
|
||||
if output_dtype is not None:
|
||||
output_value.dtype = deepcopy(output_dtype)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=sanitized_input_values,
|
||||
arbitrary_func=div_op,
|
||||
output_value=output_value,
|
||||
op_kind="TLU",
|
||||
op_name=op_name,
|
||||
)
|
||||
|
||||
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
|
||||
|
||||
return result_tracer
|
||||
|
||||
def _truediv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64))
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x / y, "truediv", Float(64)
|
||||
)
|
||||
|
||||
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(self, other)
|
||||
@@ -280,7 +311,9 @@ class BaseTracer(ABC):
|
||||
def _floordiv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv")
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x // y, "floordiv"
|
||||
)
|
||||
|
||||
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._floordiv(self, other)
|
||||
|
||||
@@ -701,8 +701,8 @@ def test_tracing_numpy_calls(
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15),
|
||||
numpy.arange(15) * 8,
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) * 8,
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -711,8 +711,28 @@ def test_tracing_numpy_calls(
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15),
|
||||
numpy.arange(15) // 2,
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) // 2,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 2 << x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 8,
|
||||
2 << (numpy.arange(15).reshape(3, 5) % 8),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 256 >> x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 8,
|
||||
256 >> (numpy.arange(15).reshape(3, 5) % 8),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user