mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -23,7 +23,7 @@ class TestIndexing(unittest.TestCase):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
needle.realize()
|
||||
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
# TODO: it should work without these reshapes
|
||||
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
|
||||
@@ -38,7 +38,7 @@ class TestIndexing(unittest.TestCase):
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
|
||||
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
|
||||
@@ -72,12 +72,12 @@ class TestIndexing(unittest.TestCase):
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
X = dataset[idxs]
|
||||
assert X.shape == (4,256)
|
||||
sched = X.schedule()
|
||||
assert len(sched) == 1
|
||||
assert len(sched) == 2
|
||||
run_schedule(sched)
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
|
||||
from tinygrad.helpers import DEBUG, flatten, getenv
|
||||
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -1268,36 +1268,41 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Tensor, cnt:int):
|
||||
s = xt.schedule()
|
||||
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
|
||||
run_schedule(s)
|
||||
self.assertEqual(kernel_cnt, cnt)
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
s = xt.schedule()
|
||||
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
|
||||
run_schedule(s)
|
||||
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
|
||||
|
||||
def test_simple_indexing(self):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
idxs = Tensor([0, 2]).realize()
|
||||
xt = X[idxs]
|
||||
self.check_schedule(xt, 3)
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_simple_indexing_alt(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[[1, 2], [1, 2]]
|
||||
self.check_schedule(xt, 5)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_indexing(self):
|
||||
X = Tensor.arange(10)+1
|
||||
xt = X[[0]]
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_indexing_alt(self):
|
||||
X = Tensor.arange(6).reshape(3, 2)+1
|
||||
xt = X[[Tensor([2]), Tensor([1])]]
|
||||
self.check_schedule(xt, 6)
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[1:2, [1, 2]]
|
||||
@@ -1308,7 +1313,7 @@ class TestIndexing(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(10, 20).realize()
|
||||
out = x.argmax(1)
|
||||
self.check_schedule(out, 3)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
|
||||
|
||||
def test_arange_push_through_expand(self):
|
||||
@@ -1316,22 +1321,82 @@ class TestIndexing(unittest.TestCase):
|
||||
a = Tensor.arange(4,)
|
||||
b = Tensor.randn(4, 4).realize()
|
||||
out = a+b
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy())
|
||||
|
||||
def test_argmin(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmin(-1)
|
||||
self.check_schedule(out, 3)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
|
||||
|
||||
def test_argmax(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmax(-1)
|
||||
self.check_schedule(out, 3)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
|
||||
|
||||
def test_arange_transposed(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4, 1)
|
||||
a = (Tensor.arange(4,)*x).T
|
||||
self.check_schedule(a, 2)
|
||||
np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T)
|
||||
|
||||
def test_arange_transposed_descendants(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4, 1)
|
||||
a = (Tensor.arange(4,)*x).T
|
||||
b = Tensor.randint(4, 4).realize()
|
||||
out = a+b
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_equal(out.numpy(), (np.arange(4)*x.numpy()).T+b.numpy())
|
||||
|
||||
def test_arange_index(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
|
||||
|
||||
def test_arange_index_contiguous(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
|
||||
|
||||
def test_arange_index_child(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)+1
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
|
||||
|
||||
def test_arange_index_contiguous_child(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
|
||||
|
||||
def test_arange_childless(self):
|
||||
a = Tensor.arange(4)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.arange(4))
|
||||
|
||||
def test_arange_group_childless(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4)
|
||||
a = Tensor.arange(4)+x
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user