mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
support variable shape none slice in getitem (#10724)
This commit is contained in:
@@ -197,6 +197,18 @@ class TestSymbolicJit(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
assert_jit_cache_len(jf, 1)
|
assert_jit_cache_len(jf, 1)
|
||||||
|
|
||||||
|
def test_slice_var_shape(self):
|
||||||
|
def f(a): return (a+1).realize()
|
||||||
|
jf = TinyJit(f)
|
||||||
|
for i in range(1, 5):
|
||||||
|
vi = Variable("i", 1, 10).bind(i)
|
||||||
|
a = Tensor.ones(vi, 11).contiguous()
|
||||||
|
symbolic = a[:, 1:2]
|
||||||
|
symbolic = jf(symbolic).reshape(i, 1).numpy()
|
||||||
|
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
|
||||||
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
assert_jit_cache_len(jf, 1)
|
||||||
|
|
||||||
def test_ones_sum(self):
|
def test_ones_sum(self):
|
||||||
def f(a): return a.sum().realize()
|
def f(a): return a.sum().realize()
|
||||||
jf = TinyJit(f)
|
jf = TinyJit(f)
|
||||||
|
|||||||
@@ -175,6 +175,14 @@ class TestSymbolicOps(unittest.TestCase):
|
|||||||
expected = a[3:5, i:i+2].numpy()
|
expected = a[3:5, i:i+2].numpy()
|
||||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
def test_slice_var_shape(self):
|
||||||
|
for i in range(1, 5):
|
||||||
|
vi = Variable("i", 1, 10).bind(i)
|
||||||
|
a = Tensor.ones(vi, 11).contiguous()
|
||||||
|
symbolic = a[:, 1:2].reshape(i, 1).numpy()
|
||||||
|
expected = a.reshape(i, 11)[:, 1:2].numpy()
|
||||||
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
def test_ones_sum(self):
|
def test_ones_sum(self):
|
||||||
for i in range(1, 5):
|
for i in range(1, 5):
|
||||||
vi = Variable("i", 1, 10).bind(i)
|
vi = Variable("i", 1, 10).bind(i)
|
||||||
|
|||||||
@@ -1140,7 +1140,9 @@ class Tensor(MathTrait):
|
|||||||
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
||||||
case slice():
|
case slice():
|
||||||
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
||||||
if all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
|
if all(s is None for s in (index.start,index.stop,index.step)):
|
||||||
|
boundary, stride = [0, size], 1
|
||||||
|
elif all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
|
||||||
# handle int slicing
|
# handle int slicing
|
||||||
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
||||||
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
||||||
|
|||||||
Reference in New Issue
Block a user