mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
push permutes through fused reduces (#10628)
* fix pushing reshapes through reduceops * reduceop_view_right should assert on ndims mismatch * update that, view.reshape asserts it
This commit is contained in:
@@ -89,14 +89,20 @@ class TestSchedule(unittest.TestCase):
|
||||
with self.assertRaises(RecursionError):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1)
|
||||
|
||||
# grouper error
|
||||
@unittest.expectedFailure
|
||||
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
|
||||
# all permutes, reshapes, expands and shrinks push through the reduce
|
||||
def test_arange_sum(self):
|
||||
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [1, 5, 9])
|
||||
|
||||
def test_permute_arange(self):
|
||||
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [[15]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
||||
def test_error_on_device_mismatch(self):
|
||||
a = Tensor.empty(10)
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, can_pad, sint
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
|
||||
@@ -313,7 +313,9 @@ def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub Vi
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# don't swizzle if we can push the view to children
|
||||
if unwrap(view.st).contiguous and view.size == r.size: return None
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
@@ -331,10 +333,7 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
if (contraction:=get_contraction(v.shape, src.shape)) is None: return None
|
||||
new_axis: list[int] = []
|
||||
for i,pairs in enumerate(contraction):
|
||||
if any(x in r.axis_arg for x in pairs): new_axis.append(i)
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
|
||||
Reference in New Issue
Block a user