mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton. * add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16` * change flashattention to support fp8 in q/k.
408 lines
15 KiB
C++
408 lines
15 KiB
C++
#include "Utility.h"
|
|
#include "TypeConverter.h"
|
|
|
|
namespace mlir {
|
|
|
|
namespace LLVM {
|
|
using namespace mlir::triton;
|
|
|
|
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
|
|
auto i32ty = rewriter.getIntegerType(32);
|
|
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
|
IntegerAttr::get(i32ty, v));
|
|
}
|
|
|
|
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
|
|
auto type = type::f32Ty(rewriter.getContext());
|
|
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
|
rewriter.getF32FloatAttr(v));
|
|
}
|
|
|
|
Value createConstantF64(Location loc, OpBuilder &rewriter, float v) {
|
|
auto type = type::f64Ty(rewriter.getContext());
|
|
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
|
rewriter.getF64FloatAttr(v));
|
|
}
|
|
|
|
// Create an index type constant.
|
|
Value createIndexConstant(OpBuilder &builder, Location loc,
|
|
TypeConverter *converter, int64_t value) {
|
|
Type ty = converter->convertType(builder.getIndexType());
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
builder.getIntegerAttr(ty, value));
|
|
}
|
|
|
|
// Create an integer constant of \param width bits.
|
|
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|
int64_t value) {
|
|
Type ty = builder.getIntegerType(width);
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
builder.getIntegerAttr(ty, value));
|
|
}
|
|
|
|
// A wrapper of LoadDSmemOp when vec = 1
|
|
// (1) Get bitwidth from elemTy
|
|
// (2) Create LoadDSmemOp
|
|
// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy
|
|
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
|
Value ctaId) {
|
|
assert(addr.getType().isa<LLVMPointerType>() &&
|
|
"addr must be a pointer type");
|
|
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
|
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
|
auto elemTy = ptrTy.getElementType();
|
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
|
Value ret =
|
|
rewriter.create<triton::nvgpu::LoadDSmemOp>(loc, addr, ctaId, bitwidth);
|
|
return bitcast(ret, elemTy);
|
|
}
|
|
|
|
// A wrapper of LoadDSmemOp when vec > 1
|
|
// (1) Get bitwidth from elemTy
|
|
// (2) Create LoadDSmemOp and extract results from retStruct
|
|
// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy
|
|
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
|
|
Value addr, Value ctaId, unsigned vec) {
|
|
assert(addr.getType().isa<LLVMPointerType>() &&
|
|
"addr must be a pointer type");
|
|
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
|
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
|
auto elemTy = ptrTy.getElementType();
|
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
|
Value retStruct = rewriter.create<triton::nvgpu::LoadDSmemOp>(
|
|
loc, addr, ctaId, bitwidth, vec);
|
|
SmallVector<Value> retVals;
|
|
for (unsigned i = 0; i < vec; ++i) {
|
|
auto dataTy = rewriter.getIntegerType(bitwidth);
|
|
Value data = extract_val(dataTy, retStruct, i);
|
|
retVals.push_back(bitcast(data, elemTy));
|
|
}
|
|
return retVals;
|
|
}
|
|
|
|
// A wrapper of StoreDSmemOp when vec = 1
|
|
// (1) Get bitwidth from elemTy
|
|
// (2) Bitcast value from elemTy to dataTy (u16/u32/u64)
|
|
// (3) Create StoreDSmemOp
|
|
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
|
Value ctaId, Value value, Value pred) {
|
|
assert(addr.getType().isa<LLVMPointerType>() &&
|
|
"addr must be a pointer type");
|
|
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
|
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
|
auto elemTy = ptrTy.getElementType();
|
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
|
auto dataTy = rewriter.getIntegerType(bitwidth);
|
|
Value data = bitcast(value, dataTy);
|
|
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
|
}
|
|
|
|
// A wrapper of StoreDSmemOp when vec = 1 and pred = 1
|
|
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
|
Value ctaId, Value value) {
|
|
Value pred = int_val(/*width=*/1, 1);
|
|
createStoreDSmem(loc, rewriter, addr, ctaId, value, pred);
|
|
}
|
|
|
|
// A wrapper of StoreDSmemOp when vec > 1
|
|
// (1) Get bitwidth from elemTy
|
|
// (2) Bitcast values from elemTy to dataTy (u16/u32/u64)
|
|
// (3) Create StoreDSmemOp
|
|
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
|
Value ctaId, ArrayRef<Value> values, Value pred) {
|
|
assert(addr.getType().isa<LLVMPointerType>() &&
|
|
"addr must be a pointer type");
|
|
auto ptrTy = addr.getType().cast<LLVMPointerType>();
|
|
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
|
|
auto elemTy = ptrTy.getElementType();
|
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
|
auto dataTy = rewriter.getIntegerType(bitwidth);
|
|
SmallVector<Value> data;
|
|
for (unsigned i = 0; i < values.size(); ++i)
|
|
data.push_back(bitcast(values[i], dataTy));
|
|
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
|
|
}
|
|
|
|
// A wrapper of StoreDSmemOp when vec > 1 and pred = 1
|
|
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
|
|
Value ctaId, ArrayRef<Value> values) {
|
|
Value pred = int_val(/*width=*/1, 1);
|
|
createStoreDSmem(loc, rewriter, addr, ctaId, values, pred);
|
|
}
|
|
|
|
SharedMemoryObject
|
|
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
|
ConversionPatternRewriter &rewriter) {
|
|
ArrayRef<Type> types =
|
|
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
|
SmallVector<Value> elems(types.size());
|
|
for (unsigned i = 0; i < types.size(); ++i) {
|
|
Type type = types[i];
|
|
elems[i] = extract_val(type, llvmStruct, i);
|
|
}
|
|
|
|
auto rank = (elems.size() - 1) / 2;
|
|
return {/*base=*/elems[0],
|
|
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
|
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
|
}
|
|
|
|
SmallVector<Value>
|
|
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
|
Location loc, ConversionPatternRewriter &rewriter) {
|
|
auto rank = shape.size();
|
|
SmallVector<Value> strides(rank);
|
|
int64_t stride = 1;
|
|
for (auto idx : order) {
|
|
strides[idx] = i32_val(stride);
|
|
stride *= shape[idx];
|
|
}
|
|
return strides;
|
|
}
|
|
|
|
// Convert an \param index to a multi-dim coordinate given \param shape and
|
|
// \param order.
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value linear,
|
|
ArrayRef<unsigned> shape,
|
|
ArrayRef<unsigned> order) {
|
|
unsigned rank = shape.size();
|
|
assert(rank == order.size());
|
|
auto reordered = reorder(shape, order);
|
|
SmallVector<Value> reorderedMultiDim(rank);
|
|
if (auto constantOp = linear.getDefiningOp<arith::ConstantOp>()) {
|
|
unsigned intVal =
|
|
constantOp.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
|
reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered);
|
|
} else {
|
|
reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
|
}
|
|
SmallVector<Value> multiDim(rank);
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
multiDim[order[i]] = reorderedMultiDim[i];
|
|
}
|
|
return multiDim;
|
|
}
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
Location loc, unsigned linear,
|
|
ArrayRef<unsigned> shape) {
|
|
unsigned rank = shape.size();
|
|
assert(rank > 0);
|
|
SmallVector<Value> multiDim(rank);
|
|
unsigned remained = linear;
|
|
for (auto &&en : llvm::enumerate(shape)) {
|
|
unsigned dimSize = en.value();
|
|
multiDim[en.index()] = i32_val(remained % dimSize);
|
|
remained = remained / dimSize;
|
|
}
|
|
return multiDim;
|
|
}
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value linear,
|
|
ArrayRef<unsigned> shape) {
|
|
unsigned rank = shape.size();
|
|
assert(rank > 0);
|
|
SmallVector<Value> multiDim(rank);
|
|
Value remained = linear;
|
|
for (auto &&en : llvm::enumerate(shape)) {
|
|
Value dimSize = i32_val(en.value());
|
|
multiDim[en.index()] = urem(remained, dimSize);
|
|
remained = udiv(remained, dimSize);
|
|
}
|
|
return multiDim;
|
|
}
|
|
|
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
|
ArrayRef<unsigned> order) {
|
|
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
|
reorder<unsigned>(shape, order));
|
|
}
|
|
|
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
|
|
auto rank = multiDim.size();
|
|
Value linear = i32_val(0);
|
|
if (rank > 0) {
|
|
linear = multiDim.back();
|
|
for (auto [dim, dimShape] :
|
|
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
|
Value dimSize = i32_val(dimShape);
|
|
linear = add(mul(linear, dimSize), dim);
|
|
}
|
|
}
|
|
return linear;
|
|
}
|
|
|
|
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
|
Value val, Value pred) {
|
|
#if USE_ROCM
|
|
store(val, ptr);
|
|
return val;
|
|
#else
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
|
|
|
PTXBuilder builder;
|
|
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
|
auto *valOpr = builder.newOperand(val, c);
|
|
auto &st = builder.create<>("st")->shared().b(bits);
|
|
st(ptrOpr, valOpr).predicate(pred, "b");
|
|
return builder.launch(rewriter, loc, void_ty(ctx));
|
|
#endif
|
|
}
|
|
|
|
static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
|
Value val, int i, const std::string &shuffleType,
|
|
const std::string &clamp, Value laneId = Value()) {
|
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
|
|
#ifdef USE_ROCM
|
|
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
|
|
//so we need promote to 32 here.
|
|
if (bits == 8) {
|
|
Value i32Val = sext(i32_ty, val);
|
|
Value result = commonShflSync(loc, rewriter, i32Val, i, shuffleType, clamp, laneId);
|
|
return trunc(i8_ty, result);
|
|
}
|
|
#endif
|
|
|
|
if (bits == 64) {
|
|
Type vecTy = vec_ty(f32_ty, 2);
|
|
Value vec = bitcast(val, vecTy);
|
|
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
|
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
|
val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp, laneId);
|
|
val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp, laneId);
|
|
vec = undef(vecTy);
|
|
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
|
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
|
return bitcast(vec, val.getType());
|
|
}
|
|
|
|
#ifdef USE_ROCM
|
|
GCNBuilder builder;
|
|
if (shuffleType == "bfly") {
|
|
if (i > 16) {
|
|
Value threadId =
|
|
rewriter
|
|
.create<UnrealizedConversionCastOp>(
|
|
loc, TypeRange{i32_ty},
|
|
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
|
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
|
|
.getResult(0);
|
|
Value stride = i32_val(32);
|
|
Value byteOffset = i32_val(2);
|
|
Value lineId = add(threadId, stride);
|
|
Value permuteAddr = shl(lineId, byteOffset);
|
|
auto shfl = builder.create("ds_permute_b32");
|
|
auto dOpr = builder.newOperand("=v");
|
|
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
|
auto aOpr = builder.newOperand(val, "v");
|
|
(*shfl)(dOpr, addrOpr, aOpr);
|
|
} else {
|
|
// This map facilates the butterfly shuffle pattern for a stride less
|
|
// than 16. The pattern stride is the key of the map.
|
|
DenseMap<short, unsigned int> masks{
|
|
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
|
|
auto shfl = builder.create("ds_swizzle_b32");
|
|
auto dOpr = builder.newOperand("=v");
|
|
auto aOpr = builder.newOperand(val, "v");
|
|
auto maskOpr =
|
|
builder.newConstantOperand("offset:" + std::to_string(masks[i]));
|
|
(*shfl)(dOpr, aOpr, maskOpr);
|
|
}
|
|
} else { // shuffle_up
|
|
assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported");
|
|
Value mask = icmp_slt(laneId, i32_val(i));
|
|
Value delta = sub(laneId, i32_val(i));
|
|
Value index = select(mask, laneId, delta);
|
|
Value byteOffset = i32_val(2);
|
|
Value permuteAddr = shl(index, byteOffset);
|
|
auto shfl = builder.create("ds_bpermute_b32");
|
|
auto dOpr = builder.newOperand("=v");
|
|
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
|
auto aOpr = builder.newOperand(val, "v");
|
|
(*shfl)(dOpr, addrOpr, aOpr);
|
|
}
|
|
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
|
|
(*swait)();
|
|
return builder.launch(rewriter, loc, val.getType(), true);
|
|
#else
|
|
PTXBuilder builder;
|
|
auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32");
|
|
auto *dOpr = builder.newOperand("=r");
|
|
auto *aOpr = builder.newOperand(val, "r");
|
|
auto *bOpr = builder.newConstantOperand(i);
|
|
auto *cOpr = builder.newConstantOperand(clamp);
|
|
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
|
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
|
return builder.launch(rewriter, loc, val.getType(), false);
|
|
#endif
|
|
}
|
|
|
|
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
|
int i) {
|
|
return commonShflSync(loc, rewriter, val, i, "bfly", "0x1f");
|
|
}
|
|
|
|
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
|
int i, Value laneId) {
|
|
return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId);
|
|
}
|
|
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) {
|
|
PTXBuilder builder;
|
|
auto &mov = builder.create("mov")->o("u32");
|
|
auto *destOpr = builder.newOperand("=r");
|
|
auto *sRegOpr = builder.newConstantOperand(sRegStr);
|
|
mov(destOpr, sRegOpr);
|
|
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
|
|
return val;
|
|
}
|
|
|
|
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
|
StringRef key, StringRef content) {
|
|
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
auto ctx = moduleOp.getContext();
|
|
unsigned stringNumber = 0;
|
|
SmallString<16> stringConstName;
|
|
do {
|
|
stringConstName.clear();
|
|
(key + Twine(stringNumber++)).toStringRef(stringConstName);
|
|
} while (moduleOp.lookupSymbol(stringConstName));
|
|
|
|
llvm::SmallString<64> contentStr(content);
|
|
size_t contentSize = contentStr.size_in_bytes();
|
|
auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize);
|
|
|
|
LLVM::GlobalOp global;
|
|
{
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
global = rewriter.create<LLVM::GlobalOp>(
|
|
UnknownLoc::get(ctx), globalType,
|
|
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
|
rewriter.getStringAttr(contentStr));
|
|
}
|
|
|
|
Value zero = i32_val(0);
|
|
Value globalPtr =
|
|
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
|
|
Value stringStart =
|
|
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
|
|
globalPtr, SmallVector<Value>({zero, zero}));
|
|
return stringStart;
|
|
}
|
|
|
|
} // namespace LLVM
|
|
|
|
bool isF8(Type eType) {
|
|
return eType.isFloat8E4M3FNUZ() or eType.isFloat8E4M3FN() or
|
|
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
|
|
}
|
|
|
|
} // namespace mlir
|