mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
allow VIEW on BUFFER [pr] (#8136)
* allow VIEW of BUFFER [pr] * base it later * better diff * base shouldn't exist after anywhere merge_views
This commit is contained in:
@@ -67,25 +67,23 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
buf.buffer.dtype = dtype
|
||||
buf.buffer.options = None
|
||||
# base is a VIEW of (BUFFER, (optional) op)
|
||||
if buf.is_realized:
|
||||
# TODO: this is the same underlying Buffer in all schedules
|
||||
buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = None
|
||||
# TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this
|
||||
if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st)
|
||||
# ASSIGN uses the target buffer, otherwise we create a new buffer
|
||||
else:
|
||||
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
|
||||
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
# keep track of ops outside the big graph
|
||||
buffers[buf_uop] = buf.buffer
|
||||
if op is not None:
|
||||
ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
# keep track of scheduled ops
|
||||
buf.buffer.ref(1)
|
||||
ctx.lazybufs[buf_uop] = buf
|
||||
ctx.allbufs[buf_uop] = ret
|
||||
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
cache[buf] = ret
|
||||
buffers[ret.buf_uop] = buf.buffer
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -425,7 +423,7 @@ break_sched = PatternMatcher([
|
||||
# everything else is a VIEW of BUFFER that either realizes or fuses
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
# just load realized buffers
|
||||
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
|
||||
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
@@ -356,11 +356,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
|
||||
if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
# instant folding rules
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
|
||||
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return ret
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
|
||||
|
||||
Reference in New Issue
Block a user