support variable shape none slice in getitem (#10724)

This commit is contained in:
wozeparrot
2025-06-09 11:53:02 -07:00
committed by GitHub
parent afd5140a09
commit 27dd97f688
3 changed files with 23 additions and 1 deletions

View File

@@ -197,6 +197,18 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1) assert_jit_cache_len(jf, 1)
def test_slice_var_shape(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.ones(vi, 11).contiguous()
symbolic = a[:, 1:2]
symbolic = jf(symbolic).reshape(i, 1).numpy()
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_ones_sum(self): def test_ones_sum(self):
def f(a): return a.sum().realize() def f(a): return a.sum().realize()
jf = TinyJit(f) jf = TinyJit(f)

View File

@@ -175,6 +175,14 @@ class TestSymbolicOps(unittest.TestCase):
expected = a[3:5, i:i+2].numpy() expected = a[3:5, i:i+2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_slice_var_shape(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.ones(vi, 11).contiguous()
symbolic = a[:, 1:2].reshape(i, 1).numpy()
expected = a.reshape(i, 11)[:, 1:2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self): def test_ones_sum(self):
for i in range(1, 5): for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i) vi = Variable("i", 1, 10).bind(i)

View File

@@ -1140,7 +1140,9 @@ class Tensor(MathTrait):
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
case slice(): case slice():
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step") if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
if all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)): if all(s is None for s in (index.start,index.stop,index.step)):
boundary, stride = [0, size], 1
elif all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
# handle int slicing # handle int slicing
*boundary, stride = index.indices(cast(SupportsIndex, size)) *boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]