split cat (on cpu) (#12864)

* split ranges but only on cpu

* except KernelOptError for threads

* use GROUP and END

* no more flatten_range needed

* remove noop end

* always process replay for openpilot

* update test

* skip test

* fix in outs calculation

With the new linearizer the toposort is a problem, this matches the spec
now

* undo that
This commit is contained in:
Sieds Lykles
2025-10-28 07:55:19 +01:00
committed by GitHub
parent 3b82dee625
commit e110f4632a
5 changed files with 22 additions and 3 deletions

View File

@@ -155,6 +155,7 @@ class TestLinearizer(unittest.TestCase):
assert stores[1].src[1].dtype == dtypes.float
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
@unittest.skipIf(Device.DEFAULT=="CPU", "CPU splits the cat so cant upcast")
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)

View File

@@ -15,7 +15,7 @@ from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_ex
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
@@ -43,6 +43,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
# split store range (only on CPU for now)
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren)

View File

@@ -181,7 +181,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
if threads > k.ren.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
for axis in k.axes_of(AxisType.LOOP):
if k.full_shape[axis] % threads == 0:
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
try: k.apply_opt(Opt(OptOps.THREAD, axis, threads))
except KernelOptError: pass
break
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break

View File

@@ -1,6 +1,6 @@
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
from tinygrad.uop.symbolic import symbolic_flat
from tinygrad.helpers import partition
from tinygrad.helpers import partition, dedup
from tinygrad.dtype import dtypes
def flatten_range(r:UOp):
@@ -136,3 +136,16 @@ pm_load_collapse = PatternMatcher([
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
])
def cut_store_range(ctx, store:UOp, r:UOp):
# only cut ranges on CPU for now
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
pm_split_store = pm_flatten_range+PatternMatcher([
(UPat(Ops.END, src=(UPat(Ops.STORE, name="store"), UPat.var("r"))), cut_store_range),
])

View File

@@ -550,6 +550,7 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
(UPat(Ops.END, src=(UPat(Ops.NOOP, name="noop"),), allow_any_len=True), lambda noop:noop),
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)