diff --git a/test/test_arange.py b/test/test_arange.py index bbefe8d2d0..2b36197ed1 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -125,8 +125,15 @@ class TestIndexing(unittest.TestCase): @unittest.skip("not ready") def test_index_fused_opt(self): self.test_index_fused(0) + def test_index_fused_out_of_bounds(self): + dataset = Tensor.rand(256, 256).realize() + idxs = Tensor([-19238, -257, 256, 495, 10982377]).realize() + with Context(NOOPT=1, FUSE_ARANGE=1): + X = dataset[idxs] + np.testing.assert_equal(X.numpy(), 0) + @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") - def test_index_mnist(self, noopt=1): + def test_index_mnist(self, noopt=1, op_limit=512*784*5): from tinygrad.nn.datasets import mnist X_train, Y_train, _, _ = mnist() with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0): @@ -134,14 +141,14 @@ class TestIndexing(unittest.TestCase): samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) x = X_train[samples].numpy() y = Y_train[samples].numpy() - assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}" + assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}" np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x) np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y) @unittest.skip("not ready") def test_index_mnist_opt(self): self.test_index_mnist(0) @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") - def test_llama_embedding(self, noopt=1, op_limit=0): + def test_llama_embedding(self, noopt=1, op_limit=100): # llama3 is 128256 vocab_size, embed_size = (10, 3) if CI else (32000, 4096) emb = nn.Embedding(vocab_size, embed_size) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index c0d6b95e6b..344c5c8069 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -197,7 +197,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu def index_collapse(idx,rng,buf,add,mul,ld,reduce): if rng not in reduce.src: return None - return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)),)+ + return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+ tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) # this is symbolic 2.0