mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
ai slop flash attention (it works) (#15401)
* ai slop flash attention (it works) * speed up, 2 TFLOPS + 7 GB/s * simpler * simpler * optimize * faster * warp shuffle * sqtt: link dispatch to exec (#15396) * sqtt packet linking infra python * javascript * ~doubly linked list * ui works * work * exec can also highlight the pc, coloring work * more work * rm sqtt/model.py, doesn't need to be upstreamed * viz: no context enters in cli, update llama profile (#15404) * removed unused named arg in rules [pr] (#15414) * viz: sqtt printer in viz/cli.py (#15411) * work * sqtt timeline in CLI * format all printers nicely * s/Showed/Printed * ansistrip * sys.exit * keep colors in list * work from amd_copy_matmul * has_more always gets returned * linter * don't print colors * more colors * wow this is so deep * work * minor details * selected * improve progress bar * remove it * 22, global_load_vaddr is so long * remove *0 hack in sign, gradient materializes zeros for unconnected nodes (#15416) Amp-Thread-ID: https://ampcode.com/threads/T-019d1612-6322-706b-a94d-a812400a55cb Co-authored-by: Amp <amp@ampcode.com> * works * cnt=20 * revert that * uop slice tests * simpler --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: gg <ggordbegli@gmail.com> Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
208
extra/gemm/amd_flash_attention.py
Normal file
208
extra/gemm/amd_flash_attention.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from tinygrad import Tensor, UOp, getenv
|
||||
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
|
||||
from tinygrad.dtype import AddrSpace, dtypes
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, Context
|
||||
import math
|
||||
|
||||
B = getenv("B", 1)
|
||||
H = getenv("H", 32)
|
||||
N = getenv("N", 1024)
|
||||
D = getenv("D", 64)
|
||||
assert D % 16 == 0 and N % 16 == 0
|
||||
|
||||
BLOCK_M, BLOCK_N = 64, 64
|
||||
WARP_SIZE = 32
|
||||
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
|
||||
WAVES_M, WAVES_N = 4, 1
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
|
||||
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
|
||||
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
|
||||
|
||||
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M)
|
||||
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N)
|
||||
TD = D // (WAVES_N * LANES_PER_WAVE_N)
|
||||
LDS_PAD = 4 # pad LDS rows to reduce bank conflicts
|
||||
|
||||
WMMA_ARG = ((WMMA_M, WMMA_N, WMMA_K), 'AMD', 32)
|
||||
SCALE = 1.0 / math.sqrt(D)
|
||||
LOG2E = math.log2(math.e)
|
||||
|
||||
def warp_shfl_xor(val, offset, lane):
|
||||
"""Read val from lane ^ offset using ds_bpermute."""
|
||||
idx = ((lane ^ offset) * 4).cast(dtypes.int)
|
||||
return UOp(Ops.CUSTOM, dtypes.float, (idx, val),
|
||||
arg="__builtin_bit_cast(float, __builtin_amdgcn_ds_bpermute({0}, __builtin_bit_cast(int, {1})))")
|
||||
|
||||
def warp_reduce_max(val, lane):
|
||||
"""Tree reduce MAX across LANES_PER_WAVE_N=16 lanes."""
|
||||
for offset in [8, 4, 2, 1]:
|
||||
val = UOp(Ops.MAX, dtypes.float, (val, warp_shfl_xor(val, offset, lane)))
|
||||
return val
|
||||
|
||||
def warp_reduce_sum(val, lane):
|
||||
"""Tree reduce SUM across LANES_PER_WAVE_N=16 lanes."""
|
||||
for offset in [8, 4, 2, 1]:
|
||||
val = val + warp_shfl_xor(val, offset, lane)
|
||||
return val
|
||||
|
||||
def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
||||
block_bh = UOp.range(B * H, 0, AxisType.GLOBAL)
|
||||
block_m = UOp.range(N // BLOCK_M, 1, AxisType.GLOBAL)
|
||||
|
||||
q = q.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
|
||||
k = k.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
|
||||
v = v.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
|
||||
o = o.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
|
||||
|
||||
wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL)
|
||||
wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL)
|
||||
lane = UOp.range(WARP_SIZE, -1, AxisType.WARP)
|
||||
tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane
|
||||
lane_m = lane // LANES_PER_WAVE_N
|
||||
lane_n = lane % LANES_PER_WAVE_N
|
||||
|
||||
# LDS allocation: slot 0 = Q then P (shared), slot 1 = K then V
|
||||
# TODO: the memory planner should be able to find this reuse
|
||||
ELEMS_PER_THREAD = BLOCK_M * D // THREADS_PER_BLOCK
|
||||
QP_lds = UOp.placeholder((BLOCK_M, D + LDS_PAD), dtypes.half, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
KV_lds = UOp.placeholder((BLOCK_N, D + LDS_PAD), dtypes.half, slot=1, addrspace=AddrSpace.LOCAL)[:, :D]
|
||||
|
||||
# register state
|
||||
acc = UOp.placeholder((TM, TD), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
m_i = UOp.placeholder((TM,), dtypes.float, slot=3, addrspace=AddrSpace.REG)
|
||||
l_i = UOp.placeholder((TM,), dtypes.float, slot=4, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(acc.const_like(0)))
|
||||
m_i = m_i.after(m_i.store(m_i.const_like(-math.inf)))
|
||||
l_i = l_i.after(l_i.store(l_i.const_like(0)))
|
||||
|
||||
# ====== KV tile loop ======
|
||||
n_tile = UOp.range(N // BLOCK_N, 100, AxisType.REDUCE)
|
||||
|
||||
# load Q + K into LDS (Q reloaded each iteration since P overwrites slot 0)
|
||||
Q_lds = QP_lds[:, :D]
|
||||
Q_store = Q_lds.after(n_tile).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
q.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
K_store = KV_lds.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
k[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
qk_load_barrier = UOp.barrier(UOp.group(Q_store, K_store))
|
||||
Q_lds = Q_lds.after(qk_load_barrier)
|
||||
KV_lds_k = KV_lds.after(qk_load_barrier)
|
||||
|
||||
# -- S = Q @ K^T via WMMA (re-init each n_tile) --
|
||||
S_reg = UOp.placeholder((TM, TN), dtypes.float, slot=6, addrspace=AddrSpace.REG)
|
||||
S_reg = S_reg.after(S_reg.after(n_tile).store(S_reg.const_like(0)))
|
||||
k_qk = UOp.range(D // WMMA_K, 101, AxisType.REDUCE)
|
||||
tm1 = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP)
|
||||
tn1 = UOp.range(TN, 201, AxisType.LOOP)
|
||||
S_frag = S_reg.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0, 2, 1)[tm1, tn1]
|
||||
q_frag = Q_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, D // WMMA_K, WMMA_K)[wave_m, tm1, lane_n, k_qk]
|
||||
k_frag = KV_lds_k.reshape(WAVES_N, TN, WMMA_N, D // WMMA_K, WMMA_K)[wave_n, tn1, lane_n, k_qk]
|
||||
qk = UOp(Ops.SHAPED_WMMA, dtypes.float, (q_frag, k_frag, S_frag.after(k_qk)), arg=WMMA_ARG)
|
||||
qk_done = S_frag.store(qk).end(tm1, tn1).end(k_qk)
|
||||
S_reg = S_reg.after(qk_done)
|
||||
|
||||
# -- softmax in registers with warp shuffles --
|
||||
S_reg = S_reg.after(S_reg.store(S_reg * SCALE))
|
||||
|
||||
# per-thread local row max over TN=4 elements, then warp reduce across 16 lanes
|
||||
m_ij = UOp.placeholder((TM,), dtypes.float, slot=7, addrspace=AddrSpace.REG)
|
||||
m_ij = m_ij.after(m_ij.after(n_tile).store(m_ij.const_like(-math.inf)))
|
||||
rm1 = UOp.range(TM, 260, AxisType.LOOP)
|
||||
rm2 = UOp.range(TN, 261, AxisType.REDUCE)
|
||||
m_ij = m_ij.after(m_ij[rm1].store(UOp(Ops.MAX, dtypes.float, (m_ij.after(rm1, rm2)[rm1], S_reg[rm1, rm2]))).end(rm2, rm1))
|
||||
# warp reduce max (in-place)
|
||||
ri_w = UOp.range(TM, 270, AxisType.LOOP)
|
||||
m_ij = m_ij.after(m_ij[ri_w].store(warp_reduce_max(m_ij[ri_w], lane)).end(ri_w))
|
||||
|
||||
# compute P = exp(S - m_ij) in S_reg (manual ranges)
|
||||
rp0a = UOp.range(TM, 275, AxisType.LOOP)
|
||||
rp0b = UOp.range(TN, 276, AxisType.LOOP)
|
||||
S_reg = S_reg.after(S_reg[rp0a, rp0b].store(((S_reg[rp0a, rp0b] - m_ij[rp0a]) * LOG2E).exp2()).end(rp0a, rp0b))
|
||||
|
||||
p_local = UOp.placeholder((TM,), dtypes.float, slot=8, addrspace=AddrSpace.REG)
|
||||
p_local = p_local.after(p_local.after(n_tile).store(p_local.const_like(0)))
|
||||
rp1 = UOp.range(TM, 290, AxisType.LOOP)
|
||||
rp2 = UOp.range(TN, 291, AxisType.REDUCE)
|
||||
p_local = p_local.after(p_local[rp1].store(p_local.after(rp1, rp2)[rp1] + S_reg[rp1, rp2]).end(rp2, rp1))
|
||||
ri_ws = UOp.range(TM, 295, AxisType.LOOP)
|
||||
p_sum = p_local.after(p_local[ri_ws].store(warp_reduce_sum(p_local[ri_ws], lane)).end(ri_ws))
|
||||
|
||||
# write P = exp(S - m_ij) to P_lds (reuses slot 0, Q no longer needed)
|
||||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
||||
# -- online softmax correction --
|
||||
ri4 = UOp.range(TM, 330, AxisType.LOOP)
|
||||
m_new_val = UOp(Ops.MAX, dtypes.float, (m_i[ri4], m_ij[ri4]))
|
||||
alpha_val = ((m_i[ri4] - m_new_val) * LOG2E).exp2()
|
||||
beta_val = ((m_ij[ri4] - m_new_val) * LOG2E).exp2()
|
||||
rj4 = UOp.range(TD, 331, AxisType.LOOP)
|
||||
correction = UOp.group(
|
||||
acc[ri4, rj4].store(alpha_val * acc[ri4, rj4]).end(rj4),
|
||||
l_i[ri4].store(alpha_val * l_i[ri4] + beta_val * p_sum[ri4]),
|
||||
m_i[ri4].store(m_new_val),
|
||||
).end(ri4)
|
||||
acc = acc.after(correction)
|
||||
l_i = l_i.after(correction)
|
||||
m_i = m_i.after(correction)
|
||||
|
||||
# load V into KV_lds (must wait for QK WMMA to finish reading K from KV_lds)
|
||||
V_store = KV_lds.after(qk_done).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
v[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
pv_barrier = UOp.barrier(UOp.group(P_store, V_store))
|
||||
P_lds = P_lds.after(pv_barrier)
|
||||
KV_lds_v = KV_lds.after(pv_barrier)
|
||||
|
||||
# -- acc += P @ V via WMMA --
|
||||
k_pv = UOp.range(BLOCK_N // WMMA_K, 400, AxisType.REDUCE)
|
||||
tm2 = UOp.range(TM // WMMA_ACC, 401, AxisType.LOOP)
|
||||
tn2 = UOp.range(TD, 402, AxisType.LOOP)
|
||||
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TD).permute(0, 2, 1)[tm2, tn2]
|
||||
p_frag = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_N // WMMA_K, WMMA_K)[wave_m, tm2, lane_n, k_pv]
|
||||
v_frag = KV_lds_v.reshape(WAVES_N, TD, WMMA_N, BLOCK_N // WMMA_K, WMMA_K)[wave_n, tn2, lane_n, k_pv]
|
||||
pv = UOp(Ops.SHAPED_WMMA, dtypes.float, (p_frag, v_frag, acc_frag.after(k_pv)), arg=WMMA_ARG)
|
||||
|
||||
# end KV tile loop
|
||||
n_tile_end = acc_frag.store(pv).end(tm2, tn2).end(k_pv).barrier().end(n_tile)
|
||||
acc = acc.after(n_tile_end)
|
||||
l_i = l_i.after(n_tile_end)
|
||||
m_i = m_i.after(n_tile_end)
|
||||
|
||||
# normalize: acc /= l_i
|
||||
acc = acc.after(acc.store(acc * (1 / l_i).reshape(TM, 1).expand(TM, TD)))
|
||||
|
||||
# store output
|
||||
o = o.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TD, LANES_PER_WAVE_N)
|
||||
o = o.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TD)
|
||||
return o[tid].store(acc).end(wave_m, wave_n, lane).end(block_m, block_bh).sink(arg=KernelInfo(opts_to_apply=()))
|
||||
|
||||
if __name__ == "__main__":
|
||||
q = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
k = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
v = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
o = Tensor.empty(B, H, N, D, dtype=dtypes.float)
|
||||
with Context(DEBUG=0): Tensor.realize(q, k, v)
|
||||
|
||||
q_flat, k_flat, v_flat, o_flat = q.reshape(B*H, N, D), k.reshape(B*H, N, D), v.reshape(B*H, N, D), o.reshape(B*H, N, D)
|
||||
NUM_RUNS = getenv("CNT", 5)
|
||||
ets = []
|
||||
with Context(DEBUG=getenv("KDBG", 2)):
|
||||
for _ in range(NUM_RUNS):
|
||||
GlobalCounters.reset()
|
||||
tst = Tensor.custom_kernel(o_flat, q_flat, k_flat, v_flat, fxn=amd_flash_attention)[0].realize()
|
||||
ets.append(GlobalCounters.time_sum_s)
|
||||
print(f"best time: {min(ets)*1e3:.2f}ms")
|
||||
|
||||
if getenv("VERIFY", 1):
|
||||
with Context(DEBUG=0):
|
||||
ref = q.float().scaled_dot_product_attention(k.float(), v.float()).reshape(B*H, N, D).realize()
|
||||
err = (ref - tst).square().mean().item()
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-2:
|
||||
raise RuntimeError("flash attention is wrong!")
|
||||
else:
|
||||
print("flash attention is correct!")
|
||||
@@ -816,5 +816,58 @@ class TestUOpTags(unittest.TestCase):
|
||||
g = graph_rewrite(g, pm_plus_1)
|
||||
assert g.ssimplify() == 6
|
||||
|
||||
class TestUOpGetItem(unittest.TestCase):
|
||||
def _placeholder(self, shape, dtype=dtypes.half):
|
||||
return UOp.placeholder(shape, dtype, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
# full slices (no shrink)
|
||||
def test_full_slice(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[:, :].shape, (64, 64))
|
||||
def test_full_slice_explicit(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[0:64, 0:64].shape, (64, 64))
|
||||
|
||||
# partial slices (shrink)
|
||||
def test_shrink_cols(self):
|
||||
p = self._placeholder((64, 80))
|
||||
self.assertEqual(p[:, :64].shape, (64, 64))
|
||||
def test_shrink_rows(self):
|
||||
p = self._placeholder((80, 64))
|
||||
self.assertEqual(p[:64, :].shape, (64, 64))
|
||||
def test_shrink_both(self):
|
||||
p = self._placeholder((80, 80))
|
||||
self.assertEqual(p[:64, :64].shape, (64, 64))
|
||||
def test_shrink_start(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[8:, :].shape, (56, 64))
|
||||
def test_shrink_start_and_end(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[8:56, 4:60].shape, (48, 56))
|
||||
|
||||
# mixed slice and index
|
||||
def test_index_and_slice(self):
|
||||
p = self._placeholder((64, 80))
|
||||
r = UOp.range(64, 100)
|
||||
result = p[r, :64]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
def test_slice_and_index(self):
|
||||
p = self._placeholder((80, 64))
|
||||
r = UOp.range(64, 100)
|
||||
result = p[:64, r]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
def test_shrink_then_index(self):
|
||||
p = self._placeholder((64, 80))
|
||||
s = p[:, :64]
|
||||
r = UOp.range(64, 100)
|
||||
result = s[r]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
|
||||
# integer index (no slice)
|
||||
def test_int_index(self):
|
||||
p = self._placeholder((64, 64))
|
||||
result = p[0]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -431,13 +431,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if len(idx) < len(self.shape): idx += tuple([slice(None)]*(len(self.shape)-len(idx)))
|
||||
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
|
||||
# apply SHRINK for slices that aren't the full range
|
||||
bounds = tuple((s.start or 0, s.stop if s.stop is not None else self.shape[i]) if isinstance(s, slice) else (0, self.shape[i])
|
||||
for i, s in enumerate(idx))
|
||||
src = self if all(b == (0, self.shape[i]) for i, b in enumerate(bounds)) else self.shrink(bounds)
|
||||
perm = src.permute(tuple([i for i in range(src.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
else:
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
return UOp.const(self.dtype.base, b, device=self._device, shape=self._shape)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.vcount == 1
|
||||
if count == 1: return self
|
||||
|
||||
Reference in New Issue
Block a user