diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index bc4ee0ab1b..9fc2204ae7 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -502,12 +502,11 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None: # gather the metadata metadatas = [ctx[y].metadata for y in lctx.parent_tags] - # NOTE: the hack for COPY is here - for u in ret.toposort(): - # TODO: this can be wrong if there's multiple of these - if u.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: - ret = u - break + # SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops + if ret.op is Ops.STORE: stored = ret.src[1] + elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1] + else: raise RuntimeError(f"unknown kernel type {ret.op}") + if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)