From d99436e0980fbdb22e15c0809c2655ca85542df5 Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 11 May 2023 10:04:45 +0200 Subject: [PATCH] fix(frontend-python): broadcast tlu input on multi tlus --- .../concrete/fhe/mlir/context.py | 2 ++ .../tests/execution/test_others.py | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index b2c294b66..dfe32f35f 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -1487,6 +1487,8 @@ class Context: mapping = np.array(mapping, dtype=np.uint64) + on = self.broadcast_to(on, mapping.shape) + assert mapping.shape == on.shape assert mapping.min() == 0 assert mapping.max() == len(tables) - 1 diff --git a/frontends/concrete-python/tests/execution/test_others.py b/frontends/concrete-python/tests/execution/test_others.py index f7c54fe6a..d1b7332e9 100644 --- a/frontends/concrete-python/tests/execution/test_others.py +++ b/frontends/concrete-python/tests/execution/test_others.py @@ -914,3 +914,33 @@ def test_others_bad_univariate(helpers): "Function bad_univariate cannot be used with fhe.univariate", str(excinfo.value), ) + + +def test_dynamic_indexing_hack(helpers): + """ + Test dynamic indexing using basic operators. + """ + + @fhe.compiler({"array": "encrypted", "index": "encrypted"}) + def function(array, index): + all_indices = np.arange(array.size) + index_selection = index == all_indices + selection_and_zeros = array * index_selection + selection = np.sum(selection_and_zeros) + return selection + + inputset = [ + ( + np.random.randint(0, 16, size=(4,)), + np.random.randint(0, 4, size=()), + ) + for _ in range(100) + ] + circuit = function.compile(inputset, helpers.configuration()) + + sample = np.random.randint(0, 16, size=(4,)) + + helpers.check_execution(circuit, function, [sample, 0], retries=3) + helpers.check_execution(circuit, function, [sample, 1], retries=3) + helpers.check_execution(circuit, function, [sample, 2], retries=3) + helpers.check_execution(circuit, function, [sample, 3], retries=3)