mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix gfr
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
@@ -74,14 +74,14 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
|
||||
ret.append(RewriteStep(sym+migrate_indexing+pm_group_for_reduce, name="postopt symbolic"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
@@ -157,6 +157,9 @@ pm_pre_expander = PatternMatcher([
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
(UPat(Ops.STORE, name="x"), fix_store_unroll),
|
||||
])
|
||||
|
||||
pm_group_for_reduce = PatternMatcher([
|
||||
# fix group for reduce
|
||||
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
|
||||
])
|
||||
])
|
||||
@@ -478,7 +478,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
|
||||
# TODO: this should be the only one of these. this is the one RANGEIFY uses
|
||||
s = self
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.MSTACK}: s = s.src[0]
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
|
||||
return s
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user