From 0dc333cfab474abd5f8b107cd2bd3515fd5dbe55 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 25 May 2023 21:39:45 -0400 Subject: [PATCH] Promote Embedding to `nn` (#798) * feat: promote Embedding to nn * fix: fix failing test * feat: add test with jit * feat: rewrite embedding to no longer need stacked for loops * clean+fix: don't know how that happened --- examples/llama.py | 19 ++++++----------- models/rnnt.py | 18 +--------------- test/external/external_test_opt.py | 6 +++--- test/test_nn.py | 33 +++++++++++++++++++++++++++++- tinygrad/nn/__init__.py | 9 ++++++++ 5 files changed, 51 insertions(+), 34 deletions(-) diff --git a/examples/llama.py b/examples/llama.py index 30ea38bfdd..ddfcdbdd79 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -14,7 +14,7 @@ from tinygrad.helpers import getenv, DEBUG from tinygrad.lazy import Device from extra.helpers import Timing from tinygrad.tensor import Tensor -from tinygrad.nn import Linear +from tinygrad.nn import Embedding, Linear from tinygrad.ops import GlobalCounters from tinygrad.jit import TinyJit @@ -133,13 +133,13 @@ class Transformer: def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024): self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)] self.norm = RMSNorm(dim, norm_eps) - self.tok_embeddings = {"weight": Tensor.glorot_uniform(vocab_size, dim)} + self.tok_embeddings = Embedding(vocab_size, dim) self.output = Linear(dim, vocab_size, bias=False) self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2)) def __call__(self, tokens:Tensor, start_pos:int): - _bsz, seqlen, _ = tokens.shape - h = tokens @ self.tok_embeddings['weight'] + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) # get only the part we are using. making it contiguous avoids more kernel calls freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize() @@ -174,13 +174,6 @@ WEIGHTS_13B_0_FILENAME = WEIGHTS_DIR / "13B/consolidated.00.pth" WEIGHTS_13B_1_FILENAME = WEIGHTS_DIR / "13B/consolidated.01.pth" # **** helper functions **** - -def onehot_encode(toks, vocab_size=VOCAB_SIZE): - # this allows the embedding to work in tinygrad - onehot = np.zeros((1, len(toks), vocab_size), dtype=np.float32) - onehot[0,range(len(toks)),toks] = 1 - return Tensor(onehot) - def sample(logits, temperature): if temperature < 1e-6: # so close to 0 we use argmax @@ -365,7 +358,7 @@ After you are done speaking, output [EOS]. You are not Chad. print(f"Preparing KV cache for chatbot with personality {args.personality}...") with Timing(): - model(onehot_encode(toks), 0).realize() # NOTE: output logits are not used + model(Tensor([toks]), 0).realize() # NOTE: output logits are not used start_pos = len(toks) else: # non chat bot mode @@ -400,7 +393,7 @@ After you are done speaking, output [EOS]. You are not Chad. if args.timing: print("") st = GlobalCounters.time_sum_s with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing): - logits = model(onehot_encode(toks[start_pos:]), start_pos).realize() + logits = model(Tensor([toks[start_pos:]]), start_pos).realize() with Timing("sync in ", enabled=args.timing): tok = sample(logits, args.temperature) diff --git a/models/rnnt.py b/models/rnnt.py index df14aa5068..e9f7b171af 100644 --- a/models/rnnt.py +++ b/models/rnnt.py @@ -1,6 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit -from tinygrad.nn import Linear +from tinygrad.nn import Linear, Embedding import numpy as np from extra.utils import download_file from pathlib import Path @@ -171,22 +171,6 @@ class Encoder: return x.transpose(0, 1), x_lens -class Embedding: - def __init__(self, vocab_size: int, embed_size: int): - self.vocab_size = vocab_size - self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False) - self.weight = Tensor.scaled_uniform(vocab_size, embed_size) - - def __call__(self, idx: Tensor) -> Tensor: - oha = [] - for i in range(idx.shape[0]): - ohba = [] - for j in range(idx.shape[1]): - ohba.append((self.vocab_counter == idx[i, j]).realize()) - oha.append(Tensor.stack(ohba).realize()) - return Tensor.stack(oha) @ self.weight - - class Prediction: def __init__(self, vocab_size, hidden_size, layers, dropout): self.hidden_size = hidden_size diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 3275d257ce..06d0caf804 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -81,12 +81,12 @@ class TestInferenceMinKernels(unittest.TestCase): out.realize() def test_llama(self): - from examples.llama import Transformer, onehot_encode + from examples.llama import Transformer args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - with CLCache(85): - model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize() + with CLCache(86): + model(Tensor([[1,2,3,4]]), 0).realize() @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptBinOp(unittest.TestCase): diff --git a/test/test_nn.py b/test/test_nn.py index 136138672e..f5a2963a23 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,8 +1,9 @@ #!/usr/bin/env python import unittest import numpy as np +from tinygrad.jit import TinyJit from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d +from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding import torch class TestNN(unittest.TestCase): @@ -150,5 +151,35 @@ class TestNN(unittest.TestCase): torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + def test_embedding(self): + B, T, C, VS = 4, 10, 20, 28 + + # create in tinygrad + layer = Embedding(VS, C) + + with torch.no_grad(): + torch_layer = torch.nn.Embedding(VS, C).eval() + torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) + + # test + x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) + z = layer(x) + torch_x = torch.tensor(x.cpu().numpy().astype(np.int32)) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) + + # test with jit enabled + @TinyJit + def layer_jit(x): + return layer(x).realize() + + for _ in range(3): + x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) + z = layer_jit(x) + torch_x = torch.tensor(x.cpu().numpy().astype(np.int32)) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) + + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 2da901f2dc..3f949a328e 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -81,3 +81,12 @@ class LayerNorm: class LayerNorm2d(LayerNorm): def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + +class Embedding: + def __init__(self, vocab_size:int, embed_size:int): + self.vocab_size = vocab_size + self.weight = Tensor.glorot_uniform(vocab_size, embed_size) + + def __call__(self, idx:Tensor) -> Tensor: + vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size) + return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight