From 59cacc35df8289e60b7bab6a25129df8b0a12b72 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 10 Jun 2022 16:18:03 +0200 Subject: [PATCH] feat: raise proper error if function being compiled returns something unsupported --- concrete/numpy/tracing/tracer.py | 21 +++++++++++++++++++-- tests/compilation/test_compiler.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index f01034cdd..d5a1a7fc4 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -65,12 +65,29 @@ class Tracer: input_indices[node] = index Tracer._is_tracing = True - output_tracers = function(**arguments) + output_tracers: Any = function(**arguments) Tracer._is_tracing = False - if isinstance(output_tracers, Tracer): + if not isinstance(output_tracers, tuple): output_tracers = (output_tracers,) + sanitized_tracers = [] + for tracer in output_tracers: + if isinstance(tracer, Tracer): + sanitized_tracers.append(tracer) + continue + + try: + sanitized_tracers.append(Tracer._sanitize(tracer)) + except Exception as error: + raise ValueError( + f"Function '{function.__name__}' " + f"returned '{tracer}', " + f"which is not supported" + ) from error + + output_tracers = tuple(sanitized_tracers) + def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph: graph = nx.MultiDiGraph() diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index c0fad81d4..5c55629bf 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -2,6 +2,7 @@ Tests of `Compiler` class. """ +import numpy as np import pytest from concrete.numpy.compilation import Compiler @@ -126,6 +127,9 @@ def test_compiler_bad_trace(helpers): configuration = helpers.configuration() + # without inputset + # ---------------- + def f(x, y, z): return x + y + z @@ -138,6 +142,18 @@ def test_compiler_bad_trace(helpers): assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported" + # bad return + # ---------- + + def g(): + return np.array([{}, ()], dtype=object) + + with pytest.raises(ValueError) as excinfo: + compiler = Compiler(g, {}) + compiler.trace(inputset=[()], configuration=configuration) + + assert str(excinfo.value) == "Function 'g' returned '[{} ()]', which is not supported" + def test_compiler_bad_compile(helpers): """