fix indexing out of bounds (#6208)

* fix indeing out of bounds

* 5 ops per access is fine
This commit is contained in:
George Hotz
2024-08-20 11:34:56 -07:00
committed by GitHub
parent 4451bcaf95
commit a5d79688db
2 changed files with 11 additions and 4 deletions

View File

@@ -125,8 +125,15 @@ class TestIndexing(unittest.TestCase):
@unittest.skip("not ready")
def test_index_fused_opt(self): self.test_index_fused(0)
def test_index_fused_out_of_bounds(self):
dataset = Tensor.rand(256, 256).realize()
idxs = Tensor([-19238, -257, 256, 495, 10982377]).realize()
with Context(NOOPT=1, FUSE_ARANGE=1):
X = dataset[idxs]
np.testing.assert_equal(X.numpy(), 0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1):
def test_index_mnist(self, noopt=1, op_limit=512*784*5):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
@@ -134,14 +141,14 @@ class TestIndexing(unittest.TestCase):
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
x = X_train[samples].numpy()
y = Y_train[samples].numpy()
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
@unittest.skip("not ready")
def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_llama_embedding(self, noopt=1, op_limit=0):
def test_llama_embedding(self, noopt=1, op_limit=100):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size)

View File

@@ -197,7 +197,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
if rng not in reduce.src: return None
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)),)+
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
# this is symbolic 2.0