mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -190,7 +190,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert stores[1].src[1].dtype == dtypes.float
|
||||
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="CPU", "CPU splits the cat so cant upcast")
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
from tinygrad.codegen.opt.postrange import apply_opts, make_images
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
@@ -50,9 +50,6 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# optimize (schedule) the AST
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||
|
||||
# split store range (only on CPU for now)
|
||||
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
|
||||
|
||||
# create image buffers
|
||||
sink = make_images(sink, ren)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import partition, dedup
|
||||
from tinygrad.helpers import partition
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
|
||||
def flatten_range(r:UOp) -> UOp|None:
|
||||
@@ -147,16 +147,3 @@ pm_load_collapse = PatternMatcher([
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
])
|
||||
|
||||
def cut_store_range(ctx:str, store:UOp, r:UOp) -> UOp|None:
|
||||
# only cut ranges on CPU for now
|
||||
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
|
||||
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
|
||||
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
|
||||
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
|
||||
|
||||
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
|
||||
|
||||
pm_split_store = pm_flatten_range+PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(Ops.STORE, name="store"), UPat.var("r"))), cut_store_range),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user