llama3 fixes part2 (#15150)

This commit is contained in:
wozeparrot
2026-03-05 15:43:50 +08:00
committed by GitHub
parent 0c769289eb
commit be23772d43
2 changed files with 39 additions and 28 deletions

View File

@@ -19,32 +19,8 @@ def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None
def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor:
return _sharded_empty(ref.shape, ref, axis)
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
assert attn_mask is None, "attn_mask not supported"
assert is_causal, "only causal attention supported"
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
B, N, H, D = xq.shape
H_KV = xk.shape[2]
assert D == 128, "only D=128 supported"
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
is_dp = xq.uop.axis == 0
is_mp = xq.uop.axis == 2
B_local = B // num_devices if is_dp else B
H_local = H // num_devices if is_mp else H
H_KV_local = H_KV // num_devices if is_mp else H_KV
shard_axis = 0 if is_dp else 2 if is_mp else None
shard_axis_t = 0 if is_dp else 1 if is_mp else None
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}")
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
arch = Device[single_device].renderer.arch
attn = _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
@functools.cache
def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch):
def grad(dou:UOp, ker:UOp) -> tuple[None, None, UOp, UOp, UOp]:
do = Tensor(dou, device=dou.device)
attn = Tensor(ker.src[1].after(ker), device=ker.src[1].device)
@@ -73,6 +49,35 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
return None, None, dq.uop, dk.uop, dv.uop
return grad
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
assert attn_mask is None, "attn_mask not supported"
assert is_causal, "only causal attention supported"
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
B, N, H, D = xq.shape
H_KV = xk.shape[2]
assert D == 128, "only D=128 supported"
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
is_dp = xq.uop.axis == 0
is_mp = xq.uop.axis == 2
B_local = B // num_devices if is_dp else B
H_local = H // num_devices if is_mp else H
H_KV_local = H_KV // num_devices if is_mp else H_KV
shard_axis = 0 if is_dp else 2 if is_mp else None
shard_axis_t = 0 if is_dp else 1 if is_mp else None
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}")
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
arch = Device[single_device].renderer.arch
attn = _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D), grad_fxn=grad)[:2]

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import math
import math, functools
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
@@ -343,6 +343,10 @@ def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor:
arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device)
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype)
@functools.cache
def _embedding_fwd_fxn(wp, ip, device):
return _embedding_fwd(Tensor(wp, device=device), Tensor(ip, device=device))
class Embedding:
"""
A simple lookup table that stores embeddings of a fixed dictionary and size.
@@ -359,7 +363,9 @@ class Embedding:
def __call__(self, idx:Tensor) -> Tensor:
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
if USE_ATOMICS:
fxn = _embedding_fwd_fxn(self.weight.as_param(0).uop, idx.as_param(1).uop, self.weight.device)
return Tensor.call(self.weight, idx, fxn=fxn, grad_fxn=_embedding_bwd)
return _embedding_fwd(self.weight, idx)
class LSTMCell: