From 473e5e43686496dcfdd249e0d438f14eeb37536c Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 15 Mar 2026 12:21:10 +0800 Subject: [PATCH] feat: make USE_ATOMICS embedding bwd faster (#15151) --- tinygrad/nn/__init__.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 9cfd1a20a0..8b4bbb2c64 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -303,7 +303,7 @@ class RMSNorm: x = self._norm(x.float()).cast(x.dtype) return x if self.weight is None else x * self.weight -from tinygrad.uop.ops import UOp, KernelInfo, Ops +from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: weight, idx = call.src[1:] # for multi-device: unshard inputs to one device @@ -326,15 +326,29 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: # this is the real atomic kernel def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp: idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1])) + i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size + + embed_size = grad_weight.shape[-1] + BLOCK_J = min(256, embed_size) + assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}" + + n_j_blocks = embed_size // BLOCK_J + i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length -> GLOBAL + j_inner = UOp.range(BLOCK_J, 2, AxisType.LOOP if device in ("CPU", "NULL") else AxisType.LOCAL) # BLOCK_J threads per workgroup + j_outer = UOp.range(n_j_blocks, 1) + j = j_outer * BLOCK_J + j_inner + token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index) + # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j] if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);" elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);" else: raise NotImplementedError(f"no atomics for device {device}") atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg) - return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=())) + return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=())) + grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0] return (grad_weight_uop.cast(weight.dtype), None)