mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPS] improved pointer arithmetic in attention (#1926)
this provides an additional 3-4% speed-up in non-causal attention, which now tops at 155TFLOPS
This commit is contained in:
@@ -44,7 +44,6 @@ matmul_data = {
|
||||
(512, 512, 512): {'float16': 0.061, 'float32': 0.097, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.283, 'float32': 0.313, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.618, 'float32': 0.532, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.751, 'float32': 0.726, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.786, 'float32': 0.754, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.006, 'float32': 0.009, 'int8': 0.005},
|
||||
@@ -56,6 +55,8 @@ matmul_data = {
|
||||
(1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.179, 'float32': 0.214, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177},
|
||||
# test EVEN_K==False
|
||||
(8192, 8192, 8176): {'float16': 0.786, 'float32': 0.696, 'int8': 0.51},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +116,6 @@ elementwise_data = {
|
||||
1024 * 64: {'float16': 0.013, 'float32': 0.026},
|
||||
1024 * 256: {'float16': 0.053, 'float32': 0.105},
|
||||
1024 * 1024: {'float16': 0.212, 'float32': 0.420},
|
||||
1024 * 4096: {'float16': 0.791, 'float32': 0.668},
|
||||
1024 * 16384: {'float16': 0.762, 'float32': 0.812},
|
||||
1024 * 65536: {'float16': 0.846, 'float32': 0.869},
|
||||
# Non pow 2
|
||||
@@ -162,7 +162,7 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.248,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.220,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
@@ -173,8 +173,8 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
|
||||
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.242,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.248,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.251,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.220,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
|
||||
|
||||
@@ -94,11 +94,14 @@ def _fwd_kernel(
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(K.dtype.element_ty)
|
||||
# advance block pointers to first iteration of the loop
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(tl.advance(K_block_ptr, (0, start_n)))
|
||||
k = tl.load(K_block_ptr)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, allow_tf32=True)
|
||||
if MODE == 1 or MODE == 3:
|
||||
@@ -120,12 +123,15 @@ def _fwd_kernel(
|
||||
acc_scale = l_i / l_i_new
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(tl.advance(V_block_ptr, (start_n, 0)))
|
||||
v = tl.load(V_block_ptr)
|
||||
p = p.to(V.dtype.element_ty)
|
||||
acc += tl.dot(p, v, allow_tf32=True)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
|
||||
@@ -93,11 +93,14 @@ def _fwd_kernel(
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
# advance block pointers to first iteration of the loop
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(tl.advance(K_block_ptr, (0, start_n)))
|
||||
k = tl.load(K_block_ptr)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
if MODE == 1 or MODE == 3:
|
||||
@@ -119,12 +122,15 @@ def _fwd_kernel(
|
||||
acc_scale = l_i / l_i_new
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(tl.advance(V_block_ptr, (start_n, 0)))
|
||||
v = tl.load(V_block_ptr)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
|
||||
Reference in New Issue
Block a user