add contiguous_view_offset (#15084)

* add contiguous_view_offset

* no int
This commit is contained in:
George Hotz
2026-03-02 18:05:04 +08:00
committed by GitHub
parent 977c270774
commit 5ff278446c
4 changed files with 29 additions and 16 deletions

View File

@@ -143,7 +143,8 @@ class TransformerBlock:
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
attn = self.attn_output(attn)

View File

@@ -304,7 +304,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)
return t[:dtype.itemsize * n].contiguous().bitcast(dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
@@ -313,7 +313,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 8: (32, 34), 12: (256, 144), 14: (256, 210), 39: (32, 17) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])).contiguous()
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])

View File

@@ -657,10 +657,22 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
return s
def contiguous_view_offset(self) -> tuple[int, int]|None:
"""If movement ops on a BUFFER collapse to a contiguous range, return (size, offset) in elements. Otherwise None."""
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset")
if out.op is not Ops.INDEX: return None
if out.src[1].op is Ops.CONST and self.size == 1: return (1, out.src[1].arg)
if out.src[1].op is Ops.RANGE: return (self.size, 0)
if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST:
return (self.size, out.src[1].src[1].arg)
return None
def has_buffer_identity(self):
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
return self.op in {Ops.BUFFER, Ops.PARAM}
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
@property
def buffer(self) -> Buffer|MultiBuffer:
@@ -668,23 +680,20 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE}: return self.src[0].buffer
# this buffer can process disk tensors and simple movement ops
if self is not self.base:
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops+symbolic)
buf = out.src[0].buffer
size_offset = self.contiguous_view_offset()
if size_offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view")
size, offset = size_offset
buf = self.base.buffer
assert isinstance(buf, Buffer), "must be a Buffer for movement ops"
assert out.op is Ops.INDEX, "couldn't collapse to a single INDEX"
if out.src[1].op is Ops.CONST:
return buf.view(1, out.dtype, out.src[1].arg*out.dtype.itemsize)
if out.src[1].op is Ops.RANGE:
return buf.view(self.size, out.dtype, 0)
if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST:
return buf.view(self.size, out.dtype, out.src[1].src[1].arg*out.dtype.itemsize)
raise RuntimeError(f"cannot collapse INDEX {out.pyrender()} to a single size/offset")
return buf.view(size, self.dtype, offset*self.dtype.itemsize)
if self.op is Ops.BITCAST:
buf = self.src[0].buffer
assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"
return buf.view(self.size, self.dtype, 0)
if self.op is Ops.BUFFER_VIEW:
buf = self.src[0].buffer
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize)
if self.op is Ops.MSELECT:
ret = self.src[0].buffer
assert isinstance(ret, MultiBuffer)

View File

@@ -87,6 +87,9 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),