mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete the arange dim mismatch restriction (#11568)
* delete the arange dim mismatch restriction * skip that test race
This commit is contained in:
@@ -1727,7 +1727,8 @@ class TestIndexing(unittest.TestCase):
|
||||
s = Tensor.schedule(*lst)
|
||||
lowered = [x[1] for x in lower_schedule(s.copy())]
|
||||
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
if FUSE_ARANGE and len(kernels) != cnt:
|
||||
raise KernelCountException(f"{len(kernels)} != {cnt}")
|
||||
for ei in lowered: ei.run(do_update_stats=True)
|
||||
return s
|
||||
|
||||
@@ -1741,7 +1742,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_simple_indexing_alt(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[[1, 2], [1, 2]]
|
||||
self.check_schedule(xt, 5)
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
def test_advanced_indexing(self):
|
||||
@@ -1753,13 +1754,13 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_advanced_indexing_alt(self):
|
||||
X = Tensor.arange(6).reshape(3, 2)+1
|
||||
xt = X[[Tensor([2]), Tensor([1])]]
|
||||
self.check_schedule(xt, 6)
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[1:2, [1, 2]]
|
||||
self.check_schedule(xt, 4)
|
||||
self.check_schedule(xt, 2)
|
||||
|
||||
def test_push_through_reshape(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import unittest, base64, functools
|
||||
import unittest, base64, functools, sys
|
||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
||||
class TestLLMTokenizer(unittest.TestCase):
|
||||
@functools.cached_property
|
||||
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
|
||||
|
||||
@@ -26,8 +26,8 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
||||
# realize parents of COPY, MSELECT, MSTACK
|
||||
|
||||
@@ -257,8 +257,6 @@ def fuse_arange(root:UOp):
|
||||
for u in toposort:
|
||||
for s in u.src: local_children.setdefault(s, []).append(u)
|
||||
fuse_rep: dict[UOp, UOp] = {}
|
||||
# skip if root depends on aranges with different ndims. This can be improved
|
||||
if any(len(set(dims)) > 1 for dims in zip(*[r.src[0].shape for r in local_arange])): return
|
||||
for r in local_arange:
|
||||
# skip if already fused
|
||||
if len(r.arg) > 2: continue
|
||||
|
||||
Reference in New Issue
Block a user