mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
early reduce simplify (#12046)
* early reduce simplify * min changes * need that * that goes in simplify * no more arange reduce opt
This commit is contained in:
@@ -327,13 +327,14 @@ class TestKernelOpts(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_arange_opts(self):
|
||||
a = Tensor.arange(128)
|
||||
# NOTE: arange no longer has reduce ops available for opt
|
||||
helper_linearizer_opt(a, [
|
||||
[Opt(OptOps.GROUP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
#[Opt(OptOps.GROUP, 0, 32)],
|
||||
#[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(op=OptOps.LOCAL, axis=0, arg=8)],
|
||||
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)],
|
||||
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)],
|
||||
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
|
||||
#[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)],
|
||||
#[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_threads, "test requires threads")
|
||||
|
||||
@@ -3,7 +3,6 @@ import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
|
||||
from tinygrad.helpers import CI, Context, getenv
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
@@ -31,6 +30,9 @@ class TestArange(unittest.TestCase):
|
||||
# PTX counts index ALU in flops
|
||||
assert f1 <= limit, f"{f1=}, {limit=}"
|
||||
|
||||
# reduce collapse now happens before optimizations
|
||||
"""
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0)
|
||||
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0)
|
||||
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0)
|
||||
@@ -47,6 +49,7 @@ class TestArange(unittest.TestCase):
|
||||
def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0)
|
||||
@unittest.skip("doesn't work yet")
|
||||
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
|
||||
"""
|
||||
|
||||
class TestRand(unittest.TestCase):
|
||||
def test_fused_rand_less_ops(self, noopt=1):
|
||||
|
||||
@@ -47,7 +47,8 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
c10 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=())
|
||||
c11 = c1.store((c4.alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c8)).cast(dtypes.int)*(c9.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, -1, src=c10), UOp.const(dtypes.int, 0, src=c10)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,)))+UOp.const(dtypes.int, 1000, src=c8))))
|
||||
ast = c11.sink()
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
#opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
prg = get_program(ast, Device[Device.DEFAULT].renderer, opts)
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
|
||||
|
||||
@@ -23,7 +23,6 @@ class TestLinearizerRewrite(unittest.TestCase):
|
||||
si = out.schedule()[-1]
|
||||
opts_to_apply = []
|
||||
opts_to_apply.append(Opt(OptOps.UPCAST, 0, 4))
|
||||
opts_to_apply.append(Opt(OptOps.UNROLL, 0, 4))
|
||||
ast = si.ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
prg = get_program(ast, Device["CPU"].renderer)
|
||||
print(prg.src)
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, cast_foldin
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
ReduceContext, correct_load_store, pm_render, pm_reduce_simplify
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
@@ -66,6 +66,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
|
||||
# optimize (schedule) the AST
|
||||
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
|
||||
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
|
||||
@@ -377,13 +377,16 @@ def reduce_unparented(red:UOp):
|
||||
return ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])+sym
|
||||
|
||||
pm_reduce_simplify = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user