[BACKEND] Fix tl.cat when the number of threads > the size of a tensor (#1751)

`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates
elements into a single thread, if number of threads is 128, each thread
should own at least 2 elements.
With this PR, we also disable remat of the cat op in some cases.
This commit is contained in:
Keren Zhou
2023-06-07 15:42:38 -07:00
committed by GitHub
parent 8faa47a810
commit 4fbadf6f6f
7 changed files with 104 additions and 10 deletions

View File

@@ -87,7 +87,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
/// output[i] = input[order[i]]
template <typename T, typename RES_T = T>
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
@@ -99,6 +99,7 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
return result;
}
/// Get the highest power of 2 divisor of an integer.
template <typename T> T highestPowOf2Divisor(T n) {
if (n == 0) {
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
@@ -106,6 +107,19 @@ template <typename T> T highestPowOf2Divisor(T n) {
return (n & (~(n - 1)));
}
/// Get the next power of 2 for an integer (or the integer itself if it is a
/// power of 2).
template <typename T> T nextPowOf2(T n) {
if (n == 0) {
return 1;
}
n--;
for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) {
n |= n >> i;
}
return n + 1;
}
bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

View File

@@ -23,6 +23,9 @@ namespace gpu {
unsigned getTotalElemsPerThread(Type type);
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy);
SmallVector<unsigned> getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
@@ -72,6 +75,8 @@ SmallVector<unsigned> getOrder(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding);
} // namespace gpu
} // namespace triton

View File

@@ -6,6 +6,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
@@ -309,11 +310,43 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
// The cat op satisfy two conditions:
// 1. output.numel = lhs.numel + rhs.numel
// 2. output.total_elems_per_thread =
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
// For now, this behaves like generic, but this
// will evolve when we add support for `can_reorder=False`.
auto retType = this->getTypeConverter()
->convertType(op.getType())
.cast<RankedTensorType>();
auto retEncoding =
retType.getEncoding().cast<triton::gpu::BlockedEncodingAttr>();
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
auto retShape = retType.getShape();
auto retOrder = retEncoding.getOrder();
auto retSizePerThread = retEncoding.getSizePerThread();
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
// Get new retSizePerThread if ret elems per thread is not enough.
// We have to round it up to the next power of 2 due to triton's tensor size
// constraint.
auto newRetTotalElemsPerThread =
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
auto newRetSizePerThread = retSizePerThread.vec();
newRetSizePerThread[retOrder[0]] *=
newRetTotalElemsPerThread / retTotalElemsPerThread;
triton::gpu::BlockedEncodingAttr newRetEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
newRetEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
op, retType, adaptor.getOperands()),
op, newRetType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}

View File

@@ -354,6 +354,19 @@ bool isaDistributedLayout(Attribute layout) {
layout.isa<SliceEncodingAttr>();
}
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}
} // namespace gpu
} // namespace triton
@@ -1127,6 +1140,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
}
// cvt(cat) -> cat
if (auto cat = dyn_cast<triton::CatOp>(arg)) {
auto encoding =
op->getResult(0).getType().cast<RankedTensorType>().getEncoding();
if (triton::gpu::expensiveCat(cat, encoding))
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
return mlir::success();

View File

@@ -359,7 +359,7 @@ public:
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, dstEncoding))
if (expensiveToRemat(op, srcEncoding))
return failure();
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&

View File

@@ -112,6 +112,8 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
if (isa<triton::CatOp>(op))
return triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
@@ -122,10 +124,11 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
return false;
}
bool canFoldConversion(Operation *op) {
bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
if (isa<triton::CatOp>(op))
return !triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
triton::CatOp>(*op);
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
}
int simulateBackwardRematerialization(
@@ -173,7 +176,7 @@ int simulateBackwardRematerialization(
continue;
// If the conversion can be folded into opArgI then
// we don't count this conversion as expensive
if (canFoldConversion(opArgI))
if (canFoldConversion(opArgI, newEncoding))
continue;
// We add one expensive conversion for the current operand

View File

@@ -1147,6 +1147,28 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref
@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]])
def test_cat(dtype_str, num_warps):
check_type_supported(dtype_str)
@triton.jit
def kernel(X, Y, Z, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.load(X + offs)
y = tl.load(Y + offs)
z = tl.cat(x, y, can_reorder=True)
tl.store(Z + tl.arange(0, 2 * N), z)
x = torch.arange(0, 128, device='cuda').to(getattr(torch, dtype_str))
y = torch.arange(-128, 0, device='cuda').to(getattr(torch, dtype_str))
z_ref = torch.cat([x, y], dim=0).sum()
z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device='cuda')
kernel[(1, )](x, y, z, N=128, num_warps=num_warps)
assert z.sum() == z_ref
# check if there's no duplicate value in z
assert z.unique().size(0) == z.size(0)
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
def test_store_constant(dtype_str):
check_type_supported(dtype_str)