feat: raise proper error if function being compiled returns something unsupported

This commit is contained in:
Umut
2022-06-10 16:18:03 +02:00
parent 53e5dda732
commit 59cacc35df
2 changed files with 35 additions and 2 deletions

View File

@@ -65,12 +65,29 @@ class Tracer:
input_indices[node] = index
Tracer._is_tracing = True
output_tracers = function(**arguments)
output_tracers: Any = function(**arguments)
Tracer._is_tracing = False
if isinstance(output_tracers, Tracer):
if not isinstance(output_tracers, tuple):
output_tracers = (output_tracers,)
sanitized_tracers = []
for tracer in output_tracers:
if isinstance(tracer, Tracer):
sanitized_tracers.append(tracer)
continue
try:
sanitized_tracers.append(Tracer._sanitize(tracer))
except Exception as error:
raise ValueError(
f"Function '{function.__name__}' "
f"returned '{tracer}', "
f"which is not supported"
) from error
output_tracers = tuple(sanitized_tracers)
def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph:
graph = nx.MultiDiGraph()

View File

@@ -2,6 +2,7 @@
Tests of `Compiler` class.
"""
import numpy as np
import pytest
from concrete.numpy.compilation import Compiler
@@ -126,6 +127,9 @@ def test_compiler_bad_trace(helpers):
configuration = helpers.configuration()
# without inputset
# ----------------
def f(x, y, z):
return x + y + z
@@ -138,6 +142,18 @@ def test_compiler_bad_trace(helpers):
assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported"
# bad return
# ----------
def g():
return np.array([{}, ()], dtype=object)
with pytest.raises(ValueError) as excinfo:
compiler = Compiler(g, {})
compiler.trace(inputset=[()], configuration=configuration)
assert str(excinfo.value) == "Function 'g' returned '[{} ()]', which is not supported"
def test_compiler_bad_compile(helpers):
"""