Add view_slice ttgir instruction (#427)

* Add view_slice op in ttgir

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
Co-authored-by: Ognjen <oplavsic@luxoft.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
oplavsic
2024-01-02 22:40:11 +01:00
committed by GitHub
parent 98589ac013
commit 6a520566a3
7 changed files with 372 additions and 0 deletions

View File

@@ -105,6 +105,11 @@ unsigned getNumCTAs(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
BlockedEncodingAttr blockedB);
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB);
bool isSharedEncoding(Value value);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

View File

@@ -14,6 +14,7 @@ namespace OpTrait {
// instantiated/duplicated.
namespace impl {
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
LogicalResult verifyOperandAndResultHaveSameEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
@@ -25,6 +26,14 @@ public:
}
};
template <typename ConcreteType>
class OperandAndResultHaveSameEncoding
: public TraitBase<ConcreteType, OperandAndResultHaveSameEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandAndResultHaveSameEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir

View File

@@ -14,6 +14,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
def OperandAndResultHaveSameEncoding: NativeOpTrait<"OperandAndResultHaveSameEncoding">;
class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
@@ -179,6 +180,67 @@ def TTG_InsertSliceOp : TTG_Op<"insert_slice",
}
def TTG_ViewSliceOp : TTG_Op<"view_slice",
[AttrSizedOperandSegments,
OperandAndResultHaveSameEncoding,
Pure,
OffsetSizeAndStrideOpInterface
]> {
let summary = "view slice operation";
let description = [{
Represents view of the slice of the tensor in registers. Syntax of the operation is the same
as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory,
'view_slice' specifically slices within registers.
Slice of the tensor is required to have the same layout as the original tensor.
In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics.
}];
let arguments = (ins
AnyRankedTensor:$source,
Variadic<I32>:$offsets,
Variadic<I32>:$sizes,
Variadic<I32>:$strides,
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
let builders = [
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = [{
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
/// Returns the type of the base tensor operand.
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank = getSourceType().getRank();
return {rank, rank, rank};
}
}];
let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $static_offsets)
custom<DynamicIndexList>($sizes, $static_sizes)
custom<DynamicIndexList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,

View File

@@ -849,6 +849,135 @@ struct ExtractSliceOpConversion
}
};
// clang-format off
/***
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# WO # W1 # | #
# # # | #
# # # # # | #
# W2 # W3 # .... | #
# # # | SkipElems #
# # # # # | #
# | #
# Slice | #
# . / \ | #
# . / \ | #
# . / \| #
# # # # # # #
# # W0 # W1 # #
# # # # #
# # # # # # tensorStride #
# # W2 # W3 # --------------------------------#
# # # # #
# # # # # # #
# tensorStride # W0 # W1 # #
# ---------------------------------- # # # #
# # # # # # #
# # W2 # W3 # #
# # # # #
# # # # # # ---> lastIdx #
# . #
# . #
# . #
# #
# #
# #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
***/
// clang-format on
struct ViewSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp> {
using OpAdaptor = typename triton::gpu::ViewSliceOp::Adaptor;
explicit ViewSliceOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp>(typeConverter,
benefit) {}
LogicalResult
processBlockedLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(
srcLayout &&
"Currently only blocked layout is supported in view_slice instruction");
auto srcShape = srcTy.getShape();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSource(), rewriter, srcTy);
auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy);
auto sizePerThread = srcLayout.getSizePerThread();
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1];
auto order = srcLayout.getOrder();
auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]);
shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]);
auto offsets = op.getStaticOffsets();
auto sizes = op.getStaticSizes();
// ViewSlice only supports slicing where offsets and sizes are multiples of
// shapePerCTA. This condition ensures that slice has the same layout as the
// original tensor.
assert(offsets[0] % shapePerCTA[0] == 0);
assert(offsets[1] % shapePerCTA[1] == 0);
assert(sizes[0] % shapePerCTA[0] == 0);
assert(sizes[1] % shapePerCTA[1] == 0);
assert(op.hasUnitStride() &&
"Only unit stride supported by ViewSliceOpConversion");
// Calculate offsets and sizes in terms of CTA units.
std::vector<long int> CTAOffsets{offsets[0] / shapePerCTA[0],
offsets[1] / shapePerCTA[1]};
std::vector<long int> CTASizes{sizes[0] / shapePerCTA[0],
sizes[1] / shapePerCTA[1]};
std::vector<long int> CTAPerShape{srcShape[0] / shapePerCTA[0],
srcShape[1] / shapePerCTA[1]};
SmallVector<Value> resultVals;
// The diagram above illustrates the graphical representation of the
// skipElems, tensorStride, and lastIdx variables.
auto skipElems = CTAOffsets[order[1]] *
(elemsPerThread[order[0]] * sizePerThread[order[1]]) +
CTAOffsets[order[0]] * totalSizePerThread;
auto tensorStride =
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread;
auto lastIdx =
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
elemsPerThread[order[0]] * sizePerThread[order[1]] +
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread;
assert(lastIdx <= vals.size());
for (int i = skipElems; i < lastIdx; i += tensorStride) {
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) {
assert(i < lastIdx);
resultVals.push_back(vals[i]);
}
}
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
LogicalResult
matchAndRewrite(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
if (srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>()) {
return processBlockedLayout(op, adaptor, rewriter);
} else {
assert(false && "unsupported layout in viewSlice");
return failure();
}
}
};
struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -954,6 +1083,7 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<ViewSliceOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);

View File

@@ -691,6 +691,59 @@ bool isaDistributedLayout(Attribute layout) {
layout.isa<MfmaEncodingAttr>() || layout.isa<SliceEncodingAttr>();
}
bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
BlockedEncodingAttr blockedB) {
auto sizePerThreadA = blockedA.getSizePerThread();
auto threadsPerWarpA = blockedA.getThreadsPerWarp();
auto warpsPerCTAA = blockedA.getWarpsPerCTA();
auto orderA = blockedA.getOrder();
size_t rankA = orderA.size();
auto sizePerThreadB = blockedB.getSizePerThread();
auto threadsPerWarpB = blockedB.getThreadsPerWarp();
auto warpsPerCTAB = blockedB.getWarpsPerCTA();
auto orderB = blockedB.getOrder();
size_t rankB = orderB.size();
if (rankA != rankB) {
return false;
}
for (size_t i = 0; i < rankA; ++i) {
if (sizePerThreadA[i] != sizePerThreadB[i] ||
threadsPerWarpA[i] != threadsPerWarpB[i] ||
warpsPerCTAA[i] != warpsPerCTAB[i] || orderA[i] != orderB[i]) {
return false;
}
}
return true;
}
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
auto nonKDimA = mfmaA.getNonKDim();
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
auto isTransposedA = mfmaA.getIsTransposed();
auto nonKDimB = mfmaB.getNonKDim();
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
auto isTransposedB = mfmaB.getIsTransposed();
if (nonKDimA != nonKDimB || isTransposedA != isTransposedB) {
return false;
}
if (warpsPerCTAA.size() != warpsPerCTAB.size()) {
return false;
}
auto rank = warpsPerCTAA.size();
for (size_t i = 0; i < rank; ++i) {
if (warpsPerCTAA[i] != warpsPerCTAB[i]) {
return false;
}
}
return true;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {

View File

@@ -12,3 +12,51 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
return success();
};
mlir::LogicalResult
mlir::OpTrait::impl::verifyOperandAndResultHaveSameEncoding(Operation *op) {
if (op->getNumOperands() != 1 || op->getNumResults() != 1) {
return failure();
}
auto operandType = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
auto resultType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!operandType || !resultType) {
return failure();
}
auto operandLayout = operandType.getEncoding();
auto resultLayout = resultType.getEncoding();
if (auto blockedLayoutSrc =
dyn_cast<triton::gpu::BlockedEncodingAttr>(operandLayout)) {
auto blockedLayoutRes =
dyn_cast<triton::gpu::BlockedEncodingAttr>(resultLayout);
if (!blockedLayoutRes) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
if (!triton::gpu::sameBlockedEncodings(blockedLayoutSrc,
blockedLayoutRes)) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
} else if (auto mfmaLayoutSrc =
dyn_cast<triton::gpu::MfmaEncodingAttr>(operandLayout)) {
auto mfmaLayoutRes = dyn_cast<triton::gpu::MfmaEncodingAttr>(resultLayout);
if (!mfmaLayoutRes) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
if (!triton::gpu::sameMfmaEncodings(mfmaLayoutSrc, mfmaLayoutRes)) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
} else {
assert(false &&
"Unexpected Layout in verifyOperandAndResultHaveSmeEncoding");
}
return success();
};

View File

@@ -2933,6 +2933,71 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
assert torch.equal(z, x)
layouts = [
BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
]
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [256, 256, 128, 64, 64, 128], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, src_layout, device='cuda'):
if torch.version.hip is None:
pytest.skip("view_slice is AMD specific instruction.")
ir = f"""
#src = {src_layout}
""" + """
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + f""" : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M_tile_size}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{M}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N_tile_size}xi32, #src>
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%12 = triton_gpu.view_slice %11[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #src> to tensor<{M_tile_size}x{N_tile_size}xf16, #src>
%13 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>, tensor<{M_tile_size}x{N_tile_size}xi32, #src>
tt.store %13, %12 : tensor<{M_tile_size}x{N_tile_size}xf16, #src>
tt.return
}}
}}
"""
x_numpy = numpy_random((M, N), dtype_str=dtype)
z_numpy = x_numpy[M_tile_offset:M_tile_offset + M_tile_size, N_tile_offset:N_tile_offset + N_tile_size]
x = to_triton(x_numpy)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
z = np.zeros((M_tile_size, N_tile_size)).astype('float16')
z_tri = torch.tensor(z, device=device)
kernel[(1, 1, 1)](x.data_ptr(), z_tri)
np.testing.assert_equal(z_numpy, to_numpy(z_tri))
if torch.version.hip is not None and _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),