make Embedding device aware for multigpu (#3051)

* make Embedding device aware for multigpu

* split line instead of igore because that's cheating

* add test incomplete

* add test complete

* remove comment

* fix white space

* remove nn.Embedding
This commit is contained in:
Yixiang Gao
2024-01-08 20:09:26 -08:00
committed by GitHub
parent 19298e7a3f
commit a686663657
2 changed files with 17 additions and 2 deletions

View File

@@ -138,6 +138,20 @@ class TestMultiTensor(unittest.TestCase):
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28
layer = nn.Embedding(vocab_size, embed_size)
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
z = layer(x)
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
x_sharded = x.shard((d0, d1), axis=None)
z_shard = layer_sharded(x_sharded)
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_data_parallel_resnet(self):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())

View File

@@ -126,7 +126,8 @@ class Embedding:
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
if not hasattr(self, 'vocab_counter'):
self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False, device=self.weight.device).reshape(1, 1, self.vocab_size)
batch_size, seqlen = idx.shape
if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size)
if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size, device=self.weight.device)
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight