mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hotfix: llm kv cache uses clone instead of realize to avoid many realize
This commit is contained in:
@@ -165,7 +165,8 @@ class TransformerBlock:
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||||
if not hasattr(self, "cache_kv"):
|
||||
# TODO: how is the dtype of this determined?
|
||||
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
|
||||
# NOTE: clone is used to promise the creation of a specific buffer
|
||||
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).clone()
|
||||
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
|
||||
class Transformer:
|
||||
|
||||
Reference in New Issue
Block a user