Support symbolic slice with no start [pr] (#10775)

* add symbolic slice with no start

* reshape the test

* step must be int

* just add a cast...

* more cast...
This commit is contained in:
Sieds Lykles
2025-06-11 22:00:38 +02:00
committed by GitHub
parent d465ef4acb
commit 10b61157b9
2 changed files with 15 additions and 6 deletions

View File

@@ -175,6 +175,15 @@ class TestSymbolicOps(unittest.TestCase):
expected = a[3:5, i:i+2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_slice_no_start(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, :vi:1].reshape(2,i)
symbolic = symbolic.numpy()
expected = a[3:5, :i:1].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_expand_padded(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)

View File

@@ -1143,19 +1143,19 @@ class Tensor(MathTrait):
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
case slice():
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
if all(s is None for s in (index.start,index.stop,index.step)):
boundary, stride = [0, size], 1
elif all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
start, stop = 0 if index.start is None else index.start, size if index.stop is None else index.stop
step = 1 if index.step is None else index.step
boundary, stride = [start, stop], step
if all(isinstance(s, int) for s in (start,stop,step)):
# handle int slicing
*boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
# update size for slice
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False):
elif (step == 1) and isinstance(step, int) and all(isinstance(s,(int,UOp)) for s in (start, stop)) and resolve((stop-start) > 0, False):
# simple symbolic slice
boundary = [index.start, index.stop]
size = (index.stop - index.start).ssimplify()
size = cast(UOp|int, cast(UOp, (stop - start)).ssimplify())
else: raise TypeError(f"slice {index=} is not supported")
case None: pass # do nothing
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")