diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f8307ad028..4f6721e48d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -349,6 +349,16 @@ def _append_buf(ctx:KernelContext, x:UOp) -> UOp: ctx.bufs.append(x) return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) +def check_load_st(glbl:UOp, view:UOp): + if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return + # if it has a single view and it's equal when you shrink a contig, it's fine + if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return + # otherwise, it's not fine + raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + to_si = PatternMatcher([ # BUFFER -> DEFINE_GLOBAL (UPat(Ops.BUFFER, name="x"), _append_buf), @@ -365,6 +375,8 @@ to_si = PatternMatcher([ (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + # if this kernel also assigns to the loaded buffer, ensure we can index it correctly + (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st), ]) def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): @@ -384,17 +396,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp: # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") # PRELOAD tells the toposort this kernel should run before ASSIGN - if x.op is Ops.PRELOAD: - assign_preloads[x.buf_uop] = None - # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD - if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: - # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine - if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass - # if it has a single view and it's equal when you shrink a contig, it's fine - elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass - # otherwise, it's not fine - else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None # NOTE: we only add the metadata for fused tensors metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None)) return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))