mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ```
This commit is contained in:
@@ -12,13 +12,26 @@ namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.getOperand().getType().cast<RankedTensorType>();
|
||||
explicit ReduceOpHelper(triton::ReduceOp rop)
|
||||
: op(rop.getOperation()), axis(rop.getAxis()) {
|
||||
auto firstTy = rop.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
srcShape = firstTy.getShape();
|
||||
srcEncoding = firstTy.getEncoding();
|
||||
srcElementTypes = rop.getElementTypes();
|
||||
|
||||
for (const auto &t : rop.getInputTypes()) {
|
||||
if (t.getShape() != srcShape) {
|
||||
rop.emitError() << "shape mismatch";
|
||||
}
|
||||
if (t.getEncoding() != srcEncoding) {
|
||||
rop.emitError() << "encoding mismatch";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
ArrayRef<int64_t> getSrcShape() { return srcShape; }
|
||||
|
||||
Attribute getSrcLayout() { return srcTy.getEncoding(); }
|
||||
Attribute getSrcLayout() { return srcEncoding; }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
@@ -37,8 +50,11 @@ public:
|
||||
bool isSupportedLayout();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
Operation *op;
|
||||
ArrayRef<int64_t> srcShape;
|
||||
Attribute srcEncoding;
|
||||
SmallVector<Type> srcElementTypes;
|
||||
int axis;
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
@@ -34,30 +34,6 @@ def TT_PaddingOptionAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||
I32EnumAttrCase<"FADD", 2, "fadd">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"MAX", 4, "max">,
|
||||
I32EnumAttrCase<"UMIN", 5, "umin">,
|
||||
I32EnumAttrCase<"UMAX", 6, "umax">,
|
||||
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
|
||||
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
|
||||
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
|
||||
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
|
||||
I32EnumAttrCase<"FMIN", 11, "fmin">,
|
||||
I32EnumAttrCase<"FMAX", 12, "fmax">,
|
||||
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
|
||||
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
|
||||
I32EnumAttrCase<"XOR", 15, "xor">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
|
||||
@@ -388,27 +388,35 @@ def TT_DotOp : TT_Op<"dot", [Pure,
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
def TT_ReduceOp: TT_Op<"reduce",
|
||||
[Pure,
|
||||
SameOperandsEncoding,
|
||||
SingleBlock,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Reduction using generic combination algorithm";
|
||||
let arguments = (ins Variadic<TT_Tensor>:$operands, I32Attr:$axis);
|
||||
let results = (outs Variadic<TT_Type>:$result);
|
||||
let regions = (region SizedRegion<1>:$combineOp);
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
OpBuilder<(ins "ValueRange":$operands, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasRegionVerifier = 1;
|
||||
let extraClassDeclaration = [{
|
||||
// This member function is marked static because we need to call it before the ReduceOp
|
||||
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||
static bool withIndex(mlir::triton::RedOp redOp);
|
||||
llvm::SmallVector<RankedTensorType> getInputTypes();
|
||||
llvm::SmallVector<Type> getElementTypes();
|
||||
unsigned getNumOperands();
|
||||
}];
|
||||
}
|
||||
|
||||
def TT_ReduceReturnOp: TT_Op<"reduce.return",
|
||||
[HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> {
|
||||
let summary = "terminator for reduce operator";
|
||||
let arguments = (ins Variadic<AnyType>:$result);
|
||||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// External Elementwise op
|
||||
//
|
||||
|
||||
@@ -103,7 +103,7 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,10 @@ void MembarAnalysis::visitTerminator(Operation *op,
|
||||
return;
|
||||
}
|
||||
// Otherwise, it could be a return op
|
||||
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
|
||||
if (isa<triton::ReduceReturnOp>(op) || isa<triton::ReturnOp>(op)) {
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("Unknown terminator encountered in membar analysis");
|
||||
}
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
|
||||
@@ -10,49 +10,38 @@
|
||||
namespace mlir {
|
||||
|
||||
bool ReduceOpHelper::isFastReduction() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto axis = op.getAxis();
|
||||
return axis == triton::gpu::getOrder(srcLayout)[0];
|
||||
return axis == triton::gpu::getOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.getAxis();
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
|
||||
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.getAxis();
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
|
||||
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto axis = op.getAxis();
|
||||
auto srcLayout = getSrcLayout();
|
||||
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
auto axis = op.getAxis();
|
||||
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto axis = op.getAxis();
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
auto argLayout = srcTy.getEncoding();
|
||||
auto argLayout = getSrcLayout();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
||||
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
@@ -64,7 +53,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
|
||||
/// FIXME(Qingyi): This size is actually larger than required.
|
||||
/// shared memory block1:
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
smemShapes[1].push_back(numWarps * 32);
|
||||
|
||||
@@ -82,17 +71,15 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
elems = product<unsigned>(smemShape);
|
||||
}
|
||||
|
||||
auto tensorType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
|
||||
if (triton::ReduceOp::withIndex(op.getRedOp()))
|
||||
bytes += elems * sizeof(int32_t);
|
||||
|
||||
return bytes;
|
||||
unsigned bytesPerElem = 0;
|
||||
for (const auto &ty : srcElementTypes) {
|
||||
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
|
||||
}
|
||||
return bytesPerElem * elems;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isSupportedLayout() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcLayout = getSrcLayout();
|
||||
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1073,12 +1073,14 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
|
||||
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
|
||||
@@ -23,112 +23,59 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RedOp redOp, Value &acc, Value cur, bool isFirst) const {
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
||||
llvm::SmallVectorImpl<Value> &acc, ValueRange cur,
|
||||
bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc = cur;
|
||||
acc.resize(cur.size());
|
||||
for (unsigned i = 0; i < cur.size(); ++i) {
|
||||
acc[i] = cur[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
switch (redOp) {
|
||||
case RedOp::ADD:
|
||||
acc = add(acc, cur);
|
||||
break;
|
||||
case RedOp::FADD:
|
||||
acc = fadd(acc.getType(), acc, cur);
|
||||
break;
|
||||
case RedOp::MIN:
|
||||
acc = smin(acc, cur);
|
||||
break;
|
||||
case RedOp::MAX:
|
||||
acc = smax(acc, cur);
|
||||
break;
|
||||
case RedOp::UMIN:
|
||||
acc = umin(acc, cur);
|
||||
break;
|
||||
case RedOp::UMAX:
|
||||
acc = umax(acc, cur);
|
||||
break;
|
||||
case RedOp::FMIN:
|
||||
acc = fmin(acc, cur);
|
||||
break;
|
||||
case RedOp::FMAX:
|
||||
acc = fmax(acc, cur);
|
||||
break;
|
||||
case RedOp::XOR:
|
||||
acc = xor_(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGMIN:
|
||||
case RedOp::ARGMAX:
|
||||
case RedOp::ARGUMIN:
|
||||
case RedOp::ARGUMAX:
|
||||
case RedOp::ARGFMIN:
|
||||
case RedOp::ARGFMAX:
|
||||
llvm::report_fatal_error(
|
||||
"This accumulate implementation is not for argmin / argmax");
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
|
||||
// Create a new copy of the reduce block, and inline it
|
||||
Block *currentBlock = rewriter.getBlock();
|
||||
Region &parent = *currentBlock->getParent();
|
||||
rewriter.cloneRegionBefore(combineOp, &parent.front());
|
||||
auto &newReduce = parent.front();
|
||||
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
|
||||
|
||||
llvm::SmallVector<Value> combineArgs(2 * acc.size());
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
combineArgs[i] = acc[i];
|
||||
combineArgs[acc.size() + i] = cur[i];
|
||||
}
|
||||
|
||||
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
|
||||
combineArgs);
|
||||
|
||||
auto results = returnOp.getResult();
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
acc[i] = results[i];
|
||||
}
|
||||
|
||||
// Delete the terminator, which is no longer used
|
||||
rewriter.eraseOp(returnOp);
|
||||
}
|
||||
|
||||
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
||||
Value curIndex, bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc = cur;
|
||||
accIndex = curIndex;
|
||||
return;
|
||||
}
|
||||
switch (redOp) {
|
||||
case RedOp::ARGMIN:
|
||||
accIndex = select(
|
||||
icmp_slt(acc, cur), accIndex,
|
||||
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = smin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGMAX:
|
||||
accIndex = select(
|
||||
icmp_sgt(acc, cur), accIndex,
|
||||
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = smax(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGUMIN:
|
||||
accIndex = select(
|
||||
icmp_ult(acc, cur), accIndex,
|
||||
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = umin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGUMAX:
|
||||
accIndex = select(
|
||||
icmp_ugt(acc, cur), accIndex,
|
||||
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = umax(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGFMIN:
|
||||
accIndex = select(
|
||||
fcmp_olt(acc, cur), accIndex,
|
||||
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = fmin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGFMAX:
|
||||
accIndex = select(
|
||||
fcmp_ogt(acc, cur), accIndex,
|
||||
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = fmax(acc, cur);
|
||||
break;
|
||||
case RedOp::ADD:
|
||||
case RedOp::FADD:
|
||||
case RedOp::MIN:
|
||||
case RedOp::MAX:
|
||||
case RedOp::UMIN:
|
||||
case RedOp::UMAX:
|
||||
case RedOp::FMIN:
|
||||
case RedOp::FMAX:
|
||||
case RedOp::XOR:
|
||||
llvm::report_fatal_error(
|
||||
"This accumulate implementation is only for argmin / argmax");
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
SmallVector<SmallVector<Value>>
|
||||
unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto types = op.getInputTypes();
|
||||
auto operands = adaptor.getOperands();
|
||||
unsigned srcElems = getElemsPerThread(types[0]);
|
||||
SmallVector<SmallVector<Value>> srcValues(srcElems);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto values = getTypeConverter()->unpackLLElements(loc, operands[i],
|
||||
rewriter, types[i]);
|
||||
|
||||
assert(values.size() == srcValues.size());
|
||||
for (unsigned j = 0; j < srcValues.size(); ++j) {
|
||||
srcValues[j].push_back(values[j]);
|
||||
}
|
||||
}
|
||||
return srcValues;
|
||||
}
|
||||
|
||||
// Calculates the write index in the shared memory where we would be writing
|
||||
@@ -177,63 +124,64 @@ private:
|
||||
matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ReduceOpHelper helper(op);
|
||||
Location loc = op->getLoc();
|
||||
Location loc = op.getLoc();
|
||||
unsigned axis = op.getAxis();
|
||||
// Specifies whether the reduce operation returns an index
|
||||
// rather than a value, e.g. argmax, argmin, .. etc
|
||||
bool withIndex = triton::ReduceOp::withIndex(op.getRedOp());
|
||||
|
||||
auto srcTy = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
if (!helper.isSupportedLayout()) {
|
||||
assert(false && "Unexpected srcLayout in ReduceOpConversion");
|
||||
}
|
||||
// The order of the axes for the the threads within the warp
|
||||
auto srcOrd = triton::gpu::getOrder(srcLayout);
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout);
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcShape = helper.getSrcShape();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
SmallVector<Type> elemPtrTys(srcTys.size());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto ty = srcTys[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto smemShape = helper.getScratchConfigBasic();
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTys[0]);
|
||||
// Emits indices of the original tensor that each thread
|
||||
// would own
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
auto srcValues = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperand(), rewriter, srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
|
||||
// Emits offsets (the offset from the base index)
|
||||
// of the original tensor that each thread would own
|
||||
// NOTE: Assumes offsets don't actually depend on type
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcTy);
|
||||
emitOffsetForLayout(srcLayout, srcTys[0]);
|
||||
|
||||
// Keep track of accumulations and their indices
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
Region *combineOp = &op.getCombineOp();
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i],
|
||||
isFirst);
|
||||
} else {
|
||||
Value curIndex = srcIndices[i][axis];
|
||||
accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key],
|
||||
accIndices[key], srcValues[i], curIndex, isFirst);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
@@ -250,24 +198,20 @@ private:
|
||||
// reduce across threads
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
Value accIndex;
|
||||
if (withIndex)
|
||||
accIndex = accIndices[key];
|
||||
auto &acc = it.second;
|
||||
// get the writeIdx at which to write in smem
|
||||
SmallVector<Value> writeIdx;
|
||||
getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints,
|
||||
axis);
|
||||
|
||||
// calculate the offset in smem for that writeIdx
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||
// Get element pointers for the value and index
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
// Store the within-thread accumulated value at writePtr
|
||||
store(acc, writePtr);
|
||||
// Store the index of within-thread accumulation at indexWritePtr
|
||||
if (withIndex)
|
||||
store(accIndex, indexWritePtr);
|
||||
SmallVector<Value> writePtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
// Store the within-thread accumulated value into shared memory
|
||||
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
store(acc[i], writePtrs[i]);
|
||||
}
|
||||
|
||||
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||
// Perform parallel reduction with sequential addressing
|
||||
@@ -286,27 +230,24 @@ private:
|
||||
Value readOffset = select(
|
||||
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
|
||||
ints[0]);
|
||||
// The readPtr is readOffset away from writePtr
|
||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||
SmallVector<Value> readPtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
// The readPtr is readOffset away from writePtr
|
||||
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
|
||||
}
|
||||
|
||||
barrier();
|
||||
// If we do not care about the index, i.e. this is not an argmax,
|
||||
// argmin, .. etc
|
||||
if (!withIndex) {
|
||||
// The value at the readPtr, whereas acc is the value at writePtr
|
||||
Value cur = load(readPtr);
|
||||
accumulate(rewriter, loc, op.getRedOp(), acc, cur, false);
|
||||
barrier();
|
||||
// Update writePtr value
|
||||
store(acc, writePtr);
|
||||
} else {
|
||||
Value cur = load(readPtr);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
|
||||
Value curIndex = load(indexReadPtr);
|
||||
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, cur,
|
||||
curIndex, false);
|
||||
barrier();
|
||||
store(acc, writePtr);
|
||||
store(accIndex, indexWritePtr);
|
||||
// Combine accumulator value from another thread
|
||||
SmallVector<Value> cur(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
cur[i] = load(readPtrs[i]);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, cur, false);
|
||||
|
||||
barrier();
|
||||
// Publish our new accumulator value to shared memory
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
store(acc[i], writePtrs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -314,33 +255,37 @@ private:
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
auto resultShape = resultTy.getShape();
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned j = 0; j < resultElems; ++j) {
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
results[i] = load(smemBases[i]);
|
||||
}
|
||||
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
|
||||
auto parentBlock = op.getOperation()->getBlock();
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -351,60 +296,59 @@ private:
|
||||
ReduceOpHelper helper(op);
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = adaptor.getAxis();
|
||||
bool withIndex = triton::ReduceOp::withIndex(op.getRedOp());
|
||||
|
||||
auto srcTy = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
if (!helper.isSupportedLayout()) {
|
||||
assert(false && "Unexpected srcLayout in ReduceOpConversion");
|
||||
}
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto order = getOrder(srcLayout);
|
||||
auto srcOrd = triton::gpu::getOrder(srcLayout);
|
||||
auto srcShape = helper.getSrcShape();
|
||||
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
SmallVector<Type> elemPtrTys(srcTys.size());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto ty = srcTys[i].getElementType();
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty);
|
||||
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
}
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
auto srcValues = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperand(), rewriter, srcTy);
|
||||
unsigned srcElems = getElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcTy);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
// Assumes offsets don't actually depend on type
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcTys[0]);
|
||||
|
||||
auto *combineOp = &op.getCombineOp();
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i],
|
||||
isFirst);
|
||||
} else {
|
||||
Value curIndex = srcIndices[i][axis];
|
||||
accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key],
|
||||
accIndices[key], srcValues[i], curIndex, isFirst);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
@@ -414,6 +358,9 @@ private:
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
||||
auto order = getOrder(srcLayout);
|
||||
SmallVector<Value> multiDimLaneId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
@@ -427,32 +374,24 @@ private:
|
||||
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
Value accIndex;
|
||||
if (withIndex)
|
||||
accIndex = accIndices[key];
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
// Reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
if (withIndex) {
|
||||
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,39 +408,36 @@ private:
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
// FIXME(Qingyi): need predicate icmp_slt(threadId,
|
||||
// i32_val(sizeInerWarps))
|
||||
Value acc = load(readPtr);
|
||||
Value accIndex;
|
||||
if (withIndex) {
|
||||
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
accIndex = load(readIndexPtr);
|
||||
SmallVector<Value> acc(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
acc[i] = load(readPtr);
|
||||
}
|
||||
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
// only the first thread in each sizeInterWarps is writing
|
||||
Value writeOffset = readOffset;
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
SmallVector<Value> writePtrs(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
|
||||
}
|
||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value laneIdModSizeInterWarpsIsZero =
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||
storeShared(rewriter, loc, writePtr, acc, pred);
|
||||
if (withIndex) {
|
||||
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
|
||||
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
storeShared(rewriter, loc, writePtrs[i], acc[i], pred);
|
||||
}
|
||||
|
||||
if (round != elemsPerThread - 1) {
|
||||
@@ -515,32 +451,34 @@ private:
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t j = 0; j < resultElems; ++j) {
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
}
|
||||
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
results[i] = load(smemBases[i]);
|
||||
}
|
||||
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -46,15 +46,30 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
|
||||
Value TritonGPUToLLVMTypeConverter::packLLElements(
|
||||
Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
auto structType = this->convertType(type);
|
||||
if (!structType.isa<LLVM::LLVMStructType>()) {
|
||||
auto structType = this->convertType(type).dyn_cast<LLVM::LLVMStructType>();
|
||||
if (!structType) {
|
||||
assert(resultVals.size() == 1);
|
||||
return *resultVals.begin();
|
||||
}
|
||||
|
||||
auto elementTypes = structType.getBody();
|
||||
if (elementTypes.size() != resultVals.size()) {
|
||||
emitError(loc) << " size mismatch when packing elements for LLVM struct"
|
||||
<< " expected " << elementTypes.size() << " but got "
|
||||
<< resultVals.size();
|
||||
}
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
// llvm::outs() << structType << "\n";
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
if (!v.value()) {
|
||||
emitError(loc)
|
||||
<< "cannot insert null values into struct, but tried to insert"
|
||||
<< v.value();
|
||||
}
|
||||
if (v.value().getType() != elementTypes[v.index()]) {
|
||||
emitError(loc) << "invalid element type in packLLEElements. Expected "
|
||||
<< elementTypes[v.index()] << " but got "
|
||||
<< v.value().getType();
|
||||
}
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
|
||||
}
|
||||
return llvmStruct;
|
||||
|
||||
@@ -68,13 +68,15 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
if (value.getElementType().isInteger(1) && value.isSplat())
|
||||
// Workaround until https://reviews.llvm.org/D133743 is included.
|
||||
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retType);
|
||||
if (dyn_cast<RankedTensorType>(retType)) {
|
||||
assert(value);
|
||||
if (value.getElementType().isInteger(1) && value.isSplat())
|
||||
// Workaround until https://reviews.llvm.org/D133743 is included.
|
||||
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retType);
|
||||
}
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
|
||||
adaptor.getAttributes());
|
||||
@@ -469,10 +471,28 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()),
|
||||
adaptor.getAttributes());
|
||||
auto newReduce = rewriter.create<triton::ReduceOp>(
|
||||
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
|
||||
addNamedAttrs(newReduce, adaptor.getAttributes());
|
||||
|
||||
auto &newCombineOp = newReduce.getCombineOp();
|
||||
rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp,
|
||||
newCombineOp.end());
|
||||
rewriter.replaceOp(op, newReduce.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReduceReturnPattern
|
||||
: public OpConversionPattern<triton::ReduceReturnOp> {
|
||||
using OpConversionPattern<triton::ReduceReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ReduceReturnOp>(
|
||||
op, adaptor.getResult()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -517,10 +537,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern,
|
||||
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -310,21 +310,10 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
}
|
||||
|
||||
//-- ReduceOp --
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||
auto redOp =
|
||||
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||
static mlir::LogicalResult
|
||||
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
|
||||
int axis, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
if (retShape.empty()) {
|
||||
// 0d-tensor -> scalar
|
||||
@@ -352,15 +341,114 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
|
||||
return redOp == mlir::triton::RedOp::ARGMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGUMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGUMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGFMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGFMAX;
|
||||
void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
mlir::ValueRange operands, int axis) {
|
||||
SmallVector<Type> inferredReturnTypes;
|
||||
for (unsigned i = 0; i < operands.size(); ++i) {
|
||||
auto argTy = operands[i].getType().cast<RankedTensorType>();
|
||||
auto retEltTy = argTy.getElementType();
|
||||
(void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes);
|
||||
}
|
||||
|
||||
ReduceOp::build(builder, state, inferredReturnTypes, operands, axis);
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
for (auto arg : operands) {
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retEltTy = argTy.getElementType();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes)
|
||||
.failed()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::verify() {
|
||||
if (this->getOperands().size() < 1) {
|
||||
return this->emitOpError() << "must have at least 1 operand";
|
||||
}
|
||||
for (const auto &operand : this->getOperands()) {
|
||||
if (!dyn_cast<RankedTensorType>(operand.getType())) {
|
||||
return this->emitOpError() << "operands must be RankedTensorType";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() {
|
||||
auto argElementTypes = this->getElementTypes();
|
||||
const auto &operands = this->getOperands();
|
||||
const auto numArgs = 2 * operands.size();
|
||||
auto &block = *this->getBody();
|
||||
if (block.getNumArguments() != numArgs) {
|
||||
return this->emitOpError() << "nested block must take " << numArgs
|
||||
<< " arguments, but given block with "
|
||||
<< block.getNumArguments() << " arguments";
|
||||
}
|
||||
unsigned i = 0;
|
||||
const auto &blockArgTypes = block.getArgumentTypes();
|
||||
for (unsigned i = 0; i < numArgs; ++i) {
|
||||
const auto &blockArgTy = blockArgTypes[i];
|
||||
const auto &argElemTy = argElementTypes[i % operands.size()];
|
||||
if (blockArgTy != argElemTy) {
|
||||
return this->emitOpError()
|
||||
<< "type mismatch on combine operation. Expected argument " << i
|
||||
<< " to have type " << argElemTy << " but got " << blockArgTy;
|
||||
}
|
||||
}
|
||||
|
||||
auto terminator =
|
||||
dyn_cast<mlir::triton::ReduceReturnOp>(block.getTerminator());
|
||||
if (!terminator) {
|
||||
return this->emitOpError()
|
||||
<< "combine operation must be terminated "
|
||||
<< "with a ReduceReturnOp but got " << block.getTerminator();
|
||||
}
|
||||
const auto &combineResults = terminator->getOperands();
|
||||
if (combineResults.size() != operands.size()) {
|
||||
return this->emitOpError()
|
||||
<< "expected combine operation to return " << operands.size()
|
||||
<< " values but got " << combineResults.size();
|
||||
}
|
||||
for (unsigned i = 0; i < combineResults.size(); ++i) {
|
||||
const auto &resultTy = combineResults[i].getType();
|
||||
const auto &argElemTy = argElementTypes[i];
|
||||
if (resultTy != argElemTy) {
|
||||
return this->emitOpError()
|
||||
<< "type mismatch on combine operation. Expected argument " << i
|
||||
<< " to have type " << argElemTy << " but got " << resultTy;
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::RankedTensorType> ReduceOp::getInputTypes() {
|
||||
llvm::SmallVector<RankedTensorType> srcTys;
|
||||
srcTys.reserve(this->getNumOperands());
|
||||
for (const auto &ty : this->getOperands().getTypes()) {
|
||||
srcTys.push_back(ty.cast<RankedTensorType>());
|
||||
}
|
||||
return srcTys;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type> ReduceOp::getElementTypes() {
|
||||
llvm::SmallVector<Type> srcElemTys;
|
||||
srcElemTys.reserve(this->getNumOperands());
|
||||
for (const auto &op : this->getOperands()) {
|
||||
srcElemTys.push_back(
|
||||
op.getType().cast<RankedTensorType>().getElementType());
|
||||
}
|
||||
return srcElemTys;
|
||||
}
|
||||
|
||||
unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
|
||||
auto value = adaptor.getSrc();
|
||||
|
||||
@@ -101,29 +101,59 @@ public:
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
triton::ReduceOp reduce;
|
||||
for (auto &use : convert.getResult().getUses()) {
|
||||
auto owner = use.getOwner();
|
||||
if (llvm::isa_and_nonnull<triton::ReduceOp>(owner)) {
|
||||
reduce = llvm::cast<triton::ReduceOp>(owner);
|
||||
break;
|
||||
auto owner = llvm::dyn_cast<triton::ReduceOp>(use.getOwner());
|
||||
if (!owner) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: This only moves conversions from the first argument which is
|
||||
// fine for argmin/argmax but may not be optimal generally
|
||||
if (convert.getResult() != owner.getOperands()[0]) {
|
||||
continue;
|
||||
}
|
||||
reduce = owner;
|
||||
break;
|
||||
}
|
||||
if (!reduce)
|
||||
return mlir::failure();
|
||||
|
||||
SmallVector<Value> newOperands = reduce.getOperands();
|
||||
|
||||
newOperands[0] = convert.getOperand();
|
||||
auto newEncoding =
|
||||
newOperands[0].getType().cast<RankedTensorType>().getEncoding();
|
||||
|
||||
// this may generate unsupported conversions in the LLVM codegen
|
||||
if (convert.getOperand()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
if (newEncoding.isa<triton::gpu::MmaEncodingAttr>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (unsigned i = 1; i < newOperands.size(); ++i) {
|
||||
auto oldTy = newOperands[i].getType().cast<RankedTensorType>();
|
||||
RankedTensorType newTy =
|
||||
RankedTensorType::Builder(oldTy).setEncoding(newEncoding);
|
||||
|
||||
newOperands[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newTy, newOperands[i]);
|
||||
}
|
||||
|
||||
auto newReduce = rewriter.create<triton::ReduceOp>(
|
||||
op->getLoc(), reduce.getRedOp(), convert.getOperand(),
|
||||
reduce.getAxis());
|
||||
Value newRet = newReduce.getResult();
|
||||
if (newRet.getType() != reduce.getResult().getType())
|
||||
newRet = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), reduce.getResult().getType(), newRet);
|
||||
rewriter.replaceAllUsesWith(reduce, newRet);
|
||||
op->getLoc(), newOperands, reduce.getAxis());
|
||||
auto &newCombineOp = newReduce.getCombineOp();
|
||||
rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp,
|
||||
newCombineOp.end());
|
||||
|
||||
SmallVector<Value> newRet = newReduce.getResult();
|
||||
auto oldTypes = reduce.getResult().getType();
|
||||
for (unsigned i = 0; i < reduce.getNumOperands(); ++i) {
|
||||
// it's still beneficial to move the conversion
|
||||
// to after the reduce if necessary since it will be
|
||||
// done on a rank-reduced tensor hence cheaper
|
||||
if (newRet[i].getType() != oldTypes[i])
|
||||
newRet[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), oldTypes[i], newRet[i]);
|
||||
}
|
||||
rewriter.replaceAllUsesWith(reduce.getResult(), newRet);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||
scf::ReduceReturnOp>();
|
||||
// We have custom versions of some arith operators
|
||||
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
|
||||
triton::TritonDialect, cf::ControlFlowDialect,
|
||||
|
||||
@@ -95,23 +95,6 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::RedOp>(m, "REDUCE_OP")
|
||||
.value("ADD", mlir::triton::RedOp::ADD)
|
||||
.value("FADD", mlir::triton::RedOp::FADD)
|
||||
.value("MIN", mlir::triton::RedOp::MIN)
|
||||
.value("MAX", mlir::triton::RedOp::MAX)
|
||||
.value("UMIN", mlir::triton::RedOp::UMIN)
|
||||
.value("UMAX", mlir::triton::RedOp::UMAX)
|
||||
.value("ARGMIN", mlir::triton::RedOp::ARGMIN)
|
||||
.value("ARGMAX", mlir::triton::RedOp::ARGMAX)
|
||||
.value("ARGUMIN", mlir::triton::RedOp::ARGUMIN)
|
||||
.value("ARGUMAX", mlir::triton::RedOp::ARGUMAX)
|
||||
.value("FMIN", mlir::triton::RedOp::FMIN)
|
||||
.value("FMAX", mlir::triton::RedOp::FMAX)
|
||||
.value("ARGFMIN", mlir::triton::RedOp::ARGFMIN)
|
||||
.value("ARGFMAX", mlir::triton::RedOp::ARGFMAX)
|
||||
.value("XOR", mlir::triton::RedOp::XOR);
|
||||
|
||||
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
|
||||
.value("ADD", mlir::triton::RMWOp::ADD)
|
||||
.value("FADD", mlir::triton::RMWOp::FADD)
|
||||
@@ -1349,21 +1332,20 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::math::AbsIOp>(loc, val);
|
||||
})
|
||||
.def("create_reduce",
|
||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
[](mlir::OpBuilder &self, std::vector<mlir::Value> operands,
|
||||
int axis) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto inputTensorType =
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
mlir::Type resType = withIndex ? self.getI32Type()
|
||||
: inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(shape, resType);
|
||||
return self.create<mlir::triton::ReduceOp>(loc, operands, axis);
|
||||
})
|
||||
.def("create_reduce_ret",
|
||||
[](mlir::OpBuilder &self, py::args args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
llvm::SmallVector<mlir::Value> return_values;
|
||||
for (const auto &arg : args) {
|
||||
return_values.push_back(py::cast<mlir::Value>(arg));
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
return self.create<mlir::triton::ReduceReturnOp>(loc,
|
||||
return_values);
|
||||
})
|
||||
.def("create_ptr_to_int",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val,
|
||||
|
||||
@@ -1295,10 +1295,15 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
||||
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src>
|
||||
%15 = tt.reduce %14 {{axis = {axis} : i32, redOp = 12 : i32}} : tensor<{M}x{N}xf32, #src> -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
|
||||
tt.store %12, %17 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
|
||||
%15 = "tt.reduce"(%14) ({{
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
%16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1
|
||||
%17 = arith.select %16, %arg3, %arg4 : f32
|
||||
tt.reduce.return %17 : f32
|
||||
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
|
||||
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
|
||||
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import ast
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
@@ -755,6 +756,43 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# Convert assert to triton's device_assert which happens on the device
|
||||
return language.core.device_assert(test, msg, _builder=self.builder)
|
||||
|
||||
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
||||
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in args if arg is not None]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
else:
|
||||
callee_ret_type = self.function_ret_types[fn_name]
|
||||
symbol = self.module.get_function(fn_name)
|
||||
call_op = self.builder.call(symbol, arg_vals)
|
||||
if call_op.get_num_results() == 0 or callee_ret_type is None:
|
||||
return None
|
||||
elif call_op.get_num_results() == 1:
|
||||
return tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = _unwrap_if_constexpr(self.visit(node.func))
|
||||
|
||||
@@ -768,44 +806,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in args if arg is not None]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
else:
|
||||
callee_ret_type = self.function_ret_types[fn_name]
|
||||
symbol = self.module.get_function(fn_name)
|
||||
call_op = self.builder.call(symbol, arg_vals)
|
||||
if call_op.get_num_results() == 0 or callee_ret_type is None:
|
||||
return None
|
||||
elif call_op.get_num_results() == 1:
|
||||
return tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
return self.call_JitFunction(fn, args, kws)
|
||||
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
extra_kwargs = dict(_builder=self.builder)
|
||||
sig = inspect.signature(fn)
|
||||
if '_generator' in sig.parameters:
|
||||
extra_kwargs['_generator'] = self
|
||||
return fn(*args, **extra_kwargs, **kws)
|
||||
if fn in self.builtin_namespace.values():
|
||||
args = map(_unwrap_if_constexpr, args)
|
||||
return fn(*args, **kws)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable, List, TypeVar
|
||||
@@ -1190,46 +1191,166 @@ def _add_reduction_docstr(name: str) -> Callable[[T], T]:
|
||||
return _decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _insertion_guard(builder):
|
||||
ip = builder.get_insertion_point()
|
||||
yield
|
||||
builder.restore_insertion_point(ip)
|
||||
|
||||
|
||||
@builtin
|
||||
def reduction(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
|
||||
|
||||
:param input: the input tensor, or tuple of tensors
|
||||
:param axis: the dimension along which the reduction should be done
|
||||
:param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduction((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
|
||||
|
||||
region = reduce_op.get_region(0)
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def _promote_reduction_input(t, _builder=None):
|
||||
scalar_ty = t.type.scalar
|
||||
# input is extended to 32-bits if necessary
|
||||
# this increases numerical accuracy and can be done pretty much for free
|
||||
# on GPUs
|
||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32:
|
||||
return t.to(int32, _builder=_builder)
|
||||
|
||||
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
||||
if scalar_ty is bfloat16:
|
||||
return t.to(float32, _builder=_builder)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
@builtin
|
||||
def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
n = input.shape[axis]
|
||||
index = arange(0, n, _builder=_builder)
|
||||
|
||||
if len(input.shape) > 1:
|
||||
new_shape = [constexpr(1)] * len(input.shape)
|
||||
new_shape[axis] = constexpr(n)
|
||||
index = view(index, new_shape, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduction((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return rindices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _max_combine(a, b):
|
||||
return maximum(a, b)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.max(input, axis, _builder)
|
||||
def max(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _max_combine)
|
||||
|
||||
|
||||
@builtin
|
||||
@triton.jit
|
||||
def _argmax_combine(value1, index1, value2, index2):
|
||||
gt = value1 > value2
|
||||
lt = value1 < value2
|
||||
index_min = minimum(index1, index2)
|
||||
index_ret = where(gt, index1, where(lt, index2, index_min))
|
||||
value_ret = maximum(value1, value2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmax(input, axis, _builder)
|
||||
def argmax(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return _argreduce(input, axis, _argmax_combine)
|
||||
|
||||
|
||||
@builtin
|
||||
@triton.jit
|
||||
def _min_combine(a, b):
|
||||
# TODO: minimum/maximum doesn't get lowered to fmin/fmax...
|
||||
return minimum(a, b)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.min(input, axis, _builder)
|
||||
def min(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _min_combine)
|
||||
|
||||
|
||||
@builtin
|
||||
@triton.jit
|
||||
def _argmin_combine(value1, index1, value2, index2):
|
||||
lt = value1 < value2
|
||||
gt = value1 > value2
|
||||
index_min = minimum(index1, index2)
|
||||
index_ret = where(lt, index1, where(gt, index2, index_min))
|
||||
value_ret = minimum(value1, value2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmin(input, axis, _builder)
|
||||
def argmin(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return _argreduce(input, axis, _argmin_combine)
|
||||
|
||||
|
||||
@builtin
|
||||
@triton.jit
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.sum(input, axis, _builder)
|
||||
def sum(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _sum_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.xor_sum(input, axis, _builder)
|
||||
def xor_sum(input, axis, _builder=None, _generator=None):
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = _promote_reduction_input(input, _builder=_builder)
|
||||
return reduction(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations # remove after python 3.11
|
||||
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
@@ -1228,91 +1228,36 @@ def where(condition: tl.tensor,
|
||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Reductions
|
||||
# Reduction
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
|
||||
scalar_ty = input.type.scalar
|
||||
out_scalar_ty = scalar_ty
|
||||
# input is extended to 32-bits if necessary
|
||||
# this increases numerical accuracy and can be done pretty much for free
|
||||
# on GPUs
|
||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
||||
input = cast(input, tl.int32, builder)
|
||||
out_scalar_ty = tl.int32
|
||||
def reduction(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
# get result shape
|
||||
shape = inputs[0].type.shape
|
||||
print(shape, axis)
|
||||
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
||||
for t in inputs:
|
||||
assert t.type.shape == shape
|
||||
|
||||
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
||||
if scalar_ty is tl.bfloat16:
|
||||
input = cast(input, tl.float32, builder)
|
||||
out_scalar_ty = tl.float32
|
||||
def wrap_tensor(x, scalar_ty):
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
else:
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = scalar_ty
|
||||
return tl.tensor(x, res_ty)
|
||||
|
||||
# choose the right unsigned operation
|
||||
if scalar_ty.is_int_unsigned():
|
||||
int_op_to_unit = {
|
||||
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
|
||||
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
|
||||
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
|
||||
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
|
||||
}
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
reduce_op = builder.create_reduce([t.handle for t in inputs], axis)
|
||||
region_builder_fn(reduce_op)
|
||||
reduce_op.verify()
|
||||
|
||||
# If we are doing an argmin or argmax we want to use an int32 output type
|
||||
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
|
||||
out_scalar_ty = tl.int32
|
||||
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
|
||||
out_scalar_ty = tl.int32
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
|
||||
rank = len(shape)
|
||||
assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})"
|
||||
|
||||
ret_shape = []
|
||||
for i, s in enumerate(shape):
|
||||
if i != axis:
|
||||
ret_shape.append(s)
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(out_scalar_ty, ret_shape)
|
||||
else:
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = out_scalar_ty
|
||||
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||
elif scalar_ty.is_int():
|
||||
return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty)
|
||||
assert False
|
||||
|
||||
|
||||
def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||
|
||||
|
||||
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||
|
||||
|
||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||
|
||||
|
||||
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||
|
||||
|
||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||
|
||||
|
||||
def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
|
||||
return tuple(
|
||||
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
|
||||
for i in range(len(inputs))
|
||||
)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
@@ -217,7 +217,11 @@ tt.func @alloc(%A : !tt.ptr<f16>) {
|
||||
tt.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: scratch offset = 0, size = 512
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
%b = "tt.reduce" (%cst0) ({
|
||||
^bb0(%arg0: f16, %arg1: f16):
|
||||
%add = arith.addf %arg0, %arg1 : f16
|
||||
tt.reduce.return %add : f16
|
||||
}) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
@@ -79,7 +79,11 @@ tt.func @scratch() {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
%2 = "tt.reduce" (%1) ({
|
||||
^bb0(%arg1: f16, %arg2: f16):
|
||||
%add = arith.addf %arg1, %arg2 : f16
|
||||
tt.reduce.return %add : f16
|
||||
}) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
@@ -79,18 +79,42 @@ tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
// Test if reduce ops infer types correctly
|
||||
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
||||
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
|
||||
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
|
||||
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
|
||||
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
|
||||
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
|
||||
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
|
||||
%a = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
|
||||
%b = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
|
||||
%c = "tt.reduce" (%v) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
|
||||
%e = "tt.reduce" (%b) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
|
||||
%f = "tt.reduce" (%a) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
|
||||
// CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32
|
||||
%g = "tt.reduce" (%f) ({
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<4xf32>) -> f32
|
||||
|
||||
// Avoid optimizations for c, e, and g
|
||||
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
|
||||
|
||||
@@ -40,14 +40,30 @@ tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
// CHECK: tensor<4x4xf32, #[[blocked0]]> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
|
||||
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
|
||||
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
|
||||
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
|
||||
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
|
||||
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
|
||||
// CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
|
||||
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
|
||||
// CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
|
||||
%c0_ = "tt.reduce" (%c0) ({
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%add = arith.addf %arg1, %arg2 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32>
|
||||
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
|
||||
%c1_ = "tt.reduce" (%c1) ({
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
%add = arith.addf %arg3, %arg4 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32>
|
||||
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
|
||||
%c2_ = "tt.reduce" (%c1) ({
|
||||
^bb0(%arg5: f32, %arg6: f32):
|
||||
%add = arith.addf %arg5, %arg6 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32>
|
||||
// CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
|
||||
%c3_ = "tt.reduce" (%c2) ({
|
||||
^bb0(%arg7: f32, %arg8: f32):
|
||||
%add = arith.addf %arg7, %arg8 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32>
|
||||
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -787,7 +787,11 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2>
|
||||
%28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2>
|
||||
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
|
||||
%30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%30 = "tt.reduce" (%29) ({
|
||||
^bb0(%arg4: f32, %arg5: f32):
|
||||
%max = arith.maxf %arg4, %arg5 : f32
|
||||
tt.reduce.return %max : f32
|
||||
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
|
||||
%32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
|
||||
@@ -803,7 +807,11 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%43 = math.exp %42 : tensor<16x16xf32, #blocked2>
|
||||
%44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2>
|
||||
%45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
|
||||
%46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%46 = "tt.reduce" (%45) ({
|
||||
^bb0(%arg4: f32, %arg5: f32):
|
||||
%add = arith.addf %arg4, %arg5 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
|
||||
%48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
|
||||
@@ -907,7 +915,11 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
|
||||
%74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2>
|
||||
scf.yield %74 : tensor<64x64xf32, #blocked2>
|
||||
}
|
||||
%26 = tt.reduce %25 {axis = 1 : i32, redOp = 2 : i32} : tensor<64x64xf32, #blocked2> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%26 = "tt.reduce" (%25) ({
|
||||
^bb0(%arg8: f32, %arg9: f32):
|
||||
%add = arith.addf %arg8, %arg9 : f32
|
||||
tt.reduce.return %add : f32
|
||||
}) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xf32, #blocked0>
|
||||
%28 = triton_gpu.convert_layout %27 : (tensor<64xf32, #blocked0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%29 = tt.expand_dims %28 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xf32, #blocked1>
|
||||
@@ -1016,7 +1028,11 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked>
|
||||
%3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked>
|
||||
%4 = tt.reduce %cst {axis = 1 : i32, redOp = 1 : i32} : tensor<1x2xi32, #blocked> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%4 = "tt.reduce" (%cst) ({
|
||||
^bb0(%arg3: i32, %arg4: i32):
|
||||
%add = arith.addi %arg3, %arg4 : i32
|
||||
tt.reduce.return %add : i32
|
||||
}) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1>
|
||||
%6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>
|
||||
@@ -1037,7 +1053,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
|
||||
// Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop.
|
||||
// CHECK-LABEL: reduce_cvt2
|
||||
// CHECK: tt.reduce
|
||||
// Match the reduction
|
||||
// CHECK: }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
@@ -1092,7 +1109,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%59 = "triton_gpu.select"(%52, %58, %arg6) : (tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked>, tensor<1x256xf32, #blocked>) -> tensor<1x256xf32, #blocked>
|
||||
scf.yield %59 : tensor<1x256xf32, #blocked>
|
||||
}
|
||||
%16 = tt.reduce %15 {axis = 1 : i32, redOp = 2 : i32} : tensor<1x256xf32, #blocked> -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%16 = "tt.reduce" (%15) ({
|
||||
^bb0(%arg7: f32, %arg8: f32):
|
||||
%add = arith.addf %arg7, %arg8 : f32
|
||||
tt.reduce.return %add : f32
|
||||
|
||||
}) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xf32, #blocked1>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<1xf32, #blocked1>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xf32, #blocked2>
|
||||
|
||||
Reference in New Issue
Block a user