This commit is contained in:
George Hotz
2025-10-02 12:58:31 +08:00
parent 3da569c20b
commit 3fd25a425b
3 changed files with 10 additions and 7 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
@@ -74,14 +74,14 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
# expand
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
ret.append(RewriteStep(sym+migrate_indexing+pm_group_for_reduce, name="postopt symbolic"))
# add locals
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
# expand
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))

View File

@@ -157,6 +157,9 @@ pm_pre_expander = PatternMatcher([
# fix REDUCEs with UNROLLs
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
(UPat(Ops.STORE, name="x"), fix_store_unroll),
])
pm_group_for_reduce = PatternMatcher([
# fix group for reduce
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
])
])

View File

@@ -478,7 +478,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
# TODO: this should be the only one of these. this is the one RANGEIFY uses
s = self
while len(s.src) and s.op not in {Ops.BUFFER, Ops.MSTACK}: s = s.src[0]
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
return s
@property