mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
reorder contiguous/assign ast rules [pr] (#9420)
* apply setitem ShapeTracker when creating store [pr] * comments + early contiguous remove * better * linter
This commit is contained in:
@@ -275,8 +275,15 @@ add_buffer_ops = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
# partial assign can store to a non-contiguous ShapeTracker
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
|
||||
# otherwise the store is contiguous
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
])
|
||||
|
||||
# ** push views to buffer ops
|
||||
@@ -317,9 +324,6 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
@@ -361,9 +365,6 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
# remove CONTIGUOUS/DEVICE
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
||||
# remove unmasked valid
|
||||
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
|
||||
# no ImageDType after load
|
||||
|
||||
Reference in New Issue
Block a user