mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
kernelize some llama realizes (#10098)
This commit is contained in:
@@ -175,14 +175,14 @@ class Transformer:
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
self.freqs_cis = self.freqs_cis.cast(h.dtype).realize()
|
||||
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h)).float()[:, -1, :]
|
||||
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).kernelize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
|
||||
Reference in New Issue
Block a user