expand DEFINE_ACC [pr] (#7461)

This commit is contained in:
George Hotz
2024-11-01 14:20:43 +07:00
committed by GitHub
parent d9f38f9518
commit 4f6cf1f8cc

View File

@@ -414,9 +414,11 @@ expander = PatternMatcher([
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX,
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN,
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(UPat(UOps.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
# remove EXPANDs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
@@ -516,12 +518,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# initial symbolic + migrate indexing (remove this) + transcendental
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
# expand
sink = graph_rewrite(sink, sym+expander)
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
# expand
sink = graph_rewrite(sink, sym+expander)
# devectorize + load_store_indexing
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)