mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update llama model RMSNorm casting (#5095)
following the original implementation, cast back to input dtype before multiplying weight. slightly faster https://github.com/meta-llama/llama/blob/main/llama/model.py
This commit is contained in:
@@ -36,9 +36,11 @@ class RMSNorm:
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
|
||||
def _norm(self, x:Tensor):
|
||||
return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
x = x.float()
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
return self._norm(x.float()).cast(x.dtype) * self.weight
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
@@ -101,7 +103,7 @@ class TransformerBlock:
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h).half())).contiguous()
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
|
||||
|
||||
# standard openai sampling
|
||||
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
@@ -166,7 +168,7 @@ class Transformer:
|
||||
h = self.tok_embeddings(tokens)
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))[:, -1, :]
|
||||
logits = self.output(self.norm(h)).float()[:, -1, :]
|
||||
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user