This commit is contained in:
Ognjen
2024-01-31 14:13:36 +00:00
parent 171a67e837
commit 3a12d9d269
2 changed files with 41 additions and 27 deletions

View File

@@ -222,7 +222,7 @@ public:
return std::distance(value.user_begin(), value.user_end());
}
void scheduleSlicedDot(ModuleOp m, int stages) {
void scheduleSlicedDot(ModuleOp m, int stages, bool sinkLDSRd) {
SmallVector<SmallVector<Operation *>> dotChains;
m.walk([&](tt::DotOp dotOp) {
@@ -270,6 +270,19 @@ public:
operations, i == 0, 1);
}
}
if (!sinkLDSRd) {
return;
}
for (auto chain : dotChains) {
for (int i = 0; i < chain.size(); i++) {
Operation *dotOp = chain[i];
Operation *ldsRd = dotOp->getOperand(1).getDefiningOp();
assert(isLDSRead(ldsRd));
moveBefore(ldsRd, dotOp);
}
}
}
void runOnOperation() override {
@@ -278,7 +291,8 @@ public:
moveQTensorOutOfTheLoop(m);
int stages = 4;
scheduleSlicedDot(m, stages);
bool sinkLDSRd = true;
scheduleSlicedDot(m, stages, sinkLDSRd);
}
};

View File

@@ -84,23 +84,23 @@ def _attn_fwd_inner(
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
@@ -755,20 +755,20 @@ HAS_FLASH = FLASH_VER is not None
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
for D_HEAD in [128, 64]:
for causal in [False, True]:
for D_HEAD in [128]:
for causal in [False]:
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(4, 48, 1024),
(4, 48, 2048),
x_vals=[#(16, 16, 1024),
# (8, 16, 2048),
# (4, 16, 4096),
# (2, 16, 8192),
# (1, 16, 16384),
# (4, 48, 1024),
# (4, 48, 2048),
(4, 48, 4096),
(4, 48, 8192),
(4, 48, 16384),
# (4, 48, 8192),
# (4, 48, 16384),
],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),