From 5ff278446cf4a413be5ae0e66ea1097072cbbe43 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:05:04 +0800 Subject: [PATCH] add contiguous_view_offset (#15084) * add contiguous_view_offset * no int --- tinygrad/apps/llm.py | 3 ++- tinygrad/nn/state.py | 4 ++-- tinygrad/uop/ops.py | 35 ++++++++++++++++++++++------------- tinygrad/uop/spec.py | 3 +++ 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 82c2a95326..a2fb7342e3 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -143,7 +143,8 @@ class TransformerBlock: #v = self.cache_kv[1, :, :, 0:start_pos+T, :] # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) + # TODO: this if statement should be removed and it shouldn't generate extra kernels + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) attn = self.attn_output(attn) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 5af4d250b7..3ee8bf7677 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -304,7 +304,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: # native types if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None: - return t[:dtype.itemsize * n].bitcast(dtype) + return t[:dtype.itemsize * n].contiguous().bitcast(dtype) def q_to_uint8(t: Tensor, b: int) -> Tensor: # TODO: rewrite with arange? @@ -313,7 +313,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: # map to (number of elements, number of bytes) if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 8: (32, 34), 12: (256, 144), 14: (256, 210), 39: (32, 17) }.get(ggml_type)) is not None: - blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])) + blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])).contiguous() if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 3: d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b101d83096..e51a7cb4e4 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -657,10 +657,22 @@ class UOp(OpMixin, metaclass=UOpMetaClass): while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0] return s + def contiguous_view_offset(self) -> tuple[int, int]|None: + """If movement ops on a BUFFER collapse to a contiguous range, return (size, offset) in elements. Otherwise None.""" + from tinygrad.schedule.rangeify import pm_mops + from tinygrad.uop.symbolic import symbolic + out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset") + if out.op is not Ops.INDEX: return None + if out.src[1].op is Ops.CONST and self.size == 1: return (1, out.src[1].arg) + if out.src[1].op is Ops.RANGE: return (self.size, 0) + if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST: + return (self.size, out.src[1].src[1].arg) + return None + def has_buffer_identity(self): """Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain).""" if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity() - return self.op in {Ops.BUFFER, Ops.PARAM} + return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM} @property def buffer(self) -> Buffer|MultiBuffer: @@ -668,23 +680,20 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE}: return self.src[0].buffer # this buffer can process disk tensors and simple movement ops if self is not self.base: - from tinygrad.schedule.rangeify import pm_mops - from tinygrad.uop.symbolic import symbolic - out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops+symbolic) - buf = out.src[0].buffer + size_offset = self.contiguous_view_offset() + if size_offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view") + size, offset = size_offset + buf = self.base.buffer assert isinstance(buf, Buffer), "must be a Buffer for movement ops" - assert out.op is Ops.INDEX, "couldn't collapse to a single INDEX" - if out.src[1].op is Ops.CONST: - return buf.view(1, out.dtype, out.src[1].arg*out.dtype.itemsize) - if out.src[1].op is Ops.RANGE: - return buf.view(self.size, out.dtype, 0) - if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST: - return buf.view(self.size, out.dtype, out.src[1].src[1].arg*out.dtype.itemsize) - raise RuntimeError(f"cannot collapse INDEX {out.pyrender()} to a single size/offset") + return buf.view(size, self.dtype, offset*self.dtype.itemsize) if self.op is Ops.BITCAST: buf = self.src[0].buffer assert isinstance(buf, Buffer), "must be a Buffer for BITCAST" return buf.view(self.size, self.dtype, 0) + if self.op is Ops.BUFFER_VIEW: + buf = self.src[0].buffer + assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW" + return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) if self.op is Ops.MSELECT: ret = self.src[0].buffer assert isinstance(ret, MultiBuffer) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 9c91f34060..c7f18f0794 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -87,6 +87,9 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), + # BUFFER_VIEW on BUFFER is allowed if BUFFER is + (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True), + # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER (UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),