mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix bfloat16 flash attention (#1306)
See https://github.com/openai/triton/issues/1245 for more detailed information --------- Co-authored-by: giorgio-arena <arena.cpp@gmail.com>
This commit is contained in:
@@ -558,16 +558,35 @@ private:
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
SmallVector<Type> types;
|
||||
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
|
||||
// instructions to pack & unpack sub-word integers. A workaround is to
|
||||
// store the results of ldmatrix in i32
|
||||
auto elemSize = elemTy.getIntOrFloatBitWidth();
|
||||
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
|
||||
auto fold = 32 / elemSize;
|
||||
for (unsigned i = 0; i < elems; i += fold) {
|
||||
Value val = i32_val(0);
|
||||
for (unsigned j = 0; j < fold; j++) {
|
||||
auto ext =
|
||||
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
|
||||
val = or_(i32_ty, val, ext);
|
||||
}
|
||||
vecVals.push_back(val);
|
||||
}
|
||||
elems = elems / (32 / elemSize);
|
||||
types = SmallVector<Type>(elems, i32_ty);
|
||||
} else {
|
||||
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
types = SmallVector<Type>(elems / vecSize, vecTy);
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
|
||||
@@ -259,7 +259,7 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f32.bf16");
|
||||
auto &cvt = *builder.create("cvt.f32.bf16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
|
||||
|
||||
@@ -5,7 +5,8 @@ import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
p = torch.softmax(p.float(), dim=-1).to(dtype)
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
decimal = 1 if dtype == torch.bfloat16 else 2
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
@@ -63,7 +63,7 @@ def _fwd_kernel(
|
||||
p *= l_rcp
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(tl.float16)
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
@@ -167,7 +167,7 @@ def _bwd_kernel(
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
@@ -175,10 +175,10 @@ def _bwd_kernel(
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
|
||||
Reference in New Issue
Block a user