Promote Embedding to nn (#798)

* feat: promote Embedding to nn

* fix: fix failing test

* feat: add test with jit

* feat: rewrite embedding to no longer need stacked for loops

* clean+fix: don't know how that happened
This commit is contained in:
wozeparrot
2023-05-25 21:39:45 -04:00
committed by GitHub
parent f4f23dc9a3
commit 0dc333cfab
5 changed files with 51 additions and 34 deletions

View File

@@ -81,12 +81,12 @@ class TestInferenceMinKernels(unittest.TestCase):
out.realize()
def test_llama(self):
from examples.llama import Transformer, onehot_encode
from examples.llama import Transformer
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(85):
model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize()
with CLCache(86):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):