merge_views for buffer ops + create valids last (#9472)

* merge_views for buffer ops + create valids last

* view.arg

* pass
This commit is contained in:
qazal
2025-03-17 17:15:44 +08:00
committed by GitHub
parent bd1f71c1e2
commit 813f713edc
3 changed files with 7 additions and 10 deletions

View File

@@ -98,7 +98,6 @@ class TestSchedule(unittest.TestCase):
a.realize()
assert not a.lazydata.is_realized
@unittest.expectedFailure
def test_simplify_padded_const(self):
a = Tensor.empty(1022).cummax(axis=0)
sched = check_schedule(a, 5)

View File

@@ -345,6 +345,8 @@ add_buffer_ops = PatternMatcher([
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# VALID
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
@@ -366,8 +368,6 @@ def check_load_st(glbl:UOp, view:UOp):
fix_kernel_ops = PatternMatcher([
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
@@ -385,7 +385,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# add buffer ops + fix_kernel_ops
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)

View File

@@ -980,6 +980,9 @@ merge_views = PatternMatcher([
# merge unmasked const views
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
# merge view on load/store/valid
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
# remove view if it's a contiguous and the shapes match
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
# remove mask if there's a zero in the masked dim
@@ -989,13 +992,8 @@ merge_views = PatternMatcher([
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
])
# push VIEW to parents
# view before elementwise ops
view_left = merge_views+PatternMatcher([
# VIEW(CONST) becomes VALID
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
# VIEW before elementwise/buffer ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])