diff --git a/extra/models/llama.py b/extra/models/llama.py index 75ea16a3ff..29429a53c2 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -36,9 +36,11 @@ class RMSNorm: self.eps = eps self.weight = Tensor.ones(dim) + def _norm(self, x:Tensor): + return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() + def __call__(self, x:Tensor): - x = x.float() - return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight + return self._norm(x.float()).cast(x.dtype) * self.weight class Attention: def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear): @@ -101,7 +103,7 @@ class TransformerBlock: def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - return (h + self.feed_forward(self.ffn_norm(h).half())).contiguous() + return (h + self.feed_forward(self.ffn_norm(h))).contiguous() # standard openai sampling def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): @@ -166,7 +168,7 @@ class Transformer: h = self.tok_embeddings(tokens) mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) - logits = self.output(self.norm(h))[:, -1, :] + logits = self.output(self.norm(h)).float()[:, -1, :] return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()