From 8fcc41582f1fee23a05be6166331c5f0299de5d2 Mon Sep 17 00:00:00 2001 From: David Hou Date: Tue, 25 Jun 2024 19:13:20 -0700 Subject: [PATCH] make buffer view optional with a flag (#5120) --- tinygrad/lazy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 17747eace2..29eec16582 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -34,11 +34,9 @@ class LazyBuffer: self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized" - if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \ - not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices: + if self.op is LoadOps.VIEW: # some LazyBuffers can be processed with only a view, no AST required self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize) - self.op = LoadOps.VIEW else: self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype) self.buffer.ref(1) @@ -84,9 +82,11 @@ class LazyBuffer: assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}" return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base)) - def contiguous(self): + def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices + + def contiguous(self, allow_buffer_view=True): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): - ret = self.e(LoadOps.CONTIGUOUS) + ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS) if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti return ret self.base.forced_realize = True @@ -107,7 +107,7 @@ class LazyBuffer: elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base: # TODO: applying this makes gpt2 slower return self.base.cast(dtype, bitcast)._view(self.st) - cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST + cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,)) def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)