mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
merge_views for buffer ops + create valids last (#9472)
* merge_views for buffer ops + create valids last * view.arg * pass
This commit is contained in:
@@ -98,7 +98,6 @@ class TestSchedule(unittest.TestCase):
|
||||
a.realize()
|
||||
assert not a.lazydata.is_realized
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_simplify_padded_const(self):
|
||||
a = Tensor.empty(1022).cummax(axis=0)
|
||||
sched = check_schedule(a, 5)
|
||||
|
||||
@@ -345,6 +345,8 @@ add_buffer_ops = PatternMatcher([
|
||||
# otherwise the store is contiguous
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
|
||||
# if the last child is a VIEW we merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
|
||||
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
|
||||
@@ -366,8 +368,6 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
# remove unmasked valid
|
||||
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
|
||||
# no ImageDType after load
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
@@ -385,7 +385,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# add buffer ops + fix_kernel_ops
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
|
||||
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# create subbuffer (TODO: this does not belong here)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
|
||||
@@ -980,6 +980,9 @@ merge_views = PatternMatcher([
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
|
||||
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
|
||||
# merge view on load/store/valid
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
|
||||
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
# remove view if it's a contiguous and the shapes match
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
|
||||
# remove mask if there's a zero in the masked dim
|
||||
@@ -989,13 +992,8 @@ merge_views = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
])
|
||||
|
||||
# push VIEW to parents
|
||||
# view before elementwise ops
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW(CONST) becomes VALID
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
|
||||
# VIEW before elementwise/buffer ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
||||
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
||||
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user