mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama3 fixes part2 (#15150)
This commit is contained in:
@@ -19,32 +19,8 @@ def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None
|
||||
def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor:
|
||||
return _sharded_empty(ref.shape, ref, axis)
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
assert attn_mask is None, "attn_mask not supported"
|
||||
assert is_causal, "only causal attention supported"
|
||||
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
|
||||
B, N, H, D = xq.shape
|
||||
H_KV = xk.shape[2]
|
||||
assert D == 128, "only D=128 supported"
|
||||
|
||||
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
|
||||
is_dp = xq.uop.axis == 0
|
||||
is_mp = xq.uop.axis == 2
|
||||
B_local = B // num_devices if is_dp else B
|
||||
H_local = H // num_devices if is_mp else H
|
||||
H_KV_local = H_KV // num_devices if is_mp else H_KV
|
||||
shard_axis = 0 if is_dp else 2 if is_mp else None
|
||||
shard_axis_t = 0 if is_dp else 1 if is_mp else None
|
||||
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}")
|
||||
|
||||
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
|
||||
arch = Device[single_device].renderer.arch
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=shard_axis)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
|
||||
@functools.cache
|
||||
def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch):
|
||||
def grad(dou:UOp, ker:UOp) -> tuple[None, None, UOp, UOp, UOp]:
|
||||
do = Tensor(dou, device=dou.device)
|
||||
attn = Tensor(ker.src[1].after(ker), device=ker.src[1].device)
|
||||
@@ -73,6 +49,35 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
|
||||
|
||||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
return grad
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
assert attn_mask is None, "attn_mask not supported"
|
||||
assert is_causal, "only causal attention supported"
|
||||
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
|
||||
B, N, H, D = xq.shape
|
||||
H_KV = xk.shape[2]
|
||||
assert D == 128, "only D=128 supported"
|
||||
|
||||
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
|
||||
is_dp = xq.uop.axis == 0
|
||||
is_mp = xq.uop.axis == 2
|
||||
B_local = B // num_devices if is_dp else B
|
||||
H_local = H // num_devices if is_mp else H
|
||||
H_KV_local = H_KV // num_devices if is_mp else H_KV
|
||||
shard_axis = 0 if is_dp else 2 if is_mp else None
|
||||
shard_axis_t = 0 if is_dp else 1 if is_mp else None
|
||||
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}")
|
||||
|
||||
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
|
||||
arch = Device[single_device].renderer.arch
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=shard_axis)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
|
||||
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D), grad_fxn=grad)[:2]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import math, functools
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
|
||||
@@ -343,6 +343,10 @@ def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor:
|
||||
arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device)
|
||||
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype)
|
||||
|
||||
@functools.cache
|
||||
def _embedding_fwd_fxn(wp, ip, device):
|
||||
return _embedding_fwd(Tensor(wp, device=device), Tensor(ip, device=device))
|
||||
|
||||
class Embedding:
|
||||
"""
|
||||
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
@@ -359,7 +363,9 @@ class Embedding:
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
|
||||
if USE_ATOMICS:
|
||||
fxn = _embedding_fwd_fxn(self.weight.as_param(0).uop, idx.as_param(1).uop, self.weight.device)
|
||||
return Tensor.call(self.weight, idx, fxn=fxn, grad_fxn=_embedding_bwd)
|
||||
return _embedding_fwd(self.weight, idx)
|
||||
|
||||
class LSTMCell:
|
||||
|
||||
Reference in New Issue
Block a user