[OPS] Add more perf-tests, new features to FA (#1849)

Adding new tests across the board for float32, bfloat16, non-powers-of-2
shapes (to test masks), and tests on sequence parallel for atomics. This
also adds the sequence parallel features from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py.
I am not sure about the best way to grab the baseline benchmarking
numbers. I have access to V100s and A100s, but I saw on the tests it
mentions " # A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k
float16". Current plan is to run CI here and use those numbers for
baseline, then match against my GPUs as a sanity check.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Izzy Putterman
2023-07-10 18:52:59 -07:00
committed by GitHub
parent 73e18e9588
commit d39d78fa08
4 changed files with 380 additions and 183 deletions

View File

@@ -55,26 +55,39 @@ matmul_data = {
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.084},
(1000, 200, 700): {'float16': 0.084},
(994, 136, 402): {'float16': 0.084},
(995, 135, 409): {'float16': 0.084},
(99, 1357, 409): {'float16': 0.084},
},
# NOTE:
# A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k float16
'a100': {
(512, 512, 512): {'float16': 0.084, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.332, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.641, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.785, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.805, 'float32': 0.85, 'int8': 0.51},
# square
(512, 512, 512): {'float16': 0.084, 'float32': 0.12, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.332, 'float32': 0.352, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.635, 'float32': 0.522, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.750, 'float32': 0.810, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.760, 'float32': 0.760, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.044, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.028, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.163, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.285, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.033, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.254, 'float32': 0.230, 'int8': 0.177},
(16, 1024, 1024): {'float16': 0.008, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.036, 'float32': 0.038, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.056, 'float32': 0.061, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.020, 'float32': 0.030, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.160, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.280, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.040, 'float32': 0.050, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.160, 'float32': 0.200, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.250, 'float32': 0.23, 'int8': 0.177},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.011, 'float32': 0.017, 'int8': 0.05},
(1000, 200, 700): {'float16': 0.027, 'float32': 0.047, 'int8': 0.05},
(994, 136, 402): {'float16': 0.015, 'float32': 0.024, 'int8': 0.05},
(995, 135, 409): {'float16': 0.015, 'float32': 0.025, 'int8': 0.05},
(99, 1357, 409): {'float16': 0.011, 'float32': 0.036, 'int8': 0.05}
}
}
@@ -82,10 +95,12 @@ matmul_data = {
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
for dtype_str in ['float16', 'float32']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32':
pytest.skip('Out of shared memory in float32')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
@@ -126,32 +141,44 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = {
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.530,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
1024 * 16: {'float16': 0.0219, 'float32': 0.010},
1024 * 64: {'float16': 0.0791, 'float32': 0.010},
1024 * 256: {'float16': 0.243, 'float32': 0.010},
1024 * 1024: {'float16': 0.530, 'float32': 0.010},
1024 * 4096: {'float16': 0.796, 'float32': 0.010},
1024 * 16384: {'float16': 0.905, 'float32': 0.010},
1024 * 65536: {'float16': 0.939, 'float32': 0.010},
# Non pow 2
1020 * 100: {'float16': 0.010, 'float32': 0.010},
995 * 125: {'float16': 0.010, 'float32': 0.010},
10003 * 7007: {'float16': 0.010, 'float32': 0.010},
},
'a100': {
1024 * 16: 0.010,
1024 * 64: 0.040,
1024 * 256: 0.132,
1024 * 1024: 0.353,
1024 * 4096: 0.605,
1024 * 16384: 0.758,
1024 * 65536: 0.850,
1024 * 16: {'float16': 0.010, 'bfloat16': 0.010, 'float32': 0.020},
1024 * 64: {'float16': 0.040, 'bfloat16': 0.040, 'float32': 0.066},
1024 * 256: {'float16': 0.132, 'bfloat16': 0.132, 'float32': 0.227},
1024 * 1024: {'float16': 0.353, 'bfloat16': 0.353, 'float32': 0.488},
1024 * 4096: {'float16': 0.605, 'bfloat16': 0.605, 'float32': 0.705},
1024 * 16384: {'float16': 0.758, 'bfloat16': 0.750, 'float32': 0.819},
1024 * 65536: {'float16': 0.850, 'bfloat16': 0.850, 'float32': 0.870},
# Non pow 2
1020 * 100: {'float16': 0.051, 'bfloat16': 0.051, 'float32': 0.103},
995 * 125: {'float16': 0.063, 'bfloat16': 0.063, 'float32': 0.126},
10003 * 7007: {'float16': 0.544, 'bfloat16': 0.541, 'float32': 0.861},
}
}
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32'])
def test_elementwise(N, dtype_str):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100':
pytest.skip('Only test bfloat16 on a100')
dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]
ref_gpu_util = elementwise_data[DEVICE_NAME][N][dtype_str]
max_gpu_perf = get_dram_gbps()
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
z = torch.empty((N, ), dtype=dtype, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
@@ -169,29 +196,56 @@ def test_elementwise(N):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.420,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.202,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.355,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.201,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.099,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.062,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.370,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075,
}
}
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32'])
@pytest.mark.parametrize("mode", ['forward', 'backward'])
@pytest.mark.parametrize("dtype_str", ['float16'])
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("seq_par", [True, False])
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
is_backward = mode == 'backward'
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]
# init data
if dtype_str == 'float32':
N_CTX = 1024
D_HEAD = 16
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
# benchmark
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
fn = lambda: triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
if is_backward:
o = fn()
do = torch.randn_like(o)
@@ -207,6 +261,6 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

View File

@@ -10,22 +10,23 @@ import triton.ops
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
@@ -34,7 +35,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, sm_scale)
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)

View File

@@ -3,6 +3,9 @@ Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
Sequence Parallel implementation inspired by HazyResearch
(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py)
"""
import torch
@@ -23,68 +26,113 @@ def _fwd_kernel(
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
MODE: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# causal check on every loop iteration can be expensive
# and peeling the last iteration of the loop does not work well with ptxas
# so we have a mode to do the causal check in a separate kernel entirely
if MODE == 0: # entire non-causal attention
lo, hi = 0, N_CTX
if MODE == 1: # entire causal attention
lo, hi = 0, (start_m + 1) * BLOCK_M
if MODE == 2: # off band-diagonal
lo, hi = 0, start_m * BLOCK_M
if MODE == 3: # on band-diagonal
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
m_i = tl.load(m_ptrs)
l_i = tl.load(l_ptrs)
acc += tl.load(O_block_ptr).to(tl.float32)
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
# credits to: Adam P. Goucher (https://github.com/apgoucher):
# scale sm_scale by 1/log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(K.dtype.element_ty)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs)
k = tl.load(tl.advance(K_block_ptr, (0, start_n)))
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
qk += tl.dot(q, k, allow_tf32=True)
if MODE == 1 or MODE == 3:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.math.exp2(m_i - m_i_new)
beta = tl.math.exp2(m_ij - m_i_new)
l_i *= alpha
l_i_new = l_i + beta * l_ij
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new
acc = acc * acc_scale[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
v = tl.load(tl.advance(V_block_ptr, (start_n, 0)))
p = p.to(V.dtype.element_ty)
acc += tl.dot(p, v, allow_tf32=True)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
l_i = l_i_new
m_i = m_i_new
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# write back O
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
@jit
@@ -108,93 +156,167 @@ def _bwd_preprocess(
@jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
def _bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale,
Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
MODE: tl.constexpr,
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
if MODE == 0:
lo = 0
else:
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
if MODE == 1:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
m = tl.load(m_ptrs + offs_m_curr)
p = tl.math.exp2(qk - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
# dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp = tl.dot(do, tl.trans(v), allow_tf32=True)
# compute ds = p * (dp - delta[:, None])
ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty)
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
# compute dq
if not SEQUENCE_PARALLEL:
dq = tl.load(dq_ptrs)
dq += tl.dot(ds, k, allow_tf32=True)
tl.store(dq_ptrs, dq)
elif SEQUENCE_PARALLEL:
# dq = tl.dot(ds, k, allow_tf32=True)
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
@jit
def _bwd_kernel(
# fmt: off
Q, K, V, sm_scale,
Out, DO,
DQ, DK, DV,
L, M,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
MODE: tl.constexpr,
# fmt: on
):
qk_scale = sm_scale * 1.44269504
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
if not SEQUENCE_PARALLEL:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
MODE=MODE,
)
else:
start_n = tl.program_id(1)
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
MODE=MODE,
)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
@@ -209,58 +331,80 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
if causal:
modes = [1] if q.shape[2] <= 2048 else [2, 3]
else:
modes = [0]
for mode in modes:
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=128, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk,
MODE=mode,
num_warps=num_warps,
num_stages=2)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.sequence_parallel = sequence_parallel
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128
q, k, v, o, l, m = ctx.saved_tensors
sequence_parallel = ctx.sequence_parallel
seq_len_kv = k.shape[2]
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
if sequence_parallel:
replicas = cdiv(seq_len_kv, BLOCK)
new_dq_shape = (replicas,) + q.shape
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
if ctx.causal:
mode = 1
else:
mode = 0
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
MODE=mode,
num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
if len(dq.shape) == 5:
dq = dq.sum(dim=0)
return dq, dk, dv, None, None, None
attention = _attention.apply

View File

@@ -346,9 +346,7 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)