mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
split cat (on cpu) (#12864)
* split ranges but only on cpu * except KernelOptError for threads * use GROUP and END * no more flatten_range needed * remove noop end * always process replay for openpilot * update test * skip test * fix in outs calculation With the new linearizer the toposort is a problem, this matches the spec now * undo that
This commit is contained in:
@@ -155,6 +155,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert stores[1].src[1].dtype == dtypes.float
|
||||
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="CPU", "CPU splits the cat so cant upcast")
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_ex
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
@@ -43,6 +43,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# optimize (schedule) the AST
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||
|
||||
# split store range (only on CPU for now)
|
||||
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren)
|
||||
|
||||
|
||||
@@ -181,7 +181,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
if threads > k.ren.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
|
||||
for axis in k.axes_of(AxisType.LOOP):
|
||||
if k.full_shape[axis] % threads == 0:
|
||||
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
|
||||
try: k.apply_opt(Opt(OptOps.THREAD, axis, threads))
|
||||
except KernelOptError: pass
|
||||
break
|
||||
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
|
||||
from tinygrad.uop.symbolic import symbolic_flat
|
||||
from tinygrad.helpers import partition
|
||||
from tinygrad.helpers import partition, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
def flatten_range(r:UOp):
|
||||
@@ -136,3 +136,16 @@ pm_load_collapse = PatternMatcher([
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
])
|
||||
|
||||
def cut_store_range(ctx, store:UOp, r:UOp):
|
||||
# only cut ranges on CPU for now
|
||||
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
|
||||
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
|
||||
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
|
||||
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
|
||||
|
||||
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
|
||||
|
||||
pm_split_store = pm_flatten_range+PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(Ops.STORE, name="store"), UPat.var("r"))), cut_store_range),
|
||||
])
|
||||
|
||||
@@ -550,6 +550,7 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
|
||||
(UPat(Ops.END, src=(UPat(Ops.NOOP, name="noop"),), allow_any_len=True), lambda noop:noop),
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
||||
|
||||
Reference in New Issue
Block a user