mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Improve torch inductor performance (#1108)
- Rewrite the AxisInfo analysis to handle each op case by case. - Add bit shift, min max, div/rem, and select ops to AxisInfo. - Rematerialize across load/store ops in the following two cases: - A size 1 tensor is considered not expensive since all threads will load the same - the targeEncoding may expose more vectorization opportunities (more elements per thread on the first dim) **_res2next_** benchmark GPU Kernel time comparison on A100. - Average kernel sum. Triton 16838630ns vs Triton-MLIR 17105166ns. **1.016x slowdown**. - Total kernel sum. Triton 6511735460ns vs Triton-MLIR 6512370620ns.
This commit is contained in:
@@ -3,11 +3,14 @@
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -15,40 +18,47 @@ namespace mlir {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int, 4> DimVectorT;
|
||||
typedef SmallVector<int64_t, 4> DimVectorT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
/// Default constructor
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
// Construct contiguity info with known contiguity
|
||||
/// Construct contiguity info with known contiguity
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy)
|
||||
: AxisInfo(knownContiguity, knownDivisibility, knownConstancy, {}) {}
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy, std::optional<int64_t> knownConstantValue)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == (size_t)rank);
|
||||
assert(knownConstancy.size() == (size_t)rank);
|
||||
constancy(knownConstancy), constantValue(knownConstantValue),
|
||||
rank(contiguity.size()) {
|
||||
assert(knownContiguity.size() == static_cast<size_t>(rank));
|
||||
assert(knownDivisibility.size() == static_cast<size_t>(rank));
|
||||
assert(knownConstancy.size() == static_cast<size_t>(rank));
|
||||
}
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
/// Accessors
|
||||
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
|
||||
const DimVectorT &getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
|
||||
const DimVectorT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
|
||||
const DimVectorT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
// Comparison
|
||||
std::optional<int64_t> getConstantValue() const { return constantValue; }
|
||||
|
||||
/// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy);
|
||||
(constancy == other.constancy) &&
|
||||
(constantValue == other.constantValue) && (rank == other.rank);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
@@ -57,13 +67,18 @@ public:
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
/// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it
|
||||
/// sequence of contiguous integers along it.
|
||||
/// Suppose we have an array of N elements,
|
||||
/// with a contiguity value C,
|
||||
/// the array can be divided into a list of
|
||||
/// N/C sequences of C contiguous elements.
|
||||
/// Since we have N = 2^k, C must be a power of two.
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
@@ -97,42 +112,147 @@ private:
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant
|
||||
/// of operations (e.g., add) involving a constant.
|
||||
/// Suppose we have an array of N elements,
|
||||
/// with a constancy value C,
|
||||
/// the array can be divided into a list of
|
||||
/// N/C sequences of C elements with the same value.
|
||||
/// Since we have N = 2^k, C must be a power of two.
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
DimVectorT constancy;
|
||||
|
||||
/// The constant value of the lattice if we can infer it.
|
||||
std::optional<int64_t> constantValue;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
int rank{};
|
||||
};
|
||||
|
||||
class AxisInfoVisitor {
|
||||
public:
|
||||
AxisInfoVisitor() = default;
|
||||
virtual ~AxisInfoVisitor() = default;
|
||||
|
||||
static bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape,
|
||||
int dim) {
|
||||
return info.getContiguity(dim) == shape[dim];
|
||||
}
|
||||
|
||||
static bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape,
|
||||
int dim) {
|
||||
return info.getConstancy(dim) == shape[dim];
|
||||
}
|
||||
|
||||
virtual AxisInfo
|
||||
getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0;
|
||||
|
||||
virtual bool match(Operation *op) = 0;
|
||||
};
|
||||
|
||||
/// Base class for all operations
|
||||
template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
|
||||
public:
|
||||
using AxisInfoVisitor::AxisInfoVisitor;
|
||||
|
||||
AxisInfo getAxisInfo(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) final {
|
||||
return getAxisInfo(cast<OpTy>(op), operands);
|
||||
}
|
||||
|
||||
bool match(Operation *op) final { return isa<OpTy>(op); }
|
||||
|
||||
virtual AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
llvm_unreachable("Unimplemented getAxisInfo");
|
||||
}
|
||||
};
|
||||
|
||||
/// Binary operations
|
||||
template <typename OpTy>
|
||||
class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto rank = lhsInfo.getRank();
|
||||
assert(operands.size() == 2 && "Expected two operands");
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
|
||||
for (auto d = 0; d < rank; ++d) {
|
||||
if (constantValue.has_value()) {
|
||||
contiguity.push_back(1);
|
||||
constancy.push_back(
|
||||
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
|
||||
divisibility.push_back(highestPowOf2Divisor(constantValue.value()));
|
||||
} else {
|
||||
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
|
||||
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
|
||||
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
|
||||
}
|
||||
}
|
||||
return AxisInfo(contiguity, divisibility, constancy, constantValue);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class AxisInfoVisitorList {
|
||||
public:
|
||||
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
void append() {
|
||||
(visitors.emplace_back(std::make_unique<Ts>()), ...);
|
||||
}
|
||||
|
||||
AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
for (auto &visitor : visitors)
|
||||
if (visitor->match(op))
|
||||
return visitor->getAxisInfo(op, operands);
|
||||
return AxisInfo();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n) {
|
||||
if (n == 0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
|
||||
AxisInfoVisitorList visitors;
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
AxisInfoAnalysis(MLIRContext *context);
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
unsigned getPtrVectorSize(Value ptr);
|
||||
unsigned getPtrContiguity(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
|
||||
@@ -77,6 +77,15 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T> T highestPowOf2Divisor(T n) {
|
||||
if (n == 0) {
|
||||
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
|
||||
}
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
SmallVector<unsigned> getContigPerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
@@ -9,20 +8,16 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y) {
|
||||
static int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) {
|
||||
// Base Case
|
||||
if (a == 0) {
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b % a, a, &x1, &y1);
|
||||
int64_t x1, y1; // To store results of recursive call
|
||||
int64_t gcd = gcdImpl(b % a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b / a) * x1;
|
||||
@@ -30,16 +25,30 @@ static int gcd_impl(int a, int b, int *x, int *y) {
|
||||
return gcd;
|
||||
}
|
||||
|
||||
static int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
static int64_t gcd(int64_t a, int64_t b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
int64_t x, y;
|
||||
return gcdImpl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
static constexpr int log2Int(int64_t num) {
|
||||
return (num > 1) ? 1 + log2Int(num / 2) : 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
size_t rank = 1;
|
||||
auto rank = 1;
|
||||
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
rank = ty.getRank();
|
||||
int divHint = 1;
|
||||
auto contiHint = 1;
|
||||
auto divHint = 1;
|
||||
auto constHint = 1;
|
||||
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
|
||||
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
@@ -53,139 +62,342 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
} else {
|
||||
// Derive the divisibility of the induction variable only when
|
||||
// the step and the lower bound are both constants
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
if (blockArg == forOp.getInductionVar()) {
|
||||
if (auto lowerBound =
|
||||
forOp.getLowerBound().getDefiningOp<arith::ConstantOp>()) {
|
||||
if (auto step =
|
||||
forOp.getStep().getDefiningOp<arith::ConstantOp>()) {
|
||||
auto lowerBoundVal = lowerBound.getValue()
|
||||
.cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
auto stepVal =
|
||||
step.getValue().cast<IntegerAttr>().getValue().getZExtValue();
|
||||
auto k = gcd(lowerBoundVal, stepVal);
|
||||
if (k != 0)
|
||||
divHint = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DimVectorT contiguity(rank, 1);
|
||||
DimVectorT divisibility(rank, divHint);
|
||||
DimVectorT constancy(rank, 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
|
||||
return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
|
||||
/*knownDivisibility=*/DimVectorT(rank, divHint),
|
||||
/*knownConstancy=*/DimVectorT(rank, constHint));
|
||||
}
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
DimVectorT retContiguity;
|
||||
DimVectorT retDivisibility;
|
||||
DimVectorT retConstancy;
|
||||
for (int d = 0; d < lhs.getRank(); ++d) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
DimVectorT contiguity;
|
||||
DimVectorT divisibility;
|
||||
DimVectorT constancy;
|
||||
for (auto d = 0; d < lhs.getRank(); ++d) {
|
||||
contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
}
|
||||
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||
std::optional<int64_t> constantValue;
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value() &&
|
||||
lhs.getConstantValue() == rhs.getConstantValue())
|
||||
constantValue = lhs.getConstantValue();
|
||||
return AxisInfo(contiguity, divisibility, constancy, constantValue);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
// AxisInfoVisitor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::DimVectorT newContiguity;
|
||||
AxisInfo::DimVectorT newDivisibility;
|
||||
AxisInfo::DimVectorT newConstancy;
|
||||
for (int d = 0; d < rank; ++d) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
template <typename OpTy>
|
||||
class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp,
|
||||
triton::gpu::ConvertLayoutOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range =
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::DimVectorT contiguity = {end - start};
|
||||
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::DimVectorT constancy = {1};
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
return operands[0]->getValue();
|
||||
}
|
||||
// Constant
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||
if (intAttr) {
|
||||
size_t val = intAttr.getValue().getZExtValue();
|
||||
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||
};
|
||||
|
||||
class MakeRangeOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<triton::MakeRangeOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::MakeRangeOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto start = op.start();
|
||||
auto end = op.end();
|
||||
return AxisInfo(/*contiguity=*/{end - start},
|
||||
/*divisibility=*/{highestPowOf2Divisor(start)},
|
||||
/*constancy=*/{1});
|
||||
}
|
||||
};
|
||||
|
||||
class ConstantOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<arith::ConstantOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(arith::ConstantOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
|
||||
if (intAttr || boolAttr) {
|
||||
int64_t value{};
|
||||
if (intAttr)
|
||||
value = intAttr.getValue().getZExtValue();
|
||||
else
|
||||
value = boolAttr.getValue() ? 1 : 0;
|
||||
return AxisInfo(/*contiguity=*/{1},
|
||||
/*divisibility=*/{highestPowOf2Divisor(value)},
|
||||
/*constancy=*/{1},
|
||||
/*knownConstantValue=*/{value});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
|
||||
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(
|
||||
AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
return AxisInfo(
|
||||
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
/*divisibility=*/
|
||||
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
|
||||
/*constancy=*/
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()),
|
||||
/*knownConstantValue=*/{value});
|
||||
}
|
||||
return AxisInfo();
|
||||
}
|
||||
// TODO: refactor & complete binary ops
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
};
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)),
|
||||
gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)));
|
||||
}
|
||||
// Multiplication
|
||||
if (llvm::isa<arith::MulIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
|
||||
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
// lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs)
|
||||
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
|
||||
// lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) *
|
||||
// gcd(d_lhs, d_rhs)
|
||||
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
|
||||
}
|
||||
// Remainder
|
||||
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
// TODO: All other binary ops
|
||||
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
|
||||
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
|
||||
std::is_same_v<OpTy, triton::AddPtrOp>) {
|
||||
return {lhs.getConstantValue().value() +
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
|
||||
return {lhs.getConstantValue().value() -
|
||||
rhs.getConstantValue().value()};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)) {
|
||||
};
|
||||
|
||||
class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<arith::MulIOp>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
// lhs * 1 = lhs
|
||||
auto lhsContiguity =
|
||||
rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1
|
||||
? lhs.getContiguity(dim)
|
||||
: 1;
|
||||
// 1 * rhs = rhs
|
||||
auto rhsContiguity =
|
||||
lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1
|
||||
? rhs.getContiguity(dim)
|
||||
: 1;
|
||||
return std::max(lhsContiguity, rhsContiguity);
|
||||
}
|
||||
|
||||
int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
|
||||
int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
// lhs = k * d_lhs
|
||||
// rhs = p * d_rhs
|
||||
// lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs
|
||||
return lhs.getDivisibility(dim) * rhs.getDivisibility(dim);
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value())
|
||||
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
// lhs / 1 = lhs
|
||||
return rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 1
|
||||
? lhs.getContiguity(dim)
|
||||
: 1;
|
||||
}
|
||||
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
|
||||
auto shape = resTy.getShape();
|
||||
// Case 1: both lhs and rhs are constants.
|
||||
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
// Case 2: lhs contiguous, rhs constant.
|
||||
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
|
||||
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
|
||||
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
|
||||
// ..., (d_lhs * k + n) / (d_rhs * p)
|
||||
// Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0,
|
||||
// the minimal constancy is gcd(d_lhs, d_rhs).
|
||||
// Since gcd(d_lhs, d_rhs) maybe > len(lhs),
|
||||
// we need to use another gcd to get the actual constancy.
|
||||
if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) &&
|
||||
AxisInfoVisitor::isConstantDim(rhs, shape, dim)) {
|
||||
constancy = std::max(constancy, gcd(lhs.getContiguity(dim),
|
||||
gcd(lhs.getDivisibility(dim),
|
||||
rhs.getDivisibility(dim))));
|
||||
}
|
||||
return constancy;
|
||||
}
|
||||
|
||||
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
// lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs)
|
||||
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
|
||||
// lhs / rhs = k * k' * gcd(d_lhs, d_rhs) / (p * p' * gcd(d_lhs, d_rhs))
|
||||
// = k / p * k' / p'
|
||||
// gcd(k', p') = divisibility(d_lhs / gcd(d_lhs, d_rhs), d_rhs / gcd(d_lhs,
|
||||
// d_rhs))
|
||||
auto lhsDivisibility = lhs.getDivisibility(dim);
|
||||
auto rhsDivisibility = rhs.getDivisibility(dim);
|
||||
auto initGcd = gcd(lhsDivisibility, rhsDivisibility);
|
||||
return std::max(lhsDivisibility / initGcd, rhsDivisibility / initGcd);
|
||||
};
|
||||
|
||||
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value())
|
||||
return {lhs.getConstantValue().value() / rhs.getConstantValue().value()};
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
|
||||
auto shape = resTy.getShape();
|
||||
int64_t contiguity = 1;
|
||||
// lhs contiguous, rhs constant
|
||||
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
|
||||
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
|
||||
// lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p),
|
||||
// ..., (d_lhs * k + n) % (d_rhs * p)
|
||||
// Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0,
|
||||
// The minimal contiguity is gcd(d_lhs, d_rhs).
|
||||
// Since gcd(d_lhs, d_rhs) maybe > len(lhs),
|
||||
// we need to use another gcd to get the actual contiguity.
|
||||
if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) &&
|
||||
AxisInfoVisitor::isConstantDim(rhs, shape, dim)) {
|
||||
contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim),
|
||||
gcd(lhs.getDivisibility(dim),
|
||||
rhs.getDivisibility(dim))));
|
||||
}
|
||||
return contiguity;
|
||||
}
|
||||
|
||||
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
// lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k''
|
||||
// rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p''
|
||||
// lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r
|
||||
// r must be divisible by gcd(d_lhs, d_rhs)
|
||||
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
|
||||
};
|
||||
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
|
||||
auto shape = resTy.getShape();
|
||||
// lhs % 1 = 0
|
||||
return rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 1
|
||||
? shape[dim]
|
||||
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value())
|
||||
return {lhs.getConstantValue().value() % rhs.getConstantValue().value()};
|
||||
else if (rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 1)
|
||||
return {0};
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class SplatOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<triton::SplatOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::SplatOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
@@ -197,21 +409,37 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
return AxisInfo(contiguity, divisibility, constancy,
|
||||
operands[0]->getValue().getConstantValue());
|
||||
}
|
||||
// expandDims
|
||||
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
};
|
||||
|
||||
class ExpandDimsOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<triton::ExpandDimsOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::ExpandDimsOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
|
||||
contiguity.insert(contiguity.begin() + expandDims.axis(), 1);
|
||||
divisibility.insert(divisibility.begin() + expandDims.axis(), 1);
|
||||
constancy.insert(constancy.begin() + expandDims.axis(), 1);
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
contiguity.insert(contiguity.begin() + op.axis(), 1);
|
||||
divisibility.insert(divisibility.begin() + op.axis(), 1);
|
||||
constancy.insert(constancy.begin() + op.axis(), 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy,
|
||||
operands[0]->getValue().getConstantValue());
|
||||
}
|
||||
// Broadcast
|
||||
if (llvm::isa<triton::BroadcastOp>(op)) {
|
||||
};
|
||||
|
||||
class BroadcastOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<triton::BroadcastOp> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::BroadcastOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -228,42 +456,362 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d]
|
||||
: opInfo.getConstancy(d));
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
return AxisInfo(contiguity, divisibility, constancy,
|
||||
operands[0]->getValue().getConstantValue());
|
||||
}
|
||||
};
|
||||
|
||||
// CmpI
|
||||
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
|
||||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
|
||||
op->getResult(0).getType().dyn_cast<TensorType>()) {
|
||||
auto resTy = op->getResult(0).getType().cast<TensorType>();
|
||||
template <typename OpTy>
|
||||
class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return AxisInfo();
|
||||
auto shape = resTy.getShape();
|
||||
short rank = resTy.getRank();
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto shape = resTy.getShape();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
std::optional<int64_t> constantValue;
|
||||
for (short d = 0; d < rank; ++d) {
|
||||
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
|
||||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
|
||||
constancy.push_back(
|
||||
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||
else
|
||||
constancy.push_back(1);
|
||||
int64_t constHint = 1;
|
||||
if (lhsInfo.getConstantValue().has_value() &&
|
||||
rhsInfo.getConstantValue().has_value()) {
|
||||
constHint = lhsInfo.getConstancy(d);
|
||||
constantValue =
|
||||
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
|
||||
rhsInfo.getConstantValue().value())
|
||||
? 1
|
||||
: 0;
|
||||
} else {
|
||||
// Case 1: lhs and rhs are both partial constants
|
||||
constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d));
|
||||
// Case 2: lhs all constant, rhs all contiguous
|
||||
// NOTE:
|
||||
// lhs: 4 4 4 4
|
||||
// rhs: 4 5 6 7
|
||||
// lhs ge rhs: 1, 0, 0, 0
|
||||
// Case 3: lhs all contiguous, rhs all constant
|
||||
// NOTE
|
||||
// lhs: 4 5 6 7
|
||||
// rhs: 4 4 4 4
|
||||
// lhs sle rhs: 1, 0, 0, 0
|
||||
if (/*Case 2=*/(
|
||||
notGePredicate(getPredicate(op)) &&
|
||||
(AxisInfoVisitor::isConstantDim(lhsInfo, shape, d) &&
|
||||
AxisInfoVisitor::isContiguousDim(rhsInfo, shape, d))) ||
|
||||
/*Case 3=*/(notLePredicate(getPredicate(op)) &&
|
||||
(AxisInfoVisitor::isContiguousDim(lhsInfo, shape, d) &&
|
||||
AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)))) {
|
||||
constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d),
|
||||
gcd(lhsInfo.getDivisibility(d),
|
||||
rhsInfo.getDivisibility(d))));
|
||||
}
|
||||
}
|
||||
|
||||
divisibility.push_back(shape[d]);
|
||||
constancy.push_back(constHint);
|
||||
divisibility.push_back(1);
|
||||
contiguity.push_back(1);
|
||||
}
|
||||
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
return AxisInfo(contiguity, divisibility, constancy, constantValue);
|
||||
}
|
||||
|
||||
// UnrealizedConversionCast
|
||||
private:
|
||||
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
|
||||
return op.predicate();
|
||||
}
|
||||
|
||||
static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
|
||||
return op.getPredicate();
|
||||
}
|
||||
|
||||
static bool notGePredicate(arith::CmpIPredicate predicate) {
|
||||
return predicate != arith::CmpIPredicate::sge &&
|
||||
predicate != arith::CmpIPredicate::uge;
|
||||
}
|
||||
|
||||
static bool notLePredicate(arith::CmpIPredicate predicate) {
|
||||
return predicate != arith::CmpIPredicate::sle &&
|
||||
predicate != arith::CmpIPredicate::ule;
|
||||
}
|
||||
|
||||
static bool compare(arith::CmpIPredicate predicate, int64_t lhs,
|
||||
int64_t rhs) {
|
||||
switch (predicate) {
|
||||
case arith::CmpIPredicate::eq:
|
||||
return lhs == rhs;
|
||||
case arith::CmpIPredicate::ne:
|
||||
return lhs != rhs;
|
||||
case arith::CmpIPredicate::slt:
|
||||
return lhs < rhs;
|
||||
case arith::CmpIPredicate::sle:
|
||||
return lhs <= rhs;
|
||||
case arith::CmpIPredicate::sgt:
|
||||
return lhs > rhs;
|
||||
case arith::CmpIPredicate::sge:
|
||||
return lhs >= rhs;
|
||||
case arith::CmpIPredicate::ult:
|
||||
return (uint64_t)lhs < (uint64_t)rhs;
|
||||
case arith::CmpIPredicate::ule:
|
||||
return (uint64_t)lhs <= (uint64_t)rhs;
|
||||
case arith::CmpIPredicate::ugt:
|
||||
return (uint64_t)lhs > (uint64_t)rhs;
|
||||
case arith::CmpIPredicate::uge:
|
||||
return (uint64_t)lhs >= (uint64_t)rhs;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("unknown comparison predicate");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return AxisInfo();
|
||||
auto shape = resTy.getShape();
|
||||
auto rank = shape.size();
|
||||
auto condConstancy = operands[0]->getValue().getConstancy();
|
||||
auto lhsInfo = operands[1]->getValue();
|
||||
auto rhsInfo = operands[2]->getValue();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
std::optional<int64_t> constantValue;
|
||||
if (operands[0]->getValue().getConstantValue().has_value()) {
|
||||
if (operands[0]->getValue().getConstantValue() == 0) {
|
||||
contiguity = rhsInfo.getContiguity();
|
||||
divisibility = rhsInfo.getDivisibility();
|
||||
constancy = rhsInfo.getConstancy();
|
||||
constantValue = rhsInfo.getConstantValue();
|
||||
} else {
|
||||
contiguity = lhsInfo.getContiguity();
|
||||
divisibility = lhsInfo.getDivisibility();
|
||||
constancy = lhsInfo.getConstancy();
|
||||
constantValue = lhsInfo.getConstantValue();
|
||||
}
|
||||
} else {
|
||||
for (auto d = 0; d < rank; ++d) {
|
||||
constancy.push_back(
|
||||
std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]),
|
||||
gcd(rhsInfo.getConstancy(d), condConstancy[d])));
|
||||
divisibility.push_back(
|
||||
std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||
contiguity.push_back(
|
||||
std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]),
|
||||
gcd(rhsInfo.getContiguity(d), condConstancy[d])));
|
||||
}
|
||||
if (lhsInfo.getConstantValue().has_value() &&
|
||||
rhsInfo.getConstantValue().has_value() &&
|
||||
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
|
||||
constantValue = lhsInfo.getConstantValue();
|
||||
}
|
||||
|
||||
return AxisInfo(contiguity, divisibility, constancy, constantValue);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same<OpTy, arith::AndIOp>::value) {
|
||||
return {lhs.getConstantValue().value() &
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same<OpTy, arith::OrIOp>::value) {
|
||||
return {lhs.getConstantValue().value() |
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same<OpTy, arith::XOrIOp>::value) {
|
||||
return {lhs.getConstantValue().value() ^
|
||||
rhs.getConstantValue().value()};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<arith::ShLIOp>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
if (rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 0)
|
||||
return lhs.getContiguity(dim);
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
auto shift = rhs.getConstantValue().has_value()
|
||||
? rhs.getConstantValue().value()
|
||||
: rhs.getDivisibility(dim);
|
||||
auto numBits = log2Int(lhs.getDivisibility(dim));
|
||||
auto maxBits = log2Int(highestPowOf2Divisor<int64_t>(0));
|
||||
// Make sure the return value doesn't exceed highestPowOf2Divisor<int64>(0)
|
||||
if (shift + numBits > maxBits)
|
||||
return highestPowOf2Divisor<int64_t>(0);
|
||||
return lhs.getDivisibility(dim) << shift;
|
||||
}
|
||||
|
||||
int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs, int dim) override {
|
||||
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value())
|
||||
return {lhs.getConstantValue().value() << rhs.getConstantValue().value()};
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
|
||||
public:
|
||||
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
|
||||
|
||||
private:
|
||||
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
if (rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 0)
|
||||
return lhs.getContiguity(dim);
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
if (rhs.getConstantValue().has_value())
|
||||
return std::max<int64_t>(1, lhs.getDivisibility(dim) /
|
||||
(1 << rhs.getConstantValue().value()));
|
||||
else
|
||||
return std::max<int64_t>(1, lhs.getDivisibility(dim) /
|
||||
(1 << rhs.getDivisibility(dim)));
|
||||
}
|
||||
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
int dim) override {
|
||||
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value())
|
||||
return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()};
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
std::optional<int64_t> constantValue;
|
||||
if (lhsInfo.getConstantValue().has_value() &&
|
||||
rhsInfo.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same_v<OpTy, arith::MaxSIOp> ||
|
||||
std::is_same_v<OpTy, arith::MaxUIOp>) {
|
||||
constantValue = {std::max(lhsInfo.getConstantValue().value(),
|
||||
rhsInfo.getConstantValue().value())};
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::MinSIOp> ||
|
||||
std::is_same_v<OpTy, arith::MinUIOp>) {
|
||||
constantValue = {std::min(lhsInfo.getConstantValue().value(),
|
||||
rhsInfo.getConstantValue().value())};
|
||||
}
|
||||
}
|
||||
auto rank = lhsInfo.getRank();
|
||||
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
|
||||
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
|
||||
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
|
||||
/*constantValue=*/constantValue);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
||||
: ForwardDataFlowAnalysis<AxisInfo>(context) {
|
||||
// UnrealizedConversionCast:
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
// may exist
|
||||
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
|
||||
curr = operands[0]->getValue();
|
||||
}
|
||||
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
|
||||
CastOpAxisInfoVisitor<arith::ExtUIOp>,
|
||||
CastOpAxisInfoVisitor<arith::TruncIOp>,
|
||||
CastOpAxisInfoVisitor<arith::IndexCastOp>,
|
||||
CastOpAxisInfoVisitor<triton::PtrToIntOp>,
|
||||
CastOpAxisInfoVisitor<triton::IntToPtrOp>,
|
||||
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
|
||||
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
|
||||
CastOpAxisInfoVisitor<triton::BitcastOp>>();
|
||||
visitors.append<MakeRangeOpAxisInfoVisitor>();
|
||||
visitors.append<ConstantOpAxisInfoVisitor>();
|
||||
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::AddIOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
|
||||
visitors.append<MulIOpAxisInfoVisitor>();
|
||||
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
|
||||
DivOpAxisInfoVisitor<arith::DivUIOp>>();
|
||||
visitors.append<RemOpAxisInfoVisitor<arith::RemSIOp>,
|
||||
RemOpAxisInfoVisitor<arith::RemUIOp>>();
|
||||
visitors.append<BroadcastOpAxisInfoVisitor>();
|
||||
visitors.append<SplatOpAxisInfoVisitor>();
|
||||
visitors.append<ExpandDimsOpAxisInfoVisitor>();
|
||||
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>,
|
||||
CmpOpAxisInfoVisitor<triton::gpu::CmpIOp>>();
|
||||
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::OrIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
|
||||
visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>,
|
||||
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
|
||||
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
|
||||
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
|
||||
visitors.append<MaxMinOpAxisInfoVisitor<arith::MaxSIOp>,
|
||||
MaxMinOpAxisInfoVisitor<arith::MaxUIOp>,
|
||||
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
|
||||
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr = visitors.apply(op, operands);
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
@@ -276,7 +824,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
|
||||
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
@@ -289,10 +837,10 @@ unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
|
||||
unsigned align = getPtrAlignment(ptr);
|
||||
|
||||
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
contigPerThread = std::min(align, contigPerThread);
|
||||
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
|
||||
|
||||
return vec;
|
||||
return contigPerThread;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
@@ -302,8 +850,8 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto axisInfo = lookupLatticeElement(ptr)->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo.getContiguity(order[0]);
|
||||
auto maxMultiple = axisInfo.getDivisibility(order[0]);
|
||||
auto maxContig = axisInfo.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
@@ -165,6 +165,19 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
return tensorTy.getNumElements() == 1;
|
||||
// TODO: Handle other cases.
|
||||
// For example, when ptr is a tensor of single value.
|
||||
// It means that ptr is a resultant of broadcast or generated through
|
||||
// a chain of broadcast and other operations.
|
||||
// Rematerialize it without considering contiguous memory access pattern is
|
||||
// fine.
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A data structure similar to SetVector but maintains
|
||||
|
||||
@@ -31,8 +31,26 @@ struct LoadStoreConversionBase {
|
||||
return valueVals;
|
||||
}
|
||||
|
||||
unsigned getContiguity(Value ptr) const {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
return axisAnalysisPass.getPtrContiguity(ptr);
|
||||
}
|
||||
|
||||
unsigned getVectorSize(Value ptr) const {
|
||||
return axisAnalysisPass.getPtrVectorSize(ptr);
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto contiguity = getContiguity(ptr);
|
||||
unsigned numElemBits = 0;
|
||||
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
|
||||
auto pointeeType = ptrTy.getPointeeType();
|
||||
numElemBits = pointeeType.isa<triton::Float8Type>()
|
||||
? 8
|
||||
: pointeeType.getIntOrFloatBitWidth();
|
||||
// The maximum vector size is 128 bits on NVIDIA GPUs.
|
||||
return std::min<unsigned>(128 / numElemBits, contiguity);
|
||||
}
|
||||
|
||||
unsigned getMaskAlignment(Value mask) const {
|
||||
@@ -734,7 +752,10 @@ struct InsertSliceAsyncOpConversion
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
||||
unsigned inVec = getVectorSize(src);
|
||||
// We don't use getVec() here because we are copying from memory to memory.
|
||||
// If contiguity > vector size, we can have one pointer maintaining the
|
||||
// start of the vector and the other pointer moving to the next vector.
|
||||
unsigned inVec = getContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = getElemsPerThread(srcTy);
|
||||
|
||||
@@ -342,7 +342,7 @@ private:
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
|
||||
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
|
||||
@@ -143,7 +143,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
||||
SmallVector<unsigned> getContigPerThread(const Attribute &layout) {
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
|
||||
return {1, 2};
|
||||
|
||||
@@ -281,26 +281,35 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult getForwardEncoding(Attribute sourceEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
if (op->hasTrait<mlir::OpTrait::Elementwise>()) {
|
||||
ret = sourceEncoding;
|
||||
return success();
|
||||
inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
auto ptr = op->getOperand(0);
|
||||
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorTy.getEncoding();
|
||||
// Case 2: Different type conversion is expensive (e.g., mma <-> block)
|
||||
if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
return true;
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
|
||||
auto order = triton::gpu::getOrder(encoding);
|
||||
auto targetOrder = triton::gpu::getOrder(targetEncoding);
|
||||
// Case 3: The targeEncoding may expose more vectorization opportunities
|
||||
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
|
||||
}
|
||||
if (isa<triton::ReduceOp>(op)) {
|
||||
ret = Attribute();
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool expensiveToRemat(Operation *op, const Attribute &targetEncoding) {
|
||||
inline bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
return expensiveLoadOrStore(op, targetEncoding);
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
||||
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp, triton::DotOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
|
||||
op))
|
||||
@@ -509,6 +518,7 @@ public:
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto dstEncoding =
|
||||
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
// XXX: why is this needed?
|
||||
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
|
||||
return failure();
|
||||
SetVector<Operation *> cvtSlices;
|
||||
|
||||
@@ -168,7 +168,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
for (Operation &op : *loop)
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
auto ptr = loadOp.ptr();
|
||||
unsigned vec = axisInfoAnalysis.getPtrVectorSize(ptr);
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
auto ty = getElementTypeOrSelf(ptr.getType())
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType();
|
||||
|
||||
@@ -1,51 +1,336 @@
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
// CHECK-LABEL: cast
|
||||
func @cast() {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
||||
%cst = arith.constant 1 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
||||
%0 = arith.extsi %cst : i32 to i64
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: add
|
||||
func @add() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = arith.addi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127]
|
||||
%3 = arith.constant dense<127> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
%4 = arith.addi %1, %3 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: sub
|
||||
func @sub() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = arith.subi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129]
|
||||
%3 = arith.constant dense<129> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
%4 = arith.subi %3, %1 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: mul
|
||||
func @mul() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = arith.muli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
%3 = arith.constant dense<128> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
%4 = arith.muli %3, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2]
|
||||
%5 = arith.constant dense<2> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256]
|
||||
%6 = arith.muli %4, %5 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: div
|
||||
func @div() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%3 = arith.divui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
||||
%5 = arith.divsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%6 = arith.divsi %4, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
||||
%7 = arith.constant dense<66> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [536870912] ; Constancy: [2] ; ConstantValue: [None]
|
||||
%8 = arith.divui %0, %7 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: rem
|
||||
func @rem() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%2 = arith.remsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%3 = arith.remui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%5 = arith.remsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%6 = arith.remsi %4, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
||||
%7 = arith.constant dense<66> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%8 = arith.remui %0, %7 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast
|
||||
func @broadcast() {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
%0 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64]
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64]
|
||||
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: splat
|
||||
func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None]
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: cmp
|
||||
func @cmp() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%1 = arith.constant dense<0> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%4 = arith.cmpi sle, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%5 = arith.cmpi sge, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
%6 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: logic
|
||||
func @logic() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
%1 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
%3 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None]
|
||||
%4 = arith.divsi %0, %3 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%5 = arith.andi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%6 = arith.ori %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%7 = arith.xori %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
%8 = arith.andi %2, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
%9 = arith.ori %2, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
%10 = arith.xori %2, %4 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: select
|
||||
func @select() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%1 = arith.constant dense<0> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
||||
%4 = arith.constant 0 : i1
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%7 = tt.splat %4 : (i1) -> tensor<128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%5 = select %4, %3, %7 : tensor<128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shift() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
%1 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
||||
%2 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%3 = arith.shli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%4 = arith.shrsi %0, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
%5 = arith.shli %1, %2 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @max_min() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = arith.maxsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%3 = arith.minsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
%4 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
||||
%5 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8]
|
||||
%6 = arith.maxsi %4, %5 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func @for() {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0]
|
||||
%a_init = arith.constant dense<0> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1]
|
||||
%b_init = arith.constant dense<1> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
||||
%c_init = arith.constant dense<4> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
||||
%ub = arith.constant 128 : index
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
||||
%lb = arith.constant 0 : index
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16]
|
||||
%step = arith.constant 16 : index
|
||||
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%t = arith.index_cast %iv : index to i32
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
||||
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: permute_2d
|
||||
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1]
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
@@ -56,28 +341,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
// CHECK-LABEL: store_constant_align
|
||||
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%1 = arith.muli %pid, %c128_i32 : i32
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%4 = arith.addi %3, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
||||
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None]
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||
return
|
||||
@@ -89,6 +375,7 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||
// CHECK-LABEL: vecadd_mask_align_16
|
||||
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -101,13 +388,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
@@ -117,6 +404,7 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||
// CHECK-LABEL: vecadd_mask_align_1
|
||||
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -129,7 +417,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout2 = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
|
||||
// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
@@ -54,6 +55,61 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remat_load_store
|
||||
func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout1>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout1>
|
||||
return
|
||||
}
|
||||
|
||||
// Don't rematerialize vectorized loads
|
||||
// CHECK-LABEL: remat_expensive
|
||||
func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout1>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout1>) -> tensor<64xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout1>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout0>
|
||||
return
|
||||
}
|
||||
|
||||
// Don't rematerialize loads when original and target layouts are different
|
||||
// CHECK-LABEL: remat_multi_layout
|
||||
func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout2>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout2>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout2>
|
||||
return
|
||||
}
|
||||
|
||||
// Always rematerialize single value loads
|
||||
// CHECK-LABEL: remat_single_value
|
||||
func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #layout1>) -> tensor<1xi32, #layout0>
|
||||
%3 = triton_gpu.convert_layout %0 : (tensor<1x!tt.ptr<i32>, #layout1>) -> tensor<1x!tt.ptr<i32>, #layout0>
|
||||
tt.store %3, %2 : tensor<1xi32, #layout0>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if
|
||||
func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
|
||||
@@ -36,8 +36,8 @@ struct TestAliasPass
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << op_name << "\n";
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
|
||||
SharedMemoryAliasAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
|
||||
@@ -19,9 +19,9 @@ struct TestAllocationPass
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
// Convert to std::string can remove quotes from op_name
|
||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << op_name << "\n";
|
||||
// Convert to std::string can remove quotes from opName
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
Allocation allocation(operation);
|
||||
operation->walk([&](Operation *op) {
|
||||
auto scratchBufferId = allocation.getBufferId(op);
|
||||
|
||||
@@ -11,7 +11,7 @@ struct TestAxisInfoPass
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string &name, raw_ostream &os, ArrayRef<int> vals) {
|
||||
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
|
||||
os << name << ": [";
|
||||
for (size_t d = 0; d < vals.size(); d++) {
|
||||
if (d != 0)
|
||||
@@ -29,7 +29,8 @@ struct TestAxisInfoPass
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
os << "Testing: " << operation->getName() << "\n";
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
operation->walk([&](Operation *op) {
|
||||
@@ -51,7 +52,14 @@ struct TestAxisInfoPass
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
os << " ; ";
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ( ";
|
||||
os << " ; ";
|
||||
auto constantValue = info.getConstantValue();
|
||||
os << "ConstantValue: [";
|
||||
if (constantValue.has_value())
|
||||
os << constantValue.value();
|
||||
else
|
||||
os << "None";
|
||||
os << "] ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << "\n";
|
||||
|
||||
@@ -23,8 +23,8 @@ struct TestMembarPass
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
// Convert to std::string can remove quotes from op_name
|
||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << op_name << "\n";
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
Allocation allocation(operation);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
Reference in New Issue
Block a user