From a68666365783641ba7e1a6af243bd487d5ac1a61 Mon Sep 17 00:00:00 2001 From: Yixiang Gao Date: Mon, 8 Jan 2024 20:09:26 -0800 Subject: [PATCH] make Embedding device aware for multigpu (#3051) * make Embedding device aware for multigpu * split line instead of igore because that's cheating * add test incomplete * add test complete * remove comment * fix white space * remove nn.Embedding --- test/test_multitensor.py | 14 ++++++++++++++ tinygrad/nn/__init__.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 1d25761dfd..a18fece9ea 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -138,6 +138,20 @@ class TestMultiTensor(unittest.TestCase): lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() + def test_embedding(self): + B, T, embed_size, vocab_size = 4, 10, 20, 28 + + layer = nn.Embedding(vocab_size, embed_size) + x = Tensor(np.random.randint(0, vocab_size, (B, T))) + z = layer(x) + + layer_sharded = nn.Embedding(vocab_size, embed_size) + layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize() + x_sharded = x.shard((d0, d1), axis=None) + z_shard = layer_sharded(x_sharded) + + np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6) + def test_data_parallel_resnet(self): import sys, pathlib sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index c016afe407..67957749cc 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -126,7 +126,8 @@ class Embedding: self.weight = Tensor.glorot_uniform(vocab_size, embed_size) def __call__(self, idx:Tensor) -> Tensor: - if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size) + if not hasattr(self, 'vocab_counter'): + self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False, device=self.weight.device).reshape(1, 1, self.vocab_size) batch_size, seqlen = idx.shape - if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size) + if seqlen == 0: return Tensor.empty(batch_size, 0, self.embed_size, device=self.weight.device) return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight