test removing expander rules [pr] (#13994)

This commit is contained in:
chenyu
2026-01-03 12:38:01 -05:00
committed by GitHub
parent 35c2870b1f
commit c1b8644a3f

View File

@@ -1,7 +1,7 @@
# this converts a lowerer program into a vectorized program
import functools, itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
from tinygrad.helpers import dedup, flatten, all_same, prod, partition
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
from tinygrad.schedule.rangeify import BufferizeOpts
@@ -82,7 +82,7 @@ def end_unrolls(u:UOp):
return u.replace(src=(ret,)+tuple(src))
expander = PatternMatcher([
# push broadcast through AFTER
# push broadcast through AFTER/END
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
# END on UNROLL ends the UNROLL
@@ -97,14 +97,8 @@ expander = PatternMatcher([
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
# empty UNROLL is NOOP
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
# ****