mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
realize before copy rule [pr] (#7476)
* realize before COPY and BUFFER_VIEW rule [pr] * only upat.view * move the assert to lazy
This commit is contained in:
@@ -38,10 +38,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
assign_targets[(target:=buf.srcs[0])] = buf
|
||||
assert target._base is None, f"assign must be to base {target}"
|
||||
assert target.is_realized(), f"assign must be already realized to schedule {target}"
|
||||
if buf.op is MetaOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes[buf.srcs[0].base] = None
|
||||
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
|
||||
for x in buf.srcs:
|
||||
if x.base.realized is None: children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
|
||||
|
||||
@@ -118,6 +118,7 @@ class LazyBuffer(MathTrait):
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
||||
|
||||
@@ -238,6 +238,8 @@ def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), U
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
(UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
|
||||
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
|
||||
])
|
||||
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ class UPat(MathTrait):
|
||||
|
||||
# copied from UOp
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
|
||||
|
||||
Reference in New Issue
Block a user