kernelize some llama realizes (#10098)

This commit is contained in:
qazal
2025-04-29 13:39:56 +03:00
committed by GitHub
parent cbf7347cd6
commit 3b67f56c02

View File

@@ -175,14 +175,14 @@ class Transformer:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.cast(h.dtype).realize()
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h)).float()[:, -1, :]
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).kernelize()
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
# TODO: better way to handle the first call v.s. the rest?