diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 6e9ec2df4..38cbb8065 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1468,7 +1468,9 @@ struct InsertSliceOpConversion "Only support in-place insert_slice for now"); auto srcTy = src.getType().dyn_cast(); - auto srcLayout = srcTy.getEncoding().dyn_cast(); + auto srcLayout = srcTy.getEncoding(); + assert((srcLayout.isa() && + "Unexpected srcLayout in InsertSliceOpConversion")); auto srcShape = srcTy.getShape(); assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion"); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 034eff593..789aa7be0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -512,7 +512,7 @@ public: ConversionPatternRewriter &rewriter) const { auto srcTy = src.getType().cast(); auto srcShape = srcTy.getShape(); - assert(srcShape.size() == 2 && + assert((srcShape.size() == 1 || srcShape.size() == 2) && "Unexpected rank of storeDistributedToShared"); auto dstTy = dst.getType().cast(); auto srcDistributedLayout = srcTy.getEncoding(); @@ -538,8 +538,12 @@ public: auto wordTy = vec_ty(elemTy, minVec); Value word; - SmallVector srcStrides = {dstStrides[0], dstStrides[1]}; - SmallVector offsetVals = {i32_val(0), i32_val(0)}; + SmallVector srcStrides; + SmallVector offsetVals; + for (int i = 0; i < srcShape.size(); i++) { + srcStrides.push_back(dstStrides[i]); + offsetVals.push_back(i32_val(0)); + } SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals); DenseMap sharedPtrs = diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 1e4d6443f..104401e64 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -935,8 +935,10 @@ private: auto mask = insertSliceAsyncOp.getMask(); auto srcTy = src.getType().cast(); auto dstTy = dst.getType().cast(); - auto srcBlocked = - srcTy.getEncoding().dyn_cast(); + auto srcLayout = srcTy.getEncoding(); + assert((srcLayout.isa() && + "Unexpected srcLayout")); + auto resSharedLayout = dstTy.getEncoding().dyn_cast(); auto resElemTy = dstTy.getElementType(); @@ -966,7 +968,7 @@ private: // load auto tmpTy = - RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked); + RankedTensorType::get(srcTy.getShape(), resElemTy, srcLayout); auto loadOp = builder.create( insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(), insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(), @@ -999,8 +1001,12 @@ private: }); mod.walk([&](triton::gpu::AsyncCommitGroupOp asyncCommitGroupOp) -> void { +#ifdef USE_ROCM + asyncCommitGroupOp.erase(); +#else if (!triton::gpu::AsyncCommitGroupOp::isSupported(computeCapability)) asyncCommitGroupOp.erase(); +#endif }); mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 784183945..2c5ff02a5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -877,16 +877,46 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %71 = triton_gpu.alloc_tensor : tensor<2x64xi64, #shared> - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 - // CHECK-NEXT: cp.async.commit_group + + // This test is PTX specific, GCN targets decompose async operations into oridinary load/stores. + + // PTX: llvm.inline_asm has_side_effects asm_dialect = att + // PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // PTX-NEXT: cp.async.commit_group + + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i64 to vector<1xi64> + // GCN-COUNT-8: llvm.store {{.*}} : !llvm.ptr, 3> + %73 = triton_gpu.insert_slice_async %66, %71, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x!tt.ptr, #slice1d0> -> tensor<2x64xi64, #shared> triton_gpu.async_commit_group tt.return @@ -963,9 +993,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %index = arith.constant 1 : i32 // This test is PTX specific, GCN targets decompose async operations into oridinary load/stores. - // TODO: Fix AMD compilation. - // last operation (commit_group) is still emitted by AMD pipeline, - // It is left to catch changes in AMD compilation pipeline. // PTX: llvm.inline_asm has_side_effects asm_dialect = att // PTX-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10 @@ -999,7 +1026,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // GCN: llvm.load {{.*}} : !llvm.ptr // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> // GCN: llvm.store {{.*}} : !llvm.ptr, 3> - // GCN: llvm.inline_asm {{.*}}cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf32, #A> triton_gpu.async_commit_group tt.return @@ -1037,9 +1063,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %index = arith.constant 1 : i32 // This test is PTX specific, GCN targets decompose async operations into oridinary load/stores. - // TODO: Fix AMD compilation. - // last operation (commit_group) is still emitted by AMD pipeline, - // It is left to catch changes in AMD compilation pipeline. // PTX: llvm.inline_asm // PTX: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 @@ -1065,7 +1088,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // GCN: llvm.load {{.*}} : !llvm.ptr // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> // GCN-COUNT-4: llvm.store {{.*}} : !llvm.ptr, 3> - // GCN: llvm.inline_asm {{.*}}cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr, #AL> -> tensor<2x16x32xf32, #A> triton_gpu.async_commit_group tt.return @@ -1102,9 +1124,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %index = arith.constant 1 : i32 // This test is PTX specific, GCN targets decompose async operations into oridinary load/stores. - // TODO: Fix AMD compilation. - // last operation (commit_group) is still emitted by AMD pipeline, - // It is left to catch changes in AMD compilation pipeline. // // PTX: llvm.mlir.constant(0 : i32) : i32 // PTX: llvm.mlir.constant(16 : i32) : i32 @@ -1154,7 +1173,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // GCN: llvm.load {{.*}} : !llvm.ptr // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> // GCN-COUNT-8: llvm.store {{.*}} : !llvm.ptr, 3> - // GCN: llvm.inline_asm {{.*}}cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr, #AL> -> tensor<2x32x32xf32, #A> triton_gpu.async_commit_group tt.return