mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Resoolve conflicts in FA tutorial
This commit is contained in:
@@ -701,12 +701,7 @@ class _attention(torch.autograd.Function):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
=======
|
||||
o = torch.empty_like(q)
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
if torch.version.hip is None:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
@@ -721,18 +716,6 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
=======
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, M, o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
@@ -749,20 +732,7 @@ class _attention(torch.autograd.Function):
|
||||
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
|
||||
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
|
||||
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=3,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
=======
|
||||
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
@@ -862,39 +832,17 @@ attention = _attention.apply
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
<<<<<<< HEAD
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
if TORCH_HAS_FP8E5:
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
=======
|
||||
q = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0., std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
sm_scale = 0.5
|
||||
dout = torch.randn_like(q, dtype=torch.float16)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
<<<<<<< HEAD
|
||||
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
|
||||
=======
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
@@ -918,27 +866,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
=======
|
||||
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 2, 1024, 64)])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
sm_scale = 0.5
|
||||
split_kernel = True
|
||||
dout = torch.randn_like(q)
|
||||
@@ -984,7 +911,6 @@ for mode in ['fwd', 'bwd']:
|
||||
for D_HEAD in [128, 64]:
|
||||
if mode == 'bwd' and D_HEAD == 128:
|
||||
continue
|
||||
<<<<<<< HEAD
|
||||
for causal in [False, True]:
|
||||
if mode == 'bwd' and causal == False:
|
||||
continue
|
||||
@@ -1013,49 +939,6 @@ for mode in ['fwd', 'bwd']:
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
=======
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["N_CTX"],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg="provider",
|
||||
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
|
||||
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}",
|
||||
args={
|
||||
"H": N_HEADS,
|
||||
"BATCH": BATCH,
|
||||
"D_HEAD": D_HEAD,
|
||||
"dtype": torch.float16,
|
||||
"mode": mode,
|
||||
"causal": causal,
|
||||
},
|
||||
)
|
||||
for mode in ["fwd", "bwd"]
|
||||
for causal in [True]
|
||||
]
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
=======
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
>>>>>>> 0eed50883... resolve some merge conflicts
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
|
||||
Reference in New Issue
Block a user