feat(tracing): add support for ndarray functions

ndarray.flatten
ndarray.__abs__
ndarray.__neg__, __pos__ and __invert__
ndarray.__rshift__ and __lshift_

refs #751
refs #218
This commit is contained in:
Benoit Chevallier-Mames
2021-11-12 17:59:00 +01:00
committed by Benoit Chevallier
parent baa8a97e98
commit a2a61a079f
3 changed files with 172 additions and 4 deletions

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast
from ..data_types import Float
from ..data_types.base import BaseDataType
@@ -17,7 +17,7 @@ from ..representation.intermediate import (
Mul,
Sub,
)
from ..values import BaseValue
from ..values import BaseValue, TensorValue
class BaseTracer(ABC):
@@ -130,6 +130,17 @@ class BaseTracer(ABC):
def __neg__(self) -> "BaseTracer":
return 0 - self
def __pos__(self) -> "BaseTracer":
# Remark that we don't want to return 'self' since we want the result to be a copy, ie not
# a reference to the same object
return 0 + self
def __lshift__(self, shift) -> "BaseTracer":
return 2 ** shift * self
def __rshift__(self, shift) -> "BaseTracer":
return self // 2 ** shift
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -171,6 +182,45 @@ class BaseTracer(ABC):
# some changes
__rmul__ = __mul__
def unary_ndarray_op(self, op_lambda, op_string: str):
"""Trace an operator which maintains the shape, which will thus be replaced by a TLU.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
arbitrary_func=op_lambda,
output_value=generic_function_output_value,
op_kind="TLU",
op_name=f"{op_string}",
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __abs__(self):
return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__")
def __invert__(self):
return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__")
def __getitem__(self, item):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)

View File

@@ -427,7 +427,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
inputs=[deepcopy(first_arg_output)],
arbitrary_func=numpy.reshape,
output_value=generic_function_output_value,
op_kind="Memory",
@@ -442,6 +442,45 @@ class NPTracer(BaseTracer):
)
return output_tracer
def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.flatten.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}")
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
flatten_is_fusable = first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
arbitrary_func=lambda x: x.flatten(),
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="flatten",
op_attributes={"fusable": flatten_is_fusable},
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __getitem__(self, item):
if isinstance(item, tuple):
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)

View File

@@ -1,5 +1,7 @@
"""Test file for numpy tracing"""
# pylint: disable=too-many-lines
import inspect
from copy import deepcopy
@@ -691,7 +693,11 @@ def subtest_tracing_calls(
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
assert numpy.array_equal(expected_output, evaluated_output)
if not numpy.array_equal(expected_output, evaluated_output):
print("Wrong result")
print(f"Expected: {expected_output}")
print(f"Got : {evaluated_output}")
raise AssertionError
@pytest.mark.parametrize(
@@ -831,6 +837,76 @@ def test_tracing_numpy_calls(
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
),
pytest.param(
lambda x: x.flatten(),
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15),
)
],
),
pytest.param(
lambda x: abs(x),
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: +x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: -x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
(numpy.arange(15).reshape(3, 5)) * (-1),
)
],
),
pytest.param(
lambda x: ~x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5).__invert__(),
)
],
),
pytest.param(
lambda x: x << 3,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15),
numpy.arange(15) * 8,
)
],
),
pytest.param(
lambda x: x >> 1,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15),
numpy.arange(15) // 2,
)
],
),
],
)
def test_tracing_ndarray_calls(
@@ -858,3 +934,6 @@ def test_errors_with_generic_function(lambda_f, params):
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
# pylint: enable=too-many-lines