enable rangeify const folding (#12181)

This commit is contained in:
George Hotz
2025-09-15 12:02:19 +08:00
committed by GitHub
parent 1353250b6c
commit 9fcc87761e

View File

@@ -3,7 +3,7 @@ import functools, operator
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
from tinygrad.uop.symbolic import sym
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.multi import multi_pm
@@ -391,6 +391,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)),
# if any CONST with DEVICE make it here (symbolic/copy issue), remove it
(UPat(Ops.DEVICE).f(Ops.CONST, name="c"), lambda c: c.replace(src=())),
])
# *****************
@@ -556,7 +558,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# rangeify
tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify")
#tsink = graph_rewrite(tsink, sym, name="symbolic") # this supports const folding
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph