delete the arange dim mismatch restriction (#11568)

* delete the arange dim mismatch restriction

* skip that test race
This commit is contained in:
George Hotz
2025-08-07 13:46:17 -07:00
committed by GitHub
parent 7ae4335127
commit 6ed2dfd187
4 changed files with 9 additions and 9 deletions

View File

@@ -1727,7 +1727,8 @@ class TestIndexing(unittest.TestCase):
s = Tensor.schedule(*lst) s = Tensor.schedule(*lst)
lowered = [x[1] for x in lower_schedule(s.copy())] lowered = [x[1] for x in lower_schedule(s.copy())]
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)] kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) if FUSE_ARANGE and len(kernels) != cnt:
raise KernelCountException(f"{len(kernels)} != {cnt}")
for ei in lowered: ei.run(do_update_stats=True) for ei in lowered: ei.run(do_update_stats=True)
return s return s
@@ -1741,7 +1742,7 @@ class TestIndexing(unittest.TestCase):
def test_simple_indexing_alt(self): def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4) X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]] xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 5) self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]]) np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
def test_advanced_indexing(self): def test_advanced_indexing(self):
@@ -1753,13 +1754,13 @@ class TestIndexing(unittest.TestCase):
def test_advanced_indexing_alt(self): def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1 X = Tensor.arange(6).reshape(3, 2)+1
xt = X[[Tensor([2]), Tensor([1])]] xt = X[[Tensor([2]), Tensor([1])]]
self.check_schedule(xt, 6) self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), 6) np.testing.assert_equal(xt.numpy(), 6)
def test_advanced_simple_indexing_combined(self): def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4) X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]] xt = X[1:2, [1, 2]]
self.check_schedule(xt, 4) self.check_schedule(xt, 2)
def test_push_through_reshape(self): def test_push_through_reshape(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)

View File

@@ -1,7 +1,8 @@
import unittest, base64, functools import unittest, base64, functools, sys
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
from tinygrad.helpers import fetch from tinygrad.helpers import fetch
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
class TestLLMTokenizer(unittest.TestCase): class TestLLMTokenizer(unittest.TestCase):
@functools.cached_property @functools.cached_property
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 }) def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })

View File

@@ -26,8 +26,8 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
do_realize = PatternMatcher([ do_realize = PatternMatcher([
# always realize SINK parents # always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize), (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize before expand or unsafe pad ops # realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view), (UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize parents of COPY, MSELECT, MSTACK # realize parents of COPY, MSELECT, MSTACK

View File

@@ -257,8 +257,6 @@ def fuse_arange(root:UOp):
for u in toposort: for u in toposort:
for s in u.src: local_children.setdefault(s, []).append(u) for s in u.src: local_children.setdefault(s, []).append(u)
fuse_rep: dict[UOp, UOp] = {} fuse_rep: dict[UOp, UOp] = {}
# skip if root depends on aranges with different ndims. This can be improved
if any(len(set(dims)) > 1 for dims in zip(*[r.src[0].shape for r in local_arange])): return
for r in local_arange: for r in local_arange:
# skip if already fused # skip if already fused
if len(r.arg) > 2: continue if len(r.arg) > 2: continue