fp8 type support (#357)

* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
This commit is contained in:
Shucai Xiao
2023-11-02 15:51:23 -05:00
committed by GitHub
parent 38f9136fc8
commit 79bebc4ffe
11 changed files with 445 additions and 131 deletions

View File

@@ -811,6 +811,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
})
.def("get_fp8e4b15_ty",
[](TritonOpBuilder &self) -> mlir::Type {
// TODO: upstream FP8E4B15 into MLIR, or find a way to externally
@@ -827,6 +831,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E5M2Type>();
})
.def("get_fp8e5b16_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E5M2FNUZType>();
})
.def("get_half_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getF16Type();

View File

@@ -959,6 +959,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
# float8e4m3nv does not have infinities
output[fp == 0b01111111] = torch.nan
output[fp == 0b11111111] = torch.nan
elif dtype in [tl.float8e4b8, tl.float8e5b16]:
output[fp==0b10000000] = torch.nan
else:
output = torch.where(exp == (1 << exp_width) - 1,
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
@@ -1015,7 +1017,11 @@ def deserialize_fp8(np_data, in_dtype):
return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15,
tl.float8e4b15x4,
tl.float8e4b8,
tl.float8e5,
tl.float8e5b16])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
"""
@@ -1040,9 +1046,11 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask)
ref_fp8[is_nan] = 0
ref_fp8[is_subnormal] = 0
tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda()
tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda")
copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024)
@@ -1055,6 +1063,50 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
assert torch.all(tri_fp8 == ref_fp8)
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
tl_to_torch_types = {
tl.float16: torch.float16,
tl.float32: torch.float32,
}
if TORCH_HAS_FP8E5B16:
tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz
if TORCH_HAS_FP8E4B8:
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
def gen_input(M, N, d_type, seed, device='cuda'):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if d_type == tl.float16:
input = torch.randn((M, N), dtype=torch.float16, device=device)
input_f16 = input
else: # d_type is float8
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) :
input = raw_data.to(tl_to_torch_types[d_type])
input_f16 = input.to(torch.float16)
else:
f8_tensor = raw_data.to(torch.int8)
# keep only two bits of exponent to avoid overflow
f8_tensor = f8_tensor & 0b00111111
input = triton.reinterpret(f8_tensor, d_type)
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
n_elements = raw_data.numel()
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
return input, input_f16
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
[(*shape, *ab_type, out_dtype)
for shape in [[128, 256, 32],
@@ -1071,6 +1123,10 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
[tl.float8e4b15, tl.float16],
[tl.float8e4b15x4, tl.float16],
[tl.float8e5, tl.float16],
[tl.float8e4b8, tl.float16],
[tl.float8e5b16, tl.float16],
[tl.float16, tl.float8e5b16],
[tl.float16, tl.float8e4b8],
[tl.float16, tl.float8e4nv],
[tl.float16, tl.float8e4b15],
[tl.float16, tl.float8e4b15x4],
@@ -1078,17 +1134,8 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
for out_dtype in [torch.float16, torch.float32]
])
def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
check_type_supported(out_dtype, device)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -1117,15 +1164,128 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b, out_dtype=compute_type)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
c = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b, c_type):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
M, K = a.shape
K, N = b.shape
if c_type == torch.float16:
comp_type = tl.float16
else:
comp_type = tl.float32
c = torch.empty((M, N), device = a.device, dtype=c_type)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
compute_type = comp_type,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=4,
num_stages=1,
num_warps=2,
)
return c
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
# call torch function to compute gold
golden = torch.matmul(a_f16, b_f16)
c = matmul(a, b, out_dtype)
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=1e-2)
# @pytest.mark.skip(reason="Pytorch does not support the following types, so need to skip for now")
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
[(*shape, *ab_type, out_dtype)
for shape in [[128, 256, 32],
[128, 16, 32],
[32, 128, 64],
[128, 128, 64],
[64, 128, 128],
[32, 128, 64],
[64, 64, 32],
[32, 32, 128],
[128, 128, 64],
[64, 128, 128]]
for ab_type in [[tl.float8e4b8, tl.float8e4b8],
[tl.float8e5b16, tl.float8e4b8],
[tl.float8e4b8, tl.float8e5b16],
[tl.float8e5b16, tl.float8e5b16]]
for out_dtype in [torch.float32]
])
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
check_type_supported(out_dtype, device)
if triton.language.semantic.gpu_matrix_core_version() != 3:
pytest.skip("fp8 data type is not available on hardware")
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
compute_type:tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
@@ -1168,33 +1328,14 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
return c
def gen_input(M, N, d_type, seed, device='cuda'):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if d_type == tl.float16:
input = torch.randn((M, K), dtype=torch.float16, device=device)
input_f16 = input
else: # d_type is float8
f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
f8_tensor = f8_tensor.to(torch.int8)
# keep only two bits of exponent to avoid overflow
f8_tensor = f8_tensor & 0b00111111
input = triton.reinterpret(f8_tensor, d_type)
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
n_elements = f8_tensor.numel()
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
return input, input_f16
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
a, a_f16 = gen_input(M, K, a_type, 21, device=device)
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
# call torch function to compute gold
golden = torch.matmul(a_f16, b_f16)
c = matmul(a, b, out_dtype)
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2)
torch.testing.assert_close(golden, c.to(golden.dtype), rtol=1e-2, atol=2e-2)
# ---------------
@@ -1591,9 +1732,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
# if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
# TODO change types when they are available
# if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
x = x.to(in_dtype)
y = y.to(in_dtype)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
@@ -1626,13 +1767,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
effective_in_dtype = tl.bfloat16
elif in_dtype == "float8e5m2fnuz":
# TODO change types when they are available
effective_in_dtype = tl.float8e5
# effective_in_dtype = tl.float8e5b16
# effective_in_dtype = tl.float8e5
effective_in_dtype = tl.float8e5b16
in_dtype = "float32"
elif in_dtype == "float8e4m3fnuz":
# TODO change types when they are available
effective_in_dtype = tl.float8e4b15
# effective_in_dtype = tl.float8e4b8
# effective_in_dtype = tl.float8e4b15
effective_in_dtype = tl.float8e4b8
in_dtype = "float32"
else:
assert("unexpected in dtype")
@@ -1655,10 +1796,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
if effective_in_dtype.is_fp8():
if effective_in_dtype.is_fp8e5():
x = x + 1
y = y + 1
if effective_in_dtype.is_fp8e5b16():
mask = 0b111111000110 << 20
else:
mask = 0b111110000111 << 20
mask = 0b101111000111 << 20
x = (x.view('uint32') & np.uint32(mask)).view('float32')
y = (y.view('uint32') & np.uint32(mask)).view('float32')
x_tri = to_triton(x, device=device)

View File

@@ -1070,7 +1070,9 @@ def str_to_ty(name):
return language.pointer_type(ty)
tys = {
"fp8e4nv": language.float8e4nv,
"fp8e4b8": language.float8e4b8,
"fp8e5": language.float8e5,
"fp8e5b16": language.float8e5b16,
"fp8e4b15": language.float8e4b15,
"fp8e4b15x4": language.float8e4b15x4,
"fp16": language.float16,

View File

@@ -58,7 +58,9 @@ from .core import (
float8e4b15,
float8e4b15x4,
float8e4nv,
float8e4b8,
float8e5,
float8e5b16,
function_type,
inline_asm_elementwise,
int1,

View File

@@ -75,7 +75,7 @@ def _to_tensor(x, builder):
class dtype:
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']
@@ -107,10 +107,18 @@ class dtype:
self.fp_mantissa_width = 3
self.primitive_bitwidth = 8
self.exponent_bias = 7
elif name == 'fp8e4b8':
self.fp_mantissa_width = 3
self.primitive_bitwidth = 8
self.exponent_bias = 8
elif name == 'fp8e5':
self.fp_mantissa_width = 2
self.primitive_bitwidth = 8
self.exponent_bias = 15
elif name == 'fp8e5b16':
self.fp_mantissa_width = 2
self.primitive_bitwidth = 8
self.exponent_bias = 16
elif name == 'fp16':
self.fp_mantissa_width = 10
self.primitive_bitwidth = 16
@@ -138,6 +146,9 @@ class dtype:
def is_fp8e4nv(self):
return self.name == 'fp8e4nv'
def is_fp8e4b8(self):
return self.name == 'fp8e4b8'
def is_fp8e4b15(self):
return self.name == 'fp8e4b15'
@@ -147,6 +158,9 @@ class dtype:
def is_fp8e5(self):
return self.name == 'fp8e5'
def is_fp8e5b16(self):
return self.name == 'fp8e5b16'
def is_fp16(self):
return self.name == 'fp16'
@@ -250,8 +264,12 @@ class dtype:
return builder.get_int64_ty()
elif self.name == 'fp8e5':
return builder.get_fp8e5_ty()
elif self.name == 'fp8e5b16':
return builder.get_fp8e5b16_ty()
elif self.name == 'fp8e4nv':
return builder.get_fp8e4nv_ty()
elif self.name == 'fp8e4b8':
return builder.get_fp8e4b8_ty()
elif self.name == 'fp8e4b15':
return builder.get_fp8e4b15_ty()
elif self.name == 'fp8e4b15x4':
@@ -388,7 +406,9 @@ uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
float8e5 = dtype('fp8e5')
float8e5b16 = dtype('fp8e5b16')
float8e4nv = dtype('fp8e4nv')
float8e4b8 = dtype('fp8e4b8')
float8e4b15 = dtype('fp8e4b15')
float8e4b15x4 = dtype('fp8e4b15x4')
float16 = dtype('fp16')

View File

@@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]):
tys = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8e5": "fp8e5",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float8e4b15": "fp8e4b15",
"float8e4b15x4": "fp8e4b15x4",
"float16": "fp16",

View File

@@ -17,6 +17,9 @@ import torch
import triton
import triton.language as tl
torch_dtype:tl.constexpr = torch.float16
# torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
@triton.jit
def max_fn(x, y):
@@ -67,7 +70,7 @@ def _attn_fwd_inner(
acc = acc * alpha[:, None]
if not pre_load_v:
v = tl.load(V_block_ptr)
acc += tl.dot(p.to(tl.float16), v)
acc += tl.dot(p.to(v.dtype), v)
# -- update m_i and l_i
l_ij = tl.sum(p, 1)
l_i = l_i * alpha + l_ij
@@ -144,7 +147,7 @@ def _attn_fwd(
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
@@ -272,7 +275,7 @@ def _bwd_kernel(
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
@@ -357,7 +360,7 @@ def _bwd_kernel_dk_dv(
qk_scale = sm_scale * 1.44269504
# load k and v: they will stay in SRAM throughout
k = tl.load(K_block_ptr)
k = (k * qk_scale).to(tl.float16)
k = (k * qk_scale).to(k.dtype)
v = tl.load(V_block_ptr)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
@@ -378,7 +381,7 @@ def _bwd_kernel_dk_dv(
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i)
# -- compute dv ----
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
@@ -407,7 +410,7 @@ def _bwd_kernel_dk_dv(
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DK_block_ptr, (dk * sm_scale).to(tl.float16))
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
tl.store(DV_block_ptr, dv.to(tl.float16))
@triton.jit
@@ -469,7 +472,7 @@ def _bwd_kernel_dq(
qk_scale = sm_scale * 1.44269504
# load q and do: they will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
q = (q * qk_scale).to(q.dtype)
do = tl.load(DO_block_ptr)
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)
@@ -518,7 +521,7 @@ class _attention(torch.autograd.Function):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
o = torch.empty_like(q, dtype=v.dtype)
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
@@ -569,6 +572,7 @@ class _attention(torch.autograd.Function):
q, k, v, o, L = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
# dk = torch.empty_like(k, dtype=torch_dtype)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
@@ -648,26 +652,17 @@ attention = _attention.apply
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0., std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0., std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0., std=0.5)
.requires_grad_()
)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 0.5
dout = torch.randn_like(q)
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
@@ -675,7 +670,7 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale)
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
@@ -690,7 +685,8 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0,5
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
# reference implementation
@@ -777,6 +773,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
if mode == 'bwd':