ROCM IFU: Resoolve conflicts in FA tutorial

This commit is contained in:
Jason Furmanek
2023-11-07 03:26:02 +00:00
parent 502525ff11
commit 85216ea5c5

View File

@@ -701,12 +701,7 @@ class _attention(torch.autograd.Function):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
<<<<<<< HEAD
<<<<<<< HEAD
o = torch.empty_like(q, dtype=v.dtype)
=======
o = torch.empty_like(q)
>>>>>>> 0eed50883... resolve some merge conflicts
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
@@ -721,18 +716,6 @@ class _attention(torch.autograd.Function):
)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
<<<<<<< HEAD
=======
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
=======
>>>>>>> 0eed50883... resolve some merge conflicts
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
@@ -749,20 +732,7 @@ class _attention(torch.autograd.Function):
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
<<<<<<< HEAD
=======
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
STAGE=3,
num_warps=num_warps,
num_stages=num_stages,
)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
=======
>>>>>>> 0eed50883... resolve some merge conflicts
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
@@ -862,39 +832,17 @@ attention = _attention.apply
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
<<<<<<< HEAD
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
=======
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0., std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
>>>>>>> 0eed50883... resolve some merge conflicts
sm_scale = 0.5
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
<<<<<<< HEAD
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
=======
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
>>>>>>> 0eed50883... resolve some merge conflicts
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
@@ -918,27 +866,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
=======
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 2, 1024, 64)])
@pytest.mark.parametrize("causal", [True])
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
@@ -984,7 +911,6 @@ for mode in ['fwd', 'bwd']:
for D_HEAD in [128, 64]:
if mode == 'bwd' and D_HEAD == 128:
continue
<<<<<<< HEAD
for causal in [False, True]:
if mode == 'bwd' and causal == False:
continue
@@ -1013,49 +939,6 @@ for mode in ['fwd', 'bwd']:
'mode': mode,
'causal': causal})
)
=======
configs = [
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"D_HEAD": D_HEAD,
"dtype": torch.float16,
"mode": mode,
"causal": causal,
},
)
for mode in ["fwd", "bwd"]
for causal in [True]
]
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
=======
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
>>>>>>> 0eed50883... resolve some merge conflicts
@triton.testing.perf_report(configs)