mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
test removing expander rules [pr] (#13994)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# this converts a lowerer program into a vectorized program
|
||||
import functools, itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
|
||||
from tinygrad.helpers import dedup, flatten, all_same, prod, partition
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
|
||||
@@ -82,7 +82,7 @@ def end_unrolls(u:UOp):
|
||||
return u.replace(src=(ret,)+tuple(src))
|
||||
|
||||
expander = PatternMatcher([
|
||||
# push broadcast through AFTER
|
||||
# push broadcast through AFTER/END
|
||||
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
|
||||
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
|
||||
# END on UNROLL ends the UNROLL
|
||||
@@ -97,14 +97,8 @@ expander = PatternMatcher([
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
||||
# empty UNROLL is NOOP
|
||||
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
# ****
|
||||
|
||||
Reference in New Issue
Block a user