diff --git a/test/backend/test_arange.py b/test/backend/test_arange.py index a6b62bfe24..4ab9b6ef20 100644 --- a/test/backend/test_arange.py +++ b/test/backend/test_arange.py @@ -8,6 +8,7 @@ from tinygrad.engine.schedule import ExecItem from tinygrad.uop.ops import Ops from tinygrad.renderer import Estimates from tinygrad.renderer.ptx import PTXRenderer +from test.helpers import needs_second_gpu class TestArange(unittest.TestCase): def _get_flops(self, tensor, desired): @@ -188,6 +189,33 @@ class TestIndexing(unittest.TestCase): for i in idx.flatten().numpy(): expected_grad[i] += 2 np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5) + @needs_second_gpu + @unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU") + @Context(USE_ATOMICS=1, SPEC=1) + def test_embedding_backward_vocab_sharded(self): + from tinygrad.renderer.cstyle import CStyleLanguage + if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer") + devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1") + vocab_size, embed_size = 1000, 128 + bs, seqlen = 4, 256 + idx = Tensor.randint(bs, seqlen, high=vocab_size) + emb = nn.Embedding(vocab_size, embed_size) + emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True) + gt = Tensor.zeros(bs, seqlen, embed_size) + Tensor.realize(idx, emb.weight, gt) + # compute expected grad on single device + expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32) + for i in idx.flatten().numpy(): expected_grad[i] += 2 + # now shard the embedding weight on vocab axis and recompute + emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True) + emb.weight.shard_(devices, axis=0) + idx = idx.shard(devices, axis=None) + gt = gt.shard(devices, axis=None) + Tensor.realize(idx, emb.weight, gt) + loss = (emb(idx)-gt).square().sum() + loss.backward() + np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5) + @unittest.skipUnless(Device.DEFAULT == "AMD" or (Device.DEFAULT == "NULL" and EMULATE.value.startswith("AMD")), "tests AMD bf16 cast overhead") def base_test_llama_8b_rope_backward(self, dtype): from extra.models.llama import precompute_freqs_cis, apply_rotary_emb diff --git a/test/unit/test_call.py b/test/unit/test_call.py index 95c498362b..d7c1cc2873 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -245,5 +245,36 @@ class TestCallSchedule(unittest.TestCase): out = f(a) + 2 np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3) + def test_call_reduce_sharded(self): + devs = ("CPU:0", "CPU:1") + a = Tensor.ones(10, 10).shard(devs, axis=0) + Tensor.realize(a) + c = Tensor.call(a, fxn=a.as_param(0).sum(axis=0)) + np.testing.assert_equal(c.numpy(), 10 * np.ones(10)) + + def test_call_reduce_sharded_mixed_args(self): + devs = ("CPU:0", "CPU:1") + a = Tensor.ones(10, 10).shard(devs, axis=0) + b = Tensor.ones(10).shard(devs, axis=None) + Tensor.realize(a, b) + c = Tensor.call(a, b, fxn=a.as_param(0).sum(axis=0) + b.as_param(1)) + np.testing.assert_equal(c.numpy(), 11 * np.ones(10)) + + def test_call_reduce_sharded_backward(self): + devs = ("CPU:0", "CPU:1") + a = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0) + b = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0) + Tensor.realize(a, b) + + def grad_fxn(grad, call): + a_arg, b_arg = call.src[1], call.src[2] + return (grad.expand(a_arg.shape) * b_arg, grad.expand(b_arg.shape) * a_arg) + + body = (a.as_param(0) * b.as_param(1)).sum(axis=0) + c = Tensor.call(a, b, fxn=body, grad_fxn=grad_fxn) + c.sum().backward() + np.testing.assert_allclose(a.grad.numpy(), b.numpy(), rtol=1e-5) + np.testing.assert_allclose(b.grad.numpy(), a.numpy(), rtol=1e-5) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 8b4bbb2c64..f7ec524346 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -306,13 +306,19 @@ class RMSNorm: from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: weight, idx = call.src[1:] - # for multi-device: unshard inputs to one device + is_vocab_sharded = isinstance(weight.device, tuple) and weight.axis == 0 + # for multi-device: replicate grad_emb and idx on all devices if isinstance(weight.device, tuple): - assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS" + assert weight.axis is None or weight.axis == 0, "only vocab (axis=0) sharding supported on Embedding with USE_ATOMICS" grad_emb = grad_emb.copy_to_device(weight.device) idx = idx.copy_to_device(weight.device) - # weight is replicated, grad_weight should match - grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop + if is_vocab_sharded: + ndev = len(weight.device) + local_vocab_size = weight.shape[0] // ndev + grad_weight_uop = Tensor.empty(local_vocab_size, weight.shape[1], dtype=dtypes.float, device=weight.device).uop.multi(axis=0) + else: + # weight is replicated (or single device), grad_weight should match + grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop # TODO: how do we remove this dumb kernel and use Tensor.zeros? def _zero_kernel(out:UOp) -> UOp: @@ -327,9 +333,6 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp: idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1])) - i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length - j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size - embed_size = grad_weight.shape[-1] BLOCK_J = min(256, embed_size) assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}" @@ -340,13 +343,22 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: j_outer = UOp.range(n_j_blocks, 1) j = j_outer * BLOCK_J + j_inner - token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index) - + if is_vocab_sharded: + # each device owns [offset, offset+local_vocab_size) of the global vocabulary + dnum = UOp.variable("_device_num", 0, ndev-1) + offset = dnum * local_vocab_size + global_token_id = idx_flat[i].cast(dtypes.index) + local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1) + in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size)) + grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0) + else: + local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index) + grad_val = grad_emb_flat[i, j].cast(dtypes.float) # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j] if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);" elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);" else: raise NotImplementedError(f"no atomics for device {device}") - atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg) + atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(local_token_id, j, ptr=True), grad_val), arg = atomic_arg) return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=())) grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]