Merge remote-tracking branch 'oai/main' into ifu230601

Conflicts:
	python/test/unit/language/assert_helper.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-06-01 20:53:33 +00:00
43 changed files with 1381 additions and 409 deletions

View File

@@ -222,6 +222,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
${conversion_libs}
# optimizations
MLIRBytecodeWriter
MLIRPass
MLIRTransforms
MLIRLLVMDialect

View File

@@ -23,6 +23,8 @@ Creation Ops
:nosignatures:
arange
cat
full
zeros
@@ -33,11 +35,13 @@ Shape Manipulation Ops
:toctree: generated
:nosignatures:
broadcast
broadcast_to
expand_dims
reshape
ravel
reshape
trans
view
Linear Algebra Ops
@@ -83,11 +87,13 @@ Math Ops
abs
exp
log
fdiv
cos
sin
sqrt
sigmoid
softmax
umulhi
Reduction Ops
@@ -151,4 +157,27 @@ Compiler Hint Ops
:toctree: generated
:nosignatures:
debug_barrier
max_contiguous
multiple_of
Debug Ops
-----------------
.. autosummary::
:toctree: generated
:nosignatures:
static_print
static_assert
device_print
device_assert
Iterators
-----------------
.. autosummary::
:toctree: generated
:nosignatures:
static_range

View File

@@ -52,4 +52,16 @@ def TT_AtomicRMWAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
// Program ID dimensions.
def TT_ProgramDim : I32EnumAttr<
"ProgramIDDim", "",
[
I32EnumAttrCase<"X", 0, "x">,
I32EnumAttrCase<"Y", 1, "y">,
I32EnumAttrCase<"Z", 2, "z">,
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@@ -37,6 +37,7 @@ def Triton_Dialect : Dialect {
let hasConstantMaterializer = 1;
let useDefaultTypePrinterParser = 1;
let usePropertiesForAttributes = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"

View File

@@ -351,11 +351,17 @@ def TT_TransOp : TT_Op<"trans", [Pure,
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
let arguments = (ins I32Attr:$axis);
let arguments = (ins TT_ProgramDim:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
let assemblyFormat = "$axis attr-dict `:` type($result)";
let extraClassDeclaration = [{
int32_t getAxisAsInt() {
return static_cast<int32_t>(getAxis());
}
}];
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
@@ -479,7 +485,7 @@ def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict `:` ($args^ `:` type($args))?
$prefix attr-dict (`:` $args^ `:` type($args))?
}];
}

View File

@@ -92,6 +92,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase

View File

@@ -23,14 +23,16 @@ def TritonGPU_Dialect : Dialect {
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
if(!numWarps)
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
return numWarps.cast<IntegerAttr>().getInt();
}
}];
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}
#endif

View File

@@ -111,15 +111,15 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
}
}
} else if (Operation *op = value.getDefiningOp()) {
if (Attribute attr = op->getAttr("tt.divisibility")) {
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownDivisibility = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.contiguity")) {
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownContiguity = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownConstancy = DimVectorT(vals.begin(), vals.end());
}
@@ -888,15 +888,15 @@ void AxisInfoAnalysis::visitOperation(
auto newContiguity = curr.getContiguity();
auto newDivisibility = curr.getDivisibility();
auto newConstancy = curr.getConstancy();
if (Attribute attr = op->getAttr("tt.contiguity")) {
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.divisibility")) {
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
}

View File

@@ -87,6 +87,10 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
@@ -544,9 +548,40 @@ private:
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
return success();
}
LogicalResult
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
SmallVector<Value> outVals = loadSharedToDistributed(
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

View File

@@ -303,6 +303,8 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
} else {
if (needTrans && (4 / elemBytes) != kWidth)
llvm_unreachable("unimplemented Shared -> DotOperandMmav2 code path");
// base pointers
std::array<std::array<Value, 4>, 2> ptrs;
int vecWidth = 4 / elemBytes;
@@ -329,8 +331,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
// row + trans and col + no-trans are equivalent
bool isActualTrans =
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
if (isActualTrans)
std::swap(vptrs[1], vptrs[2]);
// pack loaded vectors into 4 32-bit values
int inc = needTrans ? 1 : kWidth;
VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc);
@@ -348,12 +348,15 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
Value val = load(ptr);
Value canonval = bitcast(val, vec_ty(canonInt, canonWidth));
for (int w = 0; w < canonWidth; ++w) {
retElems[idx + w * kWidth / vecWidth] =
insert_element(retElems[idx + w * kWidth / vecWidth],
int ridx = idx + w * kWidth / vecWidth;
retElems[ridx] =
insert_element(retElems[ridx],
extract_element(canonval, i32_val(w)), i32_val(e));
}
}
}
if (isActualTrans)
std::swap(retElems[1], retElems[2]);
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
}

View File

@@ -440,10 +440,10 @@ struct GetProgramIdOpConversion
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.getAxis() < 3);
assert(op.getAxisAsInt() < 3);
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
}

View File

@@ -119,9 +119,8 @@ protected:
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
if (funcOp->hasAttr("llvm.linkage")) {
auto attr =
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
if (auto linkageAttr = funcOp->getDiscardableAttr("llvm.linkage")) {
auto attr = linkageAttr.dyn_cast<mlir::LLVM::LinkageAttr>();
if (!attr) {
funcOp->emitError()
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
@@ -360,6 +359,55 @@ public:
return ret;
}
SmallVector<Value>
loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
Value src, SharedMemoryObject smemObj, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) const {
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of loadSharedToDistributed");
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstDistributedLayout = dstTy.getEncoding();
if (auto mmaLayout = dstDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout Shared->MMAv1 is not supported yet");
}
auto srcSharedLayout =
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto srcElemTy = srcTy.getElementType();
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec =
inOrd == outOrd
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
: 1;
unsigned inVec = srcSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
assert(outElems == dstIndices.size());
DenseMap<unsigned, Value> sharedPtrs = getSwizzledSharedPtrs(
loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter,
smemObj.offsets, smemObj.strides);
assert(outElems % minVec == 0 && "Unexpected number of elements");
unsigned numVecs = outElems / minVec;
auto wordTy = vec_ty(elemTy, minVec);
SmallVector<Value> outVals(outElems);
for (unsigned i = 0; i < numVecs; ++i) {
Value smemAddr = sharedPtrs[i * minVec];
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
Value valVec = load(smemAddr);
for (unsigned v = 0; v < minVec; ++v) {
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
outVals[i * minVec + v] = currVal;
}
}
return outVals;
}
void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> dstStrides,
ArrayRef<SmallVector<Value>> srcIndices,
@@ -387,16 +435,11 @@ public:
: 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals =
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec);
Value word;
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};

View File

@@ -310,8 +310,7 @@ public:
// Preprocess
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
if (failed(decomposeInsertSliceAsyncOp(mod)))
return signalPassFailure();
decomposeInsertSliceAsyncOp(mod);
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
@@ -490,7 +489,7 @@ private:
});
}
LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
@@ -518,6 +517,7 @@ private:
// Get the vectorized load size
auto src = insertSliceAsyncOp.getSrc();
auto dst = insertSliceAsyncOp.getDst();
auto mask = insertSliceAsyncOp.getMask();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlocked =
@@ -526,6 +526,9 @@ private:
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
if (mask)
inVec =
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =
@@ -597,7 +600,6 @@ private:
}
#endif
});
return success();
}
};

View File

@@ -796,7 +796,6 @@ public:
if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(),
*converter)))
return failure();
rewriter.eraseOp(op);
return success();
}
};

View File

@@ -446,10 +446,11 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Properties *prop = properties.as<Properties *>();
int axis = prop->axis.getInt();
for (auto arg : operands) {
auto argTy = arg.getType().cast<RankedTensorType>();
auto retEltTy = argTy.getElementType();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes)
.failed()) {
return failure();
@@ -557,7 +558,8 @@ mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
auto arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
Properties *prop = properties.as<Properties *>();
int axis = prop->axis.getInt();
retShape.insert(retShape.begin() + axis, 1);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
@@ -740,7 +742,7 @@ void triton::FuncOp::print(OpAsmPrinter &printer) {
LogicalResult
triton::CallOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
auto fnAttr = (*this).getProperties().callee;
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);

View File

@@ -55,10 +55,6 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)

View File

@@ -81,15 +81,18 @@ public:
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
static mlir::LogicalResult
isBlockedToDotOperand(mlir::Operation *op,
triton::gpu::DotOperandEncodingAttr &retEncoding,
triton::gpu::BlockedEncodingAttr &srcEncoding) {
if (!op)
return failure();
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
auto retEncoding =
retEncoding =
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
auto srcEncoding =
srcEncoding =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
if (!retTy)
return failure();
@@ -101,6 +104,51 @@ public:
return failure();
if (!srcEncoding)
return failure();
return success();
}
static bool isTrans(const triton::gpu::DotOperandEncodingAttr &retEncoding,
const triton::gpu::BlockedEncodingAttr &srcEncoding) {
int kOrder = retEncoding.getOpIdx() ^ 1;
return kOrder != srcEncoding.getOrder()[0];
}
static bool isDotNT(triton::DotOp dotOp) {
triton::gpu::DotOperandEncodingAttr aRetEncoding;
triton::gpu::DotOperandEncodingAttr bRetEncoding;
triton::gpu::BlockedEncodingAttr aSrcEncoding;
triton::gpu::BlockedEncodingAttr bSrcEncoding;
if (isBlockedToDotOperand(dotOp.getOperand(0).getDefiningOp(), aRetEncoding,
aSrcEncoding)
.failed())
return false;
if (isBlockedToDotOperand(dotOp.getOperand(1).getDefiningOp(), bRetEncoding,
bSrcEncoding)
.failed())
return false;
if (!aRetEncoding || !bRetEncoding || !aSrcEncoding || !bSrcEncoding)
return false;
return !isTrans(aRetEncoding, aSrcEncoding) &&
!isTrans(bRetEncoding, bSrcEncoding);
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
triton::gpu::DotOperandEncodingAttr retEncoding;
triton::gpu::BlockedEncodingAttr srcEncoding;
if (isBlockedToDotOperand(op, retEncoding, srcEncoding).failed())
return mlir::failure();
// only supports dot NT
auto users = cvt->getUsers();
auto dotOp = dyn_cast_or_null<DotOp>(*users.begin());
if (!dotOp)
return failure();
if (!isDotNT(dotOp))
return failure();
// don't move things around when cvt operand is a block arg
Operation *argOp = cvt.getOperand().getDefiningOp();
if (!argOp)
@@ -129,8 +177,8 @@ public:
return failure();
// we don't want to use ldmatrix for 8-bit data that requires trans
// since Nvidia GPUs can't do it efficiently
bool isTrans =
(retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0);
int kOrder = retEncoding.getOpIdx() ^ 1;
bool isTrans = kOrder != srcEncoding.getOrder()[0];
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
if (isTrans && isInt8)
return failure();

View File

@@ -8,6 +8,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "llvm/ADT/MapVector.h"
//===----------------------------------------------------------------------===//
//
@@ -17,6 +18,7 @@
//
//===----------------------------------------------------------------------===//
using llvm::MapVector;
using namespace mlir;
namespace ttg = triton::gpu;
@@ -25,9 +27,19 @@ namespace ttg = triton::gpu;
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
for (const NamedAttribute attr : dictAttrs.getValue())
if (!op->hasAttr(attr.getName()))
op->setAttr(attr.getName(), attr.getValue());
NamedAttrList attrs = op->getDiscardableAttrs();
// Collect the attributes to propagate: the ones in dictAttrs and not yet on
// the operation.
SmallVector<NamedAttribute> toPropagate;
for (const NamedAttribute attr : dictAttrs.getValue()) {
if (!attrs.get(attr.getName()))
toPropagate.push_back(attr);
}
// If we found any, let's set them here as a single step.
if (toPropagate.size()) {
attrs.append(toPropagate);
op->setDiscardableAttrs(attrs);
}
}
#define int_attr(num) builder.getI64IntegerAttr(num)
@@ -69,31 +81,46 @@ class LoopPipeliner {
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// Block arguments that loads depend on
SetVector<BlockArgument> depArgs;
/// For each argument, we need to record at which stage it is defined.
/// If we have a load that immediately depends on a block argument in the
/// current iteration, it is an immediate dependency. Otherwise, it is a
/// non-immediate dependency, which means the load depends on a block argument
/// in the previous iterations.
/// For example:
/// scf.for (%arg0, %arg1, %arg2) {
/// %0 = load %arg0 <--- immediate dep, this address is initialized at
/// numStages-2
/// %0 = load %arg0 <--- immediate dep, this address is initialized before
/// numStages-1
/// %1 = load %arg1
/// %2 = add %1, %arg2
/// %3 = load %2 <--- non-immediate dep, %arg1 must be an update-to-date
/// value
/// }
SetVector<BlockArgument> immedidateDepArgs;
/// Collect values that v depends on and are defined inside the loop
LogicalResult collectDeps(Value v, int stage,
MapVector<Value, int> &depStage);
SetVector<BlockArgument> nonImmedidateDepArgs;
/// Associate each variable with a unique stage. If a variable is defined
/// at multiple stages, we don't pipeline it.
LogicalResult addDep(Value v, int stage, MapVector<Value, int> &depStage);
int getArgDefStage(Value v, int stage);
/// Block arguments that loads depend on
MapVector<BlockArgument, int> depArgUseStage;
/// Block arguments that loads depend on (defined in the loop body)
MapVector<BlockArgument, int> depArgDefStage;
/// Operations (inside the loop body) that loads depend on
SetVector<Operation *> depOps;
MapVector<Operation *, int> depOpDefStage;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, SetVector<Value> &deps);
/// Operations (inside the loop body) that loads depend on (defined in the
/// loop body)
SetVector<BlockArgument> immediateDepArgs;
/// Operations (inside the loop body) that loads depend on (defined in the
/// previous iterations)
SetVector<BlockArgument> nonImmediateDepArgs;
void setValueMapping(Value origin, Value newValue, int stage);
@@ -140,31 +167,60 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
return valueMapping[origin][stage];
}
void LoopPipeliner::collectDeps(Value v, int stages, SetVector<Value> &deps) {
LogicalResult LoopPipeliner::addDep(Value v, int stage,
MapVector<Value, int> &depStage) {
if (!depStage.contains(v)) {
depStage.insert(std::make_pair(v, stage));
} else if (depStage[v] != stage)
return failure();
return success();
}
LogicalResult LoopPipeliner::collectDeps(Value v, int stage,
MapVector<Value, int> &depStage) {
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
return success();
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
if (stage < 0)
return success();
if (auto arg = v.dyn_cast<BlockArgument>()) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
// If we've found the first definition of this arg, we're done, don't
// recurse
if (addDep(v, stage, depStage).succeeded())
if (collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1,
depStage)
.failed())
return failure();
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations
deps.insert(v);
// An operation cannot be dependent on different stages
if (addDep(v, stage, depStage).failed())
return failure();
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
if (collectDeps(op, stage, depStage).failed())
return failure();
}
return success();
}
int LoopPipeliner::getArgDefStage(Value v, int stage) {
if (stage < 0)
return -1;
if (auto arg = v.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0) {
return getArgDefStage(yieldOp->getOperand(arg.getArgNumber() - 1),
stage - 1);
}
llvm_unreachable("Loop induction variable should not be a dependency");
} else
return stage;
}
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
@@ -188,7 +244,8 @@ LogicalResult LoopPipeliner::initialize() {
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
// can we use forOp.walk(...) here?
// We cannot use forOp.walk(...) here because we only want to visit the
// operations in the loop body block. Nested blocks are handled separately.
SmallVector<triton::LoadOp, 2> validLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
@@ -205,7 +262,11 @@ LogicalResult LoopPipeliner::initialize() {
.cast<triton::PointerType>()
.getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();
// cp.async's cp-size can only be 4, 8 and 16.
// We do not pipeline all loads for the following reasons:
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
// 2. It's likely that pipling small load won't offer much performance
// improvement and may even hurt performance by increasing register
// pressure.
if (width >= 32)
validLoads.push_back(loadOp);
}
@@ -215,12 +276,28 @@ LogicalResult LoopPipeliner::initialize() {
return failure();
// load => values that it depends on
// Don't pipeline if any load's operands
DenseMap<Value, SetVector<Value>> loadDeps;
MapVector<Value, int> depStage;
for (triton::LoadOp loadOp : validLoads) {
SetVector<Value> deps;
for (Value op : loadOp->getOperands())
collectDeps(op, numStages - 1, deps);
loadDeps[loadOp] = deps;
for (Value op : loadOp->getOperands()) {
MapVector<Value, int> operandDepStage;
if (collectDeps(op, numStages - 1, operandDepStage).failed())
return failure();
for (auto [v, stage] : operandDepStage) {
auto immedidate = operandDepStage.front().first.isa<BlockArgument>();
if (v.isa<BlockArgument>()) {
auto arg = v.cast<BlockArgument>();
if (immedidate)
immediateDepArgs.insert(arg);
else
nonImmediateDepArgs.insert(arg);
}
loadDeps[loadOp].insert(v);
if (addDep(v, stage, depStage).failed())
return failure();
}
}
}
// Don't pipeline valid loads that depend on other valid loads
@@ -268,9 +345,7 @@ LogicalResult LoopPipeliner::initialize() {
continue;
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
}
else
} else
isCandidate = false;
if (isCandidate)
@@ -306,7 +381,7 @@ LogicalResult LoopPipeliner::initialize() {
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
ttg::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}
@@ -314,27 +389,19 @@ LogicalResult LoopPipeliner::initialize() {
// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
for (Value loadOp : loads) {
auto &deps = loadDeps[loadOp];
for (auto &dep : deps) {
if (auto arg = dep.dyn_cast<BlockArgument>()) {
depArgs.insert(arg);
if (deps.front().isa<BlockArgument>()) {
immedidateDepArgs.insert(arg);
} else {
nonImmedidateDepArgs.insert(arg);
}
} else
depOps.insert(dep.getDefiningOp());
}
for (auto [dep, stage] : depStage) {
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgUseStage.insert({arg, stage});
else
depOpDefStage.insert({dep.getDefiningOp(), stage});
}
return success();
}
// Check if immedidateDepArgs and nonImmedidateDepArgs are disjoint
// If yes, we cannot pipeline the loop for now
for (BlockArgument arg : immedidateDepArgs)
if (nonImmedidateDepArgs.contains(arg)) {
for (BlockArgument arg : immediateDepArgs)
if (nonImmediateDepArgs.contains(arg)) {
return failure();
}
@@ -365,12 +432,13 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
void LoopPipeliner::emitPrologue() {
OpBuilder builder(forOp);
// Get init operands for loop carried values
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
// prologue from [0, numStage-1)
// Emit prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
@@ -386,12 +454,12 @@ void LoopPipeliner::emitPrologue() {
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
if (depOpDefStage.contains(&op))
orderedDeps.push_back(&op);
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
assert(depOpDefStage.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
for (Operation *op : orderedDeps) {
Operation *newOp = nullptr;
@@ -406,14 +474,13 @@ void LoopPipeliner::emitPrologue() {
Value newMask =
getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage),
loopCond, builder);
// TODO: check if the hardware supports async copy
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
newOp = builder.create<ttg::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
lookupOrDefault(loadOp.getPtr(), stage),
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0);
builder.create<triton::gpu::AsyncCommitGroupOp>(op->getLoc());
builder.create<ttg::AsyncCommitGroupOp>(op->getLoc());
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
} else
llvm_unreachable("This should be LoadOp");
@@ -428,10 +495,9 @@ void LoopPipeliner::emitPrologue() {
lookupOrDefault(loadOp.getOther(), stage),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
addNamedAttrs(newOp, op->getAttrDictionary());
} else {
addNamedAttrs(newOp, op->getDiscardableAttrDictionary());
} else
newOp = builder.clone(*op);
}
// Update loop-carried uses
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
@@ -443,32 +509,40 @@ void LoopPipeliner::emitPrologue() {
}
}
// Update mapping of results
// if (stage == numStages - 2)
// continue;
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
Value originResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
// TODO: load should not be used in the preheader?
if (loads.contains(originalResult)) {
if (loads.contains(originResult))
break;
// originalResult = loadsMapping[originalResult];
}
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
setValueMapping(originResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
if (operand.get() == op->getResult(dstIdx)) {
auto yieldIdx = operand.getOperandNumber();
auto value = forOp.getRegionIterArgs()[yieldIdx];
setValueMapping(value, newOp->getResult(dstIdx), stage + 1);
}
}
}
} // for (Operation *op : orderedDeps)
// Update pipeline index
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs()) {
// Check if arg has a yieldOp use
for (OpOperand &operand : arg.getUses()) {
if (operand.getOwner() == yieldOp) {
auto yieldIdx = operand.getOperandNumber();
auto value = forOp.getRegionIterArgs()[yieldIdx];
if (!valueMapping[value][stage + 1])
setValueMapping(value, valueMapping[arg][stage], stage + 1);
}
}
}
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
@@ -484,7 +558,7 @@ void LoopPipeliner::emitPrologue() {
sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]},
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
Value extractSlice = builder.create<ttg::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
@@ -505,7 +579,7 @@ void LoopPipeliner::emitEpilogue() {
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
}
scf::ForOp LoopPipeliner::createNewForOp() {
@@ -537,22 +611,24 @@ scf::ForOp LoopPipeliner::createNewForOp() {
newLoopArgs.push_back(loadsExtract[loadOp]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
for (auto [depArg, useStage] : depArgUseStage) {
depArgsIdx[depArg] = newLoopArgs.size();
if (immedidateDepArgs.contains(depArg)) {
auto defStage = getArgDefStage(depArg, useStage);
assert(defStage >= 0 &&
"newLoopArgs has null args without a define op. Consider either "
"rewrite the loop to reduce cross iteration dependencies or "
"increase the num_stages value.");
if (immediateDepArgs.contains(depArg) && defStage == numStages - 2) {
newLoopArgs.push_back(valueMapping[depArg][numStages - 2]);
} else
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
size_t ivIndex = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
@@ -565,7 +641,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2. clone the loop body, replace original args with args of the new ForOp
// 3. clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
DenseSet<Value> isModified;
for (Operation &op : forOp.getBody()->without_terminator()) {
@@ -597,50 +673,54 @@ scf::ForOp LoopPipeliner::createNewForOp() {
isModified.insert(op.getResult(0));
}
// 3. prefetch the next iteration
// 4. prefetch the next iteration
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
if (depOpDefStage.contains(&op))
orderedDeps.push_back(&op);
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
assert(depOpDefStage.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
IRMapping nextMapping;
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
for (auto [depArg, useStage] : depArgUseStage) {
BlockArgument nextArg =
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx];
nextMapping.map(arg, nextArg);
nextMapping.map(depArg, nextArg);
++argIdx;
}
// Special handling for iv & loop condition
Value curIV = newForOp.getRegionIterArgs()[ivIndex];
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
newForOp.getInductionVar().getLoc(), curIV, newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
// Prefetch load deps
for (Operation *op : orderedDeps)
if (!loads.contains(op->getResult(0))) {
if (depOpDefStage[op] == numStages - 2)
nextMapping.map(forOp.getInductionVar(), curIV);
else
nextMapping.map(forOp.getInductionVar(), nextIV);
Operation *nextOp;
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
auto newMask =
@@ -652,20 +732,20 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.lookupOrDefault(loadOp.getOther()),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
addNamedAttrs(nextOp, op->getAttrDictionary());
addNamedAttrs(nextOp, op->getDiscardableAttrDictionary());
nextMapping.map(loadOp.getResult(), nextOp->getResult(0));
} else {
nextOp = builder.clone(*op, nextMapping);
}
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
for (OpOperand &operand : originYield->getOpOperands()) {
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
nextMapping.map(forOp.getRegionIterArgs()[originIdx],
size_t yieldIdx = operand.getOperandNumber();
size_t depYieldIdx =
depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[depYieldIdx];
nextMapping.map(forOp.getRegionIterArgs()[yieldIdx],
nextOp->getResult(dstIdx));
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
@@ -673,6 +753,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
// loads -> async loads
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// Update loading mask
@@ -689,14 +770,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.map(loadOp.getMask(), newMask);
newMask = nextMapping.lookupOrDefault(mask);
}
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
Value insertAsyncOp = builder.create<ttg::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.getPtr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, newMask,
nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0);
builder.create<triton::gpu::AsyncCommitGroupOp>(op->getLoc());
builder.create<ttg::AsyncCommitGroupOp>(op->getLoc());
nextBuffers.push_back(insertAsyncOp);
// ExtractSlice
auto bufferType = insertAsyncOp.getType().cast<RankedTensorType>();
@@ -706,7 +787,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
nextOp = builder.create<ttg::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
@@ -720,12 +801,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
auto yieldIdx = operand.getOperandNumber();
auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
auto newArg = newForOp.getRegionIterArgs()[depYieldIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
@@ -733,6 +813,22 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs()) {
// Check if arg has a yieldOp use
for (OpOperand &operand : arg.getUses()) {
if (operand.getOwner() == yieldOp) {
auto yieldIdx = operand.getOperandNumber();
auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
auto newArg = newForOp.getRegionIterArgs()[depYieldIdx];
if (!depArgsMapping.contains(newArg)) {
auto argIdx = depArgsIdx[arg];
depArgsMapping[newArg] = newForOp.getRegionIterArgs()[argIdx];
}
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
@@ -751,14 +847,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
for (Value v : yieldOp->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
for (size_t i = depArgsBeginIdx; i < ivIndex; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
@@ -768,8 +864,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
builder.create<scf::YieldOp>(yieldOp->getLoc(), yieldValues);
return newForOp;
}

View File

@@ -349,27 +349,23 @@ public:
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return op->getBlock() == cvt->getBlock() &&
!isa<triton::gpu::ConvertLayoutOp, scf::YieldOp>(op) &&
!(isa<triton::ReduceOp>(op) &&
!op->getResult(0).getType().isa<RankedTensorType>()) &&
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
!op->getResult(0).getType().isa<RankedTensorType>());
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty()) {
if (cvtSlices.empty())
return failure();
}
llvm::MapVector<Value, Attribute> toConvert;
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, dstEncoding)) {
if (expensiveToRemat(op, dstEncoding))
return failure();
}
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op) && !isa<triton::ReduceOp>(op)) {
!isa<triton::StoreOp, triton::ReduceOp>(op))
return failure();
}
// don't rematerialize if it adds an extra conversion that can't
// be removed
for (Value arg : op->getOperands()) {
@@ -380,9 +376,8 @@ public:
int numAddedConvs = simulateBackwardRematerialization(
argOp, processed, layout, toConvert, srcEncoding);
if (argOp && !isa<triton::gpu::ConvertLayoutOp>(argOp) &&
cvtSlices.count(argOp) == 0 && numAddedConvs > 0) {
cvtSlices.count(argOp) == 0 && numAddedConvs > 0)
return failure();
}
}
}
@@ -425,7 +420,6 @@ public:
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue;
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
targetType.getEncoding()) > 0)
return mlir::failure();
@@ -507,49 +501,15 @@ public:
auto forOp = cast<scf::ForOp>(op);
auto iterArgs = forOp.getRegionIterArgs();
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1)
// continue;
// skip non-tensor types
if (!iterArg.value().getType().isa<RankedTensorType>())
continue;
// we only move `iterArg` out of the loop if
// - there is only a single conversion use
// - moving this conversion out of the loop will not generate
// any extra non-removable conversion
auto users = iterArg.value().getUsers();
// check first condition
SetVector<Type> cvtTargetTypes;
for (auto user : users) {
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
auto newType =
user->getResults()[0].getType().cast<RankedTensorType>();
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
newType.getEncoding()
.isa<triton::gpu::DotOperandEncodingAttr>()) {
continue;
}
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (newType.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getVec() == 1)
continue;
}
cvtTargetTypes.insert(newType);
}
}
if (cvtTargetTypes.size() != 1)
SmallVector<Operation *> cvts;
if (canMoveOutOfLoop(iterArg.value(), cvts).failed())
continue;
// TODO: check second condition
for (auto user : users) {
if (isa<triton::gpu::ConvertLayoutOp>(user))
continue;
}
// check
for (auto op : iterArg.value().getUsers()) {
for (auto *op : cvts) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
continue;
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
targetType, cvt);

View File

@@ -269,4 +269,64 @@ void rematerializeConversionChain(
}
}
LogicalResult canMoveOutOfLoop(BlockArgument arg,
SmallVector<Operation *> &cvts) {
auto parentOp = arg.getOwner()->getParentOp();
// Don't move if arg is defined in a while loop
if (isa<scf::WhileOp>(parentOp))
return failure();
// Skip if arg is not defined in scf.for
if (!isa<scf::ForOp>(parentOp))
return success();
auto forOp = cast<scf::ForOp>(parentOp);
// We only move `iterArg` out of the loop if
// 1. There is no conversion
// 2. There is only a single conversion
// 3. Moving this conversion out of the loop will not generate any extra
// non-removable conversion
DenseSet<Type> cvtTypes;
SetVector<Operation *> others;
auto oldType = arg.getType().cast<RankedTensorType>();
for (auto user : arg.getUsers()) {
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
// Don't move if the conversion target is a dot operand or shared memory
auto newType = user->getResults()[0].getType().cast<RankedTensorType>();
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
newType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
continue;
}
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (newType.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getVec() == 1)
continue;
}
cvts.emplace_back(user);
cvtTypes.insert(newType);
} else
others.insert(user);
}
// First condition
if (cvts.empty())
return success();
if (cvtTypes.size() == 1) {
// Second condition
if (others.empty())
return success();
// Third condition: not complete
// If the other or the cvt is in the different block, we cannot push the
// conversion forward or backward
for (auto *cvt : cvts) {
if (cvt->getBlock() != forOp.getBody())
return failure();
}
for (auto *other : others) {
if (other->getBlock() != forOp.getBody())
return failure();
}
return success();
}
return failure();
}
} // namespace mlir

View File

@@ -16,6 +16,8 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
// skipInit is True when we only consider the operands of the initOp but
// not the initOp itself.
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
@@ -28,6 +30,10 @@ void rematerializeConversionChain(
const llvm::MapVector<Value, Attribute> &toConvert,
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
IRMapping &mapping);
LogicalResult canMoveOutOfLoop(BlockArgument arg,
SmallVector<Operation *> &cvts);
} // namespace mlir
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -158,12 +158,12 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
for (LLVM::LLVMFuncOp func : funcs) {
if (auto libnameAttr = func->getDiscardableAttr("libname")) {
auto name = libnameAttr.dyn_cast<StringAttr>();
auto path = func.getOperation()
->getDiscardableAttr("libpath")
.dyn_cast<StringAttr>();
if (name) {
std::string libName = name.str();
externLibs[libName] = path.str();
@@ -171,11 +171,8 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
if (auto externsAttr = module->getDiscardableAttr("triton_gpu.externs")) {
for (auto &attr : externsAttr.cast<DictionaryAttr>()) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}

View File

@@ -3,6 +3,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -350,6 +352,14 @@ void init_triton_ir(py::module &&m) {
self.print(os);
return str;
})
.def("bytecode",
[](mlir::ModuleOp &self) -> py::bytearray {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
if (failed(mlir::writeBytecodeToFile(self, os)))
throw std::runtime_error("Failed to write module bytecode");
return py::bytearray(bytecode);
})
.def("push_back",
[](mlir::ModuleOp &self, mlir::triton::FuncOp &funcOp) -> void {
self.push_back(funcOp);
@@ -441,11 +451,15 @@ void init_triton_ir(py::module &&m) {
// 1. Unreachable code after return
self.walk([&](mlir::Block *block) {
mlir::Operation *retOp = nullptr;
block->walk([&](mlir::Operation *op) {
// It's better to not use walk here because we only want to
// check operations in the current block
for (auto &op : block->getOperations()) {
if (mlir::isa<mlir::triton::ReturnOp>(op))
if (retOp == nullptr)
retOp = op;
});
if (retOp == nullptr) {
retOp = &op;
break;
}
}
if (retOp && retOp != &block->back()) {
auto pos = retOp->getIterator();
pos++;
@@ -1413,8 +1427,12 @@ void init_triton_ir(py::module &&m) {
.def("create_get_program_id",
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (axis < 0 || axis > 3)
throw std::runtime_error("program_id must be in [0,3]");
return self.create<mlir::triton::GetProgramIdOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis));
loc, self.getI32Type(),
mlir::triton::ProgramIDDimAttr::get(
loc.getContext(), mlir::triton::ProgramIDDim(axis)));
})
.def("create_get_num_programs",
[](mlir::OpBuilder &self, int axis) -> mlir::Value {

View File

@@ -1,4 +1,7 @@
import numpy as np
import pytest
import torch
from numpy.random import RandomState
import triton
import triton.language as tl
@@ -66,3 +69,162 @@ def test_chained_matmul():
block_k=block_k)
assert (torch_result == triton_result).all()
def test_vecmat():
@triton.jit
def batched_vecmat(
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
dim_m, dim_n, dim_k,
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
):
m_index = tl.program_id(0)
n_index = tl.program_id(1)
# Output tile
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
k_blocks = dim_k // block_k
for k_index in range(k_blocks):
# Load A tile
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
a = tl.load(A + a_tile)
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
# leading dimension.
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
b = tl.load(B + b_tile)
expanded_a, _ = tl.broadcast(a, b)
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
tl.store(output + output_tile, vecmat)
M, N, K = 128, 128, 128
block_m, block_n, block_k = 16, 32, 64
rs = RandomState(17)
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
A = A_vec
B = B_vec
A_tri = torch.tensor(A, device='cuda')
B_tri = torch.tensor(B, device='cuda')
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
grid = (M // block_m, N // block_n)
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
block_m=block_m, block_n=block_n, block_k=block_k,
num_warps=4, num_stages=1)
A_expanded = A[:, np.newaxis, :]
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
AB = A_broadcasted * B
C_ref = np.sum(AB, axis=2)
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
def test_iv_dependent_matmul(type):
@triton.jit
def kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
type: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a_ptrs = a_ptr
b_ptrs = b_ptr
if type == "post_load_two_iters":
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
elif type == "post_load_three_iters":
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak
b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if type == "pre_load":
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk
elif type == "post_pre_mixed":
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator += tl.dot(a, b)
if type == "post_load":
a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
elif type == "post_pre_mixed":
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
elif type == "post_load_two_iters":
a_ptrs = a_ptrs_next
b_ptrs = b_ptrs_next
a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk
elif type == "post_load_three_iters":
a_ptrs = a_ptrs_next
b_ptrs = b_ptrs_next
a_ptrs_next = a_ptrs_next_next
b_ptrs_next = b_ptrs_next_next
a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak
b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
M = 256
K = 256
N = 256
BLOCK_SIZE_K = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_M = 32
a = torch.rand((M, K), device='cuda')
b = torch.rand((K, N), device='cuda')
torch_output = torch.mm(a, b)
triton_output = torch.empty_like(
torch_output, device=torch_output.device)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
num_stages = 4 if type == "post_load_three_iters" else 3
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
type=type, num_stages=num_stages)
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)

View File

@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -41,8 +48,16 @@ def test_assert(func: str):
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
<<<<<<< HEAD
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
=======
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
elif func == "no_debug":
# TRITON_DEBUG=True can override the debug flag
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
>>>>>>> oai/main
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":
@@ -50,5 +65,72 @@ def test_assert(func: str):
assert_close(y, x)
@triton.jit
def jit_device_assert_none(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=True)
def jit_device_assert_true(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=False)
def jit_device_assert_false(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=True)
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert_nested(caller: str, callee: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if caller == "none":
kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "true":
kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "false":
kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])
if len(sys.argv) == 3:
test_assert_nested(sys.argv[1], sys.argv[2])
else:
test_assert(sys.argv[1])

View File

@@ -130,6 +130,17 @@ class BlockedLayout:
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
class SharedLayout:
def __init__(self, vec, per_phase, max_phase, order):
self.vec = str(vec)
self.per_phase = str(per_phase)
self.max_phase = str(max_phase)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@@ -456,6 +467,21 @@ def test_broadcast(dtype):
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ------------------
# test invalid slice
# ------------------
def test_invalid_slice():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
dst[10:]
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
_kernel[(1,)](dst=dst)
# ----------------
# test expand_dims
@@ -537,6 +563,20 @@ def test_expand_dims_error_cases():
duplicate_dim2[(1,)](dummy_tensor, N)
# ----------------------------
# test invalid program id axis
# ----------------------------
def test_invalid_pid_axis():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
pid = tl.program_id(20)
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
_kernel[(1,)](dst)
# ---------------
# test where
# ---------------
@@ -1368,6 +1408,9 @@ reduce_configs2 = [
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for shape in reduce2d_shapes
for axis in [0, 1]
] + [
(op, 'float32', [16, 32], None)
for op in ['min', 'max', 'sum']
]
@@ -1382,7 +1425,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if AXIS == 1:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 1:
tl.store(Z + range_m, z)
else:
tl.store(Z + range_n, z)
@@ -1407,7 +1452,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
ret_numel = 1 if axis is None else shape[1 - axis]
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)
@@ -1958,6 +2004,23 @@ def test_full(dtype_str):
assert torch.all(out_dynamic == 2)
@pytest.mark.parametrize("literal, dtype_str",
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
('float("inf")', "f32"), ('float("-inf")', "f32"),
('float("nan")', "f32"), ('float("-nan")', "f32"),
(0., "f32"),
(5, "i32"), (2**40, "i64"),])
def test_constexpr(literal, dtype_str):
@triton.jit
def kernel(out_ptr):
val = GENERATE_TEST_HERE
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
h = kernel_patched[(1,)](out)
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
@@ -2628,41 +2691,64 @@ def add_fn_static_cond(x, cond: tl.constexpr):
return x + 1
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit",
"jit", "jit_if", "jit_ifexp", "jit_expr",
"jit_static_cond", "jit_noinline", "jit_extern"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
pid = tl.program_id(0)
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32).to(tl.int32)
o = a
else:
if call_type == "attribute":
# call attribute
if pid == 0:
a = o
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
a = a.to(tl.int32).to(tl.int32) + 1
o = a
elif call_type == "attribute_jit":
# call attribute and jit function
if pid == 0:
a = o
a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1
o = a
elif call_type == "jit":
if pid == 0:
# regular function call
a = o
a = add_fn(a)
o = a
elif call_type == "jit_if":
# function without end_if block
if pid == 0:
a = o
a = add_fn_return(a, pid)
o = a
elif call_type == "jit_ifexp":
# ifexp expression
if pid == 0:
a = o
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
o = a
elif call_type == "jit_expr":
# call without return
if pid == 0:
a = o + 1
add_fn_expr(Out, a)
o = a
elif call_type == "jit_static_cond":
if pid == 0:
a = o + 1
add_fn_static_cond(o, call_type)
o = a
elif call_type == "jit_noinline":
if pid == 0:
a = o + 1
add_fn_noinline(a)
o = a
elif call_type == "jit_extern":
if pid == 0:
a = o + 1
tl.cdiv(a, a)
o = a
tl.store(Out, o)
@@ -2766,7 +2852,7 @@ def test_globaltimer():
def kernel(Out1, Out2):
start = tl.extra.cuda.globaltimer()
off = tl.arange(0, 128)
for i in range(100):
for i in range(10000):
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
end = tl.extra.cuda.globaltimer()
tl.store(Out2, end - start)
@@ -2810,22 +2896,49 @@ layouts = [
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
intermediate_layouts = [
None,
SharedLayout(1, 1, 1, [1, 0]),
SharedLayout(4, 2, 4, [1, 0]),
SharedLayout(2, 2, 4, [1, 0]),
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
""" if interm_layout is None else f"""
#src = {src_layout}
#interm = {interm_layout}
#dst = {dst_layout}
"""
conversion = f"""
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" if interm_layout is None else f"""
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
"""
ir = layouts + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
@@ -2840,8 +2953,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" + conversion + """
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
tt.return

View File

@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
func_types = ["device_assert", "assert", "static_assert"]
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
assert num_errs == 127
else:
assert num_errs == 0
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
def test_assert_nested(caller_type, callee_type):
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
if caller_type == "none":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0
elif caller_type == "true":
if callee_type == "false":
assert num_errs == 0
else:
assert num_errs == 127
elif caller_type == "false":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0

View File

@@ -5,7 +5,10 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()

View File

@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
assert len(kernel_add.cache[device]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 2
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
@triton.jit

View File

@@ -97,6 +97,97 @@ class enter_sub_region:
self.generator.local_defs = self.prev_defs
# Check if the given syntax node has an "early" return
class ContainsReturnChecker(ast.NodeVisitor):
def __init__(self, gscope):
self.gscope = gscope
def _visit_stmts(self, body) -> bool:
for s in body:
if self.visit(s):
return True
return False
def _visit_function(self, fn) -> bool:
# Currently we only support JITFunctions defined in the global scope
if isinstance(fn, JITFunction) and not fn.noinline:
fn_node = fn.parse()
return ContainsReturnChecker(self.gscope).visit(fn_node)
return False
def generic_visit(self, node) -> bool:
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.visit(item)
elif isinstance(value, ast.AST):
ret = ret or self.visit(value)
return ret
def visit_Attribute(self, node: ast.Attribute) -> bool:
# If the left part is a name, it's possible that
# we call triton native function or a jit function from another module.
# If the left part is not a name, it must return a tensor or a constexpr
# whose methods do not contain return statements
# e.g., (tl.load(x)).to(y)
# So we only check if the expressions within value have return or not
if isinstance(node.value, ast.Name):
if node.value.id in self.gscope:
value = self.gscope[node.value.id]
fn = getattr(value, node.attr)
return self._visit_function(fn)
return False
return self.visit(node.value)
def visit_Name(self, node: ast.Name) -> bool:
if type(node.ctx) == ast.Store:
return False
if node.id in self.gscope:
fn = self.gscope[node.id]
return self._visit_function(fn)
return False
def visit_Return(self, node: ast.Return) -> bool:
return True
def visit_Assign(self, node: ast.Assign) -> bool:
# There couldn't be an early return
# x = ...
return False
def visit_AugAssign(self, node: ast.AugAssign) -> bool:
# There couldn't be an early return
# x += ...
return False
def visit_Module(self, node: ast.Module) -> bool:
return self._visit_stmts(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
return self._visit_stmts(node.body)
def visit_If(self, node: ast.If) -> bool:
# TODO: optimize the following case in which we actually don't have
# a return when static_cond is false:
# if dynamic_cond
# if static_cond
# func_with_return
# else
# func_without_return
ret = self._visit_stmts(node.body)
if node.orelse:
ret = ret or self._visit_stmts(node.orelse)
return ret
def visit_IfExp(self, node: ast.IfExp) -> bool:
return self.visit(node.body) or self.visit(node.orelse)
def visit_Call(self, node: ast.Call) -> bool:
return self.visit(node.func)
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
module=None, is_kernel=False, function_types: Optional[Dict] = None,
@@ -166,63 +257,6 @@ class CodeGenerator(ast.NodeVisitor):
if ret_type is not None and isinstance(stmt, ast.Return):
self.last_ret_type = ret_type
# TODO: should be its own AST visitor
def contains_return_op(self, node):
if isinstance(node, ast.Return):
return True
elif isinstance(node, ast.Assign):
return self.contains_return_op(node.value)
elif isinstance(node, ast.Module):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.FunctionDef):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.Call):
def check_undefined_name(cur_node):
# Check if name is an undefined local variable,
# which can only be a tensor or a constexpr
if isinstance(cur_node.func, ast.Attribute):
if isinstance(cur_node.func.value, ast.Name):
name = cur_node.func.value.id
if name not in self.lscope and name not in self.gscope:
return True
return False
# chain of calls
# e.g., tl.load(a).to(tl.float32)
return check_undefined_name(cur_node.func.value)
return False
if check_undefined_name(node):
return False
fn = self.visit(node.func)
if isinstance(fn, JITFunction) and fn.noinline is not True:
old_gscope = self.gscope
self.gscope = sys.modules[fn.fn.__module__].__dict__
ret = self.contains_return_op(fn.parse())
self.gscope = old_gscope
return ret
return False
elif isinstance(node, ast.If):
pred = lambda s: self.contains_return_op(s)
ret = any(pred(s) for s in node.body)
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
elif isinstance(node, ast.IfExp):
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
elif isinstance(node, ast.Expr):
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.contains_return_op(item)
elif isinstance(value, ast.AST):
ret = ret or self.contains_return_op(value)
return ret
else:
return False
def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)
@@ -354,7 +388,8 @@ class CodeGenerator(ast.NodeVisitor):
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
value = _unwrap_if_constexpr(value)
if not _is_triton_tensor(value) and \
if value is not None and \
not _is_triton_tensor(value) and \
not isinstance(value, native_nontensor_types):
value = language.core._to_tensor(value, self.builder)
self.set_value(name, value)
@@ -526,7 +561,11 @@ class CodeGenerator(ast.NodeVisitor):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
cond = cond.to(language.int1, _builder=self.builder)
if self.scf_stack or not self.contains_return_op(node):
contains_return = ContainsReturnChecker(self.gscope).visit(node)
if self.scf_stack and contains_return:
raise UnsupportedLanguageConstruct(None, node,
"Cannot have `return` statements inside `while` or `for` statements in triton")
elif self.scf_stack or not contains_return:
self.visit_if_scf(cond, node)
else:
self.visit_if_top_level(cond, node)
@@ -825,7 +864,9 @@ class CodeGenerator(ast.NodeVisitor):
if not self.module.has_function(fn_name):
prototype = language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
# If the callee is not set, we use the same debug setting as the caller
debug = self.debug if fn.debug is None else fn.debug
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type

View File

@@ -55,7 +55,17 @@ def _to_tensor(x, builder):
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
return tensor(builder.get_fp32(x), float32)
min_float32 = 2 ** -126
max_float32 = (2 - 2**-23) * 2**127
abs_x = __builtins__['abs'](x)
if abs_x == float("inf") or\
abs_x == 0.0 or \
x != x or \
min_float32 <= abs_x <= max_float32:
return tensor(builder.get_fp32(x), float32)
else:
return tensor(builder.get_fp64(x), float64)
elif isinstance(x, constexpr):
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):
@@ -701,7 +711,7 @@ class tensor:
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim, _builder)
elif sl == slice(None, None, None):
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
pass
else:
assert False, f"unsupported tensor index: {sl}"
@@ -1307,8 +1317,8 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
axis = _constexpr_to_value(axis)
if axis is not None:
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@@ -1379,7 +1389,7 @@ def _max_combine(a, b):
@triton.jit
@_add_reduction_docstr("maximum")
def max(input, axis):
def max(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _max_combine)
@@ -1409,7 +1419,7 @@ def _min_combine(a, b):
@triton.jit
@_add_reduction_docstr("minimum")
def min(input, axis):
def min(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _min_combine)
@@ -1438,7 +1448,7 @@ def _sum_combine(a, b):
@triton.jit
@_add_reduction_docstr("sum")
def sum(input, axis):
def sum(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _sum_combine)
@@ -1450,7 +1460,7 @@ def _xor_combine(a, b):
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None, _generator=None):
def xor_sum(input, axis=None, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
@@ -1461,12 +1471,15 @@ def xor_sum(input, axis, _builder=None, _generator=None):
# -----------------------
# Internal for debugging
# Compiler Hint Ops
# -----------------------
@builtin
def debug_barrier(_builder=None):
'''
Insert a barrier to synchronize all threads in a block.
'''
return semantic.debug_barrier(_builder)
@@ -1508,16 +1521,28 @@ def max_contiguous(input, values, _builder=None):
@builtin
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
'''
Print the values at compile time. The parameters are the same as the builtin :code:`print`.
'''
pass
@builtin
def static_assert(cond, msg="", _builder=None):
'''
Assert the condition at compile time. The parameters are the same as the builtin :code:`assert`.
'''
pass
@builtin
def device_print(prefix, *args, _builder=None):
'''
Print the values at runtime from the device.
:param prefix: a prefix to print before the values. This is required to be a string literal.
:param args: the values to print. They can be any tensor or scalar.
'''
import string
prefix = _constexpr_to_value(prefix)
assert isinstance(prefix, str), f"{prefix} is not string"
@@ -1535,6 +1560,12 @@ def device_print(prefix, *args, _builder=None):
@builtin
def device_assert(cond, msg="", _builder=None):
'''
Assert the condition at runtime from the device.
:param cond: the condition to assert. This is required to be a boolean tensor.
:param msg: the message to print if the assertion fails. This is required to be a string literal.
'''
msg = _constexpr_to_value(msg)
import inspect
frame = inspect.currentframe()
@@ -1560,7 +1591,22 @@ def device_assert(cond, msg="", _builder=None):
class static_range:
"""Iterator that counts upward forever."""
"""
Iterator that counts upward forever.
.. highlight:: python
.. code-block:: python
@triton.jit
def kernel(...):
for i in tl.static_range(10):
...
:note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
:code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
:param arg1: the start value.
:param arg2: the end value.
:param step: the step value.
"""
def __init__(self, arg1, arg2=None, step=None):
assert isinstance(arg1, constexpr)

View File

@@ -1280,6 +1280,13 @@ def where(condition: tl.tensor,
def reduction(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
if axis is None:
new_inputs = []
for i in range(len(inputs)):
new_shape = [inputs[i].numel.value]
new_inputs.append(view(inputs[i], new_shape, builder))
inputs = tuple(new_inputs)
axis = 0
# get result shape
shape = inputs[0].type.shape
ret_shape = [s for i, s in enumerate(shape) if i != axis]

View File

@@ -203,8 +203,7 @@ class _attention(torch.autograd.Function):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# assert Lk in {16, 32, 64, 128}
assert Lk in {64} # TODO: fix other cases
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

View File

@@ -356,7 +356,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty

View File

@@ -40,29 +40,33 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
:type fast_flush: bool
"""
# Estimate the runtime of the function
fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()

View File

@@ -406,7 +406,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
// CHECK-LABEL: @store_constant_align
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%pid = tt.get_program_id {axis = 0 : i32} : i32
%pid = tt.get_program_id x : i32
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
%c128_i32 = arith.constant 128 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
@@ -438,7 +438,7 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
// CHECK-LABEL: @vecadd_mask_align_16
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
@@ -467,7 +467,7 @@ tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
// CHECK-LABEL: @vecadd_mask_align_1
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.splat %1 : (i32) -> tensor<64xi32>

View File

@@ -76,40 +76,64 @@ tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.return
}
// CHECK-LABEL: reduce_ops_infer
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly
// CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
// CHECK: tt.reduce
// CHECK-SAME: axis = 0
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<2x4xf32>
%a = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x4xf32>
%b = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
// CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
// CHECK: tt.reduce
// CHECK-SAME: axis = 2
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x2xf32>
%c = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<1x4xf32>) -> tensor<1xf32>
%e = "tt.reduce" (%b) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
// CHECK: tt.reduce
// CHECK-SAME: axis = 0
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<2x4xf32>) -> tensor<4xf32>
%f = "tt.reduce" (%a) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32
// CHECK: tt.reduce
// CHECK-SAME: axis = 0
// CHECK: tt.reduce.return
// CHECK-NEXT: (tensor<4xf32>) -> f32
%g = "tt.reduce" (%f) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
@@ -154,3 +178,12 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
tt.return
}
// CHECK-LABEL: @print_no_arg
tt.func @print_no_arg(%arg0: !tt.ptr<f32>) {
// CHECK: tt.print "test"
tt.print "test"
%0 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
tt.store %arg0, %0 {cache = 1 : i32, evict = 1 : i32} : f32
tt.return
}

View File

@@ -137,7 +137,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_no_vec
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -229,7 +229,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -306,7 +306,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
module attributes {"triton_gpu.num-warps" = 2 : i32} {
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
@@ -340,7 +340,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec2
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -461,7 +461,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -643,8 +643,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
tt.func @basic_program_id() {
<<<<<<< HEAD
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
=======
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id x : i32
>>>>>>> oai/main
tt.return
}
}
@@ -1528,6 +1533,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
<<<<<<< HEAD
%blockidx = tt.get_program_id {axis=0:i32} : i32
%blockidy = tt.get_program_id {axis=1:i32} : i32
%blockidz = tt.get_program_id {axis=2:i32} : i32
@@ -1537,6 +1543,14 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// GCN: rocdl.workgroup.id.x
// GCN: rocdl.workgroup.id.y
// GCN: rocdl.workgroup.id.z
=======
%blockidx = tt.get_program_id x : i32
%blockidy = tt.get_program_id y : i32
%blockidz = tt.get_program_id z : i32
// CHECK: nvvm.read.ptx.sreg.ctaid.x
// CHECK: nvvm.read.ptx.sreg.ctaid.y
// CHECK: nvvm.read.ptx.sreg.ctaid.z
>>>>>>> oai/main
%v0 = arith.addi %blockidx, %blockidy : i32
%v1 = arith.addi %v0, %blockidz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>

View File

@@ -10,8 +10,8 @@ tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.addi %arg3, %c127_i32 : i32
%3 = arith.divsi %2, %c128_i32 : i32
%4 = arith.addi %arg4, %c31_i32 : i32

View File

@@ -2,7 +2,7 @@
module {
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%c256_i32 = arith.constant 256 : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
@@ -49,7 +49,7 @@ module {
// %c0 = arith.constant 0 : index
// %cst = arith.constant 0.000000e+00 : f32
// %c256_i32 = arith.constant 256 : i32
// %0 = tt.get_program_id {axis = 0 : i32} : i32
// %0 = tt.get_program_id x : i32
// %1 = arith.muli %0, %c256_i32 : i32
// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>

View File

@@ -86,7 +86,7 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout1>
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout1>
@@ -102,7 +102,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-LABEL: if_convert_else_not
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
@@ -123,7 +123,7 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK-LABEL: if_not_else_convert
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
@@ -144,7 +144,7 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK-LABEL: if_else_both_convert
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0>
@@ -267,11 +267,63 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
tt.return
}
// CHECK-LABEL: loop_if
tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%i0 = arith.constant 0 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
%33 = "triton_gpu.cmpi"(%i0, %i0) {predicate = 4 : i64} : (i32, i32) -> i1
%34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) {
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
scf.yield %27 : tensor<64x64xf32, #blocked1>
} else {
scf.yield %arg6 : tensor<64x64xf32, #blocked1>
}
%28 = arith.addf %arg6, %34 : tensor<64x64xf32, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
}
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
// CHECK-LABEL: vecadd
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
@@ -309,7 +361,7 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
@@ -370,7 +422,7 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
%cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0>
%cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0>
@@ -757,7 +809,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2>
%cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c16_i32 : i32
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0>
%3 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
@@ -856,7 +908,7 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
%cst_4 = arith.constant dense<2048> : tensor<64x1xi32, #blocked2>
%cst_5 = arith.constant dense<49152> : tensor<64x1xi32, #blocked2>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
%3 = triton_gpu.convert_layout %2 : (tensor<64xi32, #blocked0>) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
@@ -992,7 +1044,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
%c-1_i64 = arith.constant -1 : i64
%cst = arith.constant 0.000000e+00 : f32
%c-1_i32 = arith.constant -1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
@@ -1054,7 +1106,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop.
// CHECK-LABEL: reduce_cvt2
// Match the reduction
// CHECK: }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
@@ -1073,7 +1127,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%cst_3 = arith.constant dense<196> : tensor<1x256xi32, #blocked>
%cst_4 = arith.constant dense<3136> : tensor<1x256xi32, #blocked>
%cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1>
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>

View File

@@ -6,8 +6,10 @@
#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}>
#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}>
#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#ALR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#ALC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}>
#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: tt.func @push_elementwise1
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
@@ -17,36 +19,106 @@
// CHECK: %[[C:.*]] = tt.dot %[[AF16]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise1(
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2>
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
// Not modified for row-row
// CHECK: tt.func @push_elementwise2
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise2(
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
// Not modified for col-row
// CHECK: tt.func @push_elementwise3
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise3(
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
// Not modified for col-col
// CHECK: tt.func @push_elementwise4
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise4(
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
// Not modified for Volta
// CHECK: tt.func @push_elementwise5
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
tt.func @push_elementwise5(
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1>
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av1>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv1>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
tt.return %newc : tensor<16x16xf32, #Cv1>
}

View File

@@ -12,7 +12,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32