diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index a3eb42bf0d..b12dc147e7 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -1,7 +1,7 @@ # this converts a lowerer program into a vectorized program import functools, itertools from tinygrad.dtype import dtypes, PtrDType, AddrSpace -from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition +from tinygrad.helpers import dedup, flatten, all_same, prod, partition from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start from tinygrad.schedule.rangeify import BufferizeOpts @@ -82,7 +82,7 @@ def end_unrolls(u:UOp): return u.replace(src=(ret,)+tuple(src)) expander = PatternMatcher([ - # push broadcast through AFTER + # push broadcast through AFTER/END (UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))), (UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))), # END on UNROLL ends the UNROLL @@ -97,14 +97,8 @@ expander = PatternMatcher([ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE, Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), - # BARRIERs aren't actually expanded - (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)), - lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)), # empty UNROLL is NOOP (UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x), - # UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU - (UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), - lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) # ****