diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 6d21cfbf40..52e33c7c1c 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -126,15 +126,14 @@ class TransformerBlock: v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k) - # TODO: make UOp have SupportsIndex - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] q = apply_rope(q, freqs_cis) k = apply_rope(k, freqs_cis) # TODO: remove these kv cache realizes if not hasattr(self, "cache_kv"): self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize() - self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore + self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() k = self.cache_kv[0, :, :, 0:start_pos+T, :] v = self.cache_kv[1, :, :, 0:start_pos+T, :] diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 93f8db1b96..0ebd898a28 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -345,6 +345,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if len(dvars) == 0: return self with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)): return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name) + # NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking + def __index__(self): return self.__int__() # *** uop tracing stuff ***