mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Verify the same operand and result element type for convert_layout (#1081)
And a hotfix for incorrect convert_layout construction in the GPU combine pass.
This commit is contained in:
@@ -96,8 +96,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
|
||||
|
||||
@@ -16,7 +16,9 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[SameOperandsAndResultShape, NoSideEffect]> {
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultElementType,
|
||||
NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
@@ -87,8 +89,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
@@ -498,8 +498,6 @@ public:
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
|
||||
return failure();
|
||||
Attribute dstEncoding =
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
// don't rematerialize if it adds an extra conversion that can't
|
||||
// be removed
|
||||
for (Value arg : op->getOperands()) {
|
||||
@@ -509,7 +507,7 @@ public:
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
|
||||
failed(simulateBackwardRematerialization(argOp, processed, layout,
|
||||
toConvert, dstEncoding))) {
|
||||
toConvert, srcEncoding))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@@ -521,8 +519,11 @@ public:
|
||||
if (arg.getDefiningOp() == cvt)
|
||||
mapping.map(arg, cvt.getOperand());
|
||||
else {
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
arg.getLoc(), cvt.getOperand().getType(), arg);
|
||||
auto oldType = arg.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
oldType.getShape(), oldType.getElementType(), srcEncoding);
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(arg.getLoc(),
|
||||
newType, arg);
|
||||
if (Operation *argOp = arg.getDefiningOp())
|
||||
cvtI->moveAfter(argOp);
|
||||
mapping.map(arg, cvtI);
|
||||
@@ -531,14 +532,12 @@ public:
|
||||
rewriter.setInsertionPoint(op);
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto oldType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
oldType.getShape(), oldType.getElementType(),
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding());
|
||||
auto newType = RankedTensorType::get(oldType.getShape(),
|
||||
oldType.getElementType(), srcEncoding);
|
||||
|
||||
newOp->getResult(0).setType(newType);
|
||||
auto newCvtType = RankedTensorType::get(
|
||||
oldType.getShape(), oldType.getElementType(),
|
||||
cvt.getResult().getType().cast<RankedTensorType>().getEncoding());
|
||||
oldType.getShape(), oldType.getElementType(), dstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newOp->getLoc(), newCvtType, newOp->getResult(0));
|
||||
rewriter.replaceOp(op, newCvt->getResults());
|
||||
|
||||
@@ -50,6 +50,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
||||
}
|
||||
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||
@@ -183,3 +184,54 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
||||
tt.store %21, %22 : tensor<256xf32, #layout1>
|
||||
return
|
||||
}
|
||||
|
||||
// Select has args with different element types
|
||||
// CHECK-LABEL: select
|
||||
func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
|
||||
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
|
||||
%c512 = arith.constant 512 : index
|
||||
%c30000 = arith.constant 30000 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<1x1xi32, #blocked1>) -> tensor<1x1xi32, #blocked2>
|
||||
%5 = tt.splat %0 : (i32) -> tensor<1x1xi32, #blocked2>
|
||||
%6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2>
|
||||
%7 = "triton_gpu.cmpi"(%6, %cst_1) {predicate = 2 : i64} : (tensor<1x1xi32, #blocked2>, tensor<1x1xi32, #blocked2>) -> tensor<1x1xi1, #blocked2>
|
||||
%8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<512xi32, #blocked0>) -> tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x512xi32, #blocked2>
|
||||
%11 = arith.muli %6, %cst : tensor<1x1xi32, #blocked2>
|
||||
%12 = tt.broadcast %11 : (tensor<1x1xi32, #blocked2>) -> tensor<1x512xi32, #blocked2>
|
||||
%13 = tt.splat %arg0 : (!tt.ptr<f64>) -> tensor<1x512x!tt.ptr<f64>, #blocked2>
|
||||
%14 = tt.broadcast %7 : (tensor<1x1xi1, #blocked2>) -> tensor<1x512xi1, #blocked2>
|
||||
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) {
|
||||
%16 = arith.index_cast %arg3 : index to i32
|
||||
%17 = tt.splat %16 : (i32) -> tensor<1x512xi32, #blocked2>
|
||||
%18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2>
|
||||
%19 = "triton_gpu.cmpi"(%18, %cst_0) {predicate = 2 : i64} : (tensor<1x512xi32, #blocked2>, tensor<1x512xi32, #blocked2>) -> tensor<1x512xi1, #blocked2>
|
||||
%20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2>
|
||||
%21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr<f64>, #blocked2>, tensor<1x512xi32, #blocked2>
|
||||
%22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2>
|
||||
%23 = triton_gpu.convert_layout %21 : (tensor<1x512x!tt.ptr<f64>, #blocked2>) -> tensor<1x512x!tt.ptr<f64>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %22 : (tensor<1x512xi1, #blocked2>) -> tensor<1x512xi1, #blocked3>
|
||||
%25 = tt.load %23, %24 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf64, #blocked3>
|
||||
%26 = triton_gpu.convert_layout %25 : (tensor<1x512xf64, #blocked3>) -> tensor<1x512xf64, #blocked2>
|
||||
%27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2>
|
||||
%28 = "triton_gpu.cmpf"(%arg4, %26) {predicate = 4 : i64} : (tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xi1, #blocked2>
|
||||
%29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2>
|
||||
%30 = "triton_gpu.select"(%29, %26, %arg4) : (tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked2>
|
||||
%31 = triton_gpu.convert_layout %21 : (tensor<1x512x!tt.ptr<f64>, #blocked2>) -> tensor<1x512x!tt.ptr<f64>, #blocked3>
|
||||
%32 = triton_gpu.convert_layout %30 : (tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked3>
|
||||
%33 = triton_gpu.convert_layout %27 : (tensor<1x512xi1, #blocked2>) -> tensor<1x512xi1, #blocked3>
|
||||
tt.store %31, %32, %33 : tensor<1x512xf64, #blocked3>
|
||||
scf.yield %30 : tensor<1x512xf64, #blocked2>
|
||||
}
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user