feat: make USE_ATOMICS embedding bwd faster (#15151)

This commit is contained in:
wozeparrot
2026-03-15 12:21:10 +08:00
committed by GitHub
parent 3858bfc83d
commit 473e5e4368

View File

@@ -303,7 +303,7 @@ class RMSNorm:
x = self._norm(x.float()).cast(x.dtype)
return x if self.weight is None else x * self.weight
from tinygrad.uop.ops import UOp, KernelInfo, Ops
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
weight, idx = call.src[1:]
# for multi-device: unshard inputs to one device
@@ -326,15 +326,29 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
# this is the real atomic kernel
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
embed_size = grad_weight.shape[-1]
BLOCK_J = min(256, embed_size)
assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}"
n_j_blocks = embed_size // BLOCK_J
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length -> GLOBAL
j_inner = UOp.range(BLOCK_J, 2, AxisType.LOOP if device in ("CPU", "NULL") else AxisType.LOCAL) # BLOCK_J threads per workgroup
j_outer = UOp.range(n_j_blocks, 1)
j = j_outer * BLOCK_J + j_inner
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
else: raise NotImplementedError(f"no atomics for device {device}")
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
return (grad_weight_uop.cast(weight.dtype), None)