From 8743ca40e29476a2e767f6f8541a4490026cb70f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:00:16 -0700 Subject: [PATCH] force reduce to be in axis order (#10837) * force reduce to be in axis order * disable rule causing loop * disable that rule * no ra there * only move non reduce * fix tests --- test/test_linearizer.py | 2 +- test/test_schedule.py | 2 +- tinygrad/kernelize/kernelize.py | 2 +- tinygrad/uop/ops.py | 13 ++++++++++++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f5f8f69ae3..4535ced551 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -554,7 +554,7 @@ class TestLinearizer(unittest.TestCase): idxs = Tensor([0,3,5,6]).realize() with Context(FUSE_ARANGE=1): sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast - real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1) + real_index = dataset.numpy()[idxs.numpy()].reshape(4, 256, 1, 1) helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index]) # AssertionError: repeated stores in uops diff --git a/test/test_schedule.py b/test/test_schedule.py index 1ff279a157..ae04c4690e 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1622,7 +1622,7 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape def test_conv2d(self): _test_conv2d(7) - def test_conv2d_fused(self): _test_conv2d(6, FUSE_CONV_BW=1) + def test_conv2d_fused(self): _test_conv2d(5, FUSE_CONV_BW=1) @unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong") def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half) diff --git a/tinygrad/kernelize/kernelize.py b/tinygrad/kernelize/kernelize.py index 87173581d8..0719fa206d 100644 --- a/tinygrad/kernelize/kernelize.py +++ b/tinygrad/kernelize/kernelize.py @@ -241,7 +241,7 @@ view_right = merge_views+PatternMatcher([ # apply view after reduceops (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops - (UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER}, name="root"), elementwise_view_right), + (UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right), # merge axes for double reduce (invert of SPLIT_REDUCEOP=1) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 3c299d53ca..4830df41cf 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -253,7 +253,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx) def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) - return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) + if len(axis) == 0: return self + # move any non reduce axis before the first reduce axis + move_early = [i for i in range(axis[0], len(self.shape)) if i not in axis and resolve(self.shape[i] != 1)] + if move_early: + permute = tuple(range(axis[0])) + tuple(move_early) + tuple([i for i in range(axis[0], len(self.shape)) if i not in move_early]) + ret = self.permute(permute) + new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)]) + assert len(axis) == len(new_axis) + else: + ret, new_axis = self, axis + ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis)) + return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def contiguous(self): return self.alu(Ops.CONTIGUOUS)