push permutes through fused reduces (#10628)

* fix pushing reshapes through reduceops

* reduceop_view_right should assert on ndims mismatch

* update that, view.reshape asserts it
This commit is contained in:
qazal
2025-06-05 16:14:04 +03:00
committed by GitHub
parent 8db0ba1161
commit 8c5ea00522
2 changed files with 13 additions and 8 deletions

View File

@@ -89,14 +89,20 @@ class TestSchedule(unittest.TestCase):
with self.assertRaises(RecursionError):
with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1)
# grouper error
@unittest.expectedFailure
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
# all permutes, reshapes, expands and shrinks push through the reduce
def test_arange_sum(self):
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [1, 5, 9])
def test_permute_arange(self):
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [[15]])
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
def test_error_on_device_mismatch(self):
a = Tensor.empty(10)

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, can_pad, sint
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
from tinygrad.codegen.lowerer import get_contraction_with_reduce
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
@@ -313,7 +313,9 @@ def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub Vi
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
# don't swizzle if we can push the view to children
if unwrap(view.st).contiguous and view.size == r.size: return None
if unwrap(view.st).contiguous and view.size == r.size and \
(not (len(r.arg) == 3 and r.arg[2]) or tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1))):
return None
# swizzle the input
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
@@ -331,10 +333,7 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
if (contraction:=get_contraction(v.shape, src.shape)) is None: return None
new_axis: list[int] = []
for i,pairs in enumerate(contraction):
if any(x in r.axis_arg for x in pairs): new_axis.append(i)
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
def elementwise_view_right(root:UOp):