update llama model RMSNorm casting (#5095)

following the original implementation, cast back to input dtype before multiplying weight. slightly faster
https://github.com/meta-llama/llama/blob/main/llama/model.py
This commit is contained in:
chenyu
2024-06-21 23:02:04 -04:00
committed by GitHub
parent 0c857ae2d6
commit 8bd6cb9511

View File

@@ -36,9 +36,11 @@ class RMSNorm:
self.eps = eps
self.weight = Tensor.ones(dim)
def _norm(self, x:Tensor):
return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor):
x = x.float()
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
return self._norm(x.float()).cast(x.dtype) * self.weight
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
@@ -101,7 +103,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h).half())).contiguous()
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
# standard openai sampling
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
@@ -166,7 +168,7 @@ class Transformer:
h = self.tok_embeddings(tokens)
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))[:, -1, :]
logits = self.output(self.norm(h)).float()[:, -1, :]
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()