mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
feat: make USE_ATOMICS embedding bwd faster (#15151)
This commit is contained in:
@@ -303,7 +303,7 @@ class RMSNorm:
|
||||
x = self._norm(x.float()).cast(x.dtype)
|
||||
return x if self.weight is None else x * self.weight
|
||||
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, Ops
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType
|
||||
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
weight, idx = call.src[1:]
|
||||
# for multi-device: unshard inputs to one device
|
||||
@@ -326,15 +326,29 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
# this is the real atomic kernel
|
||||
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
|
||||
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
|
||||
|
||||
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
|
||||
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
|
||||
|
||||
embed_size = grad_weight.shape[-1]
|
||||
BLOCK_J = min(256, embed_size)
|
||||
assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}"
|
||||
|
||||
n_j_blocks = embed_size // BLOCK_J
|
||||
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length -> GLOBAL
|
||||
j_inner = UOp.range(BLOCK_J, 2, AxisType.LOOP if device in ("CPU", "NULL") else AxisType.LOCAL) # BLOCK_J threads per workgroup
|
||||
j_outer = UOp.range(n_j_blocks, 1)
|
||||
j = j_outer * BLOCK_J + j_inner
|
||||
|
||||
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
|
||||
|
||||
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
|
||||
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
|
||||
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
|
||||
else: raise NotImplementedError(f"no atomics for device {device}")
|
||||
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
|
||||
return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
|
||||
return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
|
||||
|
||||
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
|
||||
|
||||
return (grad_weight_uop.cast(weight.dtype), None)
|
||||
|
||||
Reference in New Issue
Block a user