[BACKEND] Replace Func Dialect with custom triton ops (func, call, return) (#1502)

MLIR current only supports a custom inlining interface per dialect, so
we cannot change the inlining decision of `func.func`.


https://discourse.llvm.org/t/avoid-inlining-some-functions-using-the-func-dialect/69830/3

Could revert it back once they've designed a better inliner interface.

Inlining attributes will be implemented in the next PR since this PR is
already huge.
This commit is contained in:
Keren Zhou
2023-04-10 21:08:40 -07:00
committed by GitHub
parent 640f3c3921
commit 6d0ed41307
37 changed files with 834 additions and 502 deletions

View File

@@ -32,7 +32,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect>();
}

View File

@@ -19,15 +19,16 @@ def Triton_Dialect : Dialect {
* Math:
* exp, sin, cos, log, ...
* StructuredControlFlow:
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
* for, if, while, yield, condition
* ControlFlow:
* br, cond_br
}];
let dependentDialects = [
"arith::ArithDialect",
"math::MathDialect",
"scf::SCFDialect",
"cf::ControlFlowDialect",
"func::FuncDialect"
"cf::ControlFlowDialect"
];
let extraClassDeclaration = [{

View File

@@ -6,11 +6,15 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
//
// Op Base
@@ -20,7 +24,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
}
//
// CastOps
// Cast Ops
//
// Use cast ops in arith:
// bitcast
@@ -93,7 +97,6 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
//
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[Pure,
SameOperandsAndResultShape,
@@ -209,7 +212,7 @@ def TT_StoreOp : TT_Op<"store",
}
//
// Atomic Op
// Atomic Ops
//
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
@@ -256,7 +259,6 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
let results = (outs TT_Type:$result);
}
//
// Shape Manipulation Ops
//
@@ -408,7 +410,7 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure,
}
//
// External elementwise op
// External Elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
@@ -449,7 +451,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
}
//
// Make PrintOp
// Print Op
//
def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
@@ -464,7 +466,7 @@ def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
}
//
// Make AssertOp
// Assert Op
//
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let summary = "Device-side assert, as in CUDA for correctness checking";
@@ -477,7 +479,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
}
//
// Make a Tensor Pointer
// Make Tensor Pointer Op
//
def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
[Pure,
@@ -518,4 +520,199 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
];
}
// The following ops, including `call`, `func`, and `return` are copied and modified from
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
// We could revert it back once MLIR has a better inliner interface.
//
// Function Ops
//
def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "call operation";
let description = [{
The `tt.call` operation represents a direct call to a function that is
within the same symbol scope as the call. The operands and result types of
the call must match the specified function type. The callee is encoded as a
symbol reference attribute named "callee".
Example:
```mlir
%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32
```
}];
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
$_state.addTypes(callee.getFunctionType().getResults());
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
$_state.addTypes(results);
}]>,
OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, SymbolRefAttr::get(callee), results, operands);
}]>,
OpBuilder<(ins "StringRef":$callee, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
results, operands);
}]>];
let extraClassDeclaration = [{
FunctionType getCalleeType() {
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
}];
let assemblyFormat = [{
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
}];
}
def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> {
let summary = "An operation with a name containing a single `SSACFG` region";
let description = [{
Operations within the function cannot implicitly capture values defined
outside of the function, i.e. Functions are `IsolatedFromAbove`. All
external references must use function arguments or attributes that establish
a symbolic connection (e.g. symbols referenced by name via a string
attribute like SymbolRefAttr). An external function declaration (used when
referring to a function declared in some other module) has no body. While
the MLIR textual form provides a nice inline syntax for function arguments,
they are internally represented as “block arguments” to the first block in
the region.
Only dialect attribute names may be specified in the attribute dictionaries
for function arguments, results, or the function itself.
Example:
```mlir
// External function definitions.
tt.func @abort()
tt.func @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64
// A function that returns its argument twice:
tt.func @count(%x: i64) -> (i64, i64)
attributes {fruit: "banana"} {
return %x, %x: i64, i64
}
// A function with an argument attribute
tt.func @example_fn_arg(%x: i32 {swift.self = unit})
// A function with a result attribute
tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})
// A function with an attribute
tt.func @example_fn_attr() attributes {dialectName.attrName = false}
```
}];
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<StrAttr>:$sym_visibility,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
>];
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
/// Returns the region on the current operation that is callable. This may
/// return null in the case of an external callable object, e.g. an external
/// function.
::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//
bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
}
def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> {
let summary = "Function return operation";
let description = [{
The `tt.return` operation represents a return operation within a function.
The operation takes variable number of operands and produces no results.
The operand number and types must match the signature of the function
that contains the operation.
Example:
```mlir
tt.func @foo() : (i32, f8) {
...
tt.return %0, %1 : i32, f8
}
```
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<(ins), [{
build($_builder, $_state, std::nullopt);
}]>];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasVerifier = 1;
}
#endif // Triton_OPS

View File

@@ -77,7 +77,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (auto fun = dyn_cast<func::FuncOp>(op))
if (auto fun = dyn_cast<triton::FuncOp>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);

View File

@@ -74,7 +74,7 @@ void MembarAnalysis::visitTerminator(Operation *op,
return;
}
// Otherwise, it could be a return op
assert(isa<func::ReturnOp>(op) && "Unknown terminator");
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,

View File

@@ -9,11 +9,11 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();

View File

@@ -4,7 +4,6 @@
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Analysis/Allocation.h"
#include "TypeConverter.h"
@@ -41,12 +40,12 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
// All the rights are reserved by the LLVM community.
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
private:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs,
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : op->getAttrs()) {
@@ -66,12 +65,12 @@ private:
}
protected:
using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<triton::FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
convertFuncOpToLLVMFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.

View File

@@ -51,16 +51,15 @@ public:
} else {
addLegalDialect<NVVM::NVVMDialect>();
}
addIllegalOp<mlir::func::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
@@ -86,7 +85,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp) {

View File

@@ -163,8 +163,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are
// inlined by the frontend
target.addLegalOp<triton::ReturnOp>(); // this is ok because all functions are
// inlined by the frontend
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
@@ -721,15 +721,15 @@ public:
}
};
class FuncOpPattern : public OpConversionPattern<func::FuncOp> {
class FuncOpPattern : public OpConversionPattern<triton::FuncOp> {
public:
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
auto newOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
op, op.getName(), op.getFunctionType());
addNamedAttrs(newOp, adaptor.getAttributes());
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),

View File

@@ -12,5 +12,4 @@ add_mlir_dialect_library(TritonIR
MLIRIR
MLIRArithDialect
MLIRSCFDialect
MLIRFuncDialect
)

View File

@@ -1,6 +1,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -21,6 +22,10 @@ using namespace mlir::triton;
namespace {
struct TritonInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final {
return true;
@@ -29,6 +34,37 @@ struct TritonInlinerInterface : public DialectInlinerInterface {
IRMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op, Block *newDest) const final {
// Only return needs to be handled here.
auto returnOp = dyn_cast<triton::ReturnOp>(op);
if (!returnOp)
return;
// Replace the return with a branch to the dest.
OpBuilder builder(op);
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
returnOp.getOperands());
op->erase();
}
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only return needs to be handled here.
auto returnOp = cast<triton::ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};
} // namespace

View File

@@ -1,6 +1,8 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
@@ -516,5 +518,105 @@ void MakeTensorPtrOp::build(::mlir::OpBuilder &builder,
builder.getDenseI32ArrayAttr(order));
}
// The following ops, including `call`, `func`, and `return` are copied and
// modified from
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp
// We could revert it back once MLIR has a better inliner interface.
//-- FuncOp --
void triton::FuncOp::build(OpBuilder &builder, OperationState &state,
StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult triton::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void triton::FuncOp::print(OpAsmPrinter &printer) {
function_interface_impl::printFunctionOp(
printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
// -- CallOp --
LogicalResult
triton::CallOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
auto fnType = fn.getFunctionType();
if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (getOperand(i).getType() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
if (getResult(i).getType() != fnType.getResult(i)) {
auto diag = emitOpError("result type mismatch at index ") << i;
diag.attachNote() << " op result types: " << getResultTypes();
diag.attachNote() << "function result types: " << fnType.getResults();
return diag;
}
return success();
}
// -- ReturnOp --
LogicalResult triton::ReturnOp::verify() {
auto function = cast<triton::FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
} // namespace triton
} // namespace mlir

View File

@@ -81,18 +81,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
func::FuncDialect, triton::TritonDialect,
cf::ControlFlowDialect, scf::SCFDialect>(
[&](Operation *op) {
bool hasLegalRegions = true;
for (auto &region : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
}
if (hasLegalRegions && typeConverter.isLegal(op)) {
return true;
}
return false;
});
triton::TritonDialect, cf::ControlFlowDialect,
scf::SCFDialect>([&](Operation *op) {
bool hasLegalRegions = true;
for (auto &region : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
}
if (hasLegalRegions && typeConverter.isLegal(op)) {
return true;
}
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {

View File

@@ -208,7 +208,7 @@ void init_triton_ir(py::module &&m) {
std::string attrName = name + "_arg" + std::to_string(id);
mlir::Block *owner = arg.getOwner();
if (owner->isEntryBlock() &&
!mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) {
!mlir::isa<mlir::triton::FuncOp>(owner->getParentOp())) {
owner->getParentOp()->setAttr(attrName, attr);
}
}
@@ -361,7 +361,7 @@ void init_triton_ir(py::module &&m) {
return str;
})
.def("push_back",
[](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void {
[](mlir::ModuleOp &self, mlir::triton::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function",
@@ -372,13 +372,14 @@ void init_triton_ir(py::module &&m) {
})
.def("get_function",
[](mlir::ModuleOp &self,
std::string &funcName) -> mlir::func::FuncOp {
return self.lookupSymbol<mlir::func::FuncOp>(funcName);
std::string &funcName) -> mlir::triton::FuncOp {
return self.lookupSymbol<mlir::triton::FuncOp>(funcName);
})
.def("get_single_function",
[](mlir::ModuleOp &self) -> mlir::func::FuncOp {
llvm::SmallVector<mlir::func::FuncOp> funcs;
self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); });
[](mlir::ModuleOp &self) -> mlir::triton::FuncOp {
llvm::SmallVector<mlir::triton::FuncOp> funcs;
self.walk(
[&](mlir::triton::FuncOp func) { funcs.push_back(func); });
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
@@ -400,12 +401,11 @@ void init_triton_ir(py::module &&m) {
// initialize registry
// note: we initialize llvm for undef
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::index::IndexDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::LLVM::LLVMDialect>();
registry.insert<
mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::index::IndexDialect, mlir::scf::SCFDialect,
mlir::cf::ControlFlowDialect, mlir::LLVM::LLVMDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
@@ -423,30 +423,30 @@ void init_triton_ir(py::module &&m) {
},
ret::take_ownership);
py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function")
py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function")
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
.def("args",
[](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
[](mlir::triton::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
return self.getArgument(idx);
})
.def(
"add_entry_block",
[](mlir::func::FuncOp &self) -> mlir::Block * {
[](mlir::triton::FuncOp &self) -> mlir::Block * {
return self.addEntryBlock();
},
ret::reference)
.def(
"set_arg_attr",
[](mlir::func::FuncOp &self, int arg_no, const std::string &name,
[](mlir::triton::FuncOp &self, int arg_no, const std::string &name,
int val) {
// set arg attributes "name" to value "val"
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def_property_readonly("type", &mlir::func::FuncOp::getFunctionType)
.def("reset_type", &mlir::func::FuncOp::setType);
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
.def("reset_type", &mlir::triton::FuncOp::setType);
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -463,13 +463,13 @@ void init_triton_ir(py::module &&m) {
.def("ret",
[](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::func::ReturnOp>(loc, vals);
self.create<mlir::triton::ReturnOp>(loc, vals);
})
.def("call",
[](mlir::OpBuilder &self, mlir::func::FuncOp &func,
[](mlir::OpBuilder &self, mlir::triton::FuncOp &func,
std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::func::CallOp>(loc, func, args);
return self.create<mlir::triton::CallOp>(loc, func, args);
})
// insertion block/point
.def("set_insertion_point_to_start",
@@ -651,16 +651,16 @@ void init_triton_ir(py::module &&m) {
.def("get_or_insert_function",
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType,
std::string &visibility) -> mlir::func::FuncOp {
std::string &visibility) -> mlir::triton::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation);
return llvm::dyn_cast<mlir::triton::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
llvm::SmallVector<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
return self.create<mlir::func::FuncOp>(loc, funcName, funcTy,
attrs);
return self.create<mlir::triton::FuncOp>(loc, funcName, funcTy,
attrs);
}
throw std::runtime_error("invalid function type");
})

View File

@@ -2238,7 +2238,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
@@ -2256,7 +2256,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
tt.return
}
}
"""

View File

@@ -267,14 +267,14 @@ def make_hash(fn, **kwargs):
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,

View File

@@ -11,7 +11,7 @@
// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
@@ -32,38 +32,38 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
tt.return
}
// CHECK-LABEL: alloc
func.func @alloc(%A : !tt.ptr<f16>) {
tt.func @alloc(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: %0 -> %0
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: convert
func.func @convert(%A : !tt.ptr<f16>) {
tt.func @convert(%A : !tt.ptr<f16>) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: %0 -> %0
%cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: trans
func.func @trans(%A : !tt.ptr<f16>) {
tt.func @trans(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK: %0 -> %cst
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
return
tt.return
}
// CHECK-LABEL: insert_slice_async
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
tt.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -72,11 +72,11 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%index = arith.constant 0 : i32
// CHECK: %2 -> %cst_0
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: insert_slice
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
tt.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -86,21 +86,21 @@ func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
%a = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
// CHECK: %inserted_slice -> %cst_0
%b = tensor.insert_slice %a into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: extract_slice
func.func @extract_slice(%A : !tt.ptr<f16>) {
tt.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
// CHECK-NEXT: %0 -> %cst
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: if_cat
func.func @if_cat(%i1 : i1) {
tt.func @if_cat(%i1 : i1) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: %cst_0 -> %cst_0
@@ -115,11 +115,11 @@ func.func @if_cat(%i1 : i1) {
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %b : tensor<32x16xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: if_alias
func.func @if_alias(%i1 : i1) {
tt.func @if_alias(%i1 : i1) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -130,11 +130,11 @@ func.func @if_alias(%i1 : i1) {
} else {
scf.yield %cst1 : tensor<16x16xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: for
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -150,11 +150,11 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: for_if
func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -176,11 +176,11 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -211,11 +211,11 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
}
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: cf_for
func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -242,5 +242,5 @@ func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>,
gpu.barrier
// CHECK-NEXT: %9 -> %9
%9 = tt.cat %0, %0 {axis = 0 : i64} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
tt.return
}

View File

@@ -1,7 +1,7 @@
// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s
// CHECK-LABEL: @cast
func.func @cast() {
tt.func @cast() {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
@@ -10,13 +10,13 @@ func.func @cast() {
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
return
tt.return
}
// -----
// CHECK-LABEL: @add
func.func @add() {
tt.func @add() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
@@ -27,13 +27,13 @@ func.func @add() {
%3 = arith.constant dense<127> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%4 = arith.addi %1, %3 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @addptr
func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
tt.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
@@ -84,13 +84,13 @@ func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.pt
%21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr<i32>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = <none>
%22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr<i64>>, tensor<128x128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @sub
func.func @sub() {
tt.func @sub() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
@@ -101,13 +101,13 @@ func.func @sub() {
%3 = arith.constant dense<129> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%4 = arith.subi %3, %1 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @mul
func.func @mul() {
tt.func @mul() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
@@ -122,13 +122,13 @@ func.func @mul() {
%5 = arith.constant dense<2> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256
%6 = arith.muli %4, %5 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @div
func.func @div() {
tt.func @div() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
@@ -153,14 +153,14 @@ func.func @div() {
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%11 = arith.divsi %10, %4 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @rem
func.func @rem() {
tt.func @rem() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
@@ -179,35 +179,35 @@ func.func @rem() {
%7 = arith.constant dense<66> : tensor<128xi32>
// CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>
%8 = arith.remui %0, %7 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @broadcast
func.func @broadcast() {
tt.func @broadcast() {
// CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%0 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @splat
func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
return
tt.return
}
// -----
// CHECK-LABEL: @cmp
func.func @cmp() {
tt.func @cmp() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
@@ -226,13 +226,13 @@ func.func @cmp() {
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @logic
func.func @logic() {
tt.func @logic() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
@@ -255,13 +255,13 @@ func.func @logic() {
%9 = arith.ori %2, %4 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%10 = arith.xori %2, %4 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @select
func.func @select() {
tt.func @select() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
@@ -278,12 +278,12 @@ func.func @select() {
%5 = arith.select %4, %3, %7 : tensor<128xi1>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
return
tt.return
}
// -----
func.func @shift() {
tt.func @shift() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
@@ -296,12 +296,12 @@ func.func @shift() {
%4 = arith.shrsi %0, %2 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%5 = arith.shli %1, %2 : tensor<128xi32>
return
tt.return
}
// -----
func.func @max_min() {
tt.func @max_min() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none>
@@ -316,13 +316,13 @@ func.func @max_min() {
%5 = arith.constant dense<4> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8
%6 = arith.maxsi %4, %5 : tensor<128xi32>
return
tt.return
}
// -----
// CHECK-LABEL: @for
func.func @for() {
tt.func @for() {
// CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0
%a_init = arith.constant dense<0> : tensor<128x32xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1
@@ -343,13 +343,13 @@ func.func @for() {
// CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
}
return
tt.return
}
// -----
// CHECK-LABEL: @permute_2d
func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1
%cst = arith.constant dense<true> : tensor<128x128xi1>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
@@ -397,7 +397,7 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst : tensor<128x128xf32>
return
tt.return
}
// -----
@@ -406,7 +406,7 @@ module {
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
// CHECK-LABEL: @store_constant_align
func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%pid = tt.get_program_id {axis = 0 : i32} : i32
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
@@ -430,7 +430,7 @@ func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%cst = arith.constant dense<0.0> : tensor<128xf32>
tt.store %5, %cst, %mask : tensor<128xf32>
return
tt.return
}
}
@@ -440,7 +440,7 @@ func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
// This IR is dumped from vecadd test.
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
// CHECK-LABEL: @vecadd_mask_align_16
func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -461,7 +461,7 @@ func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
// CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %mask : tensor<64xf32>
return
tt.return
}
// -----
@@ -469,7 +469,7 @@ func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
// This IR is dumped from vecadd test.
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
// CHECK-LABEL: @vecadd_mask_align_1
func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -489,5 +489,5 @@ func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %10 : tensor<64xf32>
return
tt.return
}

View File

@@ -13,7 +13,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
@@ -40,13 +40,13 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
tt.return
// CHECK-NEXT: size = 4608
}
// Shared memory is available after a tensor's liveness range ends
// CHECK-LABEL: reusable
func.func @reusable(%A : !tt.ptr<f16>) {
tt.func @reusable(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
@@ -69,7 +69,7 @@ func.func @reusable(%A : !tt.ptr<f16>) {
// CHECK-NEXT: offset = 0, size = 1152
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
return
tt.return
// CHECK-NEXT: size = 4608
}
@@ -78,7 +78,7 @@ func.func @reusable(%A : !tt.ptr<f16>) {
// %cst1->%cst4
// %cst3->%g->%h->%i
// CHECK-LABEL: preallocate
func.func @preallocate(%A : !tt.ptr<f16>) {
tt.func @preallocate(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 1024, size = 512
@@ -107,13 +107,13 @@ func.func @preallocate(%A : !tt.ptr<f16>) {
%h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 2048, size = 4096
%i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 12288
}
// Unused tensors are immediately released
// CHECK-LABEL: unused
func.func @unused(%A : !tt.ptr<f16>) {
tt.func @unused(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 0, size = 512
@@ -122,13 +122,13 @@ func.func @unused(%A : !tt.ptr<f16>) {
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 1024, size = 1024
%a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
tt.return
// CHECK: size = 2048
}
// cst0 is alive through the entire function, it cannot be released before the end of the function
// CHECK-LABEL: longlive
func.func @longlive(%A : !tt.ptr<f16>) {
tt.func @longlive(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -151,13 +151,13 @@ func.func @longlive(%A : !tt.ptr<f16>) {
%c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 1024
%d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 2560
}
// This example triggers graph coloring with > 1 colors.
// CHECK-LABEL: multi_color
func.func @multi_color(%A : !tt.ptr<f16>) {
tt.func @multi_color(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 64
%cst = arith.constant dense<0.000000e+00> : tensor<4x8xf16, #A_SHARED>
// CHECK-NEXT: offset = 1216, size = 32
@@ -199,39 +199,39 @@ func.func @multi_color(%A : !tt.ptr<f16>) {
%cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL>
// CHECK-NEXT: size = 2656
return
tt.return
}
// CHECK-LABEL: alloc
func.func @alloc(%A : !tt.ptr<f16>) {
tt.func @alloc(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 512
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: scratch
func.func @scratch() {
tt.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: scratch offset = 0, size = 512
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
return
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: trans
func.func @trans(%A : !tt.ptr<f16>) {
tt.func @trans(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
return
tt.return
}
// CHECK-LABEL: insert_slice_async
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
tt.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -239,24 +239,24 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: extract_slice
func.func @extract_slice(%A : !tt.ptr<f16>) {
tt.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 512
}
// B0 -> (B1) -> B0
// Memory used by B1 can be reused by B0.
// CHECK-LABEL: if
func.func @if(%i1 : i1) {
tt.func @if(%i1 : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -273,14 +273,14 @@ func.func @if(%i1 : i1) {
%cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 1024, size = 1024
%a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 2048
}
// B0 -> (B1) -> (B2) -> B0
// Memory used by B0 cannot be reused by B1 or B2.
// CHECK-LABEL: if_else
func.func @if_else(%i1 : i1) {
tt.func @if_else(%i1 : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -300,14 +300,14 @@ func.func @if_else(%i1 : i1) {
}
// CHECK-NEXT: offset = 1024, size = 1024
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 3072
}
// Block arguments and yields are memory aliases that do not trigger a new
// allocation.
// CHECK-LABEL: for
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -317,12 +317,12 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
// CHECK-NEXT: size = 24576
}
// CHECK-LABEL: for_if_slice
func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -337,13 +337,13 @@ func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f1
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
// CHECK-NEXT: size = 24576
}
// c0 cannot be released in the loop
// CHECK-LABEL: for_use_ancestor
func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -356,14 +356,14 @@ func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.pt
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
// CHECK-NEXT: size = 32768
}
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -387,7 +387,7 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
}
// CHECK-NEXT: offset = 0, size = 8192
%cst2 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
return
tt.return
// CHECK-NEXT: size = 40960
}

View File

@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
// There shouldn't be any membar with the dot op encoding.
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
@@ -38,11 +38,11 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
tt.return
}
// CHECK-LABEL: raw_single_block
func.func @raw_single_block(%A : !tt.ptr<f16>) {
tt.func @raw_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -51,11 +51,11 @@ func.func @raw_single_block(%A : !tt.ptr<f16>) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: war_single_block
func.func @war_single_block(%A : !tt.ptr<f16>) {
tt.func @war_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -67,11 +67,11 @@ func.func @war_single_block(%A : !tt.ptr<f16>) {
// CHECK: gpu.barrier
// CHECK-NEXT: %4 = triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: scratch
func.func @scratch() {
tt.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
@@ -80,11 +80,11 @@ func.func @scratch() {
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
return
tt.return
}
// CHECK-LABEL: async_wait
func.func @async_wait() {
tt.func @async_wait() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
@@ -93,21 +93,21 @@ func.func @async_wait() {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
return
tt.return
}
// CHECK-LABEL: alloc
func.func @alloc() {
tt.func @alloc() {
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = tt.cat %0, %0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
return
tt.return
}
// CHECK-LABEL: extract_slice
func.func @extract_slice() {
tt.func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%0 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
@@ -117,19 +117,19 @@ func.func @extract_slice() {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: trans
func.func @trans() {
tt.func @trans() {
// CHECK-NOT: gpu.barrier
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
return
tt.return
}
// CHECK-LABEL: insert_slice_async_op
func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
tt.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -142,11 +142,11 @@ func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: insert_slice_op
func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
tt.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -162,12 +162,12 @@ func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
return
tt.return
}
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
// CHECK-LABEL: multi_blocks
func.func @multi_blocks(%i1 : i1) {
tt.func @multi_blocks(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -186,12 +186,12 @@ func.func @multi_blocks(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%2 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
tt.return
}
// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region
// CHECK-LABEL: multi_blocks_join_barrier
func.func @multi_blocks_join_barrier(%i1 : i1) {
tt.func @multi_blocks_join_barrier(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -206,12 +206,12 @@ func.func @multi_blocks_join_barrier(%i1 : i1) {
scf.yield
}
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
return
tt.return
}
// Read yielded tensor requires a barrier
// CHECK-LABEL: multi_blocks_yield
func.func @multi_blocks_yield(%i1 : i1) {
tt.func @multi_blocks_yield(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
@@ -229,12 +229,12 @@ func.func @multi_blocks_yield(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
return
tt.return
}
// Even though the entry block doesn't have a barrier, the successors should have barriers
// CHECK-LABEL: multi_blocks_entry_no_shared
func.func @multi_blocks_entry_no_shared(%i1 : i1) {
tt.func @multi_blocks_entry_no_shared(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
// CHECK: gpu.barrier
@@ -251,12 +251,12 @@ func.func @multi_blocks_entry_no_shared(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
return
tt.return
}
// Conservatively add a barrier as if the branch (%i1) is never taken
// CHECK-LABEL: multi_blocks_noelse
func.func @multi_blocks_noelse(%i1 : i1) {
tt.func @multi_blocks_noelse(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -268,12 +268,12 @@ func.func @multi_blocks_noelse(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
return
tt.return
}
// Conservatively add a barrier as if the branch (%i2) is never taken
// CHECK-LABEL: multi_blocks_nested_scf
func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -293,11 +293,11 @@ func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
return
tt.return
}
// CHECK-LABEL: for
func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
@@ -307,13 +307,13 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%5 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// Although a_shared and b_shared are synced before entering the loop,
// they are reassociated with aliases (c_shared) and thus require a barrier.
// CHECK-LABEL: for_alias
func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
@@ -330,13 +330,13 @@ func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
tt.return
}
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
// So we need a barrier both before and after cst1
// CHECK-LABEL: for_reuse
func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
@@ -355,11 +355,11 @@ func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
tt.return
}
// CHECK-LABEL: for_reuse_nested
func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
@@ -381,12 +381,12 @@ func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.pt
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%15 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
tt.return
}
// repeatedly write to the same shared memory addresses
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
@@ -407,12 +407,12 @@ func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
}
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// c_block_next can either be converted from c_shared_init or c_shared_next_next
// CHECK-LABEL: for_if_for
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
@@ -438,11 +438,11 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
tt.return
}
// CHECK-LABEL: cf_if
func.func @cf_if(%i1 : i1) {
tt.func @cf_if(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
@@ -455,10 +455,10 @@ func.func @cf_if(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
return
tt.return
}
func.func @cf_if_else(%i1 : i1) {
tt.func @cf_if_else(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
@@ -479,10 +479,10 @@ func.func @cf_if_else(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
tt.return
}
func.func @cf_if_else_return(%i1 : i1) {
tt.func @cf_if_else_return(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
@@ -490,12 +490,12 @@ func.func @cf_if_else_return(%i1 : i1) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
tt.return
^bb2: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
tt.return
}
}

View File

@@ -1,6 +1,6 @@
// RUN: triton-opt %s | FileCheck %s
func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
tt.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
// scalar -> scalar
// CHECK: i64 -> !tt.ptr<f32>
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
@@ -32,10 +32,10 @@ func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i6
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
// CHECK: tensor<16xf32> to tensor<16xf16>
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
return
tt.return
}
func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
tt.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
// scalar -> scalar
// CHECK: !tt.ptr<f32>
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
@@ -51,10 +51,10 @@ func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
// CHECK: tensor<16x!tt.ptr<f32>>
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
return
tt.return
}
func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
tt.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
// Test if Load/Store ops can handle scalar values
%other = arith.constant 0.0e+0 : f32
@@ -73,10 +73,10 @@ func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
tt.store %ptr, %b, %mask : f32
// CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} {cache = 1 : i32, evict = 1 : i32} : f32
tt.store %ptr, %c, %mask : f32
return
tt.return
}
func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
tt.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
@@ -98,10 +98,10 @@ func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
tt.store %ptr1x2, %c : tensor<1x2xf32>
tt.store %ptr1, %e : tensor<1xf32>
tt.store %ptr, %g : f32
return
tt.return
}
func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
// Test if reduce ops infer types correctly
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
@@ -128,5 +128,5 @@ func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
return
tt.return
}

View File

@@ -1,17 +1,17 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
func.func @ops() {
tt.func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
return
tt.return
}
// -----
func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if LoadOp is lowered properly (see #771)
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%mask = arith.constant dense<true> : tensor<128xi1>
@@ -25,12 +25,12 @@ func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.store %ptrs, %a : tensor<128xf32>
tt.store %ptrs, %b : tensor<128xf32>
tt.store %ptrs, %c : tensor<128xf32>
return
tt.return
}
// -----
func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 32
// Test if the total number of warps is 2
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
@@ -49,5 +49,5 @@ func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
return
tt.return
}

View File

@@ -4,9 +4,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
tt.return
}
} // end module
@@ -15,11 +15,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_load
func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -28,13 +28,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load
func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -43,13 +43,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: vectorized_load_f16
func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b16
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b16
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
return
tt.return
}
}
@@ -59,10 +59,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -72,10 +72,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other_vec
func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_no_vec
func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -127,7 +127,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -136,7 +136,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -163,7 +163,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// Store 4 elements to global with single one vectorized store instruction
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -173,7 +173,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
module attributes {"triton_gpu.num-warps" = 2 : i32} {
func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -194,7 +194,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
return
tt.return
}
}
@@ -203,7 +203,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec2
func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -239,7 +239,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -248,7 +248,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -278,7 +278,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -291,7 +291,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_view_broadcast
func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
@@ -306,7 +306,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
return
tt.return
}
}
@@ -315,13 +315,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
func.func @basic_make_range() {
tt.func @basic_make_range() {
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
tt.return
}
}
@@ -330,11 +330,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addf
func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
// CHECK: llvm.fadd
// CHECK: llvm.fadd
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
return
tt.return
}
}
@@ -343,11 +343,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addi
func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.add
// CHECK: llvm.add
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
return
tt.return
}
}
@@ -355,10 +355,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
func.func @basic_program_id() {
tt.func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
return
tt.return
}
}
@@ -367,11 +367,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addptr
func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.getelementptr
// CHECK: llvm.getelementptr
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
return
tt.return
}
}
@@ -381,14 +381,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_alloc_tensor
func.func @basic_alloc_tensor() {
tt.func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK-NEXT: llvm.bitcast
// CHECK-NEXT: llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr
// CHECK-NEXT: llvm.bitcast
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
return
tt.return
}
}
@@ -398,7 +398,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice
func.func @basic_extract_slice() {
tt.func @basic_extract_slice() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
return
tt.return
}
}
@@ -431,10 +431,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_async_wait
func.func @basic_async_wait() {
tt.func @basic_async_wait() {
// CHECK: cp.async.wait_group 0x4
triton_gpu.async_wait {num = 4: i32}
return
tt.return
}
}
@@ -450,7 +450,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_fallback
func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
tt.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -473,7 +473,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16>, #AL> -> tensor<2x16x64xf16, #A>
return
tt.return
}
}
@@ -489,7 +489,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v4
func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -515,7 +515,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
triton_gpu.async_commit_group
return
tt.return
}
}
@@ -531,7 +531,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v1
func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -561,7 +561,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
triton_gpu.async_commit_group
return
tt.return
}
}
@@ -576,7 +576,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
@@ -618,7 +618,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
triton_gpu.async_commit_group
return
tt.return
}
}
@@ -627,12 +627,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: basic_splat
func.func @basic_splat(%ptr: !tt.ptr<f32>) {
tt.func @basic_splat(%ptr: !tt.ptr<f32>) {
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
%0 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>,#blocked0>
return
tt.return
}
}
@@ -641,13 +641,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
tt.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -658,7 +658,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked
func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
@@ -694,7 +694,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
tt.return
}
}
@@ -705,7 +705,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
@@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
tt.return
}
}
@@ -728,7 +728,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
@@ -746,7 +746,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
tt.return
}
}
@@ -759,7 +759,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot
func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
// CHECK: llvm.inline_asm
@@ -776,14 +776,14 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
return
tt.return
}
}
// TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
// func.func @make_range_sliced_layout() {
// tt.func @make_range_sliced_layout() {
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
// return
// }
@@ -796,7 +796,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav2_block
func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -805,7 +805,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
return
tt.return
}
}
@@ -816,7 +816,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -829,7 +829,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
return
tt.return
}
}
@@ -839,13 +839,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_shared
func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
tt.return
}
}
@@ -855,10 +855,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
return
tt.return
}
}
@@ -868,10 +868,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
return
tt.return
}
}
@@ -881,14 +881,14 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked_to_blocked_ptr
func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
// CHECK: llvm.ptrtoint
// CHECK: llvm.store
// CHECK: nvvm.barrier0
// CHECK: llvm.inttoptr
// CHECK-COUNT-4: llvm.insertvalue
%cvt = triton_gpu.convert_layout %src : (tensor<32x!tt.ptr<f32>, #blocked0>) -> tensor<32x!tt.ptr<f32>, #blocked1>
return
tt.return
}
}
@@ -900,7 +900,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
@@ -913,7 +913,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<128x256xf32, #blocked>
return
tt.return
}
}
@@ -926,7 +926,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
@@ -938,7 +938,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<32x64xf32, #blocked>
return
tt.return
}
}
@@ -949,7 +949,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
// CHECK: llvm.intr.fmuladd
@@ -960,7 +960,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %28 : tensor<32x32xf32, #blocked>
return
tt.return
}
}
@@ -973,7 +973,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK: llvm.inline_asm
@@ -999,7 +999,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<32x32xf32, #blocked>
return
tt.return
}
}
@@ -1008,14 +1008,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -1023,12 +1023,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32_scalar
func.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
return
tt.return
}
}
@@ -1037,14 +1037,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
func.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
return
tt.return
}
}
@@ -1052,12 +1052,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
func.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : f32
return
tt.return
}
}
@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%blockidx = tt.get_program_id {axis=0:i32} : i32
%blockidy = tt.get_program_id {axis=1:i32} : i32
%blockidz = tt.get_program_id {axis=2:i32} : i32
@@ -1078,7 +1078,7 @@ func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
tt.return
}
}
@@ -1086,7 +1086,7 @@ func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
tt.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z
@@ -1098,7 +1098,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
tt.return
}
}
@@ -1106,12 +1106,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
func.func @test_index_cache() {
tt.func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
tt.return
}
}
@@ -1120,12 +1120,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
tt.return
}
}
@@ -1134,7 +1134,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_index_cache_different_block
func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
@@ -1143,6 +1143,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
cf.br ^bb2
^bb2: // 2 preds: ^bb0, ^bb1
return
tt.return
}
}

View File

@@ -8,9 +8,9 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
return
tt.return
}
}

View File

@@ -6,9 +6,9 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
return
tt.return
}
}

View File

@@ -2,7 +2,7 @@
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
// CHECK-LABEL: @test_combine_dot_add_pattern
func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
@@ -19,12 +19,12 @@ func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x12
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// COM: CHECK-LABEL: @test_combine_addptr_pattern
func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
tt.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
@@ -42,12 +42,12 @@ func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
return %ptr1 : tensor<8x!tt.ptr<f32>>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_select_masked_load_pattern
func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
tt.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
%false_val = arith.constant dense<0.0> : tensor<8xf32>
@@ -59,12 +59,12 @@ func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>,
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%1 = arith.select %cond, %y, %false_val : tensor<8xf32>
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
return %0, %1 : tensor<8xf32>, tensor<8xf32>
// CHECK: tt.return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
tt.return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
@@ -82,21 +82,21 @@ func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8xf32>
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
return %bst_out : tensor<8x2xf32>
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
%other_val = arith.constant dense<0.0> : tensor<8xf32>
@@ -112,12 +112,12 @@ func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -
// false_mask with other. It should become "other" (i.e., %y)
%z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
// CHECK: tt.return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
tt.return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
tt.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
%other_val = arith.constant dense<0.0> : tensor<8xf32>
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
@@ -126,11 +126,11 @@ func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
return %x, %y: tensor<8xf32>, tensor<8xf32>
tt.return %x, %y: tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
tt.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
@@ -138,31 +138,31 @@ func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>,
tt.store %ptr, %val, %true_mask : tensor<8xf32>
// The following store should disappear.
// CHECK-NEXT: return
// CHECK-NEXT: tt.return
tt.store %ptr, %val, %false_mask : tensor<8xf32>
return
tt.return
}
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
tt.store %ptr, %val, %mask : tensor<8xf32>
return
tt.return
}
// CHECK-LABEL: @test_canonicalize_expand_dims
func.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
%splat = tt.splat %arg0 : (tensor<f32>) -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg0 : (tensor<f32>) -> tensor<1x8xf32>
%ed = tt.expand_dims %splat {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
return %ed : tensor<1x8xf32>
tt.return %ed : tensor<1x8xf32>
}
// CHECK-LABEL: @test_canonicalize_view
func.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
%view0 = tt.view %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.view %arg0 : (tensor<8xf32>) -> tensor<4x2xf32>
%view1 = tt.view %view0 : (tensor<2x4xf32>) -> tensor<4x2xf32>
@@ -175,11 +175,11 @@ func.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %view3, %arg0 : tensor<8xf32>
return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
tt.return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_broadcast
func.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f32>) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) {
tt.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f32>) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) {
%broadcast0 = tt.broadcast %arg0 : (tensor<1x1x8xf32>) -> tensor<1x2x8xf32>
// CHECK: %{{.*}} = tt.broadcast %arg0 : (tensor<1x1x8xf32>) -> tensor<4x2x8xf32>
%broadcast1 = tt.broadcast %broadcast0 : (tensor<1x2x8xf32>) -> tensor<4x2x8xf32>
@@ -192,11 +192,11 @@ func.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<1x1x8xf32>
%add = arith.addf %broadcast3, %arg0 : tensor<1x1x8xf32>
return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>
tt.return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>
}
// CHECK-LABEL: @test_fold_views
func.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) {
tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) {
%a = arith.constant dense<1.0> : tensor<1x128xf32>
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
@@ -208,5 +208,5 @@ func.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<1x1x128xf32>
%d = tt.expand_dims %a {axis = 0: i32} : (tensor<1x128xf32>) -> tensor<1x1x128xf32>
return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>
tt.return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>
}

View File

@@ -1,5 +1,5 @@
// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s
func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%c31_i32 = arith.constant 31 : i32
%c127_i32 = arith.constant 127 : i32
%c1 = arith.constant 1 : index
@@ -79,5 +79,5 @@ func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}
%53 = tt.broadcast %51 : (tensor<1x32xi1>) -> tensor<128x32xi1>
%54 = arith.andi %52, %53 : tensor<128x32xi1>
tt.store %45, %30, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<128x32xf16>
return
tt.return
}

View File

@@ -1,7 +1,7 @@
// RUN: triton-opt %s -verify-diagnostics
module {
func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%c256_i32 = arith.constant 256 : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -39,11 +39,11 @@ module {
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
tt.store %17, %15#0, %6 : tensor<256xf32>
return
tt.return
}
}
// module {
// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
// tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
// %c64 = arith.constant 64 : index
// %c32 = arith.constant 32 : index
// %c0 = arith.constant 0 : index
@@ -125,6 +125,6 @@ module {
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// return
// tt.return
// }
// }

View File

@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}) {
@@ -47,7 +47,7 @@ func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
return
tt.return
}
}

View File

@@ -9,34 +9,34 @@
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: cst
func.func @cst() -> tensor<1024xi32, #layout1> {
tt.func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: return %cst : tensor<1024xi32, [[$target_layout]]>
return %1: tensor<1024xi32, #layout1>
// CHECK: tt.return %cst : tensor<1024xi32, [[$target_layout]]>
tt.return %1: tensor<1024xi32, #layout1>
}
// CHECK-LABEL: range
func.func @range() -> tensor<1024xi32, #layout1> {
tt.func @range() -> tensor<1024xi32, #layout1> {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: return %0 : tensor<1024xi32, [[$target_layout]]>
return %1: tensor<1024xi32, #layout1>
// CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]>
tt.return %1: tensor<1024xi32, #layout1>
}
// CHECK-LABEL: splat
func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
tt.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: return %0 : tensor<1024xi32, [[$target_layout]]>
return %1: tensor<1024xi32, #layout1>
// CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]>
tt.return %1: tensor<1024xi32, #layout1>
}
// CHECK-LABEL: remat
func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
@@ -44,7 +44,7 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
%4 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
%6 = arith.addi %3, %5 : tensor<1024xi32, #layout1>
return %6: tensor<1024xi32, #layout1>
tt.return %6: tensor<1024xi32, #layout1>
// CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
@@ -52,24 +52,24 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]>
// CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]>
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]>
// CHECK: return %6 : tensor<1024xi32, [[$target_layout]]>
// CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]>
}
// Always rematerialize single value loads
// CHECK-LABEL: remat_single_value
func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #layout1>) -> tensor<1xi32, #layout0>
%3 = triton_gpu.convert_layout %0 : (tensor<1x!tt.ptr<i32>, #layout1>) -> tensor<1x!tt.ptr<i32>, #layout0>
tt.store %3, %2 : tensor<1xi32, #layout0>
return
tt.return
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
@@ -78,12 +78,12 @@ func.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<16xi32, #layout0>
return
tt.return
}
}
// CHECK-LABEL: if
func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -96,11 +96,11 @@ func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%6 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout1>) -> tensor<1024xi32, #layout0>
tt.store %5, %6 : tensor<1024xi32, #layout0>
}
return
tt.return
}
// CHECK-LABEL: if_convert_else_not
func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -117,11 +117,11 @@ func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
}
// CHECK-NOT: triton_gpu.convert_layout
tt.store %5, %8 : tensor<1024xi32, #layout1>
return
tt.return
}
// CHECK-LABEL: if_not_else_convert
func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -138,11 +138,11 @@ func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
}
// CHECK-NOT: triton_gpu.convert_layout
tt.store %5, %8 : tensor<1024xi32, #layout1>
return
tt.return
}
// CHECK-LABEL: if_else_both_convert
func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -161,7 +161,7 @@ func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
// disabledCHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.convert_layout
tt.store %5, %8 : tensor<1024xi32, #layout1>
return
tt.return
}
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -173,12 +173,12 @@ func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: transpose
func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[$row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[$col_layout]]>
// CHECK: return
// CHECK: tt.return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
@@ -210,65 +210,65 @@ func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i3
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
return
tt.return
}
// CHECK-LABEL: loop
func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[$row_layout]]>
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]>
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]>
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]>
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
}
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
return
tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[$row_layout]]>
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]>
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]>
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]>
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
}
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
// CHECK-LABEL: vecadd
func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -295,12 +295,12 @@ func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.p
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1>
tt.store %21, %22 : tensor<256xf32, #layout1>
return
tt.return
}
// Select has args with different element types
// CHECK-LABEL: select
func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
@@ -346,12 +346,12 @@ func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.p
tt.store %31, %32, %33 : tensor<1x512xf64, #blocked3>
scf.yield %30 : tensor<1x512xf64, #blocked2>
}
return
tt.return
}
// Make sure the following IR doesn't hang the compiler.
// CHECK-LABEL: long_func
func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
@@ -742,13 +742,13 @@ func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %a
%365 = triton_gpu.convert_layout %364 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%366 = triton_gpu.convert_layout %343 : (tensor<1024xf64, #blocked0>) -> tensor<1024xf64, #blocked0>
tt.store %365, %366 : tensor<1024xf64, #blocked0>
return
tt.return
}
// A mnist model from torch inductor.
// Check if topological sort is working correct and there's no unnecessary convert
// CHECK-LABEL: mnist
func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
@@ -822,7 +822,7 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
%62 = triton_gpu.convert_layout %58 : (tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked4>
%63 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
return
tt.return
}
// -----
@@ -835,7 +835,7 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
@@ -968,14 +968,14 @@ func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !
%82 = triton_gpu.convert_layout %54 : (tensor<64x64xi1, #blocked2>) -> tensor<64x64xi1, #blocked4>
tt.store %80, %81, %82 : tensor<64x64xf16, #blocked4>
}
return
tt.return
}
// -----
// Just make sure it doesn't crash on non-tensor types.
// CHECK-LABEL: if_no_tensor
func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c-1_i64 = arith.constant -1 : i64
%cst = arith.constant 0.000000e+00 : f32
@@ -996,7 +996,7 @@ func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%8 = tt.load %5, %7, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
%9 = tt.addptr %arg1, %0 : !tt.ptr<f32>, i32
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
return
tt.return
}
// -----
@@ -1009,7 +1009,7 @@ func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
func.func public @reduce_cvt1(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) {
tt.func public @reduce_cvt1(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) {
%cst = arith.constant dense<0> : tensor<1x2xi32, #blocked>
%cst_0 = arith.constant dense<2> : tensor<1x2xi32, #blocked>
%0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1>
@@ -1029,7 +1029,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%14 = triton_gpu.convert_layout %12 : (tensor<1x2xi64, #blocked>) -> tensor<1x2xi64, #blocked3>
%15 = triton_gpu.convert_layout %3 : (tensor<1x2xi1, #blocked>) -> tensor<1x2xi1, #blocked3>
tt.store %13, %14, %15 {cache = 1 : i32, evict = 1 : i32} : tensor<1x2xi64, #blocked3>
return
tt.return
}
}
@@ -1045,7 +1045,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
tt.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
%c3136_i32 = arith.constant 3136 : index
%c256_i32 = arith.constant 256 : index
@@ -1104,6 +1104,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%25 = triton_gpu.convert_layout %21 : (tensor<1x1xf32, #blocked>) -> tensor<1x1xf32, #blocked>
%26 = triton_gpu.convert_layout %7 : (tensor<1x1xi1, #blocked>) -> tensor<1x1xi1, #blocked>
tt.store %24, %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32, #blocked>
return
tt.return
}
}

View File

@@ -11,7 +11,7 @@
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
// CHECK: func.func @matmul_loop
// CHECK: tt.func @matmul_loop
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -47,7 +47,7 @@
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop(%lb : index, %ub : index, %step : index,
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
// A ptrs
@@ -88,10 +88,10 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return %loop#2: tensor<128x128xf32, #C>
tt.return %loop#2: tensor<128x128xf32, #C>
}
// CHECK: func.func @matmul_loop_nested
// CHECK: tt.func @matmul_loop_nested
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -120,7 +120,7 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{
@@ -162,11 +162,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
scf.yield %loop2#2 : tensor<128x128xf32, #C>
}
return %loop1#0 : tensor<128x128xf32, #C>
tt.return %loop1#0 : tensor<128x128xf32, #C>
}
// CHECK: func.func @matmul_loop_single_pipeline
// CHECK: tt.func @matmul_loop_single_pipeline
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -187,7 +187,7 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
// A ptrs
@@ -222,10 +222,10 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return %loop#1 : tensor<128x128xf32, #C>
tt.return %loop#1 : tensor<128x128xf32, #C>
}
// CHECK: func.func @lut_bmm_scalar
// CHECK: tt.func @lut_bmm_scalar
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.insert_slice_async
@@ -239,7 +239,7 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]]
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
%76: index,
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%75: !tt.ptr<i64>,
@@ -265,10 +265,10 @@ func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
%92 = tt.addptr %arg21, %c1_i32 : !tt.ptr<i64>, i32
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>
}
return %79#0 : tensor<16x16xf32, #C>
tt.return %79#0 : tensor<16x16xf32, #C>
}
// CHECK: func.func @lut_bmm_vector
// CHECK: tt.func @lut_bmm_vector
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.insert_slice_async
@@ -283,7 +283,7 @@ func.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]]
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
func.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32},
tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32},
%76: index,
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%75: tensor<16x!tt.ptr<i64>, #BLs1>,
@@ -311,5 +311,5 @@ func.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32,
%92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr<i64>, #BLs1>, tensor<16xi32, #BLs1>
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x!tt.ptr<i64>, #BLs1>
}
return %79#0 : tensor<16x16xf32, #C>
tt.return %79#0 : tensor<16x16xf32, #C>
}

View File

@@ -4,7 +4,7 @@
// CHECK: offset = 49152, size = 49152
// CHECK: size = 98304
module {
func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1>
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -101,6 +101,6 @@ func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32_
%74 = tt.broadcast %72 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%75 = arith.andi %73, %74 : tensor<64x64xi1>
tt.store %66, %47#0, %75 : tensor<64x64xf32>
return
tt.return
}
}

View File

@@ -11,7 +11,7 @@
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
// CHECK: func.func @matmul_loop
// CHECK: tt.func @matmul_loop
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
@@ -28,7 +28,7 @@
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
@@ -60,5 +60,5 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
}
return
tt.return
}

View File

@@ -9,7 +9,7 @@ using namespace mlir;
namespace {
struct TestAliasPass
: public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> {
: public PassWrapper<TestAliasPass, OperationPass<triton::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);

View File

@@ -6,7 +6,7 @@ using namespace mlir;
namespace {
struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> {
: public PassWrapper<TestAllocationPass, OperationPass<triton::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);

View File

@@ -7,7 +7,7 @@ using namespace mlir;
namespace {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
: public PassWrapper<TestAxisInfoPass, OperationPass<triton::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);

View File

@@ -11,7 +11,7 @@ using namespace mlir;
namespace {
struct TestMembarPass
: public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> {
: public PassWrapper<TestMembarPass, OperationPass<triton::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);