mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete the arange dim mismatch restriction (#11568)
* delete the arange dim mismatch restriction * skip that test race
This commit is contained in:
@@ -1727,7 +1727,8 @@ class TestIndexing(unittest.TestCase):
|
|||||||
s = Tensor.schedule(*lst)
|
s = Tensor.schedule(*lst)
|
||||||
lowered = [x[1] for x in lower_schedule(s.copy())]
|
lowered = [x[1] for x in lower_schedule(s.copy())]
|
||||||
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
|
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
if FUSE_ARANGE and len(kernels) != cnt:
|
||||||
|
raise KernelCountException(f"{len(kernels)} != {cnt}")
|
||||||
for ei in lowered: ei.run(do_update_stats=True)
|
for ei in lowered: ei.run(do_update_stats=True)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@@ -1741,7 +1742,7 @@ class TestIndexing(unittest.TestCase):
|
|||||||
def test_simple_indexing_alt(self):
|
def test_simple_indexing_alt(self):
|
||||||
X = Tensor.arange(16).reshape(4, 4)
|
X = Tensor.arange(16).reshape(4, 4)
|
||||||
xt = X[[1, 2], [1, 2]]
|
xt = X[[1, 2], [1, 2]]
|
||||||
self.check_schedule(xt, 5)
|
self.check_schedule(xt, 3)
|
||||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||||
|
|
||||||
def test_advanced_indexing(self):
|
def test_advanced_indexing(self):
|
||||||
@@ -1753,13 +1754,13 @@ class TestIndexing(unittest.TestCase):
|
|||||||
def test_advanced_indexing_alt(self):
|
def test_advanced_indexing_alt(self):
|
||||||
X = Tensor.arange(6).reshape(3, 2)+1
|
X = Tensor.arange(6).reshape(3, 2)+1
|
||||||
xt = X[[Tensor([2]), Tensor([1])]]
|
xt = X[[Tensor([2]), Tensor([1])]]
|
||||||
self.check_schedule(xt, 6)
|
self.check_schedule(xt, 3)
|
||||||
np.testing.assert_equal(xt.numpy(), 6)
|
np.testing.assert_equal(xt.numpy(), 6)
|
||||||
|
|
||||||
def test_advanced_simple_indexing_combined(self):
|
def test_advanced_simple_indexing_combined(self):
|
||||||
X = Tensor.arange(16).reshape(4, 4)
|
X = Tensor.arange(16).reshape(4, 4)
|
||||||
xt = X[1:2, [1, 2]]
|
xt = X[1:2, [1, 2]]
|
||||||
self.check_schedule(xt, 4)
|
self.check_schedule(xt, 2)
|
||||||
|
|
||||||
def test_push_through_reshape(self):
|
def test_push_through_reshape(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import unittest, base64, functools
|
import unittest, base64, functools, sys
|
||||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
||||||
class TestLLMTokenizer(unittest.TestCase):
|
class TestLLMTokenizer(unittest.TestCase):
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
|
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
|||||||
do_realize = PatternMatcher([
|
do_realize = PatternMatcher([
|
||||||
# always realize SINK parents
|
# always realize SINK parents
|
||||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||||
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
|
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
|
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||||
# realize before expand or unsafe pad ops
|
# realize before expand or unsafe pad ops
|
||||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
||||||
# realize parents of COPY, MSELECT, MSTACK
|
# realize parents of COPY, MSELECT, MSTACK
|
||||||
|
|||||||
@@ -257,8 +257,6 @@ def fuse_arange(root:UOp):
|
|||||||
for u in toposort:
|
for u in toposort:
|
||||||
for s in u.src: local_children.setdefault(s, []).append(u)
|
for s in u.src: local_children.setdefault(s, []).append(u)
|
||||||
fuse_rep: dict[UOp, UOp] = {}
|
fuse_rep: dict[UOp, UOp] = {}
|
||||||
# skip if root depends on aranges with different ndims. This can be improved
|
|
||||||
if any(len(set(dims)) > 1 for dims in zip(*[r.src[0].shape for r in local_arange])): return
|
|
||||||
for r in local_arange:
|
for r in local_arange:
|
||||||
# skip if already fused
|
# skip if already fused
|
||||||
if len(r.arg) > 2: continue
|
if len(r.arg) > 2: continue
|
||||||
|
|||||||
Reference in New Issue
Block a user