mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FA fwd D=128] Reduce LDS usage in epilogue (#340)
* rebase onto improve_fwd_fa * Fixed a leftover from rebase * rebase onto improve_fa_fwd * Reduce tuning space * Disable bwd with D=128 * Add test for d=128 * Fix an issue with get_best_config when there is only one config * Added better configs for d=128 * Fix typos --------- Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
@@ -64,6 +64,10 @@ static void addWSNamedAttrs(Operation *op,
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr int LDSSize = 65536;
|
||||
constexpr int kPtrBitWidth = 64;
|
||||
#endif
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
|
||||
@@ -410,6 +414,7 @@ struct ConvertTritonGPUToLLVM
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
#ifdef USE_ROCM
|
||||
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
reduceCvtOpLDSUsage(mod);
|
||||
#endif
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
@@ -710,6 +715,151 @@ private:
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const {
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto bytes =
|
||||
srcType.getElementType().isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; }
|
||||
|
||||
std::vector<std::pair<int, int>> factorizePowerOf2(int n) const {
|
||||
assert(isPowerOfTwo(n));
|
||||
int x = log2(n);
|
||||
std::vector<std::pair<int, int>> pairs;
|
||||
|
||||
for (int i = 0; i <= x / 2; ++i) {
|
||||
int j = x - i;
|
||||
pairs.push_back({pow(2, i), pow(2, j)});
|
||||
pairs.push_back({pow(2, j), pow(2, i)});
|
||||
}
|
||||
|
||||
return pairs;
|
||||
}
|
||||
|
||||
std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
|
||||
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
|
||||
triton::gpu::ConvertLayoutOp &cvtOp,
|
||||
std::pair<unsigned, unsigned> warpsPerCta) const {
|
||||
unsigned warpsPerCtaX = warpsPerCta.first;
|
||||
unsigned warpsPerCtaY = warpsPerCta.second;
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
|
||||
auto srcMfma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
||||
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
|
||||
srcMfma.getIsTransposed(), srcMfma.getCTALayout());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
|
||||
auto newSrcType = RankedTensorType::get(
|
||||
srcType.getShape(), srcType.getElementType(), newMfmaEnc);
|
||||
|
||||
auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newSrcType, cvtOp.getOperand());
|
||||
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newDstType, tmpCvt);
|
||||
|
||||
return std::make_pair(tmpCvt, newEpilogueCvt);
|
||||
}
|
||||
|
||||
// Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of
|
||||
// WarpsPerCta attribute in mfma layout. The implicit LDS usage of
|
||||
// cvt(mfma->blocked) op depends on the number of warps per CTA that mfma
|
||||
// layout uses along x dimension and block layout uses across y dimension.
|
||||
//
|
||||
// clang-format off
|
||||
//
|
||||
// LDS usage of this op is roughly calculated as:
|
||||
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)
|
||||
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
|
||||
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)
|
||||
//
|
||||
// clang-format on
|
||||
//
|
||||
// When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by
|
||||
// decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp)
|
||||
// and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that
|
||||
// minimizes uses of LDS for these conversions.
|
||||
void reduceCvtOpLDSUsage(ModuleOp mod) const {
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
|
||||
auto srcMfma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto dstBlocked =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
|
||||
if (!srcMfma || !dstBlocked) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
|
||||
if (currLDSUsage <= LDSSize) {
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned numWarps =
|
||||
srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1];
|
||||
|
||||
triton::gpu::ConvertLayoutOp tmpCvt;
|
||||
triton::gpu::ConvertLayoutOp newEpilogueCvt;
|
||||
|
||||
// Find all possible shapes of WarpsPerCTA by finding all possible
|
||||
// factorizations of numWarps. Pick shape for which both conversions in
|
||||
// decomposition use LDS less than LDSSize and for which sum of LDS usage
|
||||
// is minimal. If no such shape exists, do not decompose.
|
||||
unsigned minLDSUsage = 2 * LDSSize;
|
||||
int minIdx = -1;
|
||||
auto factorizedNumWarps = factorizePowerOf2(numWarps);
|
||||
|
||||
for (int i = 0; i < factorizedNumWarps.size(); i++) {
|
||||
auto warpsPerCTAPair = factorizedNumWarps[i];
|
||||
std::tie(tmpCvt, newEpilogueCvt) =
|
||||
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
|
||||
|
||||
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
|
||||
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
|
||||
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
|
||||
int LDSUsage = tmpCvtLDS + newCvtLDS;
|
||||
if (LDSUsage < minLDSUsage) {
|
||||
minLDSUsage = LDSUsage;
|
||||
minIdx = i;
|
||||
}
|
||||
}
|
||||
newEpilogueCvt.erase();
|
||||
tmpCvt.erase();
|
||||
}
|
||||
|
||||
if (minIdx == -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
|
||||
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
|
||||
std::tie(tmpCvt, newEpilogueCvt) =
|
||||
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
|
||||
|
||||
cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
|
||||
cvtOp.erase();
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) const {
|
||||
|
||||
@@ -100,7 +100,7 @@ class Autotuner(KernelInterface):
|
||||
key_values.append(kwargs[name])
|
||||
key = tuple(key_values)
|
||||
|
||||
return self.cache[key] if key in self.cache else Config({})
|
||||
return self.best_config
|
||||
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
||||
@@ -80,28 +80,12 @@ def _attn_fwd_inner(
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
],
|
||||
key=['N_CTX', 'STAGE'],
|
||||
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
|
||||
|
||||
@@ -114,9 +98,9 @@ def _attn_fwd(
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H,
|
||||
N_CTX,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
pre_load_v: tl.constexpr,
|
||||
):
|
||||
@@ -562,7 +546,7 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
|
||||
## restore the grid for bwd kernel
|
||||
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage)
|
||||
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
|
||||
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
|
||||
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
|
||||
|
||||
@@ -655,6 +639,9 @@ attention = _attention.apply
|
||||
[(4, 48, 1024, 64),
|
||||
(4, 48, 2048, 64),
|
||||
(4, 48, 4096, 64),
|
||||
(4, 48, 1024, 128),
|
||||
(4, 48, 2048, 128),
|
||||
(4, 48, 4096, 128),
|
||||
#(4, 48, 8192, 64),
|
||||
#(4, 48, 16384, 64)
|
||||
])
|
||||
@@ -747,30 +734,33 @@ except BaseException:
|
||||
FLASH_VER = None
|
||||
HAS_FLASH = FLASH_VER is not None
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
BATCH, N_HEADS, N_CTX= 4, 48, 4096
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = []
|
||||
for mode in ['fwd', 'bwd']:
|
||||
for causal in [False, True]:
|
||||
if mode == 'bwd' and causal == False:
|
||||
continue
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
for D_HEAD in [64, 128]:
|
||||
if mode == 'bwd' and D_HEAD == 128:
|
||||
continue
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
|
||||
Reference in New Issue
Block a user