mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Clean up type inference functions (#1451)
And remove duplicate function definition.
This commit is contained in:
@@ -10,10 +10,26 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
unsigned getPointeeBitWidth(RankedTensorType tensorTy);
|
||||
namespace triton {
|
||||
|
||||
bool isTensorPointerType(Type type);
|
||||
|
||||
unsigned getPointeeBitWidth(Type type);
|
||||
|
||||
Type getPointeeType(Type type);
|
||||
|
||||
Type getPointerType(Type type);
|
||||
|
||||
Type getElementTypeOfTensorPointerType(Type type);
|
||||
|
||||
Type getI1SameShape(Type type);
|
||||
|
||||
Type getI32SameShape(Type type);
|
||||
|
||||
Type getPointerTypeSameShape(Type type);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
|
||||
@@ -910,7 +910,7 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
|
||||
auto maxContig = axisInfo.getContiguity(order[0]);
|
||||
auto elemNumBits = getPointeeBitWidth(tensorTy);
|
||||
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
|
||||
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
|
||||
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
|
||||
@@ -28,7 +28,7 @@ struct LoadStoreConversionBase {
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto contiguity = getContiguity(ptr);
|
||||
auto pointeeBitWidth = getPointeeBitWidth(tensorTy);
|
||||
auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy);
|
||||
// The maximum vector size is 128 bits on NVIDIA GPUs.
|
||||
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
|
||||
}
|
||||
|
||||
@@ -8,43 +8,6 @@
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeSameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Type elementType = tensorType.getElementType();
|
||||
auto shape = tensorType.getShape();
|
||||
PointerType ptrType = PointerType::get(elementType, 1);
|
||||
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
||||
} else {
|
||||
return PointerType::get(type, 1);
|
||||
}
|
||||
}
|
||||
|
||||
static Type getPointerType(Type type) { return PointerType::get(type, 1); }
|
||||
|
||||
static Type getElementTypeOfTensorPointerType(Type type) {
|
||||
if (auto ptrType = type.dyn_cast<PointerType>())
|
||||
if (auto tensorType = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
|
||||
return tensorType.getElementType();
|
||||
return {};
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Parse operands
|
||||
|
||||
@@ -13,7 +13,7 @@ static LogicalResult verifySameEncoding(Type typeA, Type typeB,
|
||||
if (auto ptrType = type.dyn_cast<triton::PointerType>())
|
||||
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
|
||||
} else {
|
||||
assert(!isTensorPointerType(type));
|
||||
assert(!triton::isTensorPointerType(type));
|
||||
}
|
||||
return rankedType ? rankedType.getEncoding() : Attribute();
|
||||
};
|
||||
@@ -123,7 +123,7 @@ OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) {
|
||||
|
||||
bool OpTrait::impl::verifyLoadStorePointerAndValueType(Type valueType,
|
||||
Type ptrType) {
|
||||
if (isTensorPointerType(ptrType)) {
|
||||
if (triton::isTensorPointerType(ptrType)) {
|
||||
return ptrType.cast<triton::PointerType>().getPointeeType() == valueType;
|
||||
} else if (auto rankedType = ptrType.dyn_cast<RankedTensorType>()) {
|
||||
if (auto elementPtrType =
|
||||
|
||||
@@ -40,16 +40,72 @@ void PointerType::print(AsmPrinter &printer) const {
|
||||
|
||||
namespace mlir {
|
||||
|
||||
unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
|
||||
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
|
||||
auto pointeeType = ptrTy.getPointeeType();
|
||||
namespace triton {
|
||||
|
||||
unsigned getPointeeBitWidth(Type type) {
|
||||
auto pointeeType = getPointeeType(type);
|
||||
if (auto tensorTy = pointeeType.dyn_cast<RankedTensorType>())
|
||||
return tensorTy.getElementType().getIntOrFloatBitWidth();
|
||||
return pointeeType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorTy = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorTy.getShape(), i1Type,
|
||||
tensorTy.getEncoding());
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
Type getPointeeType(Type type) {
|
||||
if (auto tensorTy = type.dyn_cast<RankedTensorType>()) {
|
||||
// Tensor of pointers
|
||||
auto shape = tensorTy.getShape();
|
||||
auto ptrType = tensorTy.getElementType().dyn_cast<PointerType>();
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding());
|
||||
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||
// scalar pointer
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return pointeeType;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorTy = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorTy.getShape(), i32Type,
|
||||
tensorTy.getEncoding());
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
Type getPointerTypeSameShape(Type type) {
|
||||
if (auto tensorTy = type.dyn_cast<RankedTensorType>()) {
|
||||
Type elementType = tensorTy.getElementType();
|
||||
auto shape = tensorTy.getShape();
|
||||
PointerType ptrType = PointerType::get(elementType, 1);
|
||||
return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding());
|
||||
} else {
|
||||
return PointerType::get(type, 1);
|
||||
}
|
||||
}
|
||||
|
||||
Type getPointerType(Type type) { return PointerType::get(type, 1); }
|
||||
|
||||
bool isTensorPointerType(Type type) {
|
||||
if (auto ptrType = type.dyn_cast<PointerType>())
|
||||
return ptrType.getPointeeType().isa<RankedTensorType>();
|
||||
return false;
|
||||
}
|
||||
|
||||
Type getElementTypeOfTensorPointerType(Type type) {
|
||||
if (auto ptrType = type.dyn_cast<PointerType>())
|
||||
if (auto tensorTy = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
|
||||
return tensorTy.getElementType();
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -197,9 +197,10 @@ public:
|
||||
: computeCapability(computeCapability) {}
|
||||
|
||||
static bool needRewrite(Operation *op) {
|
||||
return std::any_of(
|
||||
op->getOperands().begin(), op->getOperands().end(),
|
||||
[](Value operand) { return isTensorPointerType(operand.getType()); });
|
||||
return std::any_of(op->getOperands().begin(), op->getOperands().end(),
|
||||
[](Value operand) {
|
||||
return triton::isTensorPointerType(operand.getType());
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
@@ -273,7 +274,7 @@ public:
|
||||
|
||||
// We only have to rewrite load/stores with tensor pointers
|
||||
auto ptr = op->getOperand(0);
|
||||
if (!isTensorPointerType(ptr.getType()))
|
||||
if (!triton::isTensorPointerType(ptr.getType()))
|
||||
return nullptr;
|
||||
|
||||
// Get info from previous results
|
||||
@@ -324,7 +325,7 @@ public:
|
||||
SmallVector<Value> newIterOperands = op.getIterOperands();
|
||||
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
|
||||
++i, ++oldI) {
|
||||
if (!isTensorPointerType(newIterOperands[i].getType()))
|
||||
if (!triton::isTensorPointerType(newIterOperands[i].getType()))
|
||||
continue;
|
||||
|
||||
// Expand the tensor pointer into offsets
|
||||
@@ -348,7 +349,7 @@ public:
|
||||
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
|
||||
++i, ++oldI) {
|
||||
auto oldRegionIterArg = op.getRegionIterArg(oldI);
|
||||
if (isTensorPointerType(oldRegionIterArg.getType())) {
|
||||
if (triton::isTensorPointerType(oldRegionIterArg.getType())) {
|
||||
// Pass rewrited info inside
|
||||
assert(rewritedInfo.count(oldIterOperands[oldI]));
|
||||
auto info = rewritedInfo[oldIterOperands[oldI]];
|
||||
@@ -375,7 +376,7 @@ public:
|
||||
assert(op.getNumResults() == op.getNumIterOperands());
|
||||
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
|
||||
auto oldResult = op.getResult(oldI);
|
||||
if (isTensorPointerType(oldResult.getType())) {
|
||||
if (triton::isTensorPointerType(oldResult.getType())) {
|
||||
// Pack new offsets into rewrited info
|
||||
assert(rewritedInfo.count(oldIterOperands[oldI]));
|
||||
auto info = rewritedInfo[oldIterOperands[oldI]];
|
||||
@@ -398,7 +399,7 @@ public:
|
||||
// Replace tensor pointers with offsets
|
||||
SmallVector<Value> newOperands = op->getOperands();
|
||||
for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) {
|
||||
if (!isTensorPointerType(newOperands[i].getType()))
|
||||
if (!triton::isTensorPointerType(newOperands[i].getType()))
|
||||
continue;
|
||||
|
||||
assert(rewritedInfo.count(newOperands[i]));
|
||||
|
||||
@@ -16,30 +16,6 @@ using namespace mlir::triton::gpu;
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getPointeeType(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
// Tensor of pointers
|
||||
auto shape = tensorType.getShape();
|
||||
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
||||
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||
// scalar pointer
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return pointeeType;
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// TODO: Inheritance of layout attributes
|
||||
|
||||
@@ -53,7 +53,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
unsigned elemNumBits = getPointeeBitWidth(origType);
|
||||
unsigned elemNumBits = triton::getPointeeBitWidth(origType);
|
||||
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
|
||||
unsigned perThread = 1;
|
||||
for (Value val : withSameOrder) {
|
||||
|
||||
@@ -22,15 +22,6 @@ namespace ttg = triton::gpu;
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
static Type getI1SameShape(Value v) {
|
||||
Type vType = v.getType();
|
||||
auto i1Type = IntegerType::get(vType.getContext(), 1);
|
||||
if (auto tensorType = vType.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
|
||||
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
|
||||
for (const NamedAttribute attr : dictAttrs.getValue())
|
||||
@@ -321,7 +312,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
|
||||
Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
|
||||
Value loopCond, OpBuilder &builder) {
|
||||
Type maskType = getI1SameShape(loadOp);
|
||||
Type maskType = triton::getI1SameShape(loadOp.getType());
|
||||
Value mask = loadOp.getMask();
|
||||
Value newMask;
|
||||
if (mask) {
|
||||
|
||||
Reference in New Issue
Block a user