feat(tracing): let's trace k<<x and k>>x

closes #915
This commit is contained in:
Benoit Chevallier-Mames
2021-11-17 11:20:44 +01:00
committed by Benoit Chevallier
parent 1d691f232b
commit 05cacc8744
2 changed files with 139 additions and 86 deletions

View File

@@ -110,6 +110,89 @@ class BaseTracer(ABC):
return output_tracers
def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer":
"""Trace a unary operator which maintains the shape, which will thus be replaced by a TLU.
Returns:
BaseTracer: The output NPTracer containing the traced function
"""
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
arbitrary_func=op_lambda,
output_value=generic_function_output_value,
op_kind="TLU",
op_name=f"{op_name}",
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def _helper_for_binary_functions_with_one_cst_input(
self,
lhs: Union["BaseTracer", Any],
rhs: Union["BaseTracer", Any],
op_lambda: Callable,
op_name: str,
output_dtype: Optional[BaseDataType] = None,
) -> "BaseTracer":
"""Trace a binary operator which maintains the shape, when one input is a constant.
This function is helpful to convert an operation with two inputs, one of which being a
constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify
debugging.
Returns:
BaseTracer: The output NPTracer containing the traced function
"""
if isinstance(lhs, BaseTracer):
if not self._supports_other_operand(rhs):
return NotImplemented
elif isinstance(rhs, BaseTracer):
if not self._supports_other_operand(lhs):
return NotImplemented
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
# One of the inputs has to be constant
if not (
isinstance(sanitized_inputs[0].traced_computation, Constant)
or isinstance(sanitized_inputs[1].traced_computation, Constant)
):
raise NotImplementedError(f"Can't manage binary operator {op_name}")
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
output_value = self._get_mix_values_func()(*sanitized_input_values)
if output_dtype is not None:
output_value.dtype = deepcopy(output_dtype)
traced_computation = GenericFunction(
inputs=sanitized_input_values,
arbitrary_func=op_lambda,
output_value=output_value,
op_kind="TLU",
op_name=op_name,
)
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
return result_tracer
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -135,11 +218,31 @@ class BaseTracer(ABC):
# a reference to the same object
return 0 + self
def __lshift__(self, shift) -> "BaseTracer":
return 2 ** shift * self
def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x << y, "lshift"
)
def __rshift__(self, shift) -> "BaseTracer":
return self // 2 ** shift
def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x << shift
return self._lshift(self, other)
def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# cst << x
return self._lshift(other, self)
def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x >> y, "rshift"
)
def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x >> shift
return self._rshift(self, other)
def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# cst >> x
return self._rshift(other, self)
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
@@ -182,94 +285,22 @@ class BaseTracer(ABC):
# some changes
__rmul__ = __mul__
def unary_ndarray_op(self, op_lambda, op_string: str):
"""Trace an operator which maintains the shape, which will thus be replaced by a TLU.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
arbitrary_func=op_lambda,
output_value=generic_function_output_value,
op_kind="TLU",
op_name=f"{op_string}",
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __abs__(self):
return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__")
return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__")
def __invert__(self):
return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__")
return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__")
def __getitem__(self, item):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)
def _div_common(
self,
lhs: Union["BaseTracer", Any],
rhs: Union["BaseTracer", Any],
div_op: Callable,
op_name: str,
output_dtype: Optional[BaseDataType] = None,
) -> "BaseTracer":
if isinstance(lhs, BaseTracer):
if not self._supports_other_operand(rhs):
return NotImplemented
elif isinstance(rhs, BaseTracer):
if not self._supports_other_operand(lhs):
return NotImplemented
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
# One of the inputs has to be constant
if not (
isinstance(sanitized_inputs[0].traced_computation, Constant)
or isinstance(sanitized_inputs[1].traced_computation, Constant)
):
raise NotImplementedError(f"Can't manage binary operator {op_name}")
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
output_value = self._get_mix_values_func()(*sanitized_input_values)
if output_dtype is not None:
output_value.dtype = deepcopy(output_dtype)
traced_computation = GenericFunction(
inputs=sanitized_input_values,
arbitrary_func=div_op,
output_value=output_value,
op_kind="TLU",
op_name=op_name,
)
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
return result_tracer
def _truediv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64))
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x / y, "truediv", Float(64)
)
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(self, other)
@@ -280,7 +311,9 @@ class BaseTracer(ABC):
def _floordiv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv")
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x // y, "floordiv"
)
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._floordiv(self, other)

View File

@@ -701,8 +701,8 @@ def test_tracing_numpy_calls(
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15),
numpy.arange(15) * 8,
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) * 8,
)
],
),
@@ -711,8 +711,28 @@ def test_tracing_numpy_calls(
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15),
numpy.arange(15) // 2,
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) // 2,
)
],
),
pytest.param(
lambda x: 2 << x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 8,
2 << (numpy.arange(15).reshape(3, 5) % 8),
)
],
),
pytest.param(
lambda x: 256 >> x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 8,
256 >> (numpy.arange(15).reshape(3, 5) % 8),
)
],
),