[BACKEND] Improve torch inductor performance (#1108)

- Rewrite the AxisInfo analysis to handle each op case by case.
- Add bit shift, min max, div/rem, and select ops to AxisInfo.
- Rematerialize across load/store ops in the following two cases:
- A size 1 tensor is considered not expensive since all threads will
load the same
- the targeEncoding may expose more vectorization opportunities (more
elements per thread on the first dim)

**_res2next_** benchmark GPU Kernel time comparison on A100.
- Average kernel sum. Triton 16838630ns vs Triton-MLIR 17105166ns.
**1.016x slowdown**.
- Total kernel sum. Triton 6511735460ns vs Triton-MLIR 6512370620ns.
This commit is contained in:
Keren Zhou
2023-02-01 18:21:15 -08:00
committed by GitHub
parent ccd17d6bf9
commit 82befe32ad
16 changed files with 1333 additions and 260 deletions

View File

@@ -3,11 +3,14 @@
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <optional>
#include <type_traits>
namespace mlir {
//===----------------------------------------------------------------------===//
@@ -15,40 +18,47 @@ namespace mlir {
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef SmallVector<int, 4> DimVectorT;
typedef SmallVector<int64_t, 4> DimVectorT;
public:
// Default constructor
/// Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
/// Construct contiguity info with known contiguity
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy)
: AxisInfo(knownContiguity, knownDivisibility, knownConstancy, {}) {}
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy, std::optional<int64_t> knownConstantValue)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == (size_t)rank);
constancy(knownConstancy), constantValue(knownConstantValue),
rank(contiguity.size()) {
assert(knownContiguity.size() == static_cast<size_t>(rank));
assert(knownDivisibility.size() == static_cast<size_t>(rank));
assert(knownConstancy.size() == static_cast<size_t>(rank));
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
/// Accessors
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
std::optional<int64_t> getConstantValue() const { return constantValue; }
/// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
(constancy == other.constancy) &&
(constantValue == other.constantValue) && (rank == other.rank);
}
/// The pessimistic value state of the contiguity is unknown.
@@ -57,13 +67,18 @@ public:
}
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
/// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// sequence of contiguous integers along it.
/// Suppose we have an array of N elements,
/// with a contiguity value C,
/// the array can be divided into a list of
/// N/C sequences of C contiguous elements.
/// Since we have N = 2^k, C must be a power of two.
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
@@ -97,42 +112,147 @@ private:
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// of operations (e.g., add) involving a constant.
/// Suppose we have an array of N elements,
/// with a constancy value C,
/// the array can be divided into a list of
/// N/C sequences of C elements with the same value.
/// Since we have N = 2^k, C must be a power of two.
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
DimVectorT constancy;
/// The constant value of the lattice if we can infer it.
std::optional<int64_t> constantValue;
// number of dimensions of the lattice
int rank;
int rank{};
};
class AxisInfoVisitor {
public:
AxisInfoVisitor() = default;
virtual ~AxisInfoVisitor() = default;
static bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape,
int dim) {
return info.getContiguity(dim) == shape[dim];
}
static bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape,
int dim) {
return info.getConstancy(dim) == shape[dim];
}
virtual AxisInfo
getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0;
virtual bool match(Operation *op) = 0;
};
/// Base class for all operations
template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
public:
using AxisInfoVisitor::AxisInfoVisitor;
AxisInfo getAxisInfo(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) final {
return getAxisInfo(cast<OpTy>(op), operands);
}
bool match(Operation *op) final { return isa<OpTy>(op); }
virtual AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) {
llvm_unreachable("Unimplemented getAxisInfo");
}
};
/// Binary operations
template <typename OpTy>
class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto rank = lhsInfo.getRank();
assert(operands.size() == 2 && "Expected two operands");
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
for (auto d = 0; d < rank; ++d) {
if (constantValue.has_value()) {
contiguity.push_back(1);
constancy.push_back(
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
divisibility.push_back(highestPowOf2Divisor(constantValue.value()));
} else {
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
}
}
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
protected:
virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) {
return 1;
}
virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) {
return 1;
}
virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) {
return 1;
}
virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) {
return {};
}
};
class AxisInfoVisitorList {
public:
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
void append() {
(visitors.emplace_back(std::make_unique<Ts>()), ...);
}
AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
for (auto &visitor : visitors)
if (visitor->match(op))
return visitor->getAxisInfo(op, operands);
return AxisInfo();
}
private:
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
};
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n) {
if (n == 0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
AxisInfoVisitorList visitors;
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
AxisInfoAnalysis(MLIRContext *context);
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
unsigned getPtrVectorSize(Value ptr);
unsigned getPtrContiguity(Value ptr);
unsigned getPtrAlignment(Value ptr);

View File

@@ -77,6 +77,15 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
return result;
}
template <typename T> T highestPowOf2Divisor(T n) {
if (n == 0) {
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
}
return (n & (~(n - 1)));
}
bool isSingleValue(Value value);
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);

View File

@@ -31,7 +31,7 @@ SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getContigPerThread(const Attribute &layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);

View File

@@ -1,7 +1,6 @@
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -9,20 +8,16 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y) {
static int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) {
// Base Case
if (a == 0) {
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b % a, a, &x1, &y1);
int64_t x1, y1; // To store results of recursive call
int64_t gcd = gcdImpl(b % a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b / a) * x1;
@@ -30,16 +25,30 @@ static int gcd_impl(int a, int b, int *x, int *y) {
return gcd;
}
static int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
static int64_t gcd(int64_t a, int64_t b) {
if (a == 0)
return b;
if (b == 0)
return a;
int64_t x, y;
return gcdImpl(a, b, &x, &y);
}
static constexpr int log2Int(int64_t num) {
return (num > 1) ? 1 + log2Int(num / 2) : 0;
}
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
size_t rank = 1;
auto rank = 1;
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
auto contiHint = 1;
auto divHint = 1;
auto constHint = 1;
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
@@ -53,139 +62,342 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} else {
// Derive the divisibility of the induction variable only when
// the step and the lower bound are both constants
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
if (blockArg == forOp.getInductionVar()) {
if (auto lowerBound =
forOp.getLowerBound().getDefiningOp<arith::ConstantOp>()) {
if (auto step =
forOp.getStep().getDefiningOp<arith::ConstantOp>()) {
auto lowerBoundVal = lowerBound.getValue()
.cast<IntegerAttr>()
.getValue()
.getZExtValue();
auto stepVal =
step.getValue().cast<IntegerAttr>().getValue().getZExtValue();
auto k = gcd(lowerBoundVal, stepVal);
if (k != 0)
divHint = k;
}
}
}
}
}
}
DimVectorT contiguity(rank, 1);
DimVectorT divisibility(rank, divHint);
DimVectorT constancy(rank, 1);
return AxisInfo(contiguity, divisibility, constancy);
return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
/*knownDivisibility=*/DimVectorT(rank, divHint),
/*knownConstancy=*/DimVectorT(rank, constHint));
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
DimVectorT retContiguity;
DimVectorT retDivisibility;
DimVectorT retConstancy;
for (int d = 0; d < lhs.getRank(); ++d) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
for (auto d = 0; d < lhs.getRank(); ++d) {
contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
return AxisInfo(retContiguity, retDivisibility, retConstancy);
std::optional<int64_t> constantValue;
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value() &&
lhs.getConstantValue() == rhs.getConstantValue())
constantValue = lhs.getConstantValue();
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
// AxisInfoVisitor
//===----------------------------------------------------------------------===//
AxisInfo AxisInfoAnalysis::visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy;
for (int d = 0; d < rank; ++d) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
}
template <typename OpTy>
class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
ChangeResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp,
triton::gpu::ConvertLayoutOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range =
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start();
int end = make_range.end();
AxisInfo::DimVectorT contiguity = {end - start};
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::DimVectorT constancy = {1};
curr = AxisInfo(contiguity, divisibility, constancy);
AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
return operands[0]->getValue();
}
// Constant
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
if (intAttr) {
size_t val = intAttr.getValue().getZExtValue();
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
};
class MakeRangeOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::MakeRangeOp> {
public:
using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(triton::MakeRangeOp op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto start = op.start();
auto end = op.end();
return AxisInfo(/*contiguity=*/{end - start},
/*divisibility=*/{highestPowOf2Divisor(start)},
/*constancy=*/{1});
}
};
class ConstantOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<arith::ConstantOp> {
public:
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(arith::ConstantOp op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
if (intAttr || boolAttr) {
int64_t value{};
if (intAttr)
value = intAttr.getValue().getZExtValue();
else
value = boolAttr.getValue() ? 1 : 0;
return AxisInfo(/*contiguity=*/{1},
/*divisibility=*/{highestPowOf2Divisor(value)},
/*constancy=*/{1},
/*knownConstantValue=*/{value});
}
// TODO: generalize to dense attr
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
auto value = splatAttr.getSplatValue<int>();
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(
AxisInfo::DimVectorT(ty.getRank(), 1),
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
/*divisibility=*/
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
/*constancy=*/
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()),
/*knownConstantValue=*/{value});
}
return AxisInfo();
}
// TODO: refactor & complete binary ops
// Addition
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
};
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
};
template <typename OpTy>
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)),
gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)));
}
// Multiplication
if (llvm::isa<arith::MulIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
// lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs)
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
// lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) *
// gcd(d_lhs, d_rhs)
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
}
// Remainder
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
std::is_same_v<OpTy, triton::AddPtrOp>) {
return {lhs.getConstantValue().value() +
rhs.getConstantValue().value()};
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
return {lhs.getConstantValue().value() -
rhs.getConstantValue().value()};
}
}
return {};
}
// Splat
if (llvm::isa<triton::SplatOp>(op)) {
};
class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
public:
using BinaryOpVisitorImpl<arith::MulIOp>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
// lhs * 1 = lhs
auto lhsContiguity =
rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1
? lhs.getContiguity(dim)
: 1;
// 1 * rhs = rhs
auto rhsContiguity =
lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1
? rhs.getContiguity(dim)
: 1;
return std::max(lhsContiguity, rhsContiguity);
}
int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
// lhs = k * d_lhs
// rhs = p * d_rhs
// lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs
return lhs.getDivisibility(dim) * rhs.getDivisibility(dim);
}
std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
return {};
}
};
template <typename OpTy>
class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
// lhs / 1 = lhs
return rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1
? lhs.getContiguity(dim)
: 1;
}
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
// Case 1: both lhs and rhs are constants.
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
// Case 2: lhs contiguous, rhs constant.
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
// ..., (d_lhs * k + n) / (d_rhs * p)
// Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0,
// the minimal constancy is gcd(d_lhs, d_rhs).
// Since gcd(d_lhs, d_rhs) maybe > len(lhs),
// we need to use another gcd to get the actual constancy.
if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) &&
AxisInfoVisitor::isConstantDim(rhs, shape, dim)) {
constancy = std::max(constancy, gcd(lhs.getContiguity(dim),
gcd(lhs.getDivisibility(dim),
rhs.getDivisibility(dim))));
}
return constancy;
}
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
// lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs)
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
// lhs / rhs = k * k' * gcd(d_lhs, d_rhs) / (p * p' * gcd(d_lhs, d_rhs))
// = k / p * k' / p'
// gcd(k', p') = divisibility(d_lhs / gcd(d_lhs, d_rhs), d_rhs / gcd(d_lhs,
// d_rhs))
auto lhsDivisibility = lhs.getDivisibility(dim);
auto rhsDivisibility = rhs.getDivisibility(dim);
auto initGcd = gcd(lhsDivisibility, rhsDivisibility);
return std::max(lhsDivisibility / initGcd, rhsDivisibility / initGcd);
};
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() / rhs.getConstantValue().value()};
return {};
}
};
template <typename OpTy>
class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
auto shape = resTy.getShape();
int64_t contiguity = 1;
// lhs contiguous, rhs constant
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
// lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p),
// ..., (d_lhs * k + n) % (d_rhs * p)
// Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0,
// The minimal contiguity is gcd(d_lhs, d_rhs).
// Since gcd(d_lhs, d_rhs) maybe > len(lhs),
// we need to use another gcd to get the actual contiguity.
if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) &&
AxisInfoVisitor::isConstantDim(rhs, shape, dim)) {
contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim),
gcd(lhs.getDivisibility(dim),
rhs.getDivisibility(dim))));
}
return contiguity;
}
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
// lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k''
// rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p''
// lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r
// r must be divisible by gcd(d_lhs, d_rhs)
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
};
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
// lhs % 1 = 0
return rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1
? shape[dim]
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() % rhs.getConstantValue().value()};
else if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1)
return {0};
return {};
}
};
class SplatOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::SplatOp> {
public:
using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(triton::SplatOp op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
@@ -197,21 +409,37 @@ ChangeResult AxisInfoAnalysis::visitOperation(
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
}
curr = AxisInfo(contiguity, divisibility, constancy);
return AxisInfo(contiguity, divisibility, constancy,
operands[0]->getValue().getConstantValue());
}
// expandDims
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
};
class ExpandDimsOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::ExpandDimsOp> {
public:
using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(triton::ExpandDimsOp op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
contiguity.insert(contiguity.begin() + expandDims.axis(), 1);
divisibility.insert(divisibility.begin() + expandDims.axis(), 1);
constancy.insert(constancy.begin() + expandDims.axis(), 1);
curr = AxisInfo(contiguity, divisibility, constancy);
contiguity.insert(contiguity.begin() + op.axis(), 1);
divisibility.insert(divisibility.begin() + op.axis(), 1);
constancy.insert(constancy.begin() + op.axis(), 1);
return AxisInfo(contiguity, divisibility, constancy,
operands[0]->getValue().getConstantValue());
}
// Broadcast
if (llvm::isa<triton::BroadcastOp>(op)) {
};
class BroadcastOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::BroadcastOp> {
public:
using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(triton::BroadcastOp op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -228,42 +456,362 @@ ChangeResult AxisInfoAnalysis::visitOperation(
constancy.push_back(opShape[d] == 1 ? retShape[d]
: opInfo.getConstancy(d));
}
curr = AxisInfo(contiguity, divisibility, constancy);
return AxisInfo(contiguity, divisibility, constancy,
operands[0]->getValue().getConstantValue());
}
};
// CmpI
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
op->getResult(0).getType().dyn_cast<TensorType>()) {
auto resTy = op->getResult(0).getType().cast<TensorType>();
template <typename OpTy>
class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
auto shape = resTy.getShape();
short rank = resTy.getRank();
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto shape = resTy.getShape();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
std::optional<int64_t> constantValue;
for (short d = 0; d < rank; ++d) {
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
constancy.push_back(
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
else
constancy.push_back(1);
int64_t constHint = 1;
if (lhsInfo.getConstantValue().has_value() &&
rhsInfo.getConstantValue().has_value()) {
constHint = lhsInfo.getConstancy(d);
constantValue =
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())
? 1
: 0;
} else {
// Case 1: lhs and rhs are both partial constants
constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d));
// Case 2: lhs all constant, rhs all contiguous
// NOTE:
// lhs: 4 4 4 4
// rhs: 4 5 6 7
// lhs ge rhs: 1, 0, 0, 0
// Case 3: lhs all contiguous, rhs all constant
// NOTE
// lhs: 4 5 6 7
// rhs: 4 4 4 4
// lhs sle rhs: 1, 0, 0, 0
if (/*Case 2=*/(
notGePredicate(getPredicate(op)) &&
(AxisInfoVisitor::isConstantDim(lhsInfo, shape, d) &&
AxisInfoVisitor::isContiguousDim(rhsInfo, shape, d))) ||
/*Case 3=*/(notLePredicate(getPredicate(op)) &&
(AxisInfoVisitor::isContiguousDim(lhsInfo, shape, d) &&
AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)))) {
constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d),
gcd(lhsInfo.getDivisibility(d),
rhsInfo.getDivisibility(d))));
}
}
divisibility.push_back(shape[d]);
constancy.push_back(constHint);
divisibility.push_back(1);
contiguity.push_back(1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
// UnrealizedConversionCast
private:
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
return op.predicate();
}
static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
return op.getPredicate();
}
static bool notGePredicate(arith::CmpIPredicate predicate) {
return predicate != arith::CmpIPredicate::sge &&
predicate != arith::CmpIPredicate::uge;
}
static bool notLePredicate(arith::CmpIPredicate predicate) {
return predicate != arith::CmpIPredicate::sle &&
predicate != arith::CmpIPredicate::ule;
}
static bool compare(arith::CmpIPredicate predicate, int64_t lhs,
int64_t rhs) {
switch (predicate) {
case arith::CmpIPredicate::eq:
return lhs == rhs;
case arith::CmpIPredicate::ne:
return lhs != rhs;
case arith::CmpIPredicate::slt:
return lhs < rhs;
case arith::CmpIPredicate::sle:
return lhs <= rhs;
case arith::CmpIPredicate::sgt:
return lhs > rhs;
case arith::CmpIPredicate::sge:
return lhs >= rhs;
case arith::CmpIPredicate::ult:
return (uint64_t)lhs < (uint64_t)rhs;
case arith::CmpIPredicate::ule:
return (uint64_t)lhs <= (uint64_t)rhs;
case arith::CmpIPredicate::ugt:
return (uint64_t)lhs > (uint64_t)rhs;
case arith::CmpIPredicate::uge:
return (uint64_t)lhs >= (uint64_t)rhs;
default:
break;
}
llvm_unreachable("unknown comparison predicate");
}
};
template <typename OpTy>
class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
auto shape = resTy.getShape();
auto rank = shape.size();
auto condConstancy = operands[0]->getValue().getConstancy();
auto lhsInfo = operands[1]->getValue();
auto rhsInfo = operands[2]->getValue();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
std::optional<int64_t> constantValue;
if (operands[0]->getValue().getConstantValue().has_value()) {
if (operands[0]->getValue().getConstantValue() == 0) {
contiguity = rhsInfo.getContiguity();
divisibility = rhsInfo.getDivisibility();
constancy = rhsInfo.getConstancy();
constantValue = rhsInfo.getConstantValue();
} else {
contiguity = lhsInfo.getContiguity();
divisibility = lhsInfo.getDivisibility();
constancy = lhsInfo.getConstancy();
constantValue = lhsInfo.getConstantValue();
}
} else {
for (auto d = 0; d < rank; ++d) {
constancy.push_back(
std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]),
gcd(rhsInfo.getConstancy(d), condConstancy[d])));
divisibility.push_back(
std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
contiguity.push_back(
std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]),
gcd(rhsInfo.getContiguity(d), condConstancy[d])));
}
if (lhsInfo.getConstantValue().has_value() &&
rhsInfo.getConstantValue().has_value() &&
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
constantValue = lhsInfo.getConstantValue();
}
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
};
template <typename OpTy>
class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
private:
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same<OpTy, arith::AndIOp>::value) {
return {lhs.getConstantValue().value() &
rhs.getConstantValue().value()};
} else if constexpr (std::is_same<OpTy, arith::OrIOp>::value) {
return {lhs.getConstantValue().value() |
rhs.getConstantValue().value()};
} else if constexpr (std::is_same<OpTy, arith::XOrIOp>::value) {
return {lhs.getConstantValue().value() ^
rhs.getConstantValue().value()};
}
}
return {};
}
};
class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
public:
using BinaryOpVisitorImpl<arith::ShLIOp>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 0)
return lhs.getContiguity(dim);
else
return 1;
}
int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
auto shift = rhs.getConstantValue().has_value()
? rhs.getConstantValue().value()
: rhs.getDivisibility(dim);
auto numBits = log2Int(lhs.getDivisibility(dim));
auto maxBits = log2Int(highestPowOf2Divisor<int64_t>(0));
// Make sure the return value doesn't exceed highestPowOf2Divisor<int64>(0)
if (shift + numBits > maxBits)
return highestPowOf2Divisor<int64_t>(0);
return lhs.getDivisibility(dim) << shift;
}
int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() << rhs.getConstantValue().value()};
return {};
}
};
template <typename OpTy>
class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
private:
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 0)
return lhs.getContiguity(dim);
else
return 1;
}
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
if (rhs.getConstantValue().has_value())
return std::max<int64_t>(1, lhs.getDivisibility(dim) /
(1 << rhs.getConstantValue().value()));
else
return std::max<int64_t>(1, lhs.getDivisibility(dim) /
(1 << rhs.getDivisibility(dim)));
}
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()};
return {};
}
};
template <typename OpTy>
class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo getAxisInfo(OpTy op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
std::optional<int64_t> constantValue;
if (lhsInfo.getConstantValue().has_value() &&
rhsInfo.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::MaxSIOp> ||
std::is_same_v<OpTy, arith::MaxUIOp>) {
constantValue = {std::max(lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())};
} else if constexpr (std::is_same_v<OpTy, arith::MinSIOp> ||
std::is_same_v<OpTy, arith::MinUIOp>) {
constantValue = {std::min(lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())};
}
}
auto rank = lhsInfo.getRank();
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
/*constantValue=*/constantValue);
}
};
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
: ForwardDataFlowAnalysis<AxisInfo>(context) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
// may exist
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
curr = operands[0]->getValue();
}
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<arith::IndexCastOp>,
CastOpAxisInfoVisitor<triton::PtrToIntOp>,
CastOpAxisInfoVisitor<triton::IntToPtrOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
visitors.append<MulIOpAxisInfoVisitor>();
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
DivOpAxisInfoVisitor<arith::DivUIOp>>();
visitors.append<RemOpAxisInfoVisitor<arith::RemSIOp>,
RemOpAxisInfoVisitor<arith::RemUIOp>>();
visitors.append<BroadcastOpAxisInfoVisitor>();
visitors.append<SplatOpAxisInfoVisitor>();
visitors.append<ExpandDimsOpAxisInfoVisitor>();
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>,
CmpOpAxisInfoVisitor<triton::gpu::CmpIOp>>();
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
LogicalOpAxisInfoVisitor<arith::OrIOp>,
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>,
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
visitors.append<MaxMinOpAxisInfoVisitor<arith::MaxSIOp>,
MaxMinOpAxisInfoVisitor<arith::MaxUIOp>,
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
}
ChangeResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
@@ -276,7 +824,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
return result;
}
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
@@ -289,10 +837,10 @@ unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
unsigned align = getPtrAlignment(ptr);
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
contigPerThread = std::min(align, contigPerThread);
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
return vec;
return contigPerThread;
}
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
@@ -302,8 +850,8 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto axisInfo = lookupLatticeElement(ptr)->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
unsigned maxContig = axisInfo.getContiguity(order[0]);
auto maxMultiple = axisInfo.getDivisibility(order[0]);
auto maxContig = axisInfo.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}

View File

@@ -165,6 +165,19 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout;
}
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}
namespace {
/// A data structure similar to SetVector but maintains

View File

@@ -31,8 +31,26 @@ struct LoadStoreConversionBase {
return valueVals;
}
unsigned getContiguity(Value ptr) const {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
return axisAnalysisPass.getPtrContiguity(ptr);
}
unsigned getVectorSize(Value ptr) const {
return axisAnalysisPass.getPtrVectorSize(ptr);
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto contiguity = getContiguity(ptr);
unsigned numElemBits = 0;
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
auto pointeeType = ptrTy.getPointeeType();
numElemBits = pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
// The maximum vector size is 128 bits on NVIDIA GPUs.
return std::min<unsigned>(128 / numElemBits, contiguity);
}
unsigned getMaskAlignment(Value mask) const {
@@ -734,7 +752,10 @@ struct InsertSliceAsyncOpConversion
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorSize(src);
// We don't use getVec() here because we are copying from memory to memory.
// If contiguity > vector size, we can have one pointer maintaining the
// start of the vector and the other pointer moving to the next vector.
unsigned inVec = getContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcTy);

View File

@@ -342,7 +342,7 @@ private:
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =

View File

@@ -143,7 +143,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
}
}
SmallVector<unsigned> getContigPerThread(Attribute layout) {
SmallVector<unsigned> getContigPerThread(const Attribute &layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};

View File

@@ -281,26 +281,35 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return success();
}
// TODO: Interface
LogicalResult getForwardEncoding(Attribute sourceEncoding, Operation *op,
Attribute &ret) {
if (op->hasTrait<mlir::OpTrait::Elementwise>()) {
ret = sourceEncoding;
return success();
inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
return false;
auto ptr = op->getOperand(0);
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
auto encoding = tensorTy.getEncoding();
// Case 2: Different type conversion is expensive (e.g., mma <-> block)
if (encoding.getTypeID() != targetEncoding.getTypeID())
return true;
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
auto order = triton::gpu::getOrder(encoding);
auto targetOrder = triton::gpu::getOrder(targetEncoding);
// Case 3: The targeEncoding may expose more vectorization opportunities
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
}
if (isa<triton::ReduceOp>(op)) {
ret = Attribute();
return success();
}
return failure();
return false;
}
inline bool expensiveToRemat(Operation *op, const Attribute &targetEncoding) {
inline bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
@@ -509,6 +518,7 @@ public:
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
// XXX: why is this needed?
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;

View File

@@ -168,7 +168,7 @@ LogicalResult LoopPipeliner::initialize() {
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.ptr();
unsigned vec = axisInfoAnalysis.getPtrVectorSize(ptr);
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
auto ty = getElementTypeOrSelf(ptr.getType())
.cast<triton::PointerType>()
.getPointeeType();

View File

@@ -1,51 +1,336 @@
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%cst = arith.constant dense<true> : tensor<128x128xi1>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
// CHECK-LABEL: cast
func @cast() {
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
%cst = arith.constant 1 : i32
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
%0 = arith.extsi %cst : i32 to i64
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
return
}
// -----
// CHECK-LABEL: add
func @add() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%2 = arith.addi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127]
%3 = arith.constant dense<127> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
%4 = arith.addi %1, %3 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: sub
func @sub() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%2 = arith.subi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129]
%3 = arith.constant dense<129> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
%4 = arith.subi %3, %1 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: mul
func @mul() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%2 = arith.muli %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
%3 = arith.constant dense<128> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
%4 = arith.muli %3, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2]
%5 = arith.constant dense<2> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256]
%6 = arith.muli %4, %5 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: div
func @div() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%2 = arith.divsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%3 = arith.divui %1, %0 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
%4 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
%5 = arith.divsi %0, %4 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [1] ; ConstantValue: [None]
%6 = arith.divsi %4, %0 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
%7 = arith.constant dense<66> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [536870912] ; Constancy: [2] ; ConstantValue: [None]
%8 = arith.divui %0, %7 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: rem
func @rem() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
%1 = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
%2 = arith.remsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%3 = arith.remui %1, %0 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
%4 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
%5 = arith.remsi %0, %4 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
%6 = arith.remsi %4, %0 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
%7 = arith.constant dense<66> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None]
%8 = arith.remui %0, %7 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: broadcast
func @broadcast() {
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
%0 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64]
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64]
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
return
}
// -----
// CHECK-LABEL: splat
func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None]
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
return
}
// -----
// CHECK-LABEL: cmp
func @cmp() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
%1 = arith.constant dense<0> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%4 = arith.cmpi sle, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%5 = arith.cmpi sge, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
%6 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0]
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: logic
func @logic() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
%1 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
%2 = arith.divsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
%3 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None]
%4 = arith.divsi %0, %3 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%5 = arith.andi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%6 = arith.ori %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%7 = arith.xori %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
%8 = arith.andi %2, %4 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
%9 = arith.ori %2, %4 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
%10 = arith.xori %2, %4 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: select
func @select() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
%1 = arith.constant dense<0> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
%4 = arith.constant 0 : i1
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
%7 = tt.splat %4 : (i1) -> tensor<128xi1>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
%5 = select %4, %3, %7 : tensor<128xi1>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
return
}
// -----
func @shift() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
%1 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
%2 = arith.constant dense<4> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None]
%3 = arith.shli %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None]
%4 = arith.shrsi %0, %2 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
%5 = arith.shli %1, %2 : tensor<128xi32>
return
}
// -----
func @max_min() {
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
%1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%2 = arith.maxsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%3 = arith.minsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
%4 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
%5 = arith.constant dense<4> : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8]
%6 = arith.maxsi %4, %5 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: for
func @for() {
// CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0]
%a_init = arith.constant dense<0> : tensor<128x32xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1]
%b_init = arith.constant dense<1> : tensor<128x32xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
%c_init = arith.constant dense<4> : tensor<128x32xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
%ub = arith.constant 128 : index
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
%lb = arith.constant 0 : index
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16]
%step = arith.constant 16 : index
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
%t = arith.index_cast %iv : index to i32
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
// CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
}
return
}
// -----
// CHECK-LABEL: permute_2d
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1]
%cst = arith.constant dense<true> : tensor<128x128xi1>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
%4 = arith.muli %2, %3 : tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None]
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None]
%16 = arith.muli %14, %15 : tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None]
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None]
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst : tensor<128x128xf32>
return
@@ -56,28 +341,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
module {
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
// CHECK-LABEL: store_constant_align
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%pid = tt.get_program_id {axis = 0 : i32} : i32
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
%c128_i32 = arith.constant 128 : i32
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
%1 = arith.muli %pid, %c128_i32 : i32
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None]
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
%4 = arith.addi %3, %2 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
%9 = tt.splat %n : (i32) -> tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None]
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
%cst = arith.constant dense<0.0> : tensor<128xf32>
tt.store %5, %cst, %mask : tensor<128xf32>
return
@@ -89,6 +375,7 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
// This IR is dumped from vecadd test.
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
// CHECK-LABEL: vecadd_mask_align_16
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -101,13 +388,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%13 = arith.addf %11, %12 : tensor<64xf32>
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %mask : tensor<64xf32>
return
@@ -117,6 +404,7 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
// This IR is dumped from vecadd test.
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
// CHECK-LABEL: vecadd_mask_align_1
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -129,7 +417,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>

View File

@@ -2,6 +2,7 @@
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout2 = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -54,6 +55,61 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
}
// CHECK-LABEL: remat_load_store
func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout1>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout1>
tt.store %5, %4 : tensor<64xi32, #layout1>
return
}
// Don't rematerialize vectorized loads
// CHECK-LABEL: remat_expensive
func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout1>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout1>) -> tensor<64xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout1>) -> tensor<64x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<64xi32, #layout0>
return
}
// Don't rematerialize loads when original and target layouts are different
// CHECK-LABEL: remat_multi_layout
func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout2>
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout2>
tt.store %5, %4 : tensor<64xi32, #layout2>
return
}
// Always rematerialize single value loads
// CHECK-LABEL: remat_single_value
func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #layout1>) -> tensor<1xi32, #layout0>
%3 = triton_gpu.convert_layout %0 : (tensor<1x!tt.ptr<i32>, #layout1>) -> tensor<1x!tt.ptr<i32>, #layout0>
tt.store %3, %2 : tensor<1xi32, #layout0>
return
}
// CHECK-LABEL: if
func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout

View File

@@ -36,8 +36,8 @@ struct TestAliasPass
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
SharedMemoryAliasAnalysis analysis(&getContext());
analysis.run(operation);

View File

@@ -19,9 +19,9 @@ struct TestAllocationPass
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
// Convert to std::string can remove quotes from op_name
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
// Convert to std::string can remove quotes from opName
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
Allocation allocation(operation);
operation->walk([&](Operation *op) {
auto scratchBufferId = allocation.getBufferId(op);

View File

@@ -11,7 +11,7 @@ struct TestAxisInfoPass
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
void print(const std::string &name, raw_ostream &os, ArrayRef<int> vals) {
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
os << name << ": [";
for (size_t d = 0; d < vals.size(); d++) {
if (d != 0)
@@ -29,7 +29,8 @@ struct TestAxisInfoPass
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
os << "Testing: " << operation->getName() << "\n";
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
AxisInfoAnalysis analysis(&getContext());
analysis.run(operation);
operation->walk([&](Operation *op) {
@@ -51,7 +52,14 @@ struct TestAxisInfoPass
print("Divisibility", os, info.getDivisibility());
os << " ; ";
print("Constancy", os, info.getConstancy());
os << " ( ";
os << " ; ";
auto constantValue = info.getConstantValue();
os << "ConstantValue: [";
if (constantValue.has_value())
os << constantValue.value();
else
os << "None";
os << "] ( ";
result.print(os);
os << " ) ";
os << "\n";

View File

@@ -23,8 +23,8 @@ struct TestMembarPass
Operation *operation = getOperation();
auto &os = llvm::errs();
// Convert to std::string can remove quotes from op_name
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
Allocation allocation(operation);
MembarAnalysis membarPass(&allocation);
membarPass.run();