Files
ROCm/lib/Conversion/TritonGPUToLLVM/Utility.cpp
Shucai Xiao 79bebc4ffe fp8 type support (#357)
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
2023-11-02 15:51:23 -05:00

408 lines
15 KiB
C++

#include "Utility.h"
#include "TypeConverter.h"
namespace mlir {
namespace LLVM {
using namespace mlir::triton;
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
}
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
Value createConstantF64(Location loc, OpBuilder &rewriter, float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
}
// Create an index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// A wrapper of LoadDSmemOp when vec = 1
// (1) Get bitwidth from elemTy
// (2) Create LoadDSmemOp
// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
Value ret =
rewriter.create<triton::nvgpu::LoadDSmemOp>(loc, addr, ctaId, bitwidth);
return bitcast(ret, elemTy);
}
// A wrapper of LoadDSmemOp when vec > 1
// (1) Get bitwidth from elemTy
// (2) Create LoadDSmemOp and extract results from retStruct
// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
Value addr, Value ctaId, unsigned vec) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
Value retStruct = rewriter.create<triton::nvgpu::LoadDSmemOp>(
loc, addr, ctaId, bitwidth, vec);
SmallVector<Value> retVals;
for (unsigned i = 0; i < vec; ++i) {
auto dataTy = rewriter.getIntegerType(bitwidth);
Value data = extract_val(dataTy, retStruct, i);
retVals.push_back(bitcast(data, elemTy));
}
return retVals;
}
// A wrapper of StoreDSmemOp when vec = 1
// (1) Get bitwidth from elemTy
// (2) Bitcast value from elemTy to dataTy (u16/u32/u64)
// (3) Create StoreDSmemOp
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value, Value pred) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
auto dataTy = rewriter.getIntegerType(bitwidth);
Value data = bitcast(value, dataTy);
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
}
// A wrapper of StoreDSmemOp when vec = 1 and pred = 1
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value) {
Value pred = int_val(/*width=*/1, 1);
createStoreDSmem(loc, rewriter, addr, ctaId, value, pred);
}
// A wrapper of StoreDSmemOp when vec > 1
// (1) Get bitwidth from elemTy
// (2) Bitcast values from elemTy to dataTy (u16/u32/u64)
// (3) Create StoreDSmemOp
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values, Value pred) {
assert(addr.getType().isa<LLVMPointerType>() &&
"addr must be a pointer type");
auto ptrTy = addr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
auto dataTy = rewriter.getIntegerType(bitwidth);
SmallVector<Value> data;
for (unsigned i = 0; i < values.size(); ++i)
data.push_back(bitcast(values[i], dataTy));
rewriter.create<triton::nvgpu::StoreDSmemOp>(loc, addr, ctaId, data, pred);
}
// A wrapper of StoreDSmemOp when vec > 1 and pred = 1
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values) {
Value pred = int_val(/*width=*/1, 1);
createStoreDSmem(loc, rewriter, addr, ctaId, values, pred);
}
SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> elems(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
elems[i] = extract_val(type, llvmStruct, i);
}
auto rank = (elems.size() - 1) / 2;
return {/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}
SmallVector<Value>
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
Location loc, ConversionPatternRewriter &rewriter) {
auto rank = shape.size();
SmallVector<Value> strides(rank);
int64_t stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
}
return strides;
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
unsigned rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
SmallVector<Value> reorderedMultiDim(rank);
if (auto constantOp = linear.getDefiningOp<arith::ConstantOp>()) {
unsigned intVal =
constantOp.getValue().cast<IntegerAttr>().getValue().getSExtValue();
reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered);
} else {
reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
}
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, unsigned linear,
ArrayRef<unsigned> shape) {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
unsigned remained = linear;
for (auto &&en : llvm::enumerate(shape)) {
unsigned dimSize = en.value();
multiDim[en.index()] = i32_val(remained % dimSize);
remained = remained / dimSize;
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
Value remained = linear;
for (auto &&en : llvm::enumerate(shape)) {
Value dimSize = i32_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
auto rank = multiDim.size();
Value linear = i32_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = i32_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
return linear;
}
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
#if USE_ROCM
store(val, ptr);
return val;
#else
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
#endif
}
static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
Value val, int i, const std::string &shuffleType,
const std::string &clamp, Value laneId = Value()) {
unsigned bits = val.getType().getIntOrFloatBitWidth();
#ifdef USE_ROCM
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
//so we need promote to 32 here.
if (bits == 8) {
Value i32Val = sext(i32_ty, val);
Value result = commonShflSync(loc, rewriter, i32Val, i, shuffleType, clamp, laneId);
return trunc(i8_ty, result);
}
#endif
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp, laneId);
val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp, laneId);
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(vec, val.getType());
}
#ifdef USE_ROCM
GCNBuilder builder;
if (shuffleType == "bfly") {
if (i > 16) {
Value threadId =
rewriter
.create<UnrealizedConversionCastOp>(
loc, TypeRange{i32_ty},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
.getResult(0);
Value stride = i32_val(32);
Value byteOffset = i32_val(2);
Value lineId = add(threadId, stride);
Value permuteAddr = shl(lineId, byteOffset);
auto shfl = builder.create("ds_permute_b32");
auto dOpr = builder.newOperand("=v");
auto addrOpr = builder.newOperand(permuteAddr, "v");
auto aOpr = builder.newOperand(val, "v");
(*shfl)(dOpr, addrOpr, aOpr);
} else {
// This map facilates the butterfly shuffle pattern for a stride less
// than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
auto shfl = builder.create("ds_swizzle_b32");
auto dOpr = builder.newOperand("=v");
auto aOpr = builder.newOperand(val, "v");
auto maskOpr =
builder.newConstantOperand("offset:" + std::to_string(masks[i]));
(*shfl)(dOpr, aOpr, maskOpr);
}
} else { // shuffle_up
assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported");
Value mask = icmp_slt(laneId, i32_val(i));
Value delta = sub(laneId, i32_val(i));
Value index = select(mask, laneId, delta);
Value byteOffset = i32_val(2);
Value permuteAddr = shl(index, byteOffset);
auto shfl = builder.create("ds_bpermute_b32");
auto dOpr = builder.newOperand("=v");
auto addrOpr = builder.newOperand(permuteAddr, "v");
auto aOpr = builder.newOperand(val, "v");
(*shfl)(dOpr, addrOpr, aOpr);
}
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
return builder.launch(rewriter, loc, val.getType(), true);
#else
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32");
auto *dOpr = builder.newOperand("=r");
auto *aOpr = builder.newOperand(val, "r");
auto *bOpr = builder.newConstantOperand(i);
auto *cOpr = builder.newConstantOperand(clamp);
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
return builder.launch(rewriter, loc, val.getType(), false);
#endif
}
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i) {
return commonShflSync(loc, rewriter, val, i, "bfly", "0x1f");
}
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i, Value laneId) {
return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId);
}
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) {
PTXBuilder builder;
auto &mov = builder.create("mov")->o("u32");
auto *destOpr = builder.newOperand("=r");
auto *sRegOpr = builder.newConstantOperand(sRegStr);
mov(destOpr, sRegOpr);
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
return val;
}
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto ctx = moduleOp.getContext();
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(key + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> contentStr(content);
size_t contentSize = contentStr.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(ctx), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(contentStr));
}
Value zero = i32_val(0);
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
globalPtr, SmallVector<Value>({zero, zero}));
return stringStart;
}
} // namespace LLVM
bool isF8(Type eType) {
return eType.isFloat8E4M3FNUZ() or eType.isFloat8E4M3FN() or
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
}
} // namespace mlir