mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Promote Embedding to nn (#798)
* feat: promote Embedding to nn * fix: fix failing test * feat: add test with jit * feat: rewrite embedding to no longer need stacked for loops * clean+fix: don't know how that happened
This commit is contained in:
@@ -81,3 +81,12 @@ class LayerNorm:
|
||||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size:int, embed_size:int):
|
||||
self.vocab_size = vocab_size
|
||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size)
|
||||
return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight
|
||||
|
||||
Reference in New Issue
Block a user