diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index ab9a038758..082d046550 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -175,6 +175,15 @@ class TestSymbolicOps(unittest.TestCase): expected = a[3:5, i:i+2].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_slice_no_start(self): + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(7, 11) + symbolic = a[3:5, :vi:1].reshape(2,i) + symbolic = symbolic.numpy() + expected = a[3:5, :i:1].numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_expand_padded(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8d046b1a1d..8ce86cb012 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1143,19 +1143,19 @@ class Tensor(MathTrait): boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] case slice(): if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step") - if all(s is None for s in (index.start,index.stop,index.step)): - boundary, stride = [0, size], 1 - elif all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)): + start, stop = 0 if index.start is None else index.start, size if index.stop is None else index.stop + step = 1 if index.step is None else index.step + boundary, stride = [start, stop], step + if all(isinstance(s, int) for s in (start,stop,step)): # handle int slicing *boundary, stride = index.indices(cast(SupportsIndex, size)) if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] # update size for slice size = ceildiv((boundary[1] - boundary[0]), abs(stride)) - elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False): + elif (step == 1) and isinstance(step, int) and all(isinstance(s,(int,UOp)) for s in (start, stop)) and resolve((stop-start) > 0, False): # simple symbolic slice - boundary = [index.start, index.stop] - size = (index.stop - index.start).ssimplify() + size = cast(UOp|int, cast(UOp, (stop - start)).ssimplify()) else: raise TypeError(f"slice {index=} is not supported") case None: pass # do nothing case _: raise IndexError(f"{type(index).__name__} indexing is not supported")