indexing fusion 2 (#5888)

* arange fusion

* kernels that fuse

* tests
This commit is contained in:
qazal
2024-08-03 18:13:39 +08:00
committed by GitHub
parent af59b2eea9
commit 65fa86901a
4 changed files with 94 additions and 18 deletions

View File

@@ -23,7 +23,7 @@ class TestIndexing(unittest.TestCase):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1
needle.realize()
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
# TODO: it should work without these reshapes
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
@@ -38,7 +38,7 @@ class TestIndexing(unittest.TestCase):
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
@@ -72,12 +72,12 @@ class TestIndexing(unittest.TestCase):
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
X = dataset[idxs]
assert X.shape == (4,256)
sched = X.schedule()
assert len(sched) == 1
assert len(sched) == 2
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
np.testing.assert_allclose(real_index, X.numpy())

View File

@@ -9,7 +9,7 @@ from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
from tinygrad.helpers import DEBUG, flatten, getenv
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -1268,36 +1268,41 @@ class TestSchedule(unittest.TestCase):
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Tensor, cnt:int):
s = xt.schedule()
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
run_schedule(s)
self.assertEqual(kernel_cnt, cnt)
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
s = xt.schedule()
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
run_schedule(s)
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
self.check_schedule(xt, 3)
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
@unittest.expectedFailure
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 5)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
@unittest.expectedFailure
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0]]
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
@unittest.expectedFailure
def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1
xt = X[[Tensor([2]), Tensor([1])]]
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
@unittest.expectedFailure
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
@@ -1308,7 +1313,7 @@ class TestIndexing(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
self.check_schedule(out, 3)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
def test_arange_push_through_expand(self):
@@ -1316,22 +1321,82 @@ class TestIndexing(unittest.TestCase):
a = Tensor.arange(4,)
b = Tensor.randn(4, 4).realize()
out = a+b
self.check_schedule(out, 2)
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy())
def test_argmin(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
self.check_schedule(out, 3)
self.check_schedule(out, 2)
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
def test_argmax(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
self.check_schedule(out, 3)
self.check_schedule(out, 2)
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
def test_arange_transposed(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1)
a = (Tensor.arange(4,)*x).T
self.check_schedule(a, 2)
np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T)
def test_arange_transposed_descendants(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1)
a = (Tensor.arange(4,)*x).T
b = Tensor.randint(4, 4).realize()
out = a+b
self.check_schedule(out, 2)
np.testing.assert_equal(out.numpy(), (np.arange(4)*x.numpy()).T+b.numpy())
def test_arange_index(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
def test_arange_index_contiguous(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
def test_arange_index_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
def test_arange_index_contiguous_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
def test_arange_childless(self):
a = Tensor.arange(4)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4))
def test_arange_group_childless(self):
Tensor.manual_seed(0)
x = Tensor.randint(4)
a = Tensor.arange(4)+x
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
if __name__ == '__main__':
unittest.main(verbosity=2)