From 5a065769bb7ac426eb9bc227ef1a8360f3b39c09 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 4 Jul 2022 12:02:48 +0200 Subject: [PATCH] fix: allow generator inputsets again --- concrete/numpy/compilation/compiler.py | 8 +++++--- tests/compilation/test_compiler.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 333fd1200..6ca9b0754 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -172,11 +172,16 @@ class Compiler: """ if inputset is not None: + previous_inputset_length = len(self.inputset) for index, sample in enumerate(iter(inputset)): + self.inputset.append(sample) + if not isinstance(sample, tuple): sample = (sample,) if len(sample) != len(self.parameter_encryption_statuses): + self.inputset = self.inputset[:previous_inputset_length] + expected = ( "a single value" if len(self.parameter_encryption_statuses) == 1 @@ -191,9 +196,6 @@ class Compiler: f"(expected {expected} got {actual})" ) - for input_ in inputset: - self.inputset.append(input_) - if self.graph is None: try: first_sample = next(iter(self.inputset)) diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index 5c55629bf..b3915b2f1 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -236,6 +236,8 @@ def test_compiler_virtual_compile(helpers): return x + 400 compiler = Compiler(f, {"x": "encrypted"}) - circuit = compiler.compile(inputset=range(400), configuration=configuration, virtual=True) + + inputset = (i for i in range(400)) + circuit = compiler.compile(inputset, configuration=configuration, virtual=True) assert circuit.encrypt_run_decrypt(200) == 600