feat: add clip support

closes #965
closes #983
closes #984
This commit is contained in:
Arthur Meyre
2021-11-23 14:20:54 +01:00
parent f1ed07d580
commit 60d8079303
2 changed files with 56 additions and 0 deletions

View File

@@ -239,6 +239,23 @@ class NPTracer(BaseTracer):
)
return output_tracer
def clip(self, *args: Union["NPTracer", Any], **kwargs) -> "NPTracer":
"""Trace x.clip.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
sanitized_args = [cast(NPTracer, self._sanitize(arg)) for arg in args]
return self.numpy_clip(self, *sanitized_args, **kwargs)
def numpy_clip(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.clip.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._np_operator(numpy.clip, "clip", 3, *args, **kwargs)
def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.dot.
@@ -565,6 +582,7 @@ class NPTracer(BaseTracer):
numpy.transpose: numpy_transpose,
numpy.reshape: numpy_reshape,
numpy.ravel: numpy_ravel,
numpy.clip: numpy_clip,
}

View File

@@ -937,6 +937,44 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
([2, 7, 1],),
[4, 14, 2],
),
pytest.param(
lambda x: numpy.clip(x, 1, 5),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[1, 5],
[5, 1],
[2, 5],
],
),
pytest.param(
lambda x: x.clip(1, 5),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[1, 5],
[5, 1],
[2, 5],
],
),
],
)
def test_compile_and_run_tensor_correctness(