mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: mlir conversion of np.integer slices during indexing and assignment
This commit is contained in:
@@ -878,15 +878,15 @@ class NodeConverter:
|
||||
for indexing_element, dimension_size in zip(index, input_shape):
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
size = np.zeros(dimension_size)[indexing_element].shape[0]
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
size = int(np.zeros(dimension_size)[indexing_element].shape[0])
|
||||
stride = int(indexing_element.step if indexing_element.step is not None else 1)
|
||||
offset = int(
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
if indexing_element.start is not None
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
@@ -950,15 +950,15 @@ class NodeConverter:
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
size = np.zeros(dimension_size)[indexing_element].shape[0]
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
size = int(np.zeros(dimension_size)[indexing_element].shape[0])
|
||||
stride = int(indexing_element.step if indexing_element.step is not None else 1)
|
||||
offset = int(
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
if indexing_element.start is not None
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
|
||||
@@ -146,6 +146,11 @@ import concrete.numpy as cnp
|
||||
lambda x: x[::-1],
|
||||
id="x[::-1] where x.shape == (3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
(10,),
|
||||
lambda x: x[slice(np.int64(8), np.int64(2), np.int64(-2))],
|
||||
id="x[8:2:-2] where x.shape == (10,)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_static_indexing(shape, function, helpers):
|
||||
|
||||
Reference in New Issue
Block a user