#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/Support/raw_ostream.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir { // Function for extended Euclidean Algorithm static int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { // Base Case if (a == 0) { *x = 0; *y = 1; return b; } int64_t x1, y1; // To store results of recursive call int64_t gcd = gcdImpl(b % a, a, &x1, &y1); // Update x and y using results of // recursive call *x = y1 - (b / a) * x1; *y = x1; return gcd; } static int64_t gcd(int64_t a, int64_t b) { if (a == 0) return b; if (b == 0) return a; int64_t x, y; return gcdImpl(a, b, &x, &y); } static constexpr int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; } //===----------------------------------------------------------------------===// // AxisInfo //===----------------------------------------------------------------------===// AxisInfo AxisInfo::getPessimisticValueState(Value value) { auto rank = 1; if (TensorType ty = value.getType().dyn_cast()) rank = ty.getRank(); auto contiHint = 1; auto divHint = 1; auto constHint = 1; BlockArgument blockArg = value.dyn_cast(); if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); if (func::FuncOp fun = dyn_cast(op)) { Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) divHint = attr.cast().getValue().getZExtValue(); } else if (auto fun = dyn_cast(op)) { Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) divHint = attr.cast().getValue().getZExtValue(); } else { // Derive the divisibility of the induction variable only when // the step and the lower bound are both constants if (auto forOp = dyn_cast(op)) { if (blockArg == forOp.getInductionVar()) { if (auto lowerBound = forOp.getLowerBound().getDefiningOp()) { if (auto step = forOp.getStep().getDefiningOp()) { auto lowerBoundVal = lowerBound.getValue() .cast() .getValue() .getZExtValue(); auto stepVal = step.getValue().cast().getValue().getZExtValue(); auto k = gcd(lowerBoundVal, stepVal); if (k != 0) divHint = k; } } } } } } else if (Operation *op = value.getDefiningOp()) { DimVectorT knownContiguity(rank, 1); DimVectorT knownDivisibility(rank, 1); DimVectorT knownConstancy(rank, 1); if (Attribute attr = op->getAttr("tt.divisibility")) { auto vals = attr.cast().getValues(); knownDivisibility = DimVectorT(vals.begin(), vals.end()); } if (Attribute attr = op->getAttr("tt.contiguity")) { auto vals = attr.cast().getValues(); knownContiguity = DimVectorT(vals.begin(), vals.end()); } if (Attribute attr = op->getAttr("tt.constancy")) { auto vals = attr.cast().getValues(); knownConstancy = DimVectorT(vals.begin(), vals.end()); } return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); } return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint), /*knownDivisibility=*/DimVectorT(rank, divHint), /*knownConstancy=*/DimVectorT(rank, constHint)); } // The gcd of both arguments for each dimension AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { // If one argument is not initialized, return the other. if (lhs.getRank() == 0) return rhs; if (rhs.getRank() == 0) return lhs; DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; for (auto d = 0; d < lhs.getRank(); ++d) { contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); } std::optional constantValue; if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value() && lhs.getConstantValue() == rhs.getConstantValue()) constantValue = lhs.getConstantValue(); return AxisInfo(contiguity, divisibility, constancy, constantValue); } //===----------------------------------------------------------------------===// // AxisInfoVisitor //===----------------------------------------------------------------------===// template class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { return operands[0]->getValue(); } }; class MakeRangeOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(triton::MakeRangeOp op, ArrayRef *> operands) override { auto start = op.getStart(); auto end = op.getEnd(); return AxisInfo(/*contiguity=*/{end - start}, /*divisibility=*/{highestPowOf2Divisor(start)}, /*constancy=*/{1}); } }; class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(arith::ConstantOp op, ArrayRef *> operands) override { auto intAttr = op.getValue().dyn_cast(); auto boolAttr = op.getValue().dyn_cast(); if (intAttr || boolAttr) { int64_t value{}; if (intAttr) value = intAttr.getValue().getZExtValue(); else value = boolAttr.getValue() ? 1 : 0; return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{highestPowOf2Divisor(value)}, /*constancy=*/{1}, /*knownConstantValue=*/{value}); } // TODO: generalize to dense attr auto splatAttr = op.getValue().dyn_cast(); if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { int64_t value = splatAttr.getSplatValue().getZExtValue(); TensorType ty = splatAttr.getType().cast(); return AxisInfo( /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), /*divisibility=*/ AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), /*constancy=*/ AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), /*knownConstantValue=*/{value}); } return AxisInfo(); } }; template class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); } int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * // gcd(d_lhs, d_rhs) return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) { if constexpr (std::is_same_v || std::is_same_v) { return {lhs.getConstantValue().value() + rhs.getConstantValue().value()}; } else if constexpr (std::is_same_v) { return {lhs.getConstantValue().value() - rhs.getConstantValue().value()}; } } return {}; } }; class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // lhs * 1 = lhs auto lhsContiguity = rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 ? lhs.getContiguity(dim) : 1; // 1 * rhs = rhs auto rhsContiguity = lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 ? rhs.getContiguity(dim) : 1; return std::max(lhsContiguity, rhsContiguity); } int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // lhs = k * d_lhs // rhs = p * d_rhs // lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs return lhs.getDivisibility(dim) * rhs.getDivisibility(dim); } std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; return {}; } }; template class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // lhs / 1 = lhs return rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 1 ? lhs.getContiguity(dim) : 1; } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto shape = resTy.getShape(); // Case 1: both lhs and rhs are constants. auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); // Case 2: lhs contiguous, rhs constant. // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), // ..., (d_lhs * k + n) / (d_rhs * p) // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, // the minimal constancy is gcd(d_lhs, d_rhs). // Since gcd(d_lhs, d_rhs) maybe > len(lhs), // we need to use another gcd to get the actual constancy. if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { constancy = std::max(constancy, gcd(lhs.getContiguity(dim), gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)))); } return constancy; } int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // Case 1: lhs is 0 if (lhs.getConstantValue().has_value() && lhs.getConstantValue().value() == 0) return lhs.getDivisibility(dim); // Case 2: rhs is 1 if (rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 1) return lhs.getDivisibility(dim); // otherwise: return 1 return 1; } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; return {}; } }; template class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); auto shape = resTy.getShape(); int64_t contiguity = 1; // lhs contiguous, rhs constant // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), // ..., (d_lhs * k + n) % (d_rhs * p) // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, // The minimal contiguity is gcd(d_lhs, d_rhs). // Since gcd(d_lhs, d_rhs) maybe > len(lhs), // we need to use another gcd to get the actual contiguity. if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)))); } return contiguity; } int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r // r must be divisible by gcd(d_lhs, d_rhs) return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); }; int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto shape = resTy.getShape(); // lhs % 1 = 0 return rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 1 ? shape[dim] : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; else if (rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 1) return {0}; return {}; } }; class SplatOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(triton::SplatOp op, ArrayRef *> operands) override { Type _retTy = *op->result_type_begin(); TensorType retTy = _retTy.cast(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); } return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; class ExpandDimsOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(triton::ExpandDimsOp op, ArrayRef *> operands) override { AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfo::DimVectorT constancy = opInfo.getConstancy(); contiguity.insert(contiguity.begin() + op.getAxis(), 1); divisibility.insert(divisibility.begin() + op.getAxis(), 1); constancy.insert(constancy.begin() + op.getAxis(), 1); return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; class BroadcastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(triton::BroadcastOp op, ArrayRef *> operands) override { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); TensorType opTy = _opTy.cast(); ArrayRef retShape = retTy.getShape(); ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); constancy.push_back(opShape[d] == 1 ? retShape[d] : opInfo.getConstancy(d)); } return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; template class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return AxisInfo(); auto shape = resTy.getShape(); short rank = resTy.getRank(); auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); AxisInfo::DimVectorT contiguity, divisibility, constancy; std::optional constantValue; for (short d = 0; d < rank; ++d) { int64_t constHint = 1; if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value()) { constHint = lhsInfo.getConstancy(d); constantValue = compare(getPredicate(op), lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value()) ? 1 : 0; } else { // Case 1: lhs and rhs are both partial constants constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); // Case 2: lhs all constant, rhs all contiguous // NOTE: // lhs: 4 4 4 4 // rhs: 4 5 6 7 // lhs ge rhs: 1, 0, 0, 0 // Case 3: lhs all contiguous, rhs all constant // NOTE // lhs: 4 5 6 7 // rhs: 4 4 4 4 // lhs sle rhs: 1, 0, 0, 0 if (/*Case 2=*/( notGePredicate(getPredicate(op)) && (AxisInfoVisitor::isConstantDim(lhsInfo, shape, d) && AxisInfoVisitor::isContiguousDim(rhsInfo, shape, d))) || /*Case 3=*/(notLePredicate(getPredicate(op)) && (AxisInfoVisitor::isContiguousDim(lhsInfo, shape, d) && AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)))) { constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)))); } } constancy.push_back(constHint); divisibility.push_back(1); contiguity.push_back(1); } return AxisInfo(contiguity, divisibility, constancy, constantValue); } private: static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) { return op.getPredicate(); } static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { return op.getPredicate(); } static bool notGePredicate(arith::CmpIPredicate predicate) { return predicate != arith::CmpIPredicate::sge && predicate != arith::CmpIPredicate::uge; } static bool notLePredicate(arith::CmpIPredicate predicate) { return predicate != arith::CmpIPredicate::sle && predicate != arith::CmpIPredicate::ule; } static bool compare(arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) { switch (predicate) { case arith::CmpIPredicate::eq: return lhs == rhs; case arith::CmpIPredicate::ne: return lhs != rhs; case arith::CmpIPredicate::slt: return lhs < rhs; case arith::CmpIPredicate::sle: return lhs <= rhs; case arith::CmpIPredicate::sgt: return lhs > rhs; case arith::CmpIPredicate::sge: return lhs >= rhs; case arith::CmpIPredicate::ult: return (uint64_t)lhs < (uint64_t)rhs; case arith::CmpIPredicate::ule: return (uint64_t)lhs <= (uint64_t)rhs; case arith::CmpIPredicate::ugt: return (uint64_t)lhs > (uint64_t)rhs; case arith::CmpIPredicate::uge: return (uint64_t)lhs >= (uint64_t)rhs; default: break; } llvm_unreachable("unknown comparison predicate"); } }; template class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return AxisInfo(); auto shape = resTy.getShape(); auto rank = shape.size(); auto condConstancy = operands[0]->getValue().getConstancy(); auto lhsInfo = operands[1]->getValue(); auto rhsInfo = operands[2]->getValue(); AxisInfo::DimVectorT contiguity, divisibility, constancy; std::optional constantValue; if (operands[0]->getValue().getConstantValue().has_value()) { if (operands[0]->getValue().getConstantValue() == 0) { contiguity = rhsInfo.getContiguity(); divisibility = rhsInfo.getDivisibility(); constancy = rhsInfo.getConstancy(); constantValue = rhsInfo.getConstantValue(); } else { contiguity = lhsInfo.getContiguity(); divisibility = lhsInfo.getDivisibility(); constancy = lhsInfo.getConstancy(); constantValue = lhsInfo.getConstantValue(); } } else { for (auto d = 0; d < rank; ++d) { constancy.push_back( std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), gcd(rhsInfo.getConstancy(d), condConstancy[d]))); divisibility.push_back( std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); contiguity.push_back( std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), gcd(rhsInfo.getContiguity(d), condConstancy[d]))); } if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value() && lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) constantValue = lhsInfo.getConstantValue(); } return AxisInfo(contiguity, divisibility, constancy, constantValue); } }; template class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) { if constexpr (std::is_same::value) { return {lhs.getConstantValue().value() & rhs.getConstantValue().value()}; } else if constexpr (std::is_same::value) { return {lhs.getConstantValue().value() | rhs.getConstantValue().value()}; } else if constexpr (std::is_same::value) { return {lhs.getConstantValue().value() ^ rhs.getConstantValue().value()}; } } return {}; } }; class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { if (rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 0) return lhs.getContiguity(dim); else return 1; } int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto shift = rhs.getConstantValue().has_value() ? rhs.getConstantValue().value() : rhs.getDivisibility(dim); auto numBits = log2Int(lhs.getDivisibility(dim)); auto maxBits = log2Int(highestPowOf2Divisor(0)); // Make sure the return value doesn't exceed highestPowOf2Divisor(0) if (shift + numBits > maxBits) return highestPowOf2Divisor(0); return lhs.getDivisibility(dim) << shift; } int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; return {}; } }; template class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { if (rhs.getConstantValue().has_value() && rhs.getConstantValue().value() == 0) return lhs.getContiguity(dim); else return 1; } int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { if (rhs.getConstantValue().has_value()) return std::max(1, lhs.getDivisibility(dim) / (1 << rhs.getConstantValue().value())); else return std::max(1, lhs.getDivisibility(dim) / (1 << rhs.getDivisibility(dim))); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; return {}; } }; template class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); std::optional constantValue; if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value()) { if constexpr (std::is_same_v || std::is_same_v) { constantValue = {std::max(lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value())}; } else if constexpr (std::is_same_v || std::is_same_v) { constantValue = {std::min(lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value())}; } } auto rank = lhsInfo.getRank(); return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), /*constantValue=*/constantValue); } }; //===----------------------------------------------------------------------===// // AxisInfoAnalysis //===----------------------------------------------------------------------===// AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) : dataflow::SparseDataFlowAnalysis>(solver) { // UnrealizedConversionCast: // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast // may exist visitors.append, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); visitors.append(); visitors.append, AddSubOpAxisInfoVisitor, AddSubOpAxisInfoVisitor>(); visitors.append(); visitors.append, DivOpAxisInfoVisitor>(); visitors.append, RemOpAxisInfoVisitor>(); visitors.append(); visitors.append(); visitors.append(); visitors.append, CmpOpAxisInfoVisitor>(); visitors.append, LogicalOpAxisInfoVisitor, LogicalOpAxisInfoVisitor>(); visitors.append, SelectOpAxisInfoVisitor>(); visitors.append, ShROpAxisInfoVisitor>(); visitors.append, MaxMinOpAxisInfoVisitor, MaxMinOpAxisInfoVisitor, MaxMinOpAxisInfoVisitor>(); } void AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { // TODO: For sure not the right way to do this // but why is scf.if not initialized otherwise? for (auto op : operands) if (op->getValue().getRank() == 0) setToEntryState((dataflow::Lattice *)op); AxisInfo curr = visitors.apply(op, operands); if (curr.getRank() == 0) return setAllToEntryStates(results); // override with hint auto newContiguity = curr.getContiguity(); auto newDivisibility = curr.getDivisibility(); auto newConstancy = curr.getConstancy(); if (Attribute attr = op->getAttr("tt.contiguity")) { auto vals = attr.cast().getValues(); newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); } if (Attribute attr = op->getAttr("tt.divisibility")) { auto vals = attr.cast().getValues(); newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); } if (Attribute attr = op->getAttr("tt.constancy")) { auto vals = attr.cast().getValues(); newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); } curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy, curr.getConstantValue()); // join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(curr)); } unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. auto order = triton::gpu::getOrder(layout); unsigned align = getPtrAlignment(ptr); unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]]; contigPerThread = std::min(align, contigPerThread); contigPerThread = std::min(shape[order[0]], contigPerThread); return contigPerThread; } unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; dataflow::Lattice *latticeElement = getLatticeElement(ptr); if (!latticeElement) return 1; auto axisInfo = latticeElement->getValue(); auto layout = tensorTy.getEncoding(); auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo.getDivisibility(order[0]); auto maxContig = axisInfo.getContiguity(order[0]); auto elemNumBits = getPointeeBitWidth(tensorTy); auto elemNumBytes = std::max(elemNumBits / 8, 1); auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); unsigned alignment = std::min(maxMultiple, maxContig); return alignment; } unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { auto tensorTy = mask.getType().dyn_cast(); if (!tensorTy) return 1; dataflow::Lattice *latticeElement = getLatticeElement(mask); if (!latticeElement) return 1; auto maskAxis = latticeElement->getValue(); auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); return alignment; } } // namespace mlir