mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'oai/main' into ifu230601
Conflicts: python/test/unit/language/assert_helper.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -222,6 +222,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
${conversion_libs}
|
||||
|
||||
# optimizations
|
||||
MLIRBytecodeWriter
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRLLVMDialect
|
||||
|
||||
@@ -23,6 +23,8 @@ Creation Ops
|
||||
:nosignatures:
|
||||
|
||||
arange
|
||||
cat
|
||||
full
|
||||
zeros
|
||||
|
||||
|
||||
@@ -33,11 +35,13 @@ Shape Manipulation Ops
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
broadcast
|
||||
broadcast_to
|
||||
expand_dims
|
||||
reshape
|
||||
ravel
|
||||
|
||||
reshape
|
||||
trans
|
||||
view
|
||||
|
||||
|
||||
Linear Algebra Ops
|
||||
@@ -83,11 +87,13 @@ Math Ops
|
||||
abs
|
||||
exp
|
||||
log
|
||||
fdiv
|
||||
cos
|
||||
sin
|
||||
sqrt
|
||||
sigmoid
|
||||
softmax
|
||||
umulhi
|
||||
|
||||
|
||||
Reduction Ops
|
||||
@@ -151,4 +157,27 @@ Compiler Hint Ops
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
debug_barrier
|
||||
max_contiguous
|
||||
multiple_of
|
||||
|
||||
Debug Ops
|
||||
-----------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
static_print
|
||||
static_assert
|
||||
device_print
|
||||
device_assert
|
||||
|
||||
Iterators
|
||||
-----------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
static_range
|
||||
|
||||
@@ -52,4 +52,16 @@ def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// Program ID dimensions.
|
||||
def TT_ProgramDim : I32EnumAttr<
|
||||
"ProgramIDDim", "",
|
||||
[
|
||||
I32EnumAttrCase<"X", 0, "x">,
|
||||
I32EnumAttrCase<"Y", 1, "y">,
|
||||
I32EnumAttrCase<"Z", 2, "z">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -37,6 +37,7 @@ def Triton_Dialect : Dialect {
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let usePropertiesForAttributes = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
@@ -351,11 +351,17 @@ def TT_TransOp : TT_Op<"trans", [Pure,
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
let arguments = (ins TT_ProgramDim:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$axis attr-dict `:` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
int32_t getAxisAsInt() {
|
||||
return static_cast<int32_t>(getAxis());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
|
||||
@@ -479,7 +485,7 @@ def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
|
||||
format are generated automatically from the arguments.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$prefix attr-dict `:` ($args^ `:` type($args))?
|
||||
$prefix attr-dict (`:` $args^ `:` type($args))?
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
|
||||
@@ -23,14 +23,16 @@ def TritonGPU_Dialect : Dialect {
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
|
||||
if(!numWarps)
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
return numWarps.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let usePropertiesForAttributes = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -111,15 +111,15 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
}
|
||||
}
|
||||
} else if (Operation *op = value.getDefiningOp()) {
|
||||
if (Attribute attr = op->getAttr("tt.divisibility")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownDivisibility = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.contiguity")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownContiguity = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.constancy")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownConstancy = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
@@ -888,15 +888,15 @@ void AxisInfoAnalysis::visitOperation(
|
||||
auto newContiguity = curr.getContiguity();
|
||||
auto newDivisibility = curr.getDivisibility();
|
||||
auto newConstancy = curr.getConstancy();
|
||||
if (Attribute attr = op->getAttr("tt.contiguity")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.divisibility")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.constancy")) {
|
||||
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
|
||||
@@ -87,6 +87,10 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
isaDistributedLayout(dstLayout)) {
|
||||
return lowerSharedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
@@ -544,9 +548,40 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
auto *ctx = llvmElemTy.getContext();
|
||||
Type structTy = struct_ty(types);
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(dstShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(shared->blocked)");
|
||||
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto inOrd = getOrder(srcSharedLayout);
|
||||
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
|
||||
auto srcStrides =
|
||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
|
||||
|
||||
SmallVector<Value> outVals = loadSharedToDistributed(
|
||||
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
|
||||
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@@ -303,6 +303,8 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
|
||||
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
|
||||
} else {
|
||||
if (needTrans && (4 / elemBytes) != kWidth)
|
||||
llvm_unreachable("unimplemented Shared -> DotOperandMmav2 code path");
|
||||
// base pointers
|
||||
std::array<std::array<Value, 4>, 2> ptrs;
|
||||
int vecWidth = 4 / elemBytes;
|
||||
@@ -329,8 +331,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
// row + trans and col + no-trans are equivalent
|
||||
bool isActualTrans =
|
||||
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
|
||||
if (isActualTrans)
|
||||
std::swap(vptrs[1], vptrs[2]);
|
||||
// pack loaded vectors into 4 32-bit values
|
||||
int inc = needTrans ? 1 : kWidth;
|
||||
VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc);
|
||||
@@ -348,12 +348,15 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
Value val = load(ptr);
|
||||
Value canonval = bitcast(val, vec_ty(canonInt, canonWidth));
|
||||
for (int w = 0; w < canonWidth; ++w) {
|
||||
retElems[idx + w * kWidth / vecWidth] =
|
||||
insert_element(retElems[idx + w * kWidth / vecWidth],
|
||||
int ridx = idx + w * kWidth / vecWidth;
|
||||
retElems[ridx] =
|
||||
insert_element(retElems[ridx],
|
||||
extract_element(canonval, i32_val(w)), i32_val(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isActualTrans)
|
||||
std::swap(retElems[1], retElems[2]);
|
||||
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
|
||||
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
|
||||
}
|
||||
|
||||
@@ -440,10 +440,10 @@ struct GetProgramIdOpConversion
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -119,9 +119,8 @@ protected:
|
||||
// Create an LLVM function, use external linkage by default until MLIR
|
||||
// functions have linkage.
|
||||
LLVM::Linkage linkage = LLVM::Linkage::External;
|
||||
if (funcOp->hasAttr("llvm.linkage")) {
|
||||
auto attr =
|
||||
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||
if (auto linkageAttr = funcOp->getDiscardableAttr("llvm.linkage")) {
|
||||
auto attr = linkageAttr.dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||
if (!attr) {
|
||||
funcOp->emitError()
|
||||
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
||||
@@ -360,6 +359,55 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
|
||||
Value src, SharedMemoryObject smemObj, Type elemTy,
|
||||
Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(dstShape.size() == 2 &&
|
||||
"Unexpected rank of loadSharedToDistributed");
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstDistributedLayout = dstTy.getEncoding();
|
||||
if (auto mmaLayout = dstDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert((!mmaLayout.isVolta()) &&
|
||||
"ConvertLayout Shared->MMAv1 is not supported yet");
|
||||
}
|
||||
auto srcSharedLayout =
|
||||
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto srcElemTy = srcTy.getElementType();
|
||||
auto dstElemTy = dstTy.getElementType();
|
||||
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
|
||||
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
|
||||
unsigned outVec =
|
||||
inOrd == outOrd
|
||||
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
|
||||
: 1;
|
||||
unsigned inVec = srcSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
assert(outElems == dstIndices.size());
|
||||
|
||||
DenseMap<unsigned, Value> sharedPtrs = getSwizzledSharedPtrs(
|
||||
loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter,
|
||||
smemObj.offsets, smemObj.strides);
|
||||
assert(outElems % minVec == 0 && "Unexpected number of elements");
|
||||
unsigned numVecs = outElems / minVec;
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
for (unsigned i = 0; i < numVecs; ++i) {
|
||||
Value smemAddr = sharedPtrs[i * minVec];
|
||||
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
||||
Value valVec = load(smemAddr);
|
||||
for (unsigned v = 0; v < minVec; ++v) {
|
||||
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
|
||||
outVals[i * minVec + v] = currVal;
|
||||
}
|
||||
}
|
||||
return outVals;
|
||||
}
|
||||
|
||||
void storeDistributedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> dstStrides,
|
||||
ArrayRef<SmallVector<Value>> srcIndices,
|
||||
@@ -387,16 +435,11 @@ public:
|
||||
: 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
assert(numElems == srcIndices.size());
|
||||
auto inVals =
|
||||
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
auto elemPtrTy = ptr_ty(elemTy);
|
||||
Value outVecVal = i32_val(outVec);
|
||||
Value minVecVal = i32_val(minVec);
|
||||
Value word;
|
||||
|
||||
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};
|
||||
|
||||
@@ -310,8 +310,7 @@ public:
|
||||
// Preprocess
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
if (failed(decomposeInsertSliceAsyncOp(mod)))
|
||||
return signalPassFailure();
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
// Allocate shared memory and set barrier
|
||||
ModuleAllocation allocation(mod);
|
||||
@@ -490,7 +489,7 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
@@ -518,6 +517,7 @@ private:
|
||||
// Get the vectorized load size
|
||||
auto src = insertSliceAsyncOp.getSrc();
|
||||
auto dst = insertSliceAsyncOp.getDst();
|
||||
auto mask = insertSliceAsyncOp.getMask();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
@@ -526,6 +526,9 @@ private:
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
||||
if (mask)
|
||||
inVec =
|
||||
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
@@ -597,7 +600,6 @@ private:
|
||||
}
|
||||
#endif
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -796,7 +796,6 @@ public:
|
||||
if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(),
|
||||
*converter)))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -446,10 +446,11 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
Properties *prop = properties.as<Properties *>();
|
||||
int axis = prop->axis.getInt();
|
||||
for (auto arg : operands) {
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retEltTy = argTy.getElementType();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes)
|
||||
.failed()) {
|
||||
return failure();
|
||||
@@ -557,7 +558,8 @@ mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
auto arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
Properties *prop = properties.as<Properties *>();
|
||||
int axis = prop->axis.getInt();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
@@ -740,7 +742,7 @@ void triton::FuncOp::print(OpAsmPrinter &printer) {
|
||||
LogicalResult
|
||||
triton::CallOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||
auto fnAttr = (*this).getProperties().callee;
|
||||
if (!fnAttr)
|
||||
return emitOpError("requires a 'callee' symbol reference attribute");
|
||||
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
|
||||
|
||||
@@ -55,10 +55,6 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
|
||||
@@ -81,15 +81,18 @@ public:
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
static mlir::LogicalResult
|
||||
isBlockedToDotOperand(mlir::Operation *op,
|
||||
triton::gpu::DotOperandEncodingAttr &retEncoding,
|
||||
triton::gpu::BlockedEncodingAttr &srcEncoding) {
|
||||
if (!op)
|
||||
return failure();
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retEncoding =
|
||||
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto srcEncoding =
|
||||
srcEncoding =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
if (!retTy)
|
||||
return failure();
|
||||
@@ -101,6 +104,51 @@ public:
|
||||
return failure();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isTrans(const triton::gpu::DotOperandEncodingAttr &retEncoding,
|
||||
const triton::gpu::BlockedEncodingAttr &srcEncoding) {
|
||||
int kOrder = retEncoding.getOpIdx() ^ 1;
|
||||
return kOrder != srcEncoding.getOrder()[0];
|
||||
}
|
||||
|
||||
static bool isDotNT(triton::DotOp dotOp) {
|
||||
triton::gpu::DotOperandEncodingAttr aRetEncoding;
|
||||
triton::gpu::DotOperandEncodingAttr bRetEncoding;
|
||||
triton::gpu::BlockedEncodingAttr aSrcEncoding;
|
||||
triton::gpu::BlockedEncodingAttr bSrcEncoding;
|
||||
if (isBlockedToDotOperand(dotOp.getOperand(0).getDefiningOp(), aRetEncoding,
|
||||
aSrcEncoding)
|
||||
.failed())
|
||||
return false;
|
||||
if (isBlockedToDotOperand(dotOp.getOperand(1).getDefiningOp(), bRetEncoding,
|
||||
bSrcEncoding)
|
||||
.failed())
|
||||
return false;
|
||||
if (!aRetEncoding || !bRetEncoding || !aSrcEncoding || !bSrcEncoding)
|
||||
return false;
|
||||
return !isTrans(aRetEncoding, aSrcEncoding) &&
|
||||
!isTrans(bRetEncoding, bSrcEncoding);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
triton::gpu::DotOperandEncodingAttr retEncoding;
|
||||
triton::gpu::BlockedEncodingAttr srcEncoding;
|
||||
if (isBlockedToDotOperand(op, retEncoding, srcEncoding).failed())
|
||||
return mlir::failure();
|
||||
|
||||
// only supports dot NT
|
||||
auto users = cvt->getUsers();
|
||||
auto dotOp = dyn_cast_or_null<DotOp>(*users.begin());
|
||||
if (!dotOp)
|
||||
return failure();
|
||||
if (!isDotNT(dotOp))
|
||||
return failure();
|
||||
|
||||
// don't move things around when cvt operand is a block arg
|
||||
Operation *argOp = cvt.getOperand().getDefiningOp();
|
||||
if (!argOp)
|
||||
@@ -129,8 +177,8 @@ public:
|
||||
return failure();
|
||||
// we don't want to use ldmatrix for 8-bit data that requires trans
|
||||
// since Nvidia GPUs can't do it efficiently
|
||||
bool isTrans =
|
||||
(retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0);
|
||||
int kOrder = retEncoding.getOpIdx() ^ 1;
|
||||
bool isTrans = kOrder != srcEncoding.getOrder()[0];
|
||||
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
|
||||
if (isTrans && isInt8)
|
||||
return failure();
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
@@ -17,6 +18,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using llvm::MapVector;
|
||||
using namespace mlir;
|
||||
namespace ttg = triton::gpu;
|
||||
|
||||
@@ -25,9 +27,19 @@ namespace ttg = triton::gpu;
|
||||
|
||||
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
|
||||
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
|
||||
for (const NamedAttribute attr : dictAttrs.getValue())
|
||||
if (!op->hasAttr(attr.getName()))
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
NamedAttrList attrs = op->getDiscardableAttrs();
|
||||
// Collect the attributes to propagate: the ones in dictAttrs and not yet on
|
||||
// the operation.
|
||||
SmallVector<NamedAttribute> toPropagate;
|
||||
for (const NamedAttribute attr : dictAttrs.getValue()) {
|
||||
if (!attrs.get(attr.getName()))
|
||||
toPropagate.push_back(attr);
|
||||
}
|
||||
// If we found any, let's set them here as a single step.
|
||||
if (toPropagate.size()) {
|
||||
attrs.append(toPropagate);
|
||||
op->setDiscardableAttrs(attrs);
|
||||
}
|
||||
}
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
@@ -69,31 +81,46 @@ class LoopPipeliner {
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
|
||||
/// Block arguments that loads depend on
|
||||
SetVector<BlockArgument> depArgs;
|
||||
|
||||
/// For each argument, we need to record at which stage it is defined.
|
||||
/// If we have a load that immediately depends on a block argument in the
|
||||
/// current iteration, it is an immediate dependency. Otherwise, it is a
|
||||
/// non-immediate dependency, which means the load depends on a block argument
|
||||
/// in the previous iterations.
|
||||
/// For example:
|
||||
/// scf.for (%arg0, %arg1, %arg2) {
|
||||
/// %0 = load %arg0 <--- immediate dep, this address is initialized at
|
||||
/// numStages-2
|
||||
/// %0 = load %arg0 <--- immediate dep, this address is initialized before
|
||||
/// numStages-1
|
||||
/// %1 = load %arg1
|
||||
/// %2 = add %1, %arg2
|
||||
/// %3 = load %2 <--- non-immediate dep, %arg1 must be an update-to-date
|
||||
/// value
|
||||
/// }
|
||||
SetVector<BlockArgument> immedidateDepArgs;
|
||||
/// Collect values that v depends on and are defined inside the loop
|
||||
LogicalResult collectDeps(Value v, int stage,
|
||||
MapVector<Value, int> &depStage);
|
||||
|
||||
SetVector<BlockArgument> nonImmedidateDepArgs;
|
||||
/// Associate each variable with a unique stage. If a variable is defined
|
||||
/// at multiple stages, we don't pipeline it.
|
||||
LogicalResult addDep(Value v, int stage, MapVector<Value, int> &depStage);
|
||||
|
||||
int getArgDefStage(Value v, int stage);
|
||||
|
||||
/// Block arguments that loads depend on
|
||||
MapVector<BlockArgument, int> depArgUseStage;
|
||||
|
||||
/// Block arguments that loads depend on (defined in the loop body)
|
||||
MapVector<BlockArgument, int> depArgDefStage;
|
||||
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
SetVector<Operation *> depOps;
|
||||
MapVector<Operation *, int> depOpDefStage;
|
||||
|
||||
/// collect values that v depends on and are defined inside the loop
|
||||
void collectDeps(Value v, int stages, SetVector<Value> &deps);
|
||||
/// Operations (inside the loop body) that loads depend on (defined in the
|
||||
/// loop body)
|
||||
SetVector<BlockArgument> immediateDepArgs;
|
||||
|
||||
/// Operations (inside the loop body) that loads depend on (defined in the
|
||||
/// previous iterations)
|
||||
SetVector<BlockArgument> nonImmediateDepArgs;
|
||||
|
||||
void setValueMapping(Value origin, Value newValue, int stage);
|
||||
|
||||
@@ -140,31 +167,60 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
|
||||
return valueMapping[origin][stage];
|
||||
}
|
||||
|
||||
void LoopPipeliner::collectDeps(Value v, int stages, SetVector<Value> &deps) {
|
||||
LogicalResult LoopPipeliner::addDep(Value v, int stage,
|
||||
MapVector<Value, int> &depStage) {
|
||||
if (!depStage.contains(v)) {
|
||||
depStage.insert(std::make_pair(v, stage));
|
||||
} else if (depStage[v] != stage)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LoopPipeliner::collectDeps(Value v, int stage,
|
||||
MapVector<Value, int> &depStage) {
|
||||
// Loop-invariant value, skip
|
||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||
return;
|
||||
return success();
|
||||
|
||||
// Since we only need to peel the loop numStages-1 times, don't worry about
|
||||
// depends that are too far away
|
||||
if (stages < 0)
|
||||
return;
|
||||
if (stage < 0)
|
||||
return success();
|
||||
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
// Skip the first arg (loop induction variable)
|
||||
// Otherwise the op idx is arg.getArgNumber()-1
|
||||
if (arg.getArgNumber() > 0) {
|
||||
// Skip the first arg (loop induction variable)
|
||||
// Otherwise the op idx is arg.getArgNumber()-1
|
||||
deps.insert(v);
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
|
||||
deps);
|
||||
// If we've found the first definition of this arg, we're done, don't
|
||||
// recurse
|
||||
if (addDep(v, stage, depStage).succeeded())
|
||||
if (collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1,
|
||||
depStage)
|
||||
.failed())
|
||||
return failure();
|
||||
}
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depend on value in previous iterations
|
||||
deps.insert(v);
|
||||
// An operation cannot be dependent on different stages
|
||||
if (addDep(v, stage, depStage).failed())
|
||||
return failure();
|
||||
for (Value op : v.getDefiningOp()->getOperands())
|
||||
collectDeps(op, stages, deps);
|
||||
if (collectDeps(op, stage, depStage).failed())
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
int LoopPipeliner::getArgDefStage(Value v, int stage) {
|
||||
if (stage < 0)
|
||||
return -1;
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
if (arg.getArgNumber() > 0) {
|
||||
return getArgDefStage(yieldOp->getOperand(arg.getArgNumber() - 1),
|
||||
stage - 1);
|
||||
}
|
||||
llvm_unreachable("Loop induction variable should not be a dependency");
|
||||
} else
|
||||
return stage;
|
||||
}
|
||||
|
||||
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
@@ -188,7 +244,8 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
|
||||
|
||||
// can we use forOp.walk(...) here?
|
||||
// We cannot use forOp.walk(...) here because we only want to visit the
|
||||
// operations in the loop body block. Nested blocks are handled separately.
|
||||
SmallVector<triton::LoadOp, 2> validLoads;
|
||||
for (Operation &op : *loop)
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
@@ -205,7 +262,11 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType();
|
||||
unsigned width = vec * ty.getIntOrFloatBitWidth();
|
||||
// cp.async's cp-size can only be 4, 8 and 16.
|
||||
// We do not pipeline all loads for the following reasons:
|
||||
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
|
||||
// 2. It's likely that pipling small load won't offer much performance
|
||||
// improvement and may even hurt performance by increasing register
|
||||
// pressure.
|
||||
if (width >= 32)
|
||||
validLoads.push_back(loadOp);
|
||||
}
|
||||
@@ -215,12 +276,28 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
return failure();
|
||||
|
||||
// load => values that it depends on
|
||||
// Don't pipeline if any load's operands
|
||||
DenseMap<Value, SetVector<Value>> loadDeps;
|
||||
MapVector<Value, int> depStage;
|
||||
for (triton::LoadOp loadOp : validLoads) {
|
||||
SetVector<Value> deps;
|
||||
for (Value op : loadOp->getOperands())
|
||||
collectDeps(op, numStages - 1, deps);
|
||||
loadDeps[loadOp] = deps;
|
||||
for (Value op : loadOp->getOperands()) {
|
||||
MapVector<Value, int> operandDepStage;
|
||||
if (collectDeps(op, numStages - 1, operandDepStage).failed())
|
||||
return failure();
|
||||
for (auto [v, stage] : operandDepStage) {
|
||||
auto immedidate = operandDepStage.front().first.isa<BlockArgument>();
|
||||
if (v.isa<BlockArgument>()) {
|
||||
auto arg = v.cast<BlockArgument>();
|
||||
if (immedidate)
|
||||
immediateDepArgs.insert(arg);
|
||||
else
|
||||
nonImmediateDepArgs.insert(arg);
|
||||
}
|
||||
loadDeps[loadOp].insert(v);
|
||||
if (addDep(v, stage, depStage).failed())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Don't pipeline valid loads that depend on other valid loads
|
||||
@@ -268,9 +345,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
continue;
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
|
||||
else
|
||||
} else
|
||||
isCandidate = false;
|
||||
|
||||
if (isCandidate)
|
||||
@@ -306,7 +381,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
|
||||
ttg::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
|
||||
loadsBufferType[loadOp] =
|
||||
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
@@ -314,27 +389,19 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// Update depArgs & depOps
|
||||
for (Value loadOp : loads) {
|
||||
auto &deps = loadDeps[loadOp];
|
||||
for (auto &dep : deps) {
|
||||
if (auto arg = dep.dyn_cast<BlockArgument>()) {
|
||||
depArgs.insert(arg);
|
||||
if (deps.front().isa<BlockArgument>()) {
|
||||
immedidateDepArgs.insert(arg);
|
||||
} else {
|
||||
nonImmedidateDepArgs.insert(arg);
|
||||
}
|
||||
} else
|
||||
depOps.insert(dep.getDefiningOp());
|
||||
}
|
||||
for (auto [dep, stage] : depStage) {
|
||||
if (auto arg = dep.dyn_cast<BlockArgument>())
|
||||
depArgUseStage.insert({arg, stage});
|
||||
else
|
||||
depOpDefStage.insert({dep.getDefiningOp(), stage});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check if immedidateDepArgs and nonImmedidateDepArgs are disjoint
|
||||
// If yes, we cannot pipeline the loop for now
|
||||
for (BlockArgument arg : immedidateDepArgs)
|
||||
if (nonImmedidateDepArgs.contains(arg)) {
|
||||
for (BlockArgument arg : immediateDepArgs)
|
||||
if (nonImmediateDepArgs.contains(arg)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -365,12 +432,13 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
|
||||
|
||||
void LoopPipeliner::emitPrologue() {
|
||||
OpBuilder builder(forOp);
|
||||
// Get init operands for loop carried values
|
||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
// Emit prologue from [0, numStage-1)
|
||||
Value iv = forOp.getLowerBound();
|
||||
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
@@ -386,12 +454,12 @@ void LoopPipeliner::emitPrologue() {
|
||||
// Rematerialize peeled values
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
if (depOpDefStage.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
|
||||
orderedDeps.push_back(&op);
|
||||
}
|
||||
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
||||
assert(depOpDefStage.size() + loads.size() == orderedDeps.size() &&
|
||||
"depOps contains invalid values");
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *newOp = nullptr;
|
||||
@@ -406,14 +474,13 @@ void LoopPipeliner::emitPrologue() {
|
||||
Value newMask =
|
||||
getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage),
|
||||
loopCond, builder);
|
||||
// TODO: check if the hardware supports async copy
|
||||
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
newOp = builder.create<ttg::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
lookupOrDefault(loadOp.getPtr(), stage),
|
||||
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
|
||||
lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0);
|
||||
builder.create<triton::gpu::AsyncCommitGroupOp>(op->getLoc());
|
||||
builder.create<ttg::AsyncCommitGroupOp>(op->getLoc());
|
||||
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
||||
} else
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
@@ -428,10 +495,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
lookupOrDefault(loadOp.getOther(), stage),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
addNamedAttrs(newOp, op->getAttrDictionary());
|
||||
} else {
|
||||
addNamedAttrs(newOp, op->getDiscardableAttrDictionary());
|
||||
} else
|
||||
newOp = builder.clone(*op);
|
||||
}
|
||||
// Update loop-carried uses
|
||||
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||
@@ -443,32 +509,40 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
|
||||
// Update mapping of results
|
||||
// if (stage == numStages - 2)
|
||||
// continue;
|
||||
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
Value originResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
// TODO: load should not be used in the preheader?
|
||||
if (loads.contains(originalResult)) {
|
||||
if (loads.contains(originResult))
|
||||
break;
|
||||
// originalResult = loadsMapping[originalResult];
|
||||
}
|
||||
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
||||
setValueMapping(originResult, newOp->getResult(dstIdx), stage);
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
setValueMapping(
|
||||
forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
auto yieldIdx = operand.getOperandNumber();
|
||||
auto value = forOp.getRegionIterArgs()[yieldIdx];
|
||||
setValueMapping(value, newOp->getResult(dstIdx), stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for (Operation *op : orderedDeps)
|
||||
|
||||
// Update pipeline index
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
iv.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
|
||||
// Some values have not been used by any ops in the loop body
|
||||
for (BlockArgument arg : forOp.getRegionIterArgs()) {
|
||||
// Check if arg has a yieldOp use
|
||||
for (OpOperand &operand : arg.getUses()) {
|
||||
if (operand.getOwner() == yieldOp) {
|
||||
auto yieldIdx = operand.getOperandNumber();
|
||||
auto value = forOp.getRegionIterArgs()[yieldIdx];
|
||||
if (!valueMapping[value][stage + 1])
|
||||
setValueMapping(value, valueMapping[arg][stage], stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||
|
||||
// async.wait & extract_slice
|
||||
@@ -484,7 +558,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]},
|
||||
sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
Value extractSlice = builder.create<ttg::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
@@ -505,7 +579,7 @@ void LoopPipeliner::emitEpilogue() {
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -537,22 +611,24 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
newLoopArgs.push_back(loadsExtract[loadOp]);
|
||||
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
for (auto [depArg, useStage] : depArgUseStage) {
|
||||
depArgsIdx[depArg] = newLoopArgs.size();
|
||||
if (immedidateDepArgs.contains(depArg)) {
|
||||
auto defStage = getArgDefStage(depArg, useStage);
|
||||
assert(defStage >= 0 &&
|
||||
"newLoopArgs has null args without a define op. Consider either "
|
||||
"rewrite the loop to reduce cross iteration dependencies or "
|
||||
"increase the num_stages value.");
|
||||
if (immediateDepArgs.contains(depArg) && defStage == numStages - 2) {
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages - 2]);
|
||||
} else
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
|
||||
}
|
||||
|
||||
size_t nextIVIdx = newLoopArgs.size();
|
||||
size_t ivIndex = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
||||
newLoopArgs.push_back(pipelineIterIdx);
|
||||
newLoopArgs.push_back(loopIterIdx);
|
||||
|
||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||
assert(newLoopArgs[i]);
|
||||
|
||||
// 1. signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
@@ -565,7 +641,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
// 2. clone the loop body, replace original args with args of the new ForOp
|
||||
// 3. clone the loop body, replace original args with args of the new ForOp
|
||||
// Insert async wait if necessary.
|
||||
DenseSet<Value> isModified;
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
@@ -597,50 +673,54 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
isModified.insert(op.getResult(0));
|
||||
}
|
||||
|
||||
// 3. prefetch the next iteration
|
||||
// 4. prefetch the next iteration
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
if (depOpDefStage.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
|
||||
orderedDeps.push_back(&op);
|
||||
}
|
||||
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
||||
assert(depOpDefStage.size() + loads.size() == orderedDeps.size() &&
|
||||
"depOps contains invalid values");
|
||||
IRMapping nextMapping;
|
||||
DenseMap<BlockArgument, Value> depArgsMapping;
|
||||
size_t argIdx = 0;
|
||||
for (BlockArgument arg : depArgs) {
|
||||
for (auto [depArg, useStage] : depArgUseStage) {
|
||||
BlockArgument nextArg =
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx];
|
||||
nextMapping.map(arg, nextArg);
|
||||
nextMapping.map(depArg, nextArg);
|
||||
++argIdx;
|
||||
}
|
||||
|
||||
// Special handling for iv & loop condition
|
||||
Value curIV = newForOp.getRegionIterArgs()[ivIndex];
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
newForOp.getInductionVar().getLoc(), curIV, newForOp.getStep());
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
nextMapping.map(forOp.getInductionVar(), nextIV);
|
||||
|
||||
// Slice index
|
||||
SmallVector<Value> nextBuffers;
|
||||
SmallVector<Value> extractSlices;
|
||||
|
||||
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
|
||||
pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1];
|
||||
Value insertSliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
|
||||
loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2];
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
|
||||
// Prefetch load deps
|
||||
for (Operation *op : orderedDeps)
|
||||
if (!loads.contains(op->getResult(0))) {
|
||||
if (depOpDefStage[op] == numStages - 2)
|
||||
nextMapping.map(forOp.getInductionVar(), curIV);
|
||||
else
|
||||
nextMapping.map(forOp.getInductionVar(), nextIV);
|
||||
Operation *nextOp;
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
|
||||
auto newMask =
|
||||
@@ -652,20 +732,20 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.lookupOrDefault(loadOp.getOther()),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
addNamedAttrs(nextOp, op->getAttrDictionary());
|
||||
addNamedAttrs(nextOp, op->getDiscardableAttrDictionary());
|
||||
nextMapping.map(loadOp.getResult(), nextOp->getResult(0));
|
||||
} else {
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
}
|
||||
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
size_t originIdx = operand.getOperandNumber();
|
||||
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
|
||||
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
|
||||
nextMapping.map(forOp.getRegionIterArgs()[originIdx],
|
||||
size_t yieldIdx = operand.getOperandNumber();
|
||||
size_t depYieldIdx =
|
||||
depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
|
||||
BlockArgument newArg = newForOp.getRegionIterArgs()[depYieldIdx];
|
||||
nextMapping.map(forOp.getRegionIterArgs()[yieldIdx],
|
||||
nextOp->getResult(dstIdx));
|
||||
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
|
||||
}
|
||||
@@ -673,6 +753,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
}
|
||||
|
||||
// loads -> async loads
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// Update loading mask
|
||||
@@ -689,14 +770,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.map(loadOp.getMask(), newMask);
|
||||
newMask = nextMapping.lookupOrDefault(mask);
|
||||
}
|
||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
Value insertAsyncOp = builder.create<ttg::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.getPtr()),
|
||||
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
||||
insertSliceIndex, newMask,
|
||||
nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0);
|
||||
builder.create<triton::gpu::AsyncCommitGroupOp>(op->getLoc());
|
||||
builder.create<ttg::AsyncCommitGroupOp>(op->getLoc());
|
||||
nextBuffers.push_back(insertAsyncOp);
|
||||
// ExtractSlice
|
||||
auto bufferType = insertAsyncOp.getType().cast<RankedTensorType>();
|
||||
@@ -706,7 +787,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
|
||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
nextOp = builder.create<ttg::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
||||
int_attr(0)},
|
||||
@@ -720,12 +801,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
// If this is a loop-carried value, update the mapping for yield
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
size_t originIdx = operand.getOperandNumber();
|
||||
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
|
||||
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
|
||||
auto yieldIdx = operand.getOperandNumber();
|
||||
auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
|
||||
auto newArg = newForOp.getRegionIterArgs()[depYieldIdx];
|
||||
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
|
||||
}
|
||||
}
|
||||
@@ -733,6 +813,22 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
}
|
||||
|
||||
// Some values have not been used by any ops in the loop body
|
||||
for (BlockArgument arg : forOp.getRegionIterArgs()) {
|
||||
// Check if arg has a yieldOp use
|
||||
for (OpOperand &operand : arg.getUses()) {
|
||||
if (operand.getOwner() == yieldOp) {
|
||||
auto yieldIdx = operand.getOperandNumber();
|
||||
auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
|
||||
auto newArg = newForOp.getRegionIterArgs()[depYieldIdx];
|
||||
if (!depArgsMapping.contains(newArg)) {
|
||||
auto argIdx = depArgsIdx[arg];
|
||||
depArgsMapping[newArg] = newForOp.getRegionIterArgs()[argIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
@@ -751,14 +847,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
||||
SmallVector<Value> yieldValues;
|
||||
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
||||
for (Value v : yieldOp->getOperands())
|
||||
yieldValues.push_back(mapping.lookup(v));
|
||||
for (Value nextBuffer : nextBuffers)
|
||||
yieldValues.push_back(nextBuffer);
|
||||
for (Value nextSlice : extractSlices)
|
||||
yieldValues.push_back(nextSlice);
|
||||
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
|
||||
for (size_t i = depArgsBeginIdx; i < ivIndex; ++i) {
|
||||
auto arg = newForOp.getRegionIterArgs()[i];
|
||||
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
|
||||
yieldValues.push_back(depArgsMapping[arg]);
|
||||
@@ -768,8 +864,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
yieldValues.push_back(loopIterIdx);
|
||||
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
yieldValues);
|
||||
builder.create<scf::YieldOp>(yieldOp->getLoc(), yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
|
||||
@@ -349,27 +349,23 @@ public:
|
||||
SetVector<Operation *> cvtSlices;
|
||||
auto filter = [&](Operation *op) {
|
||||
return op->getBlock() == cvt->getBlock() &&
|
||||
!isa<triton::gpu::ConvertLayoutOp, scf::YieldOp>(op) &&
|
||||
!(isa<triton::ReduceOp>(op) &&
|
||||
!op->getResult(0).getType().isa<RankedTensorType>()) &&
|
||||
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
|
||||
!op->getResult(0).getType().isa<RankedTensorType>());
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty()) {
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
for (Operation *op : cvtSlices) {
|
||||
// don't rematerialize anything expensive
|
||||
if (expensiveToRemat(op, dstEncoding)) {
|
||||
if (expensiveToRemat(op, dstEncoding))
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op) && !isa<triton::ReduceOp>(op)) {
|
||||
!isa<triton::StoreOp, triton::ReduceOp>(op))
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize if it adds an extra conversion that can't
|
||||
// be removed
|
||||
for (Value arg : op->getOperands()) {
|
||||
@@ -380,9 +376,8 @@ public:
|
||||
int numAddedConvs = simulateBackwardRematerialization(
|
||||
argOp, processed, layout, toConvert, srcEncoding);
|
||||
if (argOp && !isa<triton::gpu::ConvertLayoutOp>(argOp) &&
|
||||
cvtSlices.count(argOp) == 0 && numAddedConvs > 0) {
|
||||
cvtSlices.count(argOp) == 0 && numAddedConvs > 0)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,7 +420,6 @@ public:
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
|
||||
targetType.getEncoding()) > 0)
|
||||
return mlir::failure();
|
||||
@@ -507,49 +501,15 @@ public:
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
// continue;
|
||||
// skip non-tensor types
|
||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
// we only move `iterArg` out of the loop if
|
||||
// - there is only a single conversion use
|
||||
// - moving this conversion out of the loop will not generate
|
||||
// any extra non-removable conversion
|
||||
auto users = iterArg.value().getUsers();
|
||||
// check first condition
|
||||
SetVector<Type> cvtTargetTypes;
|
||||
for (auto user : users) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
|
||||
auto newType =
|
||||
user->getResults()[0].getType().cast<RankedTensorType>();
|
||||
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
|
||||
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
newType.getEncoding()
|
||||
.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
continue;
|
||||
}
|
||||
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (newType.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getVec() == 1)
|
||||
continue;
|
||||
}
|
||||
cvtTargetTypes.insert(newType);
|
||||
}
|
||||
}
|
||||
if (cvtTargetTypes.size() != 1)
|
||||
SmallVector<Operation *> cvts;
|
||||
if (canMoveOutOfLoop(iterArg.value(), cvts).failed())
|
||||
continue;
|
||||
// TODO: check second condition
|
||||
for (auto user : users) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user))
|
||||
continue;
|
||||
}
|
||||
// check
|
||||
for (auto op : iterArg.value().getUsers()) {
|
||||
for (auto *op : cvts) {
|
||||
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
if (!cvt)
|
||||
continue;
|
||||
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
|
||||
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
|
||||
targetType, cvt);
|
||||
|
||||
@@ -269,4 +269,64 @@ void rematerializeConversionChain(
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult canMoveOutOfLoop(BlockArgument arg,
|
||||
SmallVector<Operation *> &cvts) {
|
||||
auto parentOp = arg.getOwner()->getParentOp();
|
||||
// Don't move if arg is defined in a while loop
|
||||
if (isa<scf::WhileOp>(parentOp))
|
||||
return failure();
|
||||
// Skip if arg is not defined in scf.for
|
||||
if (!isa<scf::ForOp>(parentOp))
|
||||
return success();
|
||||
auto forOp = cast<scf::ForOp>(parentOp);
|
||||
// We only move `iterArg` out of the loop if
|
||||
// 1. There is no conversion
|
||||
// 2. There is only a single conversion
|
||||
// 3. Moving this conversion out of the loop will not generate any extra
|
||||
// non-removable conversion
|
||||
DenseSet<Type> cvtTypes;
|
||||
SetVector<Operation *> others;
|
||||
auto oldType = arg.getType().cast<RankedTensorType>();
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
|
||||
// Don't move if the conversion target is a dot operand or shared memory
|
||||
auto newType = user->getResults()[0].getType().cast<RankedTensorType>();
|
||||
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
newType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
continue;
|
||||
}
|
||||
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (newType.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getVec() == 1)
|
||||
continue;
|
||||
}
|
||||
cvts.emplace_back(user);
|
||||
cvtTypes.insert(newType);
|
||||
} else
|
||||
others.insert(user);
|
||||
}
|
||||
// First condition
|
||||
if (cvts.empty())
|
||||
return success();
|
||||
if (cvtTypes.size() == 1) {
|
||||
// Second condition
|
||||
if (others.empty())
|
||||
return success();
|
||||
// Third condition: not complete
|
||||
// If the other or the cvt is in the different block, we cannot push the
|
||||
// conversion forward or backward
|
||||
for (auto *cvt : cvts) {
|
||||
if (cvt->getBlock() != forOp.getBody())
|
||||
return failure();
|
||||
}
|
||||
for (auto *other : others) {
|
||||
if (other->getBlock() != forOp.getBody())
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -16,6 +16,8 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
// skipInit is True when we only consider the operands of the initOp but
|
||||
// not the initOp itself.
|
||||
int simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
@@ -28,6 +30,10 @@ void rematerializeConversionChain(
|
||||
const llvm::MapVector<Value, Attribute> &toConvert,
|
||||
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
|
||||
IRMapping &mapping);
|
||||
|
||||
LogicalResult canMoveOutOfLoop(BlockArgument arg,
|
||||
SmallVector<Operation *> &cvts);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
@@ -158,12 +158,12 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
funcs.push_back(func);
|
||||
});
|
||||
|
||||
for (auto &func : funcs) {
|
||||
if (func.getOperation()->hasAttr("libname")) {
|
||||
auto name =
|
||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||
auto path =
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
for (LLVM::LLVMFuncOp func : funcs) {
|
||||
if (auto libnameAttr = func->getDiscardableAttr("libname")) {
|
||||
auto name = libnameAttr.dyn_cast<StringAttr>();
|
||||
auto path = func.getOperation()
|
||||
->getDiscardableAttr("libpath")
|
||||
.dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string libName = name.str();
|
||||
externLibs[libName] = path.str();
|
||||
@@ -171,11 +171,8 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
}
|
||||
}
|
||||
|
||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||
auto dict = module.getOperation()
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
if (auto externsAttr = module->getDiscardableAttr("triton_gpu.externs")) {
|
||||
for (auto &attr : externsAttr.cast<DictionaryAttr>()) {
|
||||
externLibs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Bytecode/BytecodeWriter.h"
|
||||
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
@@ -350,6 +352,14 @@ void init_triton_ir(py::module &&m) {
|
||||
self.print(os);
|
||||
return str;
|
||||
})
|
||||
.def("bytecode",
|
||||
[](mlir::ModuleOp &self) -> py::bytearray {
|
||||
std::string bytecode;
|
||||
llvm::raw_string_ostream os(bytecode);
|
||||
if (failed(mlir::writeBytecodeToFile(self, os)))
|
||||
throw std::runtime_error("Failed to write module bytecode");
|
||||
return py::bytearray(bytecode);
|
||||
})
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp &self, mlir::triton::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
@@ -441,11 +451,15 @@ void init_triton_ir(py::module &&m) {
|
||||
// 1. Unreachable code after return
|
||||
self.walk([&](mlir::Block *block) {
|
||||
mlir::Operation *retOp = nullptr;
|
||||
block->walk([&](mlir::Operation *op) {
|
||||
// It's better to not use walk here because we only want to
|
||||
// check operations in the current block
|
||||
for (auto &op : block->getOperations()) {
|
||||
if (mlir::isa<mlir::triton::ReturnOp>(op))
|
||||
if (retOp == nullptr)
|
||||
retOp = op;
|
||||
});
|
||||
if (retOp == nullptr) {
|
||||
retOp = &op;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (retOp && retOp != &block->back()) {
|
||||
auto pos = retOp->getIterator();
|
||||
pos++;
|
||||
@@ -1413,8 +1427,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_get_program_id",
|
||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (axis < 0 || axis > 3)
|
||||
throw std::runtime_error("program_id must be in [0,3]");
|
||||
return self.create<mlir::triton::GetProgramIdOp>(
|
||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis));
|
||||
loc, self.getI32Type(),
|
||||
mlir::triton::ProgramIDDimAttr::get(
|
||||
loc.getContext(), mlir::triton::ProgramIDDim(axis)));
|
||||
})
|
||||
.def("create_get_num_programs",
|
||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -66,3 +69,162 @@ def test_chained_matmul():
|
||||
block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
|
||||
|
||||
def test_vecmat():
|
||||
@triton.jit
|
||||
def batched_vecmat(
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
dim_m, dim_n, dim_k,
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
|
||||
):
|
||||
m_index = tl.program_id(0)
|
||||
n_index = tl.program_id(1)
|
||||
# Output tile
|
||||
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
|
||||
|
||||
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
|
||||
k_blocks = dim_k // block_k
|
||||
for k_index in range(k_blocks):
|
||||
# Load A tile
|
||||
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
|
||||
a = tl.load(A + a_tile)
|
||||
|
||||
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
|
||||
# leading dimension.
|
||||
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
|
||||
b = tl.load(B + b_tile)
|
||||
|
||||
expanded_a, _ = tl.broadcast(a, b)
|
||||
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
|
||||
|
||||
tl.store(output + output_tile, vecmat)
|
||||
|
||||
M, N, K = 128, 128, 128
|
||||
block_m, block_n, block_k = 16, 32, 64
|
||||
|
||||
rs = RandomState(17)
|
||||
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
|
||||
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
|
||||
A = A_vec
|
||||
B = B_vec
|
||||
|
||||
A_tri = torch.tensor(A, device='cuda')
|
||||
B_tri = torch.tensor(B, device='cuda')
|
||||
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
|
||||
|
||||
grid = (M // block_m, N // block_n)
|
||||
|
||||
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
|
||||
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||
num_warps=4, num_stages=1)
|
||||
|
||||
A_expanded = A[:, np.newaxis, :]
|
||||
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
|
||||
AB = A_broadcasted * B
|
||||
C_ref = np.sum(AB, axis=2)
|
||||
|
||||
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
def test_iv_dependent_matmul(type):
|
||||
@triton.jit
|
||||
def kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
type: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
a_ptrs = a_ptr
|
||||
b_ptrs = b_ptr
|
||||
if type == "post_load_two_iters":
|
||||
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_three_iters":
|
||||
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
|
||||
a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
if type == "pre_load":
|
||||
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_pre_mixed":
|
||||
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator += tl.dot(a, b)
|
||||
if type == "post_load":
|
||||
a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_pre_mixed":
|
||||
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_two_iters":
|
||||
a_ptrs = a_ptrs_next
|
||||
b_ptrs = b_ptrs_next
|
||||
a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_three_iters":
|
||||
a_ptrs = a_ptrs_next
|
||||
b_ptrs = b_ptrs_next
|
||||
a_ptrs_next = a_ptrs_next_next
|
||||
b_ptrs_next = b_ptrs_next_next
|
||||
a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
M = 256
|
||||
K = 256
|
||||
N = 256
|
||||
BLOCK_SIZE_K = 32
|
||||
BLOCK_SIZE_N = 32
|
||||
BLOCK_SIZE_M = 32
|
||||
|
||||
a = torch.rand((M, K), device='cuda')
|
||||
b = torch.rand((K, N), device='cuda')
|
||||
|
||||
torch_output = torch.mm(a, b)
|
||||
triton_output = torch.empty_like(
|
||||
torch_output, device=torch_output.device)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
|
||||
num_stages = 4 if type == "post_load_three_iters" else 3
|
||||
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
type=type, num_stages=num_stages)
|
||||
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -41,8 +48,16 @@ def test_assert(func: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
<<<<<<< HEAD
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
=======
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
|
||||
>>>>>>> oai/main
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
@@ -50,5 +65,72 @@ def test_assert(func: str):
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def jit_device_assert_none(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def jit_device_assert_true(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def jit_device_assert_false(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert_nested(caller: str, callee: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if caller == "none":
|
||||
kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
if len(sys.argv) == 3:
|
||||
test_assert_nested(sys.argv[1], sys.argv[2])
|
||||
else:
|
||||
test_assert(sys.argv[1])
|
||||
|
||||
@@ -130,6 +130,17 @@ class BlockedLayout:
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
class SharedLayout:
|
||||
def __init__(self, vec, per_phase, max_phase, order):
|
||||
self.vec = str(vec)
|
||||
self.per_phase = str(per_phase)
|
||||
self.max_phase = str(max_phase)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -456,6 +467,21 @@ def test_broadcast(dtype):
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
# ------------------
|
||||
# test invalid slice
|
||||
# ------------------
|
||||
|
||||
|
||||
def test_invalid_slice():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
dst[10:]
|
||||
|
||||
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
|
||||
_kernel[(1,)](dst=dst)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test expand_dims
|
||||
@@ -537,6 +563,20 @@ def test_expand_dims_error_cases():
|
||||
duplicate_dim2[(1,)](dummy_tensor, N)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# test invalid program id axis
|
||||
# ----------------------------
|
||||
def test_invalid_pid_axis():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pid = tl.program_id(20)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
|
||||
_kernel[(1,)](dst)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test where
|
||||
# ---------------
|
||||
@@ -1368,6 +1408,9 @@ reduce_configs2 = [
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
] + [
|
||||
(op, 'float32', [16, 32], None)
|
||||
for op in ['min', 'max', 'sum']
|
||||
]
|
||||
|
||||
|
||||
@@ -1382,7 +1425,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS == 1:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
@@ -1407,7 +1452,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
ret_numel = 1 if axis is None else shape[1 - axis]
|
||||
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
@@ -1958,6 +2004,23 @@ def test_full(dtype_str):
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("literal, dtype_str",
|
||||
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
|
||||
('float("inf")', "f32"), ('float("-inf")', "f32"),
|
||||
('float("nan")', "f32"), ('float("-nan")', "f32"),
|
||||
(0., "f32"),
|
||||
(5, "i32"), (2**40, "i64"),])
|
||||
def test_constexpr(literal, dtype_str):
|
||||
@triton.jit
|
||||
def kernel(out_ptr):
|
||||
val = GENERATE_TEST_HERE
|
||||
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
|
||||
|
||||
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
|
||||
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
|
||||
h = kernel_patched[(1,)](out)
|
||||
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
@@ -2628,41 +2691,64 @@ def add_fn_static_cond(x, cond: tl.constexpr):
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
|
||||
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit",
|
||||
"jit", "jit_if", "jit_ifexp", "jit_expr",
|
||||
"jit_static_cond", "jit_noinline", "jit_extern"])
|
||||
def test_if_call(call_type):
|
||||
@triton.jit
|
||||
def kernel(Out, call_type: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
o = tl.load(Out)
|
||||
if pid == 0:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
a = o + 1
|
||||
a = a.to(tl.int32).to(tl.int32)
|
||||
o = a
|
||||
else:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
if pid == 0:
|
||||
a = o
|
||||
if call_type == "jit_function":
|
||||
# regular function call
|
||||
a = add_fn(a)
|
||||
elif call_type == "jit_function_return":
|
||||
# function without end_if block
|
||||
a = add_fn_return(a, pid)
|
||||
elif call_type == "ifexp":
|
||||
# ifexp expression
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
elif call_type == "expr":
|
||||
if pid == 1:
|
||||
return
|
||||
a = add_fn(a)
|
||||
if pid == 0:
|
||||
# call without return
|
||||
add_fn_expr(Out, a)
|
||||
elif call_type == "jit_function_static_cond":
|
||||
a = add_fn_static_cond(a, call_type)
|
||||
elif call_type == "jit_function_noinline":
|
||||
a = add_fn_noinline(a)
|
||||
a = a.to(tl.int32).to(tl.int32) + 1
|
||||
o = a
|
||||
elif call_type == "attribute_jit":
|
||||
# call attribute and jit function
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1
|
||||
o = a
|
||||
elif call_type == "jit":
|
||||
if pid == 0:
|
||||
# regular function call
|
||||
a = o
|
||||
a = add_fn(a)
|
||||
o = a
|
||||
elif call_type == "jit_if":
|
||||
# function without end_if block
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = add_fn_return(a, pid)
|
||||
o = a
|
||||
elif call_type == "jit_ifexp":
|
||||
# ifexp expression
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
o = a
|
||||
elif call_type == "jit_expr":
|
||||
# call without return
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_expr(Out, a)
|
||||
o = a
|
||||
elif call_type == "jit_static_cond":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_static_cond(o, call_type)
|
||||
o = a
|
||||
elif call_type == "jit_noinline":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_noinline(a)
|
||||
o = a
|
||||
elif call_type == "jit_extern":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
tl.cdiv(a, a)
|
||||
o = a
|
||||
|
||||
tl.store(Out, o)
|
||||
@@ -2766,7 +2852,7 @@ def test_globaltimer():
|
||||
def kernel(Out1, Out2):
|
||||
start = tl.extra.cuda.globaltimer()
|
||||
off = tl.arange(0, 128)
|
||||
for i in range(100):
|
||||
for i in range(10000):
|
||||
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
|
||||
end = tl.extra.cuda.globaltimer()
|
||||
tl.store(Out2, end - start)
|
||||
@@ -2810,22 +2896,49 @@ layouts = [
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
intermediate_layouts = [
|
||||
None,
|
||||
SharedLayout(1, 1, 1, [1, 0]),
|
||||
SharedLayout(4, 2, 4, [1, 0]),
|
||||
SharedLayout(2, 2, 4, [1, 0]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
layouts = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" if interm_layout is None else f"""
|
||||
#src = {src_layout}
|
||||
#interm = {interm_layout}
|
||||
#dst = {dst_layout}
|
||||
"""
|
||||
|
||||
conversion = f"""
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" if interm_layout is None else f"""
|
||||
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
|
||||
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
|
||||
|
||||
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
"""
|
||||
|
||||
ir = layouts + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
@@ -2840,8 +2953,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" + conversion + """
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
tt.return
|
||||
|
||||
@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
|
||||
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
|
||||
def test_assert_nested(caller_type, callee_type):
|
||||
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
if caller_type == "none":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
elif caller_type == "true":
|
||||
if callee_type == "false":
|
||||
assert num_errs == 0
|
||||
else:
|
||||
assert num_errs == 127
|
||||
elif caller_type == "false":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
@@ -5,7 +5,10 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
|
||||
(4, 48, 1024, 32),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 1024, 128)])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
assert len(kernel_add.cache[device]) == 3
|
||||
bins = list(kernel_add.cache[device].values())
|
||||
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
|
||||
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -97,6 +97,97 @@ class enter_sub_region:
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
|
||||
# Check if the given syntax node has an "early" return
|
||||
class ContainsReturnChecker(ast.NodeVisitor):
|
||||
def __init__(self, gscope):
|
||||
self.gscope = gscope
|
||||
|
||||
def _visit_stmts(self, body) -> bool:
|
||||
for s in body:
|
||||
if self.visit(s):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _visit_function(self, fn) -> bool:
|
||||
# Currently we only support JITFunctions defined in the global scope
|
||||
if isinstance(fn, JITFunction) and not fn.noinline:
|
||||
fn_node = fn.parse()
|
||||
return ContainsReturnChecker(self.gscope).visit(fn_node)
|
||||
return False
|
||||
|
||||
def generic_visit(self, node) -> bool:
|
||||
ret = False
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
ret = ret or self.visit(item)
|
||||
elif isinstance(value, ast.AST):
|
||||
ret = ret or self.visit(value)
|
||||
return ret
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> bool:
|
||||
# If the left part is a name, it's possible that
|
||||
# we call triton native function or a jit function from another module.
|
||||
# If the left part is not a name, it must return a tensor or a constexpr
|
||||
# whose methods do not contain return statements
|
||||
# e.g., (tl.load(x)).to(y)
|
||||
# So we only check if the expressions within value have return or not
|
||||
if isinstance(node.value, ast.Name):
|
||||
if node.value.id in self.gscope:
|
||||
value = self.gscope[node.value.id]
|
||||
fn = getattr(value, node.attr)
|
||||
return self._visit_function(fn)
|
||||
return False
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> bool:
|
||||
if type(node.ctx) == ast.Store:
|
||||
return False
|
||||
if node.id in self.gscope:
|
||||
fn = self.gscope[node.id]
|
||||
return self._visit_function(fn)
|
||||
return False
|
||||
|
||||
def visit_Return(self, node: ast.Return) -> bool:
|
||||
return True
|
||||
|
||||
def visit_Assign(self, node: ast.Assign) -> bool:
|
||||
# There couldn't be an early return
|
||||
# x = ...
|
||||
return False
|
||||
|
||||
def visit_AugAssign(self, node: ast.AugAssign) -> bool:
|
||||
# There couldn't be an early return
|
||||
# x += ...
|
||||
return False
|
||||
|
||||
def visit_Module(self, node: ast.Module) -> bool:
|
||||
return self._visit_stmts(node.body)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
|
||||
return self._visit_stmts(node.body)
|
||||
|
||||
def visit_If(self, node: ast.If) -> bool:
|
||||
# TODO: optimize the following case in which we actually don't have
|
||||
# a return when static_cond is false:
|
||||
# if dynamic_cond
|
||||
# if static_cond
|
||||
# func_with_return
|
||||
# else
|
||||
# func_without_return
|
||||
ret = self._visit_stmts(node.body)
|
||||
if node.orelse:
|
||||
ret = ret or self._visit_stmts(node.orelse)
|
||||
return ret
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> bool:
|
||||
return self.visit(node.body) or self.visit(node.orelse)
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> bool:
|
||||
return self.visit(node.func)
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
@@ -166,63 +257,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if ret_type is not None and isinstance(stmt, ast.Return):
|
||||
self.last_ret_type = ret_type
|
||||
|
||||
# TODO: should be its own AST visitor
|
||||
def contains_return_op(self, node):
|
||||
if isinstance(node, ast.Return):
|
||||
return True
|
||||
elif isinstance(node, ast.Assign):
|
||||
return self.contains_return_op(node.value)
|
||||
elif isinstance(node, ast.Module):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
return any(pred(s) for s in node.body)
|
||||
elif isinstance(node, ast.FunctionDef):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
return any(pred(s) for s in node.body)
|
||||
elif isinstance(node, ast.Call):
|
||||
def check_undefined_name(cur_node):
|
||||
# Check if name is an undefined local variable,
|
||||
# which can only be a tensor or a constexpr
|
||||
if isinstance(cur_node.func, ast.Attribute):
|
||||
if isinstance(cur_node.func.value, ast.Name):
|
||||
name = cur_node.func.value.id
|
||||
if name not in self.lscope and name not in self.gscope:
|
||||
return True
|
||||
return False
|
||||
# chain of calls
|
||||
# e.g., tl.load(a).to(tl.float32)
|
||||
return check_undefined_name(cur_node.func.value)
|
||||
return False
|
||||
if check_undefined_name(node):
|
||||
return False
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, JITFunction) and fn.noinline is not True:
|
||||
old_gscope = self.gscope
|
||||
self.gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
ret = self.contains_return_op(fn.parse())
|
||||
self.gscope = old_gscope
|
||||
return ret
|
||||
return False
|
||||
elif isinstance(node, ast.If):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
ret = any(pred(s) for s in node.body)
|
||||
if node.orelse:
|
||||
ret = ret or any(pred(s) for s in node.orelse)
|
||||
return ret
|
||||
elif isinstance(node, ast.IfExp):
|
||||
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
|
||||
elif isinstance(node, ast.Expr):
|
||||
ret = False
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
ret = ret or self.contains_return_op(item)
|
||||
elif isinstance(value, ast.AST):
|
||||
ret = ret or self.contains_return_op(value)
|
||||
return ret
|
||||
else:
|
||||
return False
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
@@ -354,7 +388,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
value = _unwrap_if_constexpr(value)
|
||||
if not _is_triton_tensor(value) and \
|
||||
if value is not None and \
|
||||
not _is_triton_tensor(value) and \
|
||||
not isinstance(value, native_nontensor_types):
|
||||
value = language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
@@ -526,7 +561,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
if self.scf_stack or not self.contains_return_op(node):
|
||||
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
||||
if self.scf_stack and contains_return:
|
||||
raise UnsupportedLanguageConstruct(None, node,
|
||||
"Cannot have `return` statements inside `while` or `for` statements in triton")
|
||||
elif self.scf_stack or not contains_return:
|
||||
self.visit_if_scf(cond, node)
|
||||
else:
|
||||
self.visit_if_top_level(cond, node)
|
||||
@@ -825,7 +864,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
|
||||
# If the callee is not set, we use the same debug setting as the caller
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
|
||||
@@ -55,7 +55,17 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
return tensor(builder.get_fp32(x), float32)
|
||||
min_float32 = 2 ** -126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
abs_x == 0.0 or \
|
||||
x != x or \
|
||||
min_float32 <= abs_x <= max_float32:
|
||||
return tensor(builder.get_fp32(x), float32)
|
||||
else:
|
||||
return tensor(builder.get_fp64(x), float64)
|
||||
|
||||
elif isinstance(x, constexpr):
|
||||
return _to_tensor(x.value, builder)
|
||||
elif isinstance(x, tensor):
|
||||
@@ -701,7 +711,7 @@ class tensor:
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif sl == slice(None, None, None):
|
||||
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
|
||||
pass
|
||||
else:
|
||||
assert False, f"unsupported tensor index: {sl}"
|
||||
@@ -1307,8 +1317,8 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
@@ -1379,7 +1389,7 @@ def _max_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis):
|
||||
def max(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _max_combine)
|
||||
|
||||
@@ -1409,7 +1419,7 @@ def _min_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis):
|
||||
def min(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _min_combine)
|
||||
|
||||
@@ -1438,7 +1448,7 @@ def _sum_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis):
|
||||
def sum(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _sum_combine)
|
||||
|
||||
@@ -1450,7 +1460,7 @@ def _xor_combine(a, b):
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis, _builder=None, _generator=None):
|
||||
def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
@@ -1461,12 +1471,15 @@ def xor_sum(input, axis, _builder=None, _generator=None):
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Internal for debugging
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def debug_barrier(_builder=None):
|
||||
'''
|
||||
Insert a barrier to synchronize all threads in a block.
|
||||
'''
|
||||
return semantic.debug_barrier(_builder)
|
||||
|
||||
|
||||
@@ -1508,16 +1521,28 @@ def max_contiguous(input, values, _builder=None):
|
||||
|
||||
@builtin
|
||||
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
|
||||
'''
|
||||
Print the values at compile time. The parameters are the same as the builtin :code:`print`.
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
@builtin
|
||||
def static_assert(cond, msg="", _builder=None):
|
||||
'''
|
||||
Assert the condition at compile time. The parameters are the same as the builtin :code:`assert`.
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
@builtin
|
||||
def device_print(prefix, *args, _builder=None):
|
||||
'''
|
||||
Print the values at runtime from the device.
|
||||
|
||||
:param prefix: a prefix to print before the values. This is required to be a string literal.
|
||||
:param args: the values to print. They can be any tensor or scalar.
|
||||
'''
|
||||
import string
|
||||
prefix = _constexpr_to_value(prefix)
|
||||
assert isinstance(prefix, str), f"{prefix} is not string"
|
||||
@@ -1535,6 +1560,12 @@ def device_print(prefix, *args, _builder=None):
|
||||
|
||||
@builtin
|
||||
def device_assert(cond, msg="", _builder=None):
|
||||
'''
|
||||
Assert the condition at runtime from the device.
|
||||
|
||||
:param cond: the condition to assert. This is required to be a boolean tensor.
|
||||
:param msg: the message to print if the assertion fails. This is required to be a string literal.
|
||||
'''
|
||||
msg = _constexpr_to_value(msg)
|
||||
import inspect
|
||||
frame = inspect.currentframe()
|
||||
@@ -1560,7 +1591,22 @@ def device_assert(cond, msg="", _builder=None):
|
||||
|
||||
class static_range:
|
||||
|
||||
"""Iterator that counts upward forever."""
|
||||
"""
|
||||
Iterator that counts upward forever.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.jit
|
||||
def kernel(...):
|
||||
for i in tl.static_range(10):
|
||||
...
|
||||
:note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
|
||||
:code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
|
||||
:param arg1: the start value.
|
||||
:param arg2: the end value.
|
||||
:param step: the step value.
|
||||
"""
|
||||
|
||||
def __init__(self, arg1, arg2=None, step=None):
|
||||
assert isinstance(arg1, constexpr)
|
||||
|
||||
@@ -1280,6 +1280,13 @@ def where(condition: tl.tensor,
|
||||
def reduction(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
if axis is None:
|
||||
new_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
new_shape = [inputs[i].numel.value]
|
||||
new_inputs.append(view(inputs[i], new_shape, builder))
|
||||
inputs = tuple(new_inputs)
|
||||
axis = 0
|
||||
# get result shape
|
||||
shape = inputs[0].type.shape
|
||||
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
||||
|
||||
@@ -203,8 +203,7 @@ class _attention(torch.autograd.Function):
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# assert Lk in {16, 32, 64, 128}
|
||||
assert Lk in {64} # TODO: fix other cases
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
@@ -356,7 +356,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
|
||||
@@ -40,29 +40,33 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
:type fast_flush: bool
|
||||
"""
|
||||
|
||||
# Estimate the runtime of the function
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
|
||||
# We maintain a buffer of 256 MB that we clear
|
||||
# before each kernel call to make sure that the L2
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
if fast_flush:
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
else:
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
|
||||
# Estimate the runtime of the function
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
cache.zero_()
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
fn()
|
||||
|
||||
@@ -406,7 +406,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%pid = tt.get_program_id x : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
|
||||
@@ -438,7 +438,7 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
// CHECK-LABEL: @vecadd_mask_align_16
|
||||
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
@@ -467,7 +467,7 @@ tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
// CHECK-LABEL: @vecadd_mask_align_1
|
||||
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
|
||||
@@ -76,40 +76,64 @@ tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: reduce_ops_infer
|
||||
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
// Test if reduce ops infer types correctly
|
||||
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 0
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<2x4xf32>
|
||||
%a = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
|
||||
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 1
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x4xf32>
|
||||
%b = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
|
||||
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 2
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x2xf32>
|
||||
%c = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 1
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<1x4xf32>) -> tensor<1xf32>
|
||||
%e = "tt.reduce" (%b) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 0
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<2x4xf32>) -> tensor<4xf32>
|
||||
%f = "tt.reduce" (%a) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32
|
||||
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 0
|
||||
// CHECK: tt.reduce.return
|
||||
// CHECK-NEXT: (tensor<4xf32>) -> f32
|
||||
%g = "tt.reduce" (%f) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
@@ -154,3 +178,12 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @print_no_arg
|
||||
tt.func @print_no_arg(%arg0: !tt.ptr<f32>) {
|
||||
// CHECK: tt.print "test"
|
||||
tt.print "test"
|
||||
%0 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
|
||||
tt.store %arg0, %0 {cache = 1 : i32, evict = 1 : i32} : f32
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -229,7 +229,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec4
|
||||
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -306,7 +306,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
|
||||
@@ -340,7 +340,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec2
|
||||
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -461,7 +461,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -643,8 +643,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
tt.func @basic_program_id() {
|
||||
<<<<<<< HEAD
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
=======
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
>>>>>>> oai/main
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1528,6 +1533,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
@@ -1537,6 +1543,14 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// GCN: rocdl.workgroup.id.x
|
||||
// GCN: rocdl.workgroup.id.y
|
||||
// GCN: rocdl.workgroup.id.z
|
||||
=======
|
||||
%blockidx = tt.get_program_id x : i32
|
||||
%blockidy = tt.get_program_id y : i32
|
||||
%blockidz = tt.get_program_id z : i32
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.z
|
||||
>>>>>>> oai/main
|
||||
%v0 = arith.addi %blockidx, %blockidy : i32
|
||||
%v1 = arith.addi %v0, %blockidz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
|
||||
@@ -10,8 +10,8 @@ tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.get_program_id y : i32
|
||||
%2 = arith.addi %arg3, %c127_i32 : i32
|
||||
%3 = arith.divsi %2, %c128_i32 : i32
|
||||
%4 = arith.addi %arg4, %c31_i32 : i32
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
module {
|
||||
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
|
||||
@@ -49,7 +49,7 @@ module {
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %cst = arith.constant 0.000000e+00 : f32
|
||||
// %c256_i32 = arith.constant 256 : i32
|
||||
// %0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// %0 = tt.get_program_id x : i32
|
||||
// %1 = arith.muli %0, %c256_i32 : i32
|
||||
// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
|
||||
@@ -86,7 +86,7 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout1>
|
||||
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout1>
|
||||
@@ -102,7 +102,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-LABEL: if_convert_else_not
|
||||
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -123,7 +123,7 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
// CHECK-LABEL: if_not_else_convert
|
||||
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -144,7 +144,7 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
// CHECK-LABEL: if_else_both_convert
|
||||
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -267,11 +267,63 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop_if
|
||||
tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
|
||||
%c1 = arith.constant 1 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%i0 = arith.constant 0 : i32
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%33 = "triton_gpu.cmpi"(%i0, %i0) {predicate = 4 : i64} : (i32, i32) -> i1
|
||||
%34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
scf.yield %27 : tensor<64x64xf32, #blocked1>
|
||||
} else {
|
||||
scf.yield %arg6 : tensor<64x64xf32, #blocked1>
|
||||
}
|
||||
%28 = arith.addf %arg6, %34 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: vecadd
|
||||
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
|
||||
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
|
||||
@@ -309,7 +361,7 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
|
||||
@@ -370,7 +422,7 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
|
||||
%cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0>
|
||||
%cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0>
|
||||
%cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c1024_i32 : i32
|
||||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0>
|
||||
@@ -757,7 +809,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2>
|
||||
%cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c16_i32 : i32
|
||||
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0>
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
@@ -856,7 +908,7 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
|
||||
%cst_4 = arith.constant dense<2048> : tensor<64x1xi32, #blocked2>
|
||||
%cst_5 = arith.constant dense<49152> : tensor<64x1xi32, #blocked2>
|
||||
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<64xi32, #blocked0>) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
@@ -992,7 +1044,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
|
||||
%c-1_i64 = arith.constant -1 : i64
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c-1_i32 = arith.constant -1 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
|
||||
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
|
||||
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
|
||||
@@ -1054,7 +1106,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop.
|
||||
// CHECK-LABEL: reduce_cvt2
|
||||
// Match the reduction
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 1
|
||||
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
@@ -1073,7 +1127,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%cst_3 = arith.constant dense<196> : tensor<1x256xi32, #blocked>
|
||||
%cst_4 = arith.constant dense<3136> : tensor<1x256xi32, #blocked>
|
||||
%cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}>
|
||||
#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}>
|
||||
#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#ALR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#ALC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
// CHECK: tt.func @push_elementwise1
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
@@ -17,36 +19,106 @@
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[AF16]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise1(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2>
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for row-row
|
||||
// CHECK: tt.func @push_elementwise2
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise2(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for col-row
|
||||
// CHECK: tt.func @push_elementwise3
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise3(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLR>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLR>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
// Not modified for col-col
|
||||
// CHECK: tt.func @push_elementwise4
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise4(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALC>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALC> -> tensor<16x16xf8E5M2, #ALC>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALC> -> tensor<16x16xf16, #ALC>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALC>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
|
||||
// Not modified for Volta
|
||||
// CHECK: tt.func @push_elementwise5
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
|
||||
tt.func @push_elementwise5(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1>
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #ALR>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BLC>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av1>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv1>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv1>
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%c63_i32 = arith.constant 63 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.addi %arg3, %c63_i32 : i32
|
||||
%2 = arith.divsi %1, %c64_i32 : i32
|
||||
%3 = arith.addi %arg4, %c63_i32 : i32
|
||||
|
||||
Reference in New Issue
Block a user