diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6908fa6c74..e51cc4a5d8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -275,8 +275,15 @@ add_buffer_ops = PatternMatcher([ (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))), # STORE (except for COPY/BUFFER_VIEW) (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + # partial assign can store to a non-contiguous ShapeTracker + (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), + # otherwise the store is contiguous (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), + # remove CONTIGUOUS/DEVICE from kernel AST + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), ]) # ** push views to buffer ops @@ -317,9 +324,6 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: # push VIEW to children view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), - lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # STORE is the last child, so we just merge the ShapeTrackers and store the base (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), # push a non contiguous ShapeTracker through reduceop @@ -361,9 +365,6 @@ def check_load_st(glbl:UOp, view:UOp): fix_kernel_ops = PatternMatcher([ # BIND in shapetracker becomes DEFINE_VAR (UPat(Ops.VIEW, name="x"), unbind_shapetracker), - # remove CONTIGUOUS/DEVICE - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), # remove unmasked valid (UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None), # no ImageDType after load