mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix indexing out of bounds (#6208)
* fix indeing out of bounds * 5 ops per access is fine
This commit is contained in:
@@ -125,8 +125,15 @@ class TestIndexing(unittest.TestCase):
|
||||
@unittest.skip("not ready")
|
||||
def test_index_fused_opt(self): self.test_index_fused(0)
|
||||
|
||||
def test_index_fused_out_of_bounds(self):
|
||||
dataset = Tensor.rand(256, 256).realize()
|
||||
idxs = Tensor([-19238, -257, 256, 495, 10982377]).realize()
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
X = dataset[idxs]
|
||||
np.testing.assert_equal(X.numpy(), 0)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_index_mnist(self, noopt=1):
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*5):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
@@ -134,14 +141,14 @@ class TestIndexing(unittest.TestCase):
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
x = X_train[samples].numpy()
|
||||
y = Y_train[samples].numpy()
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
||||
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
|
||||
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
|
||||
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
|
||||
@unittest.skip("not ready")
|
||||
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_llama_embedding(self, noopt=1, op_limit=0):
|
||||
def test_llama_embedding(self, noopt=1, op_limit=100):
|
||||
# llama3 is 128256
|
||||
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
|
||||
@@ -197,7 +197,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
|
||||
|
||||
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
||||
if rng not in reduce.src: return None
|
||||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)),)+
|
||||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
|
||||
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
|
||||
# this is symbolic 2.0
|
||||
|
||||
Reference in New Issue
Block a user