From 1fb815e77e846945065a8c0e4e327d0252a677c3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 25 Dec 2023 02:28:11 -0500 Subject: [PATCH] hotfix fix coder. RMSNorm cannot have float16 input (#2932) * hotfix fix coder. RMSNorm cannot have float16 input * update real world test due to new kernels * more type casts --- extra/models/llama.py | 7 ++++--- test/external/external_test_opt.py | 1 + test/models/test_real_world.py | 6 +++--- tinygrad/tensor.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 4f4eb2e082..c4f41b98c0 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -36,8 +36,8 @@ class RMSNorm: self.weight = Tensor.ones(dim) def __call__(self, x:Tensor): - # TODO: convert to float? - return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight + x = x.float() + return ((x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight) class Attention: def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear): @@ -53,6 +53,7 @@ class Attention: self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor: + x = x.half() xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) @@ -96,7 +97,7 @@ class TransformerBlock: def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - return (h + self.feed_forward(self.ffn_norm(h))).realize() + return (h + self.feed_forward(self.ffn_norm(h).half())).realize() class Transformer: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 33d6bf6b6f..f4d99bc25e 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -90,6 +90,7 @@ class TestInferenceMinKernels(unittest.TestCase): assert len(CacheCollector.cache) == 0, "ViT prerealized?" out.realize() + @unittest.skip("llama is fp16 but CI does not have fp16") def test_llama(self): from examples.llama import Transformer args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index c4f1f0edd3..faa4388ed8 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -66,7 +66,7 @@ class TestRealWorld(unittest.TestCase): helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43) @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1") + @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") def test_llama(self): dtypes.default_float = dtypes.float16 @@ -75,8 +75,8 @@ class TestRealWorld(unittest.TestCase): derandomize_model(model) @TinyJit def test(t): return model(t, 0).realize() - # TODO: test first token vs rest properly, also memory test is broken with CacheCollector - helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 13.5, 181 if CI else 685, all_jitted=True) + # TODO: test first token vs rest properly + helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.6, 190 if CI else 718, all_jitted=True) @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") def test_gpt2(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe5a25feb1..b974c438d0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -228,7 +228,7 @@ class Tensor: assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" weight = self.unsqueeze(0) if self.ndim == 1 else self - cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1) + cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1) unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)