From 250cb10e8f3eb1d0e6835cef0fbbd5580e75743d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 29 Sep 2025 07:27:57 +0300 Subject: [PATCH] rangeify permuted assign (#12299) * enable RANGEIFY=1 test_assign * work * rangeify=0 asserts this ast * remove that * beta test, it's correct though * skip multi * matches torch/np output * memcopy without memcopy * can remove this * rangeify isn't silently wrong anymore * diff cleanup * use UOp toposort instead of global tags * actual assert TestRangeifyAssign * step * work * this isn't optimizing away now * some todos * test fusion schedule * typo * dedup idxs * cleaner * pre * work * diff --- .github/workflows/test.yml | 5 +++-- test/test_assign.py | 36 ++++++++++++++++++++++++++++------- test/test_rangeify.py | 2 ++ test/test_schedule.py | 12 ++++++++++++ tinygrad/schedule/rangeify.py | 12 ++++++++---- 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 23c075ca7d..bb5742a146 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -524,12 +524,13 @@ jobs: - name: Test CPU=1 RANGEIFY=1 # TODO: add more passing tests here # test_instancenorm_3d is very slow + # rangeify diamond cycle gives the wrong answer run: | CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \ - -k "not test_instancenorm_3d" \ + -k "not test_instancenorm_3d and not test_assign_diamond_cycle" \ test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_tensor_variable.py \ test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py \ - test/test_setitem.py + test/test_setitem.py test/test_assign.py - name: Test const folding run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding" - name: Test multitensor diff --git a/test/test_assign.py b/test/test_assign.py index 837e4141e1..b35223fc12 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -1,9 +1,10 @@ #!/usr/bin/env python import unittest +import contextlib import numpy as np from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable from tinygrad.device import is_dtype_supported -from tinygrad.helpers import temp +from tinygrad.helpers import temp, RANGEIFY N = 200 # has to be bigger than the cache to fail @@ -254,6 +255,8 @@ class TestAssign(unittest.TestCase): b.assign(a.contiguous()).realize() assert GlobalCounters.kernel_count - kc == 2 + # passing in RANGEIFY=1, RANGEIFY=0 asserts permuted assigns it can't fuse + def assert_permuted_assign(self): return self.assertRaisesRegex(RuntimeError, "contiguous") if not RANGEIFY else contextlib.nullcontext() def test_permuted_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) @@ -277,7 +280,7 @@ class TestAssign(unittest.TestCase): #GlobalCounters.cache = [] ba1 = a.uop.base.realized # noqa: F841 bb1 = b.uop.base.realized # noqa: F841 - with self.assertRaisesRegex(RuntimeError, "contiguous"): + with self.assert_permuted_assign(): a.assign(a.permute(1,0) + b) # this should not work! a.realize() ba2 = a.uop.base.realized # noqa: F841 @@ -285,6 +288,22 @@ class TestAssign(unittest.TestCase): #assert ba1 == ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + @unittest.skipUnless(RANGEIFY, "only correct in rangeify") + def test_post_permuted_assignment_alt(self): + a = Tensor.arange(N*N).reshape(N,N).contiguous().realize() + b = Tensor.arange(N*N).reshape(N,N).contiguous().realize() + new_a = (a.T+b).numpy() + a.assign(a.T+b) + np.testing.assert_allclose(a.numpy(), new_a) + + def test_post_reshape_assignment_fine(self): + a = Tensor.arange(N*N).reshape(N, N).contiguous().realize() + b = Tensor.arange(N*N).reshape(N, N).contiguous().realize() + rhs = a.reshape(-1).reshape(N, N) + new_a = (rhs+b).numpy() + a.assign(rhs+b) # self-assign with reshape view is fine + np.testing.assert_allclose(a.numpy(), new_a) + @unittest.skip("multi output not supported anymore") def test_simple_assignment_multioutput(self): a = Tensor.randn(32, 32).realize() @@ -309,8 +328,8 @@ class TestAssign(unittest.TestCase): def test_permuted_assignment_correct(self): a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() - # TODO: scheduler limitation, should NOT raise AssertionError from numpy. - with self.assertRaisesRegex(RuntimeError, "contiguous"): + # TODO: swizzler.py limitation, should NOT raise AssertionError from numpy. + with self.assert_permuted_assign(): a = a.permute(1, 0) new_val = a + b a.assign(new_val) @@ -319,10 +338,11 @@ class TestAssign(unittest.TestCase): def test_permuted_reduceop_child_dual_use(self): a = Tensor.randn(32, 32, 32).realize() b = Tensor.full((32, 32), 1.).contiguous().realize() - with self.assertRaisesRegex(RuntimeError, "contiguous"): + with self.assert_permuted_assign(): r = a.sum(axis=1) b.assign(r + b.permute(1, 0)) b.realize() + np.testing.assert_allclose(b.numpy(), a.numpy().sum(axis=1)+np.ones((32, 32)).transpose(1, 0), atol=1e-6, rtol=1e-3) @unittest.skip("multi output not supported anymore") def test_permuted_reduceop_multioutput_dual_use(self): @@ -359,15 +379,17 @@ class TestAssign(unittest.TestCase): a.assign(a + b) kc = GlobalCounters.kernel_count a.realize() - assert GlobalCounters.kernel_count - kc == 1 + # rangeify makes two kernels + assert GlobalCounters.kernel_count - kc == (2 if RANGEIFY else 1) np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2)) def test_permuted_assignment_masked_view_not_contiguous(self): a = Tensor.ones(4, 4).contiguous().realize() - with self.assertRaisesRegex(RuntimeError, "contiguous"): + with self.assert_permuted_assign(): b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0) a.assign(a + b) a.realize() + self.assertListEqual(a.tolist(), [[2.,2.,2.,2.],[2.,2.,2.,2.],[3.,3.,3.,3.], [3.,3.,3.,3.]]) # TODO: is there a way to sneak in a permute such that it returns the wrong answer? diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 3c86f1e32d..302d0ebd43 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -15,6 +15,8 @@ class TestRangeifyAssign(unittest.TestCase): print(lst) print(lst2) print(lst3) + self.assertListEqual(lst, lst3) + self.assertListEqual(lst2, B.permute(1, 0).tolist()) N = 256 diff --git a/test/test_schedule.py b/test/test_schedule.py index 3235d8b148..2c29782105 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1903,6 +1903,18 @@ class TestSchedule(unittest.TestCase): # NOTE: this is a bug on non rangeify np.testing.assert_equal(tst.numpy(), a.numpy()) + def test_setitem_sched(self, transpose=False): + a = Tensor.arange(16, device="CPU").reshape(4, 4).contiguous().realize() + a2 = a.T if transpose else a + expected = (a+a2).tolist() + a.assign(a+a2) + kcount = len(sched:=a.schedule()) + run_schedule(sched) + self.assertListEqual(a.tolist(), expected) + self.assertEqual(kcount, 2 if transpose else 1) + @unittest.skipUnless(RANGEIFY>0, "this asserts on non rangeify") + def test_setitem_permuted_sched(self): self.test_setitem_sched(transpose=True) + def test_sparse_categorical_crossentropy_simple(self): X = Tensor([[0, 2, 3], [1, 2, 3]]).realize() Y = Tensor([1, 2]).realize() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 743d17d54d..5cb5ea45a9 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -13,6 +13,10 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff +ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, + Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, + Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD} + double_reshape = PatternMatcher([ # RESHAPE on RESHAPE is the second reshape (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), @@ -42,6 +46,10 @@ earliest_rewrites = double_reshape+PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"), lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None), + # realize before assign if input permutes the target buffer + (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), lambda a,b,assign: assign.replace(src=(a, b.contiguous())) \ + if any(x.base is a.base and x is not a for x in b.toposort(gate=lambda x:x.op not in ALWAYS_CONTIGUOUS)) else None), + # copy only to different device (UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None), @@ -58,10 +66,6 @@ earliest_rewrites = double_reshape+PatternMatcher([ # ***************** # 1. add realize where we have to -ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, - Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, - Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD} - def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None: