make buffer view optional with a flag (#5120)

This commit is contained in:
David Hou
2024-06-25 19:13:20 -07:00
committed by GitHub
parent 63ba2d05d1
commit 8fcc41582f

View File

@@ -34,11 +34,9 @@ class LazyBuffer:
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
if self.op is LoadOps.VIEW:
# some LazyBuffers can be processed with only a view, no AST required
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
self.op = LoadOps.VIEW
else:
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer.ref(1)
@@ -84,9 +82,11 @@ class LazyBuffer:
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
def contiguous(self):
def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
def contiguous(self, allow_buffer_view=True):
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
ret = self.e(LoadOps.CONTIGUOUS)
ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
return ret
self.base.forced_realize = True
@@ -107,7 +107,7 @@ class LazyBuffer:
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)