early reduce simplify (#12046)

* early reduce simplify

* min changes

* need that

* that goes in simplify

* no more arange reduce opt
This commit is contained in:
George Hotz
2025-09-10 21:02:46 +08:00
committed by GitHub
parent 21e6926a6a
commit 9789337722
6 changed files with 20 additions and 12 deletions

View File

@@ -327,13 +327,14 @@ class TestKernelOpts(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_arange_opts(self):
a = Tensor.arange(128)
# NOTE: arange no longer has reduce ops available for opt
helper_linearizer_opt(a, [
[Opt(OptOps.GROUP, 0, 32)],
[Opt(OptOps.GROUPTOP, 0, 32)],
#[Opt(OptOps.GROUP, 0, 32)],
#[Opt(OptOps.GROUPTOP, 0, 32)],
[Opt(op=OptOps.LOCAL, axis=0, arg=8)],
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)],
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)],
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
#[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)],
#[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_threads, "test requires threads")

View File

@@ -3,7 +3,6 @@ import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
from tinygrad.helpers import CI, Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import Ops
from tinygrad.renderer.ptx import PTXRenderer
@@ -31,6 +30,9 @@ class TestArange(unittest.TestCase):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"
# reduce collapse now happens before optimizations
"""
from tinygrad.codegen.opt import Opt, OptOps
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0)
@@ -47,6 +49,7 @@ class TestArange(unittest.TestCase):
def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
"""
class TestRand(unittest.TestCase):
def test_fused_rand_less_ops(self, noopt=1):

View File

@@ -47,7 +47,8 @@ class TestLinearizerDumb(unittest.TestCase):
c10 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=())
c11 = c1.store((c4.alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c8)).cast(dtypes.int)*(c9.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, -1, src=c10), UOp.const(dtypes.int, 0, src=c10)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,)))+UOp.const(dtypes.int, 1000, src=c8))))
ast = c11.sink()
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
#opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=8)]
prg = get_program(ast, Device[Device.DEFAULT].renderer, opts)
print(prg.src)
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"

View File

@@ -23,7 +23,6 @@ class TestLinearizerRewrite(unittest.TestCase):
si = out.schedule()[-1]
opts_to_apply = []
opts_to_apply.append(Opt(OptOps.UPCAST, 0, 4))
opts_to_apply.append(Opt(OptOps.UNROLL, 0, 4))
ast = si.ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
prg = get_program(ast, Device["CPU"].renderer)
print(prg.src)

View File

@@ -14,7 +14,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, cast_foldin
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
ReduceContext, correct_load_store, pm_render, pm_reduce_simplify
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt
@@ -66,6 +66,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
# optimize (schedule) the AST
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
# ** expander (expand_rewrite) **

View File

@@ -377,13 +377,16 @@ def reduce_unparented(red:UOp):
return ret
pm_reduce = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])+sym
pm_reduce_simplify = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
])