[BACKEND] Verify the same operand and result element type for convert_layout (#1081)

And a hotfix for incorrect convert_layout construction in the GPU
combine pass.
This commit is contained in:
Keren Zhou
2023-01-22 11:59:24 -05:00
committed by GitHub
parent c59fb4acca
commit b5d32896b1
4 changed files with 68 additions and 15 deletions

View File

@@ -96,8 +96,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);

View File

@@ -16,7 +16,9 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape, NoSideEffect]> {
[SameOperandsAndResultShape,
SameOperandsAndResultElementType,
NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
@@ -87,8 +89,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
let description = [{}];

View File

@@ -498,8 +498,6 @@ public:
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
return failure();
Attribute dstEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
// don't rematerialize if it adds an extra conversion that can't
// be removed
for (Value arg : op->getOperands()) {
@@ -509,7 +507,7 @@ public:
llvm::MapVector<Value, Attribute> toConvert;
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
failed(simulateBackwardRematerialization(argOp, processed, layout,
toConvert, dstEncoding))) {
toConvert, srcEncoding))) {
return failure();
}
}
@@ -521,8 +519,11 @@ public:
if (arg.getDefiningOp() == cvt)
mapping.map(arg, cvt.getOperand());
else {
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
arg.getLoc(), cvt.getOperand().getType(), arg);
auto oldType = arg.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(), srcEncoding);
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(arg.getLoc(),
newType, arg);
if (Operation *argOp = arg.getDefiningOp())
cvtI->moveAfter(argOp);
mapping.map(arg, cvtI);
@@ -531,14 +532,12 @@ public:
rewriter.setInsertionPoint(op);
Operation *newOp = rewriter.clone(*op, mapping);
auto oldType = op->getResult(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(),
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding());
auto newType = RankedTensorType::get(oldType.getShape(),
oldType.getElementType(), srcEncoding);
newOp->getResult(0).setType(newType);
auto newCvtType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(),
cvt.getResult().getType().cast<RankedTensorType>().getEncoding());
oldType.getShape(), oldType.getElementType(), dstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
newOp->getLoc(), newCvtType, newOp->getResult(0));
rewriter.replaceOp(op, newCvt->getResults());

View File

@@ -50,6 +50,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
}
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
@@ -183,3 +184,54 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
tt.store %21, %22 : tensor<256xf32, #layout1>
return
}
// Select has args with different element types
// CHECK-LABEL: select
func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
%c512 = arith.constant 512 : index
%c30000 = arith.constant 30000 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
%4 = triton_gpu.convert_layout %3 : (tensor<1x1xi32, #blocked1>) -> tensor<1x1xi32, #blocked2>
%5 = tt.splat %0 : (i32) -> tensor<1x1xi32, #blocked2>
%6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2>
%7 = "triton_gpu.cmpi"(%6, %cst_1) {predicate = 2 : i64} : (tensor<1x1xi32, #blocked2>, tensor<1x1xi32, #blocked2>) -> tensor<1x1xi1, #blocked2>
%8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0>
%9 = triton_gpu.convert_layout %8 : (tensor<512xi32, #blocked0>) -> tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x512xi32, #blocked2>
%11 = arith.muli %6, %cst : tensor<1x1xi32, #blocked2>
%12 = tt.broadcast %11 : (tensor<1x1xi32, #blocked2>) -> tensor<1x512xi32, #blocked2>
%13 = tt.splat %arg0 : (!tt.ptr<f64>) -> tensor<1x512x!tt.ptr<f64>, #blocked2>
%14 = tt.broadcast %7 : (tensor<1x1xi1, #blocked2>) -> tensor<1x512xi1, #blocked2>
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) {
%16 = arith.index_cast %arg3 : index to i32
%17 = tt.splat %16 : (i32) -> tensor<1x512xi32, #blocked2>
%18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2>
%19 = "triton_gpu.cmpi"(%18, %cst_0) {predicate = 2 : i64} : (tensor<1x512xi32, #blocked2>, tensor<1x512xi32, #blocked2>) -> tensor<1x512xi1, #blocked2>
%20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2>
%21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr<f64>, #blocked2>, tensor<1x512xi32, #blocked2>
%22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2>
%23 = triton_gpu.convert_layout %21 : (tensor<1x512x!tt.ptr<f64>, #blocked2>) -> tensor<1x512x!tt.ptr<f64>, #blocked3>
%24 = triton_gpu.convert_layout %22 : (tensor<1x512xi1, #blocked2>) -> tensor<1x512xi1, #blocked3>
%25 = tt.load %23, %24 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf64, #blocked3>
%26 = triton_gpu.convert_layout %25 : (tensor<1x512xf64, #blocked3>) -> tensor<1x512xf64, #blocked2>
%27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2>
%28 = "triton_gpu.cmpf"(%arg4, %26) {predicate = 4 : i64} : (tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xi1, #blocked2>
%29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2>
%30 = "triton_gpu.select"(%29, %26, %arg4) : (tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked2>
%31 = triton_gpu.convert_layout %21 : (tensor<1x512x!tt.ptr<f64>, #blocked2>) -> tensor<1x512x!tt.ptr<f64>, #blocked3>
%32 = triton_gpu.convert_layout %30 : (tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked3>
%33 = triton_gpu.convert_layout %27 : (tensor<1x512xi1, #blocked2>) -> tensor<1x512xi1, #blocked3>
tt.store %31, %32, %33 : tensor<1x512xf64, #blocked3>
scf.yield %30 : tensor<1x512xf64, #blocked2>
}
return
}