From 4f6cf1f8cc13d4e8660522595bec68c377bff371 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:20:43 +0700 Subject: [PATCH] expand DEFINE_ACC [pr] (#7461) --- tinygrad/codegen/uopgraph.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 30c4322808..1f007b0262 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -414,9 +414,11 @@ expander = PatternMatcher([ (UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)), lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN, UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), (UPat(UOps.CONTRACT, name="con"), do_contract), + # vectorize DEFINE_ACC + (UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)), # remove EXPANDs from SINK (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) @@ -516,12 +518,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # initial symbolic + migrate indexing (remove this) + transcendental sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)) - # expand - sink = graph_rewrite(sink, sym+expander) - # convert REDUCE to DEFINE_ACC + ASSIGN (contextual) sink = graph_rewrite(sink, sym+just_reduce, ctx=[0]) + # expand + sink = graph_rewrite(sink, sym+expander) + # devectorize + load_store_indexing sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)