Promote Embedding to nn (#798)

* feat: promote Embedding to nn

* fix: fix failing test

* feat: add test with jit

* feat: rewrite embedding to no longer need stacked for loops

* clean+fix: don't know how that happened
This commit is contained in:
wozeparrot
2023-05-25 21:39:45 -04:00
committed by GitHub
parent f4f23dc9a3
commit 0dc333cfab
5 changed files with 51 additions and 34 deletions

View File

@@ -81,3 +81,12 @@ class LayerNorm:
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class Embedding:
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_size = vocab_size
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size)
return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight