mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
test removing expander rules [pr] (#13994)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# this converts a lowerer program into a vectorized program
|
# this converts a lowerer program into a vectorized program
|
||||||
import functools, itertools
|
import functools, itertools
|
||||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
|
from tinygrad.helpers import dedup, flatten, all_same, prod, partition
|
||||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
|
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
|
||||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ def end_unrolls(u:UOp):
|
|||||||
return u.replace(src=(ret,)+tuple(src))
|
return u.replace(src=(ret,)+tuple(src))
|
||||||
|
|
||||||
expander = PatternMatcher([
|
expander = PatternMatcher([
|
||||||
# push broadcast through AFTER
|
# push broadcast through AFTER/END
|
||||||
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
|
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
|
||||||
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
|
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
|
||||||
# END on UNROLL ends the UNROLL
|
# END on UNROLL ends the UNROLL
|
||||||
@@ -97,14 +97,8 @@ expander = PatternMatcher([
|
|||||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||||
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||||
# BARRIERs aren't actually expanded
|
|
||||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
|
||||||
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
|
||||||
# empty UNROLL is NOOP
|
# empty UNROLL is NOOP
|
||||||
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||||
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
|
||||||
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
|
||||||
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# ****
|
# ****
|
||||||
|
|||||||
Reference in New Issue
Block a user