[BACKEND] Clean up type inference functions (#1451)

And remove duplicate function definition.
This commit is contained in:
Keren Zhou
2023-03-30 23:07:32 -07:00
committed by GitHub
parent 109bfca5c0
commit 28ea484dab
10 changed files with 91 additions and 88 deletions

View File

@@ -10,10 +10,26 @@
namespace mlir {
unsigned getPointeeBitWidth(RankedTensorType tensorTy);
namespace triton {
bool isTensorPointerType(Type type);
unsigned getPointeeBitWidth(Type type);
Type getPointeeType(Type type);
Type getPointerType(Type type);
Type getElementTypeOfTensorPointerType(Type type);
Type getI1SameShape(Type type);
Type getI32SameShape(Type type);
Type getPointerTypeSameShape(Type type);
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_TYPES_H_

View File

@@ -910,7 +910,7 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
auto maxContig = axisInfo.getContiguity(order[0]);
auto elemNumBits = getPointeeBitWidth(tensorTy);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
unsigned alignment = std::min(maxMultiple, maxContig);

View File

@@ -28,7 +28,7 @@ struct LoadStoreConversionBase {
if (!tensorTy)
return 1;
auto contiguity = getContiguity(ptr);
auto pointeeBitWidth = getPointeeBitWidth(tensorTy);
auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy);
// The maximum vector size is 128 bits on NVIDIA GPUs.
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
}

View File

@@ -8,43 +8,6 @@
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return i1Type;
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return i32Type;
}
static Type getPointerTypeSameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
} else {
return PointerType::get(type, 1);
}
}
static Type getPointerType(Type type) { return PointerType::get(type, 1); }
static Type getElementTypeOfTensorPointerType(Type type) {
if (auto ptrType = type.dyn_cast<PointerType>())
if (auto tensorType = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return {};
}
// Parser & printer for assembly forms
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse operands

View File

@@ -13,7 +13,7 @@ static LogicalResult verifySameEncoding(Type typeA, Type typeB,
if (auto ptrType = type.dyn_cast<triton::PointerType>())
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
} else {
assert(!isTensorPointerType(type));
assert(!triton::isTensorPointerType(type));
}
return rankedType ? rankedType.getEncoding() : Attribute();
};
@@ -123,7 +123,7 @@ OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) {
bool OpTrait::impl::verifyLoadStorePointerAndValueType(Type valueType,
Type ptrType) {
if (isTensorPointerType(ptrType)) {
if (triton::isTensorPointerType(ptrType)) {
return ptrType.cast<triton::PointerType>().getPointeeType() == valueType;
} else if (auto rankedType = ptrType.dyn_cast<RankedTensorType>()) {
if (auto elementPtrType =

View File

@@ -40,16 +40,72 @@ void PointerType::print(AsmPrinter &printer) const {
namespace mlir {
unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
auto pointeeType = ptrTy.getPointeeType();
namespace triton {
unsigned getPointeeBitWidth(Type type) {
auto pointeeType = getPointeeType(type);
if (auto tensorTy = pointeeType.dyn_cast<RankedTensorType>())
return tensorTy.getElementType().getIntOrFloatBitWidth();
return pointeeType.getIntOrFloatBitWidth();
}
Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorTy = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorTy.getShape(), i1Type,
tensorTy.getEncoding());
return i1Type;
}
Type getPointeeType(Type type) {
if (auto tensorTy = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorTy.getShape();
auto ptrType = tensorTy.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return type;
}
Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorTy = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorTy.getShape(), i32Type,
tensorTy.getEncoding());
return i32Type;
}
Type getPointerTypeSameShape(Type type) {
if (auto tensorTy = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorTy.getElementType();
auto shape = tensorTy.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding());
} else {
return PointerType::get(type, 1);
}
}
Type getPointerType(Type type) { return PointerType::get(type, 1); }
bool isTensorPointerType(Type type) {
if (auto ptrType = type.dyn_cast<PointerType>())
return ptrType.getPointeeType().isa<RankedTensorType>();
return false;
}
Type getElementTypeOfTensorPointerType(Type type) {
if (auto ptrType = type.dyn_cast<PointerType>())
if (auto tensorTy = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
return tensorTy.getElementType();
return {};
}
} // namespace triton
} // namespace mlir

View File

@@ -197,9 +197,10 @@ public:
: computeCapability(computeCapability) {}
static bool needRewrite(Operation *op) {
return std::any_of(
op->getOperands().begin(), op->getOperands().end(),
[](Value operand) { return isTensorPointerType(operand.getType()); });
return std::any_of(op->getOperands().begin(), op->getOperands().end(),
[](Value operand) {
return triton::isTensorPointerType(operand.getType());
});
}
static SmallVector<Value>
@@ -273,7 +274,7 @@ public:
// We only have to rewrite load/stores with tensor pointers
auto ptr = op->getOperand(0);
if (!isTensorPointerType(ptr.getType()))
if (!triton::isTensorPointerType(ptr.getType()))
return nullptr;
// Get info from previous results
@@ -324,7 +325,7 @@ public:
SmallVector<Value> newIterOperands = op.getIterOperands();
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
++i, ++oldI) {
if (!isTensorPointerType(newIterOperands[i].getType()))
if (!triton::isTensorPointerType(newIterOperands[i].getType()))
continue;
// Expand the tensor pointer into offsets
@@ -348,7 +349,7 @@ public:
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (isTensorPointerType(oldRegionIterArg.getType())) {
if (triton::isTensorPointerType(oldRegionIterArg.getType())) {
// Pass rewrited info inside
assert(rewritedInfo.count(oldIterOperands[oldI]));
auto info = rewritedInfo[oldIterOperands[oldI]];
@@ -375,7 +376,7 @@ public:
assert(op.getNumResults() == op.getNumIterOperands());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (isTensorPointerType(oldResult.getType())) {
if (triton::isTensorPointerType(oldResult.getType())) {
// Pack new offsets into rewrited info
assert(rewritedInfo.count(oldIterOperands[oldI]));
auto info = rewritedInfo[oldIterOperands[oldI]];
@@ -398,7 +399,7 @@ public:
// Replace tensor pointers with offsets
SmallVector<Value> newOperands = op->getOperands();
for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) {
if (!isTensorPointerType(newOperands[i].getType()))
if (!triton::isTensorPointerType(newOperands[i].getType()))
continue;
assert(rewritedInfo.count(newOperands[i]));

View File

@@ -16,30 +16,6 @@ using namespace mlir::triton::gpu;
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getPointeeType(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorType.getShape();
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return Type();
}
namespace gpu {
// TODO: Inheritance of layout attributes

View File

@@ -53,7 +53,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
unsigned elemNumBits = getPointeeBitWidth(origType);
unsigned elemNumBits = triton::getPointeeBitWidth(origType);
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
unsigned perThread = 1;
for (Value val : withSameOrder) {

View File

@@ -22,15 +22,6 @@ namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static Type getI1SameShape(Value v) {
Type vType = v.getType();
auto i1Type = IntegerType::get(vType.getContext(), 1);
if (auto tensorType = vType.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return i1Type;
}
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
for (const NamedAttribute attr : dictAttrs.getValue())
@@ -321,7 +312,7 @@ LogicalResult LoopPipeliner::initialize() {
Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
Value loopCond, OpBuilder &builder) {
Type maskType = getI1SameShape(loadOp);
Type maskType = triton::getI1SameShape(loadOp.getType());
Value mask = loadOp.getMask();
Value newMask;
if (mask) {