mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move subbuffer to a rewrite rule in the scheduler (#8639)
* delete buffer_view from tensor * add to the scheduler * move buffer_view to the scheduler * gradient doesn't care. * for/with
This commit is contained in:
@@ -51,17 +51,6 @@ tensor_uop_spec = PatternMatcher([
|
||||
# ASSIGN changes the value of a realized buffer
|
||||
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
|
||||
lambda assign,target,new_val: (target.op is Ops.BUFFER or target.is_realized) and (assign.dtype == target.dtype == new_val.dtype)),
|
||||
|
||||
# TODO: BUFFER_VIEW is overloaded, it should be removed.
|
||||
# BUFFER_VIEW shares the device buffer with its source, it uses a subbuffer of the underlying source buffer
|
||||
|
||||
(UPat(Ops.BUFFER_VIEW, name="root", src=(UPat.var("x"),)), lambda root,x:
|
||||
# BUFFER_VIEW can replace contiguous, keeping dtype the same
|
||||
(root.dtype == x.dtype) or
|
||||
# it can also replace bitcast, this changes the dtype, but the itemsize stays the same
|
||||
(root.dtype != x.dtype and root.dtype.itemsize == x.dtype.itemsize) or
|
||||
# it can also represent shape changing bitcast (only on DISK)
|
||||
(root.dtype != x.dtype and root.dtype.itemsize != x.dtype.itemsize and x.device.startswith("DISK"))),
|
||||
])
|
||||
|
||||
# **** ScheduleItem return type
|
||||
@@ -455,6 +444,12 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **k
|
||||
def sink_outputs(ctx:ScheduleContext, sink:UOp) -> None:
|
||||
for x in sink.src: realize(ctx, x.buf_uop, x)
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if not root.device.startswith("DISK"): return None
|
||||
if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize sinked ops
|
||||
(UPat(Ops.SINK, name="sink"), sink_outputs),
|
||||
@@ -467,6 +462,8 @@ do_realize = PatternMatcher([
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
|
||||
@@ -502,10 +499,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
# BUFFER_VIEW overrides the underlying buffer
|
||||
# TODO: this should be a shrink on the buffer
|
||||
if op.op is Ops.BUFFER_VIEW:
|
||||
buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(view.size, view.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
buf_uop.buffer.ref(1)
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
|
||||
@@ -37,9 +37,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
||||
|
||||
# there's no gradient for...is this ASSIGN?
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
|
||||
# also no gradient for bitcast
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
||||
])
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
if self.op is Ops.BUFFER_VIEW:
|
||||
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
|
||||
shape = src_sts[0].shape
|
||||
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
||||
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
||||
@@ -365,9 +365,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def bitcast(self, dtype:DType):
|
||||
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
|
||||
raise RuntimeError(f"unsupported size in bitcast {dtype}")
|
||||
# shape changing bitcast can use a subbuffer on DISK
|
||||
# TODO: this should be moved to realize.py
|
||||
if self._device is not None and self.device.startswith("DISK"): return UOp(Ops.BUFFER_VIEW, dtype, (self,))
|
||||
return UOp(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:Union[tuple[int, ...], int]):
|
||||
if isinstance(i, int):
|
||||
@@ -421,10 +418,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self):
|
||||
# TODO: BUFFER_VIEW op should be deleted and subbuffer should be moved to realize.py
|
||||
# NOTE: DISK uses subbuffer because DISK does not render kernels
|
||||
if self.device.startswith("DISK"): return self.alu(Ops.BUFFER_VIEW)
|
||||
# otherwise it's normal CONTIGUOUS
|
||||
if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST:
|
||||
return self.alu(Ops.CONTIGUOUS)
|
||||
forced_realize.add(self.base)
|
||||
|
||||
Reference in New Issue
Block a user