mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add UOp.__index__ (#14181)
Tensor slice is handled by __getitem__, so the index method is just for SupportsIndex
This commit is contained in:
@@ -126,15 +126,14 @@ class TransformerBlock:
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
# TODO: make UOp have SupportsIndex
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T]
|
||||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
|
||||
# TODO: remove these kv cache realizes
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize()
|
||||
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
|
||||
@@ -345,6 +345,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
||||
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
|
||||
# NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking
|
||||
def __index__(self): return self.__int__()
|
||||
|
||||
# *** uop tracing stuff ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user