mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add view_slice ttgir instruction (#427)
* Add view_slice op in ttgir --------- Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com> Co-authored-by: Ognjen <oplavsic@luxoft.com> Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
@@ -105,6 +105,11 @@ unsigned getNumCTAs(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
|
||||
BlockedEncodingAttr blockedB);
|
||||
|
||||
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB);
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
|
||||
|
||||
@@ -14,6 +14,7 @@ namespace OpTrait {
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
|
||||
LogicalResult verifyOperandAndResultHaveSameEncoding(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
@@ -25,6 +26,14 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class OperandAndResultHaveSameEncoding
|
||||
: public TraitBase<ConcreteType, OperandAndResultHaveSameEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyOperandAndResultHaveSameEncoding(op);
|
||||
}
|
||||
};
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
def OperandAndResultHaveSameEncoding: NativeOpTrait<"OperandAndResultHaveSameEncoding">;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
@@ -179,6 +180,67 @@ def TTG_InsertSliceOp : TTG_Op<"insert_slice",
|
||||
}
|
||||
|
||||
|
||||
def TTG_ViewSliceOp : TTG_Op<"view_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
OperandAndResultHaveSameEncoding,
|
||||
Pure,
|
||||
OffsetSizeAndStrideOpInterface
|
||||
]> {
|
||||
let summary = "view slice operation";
|
||||
let description = [{
|
||||
Represents view of the slice of the tensor in registers. Syntax of the operation is the same
|
||||
as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory,
|
||||
'view_slice' specifically slices within registers.
|
||||
Slice of the tensor is required to have the same layout as the original tensor.
|
||||
In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source,
|
||||
Variadic<I32>:$offsets,
|
||||
Variadic<I32>:$sizes,
|
||||
Variadic<I32>:$strides,
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let builders = [
|
||||
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
|
||||
// result type. If the type passed is nullptr, it is inferred.
|
||||
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
|
||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||
"ArrayRef<OpFoldResult>":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the number of leading operands before the `offsets`, `sizes` and
|
||||
/// and `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Returns the type of the base tensor operand.
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||
unsigned rank = getSourceType().getRank();
|
||||
return {rank, rank, rank};
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<DynamicIndexList>($offsets, $static_offsets)
|
||||
custom<DynamicIndexList>($sizes, $static_sizes)
|
||||
custom<DynamicIndexList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
|
||||
@@ -849,6 +849,135 @@ struct ExtractSliceOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
/***
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# WO # W1 # | #
|
||||
# # # | #
|
||||
# # # # # | #
|
||||
# W2 # W3 # .... | #
|
||||
# # # | SkipElems #
|
||||
# # # # # | #
|
||||
# | #
|
||||
# Slice | #
|
||||
# . / \ | #
|
||||
# . / \ | #
|
||||
# . / \| #
|
||||
# # # # # # #
|
||||
# # W0 # W1 # #
|
||||
# # # # #
|
||||
# # # # # # tensorStride #
|
||||
# # W2 # W3 # --------------------------------#
|
||||
# # # # #
|
||||
# # # # # # #
|
||||
# tensorStride # W0 # W1 # #
|
||||
# ---------------------------------- # # # #
|
||||
# # # # # # #
|
||||
# # W2 # W3 # #
|
||||
# # # # #
|
||||
# # # # # # ---> lastIdx #
|
||||
# . #
|
||||
# . #
|
||||
# . #
|
||||
# #
|
||||
# #
|
||||
# #
|
||||
# #
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
***/
|
||||
// clang-format on
|
||||
struct ViewSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp> {
|
||||
using OpAdaptor = typename triton::gpu::ViewSliceOp::Adaptor;
|
||||
explicit ViewSliceOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp>(typeConverter,
|
||||
benefit) {}
|
||||
|
||||
LogicalResult
|
||||
processBlockedLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(
|
||||
srcLayout &&
|
||||
"Currently only blocked layout is supported in view_slice instruction");
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSource(), rewriter, srcTy);
|
||||
|
||||
auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy);
|
||||
auto sizePerThread = srcLayout.getSizePerThread();
|
||||
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1];
|
||||
auto order = srcLayout.getOrder();
|
||||
auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape);
|
||||
shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]);
|
||||
shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]);
|
||||
|
||||
auto offsets = op.getStaticOffsets();
|
||||
auto sizes = op.getStaticSizes();
|
||||
|
||||
// ViewSlice only supports slicing where offsets and sizes are multiples of
|
||||
// shapePerCTA. This condition ensures that slice has the same layout as the
|
||||
// original tensor.
|
||||
assert(offsets[0] % shapePerCTA[0] == 0);
|
||||
assert(offsets[1] % shapePerCTA[1] == 0);
|
||||
assert(sizes[0] % shapePerCTA[0] == 0);
|
||||
assert(sizes[1] % shapePerCTA[1] == 0);
|
||||
assert(op.hasUnitStride() &&
|
||||
"Only unit stride supported by ViewSliceOpConversion");
|
||||
|
||||
// Calculate offsets and sizes in terms of CTA units.
|
||||
std::vector<long int> CTAOffsets{offsets[0] / shapePerCTA[0],
|
||||
offsets[1] / shapePerCTA[1]};
|
||||
std::vector<long int> CTASizes{sizes[0] / shapePerCTA[0],
|
||||
sizes[1] / shapePerCTA[1]};
|
||||
std::vector<long int> CTAPerShape{srcShape[0] / shapePerCTA[0],
|
||||
srcShape[1] / shapePerCTA[1]};
|
||||
|
||||
SmallVector<Value> resultVals;
|
||||
// The diagram above illustrates the graphical representation of the
|
||||
// skipElems, tensorStride, and lastIdx variables.
|
||||
auto skipElems = CTAOffsets[order[1]] *
|
||||
(elemsPerThread[order[0]] * sizePerThread[order[1]]) +
|
||||
CTAOffsets[order[0]] * totalSizePerThread;
|
||||
auto tensorStride =
|
||||
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread;
|
||||
auto lastIdx =
|
||||
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
|
||||
elemsPerThread[order[0]] * sizePerThread[order[1]] +
|
||||
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread;
|
||||
|
||||
assert(lastIdx <= vals.size());
|
||||
for (int i = skipElems; i < lastIdx; i += tensorStride) {
|
||||
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) {
|
||||
assert(i < lastIdx);
|
||||
resultVals.push_back(vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
|
||||
if (srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>()) {
|
||||
return processBlockedLayout(op, adaptor, rewriter);
|
||||
} else {
|
||||
assert(false && "unsupported layout in viewSlice");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -954,6 +1083,7 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<ViewSliceOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
|
||||
|
||||
@@ -691,6 +691,59 @@ bool isaDistributedLayout(Attribute layout) {
|
||||
layout.isa<MfmaEncodingAttr>() || layout.isa<SliceEncodingAttr>();
|
||||
}
|
||||
|
||||
bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
|
||||
BlockedEncodingAttr blockedB) {
|
||||
auto sizePerThreadA = blockedA.getSizePerThread();
|
||||
auto threadsPerWarpA = blockedA.getThreadsPerWarp();
|
||||
auto warpsPerCTAA = blockedA.getWarpsPerCTA();
|
||||
auto orderA = blockedA.getOrder();
|
||||
size_t rankA = orderA.size();
|
||||
|
||||
auto sizePerThreadB = blockedB.getSizePerThread();
|
||||
auto threadsPerWarpB = blockedB.getThreadsPerWarp();
|
||||
auto warpsPerCTAB = blockedB.getWarpsPerCTA();
|
||||
auto orderB = blockedB.getOrder();
|
||||
size_t rankB = orderB.size();
|
||||
|
||||
if (rankA != rankB) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < rankA; ++i) {
|
||||
if (sizePerThreadA[i] != sizePerThreadB[i] ||
|
||||
threadsPerWarpA[i] != threadsPerWarpB[i] ||
|
||||
warpsPerCTAA[i] != warpsPerCTAB[i] || orderA[i] != orderB[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
|
||||
auto nonKDimA = mfmaA.getNonKDim();
|
||||
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
|
||||
auto isTransposedA = mfmaA.getIsTransposed();
|
||||
|
||||
auto nonKDimB = mfmaB.getNonKDim();
|
||||
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
|
||||
auto isTransposedB = mfmaB.getIsTransposed();
|
||||
|
||||
if (nonKDimA != nonKDimB || isTransposedA != isTransposedB) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (warpsPerCTAA.size() != warpsPerCTAB.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto rank = warpsPerCTAA.size();
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
if (warpsPerCTAA[i] != warpsPerCTAB[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
|
||||
@@ -12,3 +12,51 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifyOperandAndResultHaveSameEncoding(Operation *op) {
|
||||
if (op->getNumOperands() != 1 || op->getNumResults() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto operandType = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!operandType || !resultType) {
|
||||
return failure();
|
||||
}
|
||||
auto operandLayout = operandType.getEncoding();
|
||||
auto resultLayout = resultType.getEncoding();
|
||||
|
||||
if (auto blockedLayoutSrc =
|
||||
dyn_cast<triton::gpu::BlockedEncodingAttr>(operandLayout)) {
|
||||
auto blockedLayoutRes =
|
||||
dyn_cast<triton::gpu::BlockedEncodingAttr>(resultLayout);
|
||||
if (!blockedLayoutRes) {
|
||||
return op->emitOpError()
|
||||
<< "requires operand and result to have same layout";
|
||||
}
|
||||
|
||||
if (!triton::gpu::sameBlockedEncodings(blockedLayoutSrc,
|
||||
blockedLayoutRes)) {
|
||||
return op->emitOpError()
|
||||
<< "requires operand and result to have same layout";
|
||||
}
|
||||
} else if (auto mfmaLayoutSrc =
|
||||
dyn_cast<triton::gpu::MfmaEncodingAttr>(operandLayout)) {
|
||||
auto mfmaLayoutRes = dyn_cast<triton::gpu::MfmaEncodingAttr>(resultLayout);
|
||||
if (!mfmaLayoutRes) {
|
||||
return op->emitOpError()
|
||||
<< "requires operand and result to have same layout";
|
||||
}
|
||||
if (!triton::gpu::sameMfmaEncodings(mfmaLayoutSrc, mfmaLayoutRes)) {
|
||||
return op->emitOpError()
|
||||
<< "requires operand and result to have same layout";
|
||||
}
|
||||
} else {
|
||||
assert(false &&
|
||||
"Unexpected Layout in verifyOperandAndResultHaveSmeEncoding");
|
||||
}
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
@@ -2933,6 +2933,71 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
|
||||
assert torch.equal(z, x)
|
||||
|
||||
|
||||
layouts = [
|
||||
BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [256, 256, 128, 64, 64, 128], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, src_layout, device='cuda'):
|
||||
if torch.version.hip is None:
|
||||
pytest.skip("view_slice is AMD specific instruction.")
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + f""" : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M_tile_size}x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
|
||||
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{M}xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
|
||||
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>
|
||||
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N_tile_size}xi32, #src>
|
||||
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
|
||||
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
|
||||
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
|
||||
%12 = triton_gpu.view_slice %11[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #src> to tensor<{M_tile_size}x{N_tile_size}xf16, #src>
|
||||
%13 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>, tensor<{M_tile_size}x{N_tile_size}xi32, #src>
|
||||
tt.store %13, %12 : tensor<{M_tile_size}x{N_tile_size}xf16, #src>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
x_numpy = numpy_random((M, N), dtype_str=dtype)
|
||||
z_numpy = x_numpy[M_tile_offset:M_tile_offset + M_tile_size, N_tile_offset:N_tile_offset + N_tile_size]
|
||||
x = to_triton(x_numpy)
|
||||
# write the IR to a temporary file using mkstemp
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
|
||||
z = np.zeros((M_tile_size, N_tile_size)).astype('float16')
|
||||
z_tri = torch.tensor(z, device=device)
|
||||
|
||||
kernel[(1, 1, 1)](x.data_ptr(), z_tri)
|
||||
np.testing.assert_equal(z_numpy, to_numpy(z_tri))
|
||||
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
|
||||
|
||||
Reference in New Issue
Block a user