mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Replace Func Dialect with custom triton ops (func, call, return) (#1502)
MLIR current only supports a custom inlining interface per dialect, so we cannot change the inlining decision of `func.func`. https://discourse.llvm.org/t/avoid-inlining-some-functions-using-the-func-dialect/69830/3 Could revert it back once they've designed a better inliner interface. Inlining attributes will be implemented in the next PR since this PR is already huge.
This commit is contained in:
@@ -32,7 +32,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithDialect,
|
||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
|
||||
mlir::gpu::GPUDialect>();
|
||||
}
|
||||
|
||||
@@ -19,15 +19,16 @@ def Triton_Dialect : Dialect {
|
||||
* Math:
|
||||
* exp, sin, cos, log, ...
|
||||
* StructuredControlFlow:
|
||||
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
|
||||
* for, if, while, yield, condition
|
||||
* ControlFlow:
|
||||
* br, cond_br
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect",
|
||||
"math::MathDialect",
|
||||
"scf::SCFDialect",
|
||||
"cf::ControlFlowDialect",
|
||||
"func::FuncDialect"
|
||||
"cf::ControlFlowDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
@@ -6,11 +6,15 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
|
||||
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
|
||||
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
|
||||
|
||||
//
|
||||
// Op Base
|
||||
@@ -20,7 +24,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
}
|
||||
|
||||
//
|
||||
// CastOps
|
||||
// Cast Ops
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
@@ -93,7 +97,6 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
//
|
||||
// Pointer Arith Ops
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[Pure,
|
||||
SameOperandsAndResultShape,
|
||||
@@ -209,7 +212,7 @@ def TT_StoreOp : TT_Op<"store",
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic Op
|
||||
// Atomic Ops
|
||||
//
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
@@ -256,7 +259,6 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
@@ -408,7 +410,7 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure,
|
||||
}
|
||||
|
||||
//
|
||||
// External elementwise op
|
||||
// External Elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
@@ -449,7 +451,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
}
|
||||
|
||||
//
|
||||
// Make PrintOp
|
||||
// Print Op
|
||||
//
|
||||
def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
@@ -464,7 +466,7 @@ def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
|
||||
}
|
||||
|
||||
//
|
||||
// Make AssertOp
|
||||
// Assert Op
|
||||
//
|
||||
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "Device-side assert, as in CUDA for correctness checking";
|
||||
@@ -477,7 +479,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||
}
|
||||
|
||||
//
|
||||
// Make a Tensor Pointer
|
||||
// Make Tensor Pointer Op
|
||||
//
|
||||
def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
|
||||
[Pure,
|
||||
@@ -518,4 +520,199 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
|
||||
];
|
||||
}
|
||||
|
||||
// The following ops, including `call`, `func`, and `return` are copied and modified from
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
|
||||
// We could revert it back once MLIR has a better inliner interface.
|
||||
//
|
||||
// Function Ops
|
||||
//
|
||||
def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "call operation";
|
||||
let description = [{
|
||||
The `tt.call` operation represents a direct call to a function that is
|
||||
within the same symbol scope as the call. The operands and result types of
|
||||
the call must match the specified function type. The callee is encoded as a
|
||||
symbol reference attribute named "callee".
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
|
||||
$_state.addTypes(callee.getFunctionType().getResults());
|
||||
}]>,
|
||||
OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttribute("callee", callee);
|
||||
$_state.addTypes(results);
|
||||
}]>,
|
||||
OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, SymbolRefAttr::get(callee), results, operands);
|
||||
}]>,
|
||||
OpBuilder<(ins "StringRef":$callee, "TypeRange":$results,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
|
||||
results, operands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
FunctionType getCalleeType() {
|
||||
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
/// Return the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> {
|
||||
let summary = "An operation with a name containing a single `SSACFG` region";
|
||||
let description = [{
|
||||
Operations within the function cannot implicitly capture values defined
|
||||
outside of the function, i.e. Functions are `IsolatedFromAbove`. All
|
||||
external references must use function arguments or attributes that establish
|
||||
a symbolic connection (e.g. symbols referenced by name via a string
|
||||
attribute like SymbolRefAttr). An external function declaration (used when
|
||||
referring to a function declared in some other module) has no body. While
|
||||
the MLIR textual form provides a nice inline syntax for function arguments,
|
||||
they are internally represented as “block arguments” to the first block in
|
||||
the region.
|
||||
|
||||
Only dialect attribute names may be specified in the attribute dictionaries
|
||||
for function arguments, results, or the function itself.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// External function definitions.
|
||||
tt.func @abort()
|
||||
tt.func @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64
|
||||
|
||||
// A function that returns its argument twice:
|
||||
tt.func @count(%x: i64) -> (i64, i64)
|
||||
attributes {fruit: "banana"} {
|
||||
return %x, %x: i64, i64
|
||||
}
|
||||
|
||||
// A function with an argument attribute
|
||||
tt.func @example_fn_arg(%x: i32 {swift.self = unit})
|
||||
|
||||
// A function with a result attribute
|
||||
tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})
|
||||
|
||||
// A function with an attribute
|
||||
tt.func @example_fn_attr() attributes {dialectName.attrName = false}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$function_type,
|
||||
OptionalAttr<StrAttr>:$sym_visibility,
|
||||
OptionalAttr<DictArrayAttr>:$arg_attrs,
|
||||
OptionalAttr<DictArrayAttr>:$res_attrs);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
|
||||
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
//===------------------------------------------------------------------===//
|
||||
// CallableOpInterface
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the region on the current operation that is callable. This may
|
||||
/// return null in the case of an external callable object, e.g. an external
|
||||
/// function.
|
||||
::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// SymbolOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
bool isDeclaration() { return isExternal(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> {
|
||||
let summary = "Function return operation";
|
||||
let description = [{
|
||||
The `tt.return` operation represents a return operation within a function.
|
||||
The operation takes variable number of operands and produces no results.
|
||||
The operand number and types must match the signature of the function
|
||||
that contains the operation.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
tt.func @foo() : (i32, f8) {
|
||||
...
|
||||
tt.return %0, %1 : i32, f8
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
|
||||
let builders = [OpBuilder<(ins), [{
|
||||
build($_builder, $_state, std::nullopt);
|
||||
}]>];
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
||||
@@ -77,7 +77,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
|
||||
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (auto fun = dyn_cast<func::FuncOp>(op))
|
||||
if (auto fun = dyn_cast<triton::FuncOp>(op))
|
||||
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
|
||||
&knownContiguity, &knownDivisibility,
|
||||
&knownConstancy);
|
||||
|
||||
@@ -74,7 +74,7 @@ void MembarAnalysis::visitTerminator(Operation *op,
|
||||
return;
|
||||
}
|
||||
// Otherwise, it could be a return op
|
||||
assert(isa<func::ReturnOp>(op) && "Unknown terminator");
|
||||
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
|
||||
}
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
|
||||
@@ -9,11 +9,11 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
// TODO: refactor so that it doesn't fail if Allocation.h
|
||||
// is included after utility.h (due to conflict in `store` macro
|
||||
// and <atomic>
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
|
||||
#include "TypeConverter.h"
|
||||
@@ -41,12 +40,12 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
|
||||
// All the rights are reserved by the LLVM community.
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
|
||||
private:
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs,
|
||||
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
|
||||
for (const auto &attr : op->getAttrs()) {
|
||||
@@ -66,12 +65,12 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<triton::FuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||
// to this legalization pattern.
|
||||
LLVM::LLVMFuncOp
|
||||
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
|
||||
convertFuncOpToLLVMFuncOp(triton::FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
|
||||
@@ -51,16 +51,15 @@ public:
|
||||
} else {
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
}
|
||||
addIllegalOp<mlir::func::FuncOp>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
@@ -86,7 +85,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
if (!newFuncOp) {
|
||||
|
||||
@@ -163,8 +163,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are
|
||||
// inlined by the frontend
|
||||
target.addLegalOp<triton::ReturnOp>(); // this is ok because all functions are
|
||||
// inlined by the frontend
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
@@ -721,15 +721,15 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class FuncOpPattern : public OpConversionPattern<func::FuncOp> {
|
||||
class FuncOpPattern : public OpConversionPattern<triton::FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
|
||||
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
|
||||
op, op.getName(), op.getFunctionType());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
|
||||
|
||||
@@ -12,5 +12,4 @@ add_mlir_dialect_library(TritonIR
|
||||
MLIRIR
|
||||
MLIRArithDialect
|
||||
MLIRSCFDialect
|
||||
MLIRFuncDialect
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
@@ -21,6 +22,10 @@ using namespace mlir::triton;
|
||||
namespace {
|
||||
struct TritonInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Operation *call, Operation *callable,
|
||||
bool wouldBeCloned) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
IRMapping &valueMapping) const final {
|
||||
return true;
|
||||
@@ -29,6 +34,37 @@ struct TritonInlinerInterface : public DialectInlinerInterface {
|
||||
IRMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Handle the given inlined terminator by replacing it with a new operation
|
||||
/// as necessary.
|
||||
void handleTerminator(Operation *op, Block *newDest) const final {
|
||||
// Only return needs to be handled here.
|
||||
auto returnOp = dyn_cast<triton::ReturnOp>(op);
|
||||
if (!returnOp)
|
||||
return;
|
||||
|
||||
// Replace the return with a branch to the dest.
|
||||
OpBuilder builder(op);
|
||||
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
|
||||
returnOp.getOperands());
|
||||
op->erase();
|
||||
}
|
||||
|
||||
/// Handle the given inlined terminator by replacing it with a new operation
|
||||
/// as necessary.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const final {
|
||||
// Only return needs to be handled here.
|
||||
auto returnOp = cast<triton::ReturnOp>(op);
|
||||
|
||||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
@@ -516,5 +518,105 @@ void MakeTensorPtrOp::build(::mlir::OpBuilder &builder,
|
||||
builder.getDenseI32ArrayAttr(order));
|
||||
}
|
||||
|
||||
// The following ops, including `call`, `func`, and `return` are copied and
|
||||
// modified from
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp
|
||||
// We could revert it back once MLIR has a better inliner interface.
|
||||
//-- FuncOp --
|
||||
void triton::FuncOp::build(OpBuilder &builder, OperationState &state,
|
||||
StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<DictionaryAttr> argAttrs) {
|
||||
state.addAttribute(SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(name));
|
||||
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
|
||||
state.attributes.append(attrs.begin(), attrs.end());
|
||||
state.addRegion();
|
||||
|
||||
if (argAttrs.empty())
|
||||
return;
|
||||
assert(type.getNumInputs() == argAttrs.size());
|
||||
function_interface_impl::addArgAndResultAttrs(
|
||||
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
|
||||
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
|
||||
}
|
||||
|
||||
ParseResult triton::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
auto buildFuncType =
|
||||
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
||||
function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false,
|
||||
getFunctionTypeAttrName(result.name), buildFuncType,
|
||||
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
||||
}
|
||||
|
||||
void triton::FuncOp::print(OpAsmPrinter &printer) {
|
||||
function_interface_impl::printFunctionOp(
|
||||
printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
||||
getArgAttrsAttrName(), getResAttrsAttrName());
|
||||
}
|
||||
|
||||
// -- CallOp --
|
||||
LogicalResult
|
||||
triton::CallOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return emitOpError("requires a 'callee' symbol reference attribute");
|
||||
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
|
||||
if (!fn)
|
||||
return emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto fnType = fn.getFunctionType();
|
||||
if (fnType.getNumInputs() != getNumOperands())
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (getOperand(i).getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch: expected operand type ")
|
||||
<< fnType.getInput(i) << ", but provided "
|
||||
<< getOperand(i).getType() << " for operand number " << i;
|
||||
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (getResult(i).getType() != fnType.getResult(i)) {
|
||||
auto diag = emitOpError("result type mismatch at index ") << i;
|
||||
diag.attachNote() << " op result types: " << getResultTypes();
|
||||
diag.attachNote() << "function result types: " << fnType.getResults();
|
||||
return diag;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// -- ReturnOp --
|
||||
LogicalResult triton::ReturnOp::verify() {
|
||||
auto function = cast<triton::FuncOp>((*this)->getParentOp());
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function.getFunctionType().getResults();
|
||||
if (getNumOperands() != results.size())
|
||||
return emitOpError("has ")
|
||||
<< getNumOperands() << " operands, but enclosing function (@"
|
||||
<< function.getName() << ") returns " << results.size();
|
||||
|
||||
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
||||
if (getOperand(i).getType() != results[i])
|
||||
return emitError() << "type of return operand " << i << " ("
|
||||
<< getOperand(i).getType()
|
||||
<< ") doesn't match function result type ("
|
||||
<< results[i] << ")"
|
||||
<< " in function @" << function.getName();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -81,18 +81,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
|
||||
func::FuncDialect, triton::TritonDialect,
|
||||
cf::ControlFlowDialect, scf::SCFDialect>(
|
||||
[&](Operation *op) {
|
||||
bool hasLegalRegions = true;
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion);
|
||||
}
|
||||
if (hasLegalRegions && typeConverter.isLegal(op)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
triton::TritonDialect, cf::ControlFlowDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
bool hasLegalRegions = true;
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion);
|
||||
}
|
||||
if (hasLegalRegions && typeConverter.isLegal(op)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
|
||||
@@ -208,7 +208,7 @@ void init_triton_ir(py::module &&m) {
|
||||
std::string attrName = name + "_arg" + std::to_string(id);
|
||||
mlir::Block *owner = arg.getOwner();
|
||||
if (owner->isEntryBlock() &&
|
||||
!mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) {
|
||||
!mlir::isa<mlir::triton::FuncOp>(owner->getParentOp())) {
|
||||
owner->getParentOp()->setAttr(attrName, attr);
|
||||
}
|
||||
}
|
||||
@@ -361,7 +361,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return str;
|
||||
})
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void {
|
||||
[](mlir::ModuleOp &self, mlir::triton::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function",
|
||||
@@ -372,13 +372,14 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("get_function",
|
||||
[](mlir::ModuleOp &self,
|
||||
std::string &funcName) -> mlir::func::FuncOp {
|
||||
return self.lookupSymbol<mlir::func::FuncOp>(funcName);
|
||||
std::string &funcName) -> mlir::triton::FuncOp {
|
||||
return self.lookupSymbol<mlir::triton::FuncOp>(funcName);
|
||||
})
|
||||
.def("get_single_function",
|
||||
[](mlir::ModuleOp &self) -> mlir::func::FuncOp {
|
||||
llvm::SmallVector<mlir::func::FuncOp> funcs;
|
||||
self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); });
|
||||
[](mlir::ModuleOp &self) -> mlir::triton::FuncOp {
|
||||
llvm::SmallVector<mlir::triton::FuncOp> funcs;
|
||||
self.walk(
|
||||
[&](mlir::triton::FuncOp func) { funcs.push_back(func); });
|
||||
if (funcs.size() != 1)
|
||||
throw std::runtime_error("Expected a single function");
|
||||
return funcs[0];
|
||||
@@ -400,12 +401,11 @@ void init_triton_ir(py::module &&m) {
|
||||
// initialize registry
|
||||
// note: we initialize llvm for undef
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithDialect,
|
||||
mlir::index::IndexDialect, mlir::func::FuncDialect,
|
||||
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
|
||||
mlir::LLVM::LLVMDialect>();
|
||||
registry.insert<
|
||||
mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithDialect,
|
||||
mlir::index::IndexDialect, mlir::scf::SCFDialect,
|
||||
mlir::cf::ControlFlowDialect, mlir::LLVM::LLVMDialect>();
|
||||
context.appendDialectRegistry(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
@@ -423,30 +423,30 @@ void init_triton_ir(py::module &&m) {
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function")
|
||||
py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function")
|
||||
// .def_property_readonly("attrs", &ir::function::attrs)
|
||||
// .def("add_attr", &ir::function::add_attr);
|
||||
.def("args",
|
||||
[](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
||||
[](mlir::triton::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
||||
return self.getArgument(idx);
|
||||
})
|
||||
.def(
|
||||
"add_entry_block",
|
||||
[](mlir::func::FuncOp &self) -> mlir::Block * {
|
||||
[](mlir::triton::FuncOp &self) -> mlir::Block * {
|
||||
return self.addEntryBlock();
|
||||
},
|
||||
ret::reference)
|
||||
.def(
|
||||
"set_arg_attr",
|
||||
[](mlir::func::FuncOp &self, int arg_no, const std::string &name,
|
||||
[](mlir::triton::FuncOp &self, int arg_no, const std::string &name,
|
||||
int val) {
|
||||
// set arg attributes "name" to value "val"
|
||||
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
},
|
||||
ret::reference)
|
||||
.def_property_readonly("type", &mlir::func::FuncOp::getFunctionType)
|
||||
.def("reset_type", &mlir::func::FuncOp::setType);
|
||||
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
|
||||
.def("reset_type", &mlir::triton::FuncOp::setType);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
|
||||
@@ -463,13 +463,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("ret",
|
||||
[](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::func::ReturnOp>(loc, vals);
|
||||
self.create<mlir::triton::ReturnOp>(loc, vals);
|
||||
})
|
||||
.def("call",
|
||||
[](mlir::OpBuilder &self, mlir::func::FuncOp &func,
|
||||
[](mlir::OpBuilder &self, mlir::triton::FuncOp &func,
|
||||
std::vector<mlir::Value> &args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::func::CallOp>(loc, func, args);
|
||||
return self.create<mlir::triton::CallOp>(loc, func, args);
|
||||
})
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start",
|
||||
@@ -651,16 +651,16 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_or_insert_function",
|
||||
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType,
|
||||
std::string &visibility) -> mlir::func::FuncOp {
|
||||
std::string &visibility) -> mlir::triton::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation);
|
||||
return llvm::dyn_cast<mlir::triton::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> attrs = {
|
||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||
self.getStringAttr(visibility))};
|
||||
return self.create<mlir::func::FuncOp>(loc, funcName, funcTy,
|
||||
attrs);
|
||||
return self.create<mlir::triton::FuncOp>(loc, funcName, funcTy,
|
||||
attrs);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
|
||||
@@ -2238,7 +2238,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
@@ -2256,7 +2256,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -267,14 +267,14 @@ def make_hash(fn, **kwargs):
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any aliasing with the dot op encoding.
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
@@ -32,38 +32,38 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func.func @alloc(%A : !tt.ptr<f16>) {
|
||||
tt.func @alloc(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK: %0 -> %0
|
||||
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: convert
|
||||
func.func @convert(%A : !tt.ptr<f16>) {
|
||||
tt.func @convert(%A : !tt.ptr<f16>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: %0 -> %0
|
||||
%cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func.func @trans(%A : !tt.ptr<f16>) {
|
||||
tt.func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
// CHECK: %0 -> %cst
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -72,11 +72,11 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%index = arith.constant 0 : i32
|
||||
// CHECK: %2 -> %cst_0
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice
|
||||
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -86,21 +86,21 @@ func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
|
||||
// CHECK: %inserted_slice -> %cst_0
|
||||
%b = tensor.insert_slice %a into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
tt.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
// CHECK-NEXT: %0 -> %cst
|
||||
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_cat
|
||||
func.func @if_cat(%i1 : i1) {
|
||||
tt.func @if_cat(%i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: %cst_0 -> %cst_0
|
||||
@@ -115,11 +115,11 @@ func.func @if_cat(%i1 : i1) {
|
||||
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %b : tensor<32x16xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_alias
|
||||
func.func @if_alias(%i1 : i1) {
|
||||
tt.func @if_alias(%i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -130,11 +130,11 @@ func.func @if_alias(%i1 : i1) {
|
||||
} else {
|
||||
scf.yield %cst1 : tensor<16x16xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -150,11 +150,11 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if
|
||||
func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -176,11 +176,11 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
}
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -211,11 +211,11 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
|
||||
}
|
||||
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cf_for
|
||||
func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
|
||||
tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -242,5 +242,5 @@ func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>,
|
||||
gpu.barrier
|
||||
// CHECK-NEXT: %9 -> %9
|
||||
%9 = tt.cat %0, %0 {axis = 0 : i64} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @cast
|
||||
func.func @cast() {
|
||||
tt.func @cast() {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst = arith.constant 1 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
@@ -10,13 +10,13 @@ func.func @cast() {
|
||||
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func.func @add() {
|
||||
tt.func @add() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
@@ -27,13 +27,13 @@ func.func @add() {
|
||||
%3 = arith.constant dense<127> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%4 = arith.addi %1, %3 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addptr
|
||||
func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst1 = arith.constant 1 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
@@ -84,13 +84,13 @@ func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.pt
|
||||
%21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr<i32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = <none>
|
||||
%22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr<i64>>, tensor<128x128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @sub
|
||||
func.func @sub() {
|
||||
tt.func @sub() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
@@ -101,13 +101,13 @@ func.func @sub() {
|
||||
%3 = arith.constant dense<129> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%4 = arith.subi %3, %1 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul
|
||||
func.func @mul() {
|
||||
tt.func @mul() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
@@ -122,13 +122,13 @@ func.func @mul() {
|
||||
%5 = arith.constant dense<2> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256
|
||||
%6 = arith.muli %4, %5 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div
|
||||
func.func @div() {
|
||||
tt.func @div() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
@@ -153,14 +153,14 @@ func.func @div() {
|
||||
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%11 = arith.divsi %10, %4 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @rem
|
||||
func.func @rem() {
|
||||
tt.func @rem() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
||||
@@ -179,35 +179,35 @@ func.func @rem() {
|
||||
%7 = arith.constant dense<66> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>
|
||||
%8 = arith.remui %0, %7 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast
|
||||
func.func @broadcast() {
|
||||
tt.func @broadcast() {
|
||||
// CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%0 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
|
||||
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @splat
|
||||
func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cmp
|
||||
func.func @cmp() {
|
||||
tt.func @cmp() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
@@ -226,13 +226,13 @@ func.func @cmp() {
|
||||
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0
|
||||
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @logic
|
||||
func.func @logic() {
|
||||
tt.func @logic() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
@@ -255,13 +255,13 @@ func.func @logic() {
|
||||
%9 = arith.ori %2, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%10 = arith.xori %2, %4 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @select
|
||||
func.func @select() {
|
||||
tt.func @select() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
||||
@@ -278,12 +278,12 @@ func.func @select() {
|
||||
%5 = arith.select %4, %3, %7 : tensor<128xi1>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
||||
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @shift() {
|
||||
tt.func @shift() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
@@ -296,12 +296,12 @@ func.func @shift() {
|
||||
%4 = arith.shrsi %0, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%5 = arith.shli %1, %2 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @max_min() {
|
||||
tt.func @max_min() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none>
|
||||
@@ -316,13 +316,13 @@ func.func @max_min() {
|
||||
%5 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8
|
||||
%6 = arith.maxsi %4, %5 : tensor<128xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @for
|
||||
func.func @for() {
|
||||
tt.func @for() {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0
|
||||
%a_init = arith.constant dense<0> : tensor<128x32xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1
|
||||
@@ -343,13 +343,13 @@ func.func @for() {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
|
||||
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @permute_2d
|
||||
func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
||||
@@ -397,7 +397,7 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -406,7 +406,7 @@ module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
||||
@@ -430,7 +430,7 @@ func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
@@ -440,7 +440,7 @@ func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||
// CHECK-LABEL: @vecadd_mask_align_16
|
||||
func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -461,7 +461,7 @@ func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
|
||||
// CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -469,7 +469,7 @@ func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||
// CHECK-LABEL: @vecadd_mask_align_1
|
||||
func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -489,5 +489,5 @@ func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %10 : tensor<64xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
@@ -40,13 +40,13 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 4608
|
||||
}
|
||||
|
||||
// Shared memory is available after a tensor's liveness range ends
|
||||
// CHECK-LABEL: reusable
|
||||
func.func @reusable(%A : !tt.ptr<f16>) {
|
||||
tt.func @reusable(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
|
||||
@@ -69,7 +69,7 @@ func.func @reusable(%A : !tt.ptr<f16>) {
|
||||
// CHECK-NEXT: offset = 0, size = 1152
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 4608
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func.func @reusable(%A : !tt.ptr<f16>) {
|
||||
// %cst1->%cst4
|
||||
// %cst3->%g->%h->%i
|
||||
// CHECK-LABEL: preallocate
|
||||
func.func @preallocate(%A : !tt.ptr<f16>) {
|
||||
tt.func @preallocate(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 1024, size = 512
|
||||
@@ -107,13 +107,13 @@ func.func @preallocate(%A : !tt.ptr<f16>) {
|
||||
%h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 2048, size = 4096
|
||||
%i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 12288
|
||||
}
|
||||
|
||||
// Unused tensors are immediately released
|
||||
// CHECK-LABEL: unused
|
||||
func.func @unused(%A : !tt.ptr<f16>) {
|
||||
tt.func @unused(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 0, size = 512
|
||||
@@ -122,13 +122,13 @@ func.func @unused(%A : !tt.ptr<f16>) {
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 1024, size = 1024
|
||||
%a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK: size = 2048
|
||||
}
|
||||
|
||||
// cst0 is alive through the entire function, it cannot be released before the end of the function
|
||||
// CHECK-LABEL: longlive
|
||||
func.func @longlive(%A : !tt.ptr<f16>) {
|
||||
tt.func @longlive(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -151,13 +151,13 @@ func.func @longlive(%A : !tt.ptr<f16>) {
|
||||
%c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 1024
|
||||
%d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 2560
|
||||
}
|
||||
|
||||
// This example triggers graph coloring with > 1 colors.
|
||||
// CHECK-LABEL: multi_color
|
||||
func.func @multi_color(%A : !tt.ptr<f16>) {
|
||||
tt.func @multi_color(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 64
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<4x8xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 1216, size = 32
|
||||
@@ -199,39 +199,39 @@ func.func @multi_color(%A : !tt.ptr<f16>) {
|
||||
%cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL>
|
||||
%cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL>
|
||||
// CHECK-NEXT: size = 2656
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func.func @alloc(%A : !tt.ptr<f16>) {
|
||||
tt.func @alloc(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 512
|
||||
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scratch
|
||||
func.func @scratch() {
|
||||
tt.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: scratch offset = 0, size = 512
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func.func @trans(%A : !tt.ptr<f16>) {
|
||||
tt.func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -239,24 +239,24 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
tt.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// B0 -> (B1) -> B0
|
||||
// Memory used by B1 can be reused by B0.
|
||||
// CHECK-LABEL: if
|
||||
func.func @if(%i1 : i1) {
|
||||
tt.func @if(%i1 : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -273,14 +273,14 @@ func.func @if(%i1 : i1) {
|
||||
%cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 1024, size = 1024
|
||||
%a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 2048
|
||||
}
|
||||
|
||||
// B0 -> (B1) -> (B2) -> B0
|
||||
// Memory used by B0 cannot be reused by B1 or B2.
|
||||
// CHECK-LABEL: if_else
|
||||
func.func @if_else(%i1 : i1) {
|
||||
tt.func @if_else(%i1 : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 512, size = 512
|
||||
@@ -300,14 +300,14 @@ func.func @if_else(%i1 : i1) {
|
||||
}
|
||||
// CHECK-NEXT: offset = 1024, size = 1024
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 3072
|
||||
}
|
||||
|
||||
// Block arguments and yields are memory aliases that do not trigger a new
|
||||
// allocation.
|
||||
// CHECK-LABEL: for
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -317,12 +317,12 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 24576
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if_slice
|
||||
func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -337,13 +337,13 @@ func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f1
|
||||
}
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 24576
|
||||
}
|
||||
|
||||
// c0 cannot be released in the loop
|
||||
// CHECK-LABEL: for_use_ancestor
|
||||
func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -356,14 +356,14 @@ func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.pt
|
||||
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 32768
|
||||
}
|
||||
|
||||
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
||||
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
@@ -387,7 +387,7 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
|
||||
}
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%cst2 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 40960
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any membar with the dot op encoding.
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
@@ -38,11 +38,11 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: raw_single_block
|
||||
func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
tt.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -51,11 +51,11 @@ func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: war_single_block
|
||||
func.func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
tt.func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -67,11 +67,11 @@ func.func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: %4 = triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scratch
|
||||
func.func @scratch() {
|
||||
tt.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
@@ -80,11 +80,11 @@ func.func @scratch() {
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: async_wait
|
||||
func.func @async_wait() {
|
||||
tt.func @async_wait() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
@@ -93,21 +93,21 @@ func.func @async_wait() {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func.func @alloc() {
|
||||
tt.func @alloc() {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%1 = tt.cat %0, %0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extract_slice
|
||||
func.func @extract_slice() {
|
||||
tt.func @extract_slice() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%0 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
@@ -117,19 +117,19 @@ func.func @extract_slice() {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func.func @trans() {
|
||||
tt.func @trans() {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async_op
|
||||
func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -142,11 +142,11 @@ func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_op
|
||||
func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
@@ -162,12 +162,12 @@ func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
|
||||
// CHECK-LABEL: multi_blocks
|
||||
func.func @multi_blocks(%i1 : i1) {
|
||||
tt.func @multi_blocks(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -186,12 +186,12 @@ func.func @multi_blocks(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%2 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region
|
||||
// CHECK-LABEL: multi_blocks_join_barrier
|
||||
func.func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
tt.func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -206,12 +206,12 @@ func.func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
scf.yield
|
||||
}
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Read yielded tensor requires a barrier
|
||||
// CHECK-LABEL: multi_blocks_yield
|
||||
func.func @multi_blocks_yield(%i1 : i1) {
|
||||
tt.func @multi_blocks_yield(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
||||
@@ -229,12 +229,12 @@ func.func @multi_blocks_yield(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Even though the entry block doesn't have a barrier, the successors should have barriers
|
||||
// CHECK-LABEL: multi_blocks_entry_no_shared
|
||||
func.func @multi_blocks_entry_no_shared(%i1 : i1) {
|
||||
tt.func @multi_blocks_entry_no_shared(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
||||
// CHECK: gpu.barrier
|
||||
@@ -251,12 +251,12 @@ func.func @multi_blocks_entry_no_shared(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Conservatively add a barrier as if the branch (%i1) is never taken
|
||||
// CHECK-LABEL: multi_blocks_noelse
|
||||
func.func @multi_blocks_noelse(%i1 : i1) {
|
||||
tt.func @multi_blocks_noelse(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -268,12 +268,12 @@ func.func @multi_blocks_noelse(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Conservatively add a barrier as if the branch (%i2) is never taken
|
||||
// CHECK-LABEL: multi_blocks_nested_scf
|
||||
func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
@@ -293,11 +293,11 @@ func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
@@ -307,13 +307,13 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%5 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Although a_shared and b_shared are synced before entering the loop,
|
||||
// they are reassociated with aliases (c_shared) and thus require a barrier.
|
||||
// CHECK-LABEL: for_alias
|
||||
func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
@@ -330,13 +330,13 @@ func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
|
||||
// So we need a barrier both before and after cst1
|
||||
// CHECK-LABEL: for_reuse
|
||||
func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
@@ -355,11 +355,11 @@ func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_reuse_nested
|
||||
func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
@@ -381,12 +381,12 @@ func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.pt
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%15 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// repeatedly write to the same shared memory addresses
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
@@ -407,12 +407,12 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
|
||||
}
|
||||
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// c_block_next can either be converted from c_shared_init or c_shared_next_next
|
||||
// CHECK-LABEL: for_if_for
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
@@ -438,11 +438,11 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
|
||||
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cf_if
|
||||
func.func @cf_if(%i1 : i1) {
|
||||
tt.func @cf_if(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
@@ -455,10 +455,10 @@ func.func @cf_if(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @cf_if_else(%i1 : i1) {
|
||||
tt.func @cf_if_else(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
@@ -479,10 +479,10 @@ func.func @cf_if_else(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @cf_if_else_return(%i1 : i1) {
|
||||
tt.func @cf_if_else_return(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
@@ -490,12 +490,12 @@ func.func @cf_if_else_return(%i1 : i1) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
tt.return
|
||||
^bb2: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: triton-opt %s | FileCheck %s
|
||||
|
||||
func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
tt.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
// scalar -> scalar
|
||||
// CHECK: i64 -> !tt.ptr<f32>
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
@@ -32,10 +32,10 @@ func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i6
|
||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
// CHECK: tensor<16xf32> to tensor<16xf16>
|
||||
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
tt.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
|
||||
@@ -51,10 +51,10 @@ func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||
tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||
// Test if Load/Store ops can handle scalar values
|
||||
%other = arith.constant 0.0e+0 : f32
|
||||
|
||||
@@ -73,10 +73,10 @@ func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
|
||||
tt.store %ptr, %b, %mask : f32
|
||||
// CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} {cache = 1 : i32, evict = 1 : i32} : f32
|
||||
tt.store %ptr, %c, %mask : f32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
// Test if reduce ops infer types correctly
|
||||
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
||||
@@ -98,10 +98,10 @@ func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
tt.store %ptr1x2, %c : tensor<1x2xf32>
|
||||
tt.store %ptr1, %e : tensor<1xf32>
|
||||
tt.store %ptr, %g : f32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
// Test if reduce ops infer types correctly
|
||||
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
|
||||
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
|
||||
@@ -128,5 +128,5 @@ func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
|
||||
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
|
||||
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
func.func @ops() {
|
||||
tt.func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if LoadOp is lowered properly (see #771)
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%mask = arith.constant dense<true> : tensor<128xi1>
|
||||
@@ -25,12 +25,12 @@ func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.store %ptrs, %a : tensor<128xf32>
|
||||
tt.store %ptrs, %b : tensor<128xf32>
|
||||
tt.store %ptrs, %c : tensor<128xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 32
|
||||
// Test if the total number of warps is 2
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
@@ -49,5 +49,5 @@ func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
|
||||
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
} // end module
|
||||
|
||||
@@ -15,11 +15,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_load
|
||||
func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: llvm.inline_asm
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,13 +28,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: vectorized_load
|
||||
func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b32
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,13 +43,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: vectorized_load_f16
|
||||
func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
||||
tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b16
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.b16
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,10 +59,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other
|
||||
func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,10 +72,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other_vec
|
||||
func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -127,7 +127,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec4
|
||||
func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -163,7 +163,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// Store 4 elements to global with single one vectorized store instruction
|
||||
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
@@ -194,7 +194,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
|
||||
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec2
|
||||
func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -239,7 +239,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
|
||||
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -278,7 +278,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +291,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_view_broadcast
|
||||
func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
||||
tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: %[[T0:.*]] = llvm.extractvalue
|
||||
// CHECK: %[[T1:.*]] = llvm.extractvalue
|
||||
@@ -306,7 +306,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,13 +315,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_make_range
|
||||
func.func @basic_make_range() {
|
||||
tt.func @basic_make_range() {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,11 +330,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addf
|
||||
func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.fadd
|
||||
// CHECK: llvm.fadd
|
||||
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,11 +343,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addi
|
||||
func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.add
|
||||
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,10 +355,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
func.func @basic_program_id() {
|
||||
tt.func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,11 +367,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addptr
|
||||
func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,14 +381,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func.func @basic_alloc_tensor() {
|
||||
tt.func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
// CHECK-NEXT: llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +398,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func.func @basic_extract_slice() {
|
||||
tt.func @basic_extract_slice() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.extractvalue
|
||||
@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%index = arith.constant 1 : i32
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,10 +431,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_async_wait
|
||||
func.func @basic_async_wait() {
|
||||
tt.func @basic_async_wait() {
|
||||
// CHECK: cp.async.wait_group 0x4
|
||||
triton_gpu.async_wait {num = 4: i32}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_fallback
|
||||
func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
tt.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -473,7 +473,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16>, #AL> -> tensor<2x16x64xf16, #A>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,7 +489,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||
func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
||||
tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -515,7 +515,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
|
||||
triton_gpu.async_commit_group
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -531,7 +531,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||
func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -561,7 +561,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
|
||||
triton_gpu.async_commit_group
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -576,7 +576,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||
func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
||||
@@ -618,7 +618,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
||||
triton_gpu.async_commit_group
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -627,12 +627,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
func.func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
tt.func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
%0 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>,#blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,13 +641,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
tt.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,7 +658,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked
|
||||
func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
@@ -694,7 +694,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,7 +705,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_vec
|
||||
func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
@@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -728,7 +728,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
|
||||
func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
@@ -746,7 +746,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -759,7 +759,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -776,14 +776,14 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: problems in MLIR's parser on slice layout
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// func.func @make_range_sliced_layout() {
|
||||
// tt.func @make_range_sliced_layout() {
|
||||
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
// return
|
||||
// }
|
||||
@@ -796,7 +796,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav2_block
|
||||
func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -805,7 +805,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -816,7 +816,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav1_block
|
||||
func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -829,7 +829,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -839,13 +839,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_shared
|
||||
func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -855,10 +855,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice0
|
||||
func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -868,10 +868,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice1
|
||||
func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -881,14 +881,14 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked_to_blocked_ptr
|
||||
func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
||||
tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
||||
// CHECK: llvm.ptrtoint
|
||||
// CHECK: llvm.store
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.inttoptr
|
||||
// CHECK-COUNT-4: llvm.insertvalue
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32x!tt.ptr<f32>, #blocked0>) -> tensor<32x!tt.ptr<f32>, #blocked1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -900,7 +900,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
@@ -913,7 +913,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<128x256xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -926,7 +926,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
@@ -938,7 +938,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<32x64xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -949,7 +949,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
|
||||
// CHECK: llvm.intr.fmuladd
|
||||
@@ -960,7 +960,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %28 : tensor<32x32xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -973,7 +973,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
// CHECK: llvm.inline_asm
|
||||
@@ -999,7 +999,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<32x32xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1008,14 +1008,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1023,12 +1023,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
func.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1037,14 +1037,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32
|
||||
func.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1052,12 +1052,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32_scalar
|
||||
func.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
|
||||
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : f32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
@@ -1078,7 +1078,7 @@ func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1086,7 +1086,7 @@ func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
tt.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
||||
@@ -1098,7 +1098,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1106,12 +1106,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: test_index_cache
|
||||
func.func @test_index_cache() {
|
||||
tt.func @test_index_cache() {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1120,12 +1120,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_base_index_cache
|
||||
func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1134,7 +1134,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_index_cache_different_block
|
||||
func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||
tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
@@ -1143,6 +1143,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
cf.br ^bb2
|
||||
^bb2: // 2 preds: ^bb0, ^bb1
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_combine_dot_add_pattern
|
||||
func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
||||
tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
||||
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
||||
@@ -19,12 +19,12 @@ func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x12
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
|
||||
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
tt.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
@@ -42,12 +42,12 @@ func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
|
||||
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||
func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
tt.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
@@ -59,12 +59,12 @@ func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>,
|
||||
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%1 = arith.select %cond, %y, %false_val : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
// CHECK: tt.return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
||||
tt.return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
||||
func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
||||
@@ -82,21 +82,21 @@ func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f
|
||||
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
|
||||
|
||||
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||
func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
|
||||
%const = arith.constant dense<1.0> : tensor<8xf32>
|
||||
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
|
||||
|
||||
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
|
||||
return %bst_out : tensor<8x2xf32>
|
||||
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
|
||||
tt.return %bst_out : tensor<8x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
||||
func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
@@ -112,12 +112,12 @@ func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -
|
||||
// false_mask with other. It should become "other" (i.e., %y)
|
||||
%z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
// CHECK: tt.return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
tt.return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
|
||||
func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
tt.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
|
||||
@@ -126,11 +126,11 @@ func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
return %x, %y: tensor<8xf32>, tensor<8xf32>
|
||||
tt.return %x, %y: tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
||||
func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
tt.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
|
||||
@@ -138,31 +138,31 @@ func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>,
|
||||
tt.store %ptr, %val, %true_mask : tensor<8xf32>
|
||||
|
||||
// The following store should disappear.
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: tt.return
|
||||
tt.store %ptr, %val, %false_mask : tensor<8xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
|
||||
func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
||||
tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
||||
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
|
||||
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
tt.store %ptr, %val, %mask : tensor<8xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_expand_dims
|
||||
func.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
|
||||
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
|
||||
%splat = tt.splat %arg0 : (tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK: %{{.*}} = tt.splat %arg0 : (tensor<f32>) -> tensor<1x8xf32>
|
||||
%ed = tt.expand_dims %splat {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
|
||||
|
||||
return %ed : tensor<1x8xf32>
|
||||
tt.return %ed : tensor<1x8xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_view
|
||||
func.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
|
||||
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
|
||||
%view0 = tt.view %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %{{.*}} = tt.view %arg0 : (tensor<8xf32>) -> tensor<4x2xf32>
|
||||
%view1 = tt.view %view0 : (tensor<2x4xf32>) -> tensor<4x2xf32>
|
||||
@@ -175,11 +175,11 @@ func.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (
|
||||
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
|
||||
%add = arith.addf %view3, %arg0 : tensor<8xf32>
|
||||
|
||||
return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
|
||||
tt.return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_broadcast
|
||||
func.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f32>) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) {
|
||||
tt.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f32>) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) {
|
||||
%broadcast0 = tt.broadcast %arg0 : (tensor<1x1x8xf32>) -> tensor<1x2x8xf32>
|
||||
// CHECK: %{{.*}} = tt.broadcast %arg0 : (tensor<1x1x8xf32>) -> tensor<4x2x8xf32>
|
||||
%broadcast1 = tt.broadcast %broadcast0 : (tensor<1x2x8xf32>) -> tensor<4x2x8xf32>
|
||||
@@ -192,11 +192,11 @@ func.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f
|
||||
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<1x1x8xf32>
|
||||
%add = arith.addf %broadcast3, %arg0 : tensor<1x1x8xf32>
|
||||
|
||||
return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>
|
||||
tt.return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_fold_views
|
||||
func.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) {
|
||||
tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) {
|
||||
%a = arith.constant dense<1.0> : tensor<1x128xf32>
|
||||
|
||||
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
|
||||
@@ -208,5 +208,5 @@ func.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x
|
||||
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<1x1x128xf32>
|
||||
%d = tt.expand_dims %a {axis = 0: i32} : (tensor<1x128xf32>) -> tensor<1x1x128xf32>
|
||||
|
||||
return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>
|
||||
tt.return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s
|
||||
func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c31_i32 = arith.constant 31 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c1 = arith.constant 1 : index
|
||||
@@ -79,5 +79,5 @@ func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}
|
||||
%53 = tt.broadcast %51 : (tensor<1x32xi1>) -> tensor<128x32xi1>
|
||||
%54 = arith.andi %52, %53 : tensor<128x32xi1>
|
||||
tt.store %45, %30, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<128x32xf16>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: triton-opt %s -verify-diagnostics
|
||||
|
||||
module {
|
||||
func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -39,11 +39,11 @@ module {
|
||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
tt.store %17, %15#0, %6 : tensor<256xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
// module {
|
||||
// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
// tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
// %c64 = arith.constant 64 : index
|
||||
// %c32 = arith.constant 32 : index
|
||||
// %c0 = arith.constant 0 : index
|
||||
@@ -125,6 +125,6 @@ module {
|
||||
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// return
|
||||
// tt.return
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
|
||||
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
|
||||
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg1: i32 {tt.divisibility = 16 : i32},
|
||||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
@@ -47,7 +47,7 @@ func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,34 +9,34 @@
|
||||
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK-LABEL: cst
|
||||
func.func @cst() -> tensor<1024xi32, #layout1> {
|
||||
tt.func @cst() -> tensor<1024xi32, #layout1> {
|
||||
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %cst : tensor<1024xi32, [[$target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
// CHECK: tt.return %cst : tensor<1024xi32, [[$target_layout]]>
|
||||
tt.return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: range
|
||||
func.func @range() -> tensor<1024xi32, #layout1> {
|
||||
tt.func @range() -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %0 : tensor<1024xi32, [[$target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
// CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]>
|
||||
tt.return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: splat
|
||||
func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
tt.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %0 : tensor<1024xi32, [[$target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
// CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]>
|
||||
tt.return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remat
|
||||
func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
|
||||
@@ -44,7 +44,7 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%4 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
%6 = arith.addi %3, %5 : tensor<1024xi32, #layout1>
|
||||
return %6: tensor<1024xi32, #layout1>
|
||||
tt.return %6: tensor<1024xi32, #layout1>
|
||||
// CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
|
||||
@@ -52,24 +52,24 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
// CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: return %6 : tensor<1024xi32, [[$target_layout]]>
|
||||
// CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]>
|
||||
}
|
||||
|
||||
// Always rematerialize single value loads
|
||||
// CHECK-LABEL: remat_single_value
|
||||
func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #layout1>) -> tensor<1xi32, #layout0>
|
||||
%3 = triton_gpu.convert_layout %0 : (tensor<1x!tt.ptr<i32>, #layout1>) -> tensor<1x!tt.ptr<i32>, #layout0>
|
||||
tt.store %3, %2 : tensor<1xi32, #layout0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
|
||||
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
|
||||
@@ -78,12 +78,12 @@ func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
|
||||
tt.store %5, %4 : tensor<16xi32, #layout0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if
|
||||
func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -96,11 +96,11 @@ func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%6 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout1>) -> tensor<1024xi32, #layout0>
|
||||
tt.store %5, %6 : tensor<1024xi32, #layout0>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_convert_else_not
|
||||
func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -117,11 +117,11 @@ func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
|
||||
}
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
tt.store %5, %8 : tensor<1024xi32, #layout1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_not_else_convert
|
||||
func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -138,11 +138,11 @@ func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
|
||||
}
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
tt.store %5, %8 : tensor<1024xi32, #layout1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_else_both_convert
|
||||
func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
@@ -161,7 +161,7 @@ func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
|
||||
// disabledCHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
tt.store %5, %8 : tensor<1024xi32, #layout1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
@@ -173,12 +173,12 @@ func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[$row_layout]]>
|
||||
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout]]>
|
||||
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[$col_layout]]>
|
||||
// CHECK: return
|
||||
// CHECK: tt.return
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
@@ -210,65 +210,65 @@ func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i3
|
||||
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
||||
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
|
||||
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop
|
||||
func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[$row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]>
|
||||
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
|
||||
%c1 = arith.constant 1 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[$row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]>
|
||||
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
|
||||
%c1 = arith.constant 1 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: vecadd
|
||||
func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
@@ -295,12 +295,12 @@ func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.p
|
||||
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
|
||||
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1>
|
||||
tt.store %21, %22 : tensor<256xf32, #layout1>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Select has args with different element types
|
||||
// CHECK-LABEL: select
|
||||
func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
|
||||
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
|
||||
@@ -346,12 +346,12 @@ func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.p
|
||||
tt.store %31, %32, %33 : tensor<1x512xf64, #blocked3>
|
||||
scf.yield %30 : tensor<1x512xf64, #blocked2>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// Make sure the following IR doesn't hang the compiler.
|
||||
// CHECK-LABEL: long_func
|
||||
func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
|
||||
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
|
||||
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
|
||||
@@ -742,13 +742,13 @@ func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %a
|
||||
%365 = triton_gpu.convert_layout %364 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
|
||||
%366 = triton_gpu.convert_layout %343 : (tensor<1024xf64, #blocked0>) -> tensor<1024xf64, #blocked0>
|
||||
tt.store %365, %366 : tensor<1024xf64, #blocked0>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// A mnist model from torch inductor.
|
||||
// Check if topological sort is working correct and there's no unnecessary convert
|
||||
// CHECK-LABEL: mnist
|
||||
func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
|
||||
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
|
||||
@@ -822,7 +822,7 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
|
||||
%62 = triton_gpu.convert_layout %58 : (tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked4>
|
||||
%63 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
|
||||
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -835,7 +835,7 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
|
||||
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
// cmpf and cmpi have different operands and result types
|
||||
// CHECK-LABEL: cmp
|
||||
func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64 = arith.constant 64 : index
|
||||
%c2048 = arith.constant 2048 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
@@ -968,14 +968,14 @@ func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%82 = triton_gpu.convert_layout %54 : (tensor<64x64xi1, #blocked2>) -> tensor<64x64xi1, #blocked4>
|
||||
tt.store %80, %81, %82 : tensor<64x64xf16, #blocked4>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Just make sure it doesn't crash on non-tensor types.
|
||||
// CHECK-LABEL: if_no_tensor
|
||||
func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c-1_i64 = arith.constant -1 : i64
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
@@ -996,7 +996,7 @@ func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%8 = tt.load %5, %7, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
|
||||
%9 = tt.addptr %arg1, %0 : !tt.ptr<f32>, i32
|
||||
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -1009,7 +1009,7 @@ func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
|
||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
func.func public @reduce_cvt1(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) {
|
||||
tt.func public @reduce_cvt1(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) {
|
||||
%cst = arith.constant dense<0> : tensor<1x2xi32, #blocked>
|
||||
%cst_0 = arith.constant dense<2> : tensor<1x2xi32, #blocked>
|
||||
%0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1>
|
||||
@@ -1029,7 +1029,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%14 = triton_gpu.convert_layout %12 : (tensor<1x2xi64, #blocked>) -> tensor<1x2xi64, #blocked3>
|
||||
%15 = triton_gpu.convert_layout %3 : (tensor<1x2xi1, #blocked>) -> tensor<1x2xi1, #blocked3>
|
||||
tt.store %13, %14, %15 {cache = 1 : i32, evict = 1 : i32} : tensor<1x2xi64, #blocked3>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1045,7 +1045,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
|
||||
tt.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
|
||||
%c3136_i32 = arith.constant 3136 : index
|
||||
%c256_i32 = arith.constant 256 : index
|
||||
@@ -1104,6 +1104,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%25 = triton_gpu.convert_layout %21 : (tensor<1x1xf32, #blocked>) -> tensor<1x1xf32, #blocked>
|
||||
%26 = triton_gpu.convert_layout %7 : (tensor<1x1xi1, #blocked>) -> tensor<1x1xi1, #blocked>
|
||||
tt.store %24, %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32, #blocked>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
// CHECK: func.func @matmul_loop
|
||||
// CHECK: tt.func @matmul_loop
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -47,7 +47,7 @@
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
|
||||
// A ptrs
|
||||
@@ -88,10 +88,10 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return %loop#2: tensor<128x128xf32, #C>
|
||||
tt.return %loop#2: tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
// CHECK: func.func @matmul_loop_nested
|
||||
// CHECK: tt.func @matmul_loop_nested
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -120,7 +120,7 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{
|
||||
|
||||
@@ -162,11 +162,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
|
||||
scf.yield %loop2#2 : tensor<128x128xf32, #C>
|
||||
}
|
||||
return %loop1#0 : tensor<128x128xf32, #C>
|
||||
tt.return %loop1#0 : tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func.func @matmul_loop_single_pipeline
|
||||
// CHECK: tt.func @matmul_loop_single_pipeline
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||
@@ -187,7 +187,7 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
|
||||
// A ptrs
|
||||
@@ -222,10 +222,10 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return %loop#1 : tensor<128x128xf32, #C>
|
||||
tt.return %loop#1 : tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
// CHECK: func.func @lut_bmm_scalar
|
||||
// CHECK: tt.func @lut_bmm_scalar
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
@@ -239,7 +239,7 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]]
|
||||
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
|
||||
tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
|
||||
%76: index,
|
||||
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%75: !tt.ptr<i64>,
|
||||
@@ -265,10 +265,10 @@ func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
|
||||
%92 = tt.addptr %arg21, %c1_i32 : !tt.ptr<i64>, i32
|
||||
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>
|
||||
}
|
||||
return %79#0 : tensor<16x16xf32, #C>
|
||||
tt.return %79#0 : tensor<16x16xf32, #C>
|
||||
}
|
||||
|
||||
// CHECK: func.func @lut_bmm_vector
|
||||
// CHECK: tt.func @lut_bmm_vector
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
@@ -283,7 +283,7 @@ func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
|
||||
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]]
|
||||
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
func.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32},
|
||||
tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32},
|
||||
%76: index,
|
||||
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%75: tensor<16x!tt.ptr<i64>, #BLs1>,
|
||||
@@ -311,5 +311,5 @@ func.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32,
|
||||
%92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr<i64>, #BLs1>, tensor<16xi32, #BLs1>
|
||||
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x!tt.ptr<i64>, #BLs1>
|
||||
}
|
||||
return %79#0 : tensor<16x16xf32, #C>
|
||||
tt.return %79#0 : tensor<16x16xf32, #C>
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// CHECK: offset = 49152, size = 49152
|
||||
// CHECK: size = 98304
|
||||
module {
|
||||
func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
||||
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1>
|
||||
%c64 = arith.constant 64 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
@@ -101,6 +101,6 @@ func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32_
|
||||
%74 = tt.broadcast %72 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||
%75 = arith.andi %73, %74 : tensor<64x64xi1>
|
||||
tt.store %66, %47#0, %75 : tensor<64x64xf32>
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
|
||||
// CHECK: func.func @matmul_loop
|
||||
// CHECK: tt.func @matmul_loop
|
||||
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
||||
@@ -28,7 +28,7 @@
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
@@ -60,5 +60,5 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAliasPass
|
||||
: public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> {
|
||||
: public PassWrapper<TestAliasPass, OperationPass<triton::FuncOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAllocationPass
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> {
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<triton::FuncOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<triton::FuncOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestMembarPass
|
||||
: public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> {
|
||||
: public PassWrapper<TestMembarPass, OperationPass<triton::FuncOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user