mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Promote Embedding to nn (#798)
* feat: promote Embedding to nn * fix: fix failing test * feat: add test with jit * feat: rewrite embedding to no longer need stacked for loops * clean+fix: don't know how that happened
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.nn import Linear, Embedding
|
||||
import numpy as np
|
||||
from extra.utils import download_file
|
||||
from pathlib import Path
|
||||
@@ -171,22 +171,6 @@ class Encoder:
|
||||
return x.transpose(0, 1), x_lens
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size: int, embed_size: int):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
|
||||
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx: Tensor) -> Tensor:
|
||||
oha = []
|
||||
for i in range(idx.shape[0]):
|
||||
ohba = []
|
||||
for j in range(idx.shape[1]):
|
||||
ohba.append((self.vocab_counter == idx[i, j]).realize())
|
||||
oha.append(Tensor.stack(ohba).realize())
|
||||
return Tensor.stack(oha) @ self.weight
|
||||
|
||||
|
||||
class Prediction:
|
||||
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
Reference in New Issue
Block a user