reorder contiguous/assign ast rules [pr] (#9420)

* apply setitem ShapeTracker when creating store [pr]

* comments + early contiguous remove

* better

* linter
This commit is contained in:
qazal
2025-03-12 13:13:27 +02:00
committed by GitHub
parent 5f6d5b057d
commit 12978f0d05

View File

@@ -275,8 +275,15 @@ add_buffer_ops = PatternMatcher([
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
])
# ** push views to buffer ops
@@ -317,9 +324,6 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# push a non contiguous ShapeTracker through reduceop
@@ -361,9 +365,6 @@ def check_load_st(glbl:UOp, view:UOp):
fix_kernel_ops = PatternMatcher([
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
# remove CONTIGUOUS/DEVICE
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load