add UOp.__index__ (#14181)

Tensor slice is handled by __getitem__, so the index method is just for SupportsIndex
This commit is contained in:
chenyu
2026-01-16 12:28:33 -05:00
committed by GitHub
parent 6790165ef8
commit fc10470883
2 changed files with 4 additions and 3 deletions

View File

@@ -126,15 +126,14 @@ class TransformerBlock:
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T]
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
# TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize()
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
v = self.cache_kv[1, :, :, 0:start_pos+T, :]

View File

@@ -345,6 +345,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if len(dvars) == 0: return self
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
# NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking
def __index__(self): return self.__int__()
# *** uop tracing stuff ***