From 840c0eba8ce074f76f438627acc1ffc30ef2a9c5 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 15 Feb 2023 16:13:37 +0100 Subject: [PATCH] fix: mlir conversion of np.integer slices during indexing and assignment --- concrete/numpy/mlir/node_converter.py | 16 ++++++++-------- tests/execution/test_static_indexing.py | 5 +++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index cce55b58d..fb5a44be7 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -878,15 +878,15 @@ class NodeConverter: for indexing_element, dimension_size in zip(index, input_shape): if isinstance(indexing_element, slice): - size = np.zeros(dimension_size)[indexing_element].shape[0] - stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 - offset = ( + size = int(np.zeros(dimension_size)[indexing_element].shape[0]) + stride = int(indexing_element.step if indexing_element.step is not None else 1) + offset = int( ( indexing_element.start if indexing_element.start >= 0 else indexing_element.start + dimension_size ) - if isinstance(indexing_element.start, int) + if indexing_element.start is not None else (0 if stride > 0 else dimension_size - 1) ) @@ -950,15 +950,15 @@ class NodeConverter: for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): if isinstance(indexing_element, slice): - size = np.zeros(dimension_size)[indexing_element].shape[0] - stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 - offset = ( + size = int(np.zeros(dimension_size)[indexing_element].shape[0]) + stride = int(indexing_element.step if indexing_element.step is not None else 1) + offset = int( ( indexing_element.start if indexing_element.start >= 0 else indexing_element.start + dimension_size ) - if isinstance(indexing_element.start, int) + if indexing_element.start is not None else (0 if stride > 0 else dimension_size - 1) ) diff --git a/tests/execution/test_static_indexing.py b/tests/execution/test_static_indexing.py index 199552f20..07772c153 100644 --- a/tests/execution/test_static_indexing.py +++ b/tests/execution/test_static_indexing.py @@ -146,6 +146,11 @@ import concrete.numpy as cnp lambda x: x[::-1], id="x[::-1] where x.shape == (3, 4)", ), + pytest.param( + (10,), + lambda x: x[slice(np.int64(8), np.int64(2), np.int64(-2))], + id="x[8:2:-2] where x.shape == (10,)", + ), ], ) def test_static_indexing(shape, function, helpers):