mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix tl.cat when the number of threads > the size of a tensor (#1751)
`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates elements into a single thread, if number of threads is 128, each thread should own at least 2 elements. With this PR, we also disable remat of the cat op in some cases.
This commit is contained in:
@@ -87,7 +87,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
// output[i] = input[order[i]]
|
||||
/// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
@@ -99,6 +99,7 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Get the highest power of 2 divisor of an integer.
|
||||
template <typename T> T highestPowOf2Divisor(T n) {
|
||||
if (n == 0) {
|
||||
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
|
||||
@@ -106,6 +107,19 @@ template <typename T> T highestPowOf2Divisor(T n) {
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
/// Get the next power of 2 for an integer (or the integer itself if it is a
|
||||
/// power of 2).
|
||||
template <typename T> T nextPowOf2(T n) {
|
||||
if (n == 0) {
|
||||
return 1;
|
||||
}
|
||||
n--;
|
||||
for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) {
|
||||
n |= n >> i;
|
||||
}
|
||||
return n + 1;
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
@@ -23,6 +23,9 @@ namespace gpu {
|
||||
|
||||
unsigned getTotalElemsPerThread(Type type);
|
||||
|
||||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy);
|
||||
|
||||
SmallVector<unsigned> getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
@@ -72,6 +75,8 @@ SmallVector<unsigned> getOrder(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
@@ -309,11 +310,43 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
// The cat op satisfy two conditions:
|
||||
// 1. output.numel = lhs.numel + rhs.numel
|
||||
// 2. output.total_elems_per_thread =
|
||||
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
|
||||
// For now, this behaves like generic, but this
|
||||
// will evolve when we add support for `can_reorder=False`.
|
||||
auto retType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
|
||||
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
|
||||
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
|
||||
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
|
||||
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
|
||||
auto retShape = retType.getShape();
|
||||
auto retOrder = retEncoding.getOrder();
|
||||
auto retSizePerThread = retEncoding.getSizePerThread();
|
||||
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
|
||||
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
|
||||
// Get new retSizePerThread if ret elems per thread is not enough.
|
||||
// We have to round it up to the next power of 2 due to triton's tensor size
|
||||
// constraint.
|
||||
auto newRetTotalElemsPerThread =
|
||||
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
|
||||
auto newRetSizePerThread = retSizePerThread.vec();
|
||||
newRetSizePerThread[retOrder[0]] *=
|
||||
newRetTotalElemsPerThread / retTotalElemsPerThread;
|
||||
triton::gpu::BlockedEncodingAttr newRetEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
|
||||
newRetEncoding);
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
||||
op, retType, adaptor.getOperands()),
|
||||
op, newRetType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -354,6 +354,19 @@ bool isaDistributedLayout(Attribute layout) {
|
||||
layout.isa<SliceEncodingAttr>();
|
||||
}
|
||||
|
||||
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
|
||||
// If the new elements per thread is less than the old one, we will need to do
|
||||
// convert encoding that goes through shared memory anyway. So we consider it
|
||||
// as expensive.
|
||||
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
|
||||
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
|
||||
auto shape = tensorTy.getShape();
|
||||
auto elemTy = tensorTy.getElementType();
|
||||
auto newTotalElemsPerThread =
|
||||
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
|
||||
return newTotalElemsPerThread < totalElemsPerThread;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
@@ -1127,6 +1140,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
}
|
||||
// cvt(cat) -> cat
|
||||
if (auto cat = dyn_cast<triton::CatOp>(arg)) {
|
||||
auto encoding =
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding();
|
||||
if (triton::gpu::expensiveCat(cat, encoding))
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, op->getResult(0).getType(),
|
||||
cat.getOperands());
|
||||
return mlir::success();
|
||||
|
||||
@@ -359,7 +359,7 @@ public:
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
// don't rematerialize anything expensive
|
||||
if (expensiveToRemat(op, dstEncoding))
|
||||
if (expensiveToRemat(op, srcEncoding))
|
||||
return failure();
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
|
||||
@@ -112,6 +112,8 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
return true;
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
return expensiveLoadOrStore(op, targetEncoding);
|
||||
if (isa<triton::CatOp>(op))
|
||||
return triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp, triton::DotOp>(op))
|
||||
@@ -122,10 +124,11 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool canFoldConversion(Operation *op) {
|
||||
bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
|
||||
if (isa<triton::CatOp>(op))
|
||||
return !triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
|
||||
triton::CatOp>(*op);
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
|
||||
}
|
||||
|
||||
int simulateBackwardRematerialization(
|
||||
@@ -173,7 +176,7 @@ int simulateBackwardRematerialization(
|
||||
continue;
|
||||
// If the conversion can be folded into opArgI then
|
||||
// we don't count this conversion as expensive
|
||||
if (canFoldConversion(opArgI))
|
||||
if (canFoldConversion(opArgI, newEncoding))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
|
||||
@@ -1147,6 +1147,28 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]])
|
||||
def test_cat(dtype_str, num_warps):
|
||||
check_type_supported(dtype_str)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, Z, N: tl.constexpr):
|
||||
offs = tl.arange(0, N)
|
||||
x = tl.load(X + offs)
|
||||
y = tl.load(Y + offs)
|
||||
z = tl.cat(x, y, can_reorder=True)
|
||||
tl.store(Z + tl.arange(0, 2 * N), z)
|
||||
|
||||
x = torch.arange(0, 128, device='cuda').to(getattr(torch, dtype_str))
|
||||
y = torch.arange(-128, 0, device='cuda').to(getattr(torch, dtype_str))
|
||||
z_ref = torch.cat([x, y], dim=0).sum()
|
||||
z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device='cuda')
|
||||
kernel[(1, )](x, y, z, N=128, num_warps=num_warps)
|
||||
assert z.sum() == z_ref
|
||||
# check if there's no duplicate value in z
|
||||
assert z.unique().size(0) == z.size(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
|
||||
def test_store_constant(dtype_str):
|
||||
check_type_supported(dtype_str)
|
||||
|
||||
Reference in New Issue
Block a user