remove cut_store_range (#14505)

special scheduling for CPU
This commit is contained in:
chenyu
2026-02-02 21:58:36 -05:00
committed by GitHub
parent 4f2e7aed24
commit 3c5845e8a5
3 changed files with 2 additions and 19 deletions

View File

@@ -190,7 +190,6 @@ class TestLinearizer(unittest.TestCase):
assert stores[1].src[1].dtype == dtypes.float
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
@unittest.skipIf(Device.DEFAULT=="CPU", "CPU splits the cat so cant upcast")
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)

View File

@@ -16,7 +16,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads
from tinygrad.codegen.opt.postrange import apply_opts, make_images
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
@@ -50,9 +50,6 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
# split store range (only on CPU for now)
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
# create image buffers
sink = make_images(sink, ren)

View File

@@ -1,7 +1,7 @@
import itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import partition, dedup
from tinygrad.helpers import partition
from tinygrad.dtype import dtypes, ImageDType
def flatten_range(r:UOp) -> UOp|None:
@@ -147,16 +147,3 @@ pm_load_collapse = PatternMatcher([
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
])
def cut_store_range(ctx:str, store:UOp, r:UOp) -> UOp|None:
# only cut ranges on CPU for now
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
pm_split_store = pm_flatten_range+PatternMatcher([
(UPat(Ops.END, src=(UPat(Ops.STORE, name="store"), UPat.var("r"))), cut_store_range),
])