diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3b15a4bdc4..c7258e2947 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -190,7 +190,6 @@ class TestLinearizer(unittest.TestCase): assert stores[1].src[1].dtype == dtypes.float assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort()) - @unittest.skipIf(Device.DEFAULT=="CPU", "CPU splits the cat so cant upcast") def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack(a, b) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index b3273b481f..d15b22a96f 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -16,7 +16,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads from tinygrad.codegen.opt.postrange import apply_opts, make_images -from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store +from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize @@ -50,9 +50,6 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # optimize (schedule) the AST sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges") - # split store range (only on CPU for now) - sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges") - # create image buffers sink = make_images(sink, ren) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index b81743e72e..6a7e1f43d7 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -1,7 +1,7 @@ import itertools from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start from tinygrad.uop.symbolic import symbolic -from tinygrad.helpers import partition, dedup +from tinygrad.helpers import partition from tinygrad.dtype import dtypes, ImageDType def flatten_range(r:UOp) -> UOp|None: @@ -147,16 +147,3 @@ pm_load_collapse = PatternMatcher([ # we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse ((UPat.var("x", dtypes.index)+UPat.var("y")) UOp|None: - # only cut ranges on CPU for now - if r.src[0].op is not Ops.CONST or ctx!="CPU": return None - if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None - cuts = sorted(dedup([0] + cuts + [r.src[0].arg])) - ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))] - - return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])]) - -pm_split_store = pm_flatten_range+PatternMatcher([ - (UPat(Ops.END, src=(UPat(Ops.STORE, name="store"), UPat.var("r"))), cut_store_range), -])