mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Rebase Triton to LLVM-15. (#1070)
This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are mechanical, except for the analysis framework changes.
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.6)
|
||||
|
||||
cmake_policy(SET CMP0116 OLD)
|
||||
|
||||
include(ExternalProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# # Triton
|
||||
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
|
||||
@@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRLLVMIR
|
||||
MLIRLLVMDialect
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
|
||||
@@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate)
|
||||
# MLIR core
|
||||
MLIROptLib
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLLVMDialect
|
||||
MLIRPass
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Process.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cmath>
|
||||
@@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
|
||||
return "bad-not";
|
||||
case Check::CheckBadCount:
|
||||
return "bad-count";
|
||||
case Check::CheckMisspelled:
|
||||
return "misspelled";
|
||||
case Check::CheckNone:
|
||||
llvm_unreachable("invalid FileCheckType");
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
@@ -33,8 +33,8 @@ int main(int argc, char **argv) {
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
|
||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
@@ -38,7 +38,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, arith::ArithmeticDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>();
|
||||
scf::SCFDialect>();
|
||||
|
||||
context.appendDialectRegistry(registry);
|
||||
|
||||
@@ -50,7 +50,8 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
context.loadAllAvailableDialects();
|
||||
context.allowUnregisteredDialects();
|
||||
|
||||
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
||||
OwningOpRef<ModuleOp> module =
|
||||
parseSourceFile<ModuleOp>(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Parse MLIR file failed.";
|
||||
return nullptr;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define TRITON_ANALYSIS_ALIAS_H
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -21,7 +21,7 @@ public:
|
||||
}
|
||||
|
||||
/// The pessimistic value state of a value without alias
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context) {
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
||||
return AliasInfo();
|
||||
}
|
||||
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
||||
@@ -29,6 +29,10 @@ public:
|
||||
/// The union of both arguments
|
||||
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); });
|
||||
}
|
||||
|
||||
private:
|
||||
/// The set of allocated values that are aliased by this lattice.
|
||||
/// For now, we only consider aliased value produced by the following
|
||||
@@ -58,9 +62,13 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
||||
class SharedMemoryAliasAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AliasInfo>>::getLatticeElement;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
/// Given two values, returns their aliasing behavior.
|
||||
@@ -70,9 +78,10 @@ public:
|
||||
ModRefResult getModRef(Operation *op, Value location);
|
||||
|
||||
/// Computes if the alloc set of the results are changed.
|
||||
ChangeResult
|
||||
void
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
|
||||
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -188,6 +188,8 @@ private:
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
template <typename T> Interval(T, T) -> Interval<T>;
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALLOCATION_H
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||
#define TRITON_ANALYSIS_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -62,7 +63,7 @@ public:
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context) {
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
||||
return AxisInfo();
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
@@ -70,6 +71,22 @@ public:
|
||||
/// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
auto print = [&](StringRef name, DimVectorT vec) {
|
||||
os << name << " = [";
|
||||
llvm::interleaveComma(vec, os);
|
||||
os << "]";
|
||||
};
|
||||
print("contiguity", contiguity);
|
||||
print(", divisibility", divisibility);
|
||||
print(", constancy", constancy);
|
||||
os << ", constant_value = ";
|
||||
if (constantValue)
|
||||
os << *constantValue;
|
||||
else
|
||||
os << "<none>";
|
||||
}
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
@@ -147,7 +164,8 @@ public:
|
||||
}
|
||||
|
||||
virtual AxisInfo
|
||||
getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0;
|
||||
getAxisInfo(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
|
||||
|
||||
virtual bool match(Operation *op) = 0;
|
||||
};
|
||||
@@ -157,15 +175,16 @@ template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
|
||||
public:
|
||||
using AxisInfoVisitor::AxisInfoVisitor;
|
||||
|
||||
AxisInfo getAxisInfo(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) final {
|
||||
AxisInfo
|
||||
getAxisInfo(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) final {
|
||||
return getAxisInfo(cast<OpTy>(op), operands);
|
||||
}
|
||||
|
||||
bool match(Operation *op) final { return isa<OpTy>(op); }
|
||||
|
||||
virtual AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
virtual AxisInfo
|
||||
getAxisInfo(OpTy op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
||||
llvm_unreachable("Unimplemented getAxisInfo");
|
||||
}
|
||||
};
|
||||
@@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto rank = lhsInfo.getRank();
|
||||
@@ -230,7 +250,8 @@ public:
|
||||
(visitors.emplace_back(std::make_unique<Ts>()), ...);
|
||||
}
|
||||
|
||||
AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo apply(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
||||
for (auto &visitor : visitors)
|
||||
if (visitor->match(op))
|
||||
return visitor->getAxisInfo(op, operands);
|
||||
@@ -241,16 +262,19 @@ private:
|
||||
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
class AxisInfoAnalysis
|
||||
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
|
||||
private:
|
||||
AxisInfoVisitorList visitors;
|
||||
|
||||
public:
|
||||
AxisInfoAnalysis(MLIRContext *context);
|
||||
AxisInfoAnalysis(DataFlowSolver &solver);
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AxisInfo>>::getLatticeElement;
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
void visitOperation(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
|
||||
|
||||
unsigned getPtrContiguity(Value ptr);
|
||||
|
||||
@@ -261,4 +285,4 @@ public:
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_ANALYSIS_UTILITY_H
|
||||
#define TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
@@ -12,7 +13,7 @@ namespace mlir {
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
srcTy = op.getOperand().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
@@ -103,6 +104,9 @@ SetVector<Operation *>
|
||||
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
|
||||
TransitiveFilter forwardFilter = nullptr);
|
||||
|
||||
// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
@@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::StandardOpsDialect",
|
||||
// TODO: Does this pass depend on SCF?
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
@@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::NVVM::NVVMDialect",
|
||||
"mlir::StandardOpsDialect"];
|
||||
"mlir::NVVM::NVVMDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
@@ -25,12 +25,9 @@ def Triton_Dialect : Dialect {
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"math::MathDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect",
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
// "func::FuncDialect"
|
||||
"cf::ControlFlowDialect",
|
||||
"func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
@@ -38,6 +35,7 @@ def Triton_Dialect : Dialect {
|
||||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
@@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
@@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store",
|
||||
"triton::EvictionPolicy":$evict)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_TYPES
|
||||
#define TRITON_TYPES
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
|
||||
//
|
||||
@@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
|
||||
@@ -16,8 +16,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
|
||||
let constructor = "mlir::triton::createCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITONGPU_ATTRDEFS
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
@@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -273,6 +275,7 @@ for
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
static constexpr int numBitsToHoldMmaV1ID{5};
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
@@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
template<class T>
|
||||
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
@@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details.
|
||||
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect {
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "integer comparison operation";
|
||||
@@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
@@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "select operation";
|
||||
|
||||
@@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
}
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||
|
||||
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
||||
|
||||
@@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) {
|
||||
void SharedMemoryAliasAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
|
||||
AliasInfo aliasInfo;
|
||||
bool pessimistic = true;
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
@@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
}
|
||||
|
||||
if (pessimistic) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
return markAllPessimisticFixpoint(results);
|
||||
}
|
||||
// Join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(aliasInfo);
|
||||
}
|
||||
return result;
|
||||
for (auto *result : results)
|
||||
propagateIfChanged(result, result->join(aliasInfo));
|
||||
}
|
||||
|
||||
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
@@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64;
|
||||
|
||||
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
||||
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
@@ -224,14 +223,12 @@ private:
|
||||
}
|
||||
|
||||
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
|
||||
LatticeElement<AliasInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(value);
|
||||
if (latticeElement) {
|
||||
auto &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto alloc : info.getAllocs()) {
|
||||
allocation->addAlias(value, alloc);
|
||||
}
|
||||
dataflow::Lattice<AliasInfo> *latticeElement =
|
||||
analysis.getLatticeElement(value);
|
||||
if (latticeElement && !latticeElement->isUninitialized()) {
|
||||
AliasInfo &info = latticeElement->getValue();
|
||||
for (auto alloc : info.getAllocs()) {
|
||||
allocation->addAlias(value, alloc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -244,14 +241,19 @@ private:
|
||||
getScratchValueSize(op);
|
||||
});
|
||||
// Get the alias values
|
||||
SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
|
||||
aliasAnalysis.run(operation);
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
SharedMemoryAliasAnalysis *aliasAnalysis =
|
||||
solver->load<SharedMemoryAliasAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation))) {
|
||||
// TODO: return error instead of bailing out..
|
||||
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
|
||||
}
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
getValueAlias(operand, aliasAnalysis);
|
||||
getValueAlias(operand, *aliasAnalysis);
|
||||
}
|
||||
for (auto value : op->getResults()) {
|
||||
getValueAlias(value, aliasAnalysis);
|
||||
getValueAlias(value, *aliasAnalysis);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
@@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
|
||||
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
||||
if (func::FuncOp fun = dyn_cast<func::FuncOp>(op)) {
|
||||
Attribute attr =
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
@@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
return operands[0]->getValue();
|
||||
}
|
||||
};
|
||||
@@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::MakeRangeOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(triton::MakeRangeOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto start = op.start();
|
||||
auto end = op.end();
|
||||
return AxisInfo(/*contiguity=*/{end - start},
|
||||
@@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final
|
||||
public:
|
||||
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(arith::ConstantOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(arith::ConstantOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
|
||||
if (intAttr || boolAttr) {
|
||||
@@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::SplatOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(triton::SplatOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
@@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::ExpandDimsOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(triton::ExpandDimsOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||
@@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final
|
||||
public:
|
||||
using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(triton::BroadcastOp op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(triton::BroadcastOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return AxisInfo();
|
||||
@@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return AxisInfo();
|
||||
@@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo getAxisInfo(OpTy op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
std::optional<int64_t> constantValue;
|
||||
@@ -786,8 +795,8 @@ public:
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
||||
: ForwardDataFlowAnalysis<AxisInfo>(context) {
|
||||
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
|
||||
// UnrealizedConversionCast:
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
@@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
||||
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::OrIOp>,
|
||||
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
|
||||
visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>,
|
||||
visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>,
|
||||
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
|
||||
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
|
||||
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
|
||||
@@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
||||
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
void AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
|
||||
AxisInfo curr = visitors.apply(op, operands);
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
return markAllPessimisticFixpoint(results);
|
||||
}
|
||||
// override with hint
|
||||
auto newContiguity = curr.getContiguity();
|
||||
@@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
|
||||
curr.getConstantValue());
|
||||
// join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(curr);
|
||||
}
|
||||
return result;
|
||||
for (auto *result : results)
|
||||
propagateIfChanged(result, result->join(curr));
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
||||
@@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto axisInfo = lookupLatticeElement(ptr)->getValue();
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
return 1;
|
||||
auto axisInfo = latticeElement->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
|
||||
@@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
return 1;
|
||||
auto maskAxis = latticeElement->getValue();
|
||||
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
||||
auto maskAxis = lookupLatticeElement(mask)->getValue();
|
||||
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <deque>
|
||||
@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
return multiRootTopologicalSort(slice);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
||||
// interacts with constant propagation, but SparseConstantPropagation
|
||||
// doesn't seem to be sufficient.
|
||||
struct ConstantAnalysis : public DataFlowAnalysis {
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
LogicalResult initialize(Operation *top) override {
|
||||
WalkResult result = top->walk([&](Operation *op) {
|
||||
if (failed(visit(op)))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!result.wasInterrupted());
|
||||
}
|
||||
|
||||
LogicalResult visit(ProgramPoint point) override {
|
||||
Operation *op = point.get<Operation *>();
|
||||
Attribute value;
|
||||
if (matchPattern(op, m_Constant(&value))) {
|
||||
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
|
||||
op->getResult(0));
|
||||
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
|
||||
value, op->getDialect())));
|
||||
return success();
|
||||
}
|
||||
setAllToUnknownConstants(op->getResults());
|
||||
for (Region ®ion : op->getRegions())
|
||||
setAllToUnknownConstants(region.getArguments());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Set all given values as not constants.
|
||||
void setAllToUnknownConstants(ValueRange values) {
|
||||
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
||||
for (Value value : values) {
|
||||
auto *constant =
|
||||
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
|
||||
propagateIfChanged(constant, constant->join(unknownConstant));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
||||
auto solver = std::make_unique<DataFlowSolver>();
|
||||
solver->load<dataflow::DeadCodeAnalysis>();
|
||||
solver->load<ConstantAnalysis>();
|
||||
return solver;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -159,9 +159,6 @@ private:
|
||||
Value smemBase) const {
|
||||
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding();
|
||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
|
||||
@@ -7,10 +7,8 @@
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@@ -422,9 +420,9 @@ struct MMA16816ConversionHelper {
|
||||
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
|
||||
Value thread, ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||
ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread),
|
||||
helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter),
|
||||
loc(loc), ctx(mmaLayout.getContext()) {
|
||||
helper.deduceMmaType(dotOperand);
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
|
||||
@@ -115,8 +115,6 @@ private:
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
@@ -221,7 +219,6 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
|
||||
auto A = op.a();
|
||||
auto B = op.b();
|
||||
@@ -230,12 +227,10 @@ private:
|
||||
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = C.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto cShape = cTensorTy.getShape();
|
||||
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
|
||||
@@ -61,7 +61,6 @@ struct FpToFpOpConversion
|
||||
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
@@ -153,7 +152,6 @@ struct FpToFpOpConversion
|
||||
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
|
||||
@@ -109,7 +109,8 @@ struct LoadOpConversion
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (other && valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
|
||||
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() &&
|
||||
constAttr.getElementType().isa<IntegerType>()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
@@ -333,7 +334,6 @@ struct StoreOpConversion
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
@@ -387,7 +387,6 @@ struct AtomicCASOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
Value ptr = op.ptr();
|
||||
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llCmp = adaptor.cmp();
|
||||
|
||||
@@ -286,7 +286,6 @@ private:
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcRank = srcTy.getRank();
|
||||
auto order = getOrder(srcLayout);
|
||||
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
@@ -351,7 +350,6 @@ private:
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||
Value warpZero = icmp_eq(warpIdAxis, zero);
|
||||
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
|
||||
@@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
@@ -476,7 +476,6 @@ struct ExtractSliceOpConversion
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
|
||||
strideVals, offsetVals);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// TODO: refactor so that it doesn't fail if Allocation.h
|
||||
// is included after utility.h (due to conflict in `store` macro
|
||||
// and <atomic>
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
|
||||
//
|
||||
@@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
|
||||
// All the rights are reserved by the LLVM community.
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
|
||||
private:
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
|
||||
for (const auto &attr : op->getAttrs()) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||
attr.getName() == "std.varargs" ||
|
||||
@@ -65,27 +66,27 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||
// to this legalization pattern.
|
||||
LLVM::LLVMFuncOp
|
||||
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
||||
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
|
||||
result);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
// Propagate argument/result attributes to all converted arguments/result
|
||||
// obtained after converting a given original argument/result.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||
attributes);
|
||||
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes);
|
||||
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
|
||||
assert(!resAttrDicts.empty() && "expected array to be non-empty");
|
||||
auto newResAttrDicts =
|
||||
@@ -131,7 +132,7 @@ protected:
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
@@ -191,8 +192,8 @@ public:
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
|
||||
allocation(allocation), smem(smem) {}
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem),
|
||||
indexCacheInfo(indexCacheInfo) {}
|
||||
|
||||
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
|
||||
@@ -861,7 +862,6 @@ private:
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
auto parentIndices =
|
||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = parentIndices.size();
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -40,7 +41,6 @@ public:
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addIllegalDialect<mlir::StandardOpsDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
@@ -51,7 +51,7 @@ public:
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
addIllegalOp<mlir::func::FuncOp>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
@@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
||||
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
if (!newFuncOp)
|
||||
@@ -133,7 +133,8 @@ public:
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
// Step 2
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
if (failed(decomposeInsertSliceAsyncOp(mod)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 3
|
||||
Allocation allocation(mod);
|
||||
@@ -142,7 +143,7 @@ public:
|
||||
|
||||
// Step 4
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::populateSCFToControlFlowConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
|
||||
scf::WhileOp, scf::ExecuteRegionOp>();
|
||||
@@ -159,8 +160,10 @@ public:
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 6 - get axis and shared memory info
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
return signalPassFailure();
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
mod->setAttr("triton_gpu.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
@@ -178,38 +181,39 @@ public:
|
||||
|
||||
// Normal conversions
|
||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ConvertLayoutOp
|
||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// DotOp
|
||||
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
// ElementwiseOp
|
||||
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
// LoadStoreOp
|
||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ReduceOp
|
||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ViewOp
|
||||
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
@@ -306,9 +310,11 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
return failure();
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
@@ -342,7 +348,7 @@ private:
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
||||
unsigned inVec = axisInfoAnalysis->getPtrContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
@@ -400,11 +406,11 @@ private:
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newAsyncWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -220,10 +220,7 @@ struct SharedMemoryObject {
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
||||
|
||||
for (auto idx : order) {
|
||||
offsets.emplace_back(i32_val(0));
|
||||
}
|
||||
offsets.append(order.size(), i32_val(0));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
@@ -59,10 +59,13 @@ public:
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType,
|
||||
value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
);
|
||||
if (value.getElementType().isInteger(1) && value.isSplat())
|
||||
// Workaround until https://reviews.llvm.org/D133743 is included.
|
||||
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retType);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality(
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
||||
class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
||||
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
@@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
||||
// by the frontend
|
||||
target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are
|
||||
// inlined by the frontend
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
@@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
patterns
|
||||
.insert< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
|
||||
// CF
|
||||
|
||||
class CFBranchPattern : public OpConversionPattern<BranchOp> {
|
||||
class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
|
||||
public:
|
||||
using OpConversionPattern<BranchOp>::OpConversionPattern;
|
||||
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor,
|
||||
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<BranchOp>(op, op.getSuccessor(),
|
||||
adaptor.getOperands());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
||||
op, op.getSuccessor(), adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class CFCondBranchPattern : public OpConversionPattern<CondBranchOp> {
|
||||
class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
|
||||
public:
|
||||
using OpConversionPattern<CondBranchOp>::OpConversionPattern;
|
||||
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor,
|
||||
matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<CondBranchOp>(
|
||||
auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
|
||||
op, adaptor.getCondition(), op.getTrueDest(),
|
||||
adaptor.getTrueDestOperands(), op.getFalseDest(),
|
||||
adaptor.getFalseDestOperands());
|
||||
|
||||
@@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithmetic
|
||||
MLIRSCF
|
||||
|
||||
# Since LLVM 15
|
||||
# MLIRFunc
|
||||
# else
|
||||
MLIRStandard
|
||||
MLIRArithmeticDialect
|
||||
MLIRSCFDialect
|
||||
MLIRFuncDialect
|
||||
)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
@@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) {
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
||||
Type resultTypes[1];
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
@@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
void LoadOp::print(OpAsmPrinter &printer) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
printer << getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(),
|
||||
{loadOp.operand_segment_sizesAttrName()});
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
||||
{operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
printer.printStrippedAttrOrType(getResult().getType());
|
||||
}
|
||||
|
||||
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
||||
Type valueType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
@@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
||||
void StoreOp::print(OpAsmPrinter &printer) {
|
||||
printer << " ";
|
||||
printer << storeOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(storeOp.value().getType());
|
||||
printer.printStrippedAttrOrType(value().getType());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
@@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!constOperand)
|
||||
return {};
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
auto ret = SplatElementsAttr::get(
|
||||
shapedType, ArrayRef<Attribute>(constOperand.getValue()));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
||||
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
|
||||
{triton::LoadOp::getOperationName()}) {}
|
||||
: mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3,
|
||||
context, {triton::LoadOp::getOperationName()}) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
|
||||
auto selectOp = llvm::dyn_cast<mlir::arith::SelectOp>(op);
|
||||
if (!selectOp)
|
||||
return mlir::failure();
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef TRITON_PATTERNS
|
||||
#define TRITON_PATTERNS
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
|
||||
|
||||
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
@@ -366,7 +366,6 @@ template SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
return ::getElemsPerThread(parent, paddedShape(shape));
|
||||
}
|
||||
@@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
// InsertSliceAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
||||
ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 8> allOperands;
|
||||
Type srcType, dstType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
@@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
InsertSliceAsyncOp insertSliceAsyncOp) {
|
||||
void InsertSliceAsyncOp::print(OpAsmPrinter &printer) {
|
||||
printer << " ";
|
||||
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
||||
printer << getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceAsyncOp->getAttrs(),
|
||||
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
||||
{operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
||||
printer.printStrippedAttrOrType(src().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
printer.printStrippedAttrOrType(result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||
// Get the shape of the tensor.
|
||||
size_t rank = origType.getRank();
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||
dataflow::Lattice<AxisInfo> *latticeElement =
|
||||
axisInfo.getLatticeElement(ptr);
|
||||
AxisInfo info = latticeElement && !latticeElement->isUninitialized()
|
||||
? latticeElement->getValue()
|
||||
: AxisInfo();
|
||||
// Get the contiguity order of `ptr`
|
||||
auto order = argSort(info.getContiguity());
|
||||
// The desired divisibility is the maximum divisibility
|
||||
@@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
for (Value val : op->getResults()) {
|
||||
if (val.getType() != origType)
|
||||
continue;
|
||||
auto valInfo = axisInfo.lookupLatticeElement(val);
|
||||
auto valInfo = axisInfo.getLatticeElement(val);
|
||||
auto currOrder = argSort(valInfo->getValue().getContiguity());
|
||||
if (order == currOrder)
|
||||
withSameOrder.insert(val);
|
||||
@@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
|
||||
unsigned perThread = 1;
|
||||
for (Value val : withSameOrder) {
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue();
|
||||
AxisInfo info = axisInfo.getLatticeElement(val)->getValue();
|
||||
unsigned maxMultipleBytes = info.getDivisibility(order[0]);
|
||||
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
@@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
// Run axis info analysis
|
||||
AxisInfoAnalysis axisInfo(&getContext());
|
||||
axisInfo.run(op);
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(op)))
|
||||
return signalPassFailure();
|
||||
|
||||
// For each i/o operation, we determine what layout
|
||||
// the pointers should have for best memory coalescing
|
||||
@@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!ty || !ty.getElementType().isa<PointerType>())
|
||||
return;
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||
AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue();
|
||||
auto mod = curr->getParentOfType<ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
||||
auto convertType = getTypeConverter(*axisInfo, ptr, numWarps);
|
||||
layoutMap[ptr] = convertType;
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
@@ -3,5 +3,6 @@
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
@@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
LogicalResult LoopPipeliner::initialize() {
|
||||
Block *loop = forOp.getBody();
|
||||
|
||||
AxisInfoAnalysis axisInfoAnalysis(forOp.getContext());
|
||||
axisInfoAnalysis.run(forOp->getParentOfType<ModuleOp>());
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// can we use forOp.walk(...) here?
|
||||
SmallVector<triton::LoadOp, 2> allLoads;
|
||||
for (Operation &op : *loop)
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
auto ptr = loadOp.ptr();
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
continue;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
@@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
|
||||
triton::TritonDialect, StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
triton::TritonDialect, scf::SCFDialect>(
|
||||
[&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
|
||||
.get("value")
|
||||
.dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (attr) {
|
||||
auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer(
|
||||
newType, attr.getRawData(), true);
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData());
|
||||
op->setAttr("value", newAttr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRSCFToStandard
|
||||
MLIRLLVMDialect
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include <optional>
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Pass.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
|
||||
|
||||
@@ -57,19 +57,10 @@ def get_pybind11_package_info():
|
||||
def get_llvm_package_info():
|
||||
# download if nothing is installed
|
||||
system = platform.system()
|
||||
if system == "Darwin":
|
||||
system_suffix = "apple-darwin"
|
||||
elif system == "Linux":
|
||||
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
|
||||
vglibc = vglibc[0] * 100 + vglibc[1]
|
||||
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
||||
system_suffix = f"linux-gnu-{linux_suffix}"
|
||||
else:
|
||||
raise RuntimeError(f"unsupported system: {system}")
|
||||
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}'
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz"
|
||||
name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release")
|
||||
url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name)
|
||||
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
@@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) {
|
||||
std::string attrName = name + "_arg" + std::to_string(id);
|
||||
mlir::Block *owner = arg.getOwner();
|
||||
if (owner->isEntryBlock() &&
|
||||
!mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
|
||||
!mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) {
|
||||
owner->getParentOp()->setAttr(attrName, attr);
|
||||
}
|
||||
}
|
||||
@@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return str;
|
||||
})
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
[](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function",
|
||||
@@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) {
|
||||
return false;
|
||||
})
|
||||
.def("get_function",
|
||||
[](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
[](mlir::ModuleOp &self,
|
||||
std::string &funcName) -> mlir::func::FuncOp {
|
||||
return self.lookupSymbol<mlir::func::FuncOp>(funcName);
|
||||
})
|
||||
.def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
|
||||
llvm::SmallVector<mlir::FuncOp> funcs;
|
||||
self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
|
||||
if (funcs.size() != 1)
|
||||
throw std::runtime_error("Expected a single function");
|
||||
return funcs[0];
|
||||
});
|
||||
.def("get_single_function",
|
||||
[](mlir::ModuleOp &self) -> mlir::func::FuncOp {
|
||||
llvm::SmallVector<mlir::func::FuncOp> funcs;
|
||||
self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); });
|
||||
if (funcs.size() != 1)
|
||||
throw std::runtime_error("Expected a single function");
|
||||
return funcs[0];
|
||||
});
|
||||
|
||||
m.def("make_attr",
|
||||
[](const std::vector<int> &values, mlir::MLIRContext &context) {
|
||||
@@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) {
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
|
||||
mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
|
||||
mlir::func::FuncDialect, mlir::scf::SCFDialect>();
|
||||
context.appendDialectRegistry(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
// parse module
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module(
|
||||
mlir::parseSourceFile(inputFilename, &context));
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module =
|
||||
mlir::parseSourceFile<mlir::ModuleOp>(inputFilename, &context);
|
||||
if (!module)
|
||||
throw std::runtime_error("Parse MLIR file failed.");
|
||||
// locations are incompatible with ptx < 7.5 !
|
||||
module->walk([](mlir::Operation *op) {
|
||||
op->setLoc(mlir::UnknownLoc::get(op->getContext()));
|
||||
});
|
||||
if (!module)
|
||||
throw std::runtime_error("Parse MLIR file failed.");
|
||||
|
||||
return module->clone();
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
||||
py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function")
|
||||
// .def_property_readonly("attrs", &ir::function::attrs)
|
||||
// .def("add_attr", &ir::function::add_attr);
|
||||
.def("args",
|
||||
[](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
||||
[](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
||||
return self.getArgument(idx);
|
||||
})
|
||||
.def(
|
||||
"add_entry_block",
|
||||
[](mlir::FuncOp &self) -> mlir::Block * {
|
||||
[](mlir::func::FuncOp &self) -> mlir::Block * {
|
||||
return self.addEntryBlock();
|
||||
},
|
||||
ret::reference)
|
||||
.def(
|
||||
"set_arg_attr",
|
||||
[](mlir::FuncOp &self, int arg_no, const std::string &name, int val) {
|
||||
[](mlir::func::FuncOp &self, int arg_no, const std::string &name,
|
||||
int val) {
|
||||
// set arg attributes "name" to value "val"
|
||||
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
},
|
||||
ret::reference)
|
||||
.def_property_readonly("type", &mlir::FuncOp::getType)
|
||||
.def("reset_type", &mlir::FuncOp::setType);
|
||||
.def_property_readonly("type", &mlir::func::FuncOp::getFunctionType)
|
||||
.def("reset_type", &mlir::func::FuncOp::setType);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
|
||||
@@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("ret",
|
||||
[](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::ReturnOp>(loc, vals);
|
||||
self.create<mlir::func::ReturnOp>(loc, vals);
|
||||
})
|
||||
.def("call",
|
||||
[](mlir::OpBuilder &self, mlir::FuncOp &func,
|
||||
[](mlir::OpBuilder &self, mlir::func::FuncOp &func,
|
||||
std::vector<mlir::Value> &args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::CallOp>(loc, func, args);
|
||||
return self.create<mlir::func::CallOp>(loc, func, args);
|
||||
})
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start",
|
||||
@@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_or_insert_function",
|
||||
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType,
|
||||
std::string &visibility) -> mlir::FuncOp {
|
||||
std::string &visibility) -> mlir::func::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||
return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> attrs = {
|
||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||
self.getStringAttr(visibility))};
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
||||
return self.create<mlir::func::FuncOp>(loc, funcName, funcTy,
|
||||
attrs);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
@@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value condition,
|
||||
mlir::Block *trueDest, mlir::Block *falseDest) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::CondBranchOp>(loc, condition, trueDest,
|
||||
falseDest);
|
||||
self.create<mlir::cf::CondBranchOp>(loc, condition, trueDest,
|
||||
falseDest);
|
||||
return;
|
||||
})
|
||||
.def("create_branch",
|
||||
[](mlir::OpBuilder &self, mlir::Block *dest,
|
||||
std::vector<mlir::Value> &args) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::BranchOp>(loc, dest, args);
|
||||
self.create<mlir::cf::BranchOp>(loc, dest, args);
|
||||
return;
|
||||
})
|
||||
// Structured control flow
|
||||
@@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_to_index",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getIndexType());
|
||||
return self.create<mlir::arith::IndexCastOp>(
|
||||
loc, self.getIndexType(), input);
|
||||
})
|
||||
.def("create_index_to_si",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
return self.create<mlir::arith::IndexCastOp>(
|
||||
loc, self.getI32Type(), input);
|
||||
})
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
@@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
||||
falseValue);
|
||||
return self.create<mlir::arith::SelectOp>(loc, condition,
|
||||
trueValue, falseValue);
|
||||
})
|
||||
.def("create_printf",
|
||||
[](mlir::OpBuilder &self, const std::string &prefix,
|
||||
@@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
})
|
||||
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createLowerToCFGPass());
|
||||
self.addPass(mlir::createConvertSCFToCFPass());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
|
||||
@@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs):
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
|
||||
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
Binary file not shown.
@@ -11,7 +11,7 @@
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any aliasing with the dot op encoding.
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
@@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func @alloc(%A : !tt.ptr<f16>) {
|
||||
func.func @alloc(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
@@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: convert
|
||||
func @convert(%A : !tt.ptr<f16>) {
|
||||
func.func @convert(%A : !tt.ptr<f16>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: %0 -> %0
|
||||
%cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
@@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
func.func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
// CHECK: %0 -> %cst
|
||||
@@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice
|
||||
func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
@@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_cat
|
||||
func @if_cat(%i1 : i1) {
|
||||
func.func @if_cat(%i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: %cst_0 -> %cst_0
|
||||
@@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_alias
|
||||
func @if_alias(%i1 : i1) {
|
||||
func.func @if_alias(%i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if
|
||||
func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if_for
|
||||
func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
|
||||
@@ -1,288 +1,288 @@
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: cast
|
||||
func @cast() {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
||||
// CHECK-LABEL: @cast
|
||||
func.func @cast() {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst = arith.constant 1 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%0 = arith.extsi %cst : i32 to i64
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: add
|
||||
func @add() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @add
|
||||
func.func @add() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%2 = arith.addi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127
|
||||
%3 = arith.constant dense<127> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%4 = arith.addi %1, %3 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: sub
|
||||
func @sub() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @sub
|
||||
func.func @sub() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%2 = arith.subi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129
|
||||
%3 = arith.constant dense<129> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%4 = arith.subi %3, %1 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: mul
|
||||
func @mul() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @mul
|
||||
func.func @mul() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%2 = arith.muli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%3 = arith.constant dense<128> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%4 = arith.muli %3, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2
|
||||
%5 = arith.constant dense<2> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256
|
||||
%6 = arith.muli %4, %5 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: div
|
||||
func @div() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @div
|
||||
func.func @div() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%3 = arith.divui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
%5 = arith.divsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%6 = arith.divsi %4, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%7 = arith.divsi %4, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
|
||||
%8 = arith.constant dense<66> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = <none>
|
||||
%9 = arith.divui %0, %8 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
|
||||
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
|
||||
%11 = arith.divsi %10, %4 : tensor<128xi32>
|
||||
return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: rem
|
||||
func @rem() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @rem
|
||||
func.func @rem() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
%2 = arith.remsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%3 = arith.remui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none>
|
||||
%5 = arith.remsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>
|
||||
%6 = arith.remsi %4, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
|
||||
%7 = arith.constant dense<66> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>
|
||||
%8 = arith.remui %0, %7 : tensor<128xi32>
|
||||
return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast
|
||||
func @broadcast() {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
// CHECK-LABEL: @broadcast
|
||||
func.func @broadcast() {
|
||||
// CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%0 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
|
||||
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: splat
|
||||
func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @splat
|
||||
func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: cmp
|
||||
func @cmp() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @cmp
|
||||
func.func @cmp() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
%1 = arith.constant dense<0> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%4 = arith.cmpi sle, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%5 = arith.cmpi sge, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%6 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0
|
||||
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: logic
|
||||
func @logic() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @logic
|
||||
func.func @logic() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%1 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%3 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
|
||||
%4 = arith.divsi %0, %3 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%5 = arith.andi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%6 = arith.ori %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%7 = arith.xori %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%8 = arith.andi %2, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%9 = arith.ori %2, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%10 = arith.xori %2, %4 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: select
|
||||
func @select() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @select
|
||||
func.func @select() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
%1 = arith.constant dense<0> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
|
||||
%4 = arith.constant 0 : i1
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
%7 = tt.splat %4 : (i1) -> tensor<128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
||||
%5 = select %4, %3, %7 : tensor<128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
%5 = arith.select %4, %3, %7 : tensor<128xi1>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shift() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
func.func @shift() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%1 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
|
||||
%2 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none>
|
||||
%3 = arith.shli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none>
|
||||
%4 = arith.shrsi %0, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%5 = arith.shli %1, %2 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @max_min() {
|
||||
// CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
func.func @max_min() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none>
|
||||
%1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%2 = arith.maxsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%3 = arith.minsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%4 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
|
||||
%5 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8
|
||||
%6 = arith.maxsi %4, %5 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func @for() {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0]
|
||||
// CHECK-LABEL: @for
|
||||
func.func @for() {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0
|
||||
%a_init = arith.constant dense<0> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1
|
||||
%b_init = arith.constant dense<1> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
|
||||
%c_init = arith.constant dense<4> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
||||
%ub = arith.constant 128 : index
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
|
||||
%lb = arith.constant 0 : index
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
|
||||
%step = arith.constant 16 : index
|
||||
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%t = arith.index_cast %iv : index to i32
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
||||
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
|
||||
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
|
||||
// CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
|
||||
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
|
||||
}
|
||||
return
|
||||
@@ -290,53 +290,53 @@ func @for() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: permute_2d
|
||||
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1]
|
||||
// CHECK-LABEL: @permute_2d
|
||||
func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none>
|
||||
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
|
||||
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
||||
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none>
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
|
||||
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none>
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
@@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
// CHECK-LABEL: store_constant_align
|
||||
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
|
||||
%1 = arith.muli %pid, %c128_i32 : i32
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = <none>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = <none>
|
||||
%4 = arith.addi %3, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
|
||||
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
|
||||
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||
return
|
||||
@@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||
// CHECK-LABEL: vecadd_mask_align_16
|
||||
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-LABEL: @vecadd_mask_align_16
|
||||
func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
// CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
|
||||
// CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
@@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||
// CHECK-LABEL: vecadd_mask_align_1
|
||||
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
// CHECK-LABEL: @vecadd_mask_align_1
|
||||
func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
// CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
@@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
// Shared memory is available after a tensor's liveness range ends
|
||||
// CHECK-LABEL: reusable
|
||||
func @reusable(%A : !tt.ptr<f16>) {
|
||||
func.func @reusable(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
|
||||
@@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr<f16>) {
|
||||
// %cst1->%cst4
|
||||
// %cst3->%g->%h->%i
|
||||
// CHECK-LABEL: preallocate
|
||||
func @preallocate(%A : !tt.ptr<f16>) {
|
||||
func.func @preallocate(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 1024, size = 512
|
||||
@@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr<f16>) {
|
||||
|
||||
// Unused tensors are immediately released
|
||||
// CHECK-LABEL: unused
|
||||
func @unused(%A : !tt.ptr<f16>) {
|
||||
func.func @unused(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 0, size = 512
|
||||
@@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr<f16>) {
|
||||
|
||||
// cst0 is alive through the entire function, it cannot be released before the end of the function
|
||||
// CHECK-LABEL: longlive
|
||||
func @longlive(%A : !tt.ptr<f16>) {
|
||||
func.func @longlive(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func @alloc(%A : !tt.ptr<f16>) {
|
||||
func.func @alloc(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
@@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scratch
|
||||
func @scratch() {
|
||||
func.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: scratch offset = 0, size = 512
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
@@ -176,7 +176,7 @@ func @scratch() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
func.func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
@@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
@@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// B0 -> (B1) -> B0
|
||||
// Memory used by B1 can be reused by B0.
|
||||
// CHECK-LABEL: if
|
||||
func @if(%i1 : i1) {
|
||||
func.func @if(%i1 : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -233,7 +233,7 @@ func @if(%i1 : i1) {
|
||||
// B0 -> (B1) -> (B2) -> B0
|
||||
// Memory used by B0 cannot be reused by B1 or B2.
|
||||
// CHECK-LABEL: if_else
|
||||
func @if_else(%i1 : i1) {
|
||||
func.func @if_else(%i1 : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -260,7 +260,7 @@ func @if_else(%i1 : i1) {
|
||||
// Block arguments and yields are memory aliases that do not trigger a new
|
||||
// allocation.
|
||||
// CHECK-LABEL: for
|
||||
func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if_slice
|
||||
func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
|
||||
|
||||
// c0 cannot be released in the loop
|
||||
// CHECK-LABEL: for_use_ancestor
|
||||
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
||||
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
||||
// CHECK-LABEL: for_if_for
|
||||
func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
|
||||
@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any membar with the dot op encoding.
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
@@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
}
|
||||
|
||||
// CHECK-LABEL: raw_single_block
|
||||
func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: war_single_block
|
||||
func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
func.func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scratch
|
||||
func @scratch() {
|
||||
func.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 1
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
@@ -81,7 +81,7 @@ func @scratch() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: async_wait
|
||||
func @async_wait() {
|
||||
func.func @async_wait() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 1
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
@@ -92,7 +92,7 @@ func @async_wait() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func @alloc() {
|
||||
func.func @alloc() {
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 2
|
||||
@@ -101,7 +101,7 @@ func @alloc() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func @extract_slice() {
|
||||
func.func @extract_slice() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
@@ -113,14 +113,14 @@ func @extract_slice() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans() {
|
||||
func.func @trans() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice
|
||||
func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
|
||||
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
|
||||
// CHECK-LABEL: multi_blocks
|
||||
func @multi_blocks(%i1 : i1) {
|
||||
func.func @multi_blocks(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) {
|
||||
|
||||
// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region
|
||||
// CHECK-LABEL: multi_blocks_join_barrier
|
||||
func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
func.func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
|
||||
// Read yielded tensor requires a barrier
|
||||
// CHECK-LABEL: multi_blocks_yield
|
||||
func @multi_blocks_yield(%i1 : i1) {
|
||||
func.func @multi_blocks_yield(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
||||
@@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) {
|
||||
|
||||
// Conservatively add a barrier as if the branch (%i1) is never taken
|
||||
// CHECK-LABEL: multi_blocks_noelse
|
||||
func @multi_blocks_noelse(%i1 : i1) {
|
||||
func.func @multi_blocks_noelse(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) {
|
||||
|
||||
// Conservatively add a barrier as if the branch (%i2) is never taken
|
||||
// CHECK-LABEL: multi_blocks_nested_scf
|
||||
func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
@@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
||||
// Although a_shared and b_shared are synced before entering the loop,
|
||||
// they are reassociated with aliases (c_shared) and thus require a barrier.
|
||||
// CHECK-LABEL: for_alias
|
||||
func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
@@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
|
||||
// So we need a barrier both before and after cst1
|
||||
// CHECK-LABEL: for_reuse
|
||||
func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
@@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
|
||||
|
||||
// CHECK-LABEL: for_reuse_nested
|
||||
func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: triton-opt %s | FileCheck %s
|
||||
|
||||
func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
// scalar -> scalar
|
||||
// CHECK: i64 -> !tt.ptr<f32>
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
@@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
return
|
||||
}
|
||||
|
||||
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
|
||||
@@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
return
|
||||
}
|
||||
|
||||
func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||
func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||
// Test if Load/Store ops can handle scalar values
|
||||
%other = arith.constant 0.0e+0 : f32
|
||||
|
||||
@@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma
|
||||
return
|
||||
}
|
||||
|
||||
func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
// Test if reduce ops infer types correctly
|
||||
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
||||
@@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
// Test if reduce ops infer types correctly
|
||||
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
|
||||
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
func @ops() {
|
||||
func.func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
@@ -11,7 +11,7 @@ func @ops() {
|
||||
|
||||
// -----
|
||||
|
||||
func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if LoadOp is lowered properly (see #771)
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%mask = arith.constant dense<true> : tensor<128xi1>
|
||||
@@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 32
|
||||
// Test if the total number of warps is 2
|
||||
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
|
||||
@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
}
|
||||
@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_load
|
||||
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: llvm.inline_asm
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: vectorized_load
|
||||
func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: vectorized_load_f16
|
||||
func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
||||
func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b16
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other
|
||||
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
@@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other_vec
|
||||
func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec4
|
||||
func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec2
|
||||
func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_view_broadcast
|
||||
func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
||||
func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: %[[T0:.*]] = llvm.extractvalue
|
||||
// CHECK: %[[T1:.*]] = llvm.extractvalue
|
||||
@@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_make_range
|
||||
func @basic_make_range() {
|
||||
func.func @basic_make_range() {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
@@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addf
|
||||
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.fadd
|
||||
// CHECK: llvm.fadd
|
||||
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
|
||||
@@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addi
|
||||
func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.add
|
||||
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
|
||||
@@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
func @basic_program_id() {
|
||||
func.func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
return
|
||||
@@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addptr
|
||||
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
@@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func @basic_alloc_tensor() {
|
||||
func.func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
// CHECK-NEXT: llvm.mlir.constant
|
||||
@@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
func.func @basic_extract_slice() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.extractvalue
|
||||
@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_async_wait
|
||||
func @basic_async_wait() {
|
||||
func.func @basic_async_wait() {
|
||||
// CHECK: cp.async.wait_group 0x4
|
||||
triton_gpu.async_wait {num = 4: i32}
|
||||
return
|
||||
@@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_fallback
|
||||
func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
||||
func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
||||
@@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
func.func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
@@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked
|
||||
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
@@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_vec
|
||||
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
@@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
|
||||
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
@@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// TODO: problems in MLIR's parser on slice layout
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// func @make_range_sliced_layout() {
|
||||
// func.func @make_range_sliced_layout() {
|
||||
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
// return
|
||||
// }
|
||||
@@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav2_block
|
||||
func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav1_block
|
||||
func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_shared
|
||||
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice0
|
||||
func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||
func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
return
|
||||
@@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice1
|
||||
func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
return
|
||||
@@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked_to_blocked_ptr
|
||||
func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
||||
func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
||||
// CHECK: llvm.ptrtoint
|
||||
// CHECK: llvm.store
|
||||
// CHECK: nvvm.barrier0
|
||||
@@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
@@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
@@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
|
||||
// CHECK: llvm.intr.fmuladd
|
||||
@@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
@@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
@@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
||||
@@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: test_index_cache
|
||||
func @test_index_cache() {
|
||||
func.func @test_index_cache() {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||
@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_base_index_cache
|
||||
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||
@@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_index_cache_different_block
|
||||
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||
func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
scf.if %arg1 {
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||
// CHECK: define void @test_empty_kernel
|
||||
// CHECK: !nvvm.annotations
|
||||
// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
|
||||
// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_combine_dot_add_pattern
|
||||
func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
||||
// CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
||||
func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
||||
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
||||
%a = arith.constant dense<1.0> : tensor<128x128xf32>
|
||||
%b = arith.constant dense<2.0> : tensor<128x128xf32>
|
||||
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
||||
@@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
|
||||
|
||||
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
@@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
||||
%0 = arith.select %cond, %x, %false_val : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%1 = select %cond, %y, %false_val : tensor<8xf32>
|
||||
%1 = arith.select %cond, %y, %false_val : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
||||
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%0 = select %cond0, %dummy_load, %false_val : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32>
|
||||
|
||||
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
|
||||
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%1 = select %cond0, %real_load0, %false_val : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32>
|
||||
|
||||
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
|
||||
%cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1>
|
||||
%real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%2 = select %cond1, %real_load1, %false_val : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
|
||||
|
||||
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
|
||||
%const = arith.constant dense<1.0> : tensor<8xf32>
|
||||
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
|
||||
@@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
||||
func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
@@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
|
||||
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
|
||||
@@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
||||
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
|
||||
@@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
|
||||
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
||||
func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
||||
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
|
||||
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
tt.store %ptr, %val, %mask : tensor<8xf32>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: triton-opt %s -verify-diagnostics
|
||||
|
||||
module {
|
||||
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -43,7 +43,7 @@ module {
|
||||
}
|
||||
}
|
||||
// module {
|
||||
// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
// %c64 = arith.constant 64 : index
|
||||
// %c32 = arith.constant 32 : index
|
||||
// %c0 = arith.constant 0 : index
|
||||
|
||||
@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
|
||||
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
|
||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg1: i32 {tt.divisibility = 16 : i32},
|
||||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK-LABEL: cst
|
||||
func @cst() -> tensor<1024xi32, #layout1> {
|
||||
func.func @cst() -> tensor<1024xi32, #layout1> {
|
||||
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
@@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: range
|
||||
func @range() -> tensor<1024xi32, #layout1> {
|
||||
func.func @range() -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
@@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: splat
|
||||
func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
@@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remat
|
||||
func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
|
||||
@@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remat_load_store
|
||||
func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
@@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
|
||||
// Don't rematerialize vectorized loads
|
||||
// CHECK-LABEL: remat_expensive
|
||||
func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
|
||||
@@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
|
||||
// Don't rematerialize loads when original and target layouts are different
|
||||
// CHECK-LABEL: remat_multi_layout
|
||||
func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
@@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
|
||||
// Always rematerialize single value loads
|
||||
// CHECK-LABEL: remat_single_value
|
||||
func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
@@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if
|
||||
func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_convert_else_not
|
||||
func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_not_else_convert
|
||||
func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_else_both_convert
|
||||
func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
@@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop
|
||||
func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
@@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
}
|
||||
|
||||
// CHECK-LABEL: vecadd
|
||||
func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
||||
|
||||
// Select has args with different element types
|
||||
// CHECK-LABEL: select
|
||||
func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
|
||||
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
|
||||
@@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f6
|
||||
|
||||
// Make sure the following IR doesn't hang the compiler.
|
||||
// CHECK-LABEL: long_func
|
||||
func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
||||
func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
|
||||
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
|
||||
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
|
||||
@@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1:
|
||||
// A mnist model from torch inductor.
|
||||
// Check if topological sort is working correct and there's no unnecessary convert
|
||||
// CHECK-LABEL: mnist
|
||||
func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
|
||||
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
|
||||
@@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.
|
||||
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
// cmpf and cmpi have different operands and result types
|
||||
// CHECK-LABEL: cmp
|
||||
func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64 = arith.constant 64 : index
|
||||
%c2048 = arith.constant 2048 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
// CHECK: func @matmul_loop
|
||||
// CHECK: func.func @matmul_loop
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -46,8 +46,8 @@
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
// A ptrs
|
||||
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
|
||||
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
|
||||
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
|
||||
|
||||
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||
@@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @matmul_loop_nested
|
||||
// CHECK: func.func @matmul_loop_nested
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
scf.for %iv0 = %lb to %ub step %step {
|
||||
// A ptrs
|
||||
@@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
|
||||
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
|
||||
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
|
||||
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||
@@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @matmul_loop_single_pipeline
|
||||
// CHECK: func.func @matmul_loop_single_pipeline
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
// A ptrs
|
||||
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// CHECK: offset = 49152, size = 49152
|
||||
// CHECK: size = 98304
|
||||
module {
|
||||
func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
||||
func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1>
|
||||
%c64 = arith.constant 64 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
@@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%7 = arith.muli %6, %c8_i32 : i32
|
||||
%8 = arith.subi %2, %7 : i32
|
||||
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
||||
%10 = select %9, %8, %c8_i32 : i32
|
||||
%10 = arith.select %9, %8, %c8_i32 : i32
|
||||
%11 = arith.remsi %0, %10 : i32
|
||||
%12 = arith.addi %7, %11 : i32
|
||||
%13 = arith.remsi %0, %5 : i32
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
|
||||
// CHECK: func @matmul_loop
|
||||
// CHECK: func.func @matmul_loop
|
||||
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
||||
@@ -28,7 +28,7 @@
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
// CHECK-LABEL: dot_mmav1
|
||||
func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
// CHECK-LABEL: dot_mmav1
|
||||
func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
|
||||
@@ -9,10 +9,10 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAliasPass
|
||||
: public PassWrapper<TestAliasPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
||||
static void print(StringRef name, SmallVector<std::string, 4> &vals,
|
||||
raw_ostream &os) {
|
||||
if (vals.empty())
|
||||
@@ -39,23 +39,24 @@ struct TestAliasPass
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
|
||||
SharedMemoryAliasAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
SharedMemoryAliasAnalysis *analysis =
|
||||
solver->load<SharedMemoryAliasAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation)))
|
||||
return signalPassFailure();
|
||||
|
||||
AsmState state(operation->getParentOfType<ModuleOp>());
|
||||
// Get operation ids of value's aliases
|
||||
auto getAllocOpNames = [&](Value value) {
|
||||
LatticeElement<AliasInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(value);
|
||||
dataflow::Lattice<AliasInfo> *latticeElement =
|
||||
analysis->getLatticeElement(value);
|
||||
SmallVector<std::string, 4> opNames;
|
||||
if (latticeElement) {
|
||||
if (latticeElement && !latticeElement->isUninitialized()) {
|
||||
auto &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto &alias : info.getAllocs()) {
|
||||
auto opName =
|
||||
getValueOperandName(alias.getDefiningOp()->getResult(0), state);
|
||||
opNames.push_back(std::move(opName));
|
||||
}
|
||||
for (auto &alias : info.getAllocs()) {
|
||||
auto opName =
|
||||
getValueOperandName(alias.getDefiningOp()->getResult(0), state);
|
||||
opNames.push_back(std::move(opName));
|
||||
}
|
||||
}
|
||||
// Ensure deterministic output
|
||||
|
||||
@@ -6,10 +6,9 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAllocationPass
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||||
|
||||
StringRef getArgument() const final { return "test-print-allocation"; }
|
||||
StringRef getDescription() const final {
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
|
||||
os << name << ": [";
|
||||
for (size_t d = 0; d < vals.size(); d++) {
|
||||
if (d != 0)
|
||||
os << ", ";
|
||||
os << vals[d];
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||
StringRef getDescription() const final {
|
||||
@@ -30,38 +20,19 @@ struct TestAxisInfoPass
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
os << "@" << opName << "\n";
|
||||
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation)))
|
||||
return signalPassFailure();
|
||||
operation->walk([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
// std::ostringstream oss;
|
||||
// result.print(oss);
|
||||
// os << " => ";
|
||||
LatticeElement<AxisInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(result);
|
||||
if (!latticeElement) {
|
||||
os << "None\n";
|
||||
return;
|
||||
}
|
||||
AxisInfo &info = latticeElement->getValue();
|
||||
print("Contiguity", os, info.getContiguity());
|
||||
os << " ; ";
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
os << " ; ";
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ; ";
|
||||
auto constantValue = info.getConstantValue();
|
||||
os << "ConstantValue: [";
|
||||
if (constantValue.has_value())
|
||||
os << constantValue.value();
|
||||
else
|
||||
os << "None";
|
||||
os << "] ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << " => ";
|
||||
analysis->getLatticeElement(result)->getValue().print(os);
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
@@ -9,10 +9,9 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestMembarPass
|
||||
: public PassWrapper<TestMembarPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
||||
|
||||
StringRef getArgument() const final { return "test-print-membar"; }
|
||||
StringRef getDescription() const final {
|
||||
|
||||
Reference in New Issue
Block a user