mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make Embedding device aware for multigpu (#3051)
* make Embedding device aware for multigpu * split line instead of igore because that's cheating * add test incomplete * add test complete * remove comment * fix white space * remove nn.Embedding
This commit is contained in:
@@ -138,6 +138,20 @@ class TestMultiTensor(unittest.TestCase):
|
||||
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
||||
lr_sched.step()
|
||||
|
||||
def test_embedding(self):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
||||
layer = nn.Embedding(vocab_size, embed_size)
|
||||
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
|
||||
z = layer(x)
|
||||
|
||||
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
||||
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
|
||||
x_sharded = x.shard((d0, d1), axis=None)
|
||||
z_shard = layer_sharded(x_sharded)
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_data_parallel_resnet(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
|
||||
@@ -126,7 +126,8 @@ class Embedding:
|
||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
|
||||
if not hasattr(self, 'vocab_counter'):
|
||||
self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False, device=self.weight.device).reshape(1, 1, self.vocab_size)
|
||||
batch_size, seqlen = idx.shape
|
||||
if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size)
|
||||
if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size, device=self.weight.device)
|
||||
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight
|
||||
|
||||
Reference in New Issue
Block a user