mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: raise proper error if function being compiled returns something unsupported
This commit is contained in:
@@ -65,12 +65,29 @@ class Tracer:
|
||||
input_indices[node] = index
|
||||
|
||||
Tracer._is_tracing = True
|
||||
output_tracers = function(**arguments)
|
||||
output_tracers: Any = function(**arguments)
|
||||
Tracer._is_tracing = False
|
||||
|
||||
if isinstance(output_tracers, Tracer):
|
||||
if not isinstance(output_tracers, tuple):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
sanitized_tracers = []
|
||||
for tracer in output_tracers:
|
||||
if isinstance(tracer, Tracer):
|
||||
sanitized_tracers.append(tracer)
|
||||
continue
|
||||
|
||||
try:
|
||||
sanitized_tracers.append(Tracer._sanitize(tracer))
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"Function '{function.__name__}' "
|
||||
f"returned '{tracer}', "
|
||||
f"which is not supported"
|
||||
) from error
|
||||
|
||||
output_tracers = tuple(sanitized_tracers)
|
||||
|
||||
def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph:
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Tests of `Compiler` class.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.numpy.compilation import Compiler
|
||||
@@ -126,6 +127,9 @@ def test_compiler_bad_trace(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
# without inputset
|
||||
# ----------------
|
||||
|
||||
def f(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
@@ -138,6 +142,18 @@ def test_compiler_bad_trace(helpers):
|
||||
|
||||
assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported"
|
||||
|
||||
# bad return
|
||||
# ----------
|
||||
|
||||
def g():
|
||||
return np.array([{}, ()], dtype=object)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler = Compiler(g, {})
|
||||
compiler.trace(inputset=[()], configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Function 'g' returned '[{} ()]', which is not supported"
|
||||
|
||||
|
||||
def test_compiler_bad_compile(helpers):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user