mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
fix
This commit is contained in:
@@ -222,7 +222,7 @@ public:
|
||||
return std::distance(value.user_begin(), value.user_end());
|
||||
}
|
||||
|
||||
void scheduleSlicedDot(ModuleOp m, int stages) {
|
||||
void scheduleSlicedDot(ModuleOp m, int stages, bool sinkLDSRd) {
|
||||
SmallVector<SmallVector<Operation *>> dotChains;
|
||||
|
||||
m.walk([&](tt::DotOp dotOp) {
|
||||
@@ -270,6 +270,19 @@ public:
|
||||
operations, i == 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!sinkLDSRd) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto chain : dotChains) {
|
||||
for (int i = 0; i < chain.size(); i++) {
|
||||
Operation *dotOp = chain[i];
|
||||
Operation *ldsRd = dotOp->getOperand(1).getDefiningOp();
|
||||
assert(isLDSRead(ldsRd));
|
||||
moveBefore(ldsRd, dotOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
@@ -278,7 +291,8 @@ public:
|
||||
|
||||
moveQTensorOutOfTheLoop(m);
|
||||
int stages = 4;
|
||||
scheduleSlicedDot(m, stages);
|
||||
bool sinkLDSRd = true;
|
||||
scheduleSlicedDot(m, stages, sinkLDSRd);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -84,23 +84,23 @@ def _attn_fwd_inner(
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
],
|
||||
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
@@ -755,20 +755,20 @@ HAS_FLASH = FLASH_VER is not None
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = []
|
||||
for mode in ['fwd']:
|
||||
for D_HEAD in [128, 64]:
|
||||
for causal in [False, True]:
|
||||
for D_HEAD in [128]:
|
||||
for causal in [False]:
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['BATCH', 'H','N_CTX'],
|
||||
x_vals=[(16, 16, 1024),
|
||||
(8, 16, 2048),
|
||||
(4, 16, 4096),
|
||||
(2, 16, 8192),
|
||||
(1, 16, 16384),
|
||||
(4, 48, 1024),
|
||||
(4, 48, 2048),
|
||||
x_vals=[#(16, 16, 1024),
|
||||
# (8, 16, 2048),
|
||||
# (4, 16, 4096),
|
||||
# (2, 16, 8192),
|
||||
# (1, 16, 16384),
|
||||
# (4, 48, 1024),
|
||||
# (4, 48, 2048),
|
||||
(4, 48, 4096),
|
||||
(4, 48, 8192),
|
||||
(4, 48, 16384),
|
||||
# (4, 48, 8192),
|
||||
# (4, 48, 16384),
|
||||
],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
|
||||
Reference in New Issue
Block a user