realize before copy rule [pr] (#7476)

* realize before COPY and BUFFER_VIEW rule [pr]

* only upat.view

* move the assert to lazy
This commit is contained in:
qazal
2024-11-02 07:07:27 +02:00
committed by GitHub
parent 3819f5cf4d
commit c56364fad0
4 changed files with 4 additions and 4 deletions

View File

@@ -38,10 +38,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
assign_targets[(target:=buf.srcs[0])] = buf
assert target._base is None, f"assign must be to base {target}"
assert target.is_realized(), f"assign must be already realized to schedule {target}"
if buf.op is MetaOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes[buf.srcs[0].base] = None
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
for x in buf.srcs:
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)

View File

@@ -118,6 +118,7 @@ class LazyBuffer(MathTrait):
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def _copy(self, device:str) -> LazyBuffer:
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:

View File

@@ -238,6 +238,8 @@ def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), U
do_realize = PatternMatcher([
# always realize meta ops
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
(UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])

View File

@@ -567,6 +567,7 @@ class UPat(MathTrait):
# copied from UOp
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))