[BACKEND] Clean up code (#1768)

- Remove unused header files.
- Get numThreads/numWarps from the triton module.
- Move transforms/utility.h to the include directory.
This commit is contained in:
Keren Zhou
2023-06-12 17:40:33 -07:00
committed by GitHub
parent ac15d00ef4
commit 58a8e8a914
17 changed files with 100 additions and 95 deletions

View File

@@ -69,9 +69,11 @@ bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
bool isSingleValue(Value value);
std::string getValueOperandName(Value value, AsmState &state);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
Type getElementType(Value value);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
@@ -120,10 +122,6 @@ template <typename T> T nextPowOf2(T n) {
return n + 1;
}
bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.

View File

@@ -75,13 +75,11 @@ SmallVector<unsigned> getOrder(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding);
bool isSharedEncoding(Value value);
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -1,9 +1,13 @@
#ifndef TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/MapVector.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
LogicalResult fixupLoops(ModuleOp mod);
@@ -12,9 +16,11 @@ LogicalResult fixupLoops(ModuleOp mod);
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
// skipInit is True when we only consider the operands of the initOp but
// not the initOp itself.
@@ -36,4 +42,4 @@ LogicalResult canMoveOutOfLoop(BlockArgument arg,
} // namespace mlir
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -1,19 +0,0 @@
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
#define TRITON_TARGET_AMDGCNTRANSLATION_H
#include <string>
#include <tuple>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate LLVM IR to AMDGCN code.
std::tuple<std::string, std::string>
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
} // namespace triton
#endif

View File

@@ -38,7 +38,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isSharedEncoding(result)) {
} else if (triton::gpu::isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;
}

View File

@@ -151,7 +151,7 @@ private:
}
for (Value result : op->getResults()) {
if (isSharedEncoding(result)) {
if (triton::gpu::isSharedEncoding(result)) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto tensorType = result.getType().dyn_cast<RankedTensorType>();

View File

@@ -163,13 +163,6 @@ Type getElementType(Value value) {
return type;
}
std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return opName;
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>

View File

@@ -7,6 +7,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -354,22 +355,6 @@ bool isaDistributedLayout(Attribute layout) {
layout.isa<SliceEncodingAttr>();
}
bool expensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -379,6 +364,9 @@ bool isSharedEncoding(Value value) {
return false;
}
} // namespace gpu
} // namespace triton
} // namespace mlir
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
@@ -1142,7 +1130,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
if (auto cat = dyn_cast<triton::CatOp>(arg)) {
auto encoding =
op->getResult(0).getType().cast<RankedTensorType>().getEncoding();
if (triton::gpu::expensiveCat(cat, encoding))
if (isExpensiveCat(cat, encoding))
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
@@ -1151,7 +1139,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
@@ -1161,7 +1149,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
@@ -1183,7 +1171,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
@@ -1213,12 +1201,13 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(op.getOperand()) &&
!isSharedEncoding(op.getResult())) {
!triton::gpu::isSharedEncoding(arg->getOperand(0)) &&
triton::gpu::isSharedEncoding(op.getOperand()) &&
!triton::gpu::isSharedEncoding(op.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
if (triton::gpu::isSharedEncoding(op.getOperand()) &&
triton::gpu::isSharedEncoding(op.getResult())) {
return mlir::failure();
}
auto srcType = op.getOperand().getType().cast<RankedTensorType>();

View File

@@ -7,7 +7,7 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
return failure();
for (auto result : op->getResults())
if (!isSharedEncoding(result))
if (!triton::gpu::isSharedEncoding(result))
return op->emitOpError() << "requires all results to be shared encoding";
return success();

View File

@@ -1,4 +1,3 @@
#include "Utility.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -6,6 +5,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <memory>
using namespace mlir;

View File

@@ -1,4 +1,3 @@
#include "Utility.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
@@ -8,6 +7,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/MapVector.h"
//===----------------------------------------------------------------------===//

View File

@@ -172,7 +172,7 @@ LogicalResult Prefetcher::initialize() {
break;
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op))
if (isSharedEncoding(cvt.getOperand())) {
if (triton::gpu::isSharedEncoding(cvt.getOperand())) {
foundConvertFromShared = true;
break;
}

View File

@@ -1,4 +1,3 @@
#include "Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -16,6 +15,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <memory>
@@ -359,7 +359,7 @@ public:
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, srcEncoding))
if (isExpensiveToRemat(op, srcEncoding))
return failure();
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
@@ -408,8 +408,8 @@ public:
if (!op)
return mlir::failure();
// we don't want to rematerialize any conversion to/from shared
if (isSharedEncoding(cvt->getResults()[0]) ||
isSharedEncoding(cvt->getOperand(0)))
if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) ||
triton::gpu::isSharedEncoding(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention

View File

@@ -1,10 +1,10 @@
#include "Utility.h"
#include "triton/Analysis/Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir {
@@ -88,7 +88,7 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return success();
}
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
@@ -96,24 +96,34 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
IntegerAttr numWarps =
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
"triton_gpu.num-warps");
if (numWarps) {
int sizePerThread = triton::gpu::getTotalElemsPerThread(ptrType);
if (ptrType.getNumElements() < numWarps.getInt() * 32)
return false;
}
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (ptrType.getNumElements() < numWarps * threadsPerWarp)
return false;
return true;
}
bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
return isExpensiveLoadOrStore(op, targetEncoding);
if (isa<triton::CatOp>(op))
return triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
return isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
@@ -126,7 +136,7 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
if (isa<triton::CatOp>(op))
return !triton::gpu::expensiveCat(cast<triton::CatOp>(op), targetEncoding);
return !isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
}
@@ -148,7 +158,7 @@ int simulateBackwardRematerialization(
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensiveToRemat(currOp, currLayout))
if (isExpensiveToRemat(currOp, currLayout))
break;
// A conversion will be removed here (i.e. transferred to operands)
numCvts -= 1;

View File

@@ -9,6 +9,8 @@
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: cst
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@@ -67,8 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.return
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
@@ -80,7 +80,6 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.store %5, %4 : tensor<16xi32, #layout0>
tt.return
}
}
// CHECK-LABEL: if
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
@@ -164,6 +163,8 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
tt.return
}
}
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
@@ -173,6 +174,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: transpose
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[$row_layout]]>
@@ -212,8 +214,10 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
tt.return
}
}
// CHECK-LABEL: loop
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
@@ -266,8 +270,10 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
}
// CHECK-LABEL: loop_if
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
@@ -318,8 +324,10 @@ tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i3
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
}
// CHECK-LABEL: vecadd
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
@@ -349,9 +357,11 @@ tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
tt.store %21, %22 : tensor<256xf32, #layout1>
tt.return
}
}
// Select has args with different element types
// CHECK-LABEL: select
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
@@ -400,9 +410,11 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
}
tt.return
}
}
// Make sure the following IR doesn't hang the compiler.
// CHECK-LABEL: long_func
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
@@ -796,10 +808,12 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
tt.store %365, %366 : tensor<1024xf64, #blocked0>
tt.return
}
}
// A mnist model from torch inductor.
// Check if topological sort is working correct and there's no unnecessary convert
// CHECK-LABEL: mnist
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
@@ -884,17 +898,19 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
tt.return
}
}
// -----
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
@@ -1034,11 +1050,13 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
}
tt.return
}
}
// -----
// Just make sure it doesn't crash on non-tensor types.
// CHECK-LABEL: if_no_tensor
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c-1_i64 = arith.constant -1 : i64
@@ -1062,6 +1080,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
tt.return
}
}
// -----

View File

@@ -11,6 +11,8 @@
#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: tt.func @push_elementwise1
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
@@ -122,3 +124,5 @@ tt.func @push_elementwise5(
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
tt.return %newc : tensor<16x16xf32, #Cv1>
}
}

View File

@@ -13,6 +13,13 @@ struct TestAliasPass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
static std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return opName;
}
static void print(StringRef name, SmallVector<std::string, 4> &vals,
raw_ostream &os) {
if (vals.empty())