move subbuffer to a rewrite rule in the scheduler (#8639)

* delete buffer_view from tensor

* add to the scheduler

* move buffer_view to the scheduler

* gradient doesn't care.

* for/with
This commit is contained in:
qazal
2025-01-15 20:14:28 -05:00
committed by GitHub
parent b3efeeb717
commit d5c90da286
3 changed files with 10 additions and 26 deletions

View File

@@ -51,17 +51,6 @@ tensor_uop_spec = PatternMatcher([
# ASSIGN changes the value of a realized buffer
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
lambda assign,target,new_val: (target.op is Ops.BUFFER or target.is_realized) and (assign.dtype == target.dtype == new_val.dtype)),
# TODO: BUFFER_VIEW is overloaded, it should be removed.
# BUFFER_VIEW shares the device buffer with its source, it uses a subbuffer of the underlying source buffer
(UPat(Ops.BUFFER_VIEW, name="root", src=(UPat.var("x"),)), lambda root,x:
# BUFFER_VIEW can replace contiguous, keeping dtype the same
(root.dtype == x.dtype) or
# it can also replace bitcast, this changes the dtype, but the itemsize stays the same
(root.dtype != x.dtype and root.dtype.itemsize == x.dtype.itemsize) or
# it can also represent shape changing bitcast (only on DISK)
(root.dtype != x.dtype and root.dtype.itemsize != x.dtype.itemsize and x.device.startswith("DISK"))),
])
# **** ScheduleItem return type
@@ -455,6 +444,12 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **k
def sink_outputs(ctx:ScheduleContext, sink:UOp) -> None:
for x in sink.src: realize(ctx, x.buf_uop, x)
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if not root.device.startswith("DISK"): return None
if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize sinked ops
(UPat(Ops.SINK, name="sink"), sink_outputs),
@@ -467,6 +462,8 @@ do_realize = PatternMatcher([
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
@@ -502,10 +499,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
# BUFFER_VIEW overrides the underlying buffer
# TODO: this should be a shrink on the buffer
if op.op is Ops.BUFFER_VIEW:
buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(view.size, view.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
buf_uop.buffer.ref(1)
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])

View File

@@ -37,9 +37,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
# there's no gradient for...is this ASSIGN?
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
# also no gradient for bitcast
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda ctx: (None,)),
])

View File

@@ -291,7 +291,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
if self.op is Ops.BUFFER_VIEW:
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
shape = src_sts[0].shape
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
# only reduce ops are allowed to change shape, everything else derives shape from sources
@@ -365,9 +365,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def bitcast(self, dtype:DType):
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
raise RuntimeError(f"unsupported size in bitcast {dtype}")
# shape changing bitcast can use a subbuffer on DISK
# TODO: this should be moved to realize.py
if self._device is not None and self.device.startswith("DISK"): return UOp(Ops.BUFFER_VIEW, dtype, (self,))
return UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:Union[tuple[int, ...], int]):
if isinstance(i, int):
@@ -421,10 +418,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self):
# TODO: BUFFER_VIEW op should be deleted and subbuffer should be moved to realize.py
# NOTE: DISK uses subbuffer because DISK does not render kernels
if self.device.startswith("DISK"): return self.alu(Ops.BUFFER_VIEW)
# otherwise it's normal CONTIGUOUS
if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST:
return self.alu(Ops.CONTIGUOUS)
forced_realize.add(self.base)