From a65c8de735c8da218ffcbc89efff660ec3c7348f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 May 2024 15:00:18 -0400 Subject: [PATCH] move .half() llama freq_cis to the end of sin and cos (#4587) otherwise arange has inf if either dim or context length exceeds half.max --- extra/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 086301125a..7b4c8d6d67 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -4,9 +4,9 @@ from tinygrad.helpers import getenv # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: - freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2, dtype=dtypes.half)[:(dim // 2)] / dim)) - freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0) - return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2) + freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) + freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) + return Tensor.stack([freqs.cos().half(), freqs.sin().half()], dim=-1).reshape(1, end, 1, dim//2, 2) # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc) def complex_mult(A, c, d):