From 8d89614d8aeb42862ce1fc089bad6862749d3df2 Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 23 Feb 2023 13:44:50 +0100 Subject: [PATCH] feat: raise error if tracers are tried to be converted to bool --- concrete/numpy/tracing/tracer.py | 6 ++++++ tests/tracing/test_tracer.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index e3ae9f9de..49abb53b1 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -175,6 +175,12 @@ class Tracer: def __hash__(self) -> int: return id(self) + def __bool__(self) -> bool: + # pylint: disable=invalid-bool-returned + + message = "Branching within circuits is not possible" + raise RuntimeError(message) + @staticmethod def sanitize(value: Any) -> Any: """ diff --git a/tests/tracing/test_tracer.py b/tests/tracing/test_tracer.py index 95f95990f..773742feb 100644 --- a/tests/tracing/test_tracer.py +++ b/tests/tracing/test_tracer.py @@ -45,6 +45,12 @@ from concrete.numpy.values import EncryptedTensor "`astype` method must be called with a " "numpy type for compilation (e.g., value.astype(np.int64))", ), + pytest.param( + lambda x: x + 1 if x else x + x, + {"x": EncryptedTensor(UnsignedInteger(7), shape=())}, + RuntimeError, + "Branching within circuits is not possible", + ), ], ) def test_tracer_bad_trace(function, parameters, expected_error, expected_message):