From be23772d431abb24299a8bd332bc154ba496d736 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 5 Mar 2026 15:43:50 +0800 Subject: [PATCH] llama3 fixes part2 (#15150) --- extra/thunder/amd/fa.py | 57 ++++++++++++++++++++++------------------- tinygrad/nn/__init__.py | 10 ++++++-- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index c201f9ee04..8a37d04b32 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -19,32 +19,8 @@ def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor: return _sharded_empty(ref.shape, ref, axis) -def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False): - assert attn_mask is None, "attn_mask not supported" - assert is_causal, "only causal attention supported" - - xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2) - - B, N, H, D = xq.shape - H_KV = xk.shape[2] - assert D == 128, "only D=128 supported" - - num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1 - is_dp = xq.uop.axis == 0 - is_mp = xq.uop.axis == 2 - B_local = B // num_devices if is_dp else B - H_local = H // num_devices if is_mp else H - H_KV_local = H_KV // num_devices if is_mp else H_KV - shard_axis = 0 if is_dp else 2 if is_mp else None - shard_axis_t = 0 if is_dp else 1 if is_mp else None - if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}") - - single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device - arch = Device[single_device].renderer.arch - - attn = _sharded_empty_like(xq, axis=shard_axis) - l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) - +@functools.cache +def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch): def grad(dou:UOp, ker:UOp) -> tuple[None, None, UOp, UOp, UOp]: do = Tensor(dou, device=dou.device) attn = Tensor(ker.src[1].after(ker), device=ker.src[1].device) @@ -73,6 +49,35 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1) return None, None, dq.uop, dk.uop, dv.uop + return grad + +def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False): + assert attn_mask is None, "attn_mask not supported" + assert is_causal, "only causal attention supported" + + xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2) + + B, N, H, D = xq.shape + H_KV = xk.shape[2] + assert D == 128, "only D=128 supported" + + num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1 + is_dp = xq.uop.axis == 0 + is_mp = xq.uop.axis == 2 + B_local = B // num_devices if is_dp else B + H_local = H // num_devices if is_mp else H + H_KV_local = H_KV // num_devices if is_mp else H_KV + shard_axis = 0 if is_dp else 2 if is_mp else None + shard_axis_t = 0 if is_dp else 1 if is_mp else None + if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}") + + single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device + arch = Device[single_device].renderer.arch + + attn = _sharded_empty_like(xq, axis=shard_axis) + l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) + + grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch) attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D), grad_fxn=grad)[:2] diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index edb72684e4..9cfd1a20a0 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -import math +import math, functools from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS @@ -343,6 +343,10 @@ def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor: arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device) return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype) +@functools.cache +def _embedding_fwd_fxn(wp, ip, device): + return _embedding_fwd(Tensor(wp, device=device), Tensor(ip, device=device)) + class Embedding: """ A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -359,7 +363,9 @@ class Embedding: def __call__(self, idx:Tensor) -> Tensor: if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}") - if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd) + if USE_ATOMICS: + fxn = _embedding_fwd_fxn(self.weight.as_param(0).uop, idx.as_param(1).uop, self.weight.device) + return Tensor.call(self.weight, idx, fxn=fxn, grad_fxn=_embedding_bwd) return _embedding_fwd(self.weight, idx) class LSTMCell: