[FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)

Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
This commit is contained in:
peterbell10
2023-04-13 01:37:39 +00:00
committed by GitHub
parent 5b9119117b
commit e152183570
23 changed files with 822 additions and 606 deletions

View File

@@ -12,13 +12,26 @@ namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.getOperand().getType().cast<RankedTensorType>();
explicit ReduceOpHelper(triton::ReduceOp rop)
: op(rop.getOperation()), axis(rop.getAxis()) {
auto firstTy = rop.getOperands()[0].getType().cast<RankedTensorType>();
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
srcElementTypes = rop.getElementTypes();
for (const auto &t : rop.getInputTypes()) {
if (t.getShape() != srcShape) {
rop.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
rop.emitError() << "encoding mismatch";
}
}
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
ArrayRef<int64_t> getSrcShape() { return srcShape; }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
Attribute getSrcLayout() { return srcEncoding; }
bool isFastReduction();
@@ -37,8 +50,11 @@ public:
bool isSupportedLayout();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
Operation *op;
ArrayRef<int64_t> srcShape;
Attribute srcEncoding;
SmallVector<Type> srcElementTypes;
int axis;
};
bool isSharedEncoding(Value value);

View File

@@ -34,30 +34,6 @@ def TT_PaddingOptionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",

View File

@@ -388,27 +388,35 @@ def TT_DotOp : TT_Op<"dot", [Pure,
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
def TT_ReduceOp: TT_Op<"reduce",
[Pure,
SameOperandsEncoding,
SingleBlock,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Reduction using generic combination algorithm";
let arguments = (ins Variadic<TT_Tensor>:$operands, I32Attr:$axis);
let results = (outs Variadic<TT_Type>:$result);
let regions = (region SizedRegion<1>:$combineOp);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
OpBuilder<(ins "ValueRange":$operands, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let hasVerifier = 1;
let hasRegionVerifier = 1;
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();
}];
}
def TT_ReduceReturnOp: TT_Op<"reduce.return",
[HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> {
let summary = "terminator for reduce operator";
let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}
//
// External Elementwise op
//

View File

@@ -103,7 +103,7 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
}

View File

@@ -74,7 +74,10 @@ void MembarAnalysis::visitTerminator(Operation *op,
return;
}
// Otherwise, it could be a return op
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
if (isa<triton::ReduceReturnOp>(op) || isa<triton::ReturnOp>(op)) {
return;
}
llvm_unreachable("Unknown terminator encountered in membar analysis");
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,

View File

@@ -10,49 +10,38 @@
namespace mlir {
bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
return axis == triton::gpu::getOrder(srcLayout)[0];
return axis == triton::gpu::getOrder(getSrcLayout())[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.getAxis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.getAxis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
auto argLayout = srcTy.getEncoding();
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
@@ -64,7 +53,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
@@ -82,17 +71,15 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}
auto tensorType = op.getOperand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.getRedOp()))
bytes += elems * sizeof(int32_t);
return bytes;
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
}
return bytesPerElem * elems;
}
bool ReduceOpHelper::isSupportedLayout() {
auto srcLayout = srcTy.getEncoding();
auto srcLayout = getSrcLayout();
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
return true;
}

View File

@@ -1073,12 +1073,14 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \

View File

@@ -23,112 +23,59 @@ public:
}
private:
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value cur, bool isFirst) const {
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
llvm::SmallVectorImpl<Value> &acc, ValueRange cur,
bool isFirst) const {
if (isFirst) {
acc = cur;
acc.resize(cur.size());
for (unsigned i = 0; i < cur.size(); ++i) {
acc[i] = cur[i];
}
return;
}
switch (redOp) {
case RedOp::ADD:
acc = add(acc, cur);
break;
case RedOp::FADD:
acc = fadd(acc.getType(), acc, cur);
break;
case RedOp::MIN:
acc = smin(acc, cur);
break;
case RedOp::MAX:
acc = smax(acc, cur);
break;
case RedOp::UMIN:
acc = umin(acc, cur);
break;
case RedOp::UMAX:
acc = umax(acc, cur);
break;
case RedOp::FMIN:
acc = fmin(acc, cur);
break;
case RedOp::FMAX:
acc = fmax(acc, cur);
break;
case RedOp::XOR:
acc = xor_(acc, cur);
break;
case RedOp::ARGMIN:
case RedOp::ARGMAX:
case RedOp::ARGUMIN:
case RedOp::ARGUMAX:
case RedOp::ARGFMIN:
case RedOp::ARGFMAX:
llvm::report_fatal_error(
"This accumulate implementation is not for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newReduce = parent.front();
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
combineArgs);
auto results = returnOp.getResult();
for (unsigned i = 0; i < acc.size(); ++i) {
acc[i] = results[i];
}
// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
}
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value &accIndex, Value cur,
Value curIndex, bool isFirst) const {
if (isFirst) {
acc = cur;
accIndex = curIndex;
return;
}
switch (redOp) {
case RedOp::ARGMIN:
accIndex = select(
icmp_slt(acc, cur), accIndex,
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smin(acc, cur);
break;
case RedOp::ARGMAX:
accIndex = select(
icmp_sgt(acc, cur), accIndex,
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smax(acc, cur);
break;
case RedOp::ARGUMIN:
accIndex = select(
icmp_ult(acc, cur), accIndex,
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umin(acc, cur);
break;
case RedOp::ARGUMAX:
accIndex = select(
icmp_ugt(acc, cur), accIndex,
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umax(acc, cur);
break;
case RedOp::ARGFMIN:
accIndex = select(
fcmp_olt(acc, cur), accIndex,
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmin(acc, cur);
break;
case RedOp::ARGFMAX:
accIndex = select(
fcmp_ogt(acc, cur), accIndex,
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmax(acc, cur);
break;
case RedOp::ADD:
case RedOp::FADD:
case RedOp::MIN:
case RedOp::MAX:
case RedOp::UMIN:
case RedOp::UMAX:
case RedOp::FMIN:
case RedOp::FMAX:
case RedOp::XOR:
llvm::report_fatal_error(
"This accumulate implementation is only for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
SmallVector<SmallVector<Value>>
unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto types = op.getInputTypes();
auto operands = adaptor.getOperands();
unsigned srcElems = getElemsPerThread(types[0]);
SmallVector<SmallVector<Value>> srcValues(srcElems);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto values = getTypeConverter()->unpackLLElements(loc, operands[i],
rewriter, types[i]);
assert(values.size() == srcValues.size());
for (unsigned j = 0; j < srcValues.size(); ++j) {
srcValues[j].push_back(values[j]);
}
}
return srcValues;
}
// Calculates the write index in the shared memory where we would be writing
@@ -177,63 +124,64 @@ private:
matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ReduceOpHelper helper(op);
Location loc = op->getLoc();
Location loc = op.getLoc();
unsigned axis = op.getAxis();
// Specifies whether the reduce operation returns an index
// rather than a value, e.g. argmax, argmin, .. etc
bool withIndex = triton::ReduceOp::withIndex(op.getRedOp());
auto srcTy = op.getOperand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcTys = op.getInputTypes();
auto srcLayout = helper.getSrcLayout();
if (!helper.isSupportedLayout()) {
assert(false && "Unexpected srcLayout in ReduceOpConversion");
}
// The order of the axes for the the threads within the warp
auto srcOrd = triton::gpu::getOrder(srcLayout);
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout);
auto srcShape = srcTy.getShape();
auto srcShape = helper.getSrcShape();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
SmallVector<Type> elemPtrTys(srcTys.size());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto ty = srcTys[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned srcElems = getElemsPerThread(srcTy);
SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)),
elemPtrTys[i]);
}
unsigned srcElems = getElemsPerThread(srcTys[0]);
// Emits indices of the original tensor that each thread
// would own
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
auto srcValues = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperand(), rewriter, srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
// Emits offsets (the offset from the base index)
// of the original tensor that each thread would own
// NOTE: Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTy);
emitOffsetForLayout(srcLayout, srcTys[0]);
// Keep track of accumulations and their indices
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
Region *combineOp = &op.getCombineOp();
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i],
isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
@@ -250,24 +198,20 @@ private:
// reduce across threads
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
auto &acc = it.second;
// get the writeIdx at which to write in smem
SmallVector<Value> writeIdx;
getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints,
axis);
// calculate the offset in smem for that writeIdx
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
// Get element pointers for the value and index
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
// Store the within-thread accumulated value at writePtr
store(acc, writePtr);
// Store the index of within-thread accumulation at indexWritePtr
if (withIndex)
store(accIndex, indexWritePtr);
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
// Store the within-thread accumulated value into shared memory
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
store(acc[i], writePtrs[i]);
}
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
// Perform parallel reduction with sequential addressing
@@ -286,27 +230,24 @@ private:
Value readOffset = select(
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
ints[0]);
// The readPtr is readOffset away from writePtr
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
SmallVector<Value> readPtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
// The readPtr is readOffset away from writePtr
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
}
barrier();
// If we do not care about the index, i.e. this is not an argmax,
// argmin, .. etc
if (!withIndex) {
// The value at the readPtr, whereas acc is the value at writePtr
Value cur = load(readPtr);
accumulate(rewriter, loc, op.getRedOp(), acc, cur, false);
barrier();
// Update writePtr value
store(acc, writePtr);
} else {
Value cur = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
Value curIndex = load(indexReadPtr);
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, cur,
curIndex, false);
barrier();
store(acc, writePtr);
store(accIndex, indexWritePtr);
// Combine accumulator value from another thread
SmallVector<Value> cur(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
cur[i] = load(readPtrs[i]);
}
accumulate(rewriter, *combineOp, acc, cur, false);
barrier();
// Publish our new accumulator value to shared memory
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
store(acc[i], writePtrs[i]);
}
}
}
@@ -314,33 +255,37 @@ private:
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding();
auto resultShape = resultTy.getShape();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
auto resultLayout = resultTy.getEncoding();
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (unsigned j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset =
linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
resultVals[j] = load(readPtr);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else {
// 0d-tensor -> scalar
results[i] = load(smemBases[i]);
}
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
auto parentBlock = op.getOperation()->getBlock();
rewriter.replaceOp(op, results);
return success();
}
@@ -351,60 +296,59 @@ private:
ReduceOpHelper helper(op);
Location loc = op->getLoc();
unsigned axis = adaptor.getAxis();
bool withIndex = triton::ReduceOp::withIndex(op.getRedOp());
auto srcTy = op.getOperand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcTys = op.getInputTypes();
auto srcLayout = helper.getSrcLayout();
if (!helper.isSupportedLayout()) {
assert(false && "Unexpected srcLayout in ReduceOpConversion");
}
auto srcShape = srcTy.getShape();
auto order = getOrder(srcLayout);
auto srcOrd = triton::gpu::getOrder(srcLayout);
auto srcShape = helper.getSrcShape();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
SmallVector<Type> elemPtrTys(srcTys.size());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto ty = srcTys[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
auto srcValues = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperand(), rewriter, srcTy);
unsigned srcElems = getElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTy);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTys[0]);
auto *combineOp = &op.getCombineOp();
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i],
isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
@@ -414,6 +358,9 @@ private:
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto order = getOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimWarpId =
@@ -427,32 +374,24 @@ private:
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
SmallVector<Value> acc = it.second;
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl,
shflIndex, false);
SmallVector<Value> shfl(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
if (withIndex) {
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset);
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
}
}
@@ -469,39 +408,36 @@ private:
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
// FIXME(Qingyi): need predicate icmp_slt(threadId,
// i32_val(sizeInerWarps))
Value acc = load(readPtr);
Value accIndex;
if (withIndex) {
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
accIndex = load(readIndexPtr);
SmallVector<Value> acc(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
acc[i] = load(readPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl,
shflIndex, false);
SmallVector<Value> shfl(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
}
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
storeShared(rewriter, loc, writePtr, acc, pred);
if (withIndex) {
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
storeShared(rewriter, loc, writePtrs[i], acc[i], pred);
}
if (round != elemsPerThread - 1) {
@@ -515,32 +451,34 @@ private:
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
SmallVector<Value> resultVals(resultElems);
for (size_t j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
resultVals[j] = load(readPtr);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else {
// 0d-tensor -> scalar
results[i] = load(smemBases[i]);
}
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
rewriter.replaceOp(op, results);
return success();
}

View File

@@ -46,15 +46,30 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
Value TritonGPUToLLVMTypeConverter::packLLElements(
Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter,
Type type) {
auto structType = this->convertType(type);
if (!structType.isa<LLVM::LLVMStructType>()) {
auto structType = this->convertType(type).dyn_cast<LLVM::LLVMStructType>();
if (!structType) {
assert(resultVals.size() == 1);
return *resultVals.begin();
}
auto elementTypes = structType.getBody();
if (elementTypes.size() != resultVals.size()) {
emitError(loc) << " size mismatch when packing elements for LLVM struct"
<< " expected " << elementTypes.size() << " but got "
<< resultVals.size();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
// llvm::outs() << structType << "\n";
for (const auto &v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
if (!v.value()) {
emitError(loc)
<< "cannot insert null values into struct, but tried to insert"
<< v.value();
}
if (v.value().getType() != elementTypes[v.index()]) {
emitError(loc) << "invalid element type in packLLEElements. Expected "
<< elementTypes[v.index()] << " but got "
<< v.value().getType();
}
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
}
return llvmStruct;

View File

@@ -68,13 +68,15 @@ public:
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
if (value.getElementType().isInteger(1) && value.isSplat())
// Workaround until https://reviews.llvm.org/D133743 is included.
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
else
// This is a hack. We just want to add encoding
value = value.reshape(retType);
if (dyn_cast<RankedTensorType>(retType)) {
assert(value);
if (value.getElementType().isInteger(1) && value.isSplat())
// Workaround until https://reviews.llvm.org/D133743 is included.
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
else
// This is a hack. We just want to add encoding
value = value.reshape(retType);
}
addNamedAttrs(
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
adaptor.getAttributes());
@@ -469,10 +471,28 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()),
adaptor.getAttributes());
auto newReduce = rewriter.create<triton::ReduceOp>(
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
addNamedAttrs(newReduce, adaptor.getAttributes());
auto &newCombineOp = newReduce.getCombineOp();
rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp,
newCombineOp.end());
rewriter.replaceOp(op, newReduce.getResult());
return success();
}
};
struct TritonReduceReturnPattern
: public OpConversionPattern<triton::ReduceReturnOp> {
using OpConversionPattern<triton::ReduceReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ReduceReturnOp>(
op, adaptor.getResult()),
adaptor.getAttributes());
return success();
}
};
@@ -517,10 +537,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern,
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
typeConverter, context);
}
//

View File

@@ -310,21 +310,10 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
}
//-- ReduceOp --
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp =
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
static mlir::LogicalResult
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
int axis, SmallVectorImpl<Type> &inferredReturnTypes) {
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
@@ -352,15 +341,114 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::ValueRange operands, int axis) {
SmallVector<Type> inferredReturnTypes;
for (unsigned i = 0; i < operands.size(); ++i) {
auto argTy = operands[i].getType().cast<RankedTensorType>();
auto retEltTy = argTy.getElementType();
(void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes);
}
ReduceOp::build(builder, state, inferredReturnTypes, operands, axis);
}
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
for (auto arg : operands) {
auto argTy = arg.getType().cast<RankedTensorType>();
auto retEltTy = argTy.getElementType();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes)
.failed()) {
return failure();
}
}
return success();
}
mlir::LogicalResult mlir::triton::ReduceOp::verify() {
if (this->getOperands().size() < 1) {
return this->emitOpError() << "must have at least 1 operand";
}
for (const auto &operand : this->getOperands()) {
if (!dyn_cast<RankedTensorType>(operand.getType())) {
return this->emitOpError() << "operands must be RankedTensorType";
}
}
return success();
}
mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() {
auto argElementTypes = this->getElementTypes();
const auto &operands = this->getOperands();
const auto numArgs = 2 * operands.size();
auto &block = *this->getBody();
if (block.getNumArguments() != numArgs) {
return this->emitOpError() << "nested block must take " << numArgs
<< " arguments, but given block with "
<< block.getNumArguments() << " arguments";
}
unsigned i = 0;
const auto &blockArgTypes = block.getArgumentTypes();
for (unsigned i = 0; i < numArgs; ++i) {
const auto &blockArgTy = blockArgTypes[i];
const auto &argElemTy = argElementTypes[i % operands.size()];
if (blockArgTy != argElemTy) {
return this->emitOpError()
<< "type mismatch on combine operation. Expected argument " << i
<< " to have type " << argElemTy << " but got " << blockArgTy;
}
}
auto terminator =
dyn_cast<mlir::triton::ReduceReturnOp>(block.getTerminator());
if (!terminator) {
return this->emitOpError()
<< "combine operation must be terminated "
<< "with a ReduceReturnOp but got " << block.getTerminator();
}
const auto &combineResults = terminator->getOperands();
if (combineResults.size() != operands.size()) {
return this->emitOpError()
<< "expected combine operation to return " << operands.size()
<< " values but got " << combineResults.size();
}
for (unsigned i = 0; i < combineResults.size(); ++i) {
const auto &resultTy = combineResults[i].getType();
const auto &argElemTy = argElementTypes[i];
if (resultTy != argElemTy) {
return this->emitOpError()
<< "type mismatch on combine operation. Expected argument " << i
<< " to have type " << argElemTy << " but got " << resultTy;
}
}
return mlir::success();
}
llvm::SmallVector<mlir::RankedTensorType> ReduceOp::getInputTypes() {
llvm::SmallVector<RankedTensorType> srcTys;
srcTys.reserve(this->getNumOperands());
for (const auto &ty : this->getOperands().getTypes()) {
srcTys.push_back(ty.cast<RankedTensorType>());
}
return srcTys;
}
llvm::SmallVector<Type> ReduceOp::getElementTypes() {
llvm::SmallVector<Type> srcElemTys;
srcElemTys.reserve(this->getNumOperands());
for (const auto &op : this->getOperands()) {
srcElemTys.push_back(
op.getType().cast<RankedTensorType>().getElementType());
}
return srcElemTys;
}
unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }
//-- SplatOp --
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto value = adaptor.getSrc();

View File

@@ -101,29 +101,59 @@ public:
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
triton::ReduceOp reduce;
for (auto &use : convert.getResult().getUses()) {
auto owner = use.getOwner();
if (llvm::isa_and_nonnull<triton::ReduceOp>(owner)) {
reduce = llvm::cast<triton::ReduceOp>(owner);
break;
auto owner = llvm::dyn_cast<triton::ReduceOp>(use.getOwner());
if (!owner) {
continue;
}
// TODO: This only moves conversions from the first argument which is
// fine for argmin/argmax but may not be optimal generally
if (convert.getResult() != owner.getOperands()[0]) {
continue;
}
reduce = owner;
break;
}
if (!reduce)
return mlir::failure();
SmallVector<Value> newOperands = reduce.getOperands();
newOperands[0] = convert.getOperand();
auto newEncoding =
newOperands[0].getType().cast<RankedTensorType>().getEncoding();
// this may generate unsupported conversions in the LLVM codegen
if (convert.getOperand()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
if (newEncoding.isa<triton::gpu::MmaEncodingAttr>()) {
return failure();
}
for (unsigned i = 1; i < newOperands.size(); ++i) {
auto oldTy = newOperands[i].getType().cast<RankedTensorType>();
RankedTensorType newTy =
RankedTensorType::Builder(oldTy).setEncoding(newEncoding);
newOperands[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newTy, newOperands[i]);
}
auto newReduce = rewriter.create<triton::ReduceOp>(
op->getLoc(), reduce.getRedOp(), convert.getOperand(),
reduce.getAxis());
Value newRet = newReduce.getResult();
if (newRet.getType() != reduce.getResult().getType())
newRet = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), reduce.getResult().getType(), newRet);
rewriter.replaceAllUsesWith(reduce, newRet);
op->getLoc(), newOperands, reduce.getAxis());
auto &newCombineOp = newReduce.getCombineOp();
rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp,
newCombineOp.end());
SmallVector<Value> newRet = newReduce.getResult();
auto oldTypes = reduce.getResult().getType();
for (unsigned i = 0; i < reduce.getNumOperands(); ++i) {
// it's still beneficial to move the conversion
// to after the reduce if necessary since it will be
// done on a rank-reduced tensor hence cheaper
if (newRet[i].getType() != oldTypes[i])
newRet[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), oldTypes[i], newRet[i]);
}
rewriter.replaceAllUsesWith(reduce.getResult(), newRet);
return success();
}

View File

@@ -79,6 +79,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
// We have custom versions of some arith operators
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,

View File

@@ -95,23 +95,6 @@ void init_triton_ir(py::module &&m) {
.value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST)
.export_values();
py::enum_<mlir::triton::RedOp>(m, "REDUCE_OP")
.value("ADD", mlir::triton::RedOp::ADD)
.value("FADD", mlir::triton::RedOp::FADD)
.value("MIN", mlir::triton::RedOp::MIN)
.value("MAX", mlir::triton::RedOp::MAX)
.value("UMIN", mlir::triton::RedOp::UMIN)
.value("UMAX", mlir::triton::RedOp::UMAX)
.value("ARGMIN", mlir::triton::RedOp::ARGMIN)
.value("ARGMAX", mlir::triton::RedOp::ARGMAX)
.value("ARGUMIN", mlir::triton::RedOp::ARGUMIN)
.value("ARGUMAX", mlir::triton::RedOp::ARGUMAX)
.value("FMIN", mlir::triton::RedOp::FMIN)
.value("FMAX", mlir::triton::RedOp::FMAX)
.value("ARGFMIN", mlir::triton::RedOp::ARGFMIN)
.value("ARGFMAX", mlir::triton::RedOp::ARGFMAX)
.value("XOR", mlir::triton::RedOp::XOR);
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
.value("ADD", mlir::triton::RMWOp::ADD)
.value("FADD", mlir::triton::RMWOp::FADD)
@@ -1349,21 +1332,20 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::math::AbsIOp>(loc, val);
})
.def("create_reduce",
[](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
[](mlir::OpBuilder &self, std::vector<mlir::Value> operands,
int axis) -> mlir::OpState {
auto loc = self.getUnknownLoc();
auto inputTensorType =
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
mlir::Type resType = withIndex ? self.getI32Type()
: inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(shape, resType);
return self.create<mlir::triton::ReduceOp>(loc, operands, axis);
})
.def("create_reduce_ret",
[](mlir::OpBuilder &self, py::args args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
llvm::SmallVector<mlir::Value> return_values;
for (const auto &arg : args) {
return_values.push_back(py::cast<mlir::Value>(arg));
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
return self.create<mlir::triton::ReduceReturnOp>(loc,
return_values);
})
.def("create_ptr_to_int",
[](mlir::OpBuilder &self, mlir::Value &val,

View File

@@ -1295,10 +1295,15 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked>
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src>
%15 = tt.reduce %14 {{axis = {axis} : i32, redOp = 12 : i32}} : tensor<{M}x{N}xf32, #src> -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
%16 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
%17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
tt.store %12, %17 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
%15 = "tt.reduce"(%14) ({{
^bb0(%arg3: f32, %arg4: f32):
%16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1
%17 = arith.select %16, %arg3, %arg4 : f32
tt.reduce.return %17 : f32
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
tt.return
}}
}}

View File

@@ -1,4 +1,5 @@
import ast
import inspect
import re
import sys
import warnings
@@ -755,6 +756,43 @@ class CodeGenerator(ast.NodeVisitor):
# Convert assert to triton's device_assert which happens on the device
return language.core.device_assert(test, msg, _builder=self.builder)
def call_JitFunction(self, fn: JITFunction, args, kwargs):
args = inspect.getcallargs(fn.fn, *args, **kwargs)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
prototype = language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0 or callee_ret_type is None:
return None
elif call_op.get_num_results() == 1:
return tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
results = []
for i in range(call_op.get_num_results()):
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
def visit_Call(self, node):
fn = _unwrap_if_constexpr(self.visit(node.func))
@@ -768,44 +806,13 @@ class CodeGenerator(ast.NodeVisitor):
if not self.debug:
return
if isinstance(fn, JITFunction):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
prototype = language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0 or callee_ret_type is None:
return None
elif call_op.get_num_results() == 1:
return tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
results = []
for i in range(call_op.get_num_results()):
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
return self.call_JitFunction(fn, args, kws)
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
extra_kwargs = dict(_builder=self.builder)
sig = inspect.signature(fn)
if '_generator' in sig.parameters:
extra_kwargs['_generator'] = self
return fn(*args, **extra_kwargs, **kws)
if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
return fn(*args, **kws)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from typing import Callable, List, TypeVar
@@ -1190,46 +1191,166 @@ def _add_reduction_docstr(name: str) -> Callable[[T], T]:
return _decorator
@contextmanager
def _insertion_guard(builder):
ip = builder.get_insertion_point()
yield
builder.restore_insertion_point(ip)
@builtin
def reduction(input, axis, combine_fn, _builder=None, _generator=None):
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
:param input: the input tensor, or tuple of tensors
:param axis: the dimension along which the reduction should be done
:param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
"""
if isinstance(input, tensor):
return reduction((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
region = reduce_op.get_region(0)
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@builtin
def _promote_reduction_input(t, _builder=None):
scalar_ty = t.type.scalar
# input is extended to 32-bits if necessary
# this increases numerical accuracy and can be done pretty much for free
# on GPUs
if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32:
return t.to(int32, _builder=_builder)
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is bfloat16:
return t.to(float32, _builder=_builder)
return t
@builtin
def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
axis = _constexpr_to_value(axis)
n = input.shape[axis]
index = arange(0, n, _builder=_builder)
if len(input.shape) > 1:
new_shape = [constexpr(1)] * len(input.shape)
new_shape[axis] = constexpr(n)
index = view(index, new_shape, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduction((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
return rindices
@triton.jit
def _max_combine(a, b):
return maximum(a, b)
@triton.jit
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.max(input, axis, _builder)
def max(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _max_combine)
@builtin
@triton.jit
def _argmax_combine(value1, index1, value2, index2):
gt = value1 > value2
lt = value1 < value2
index_min = minimum(index1, index2)
index_ret = where(gt, index1, where(lt, index2, index_min))
value_ret = maximum(value1, value2)
return value_ret, index_ret
@triton.jit
@_add_reduction_docstr("maximum index")
def argmax(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmax(input, axis, _builder)
def argmax(input, axis):
input = _promote_reduction_input(input)
return _argreduce(input, axis, _argmax_combine)
@builtin
@triton.jit
def _min_combine(a, b):
# TODO: minimum/maximum doesn't get lowered to fmin/fmax...
return minimum(a, b)
@triton.jit
@_add_reduction_docstr("minimum")
def min(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.min(input, axis, _builder)
def min(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _min_combine)
@builtin
@triton.jit
def _argmin_combine(value1, index1, value2, index2):
lt = value1 < value2
gt = value1 > value2
index_min = minimum(index1, index2)
index_ret = where(lt, index1, where(gt, index2, index_min))
value_ret = minimum(value1, value2)
return value_ret, index_ret
@triton.jit
@_add_reduction_docstr("minimum index")
def argmin(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmin(input, axis, _builder)
def argmin(input, axis):
input = _promote_reduction_input(input)
return _argreduce(input, axis, _argmin_combine)
@builtin
@triton.jit
def _sum_combine(a, b):
return a + b
@triton.jit
@_add_reduction_docstr("sum")
def sum(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.sum(input, axis, _builder)
def sum(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _sum_combine)
@triton.jit
def _xor_combine(a, b):
return a ^ b
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.xor_sum(input, axis, _builder)
def xor_sum(input, axis, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
input = _promote_reduction_input(input, _builder=_builder)
return reduction(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
# -----------------------

View File

@@ -1,7 +1,7 @@
from __future__ import annotations # remove after python 3.11
from functools import wraps
from typing import List, Optional, Tuple, TypeVar
from typing import List, Optional, Sequence, Tuple, TypeVar
from . import core as tl
from triton._C.libtriton.triton import ir
@@ -1228,91 +1228,36 @@ def where(condition: tl.tensor,
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reductions
# Reduction
# ===----------------------------------------------------------------------===
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
scalar_ty = input.type.scalar
out_scalar_ty = scalar_ty
# input is extended to 32-bits if necessary
# this increases numerical accuracy and can be done pretty much for free
# on GPUs
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
input = cast(input, tl.int32, builder)
out_scalar_ty = tl.int32
def reduction(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
# get result shape
shape = inputs[0].type.shape
print(shape, axis)
ret_shape = [s for i, s in enumerate(shape) if i != axis]
for t in inputs:
assert t.type.shape == shape
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is tl.bfloat16:
input = cast(input, tl.float32, builder)
out_scalar_ty = tl.float32
def wrap_tensor(x, scalar_ty):
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
return tl.tensor(x, res_ty)
# choose the right unsigned operation
if scalar_ty.is_int_unsigned():
int_op_to_unit = {
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
}
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
reduce_op = builder.create_reduce([t.handle for t in inputs], axis)
region_builder_fn(reduce_op)
reduce_op.verify()
# If we are doing an argmin or argmax we want to use an int32 output type
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
rank = len(shape)
assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})"
ret_shape = []
for i, s in enumerate(shape):
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
elif scalar_ty.is_int():
return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty)
assert False
def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
return tuple(
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
for i in range(len(inputs))
)
# ===----------------------------------------------------------------------===

View File

@@ -217,7 +217,11 @@ tt.func @alloc(%A : !tt.ptr<f16>) {
tt.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: scratch offset = 0, size = 512
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%b = "tt.reduce" (%cst0) ({
^bb0(%arg0: f16, %arg1: f16):
%add = arith.addf %arg0, %arg1 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
tt.return
// CHECK-NEXT: size = 512
}

View File

@@ -79,7 +79,11 @@ tt.func @scratch() {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%2 = "tt.reduce" (%1) ({
^bb0(%arg1: f16, %arg2: f16):
%add = arith.addf %arg1, %arg2 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
tt.return
}

View File

@@ -79,18 +79,42 @@ tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
// CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
%a = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
%b = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
// CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
%c = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
%e = "tt.reduce" (%b) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
%f = "tt.reduce" (%a) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32
%g = "tt.reduce" (%f) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<4xf32>) -> f32
// Avoid optimizations for c, e, and g
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>

View File

@@ -40,14 +40,30 @@ tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
// CHECK: tensor<4x4xf32, #[[blocked0]]> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
// CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
// CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
%c0_ = "tt.reduce" (%c0) ({
^bb0(%arg1: f32, %arg2: f32):
%add = arith.addf %arg1, %arg2 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32>
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
%c1_ = "tt.reduce" (%c1) ({
^bb0(%arg3: f32, %arg4: f32):
%add = arith.addf %arg3, %arg4 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32>
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
%c2_ = "tt.reduce" (%c1) ({
^bb0(%arg5: f32, %arg6: f32):
%add = arith.addf %arg5, %arg6 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32>
// CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
%c3_ = "tt.reduce" (%c2) ({
^bb0(%arg7: f32, %arg8: f32):
%add = arith.addf %arg7, %arg8 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32>
tt.return
}

View File

@@ -787,7 +787,11 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2>
%28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2>
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%30 = "tt.reduce" (%29) ({
^bb0(%arg4: f32, %arg5: f32):
%max = arith.maxf %arg4, %arg5 : f32
tt.reduce.return %max : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
@@ -803,7 +807,11 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%43 = math.exp %42 : tensor<16x16xf32, #blocked2>
%44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2>
%45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%46 = "tt.reduce" (%45) ({
^bb0(%arg4: f32, %arg5: f32):
%add = arith.addf %arg4, %arg5 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
@@ -907,7 +915,11 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
%74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2>
scf.yield %74 : tensor<64x64xf32, #blocked2>
}
%26 = tt.reduce %25 {axis = 1 : i32, redOp = 2 : i32} : tensor<64x64xf32, #blocked2> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%26 = "tt.reduce" (%25) ({
^bb0(%arg8: f32, %arg9: f32):
%add = arith.addf %arg8, %arg9 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%27 = triton_gpu.convert_layout %26 : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xf32, #blocked0>
%28 = triton_gpu.convert_layout %27 : (tensor<64xf32, #blocked0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%29 = tt.expand_dims %28 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xf32, #blocked1>
@@ -1016,7 +1028,11 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked>
%3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked>
%4 = tt.reduce %cst {axis = 1 : i32, redOp = 1 : i32} : tensor<1x2xi32, #blocked> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%4 = "tt.reduce" (%cst) ({
^bb0(%arg3: i32, %arg4: i32):
%add = arith.addi %arg3, %arg4 : i32
tt.reduce.return %add : i32
}) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1>
%6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>
@@ -1037,7 +1053,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop.
// CHECK-LABEL: reduce_cvt2
// CHECK: tt.reduce
// Match the reduction
// CHECK: }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
@@ -1092,7 +1109,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%59 = "triton_gpu.select"(%52, %58, %arg6) : (tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked>, tensor<1x256xf32, #blocked>) -> tensor<1x256xf32, #blocked>
scf.yield %59 : tensor<1x256xf32, #blocked>
}
%16 = tt.reduce %15 {axis = 1 : i32, redOp = 2 : i32} : tensor<1x256xf32, #blocked> -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%16 = "tt.reduce" (%15) ({
^bb0(%arg7: f32, %arg8: f32):
%add = arith.addf %arg7, %arg8 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = triton_gpu.convert_layout %16 : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xf32, #blocked1>
%18 = triton_gpu.convert_layout %17 : (tensor<1xf32, #blocked1>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xf32, #blocked2>