[BACKEND] Fix bfloat16 flash attention (#1306)

See https://github.com/openai/triton/issues/1245 for more detailed
information

---------

Co-authored-by: giorgio-arena <arena.cpp@gmail.com>
This commit is contained in:
Keren Zhou
2023-03-10 00:14:52 -05:00
committed by GitHub
parent a4a824a3c9
commit 8b25c30d39
5 changed files with 39 additions and 17 deletions

View File

@@ -558,16 +558,35 @@ private:
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
SmallVector<Type> types;
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of ldmatrix in i32
auto elemSize = elemTy.getIntOrFloatBitWidth();
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
auto fold = 32 / elemSize;
for (unsigned i = 0; i < elems; i += fold) {
Value val = i32_val(0);
for (unsigned j = 0; j < fold; j++) {
auto ext =
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
val = or_(i32_ty, val, ext);
}
vecVals.push_back(val);
}
elems = elems / (32 / elemSize);
types = SmallVector<Type>(elems, i32_ty);
} else {
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
}
// This needs to be ordered the same way that

View File

@@ -259,7 +259,7 @@ struct FpToFpOpConversion
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f32.bf16");
auto &cvt = *builder.create("cvt.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);

View File

@@ -27,6 +27,7 @@
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)

View File

@@ -5,7 +5,8 @@ import triton
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
decimal = 1 if dtype == torch.bfloat16 else 2
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)

View File

@@ -63,7 +63,7 @@ def _fwd_kernel(
p *= l_rcp
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(tl.float16)
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
@@ -167,7 +167,7 @@ def _bwd_kernel(
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
@@ -175,10 +175,10 @@ def _bwd_kernel(
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm