From 3b67f56c02692c5d886dc99cf92f501ef5a2e0ca Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 29 Apr 2025 13:39:56 +0300 Subject: [PATCH] kernelize some llama realizes (#10098) --- extra/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 49ff9890b9..3a089facbc 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -175,14 +175,14 @@ class Transformer: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.cast(h.dtype).realize() + self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize() freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None)) - mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None + mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) logits = self.output(self.norm(h)).float()[:, -1, :] - return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize() + return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).kernelize() def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0): # TODO: better way to handle the first call v.s. the rest?