mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix fix coder. RMSNorm cannot have float16 input (#2932)
* hotfix fix coder. RMSNorm cannot have float16 input * update real world test due to new kernels * more type casts
This commit is contained in:
@@ -36,8 +36,8 @@ class RMSNorm:
|
||||
self.weight = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
x = x.float()
|
||||
return ((x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight)
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
@@ -53,6 +53,7 @@ class Attention:
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
x = x.half()
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
@@ -96,7 +97,7 @@ class TransformerBlock:
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
||||
|
||||
1
test/external/external_test_opt.py
vendored
1
test/external/external_test_opt.py
vendored
@@ -90,6 +90,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
|
||||
out.realize()
|
||||
|
||||
@unittest.skip("llama is fp16 but CI does not have fp16")
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer
|
||||
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
|
||||
@@ -66,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_llama(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
@@ -75,8 +75,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.6, 190 if CI else 718, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
|
||||
@@ -228,7 +228,7 @@ class Tensor:
|
||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1)
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
Reference in New Issue
Block a user