indexing getting better (#5389)

* indexing getting better [run_process_replay] [no_assert]

* fix test

* test_arange_2_reduce is a simpler test

* put that print back, NOOPT

* don't merge reduces (they could be different reduces)

* FUSE_AS_ONE_KERNEL

* fix tests

* fix test_var_multireduce

* w/e put that there

* fails on others too

* fix test, revert UNMUL change

* in case order matters

* one kernel indexing works

* one kernel indexing works (test other)
This commit is contained in:
George Hotz
2024-07-11 16:41:51 -07:00
committed by GitHub
parent 9712d9ffb6
commit c2da4454cd
8 changed files with 92 additions and 18 deletions

View File

@@ -1,6 +1,8 @@
import unittest
from tinygrad import Tensor, GlobalCounters
import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.helpers import Context
from tinygrad.engine.realize import run_schedule
class TestArange(unittest.TestCase):
def _get_flops(self, N):
@@ -10,10 +12,54 @@ class TestArange(unittest.TestCase):
return GlobalCounters.global_ops
def test_complexity(self):
f1 = self._get_flops(256)
f2 = self._get_flops(2560)
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256) + 1
f2 = self._get_flops(2560) + 1
print(f"{f1=}, {f2=}")
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
class TestIndexing(unittest.TestCase):
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1
needle.realize()
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
GlobalCounters.reset()
# TODO: it should work without these reshapes
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
sched = out.schedule()
assert len(sched) == 1
run_schedule(sched)
assert out.item() == 1337, f"expected 1337, got {out.item()}"
def test_manual_index(self):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
GlobalCounters.reset()
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
X = ((rng==idxs).float() * dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)).sum(axis=(2,3))
sched = X.schedule()
assert len(sched) == 1
run_schedule(sched)
np.testing.assert_allclose(real_index, X.numpy())
def test_index(self):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1):
GlobalCounters.reset()
X = dataset[idxs]
assert X.shape == (4,256)
sched = X.schedule()
#assert len(sched) == 1
run_schedule(sched)
np.testing.assert_allclose(real_index, X.numpy())
if __name__ == "__main__":
unittest.main()

View File

@@ -102,8 +102,8 @@ class TestLinearizer(unittest.TestCase):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "remu doesn't have multiple wave syncs yet")
@unittest.skip("still wrong")
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet")
@unittest.skip("still broken")
def test_var_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(3, 27, 32).realize()