This commit is contained in:
wozeparrot
2025-11-25 15:49:44 -08:00
committed by GitHub
parent 436ab6bfc7
commit ffc31a23f4
5 changed files with 873 additions and 275 deletions

View File

@@ -1,21 +1,21 @@
import unittest, math
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.helpers import CI
from tinygrad.renderer.ptx import PTXRenderer
import numpy as np
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.kernel import Kernel
from extra.thunder.tiny.tk.tiles import ST_16X32, RT_16X32, RT_16X16, TileLayout
@unittest.skipIf(CI and Device.DEFAULT not in ["CUDA", "NV"], "only cuda")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "no ptx")
@unittest.skipIf(CI and Device.DEFAULT not in ["AMD"], "only amd")
class TestTK(unittest.TestCase):
@unittest.skipIf(CI, "no wmma in ci")
def test_simple_matmul(self):
N = 32
BLOCK_SIZE = 16
N = 8192
BLOCK_SIZE = 64
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -25,11 +25,10 @@ class TestTK(unittest.TestCase):
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
col, row = ker.blockIdx_x, ker.blockIdx_y
@@ -39,13 +38,12 @@ class TestTK(unittest.TestCase):
b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2)
a_reg = warp.load(a_reg, a_smem)
b_reg = warp.load(b_reg, b_smem, transpose=True)
b_reg = warp.load(b_reg, b_smem)
c_reg = warp.mma_AB(c_reg, a_reg, b_reg)
c_reg = ker.endrange()
c_smem = warp.store(c_smem, c_reg)
c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2)
c = warp.store(c, c_reg, (0, 0, row, col), (), axis=2)
sink = ker.finish()
@@ -65,27 +63,26 @@ class TestTK(unittest.TestCase):
@unittest.skipIf(CI, "no wmma in ci")
def test_simple_matmul_transposed(self):
N = 32
BLOCK_SIZE = 16
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
N = 8192
BLOCK_N, BLOCK_M, BLOCK_K = 64, 64, 128
with Kernel((N // BLOCK_N, N // BLOCK_M, 1), WARP_THREADS) as ker:
warp = ker.warp
c = ker.gl((1, 1, N, N), dtypes.float32)
a = ker.gl((1, 1, N, N), dtypes.bfloat16)
b = ker.gl((1, 1, N, N), dtypes.bfloat16)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_smem = ker.st((BLOCK_N, BLOCK_K), dtypes.bfloat16, base_shape=ST_16X32)
b_smem = ker.st((BLOCK_M, BLOCK_K), dtypes.bfloat16, base_shape=ST_16X32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_N, BLOCK_K), dtypes.bfloat16, base_shape=RT_16X32)
b_reg = ker.rt((BLOCK_M, BLOCK_K), dtypes.bfloat16, base_shape=RT_16X32)
c_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32, TileLayout.COL, base_shape=RT_16X16)
col, row = ker.blockIdx_x, ker.blockIdx_y
c_reg = warp.zero(c_reg)
for tile in ker.range(N // BLOCK_SIZE):
for tile in ker.range(N // BLOCK_K):
a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2)
b_smem = warp.load(b_smem, b, (), (0, 0, col, tile), axis=2)
@@ -95,8 +92,7 @@ class TestTK(unittest.TestCase):
c_reg = warp.mma_ABt(c_reg, a_reg, b_reg)
c_reg = ker.endrange()
c_smem = warp.store(c_smem, c_reg)
c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2)
c = warp.store(c, c_reg, (0, 0, row, col), (), axis=2)
sink = ker.finish()
@@ -115,8 +111,8 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), ref.numpy())
def test_load_store(self):
N = 32
BLOCK_SIZE = 16
N = 64
BLOCK_SIZE = 32
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -124,7 +120,6 @@ class TestTK(unittest.TestCase):
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
@@ -134,8 +129,45 @@ class TestTK(unittest.TestCase):
a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2)
a_reg = warp.load(a_reg, a_smem)
b_reg = warp.copy(b_reg, a_reg)
b_smem = warp.store(b_smem, b_reg)
b = warp.store(b, b_smem, (0, 0, row, col), (), axis=2)
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
sink = ker.finish()
with Context(DEBUG=0):
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float()
np.testing.assert_allclose(b.numpy(), ref.numpy())
@unittest.skip("TODO")
def test_load_store_group(self):
N = 256
BLOCK_SIZE = 64
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS * 2) as ker:
warp = ker.warp
group = ker.group(2)
b = ker.gl((1, 1, N, N), dtypes.float32)
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
col, row = ker.blockIdx_x, ker.blockIdx_y
a_smem = group.load(a_smem, a, (), (0, 0, row, col), axis=2)
a_reg = warp.load(a_reg, a_smem)
b_reg = warp.copy(b_reg, a_reg)
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
sink = ker.finish()
@@ -153,8 +185,8 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), ref.numpy())
def test_add(self):
N = 32
BLOCK_SIZE = 16
N = 64
BLOCK_SIZE = 32
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -172,8 +204,7 @@ class TestTK(unittest.TestCase):
a_reg += 1
a_smem = warp.store(a_smem, a_reg)
b = warp.store(b, a_smem, (0, 0, tile_row, tile_col), (), axis=2)
b = warp.store(b, a_reg, (0, 0, tile_row, tile_col), (), axis=2)
sink = ker.finish()
@@ -191,8 +222,8 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), ref.numpy())
def test_max(self):
N = 16
BLOCK_SIZE = 16
N = 64
BLOCK_SIZE = 32
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -200,27 +231,25 @@ class TestTK(unittest.TestCase):
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
max_reg = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho")
max_reg = ker.rv(BLOCK_SIZE, dtypes.float32)
for tile_row in ker.range(N // BLOCK_SIZE):
max_reg = warp.neg_inf(max_reg.after(tile_row))
for tile_col in ker.range(N // BLOCK_SIZE):
max_reg = warp.neg_inf(max_reg.after(tile_col))
for tile_col in ker.range(N // BLOCK_SIZE):
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
max_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b))
max_reg = warp.col_reduce(max_reg, a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
max_reg = ker.endrange()
b_reg = warp.map(b_reg, lambda _, idx: max_reg[idx[0], 0, (idx[2]%4)//2])
b_smem = warp.store(b_smem, b_reg)
b_reg = warp.map(b_reg, lambda _, idx: max_reg[idx[1], 0])
for tile_col in ker.range(N // BLOCK_SIZE):
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
for tile_row in ker.range(N // BLOCK_SIZE):
b = warp.store(b, b_reg, (0, 0, tile_row, tile_col), (), axis=2)
sink = ker.finish()
@@ -233,12 +262,12 @@ class TestTK(unittest.TestCase):
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float().max(axis=3, keepdim=True).expand(a.shape)
ref = a.float().max(axis=2, keepdim=True).expand(a.shape)
np.testing.assert_allclose(b.numpy(), ref.numpy())
def test_max_nonsquare(self):
N, M = 16, 64
N, M = 32, 128
BLOCK_N, BLOCK_M = 16, 64
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -247,27 +276,25 @@ class TestTK(unittest.TestCase):
a = ker.gl((1, 1, N, M), dtypes.float32)
a_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32)
b_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32)
a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32)
b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32)
a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32, TileLayout.COL)
b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32, TileLayout.COL)
max_reg = ker.rv(BLOCK_N, dtypes.float32, "ortho")
max_reg = ker.rv(BLOCK_M, dtypes.float32)
for tile_row in ker.range(N // BLOCK_N):
max_reg = warp.neg_inf(max_reg.after(tile_row))
for tile_col in ker.range(M // BLOCK_M):
max_reg = warp.neg_inf(max_reg.after(tile_col))
for tile_col in ker.range(M // BLOCK_M):
for tile_row in ker.range(N // BLOCK_N):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
max_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b))
max_reg = warp.col_reduce(max_reg, a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
max_reg = ker.endrange()
b_reg = warp.map(b_reg, lambda _, idx: max_reg[idx[0], 0, (idx[2]%4)//2])
b_smem = warp.store(b_smem, b_reg)
b_reg = warp.map(b_reg, lambda _, idx: max_reg[idx[1], 0])
for tile_col in ker.range(M // BLOCK_M):
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
for tile_row in ker.range(N // BLOCK_N):
b = warp.store(b, b_reg, (0, 0, tile_row, tile_col), (), axis=2)
sink = ker.finish()
@@ -280,13 +307,13 @@ class TestTK(unittest.TestCase):
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float().max(axis=3, keepdim=True).expand(a.shape)
ref = a.float().max(axis=2, keepdim=True).expand(a.shape)
np.testing.assert_allclose(b.numpy(), ref.numpy())
def test_sum(self):
N = 32
BLOCK_SIZE = 16
N = 64
BLOCK_SIZE = 32
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -294,27 +321,25 @@ class TestTK(unittest.TestCase):
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
sum_reg = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho")
sum_reg = ker.rv(BLOCK_SIZE, dtypes.float32)
for tile_row in ker.range(N // BLOCK_SIZE):
sum_reg = warp.zero(sum_reg.after(tile_row))
for tile_col in ker.range(N // BLOCK_SIZE):
sum_reg = warp.zero(sum_reg.after(tile_col))
for tile_col in ker.range(N // BLOCK_SIZE):
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b)
sum_reg = warp.col_reduce(sum_reg, a_reg, lambda a, b: a + b)
sum_reg = ker.endrange()
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
b_smem = warp.store(b_smem, b_reg)
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[1], 0])
for tile_col in ker.range(N // BLOCK_SIZE):
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
for tile_row in ker.range(N // BLOCK_SIZE):
b = warp.store(b, b_reg, (0, 0, tile_row, tile_col), (), axis=2)
sink = ker.finish()
@@ -327,12 +352,12 @@ class TestTK(unittest.TestCase):
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float().sum(axis=3, keepdim=True).expand(a.shape)
ref = a.float().sum(axis=2, keepdim=True).expand(a.shape)
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
def test_sum_nonsquare(self):
N, M = 16, 64
N, M = 32, 128
BLOCK_N, BLOCK_M = 16, 64
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -341,27 +366,25 @@ class TestTK(unittest.TestCase):
a = ker.gl((1, 1, N, M), dtypes.float32)
a_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32)
b_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32)
a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32)
b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32)
a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32, TileLayout.COL)
b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32, TileLayout.COL)
sum_reg = ker.rv(BLOCK_N, dtypes.float32, "ortho")
sum_reg = ker.rv(BLOCK_M, dtypes.float32)
for tile_row in ker.range(N // BLOCK_N):
sum_reg = warp.zero(sum_reg.after(tile_row))
for tile_col in ker.range(M // BLOCK_M):
sum_reg = warp.zero(sum_reg.after(tile_col))
for tile_col in ker.range(M // BLOCK_M):
for tile_row in ker.range(N // BLOCK_N):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b)
sum_reg = warp.col_reduce(sum_reg, a_reg, lambda a, b: a + b)
sum_reg = ker.endrange()
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
b_smem = warp.store(b_smem, b_reg)
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[1], 0])
for tile_col in ker.range(M // BLOCK_M):
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
for tile_row in ker.range(N // BLOCK_N):
b = warp.store(b, b_reg, (0, 0, tile_row, tile_col), (), axis=2)
sink = ker.finish()
@@ -374,14 +397,13 @@ class TestTK(unittest.TestCase):
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float().sum(axis=3, keepdim=True).expand(a.shape)
ref = a.float().sum(axis=2, keepdim=True).expand(a.shape)
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
@unittest.skip("fake range not ended")
def test_softmax(self):
N = 32
BLOCK_SIZE = 16
N = 64
BLOCK_SIZE = 32
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
@@ -392,9 +414,9 @@ class TestTK(unittest.TestCase):
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
max_vec_last = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho")
max_vec = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho")
norm_vec = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho")
max_vec_last = ker.rv(BLOCK_SIZE, dtypes.float32)
max_vec = ker.rv(BLOCK_SIZE, dtypes.float32)
norm_vec = ker.rv(BLOCK_SIZE, dtypes.float32)
max_vec = warp.neg_inf(max_vec)
norm_vec = warp.zero(norm_vec)
@@ -406,7 +428,7 @@ class TestTK(unittest.TestCase):
a_reg *= 1.0 / math.log(2)
max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec)
max_vec = warp.row_reduce(max_vec, a_reg, lambda a, b: a.maximum(b))
max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg = (a_reg - max_vec).exp2()
max_vec_last = (max_vec_last - max_vec).exp2()
norm_vec *= max_vec_last
@@ -415,14 +437,13 @@ class TestTK(unittest.TestCase):
for tile_col in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
a_reg *= 1.0 / math.log(2)
a_reg = (a_reg - max_vec).exp2()
a_reg /= norm_vec
a_smem = warp.store(a_smem, a_reg)
b = warp.store(b, a_smem, (0, 0, 0, tile_col), (), axis=2)
b = warp.store(b, a_reg, (0, 0, 0, tile_col), (), axis=2)
sink = ker.finish()
@@ -439,5 +460,177 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
def test_softmax_col(self):
N = 64
BLOCK_SIZE = 32
with Kernel((1, 1, 1), WARP_THREADS) as ker:
warp = ker.warp
b = ker.gl((1, 1, N, BLOCK_SIZE), dtypes.float32)
a = ker.gl((1, 1, N, BLOCK_SIZE), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
max_vec_last = ker.rv(BLOCK_SIZE, dtypes.float32)
max_vec = ker.rv(BLOCK_SIZE, dtypes.float32)
norm_vec = ker.rv(BLOCK_SIZE, dtypes.float32)
max_vec = warp.neg_inf(max_vec)
norm_vec = warp.zero(norm_vec)
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg = warp.load(a_reg, a_smem)
a_reg *= 1.0 / math.log(2)
max_vec_last = warp.copy(max_vec_last.after(tile_row), max_vec)
max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg = (a_reg - max_vec).exp2()
max_vec_last = (max_vec_last - max_vec).exp2()
norm_vec *= max_vec_last
norm_vec = warp.col_reduce(norm_vec, a_reg, lambda a, b: a + b)
norm_vec = ker.endrange()
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
a_reg *= 1.0 / math.log(2)
a_reg = (a_reg - max_vec).exp2()
a_reg /= norm_vec
b = warp.store(b, a_reg, (0, 0, tile_row, 0), (), axis=2)
sink = ker.finish()
with Context(DEBUG=0):
a = Tensor.rand(1, 1, N, BLOCK_SIZE, dtype="float32")
b = Tensor.empty(1, 1, N, BLOCK_SIZE, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float().softmax(axis=2)
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
def test_fa(self):
NUM_WORKERS = 1
B, N, H, H_KV, D = 1, 8192, 32, 8, 128
Q_BLOCK_SIZE = 16
KV_BLOCK_SIZE = 16
GROUP_SIZE = H // H_KV
with Kernel((H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
warp = ker.warp
# kernel
o = ker.gl((B, N, H, D), dtypes.bfloat16)
q = ker.gl((B, N, H, D), dtypes.bfloat16)
k = ker.gl((B, N, H_KV, D), dtypes.bfloat16)
v = ker.gl((B, N, H_KV, D), dtypes.bfloat16)
head = ker.blockIdx_x
head_kv = head // GROUP_SIZE
batch = ker.blockIdx_z
q_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid
k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
q_reg_fl = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
q_reg_transposed = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
k_reg_transposed = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
o_reg = ker.rt((D, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
o_reg_transposed = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
max_vec_last = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
max_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
norm_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
scale_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
max_vec = warp.neg_inf(max_vec)
norm_vec = warp.zero(norm_vec)
o_reg = warp.zero(o_reg)
scale_vec = warp.ones(scale_vec)
# load q tile
q_reg_fl = warp.load(q_reg_fl, q, (), (batch, q_seq, head, 0), axis=1)
q_reg_fl *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
q_reg = warp.copy(q_reg, q_reg_fl)
q_reg_transposed = warp.transpose(q_reg_transposed, q_reg)
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1)
v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1)
k_reg = warp.load(k_reg, k_smem)
v_reg = warp.load(v_reg, v_smem)
# mma qk^t
att_block = warp.zero(att_block.after(kv_idx))
k_reg_transposed = warp.transpose(k_reg_transposed, k_reg)
att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed)
# mask for causal
q_base = q_seq * Q_BLOCK_SIZE + (warp.laneid % 16)
kv_base = kv_idx * KV_BLOCK_SIZE + (warp.laneid // 16) * 4
att_block = warp.map(att_block,
lambda x, idx: ((kv_base + idx[0]*16 + idx[2]) > (q_base + idx[1]*16)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x))
# softmax
max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec)
max_vec = warp.row_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
scale_vec = warp.map(scale_vec.after(max_vec_last, max_vec), lambda _, idx: max_vec_last[*idx] - max_vec[*idx])
scale_vec = scale_vec.exp2()
o_reg *= scale_vec
norm_vec *= scale_vec
att_block -= max_vec
att_block = att_block.exp2()
norm_vec = warp.row_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
# mma av
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma)
o_reg = ker.endrange()
o_reg /= norm_vec
o_reg_transposed = warp.transpose(o_reg_transposed, o_reg)
o = warp.store(o, o_reg_transposed, (batch, q_seq, head, 0), (), axis=1)
sink = ker.finish()
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
out = Tensor.empty(B, N, H, D, dtype=dtypes.bfloat16)
Tensor.realize(q, k, v, out)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)])
for _ in range(5): ei.run(wait=True)
out = out.float()
q_permuted = q.permute(0, 2, 1, 3)
k_permuted = k.permute(0, 2, 1, 3)
v_permuted = v.permute(0, 2, 1, 3)
ref = q_permuted.scaled_dot_product_attention(k_permuted, v_permuted, is_causal=True, enable_gqa=True).float()
ref = ref.permute(0, 2, 1, 3)
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=1e-2, rtol=1e-5)
if __name__ == "__main__":
unittest.main()