diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 75157ede56..a1c201d104 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -38,10 +38,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La assign_targets[(target:=buf.srcs[0])] = buf assert target._base is None, f"assign must be to base {target}" assert target.is_realized(), f"assign must be already realized to schedule {target}" - if buf.op is MetaOps.COPY: - assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" - realizes[buf.srcs[0].base] = None - if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None for x in buf.srcs: if x.base.realized is None: children[x.base][buf] = None _recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx) diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 95b759d83b..d37abf0ad3 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -118,6 +118,7 @@ class LazyBuffer(MathTrait): def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views) def _copy(self, device:str) -> LazyBuffer: + assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}" return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False) def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9df5e311ba..72b757816f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -238,6 +238,8 @@ def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), U do_realize = PatternMatcher([ # always realize meta ops (UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize), + (UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"), + lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),) ]) break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 68480e510b..159ee21577 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -567,6 +567,7 @@ class UPat(MathTrait): # copied from UOp def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) + def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs) def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))