embedded bwd vocab shard (#15001)

* fix: remove more multi from call

* feat: embedding bwd vocab sharding

* clean: unused import

* clean: don't actually need this pattern
This commit is contained in:
wozeparrot
2026-03-17 10:37:16 +08:00
committed by GitHub
parent 62bfd48d95
commit 674c760974
3 changed files with 81 additions and 10 deletions

View File

@@ -8,6 +8,7 @@ from tinygrad.engine.schedule import ExecItem
from tinygrad.uop.ops import Ops
from tinygrad.renderer import Estimates
from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import needs_second_gpu
class TestArange(unittest.TestCase):
def _get_flops(self, tensor, desired):
@@ -188,6 +189,33 @@ class TestIndexing(unittest.TestCase):
for i in idx.flatten().numpy(): expected_grad[i] += 2
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
@needs_second_gpu
@unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU")
@Context(USE_ATOMICS=1, SPEC=1)
def test_embedding_backward_vocab_sharded(self):
from tinygrad.renderer.cstyle import CStyleLanguage
if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer")
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
vocab_size, embed_size = 1000, 128
bs, seqlen = 4, 256
idx = Tensor.randint(bs, seqlen, high=vocab_size)
emb = nn.Embedding(vocab_size, embed_size)
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
gt = Tensor.zeros(bs, seqlen, embed_size)
Tensor.realize(idx, emb.weight, gt)
# compute expected grad on single device
expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32)
for i in idx.flatten().numpy(): expected_grad[i] += 2
# now shard the embedding weight on vocab axis and recompute
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
emb.weight.shard_(devices, axis=0)
idx = idx.shard(devices, axis=None)
gt = gt.shard(devices, axis=None)
Tensor.realize(idx, emb.weight, gt)
loss = (emb(idx)-gt).square().sum()
loss.backward()
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
@unittest.skipUnless(Device.DEFAULT == "AMD" or (Device.DEFAULT == "NULL" and EMULATE.value.startswith("AMD")), "tests AMD bf16 cast overhead")
def base_test_llama_8b_rope_backward(self, dtype):
from extra.models.llama import precompute_freqs_cis, apply_rotary_emb

View File

@@ -245,5 +245,36 @@ class TestCallSchedule(unittest.TestCase):
out = f(a) + 2
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
def test_call_reduce_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
Tensor.realize(a)
c = Tensor.call(a, fxn=a.as_param(0).sum(axis=0))
np.testing.assert_equal(c.numpy(), 10 * np.ones(10))
def test_call_reduce_sharded_mixed_args(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
b = Tensor.ones(10).shard(devs, axis=None)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0).sum(axis=0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 11 * np.ones(10))
def test_call_reduce_sharded_backward(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
b = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
Tensor.realize(a, b)
def grad_fxn(grad, call):
a_arg, b_arg = call.src[1], call.src[2]
return (grad.expand(a_arg.shape) * b_arg, grad.expand(b_arg.shape) * a_arg)
body = (a.as_param(0) * b.as_param(1)).sum(axis=0)
c = Tensor.call(a, b, fxn=body, grad_fxn=grad_fxn)
c.sum().backward()
np.testing.assert_allclose(a.grad.numpy(), b.numpy(), rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), a.numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -306,13 +306,19 @@ class RMSNorm:
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
weight, idx = call.src[1:]
# for multi-device: unshard inputs to one device
is_vocab_sharded = isinstance(weight.device, tuple) and weight.axis == 0
# for multi-device: replicate grad_emb and idx on all devices
if isinstance(weight.device, tuple):
assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS"
assert weight.axis is None or weight.axis == 0, "only vocab (axis=0) sharding supported on Embedding with USE_ATOMICS"
grad_emb = grad_emb.copy_to_device(weight.device)
idx = idx.copy_to_device(weight.device)
# weight is replicated, grad_weight should match
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
if is_vocab_sharded:
ndev = len(weight.device)
local_vocab_size = weight.shape[0] // ndev
grad_weight_uop = Tensor.empty(local_vocab_size, weight.shape[1], dtype=dtypes.float, device=weight.device).uop.multi(axis=0)
else:
# weight is replicated (or single device), grad_weight should match
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
# TODO: how do we remove this dumb kernel and use Tensor.zeros?
def _zero_kernel(out:UOp) -> UOp:
@@ -327,9 +333,6 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
embed_size = grad_weight.shape[-1]
BLOCK_J = min(256, embed_size)
assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}"
@@ -340,13 +343,22 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
j_outer = UOp.range(n_j_blocks, 1)
j = j_outer * BLOCK_J + j_inner
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
if is_vocab_sharded:
# each device owns [offset, offset+local_vocab_size) of the global vocabulary
dnum = UOp.variable("_device_num", 0, ndev-1)
offset = dnum * local_vocab_size
global_token_id = idx_flat[i].cast(dtypes.index)
local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1)
in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size))
grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0)
else:
local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
grad_val = grad_emb_flat[i, j].cast(dtypes.float)
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
else: raise NotImplementedError(f"no atomics for device {device}")
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(local_token_id, j, ptr=True), grad_val), arg = atomic_arg)
return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]