mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
make buffer view optional with a flag (#5120)
This commit is contained in:
@@ -34,11 +34,9 @@ class LazyBuffer:
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
||||
|
||||
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
|
||||
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
||||
if self.op is LoadOps.VIEW:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
self.op = LoadOps.VIEW
|
||||
else:
|
||||
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
||||
self.buffer.ref(1)
|
||||
@@ -84,9 +82,11 @@ class LazyBuffer:
|
||||
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
||||
|
||||
def contiguous(self):
|
||||
def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
|
||||
|
||||
def contiguous(self, allow_buffer_view=True):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.e(LoadOps.CONTIGUOUS)
|
||||
ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
|
||||
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
||||
return ret
|
||||
self.base.forced_realize = True
|
||||
@@ -107,7 +107,7 @@ class LazyBuffer:
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
|
||||
cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
|
||||
|
||||
Reference in New Issue
Block a user