mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
indexing getting better (#5389)
* indexing getting better [run_process_replay] [no_assert] * fix test * test_arange_2_reduce is a simpler test * put that print back, NOOPT * don't merge reduces (they could be different reduces) * FUSE_AS_ONE_KERNEL * fix tests * fix test_var_multireduce * w/e put that there * fails on others too * fix test, revert UNMUL change * in case order matters * one kernel indexing works * one kernel indexing works (test other)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, N):
|
||||
@@ -10,10 +12,54 @@ class TestArange(unittest.TestCase):
|
||||
return GlobalCounters.global_ops
|
||||
|
||||
def test_complexity(self):
|
||||
f1 = self._get_flops(256)
|
||||
f2 = self._get_flops(2560)
|
||||
# add 1 to avoid divide by 0. arange is 0 flops now!
|
||||
f1 = self._get_flops(256) + 1
|
||||
f2 = self._get_flops(2560) + 1
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def test_arange_2_reduce(self):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
needle.realize()
|
||||
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
|
||||
GlobalCounters.reset()
|
||||
# TODO: it should work without these reshapes
|
||||
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
|
||||
sched = out.schedule()
|
||||
assert len(sched) == 1
|
||||
run_schedule(sched)
|
||||
assert out.item() == 1337, f"expected 1337, got {out.item()}"
|
||||
|
||||
def test_manual_index(self):
|
||||
dataset = Tensor.rand(16384, 256).realize()
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_AS_ONE_KERNEL=1):
|
||||
GlobalCounters.reset()
|
||||
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
|
||||
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
|
||||
X = ((rng==idxs).float() * dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)).sum(axis=(2,3))
|
||||
sched = X.schedule()
|
||||
assert len(sched) == 1
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
|
||||
def test_index(self):
|
||||
dataset = Tensor.rand(16384, 256).realize()
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1):
|
||||
GlobalCounters.reset()
|
||||
X = dataset[idxs]
|
||||
assert X.shape == (4,256)
|
||||
sched = X.schedule()
|
||||
#assert len(sched) == 1
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -102,8 +102,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "remu doesn't have multiple wave syncs yet")
|
||||
@unittest.skip("still wrong")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet")
|
||||
@unittest.skip("still broken")
|
||||
def test_var_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(3, 27, 32).realize()
|
||||
|
||||
Reference in New Issue
Block a user