mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
expand DEFINE_ACC [pr] (#7461)
This commit is contained in:
@@ -414,9 +414,11 @@ expander = PatternMatcher([
|
||||
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX,
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(UPat(UOps.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
# remove EXPANDs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
|
||||
@@ -516,12 +518,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# initial symbolic + migrate indexing (remove this) + transcendental
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
# devectorize + load_store_indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user