[FA fwd D=128] Reduce LDS usage in epilogue (#340)

* rebase onto improve_fwd_fa

* Fixed a leftover from rebase

* rebase onto improve_fa_fwd

* Reduce tuning space

* Disable bwd with D=128

* Add test for d=128

* Fix an issue with get_best_config when there is only one config

* Added better configs for d=128

* Fix typos

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
oplavsic
2023-10-25 19:10:34 +02:00
committed by GitHub
parent e74bdb1581
commit 715a589ce3
3 changed files with 182 additions and 42 deletions

View File

@@ -64,6 +64,10 @@ static void addWSNamedAttrs(Operation *op,
op->setAttr(attr.getName(), attr.getValue());
}
#ifdef USE_ROCM
constexpr int LDSSize = 65536;
constexpr int kPtrBitWidth = 64;
#endif
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
@@ -410,6 +414,7 @@ struct ConvertTritonGPUToLLVM
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
#ifdef USE_ROCM
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
reduceCvtOpLDSUsage(mod);
#endif
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
@@ -710,6 +715,151 @@ private:
}
});
}
int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const {
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto bytes =
srcType.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;
return bytes;
}
bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; }
std::vector<std::pair<int, int>> factorizePowerOf2(int n) const {
assert(isPowerOfTwo(n));
int x = log2(n);
std::vector<std::pair<int, int>> pairs;
for (int i = 0; i <= x / 2; ++i) {
int j = x - i;
pairs.push_back({pow(2, i), pow(2, j)});
pairs.push_back({pow(2, j), pow(2, i)});
}
return pairs;
}
std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
triton::gpu::ConvertLayoutOp &cvtOp,
std::pair<unsigned, unsigned> warpsPerCta) const {
unsigned warpsPerCtaX = warpsPerCta.first;
unsigned warpsPerCtaY = warpsPerCta.second;
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
srcMfma.getIsTransposed(), srcMfma.getCTALayout());
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
auto newSrcType = RankedTensorType::get(
srcType.getShape(), srcType.getElementType(), newMfmaEnc);
auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newSrcType, cvtOp.getOperand());
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newDstType, tmpCvt);
return std::make_pair(tmpCvt, newEpilogueCvt);
}
// Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of
// WarpsPerCta attribute in mfma layout. The implicit LDS usage of
// cvt(mfma->blocked) op depends on the number of warps per CTA that mfma
// layout uses along x dimension and block layout uses across y dimension.
//
// clang-format off
//
// LDS usage of this op is roughly calculated as:
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)
//
// clang-format on
//
// When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by
// decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp)
// and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that
// minimizes uses of LDS for these conversions.
void reduceCvtOpLDSUsage(ModuleOp mod) const {
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto dstBlocked =
dstType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
if (!srcMfma || !dstBlocked) {
return;
}
auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
if (currLDSUsage <= LDSSize) {
return;
}
unsigned numWarps =
srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1];
triton::gpu::ConvertLayoutOp tmpCvt;
triton::gpu::ConvertLayoutOp newEpilogueCvt;
// Find all possible shapes of WarpsPerCTA by finding all possible
// factorizations of numWarps. Pick shape for which both conversions in
// decomposition use LDS less than LDSSize and for which sum of LDS usage
// is minimal. If no such shape exists, do not decompose.
unsigned minLDSUsage = 2 * LDSSize;
int minIdx = -1;
auto factorizedNumWarps = factorizePowerOf2(numWarps);
for (int i = 0; i < factorizedNumWarps.size(); i++) {
auto warpsPerCTAPair = factorizedNumWarps[i];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
int LDSUsage = tmpCvtLDS + newCvtLDS;
if (LDSUsage < minLDSUsage) {
minLDSUsage = LDSUsage;
minIdx = i;
}
}
newEpilogueCvt.erase();
tmpCvt.erase();
}
if (minIdx == -1) {
return;
}
assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
cvtOp.erase();
});
}
#endif
void decomposeBlockedToDotOperand(ModuleOp mod) const {

View File

@@ -100,7 +100,7 @@ class Autotuner(KernelInterface):
key_values.append(kwargs[name])
key = tuple(key_values)
return self.cache[key] if key in self.cache else Config({})
return self.best_config
def run(self, *args, **kwargs):

View File

@@ -80,28 +80,12 @@ def _attn_fwd_inner(
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['N_CTX', 'STAGE'],
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
@@ -114,9 +98,9 @@ def _attn_fwd(
stride_oz, stride_oh, stride_om, stride_on,
Z, H,
N_CTX,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
):
@@ -562,7 +546,7 @@ class _attention(torch.autograd.Function):
)
## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage)
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
@@ -655,6 +639,9 @@ attention = _attention.apply
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
#(4, 48, 8192, 64),
#(4, 48, 16384, 64)
])
@@ -747,30 +734,33 @@ except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX= 4, 48, 4096
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd', 'bwd']:
for causal in [False, True]:
if mode == 'bwd' and causal == False:
continue
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
for D_HEAD in [64, 128]:
if mode == 'bwd' and D_HEAD == 128:
continue
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
@triton.testing.perf_report(configs)