Promote Embedding to nn (#798)

* feat: promote Embedding to nn

* fix: fix failing test

* feat: add test with jit

* feat: rewrite embedding to no longer need stacked for loops

* clean+fix: don't know how that happened
This commit is contained in:
wozeparrot
2023-05-25 21:39:45 -04:00
committed by GitHub
parent f4f23dc9a3
commit 0dc333cfab
5 changed files with 51 additions and 34 deletions

View File

@@ -1,6 +1,6 @@
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.nn import Linear
from tinygrad.nn import Linear, Embedding
import numpy as np
from extra.utils import download_file
from pathlib import Path
@@ -171,22 +171,6 @@ class Encoder:
return x.transpose(0, 1), x_lens
class Embedding:
def __init__(self, vocab_size: int, embed_size: int):
self.vocab_size = vocab_size
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
def __call__(self, idx: Tensor) -> Tensor:
oha = []
for i in range(idx.shape[0]):
ohba = []
for j in range(idx.shape[1]):
ohba.append((self.vocab_counter == idx[i, j]).realize())
oha.append(Tensor.stack(ohba).realize())
return Tensor.stack(oha) @ self.weight
class Prediction:
def __init__(self, vocab_size, hidden_size, layers, dropout):
self.hidden_size = hidden_size