mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Support symbolic slice with no start [pr] (#10775)
* add symbolic slice with no start * reshape the test * step must be int * just add a cast... * more cast...
This commit is contained in:
@@ -175,6 +175,15 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
expected = a[3:5, i:i+2].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_slice_no_start(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, :vi:1].reshape(2,i)
|
||||
symbolic = symbolic.numpy()
|
||||
expected = a[3:5, :i:1].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_expand_padded(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
|
||||
@@ -1143,19 +1143,19 @@ class Tensor(MathTrait):
|
||||
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
||||
case slice():
|
||||
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
||||
if all(s is None for s in (index.start,index.stop,index.step)):
|
||||
boundary, stride = [0, size], 1
|
||||
elif all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
|
||||
start, stop = 0 if index.start is None else index.start, size if index.stop is None else index.stop
|
||||
step = 1 if index.step is None else index.step
|
||||
boundary, stride = [start, stop], step
|
||||
if all(isinstance(s, int) for s in (start,stop,step)):
|
||||
# handle int slicing
|
||||
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
||||
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
||||
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
||||
# update size for slice
|
||||
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
||||
elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False):
|
||||
elif (step == 1) and isinstance(step, int) and all(isinstance(s,(int,UOp)) for s in (start, stop)) and resolve((stop-start) > 0, False):
|
||||
# simple symbolic slice
|
||||
boundary = [index.start, index.stop]
|
||||
size = (index.stop - index.start).ssimplify()
|
||||
size = cast(UOp|int, cast(UOp, (stop - start)).ssimplify())
|
||||
else: raise TypeError(f"slice {index=} is not supported")
|
||||
case None: pass # do nothing
|
||||
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
|
||||
|
||||
Reference in New Issue
Block a user