hotfix: llm kv cache uses clone instead of realize to avoid many realize

This commit is contained in:
George Hotz
2026-03-04 19:07:03 +08:00
parent 8ebd24637b
commit 47faa2d7b4

View File

@@ -165,7 +165,8 @@ class TransformerBlock:
def __call__(self, x: Tensor, start_pos: int|UOp):
if not hasattr(self, "cache_kv"):
# TODO: how is the dtype of this determined?
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
# NOTE: clone is used to promise the creation of a specific buffer
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).clone()
return self._feed_forward(self._attention(x, start_pos)).contiguous()
class Transformer: