diff --git a/test/test_schedule.py b/test/test_schedule.py index 87530de80e..59aba3ee74 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1727,7 +1727,8 @@ class TestIndexing(unittest.TestCase): s = Tensor.schedule(*lst) lowered = [x[1] for x in lower_schedule(s.copy())] kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)] - if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) + if FUSE_ARANGE and len(kernels) != cnt: + raise KernelCountException(f"{len(kernels)} != {cnt}") for ei in lowered: ei.run(do_update_stats=True) return s @@ -1741,7 +1742,7 @@ class TestIndexing(unittest.TestCase): def test_simple_indexing_alt(self): X = Tensor.arange(16).reshape(4, 4) xt = X[[1, 2], [1, 2]] - self.check_schedule(xt, 5) + self.check_schedule(xt, 3) np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]]) def test_advanced_indexing(self): @@ -1753,13 +1754,13 @@ class TestIndexing(unittest.TestCase): def test_advanced_indexing_alt(self): X = Tensor.arange(6).reshape(3, 2)+1 xt = X[[Tensor([2]), Tensor([1])]] - self.check_schedule(xt, 6) + self.check_schedule(xt, 3) np.testing.assert_equal(xt.numpy(), 6) def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) xt = X[1:2, [1, 2]] - self.check_schedule(xt, 4) + self.check_schedule(xt, 2) def test_push_through_reshape(self): Tensor.manual_seed(0) diff --git a/test/unit/test_llm_tokenizer.py b/test/unit/test_llm_tokenizer.py index fca60bd7bb..7b65818a6f 100644 --- a/test/unit/test_llm_tokenizer.py +++ b/test/unit/test_llm_tokenizer.py @@ -1,7 +1,8 @@ -import unittest, base64, functools +import unittest, base64, functools, sys from tinygrad.apps.llm import SimpleTokenizer, get_llama_re from tinygrad.helpers import fetch +@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows") class TestLLMTokenizer(unittest.TestCase): @functools.cached_property def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "": 5, "": 6, "": 7 }) diff --git a/tinygrad/schedule/grouper.py b/tinygrad/schedule/grouper.py index af62971b44..70682ad685 100644 --- a/tinygrad/schedule/grouper.py +++ b/tinygrad/schedule/grouper.py @@ -26,8 +26,8 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None: do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), - # always realize ASSIGN/CONTIGUOUS/GroupOp.Meta - (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize), + # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW + (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), # realize before expand or unsafe pad ops (UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view), # realize parents of COPY, MSELECT, MSTACK diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index ad54e205e7..687a37c14f 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -257,8 +257,6 @@ def fuse_arange(root:UOp): for u in toposort: for s in u.src: local_children.setdefault(s, []).append(u) fuse_rep: dict[UOp, UOp] = {} - # skip if root depends on aranges with different ndims. This can be improved - if any(len(set(dims)) > 1 for dims in zip(*[r.src[0].shape for r in local_arange])): return for r in local_arange: # skip if already fused if len(r.arg) > 2: continue