mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)
Support having chain of mma with mixed size. Serialize the different block calculation in backward attention to workaround problem with ptxas and wgmma.
This commit is contained in:
@@ -141,16 +141,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
// Implement backward and forward slice that will go through scf blocks when
|
||||
// yield or scf results are in the slice.
|
||||
// Note that like exisiting forward and backard slice this may add operations to
|
||||
// the slice that are not actually dependent on the root because when a region
|
||||
// is added to the slice in the forward slice all the operations of the region
|
||||
// are added. We could implement a more accurate slice method by tracking value
|
||||
// usage across scf regions.
|
||||
void getBackwardSliceSCFAware(Operation *, SetVector<Operation *> *slices);
|
||||
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
@@ -424,6 +424,23 @@ bool supportMMA(Value value, int version) {
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
auto src = srcEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!src || !dst)
|
||||
return false;
|
||||
auto srcInstrShape = src.getInstrShape();
|
||||
auto dstInstrShape = dst.getInstrShape();
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
|
||||
}
|
||||
|
||||
// For MMAV3 dotOperand layout matches mma operand for f16 case.
|
||||
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
RankedTensorType dstTy) {
|
||||
@@ -432,7 +449,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
auto mmaLayout = srcLayout.cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout &&
|
||||
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
@@ -452,17 +469,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
|
||||
@@ -81,8 +81,7 @@ public:
|
||||
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
|
||||
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
|
||||
if (isMmaToMmaShortcut(srcTy, dstTy)) {
|
||||
rewriter.replaceOp(op, op.getSrc());
|
||||
return success();
|
||||
return lowerMmaToMma(op, adaptor, rewriter);
|
||||
}
|
||||
}
|
||||
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
||||
@@ -963,6 +962,43 @@ private:
|
||||
return failure();
|
||||
}
|
||||
|
||||
// mma -> mma
|
||||
LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
if (triton::gpu::getTotalElemsPerThread(srcTy) ==
|
||||
triton::gpu::getTotalElemsPerThread(dstTy)) {
|
||||
rewriter.replaceOp(op, op.getSrc());
|
||||
return success();
|
||||
}
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
SmallVector<Value> retVals;
|
||||
SmallVector<unsigned> dstElementPerThread =
|
||||
triton::gpu::getElemsPerThread(dstTy);
|
||||
SmallVector<unsigned> srcElementPerThread =
|
||||
triton::gpu::getElemsPerThread(srcTy);
|
||||
for (unsigned j = 0; j < dstElementPerThread[0]; j++) {
|
||||
for (unsigned i = 0; i < dstElementPerThread[1]; i++) {
|
||||
if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) {
|
||||
retVals.push_back(undef(vals[0].getType()));
|
||||
continue;
|
||||
}
|
||||
unsigned index = i + j * srcElementPerThread[1];
|
||||
retVals.push_back(vals[index]);
|
||||
}
|
||||
}
|
||||
assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy));
|
||||
Value view =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
|
||||
@@ -220,7 +220,10 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
const SmallVector<Value> &elements,
|
||||
int startIndex, int numElements) {
|
||||
int startIndex, int numElements,
|
||||
Operation *insertBefore) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(insertBefore);
|
||||
if (!elements[0].getType().isF16()) {
|
||||
llvm::SmallVector<Value> mmaOut(numElements);
|
||||
for (int i = 0; i < numElements; ++i)
|
||||
@@ -351,9 +354,12 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
int numTMADescs =
|
||||
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
|
||||
Operation *startSequence = nullptr;
|
||||
if (numTMADescs == 0)
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
|
||||
startSequence = rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
Operation *fenceOp = rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
|
||||
if (startSequence == nullptr)
|
||||
startSequence = fenceOp;
|
||||
// WGMMA fp8 -> fp32 accumulates in lower precision than fp32.
|
||||
bool needsPartialAccumulator = isFP8(eltTypeA) &&
|
||||
eltTypeC == triton::nvgpu::WGMMAEltType::f32 &&
|
||||
@@ -362,7 +368,8 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
llvm::SmallVector<Value> mmaOut =
|
||||
loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize);
|
||||
loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize,
|
||||
startSequence);
|
||||
llvm::SmallVector<Type> elemTypes;
|
||||
for (Value accEl : mmaOut)
|
||||
elemTypes.push_back(accEl.getType());
|
||||
@@ -379,8 +386,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
a = aLoader.smemLoad(m, k, rewriter, loc);
|
||||
} else {
|
||||
unsigned regASize = (instrShape[0] * instrShape[2]) / 32;
|
||||
llvm::SmallVector<Value> regA = loadReg(
|
||||
rewriter, loc, structA, (m * numRepK + k) * regASize, regASize);
|
||||
llvm::SmallVector<Value> regA =
|
||||
loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize,
|
||||
regASize, startSequence);
|
||||
auto regATy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(),
|
||||
SmallVector<Type>(regA.size(), regA[0].getType()));
|
||||
|
||||
@@ -1480,7 +1480,7 @@ struct TritonGPUInferLayoutInterface
|
||||
auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!operandEncoding.isa<SharedEncodingAttr>() &&
|
||||
!(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 &&
|
||||
dotOpEnc.getParent() == mmaRetEncoding)) {
|
||||
dotOpEnc.getParent().isa<MmaEncodingAttr>())) {
|
||||
return emitOptionalError(
|
||||
location, "unexpected operand layout for MmaEncodingAttr v3");
|
||||
}
|
||||
|
||||
@@ -145,35 +145,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getMmaV3InstrN(tt::DotOp dotOp, unsigned currN) const {
|
||||
auto type = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (type.getEncoding().isa<MmaEncodingAttr>())
|
||||
return currN;
|
||||
auto it = dotOpInstNs.find(dotOp.getOperation());
|
||||
if (it != dotOpInstNs.end())
|
||||
return it->second;
|
||||
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSliceSCFAware(dotOp.getResult(), &slices);
|
||||
mlir::getBackwardSliceSCFAware(dotOp.getOperation(), &slices);
|
||||
unsigned N = currN;
|
||||
SmallVector<Operation *> dotOps;
|
||||
for (Operation *iter : slices) {
|
||||
if (auto nextDotOp = dyn_cast<tt::DotOp>(iter)) {
|
||||
auto type = nextDotOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto AType = nextDotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto shapePerCTA = ttg::getShapePerCTA(type);
|
||||
auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType);
|
||||
dotOps.push_back(iter);
|
||||
if (instrShape[1] < N)
|
||||
N = instrShape[1];
|
||||
}
|
||||
}
|
||||
for (Operation *dotOp : dotOps)
|
||||
dotOpInstNs[dotOp] = N;
|
||||
return N;
|
||||
}
|
||||
|
||||
static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter,
|
||||
int opIdx) {
|
||||
Value arg = v;
|
||||
@@ -232,9 +203,6 @@ public:
|
||||
|
||||
auto instrShape =
|
||||
mmaVersionToInstrShape(versionMajor, retShapePerCTA, AType);
|
||||
if (versionMajor == 3)
|
||||
instrShape[1] = getMmaV3InstrN(dotOp, instrShape[1]);
|
||||
|
||||
// operands
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
|
||||
@@ -247,8 +247,10 @@ struct MMAV3UseRegOperand : public OpRewritePattern<triton::DotOp> {
|
||||
return failure();
|
||||
auto srcEncoding =
|
||||
getEncoding(convertLhs.getSrc()).dyn_cast<MmaEncodingAttr>();
|
||||
if (!srcEncoding || srcEncoding.getVersionMajor() != 3 ||
|
||||
srcEncoding != getEncoding(dotOp.getResult()))
|
||||
auto dstEncoding =
|
||||
getEncoding(dotOp.getResult()).dyn_cast<MmaEncodingAttr>();
|
||||
if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || !dstEncoding ||
|
||||
dstEncoding.getVersionMajor() != 3)
|
||||
return failure();
|
||||
// We currently only support convert from f16 mma to f16 dot operand as the
|
||||
// other types require shuffling data across threads.
|
||||
|
||||
@@ -217,7 +217,8 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
|
||||
if (convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding() == encoding)
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return true;
|
||||
}
|
||||
auto yield = dyn_cast<scf::YieldOp>(op);
|
||||
|
||||
@@ -492,46 +492,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
return linear;
|
||||
}
|
||||
|
||||
void getBackwardSliceSCFAware(Operation *op, SetVector<Operation *> *slices) {
|
||||
SmallVector<Operation *> queue = {op};
|
||||
while (!queue.empty()) {
|
||||
Operation *currentOp = queue.back();
|
||||
queue.pop_back();
|
||||
SetVector<Operation *> temp;
|
||||
auto filter = [slices](Operation *sliceOp) {
|
||||
return slices->count(sliceOp) == 0;
|
||||
};
|
||||
mlir::getBackwardSlice(currentOp, &temp, filter);
|
||||
for (Operation *sliceOp : temp) {
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(sliceOp)) {
|
||||
queue.push_back(forOp.getBody()->getTerminator());
|
||||
}
|
||||
}
|
||||
slices->insert(temp.begin(), temp.end());
|
||||
}
|
||||
}
|
||||
|
||||
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices) {
|
||||
SmallVector<Value> queue = {root};
|
||||
while (!queue.empty()) {
|
||||
Value currentValue = queue.back();
|
||||
queue.pop_back();
|
||||
SetVector<Operation *> temp;
|
||||
auto filter = [slices](Operation *sliceOp) {
|
||||
return slices->count(sliceOp) == 0;
|
||||
};
|
||||
mlir::getForwardSlice(currentValue, &temp, filter);
|
||||
for (Operation *sliceOp : temp) {
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(sliceOp)) {
|
||||
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
|
||||
if (forOp)
|
||||
queue.append(forOp->getResults().begin(), forOp->getResults().end());
|
||||
}
|
||||
}
|
||||
slices->insert(temp.begin(), temp.end());
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Detect dead arguments in scf.for op by assuming all the values are dead and
|
||||
|
||||
@@ -323,106 +323,102 @@ def _attn_bwd(
|
||||
# load scales
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
if (tl.program_id(1) == 0):
|
||||
# THIS BLOCK DOES DK/DV/DR:
|
||||
|
||||
# THIS BLOCK DOES DK/DV/DR:
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
else:
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
@@ -491,7 +487,7 @@ class _attention(torch.autograd.Function):
|
||||
BATCH, N_HEAD, N_CTX,
|
||||
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD)
|
||||
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
|
||||
_attn_bwd[grid](
|
||||
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
|
||||
M, delta,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s
|
||||
|
||||
// CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
|
||||
// CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
// CHECK: #[[MMA2:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 32, 16]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
@@ -22,7 +24,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2>
|
||||
// CHECK: scf.for
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]>
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]>
|
||||
%115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 {
|
||||
%172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked>
|
||||
%178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
|
||||
@@ -30,8 +32,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
scf.yield %180 : tensor<128x64xf16, #blocked1>
|
||||
}
|
||||
// CHECK: scf.for
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA]]>
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]>
|
||||
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]>
|
||||
%149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 {
|
||||
%166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2>
|
||||
%172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
|
||||
|
||||
Reference in New Issue
Block a user