diff --git a/test/test_schedule.py b/test/test_schedule.py index d888bfadc3..e8fe0ab5f7 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -89,14 +89,20 @@ class TestSchedule(unittest.TestCase): with self.assertRaises(RecursionError): with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1) - # grouper error - @unittest.expectedFailure + # when we're fusing a reduce, all ReduceOps must have the same N in the dimensions + # all permutes, reshapes, expands and shrinks push through the reduce def test_arange_sum(self): a = Tensor.arange(6).reshape(3, 2).sum(axis=1) with Context(FUSE_ARANGE=1): run_schedule(check_schedule(a, 1)) self.assertListEqual(a.tolist(), [1, 5, 9]) + def test_permute_arange(self): + a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1) + with Context(FUSE_ARANGE=1): + run_schedule(check_schedule(a, 1)) + self.assertListEqual(a.tolist(), [[15]]) + @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") def test_error_on_device_mismatch(self): a = Tensor.empty(10) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 9e828380e4..47e06d3ece 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, can_pad, sint from tinygrad.uop.ops import track_rewrites, _substitute from tinygrad.uop.spec import type_verify, tensor_uop_spec -from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction +from tinygrad.codegen.lowerer import get_contraction_with_reduce from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP @@ -313,7 +313,9 @@ def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub Vi # change reduceop axes and input ShapeTrackers, view gets replaced with a reshape. def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False): # don't swizzle if we can push the view to children - if unwrap(view.st).contiguous and view.size == r.size: return None + if unwrap(view.st).contiguous and view.size == r.size and \ + (not (len(r.arg) == 3 and r.arg[2]) or tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1))): + return None # swizzle the input input_st = ShapeTracker.from_shape(src.shape) tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) @@ -331,10 +333,7 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False): def reduceop_view_right(src:UOp, v:UOp, r:UOp): assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" - if (contraction:=get_contraction(v.shape, src.shape)) is None: return None - new_axis: list[int] = [] - for i,pairs in enumerate(contraction): - if any(x in r.axis_arg for x in pairs): new_axis.append(i) + new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u] return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape) def elementwise_view_right(root:UOp):