hotfix fix coder. RMSNorm cannot have float16 input (#2932)

* hotfix fix coder. RMSNorm cannot have float16 input

* update real world test due to new kernels

* more type casts
This commit is contained in:
chenyu
2023-12-25 02:28:11 -05:00
committed by GitHub
parent b469fe3723
commit 1fb815e77e
4 changed files with 9 additions and 7 deletions

View File

@@ -36,8 +36,8 @@ class RMSNorm:
self.weight = Tensor.ones(dim)
def __call__(self, x:Tensor):
# TODO: convert to float?
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
x = x.float()
return ((x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight)
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
@@ -53,6 +53,7 @@ class Attention:
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
x = x.half()
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
@@ -96,7 +97,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):

View File

@@ -90,6 +90,7 @@ class TestInferenceMinKernels(unittest.TestCase):
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
out.realize()
@unittest.skip("llama is fp16 but CI does not have fp16")
def test_llama(self):
from examples.llama import Transformer
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}

View File

@@ -66,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43)
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_llama(self):
dtypes.default_float = dtypes.float16
@@ -75,8 +75,8 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 13.5, 181 if CI else 685, all_jitted=True)
# TODO: test first token vs rest properly
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.6, 190 if CI else 718, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):

View File

@@ -228,7 +228,7 @@ class Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1)
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)