allow VIEW on BUFFER [pr] (#8136)

* allow VIEW of BUFFER [pr]

* base it later

* better diff

* base shouldn't exist after anywhere merge_views
This commit is contained in:
qazal
2024-12-10 15:29:38 +02:00
committed by GitHub
parent 3a2658efbd
commit 2d26b011ac
2 changed files with 9 additions and 11 deletions

View File

@@ -67,25 +67,23 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
buf.buffer.dtype = dtype
buf.buffer.options = None
# base is a VIEW of (BUFFER, (optional) op)
if buf.is_realized:
# TODO: this is the same underlying Buffer in all schedules
buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
op = None
# TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this
if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st)
# ASSIGN uses the target buffer, otherwise we create a new buffer
else:
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st)
# keep track of ops outside the big graph
buffers[buf_uop] = buf.buffer
if op is not None:
ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
# keep track of scheduled ops
buf.buffer.ref(1)
ctx.lazybufs[buf_uop] = buf
ctx.allbufs[buf_uop] = ret
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
cache[buf] = ret
buffers[ret.buf_uop] = buf.buffer
return ret
# **** AST graph rewrite
@@ -425,7 +423,7 @@ break_sched = PatternMatcher([
# everything else is a VIEW of BUFFER that either realizes or fuses
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# just load realized buffers
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
])
@track_rewrites(named=True)

View File

@@ -356,11 +356,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
# instant folding rules
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return ret
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))