[FRONTEND] Support block pointer semantics (#1392)

This PR introduces a new semantics: **block pointer**, which makes users
easier & faster to load a block from a parent tensor.

Below is a detailed API change by an example:
```
# Make a block pointer, which points to a block in the parent shape
# `base`: the parent tensor
# `shape`: the shape of the parent tensor
# `strides`: the strides of the parent tensor
# `offsets`: the offsets of the block in the parent tensor
# `order`: the order of the data arrangement in memory
# Below is an example loading a 2D column-major matrix 
block_ptr = tl.make_block_ptr(base=ptr, shape=(M, N), strides=(stride_m, stride_n), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

# Advance the offsets; note that the striding information is already saved in `block_ptr`
# `base`: the block pointer to be advanced
# `offsets`: the offsets for each dimension
block_ptr = tl.advance(base=block_ptr, offsets=(BLOCK_M, -BLOCK_N))
block_ptr = tl.advance(base=block_ptr, offsets=(-BLOCK_M, BLOCK_N))

# Load from a block pointer, the output type is the dereferenced type of `block_ptr`, e.g. ptr<tensor<32x32xf32>> -> tensor<32x32xf32>
# `ptr`: the block pointer to be loaded
# `boundary_check`: a tuple of dimensions to check the boundary
# `padding`: padding strategy for elements out of bound
val = tl.load(ptr=block_ptr, boundary_check=(0, 1), padding="zero")

# Store by a block pointer, in which the pointer and the value tensor should have the same shape
# `ptr`: the block pointer to be stored
# `boundary_check`: a tuple of dimensions to check the boundary (no-write if out of bound)
tl.store(ptr=block_ptr, value=val, boundary_check=(0, 1))
```

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Chenggang Zhao
2023-03-28 07:46:49 +08:00
committed by GitHub
parent e3a763872b
commit 72b071253e
25 changed files with 1584 additions and 171 deletions

View File

@@ -1,9 +1,8 @@
#ifndef TRITON_IR_TRAITS_H_
#define TRITON_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"
#include <iostream>
@@ -12,11 +11,9 @@ namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op);
// The rationale for this trait is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
@@ -25,7 +22,22 @@ LogicalResult verifySameOperandsEncoding(Operation *op);
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxTensorNumElements = 1048576;
LogicalResult verifyTensorSize(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op,
bool allowTensorPointerType = false);
LogicalResult
verifySameOperandsAndResultEncoding(Operation *op,
bool allowTensorPointerType = false);
LogicalResult verifySameLoadStoreOperandsShape(Operation *op);
LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op);
bool verifyLoadStorePointerAndValueType(Type valueType, Type ptrType);
} // namespace impl
template <class ConcreteType>
@@ -54,6 +66,44 @@ public:
}
};
template <typename ConcreteType>
class SameLoadStoreOperandsShape
: public TraitBase<ConcreteType, SameLoadStoreOperandsShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameLoadStoreOperandsShape(op);
}
};
template <typename ConcreteType>
class SameLoadStoreOperandsAndResultShape
: public TraitBase<ConcreteType, SameLoadStoreOperandsAndResultShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameLoadStoreOperandsAndResultShape(op);
}
};
template <typename ConcreteType>
class SameLoadStoreOperandsEncoding
: public TraitBase<ConcreteType, SameLoadStoreOperandsEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsEncoding(op,
/*allowTensorPointerType=*/true);
}
};
template <typename ConcreteType>
class SameLoadStoreOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameLoadStoreOperandsAndResultEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultEncoding(
op, /*allowTensorPointerType=*/true);
}
};
} // namespace OpTrait
} // namespace mlir

View File

@@ -3,7 +3,7 @@
include "mlir/IR/EnumAttr.td"
// Attrs for LoadOp
// Attributes for LoadOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
@@ -13,6 +13,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
@@ -23,6 +24,16 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
def TT_PaddingOptionAttr : I32EnumAttr<
"PaddingOption", "",
[
I32EnumAttrCase<"PAD_ZERO", 1, "zero">,
// We can not set the string value to "NAN" because it is a keyword in C++
I32EnumAttrCase<"PAD_NAN", 2, "nan">
]> {
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",

View File

@@ -4,8 +4,11 @@
include "mlir/IR/OpBase.td"
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">;
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
#endif // TRITON_INTERFACES

View File

@@ -107,66 +107,104 @@ def TT_AddPtrOp : TT_Op<"addptr",
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
}
def TT_AdvanceOp : TT_Op<"advance",
[Pure,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let summary = "Advance a tensor pointer by offsets";
let arguments = (ins TT_TensorPtr:$ptr, Variadic<I32>:$offsets);
let results = (outs TT_TensorPtr:$result);
let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)";
}
//
// Load/Store Ops
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
[SameLoadStoreOperandsAndResultShape,
SameLoadStoreOperandsAndResultEncoding,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "getPointerTypeSameShape($_self)">,
"result", "ptr", "$_self",
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let summary = "Load from a tensor of pointers or from a tensor pointer";
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other, OptionalAttr<DenseI32ArrayAttr>:$boundaryCheck,
OptionalAttr<TT_PaddingOptionAttr>:$padding, TT_CacheModifierAttr:$cache,
TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor pointer with boundary check and padding
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask and other
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A utility function to build the operation with all attributes
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional<ArrayRef<int32_t>>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
];
// Format: `tt.load operands attrs : optional(type(ptr)) -> type(result)`
// We need an extra `optional(type(ptr))` for inferring the tensor pointer type with back compatibility
let hasCustomAssemblyFormat = 1;
let hasCanonicalizer = 1;
}
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
SameOperandsEncoding,
[SameLoadStoreOperandsShape,
SameLoadStoreOperandsEncoding,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeSameShape($_self)">,
"value", "ptr", "$_self",
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let summary = "Store by a tensor of pointers or by a tensor pointer";
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
OptionalAttr<DenseI32ArrayAttr>:$boundaryCheck,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache,
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>,
// A tensor pointer with boundary check
OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef<int32_t>":$boundaryCheck, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>
];
// Format: `tt.store operands attrs : optional(type(ptr)), type(val)
// We need an extra `optional(type(ptr))` for inferring the tensor pointer type with back compatibility
let hasCustomAssemblyFormat = 1;
let hasCanonicalizer = 1;
}
@@ -213,7 +251,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
return $old
}];
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}
@@ -438,4 +476,46 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
}
//
// Make a Tensor Pointer
//
def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
[Pure,
SameVariadicOperandSize,
TypesMatchWith<"infer pointer type from the result type",
"result", "base",
"getPointerType(getElementTypeOfTensorPointerType($_self))">]> {
let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified";
let description = [{
`tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a
pointer to the block tensor, e.g. returns a type of `tt.ptr<tensor<8x8xf16>>`.
}];
// TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints.
let arguments = (ins
TT_Ptr:$base,
Variadic<I64>:$shape,
Variadic<I64>:$strides,
Variadic<I32>:$offsets,
DenseI32ArrayAttr:$order
);
let results = (outs TT_TensorPtr:$result);
// Add additional `[]` to increase readability and split variadic lists
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";
let builders = [
OpBuilder<(ins
"Value":$base,
"ValueRange":$shape,
"ValueRange":$strides,
"ValueRange":$offsets,
"ArrayRef<int32_t>":$tensorShape,
"ArrayRef<int32_t>":$order
)>
];
}
#endif // Triton_OPS

View File

@@ -31,19 +31,28 @@ def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>;
// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>;
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
// Pointer Type in TableGen
class TT_PtrOf<list<Type> pointeeTypes> :
DialectType<Triton_Dialect,
And<[CPred<"$_self.isa<::mlir::triton::PointerType>()">,
Concat<"[](::mlir::Type pointeeType) { return ",
SubstLeaves<"$_self", "pointeeType", AnyTypeOf<pointeeTypes>.predicate>,
"; }($_self.cast<::mlir::triton::PointerType>().getPointeeType())">]>,
"ptr", "::mlir::triton::PointerType">;
// Pointer Type in C++ (corresponding to `TT_PtrOf`)
def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system";
let description = [{
Triton PointerType
Pointer type in Triton IR type system, which could be pointing to scalars or tensors.
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
@@ -58,14 +67,27 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
// Scalar Pointer Type: `ptr<>`
def TT_Ptr : TT_PtrOf<[AnyType]>;
// Tensor of Pointer Type
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
// Tensor Type
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
// Pointer Type to Tensor Type: `ptr<tensor<>>`
def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
// Any Type in Triton IR
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;
#endif

View File

@@ -12,6 +12,8 @@ namespace mlir {
unsigned getPointeeBitWidth(RankedTensorType tensorTy);
}
bool isTensorPointerType(Type type);
} // namespace mlir
#endif // TRITON_IR_TYPES_H_

View File

@@ -8,6 +8,9 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80);
} // namespace triton
#define GEN_PASS_REGISTRATION

View File

@@ -19,4 +19,23 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
let dependentDialects = ["mlir::arith::ArithDialect"];
}
def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
let description = [{
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
the pointer/mask/other for each load/store.
}];
let constructor = "mlir::triton::createRewriteTensorPointerPass()";
let dependentDialects = ["mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@@ -412,6 +412,8 @@ private:
auto loadOp = builder.create<triton::LoadOp>(
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
// TODO(Chenggang): confirm `boundaryCheck` and `padding`
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
insertSliceAsyncOp.getIsVolatile());

View File

@@ -355,6 +355,7 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(),
adaptor.getBoundaryCheckAttr(), adaptor.getPaddingAttr(),
adaptor.getCache(), adaptor.getEvict(),
adaptor.getIsVolatile()),
adaptor.getAttributes());

View File

@@ -36,66 +36,116 @@ static Type getPointerTypeSameShape(Type type) {
}
}
static Type getPointerType(Type type) { return PointerType::get(type, 1); }
static Type getElementTypeOfTensorPointerType(Type type) {
if (auto ptrType = type.dyn_cast<PointerType>())
if (auto tensorType = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return {};
}
// Parser & printer for assembly forms
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse operands
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
return failure();
result.addTypes(resultTypes);
// Operand types
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
// Parse `optional(type(ptr)) -> type(result)`
Type ptrType, resultType;
if (parser.parseType(resultType))
return failure();
if (parser.parseOptionalArrow().succeeded()) {
ptrType = resultType;
if (parser.parseType(resultType))
return failure();
operandTypes.push_back(ptrType);
result.addTypes(resultType);
} else {
operandTypes.push_back(getPointerTypeSameShape(resultType));
result.addTypes(resultType);
}
// Determine `mask` and `other`
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
operandTypes.push_back(getI1SameShape(resultType));
hasMask = 1;
}
if (allOperands.size() >= 3) {
operandTypes.push_back(resultTypes[0]); // other
operandTypes.push_back(resultType);
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
// Deduce `operandSegmentSizes` from the number of the operands
auto operandSegmentSizesAttrName =
LoadOp::getOperandSegmentSizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
operandSegmentSizesAttrName,
parser.getBuilder().getDenseI32ArrayAttr({1, hasMask, hasOther}));
return success();
}
void LoadOp::print(OpAsmPrinter &printer) {
printer << " ";
printer << getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
// `operandSegmentSizes` can be deduced, so we don't print it.
printer.printOptionalAttrDict(getOperation()->getAttrs(),
{getOperandSegmentSizesAttrName()});
// `type(ptr) -> type(result)`
printer << " : ";
// `type(ptr)` is optional during parsing, we only print for tensor pointers
if (isTensorPointerType(getPtr().getType())) {
printer.printStrippedAttrOrType(getPtr().getType());
printer << " -> ";
}
printer.printStrippedAttrOrType(getResult().getType());
}
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse operands
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(valueType))
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
return failure();
// Operand types
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
operandTypes.push_back(valueType); // value
// Parse `optional(type(ptr)), type(val)`
// Pointer type
Type ptrType, valType;
if (parser.parseType(valType))
return failure();
if (parser.parseOptionalComma().succeeded()) {
ptrType = valType;
if (parser.parseType(valType))
return failure();
operandTypes.push_back(ptrType);
} else {
operandTypes.push_back(getPointerTypeSameShape(valType));
}
// Value type
operandTypes.push_back(valType);
// Determine `mask`
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
operandTypes.push_back(getI1SameShape(valType));
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
@@ -107,7 +157,14 @@ void StoreOp::print(OpAsmPrinter &printer) {
printer << " ";
printer << getOperation()->getOperands();
printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{});
// `type(ptr), type(value)`
printer << " : ";
// `type(ptr)` is optional during parsing, we only print for tensor pointers
if (isTensorPointerType(getPtr().getType())) {
printer.printStrippedAttrOrType(getPtr().getType());
printer << ", ";
}
printer.printStrippedAttrOrType(getValue().getType());
}
@@ -123,15 +180,6 @@ void StoreOp::print(OpAsmPrinter &printer) {
namespace mlir {
namespace triton {
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache,
evict);
}
//-- LoadOp --
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
@@ -146,24 +194,42 @@ static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
isVolatile);
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{},
/*boundaryCheck=*/{}, /*padding=*/{}, cache, evict, isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ArrayRef<int32_t> boundaryCheck,
std::optional<::mlir::triton::PaddingOption> padding,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck,
padding, cache, evict, isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
isVolatile);
LoadOp::build(builder, state, ptr, mask, /*other=*/{}, /*boundaryCheck=*/{},
/*padding=*/{}, cache, evict, isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
Type resultType = getLoadOpResultType(builder, ptr.getType());
LoadOp::build(builder, state, ptr, mask, other, /*boundaryCheck=*/{},
/*padding=*/{}, cache, evict, isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
std::optional<ArrayRef<int32_t>> boundaryCheck,
std::optional<::mlir::triton::PaddingOption> padding,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
// Operands
state.addOperands(ptr);
if (mask) {
state.addOperands(mask);
@@ -171,9 +237,20 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
state.addOperands(other);
}
}
// Attributes
state.addAttribute(
getOperandSegmentSizesAttrName(state.name),
builder.getDenseI32ArrayAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
if (boundaryCheck.has_value()) {
state.addAttribute(getBoundaryCheckAttrName(state.name),
builder.getDenseI32ArrayAttr(boundaryCheck.value()));
}
if (padding.has_value()) {
state.addAttribute(getPaddingAttrName(state.name),
::mlir::triton::PaddingOptionAttr::get(
builder.getContext(), padding.value()));
}
state.addAttribute(
getCacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
@@ -182,9 +259,39 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(getIsVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
// Result type
Type resultType = getLoadOpResultType(builder, ptr.getType());
state.addTypes({resultType});
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
/*boundaryCheck=*/{}, cache, evict);
}
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{},
cache, evict);
}
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value,
ArrayRef<int32_t> boundaryCheck,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
builder.getDenseI32ArrayAttr(boundaryCheck), cache,
evict);
}
//-- TransOp --
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
@@ -424,5 +531,27 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return {};
}
//-- MakeTensorPtrOp --
void MakeTensorPtrOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &state, ::mlir::Value base,
::mlir::ValueRange shape,
::mlir::ValueRange strides,
::mlir::ValueRange offsets,
ArrayRef<int32_t> tensorShape,
ArrayRef<int32_t> order) {
// Get pointer type from `base`
auto pointerType = base.getType().cast<PointerType>();
assert(pointerType != nullptr);
// Build type `tt.ptr<tensor<tensorShape, base.pointeeType>>`
auto tensorType = RankedTensorType::get(
SmallVector<int64_t>(tensorShape.begin(), tensorShape.end()),
pointerType.getPointeeType());
auto result = PointerType::get(tensorType, 1);
return build(builder, state, result, base, shape, strides, offsets,
builder.getDenseI32ArrayAttr(order));
}
} // namespace triton
} // namespace mlir

View File

@@ -1,42 +1,59 @@
#include "triton/Dialect/Triton/IR/Traits.h"
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
using namespace mlir;
auto encA = tyA.dyn_cast<RankedTensorType>();
auto encB = tyB.dyn_cast<RankedTensorType>();
if (!encA || !encB)
#include "mlir/IR/TypeUtilities.h"
#include "triton/Dialect/Triton/IR/Types.h"
using namespace mlir;
static LogicalResult verifySameEncoding(Type typeA, Type typeB,
bool allowTensorPointerType) {
auto getEncoding = [=](Type type) -> Attribute {
auto rankedType = type.dyn_cast<RankedTensorType>();
if (allowTensorPointerType) {
if (auto ptrType = type.dyn_cast<triton::PointerType>())
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
} else {
assert(!isTensorPointerType(type));
}
return rankedType ? rankedType.getEncoding() : Attribute();
};
auto encodingA = getEncoding(typeA);
auto encodingB = getEncoding(typeB);
if (!encodingA || !encodingB)
return success();
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
return encodingA == encodingB ? success() : failure();
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
LogicalResult
OpTrait::impl::verifySameOperandsEncoding(Operation *op,
bool allowTensorPointerType) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
if (failed(verifySameEncoding(opType, type, allowTensorPointerType)))
return op->emitOpError() << "requires the same encoding for all operands";
return success();
}
LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding(
Operation *op, bool allowTensorPointerType) {
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto resultType : op->getResultTypes())
if (failed(verifySameEncoding(resultType, type)))
if (failed(verifySameEncoding(resultType, type, allowTensorPointerType)))
return op->emitOpError()
<< "requires the same encoding for all operands and results";
return verifySameOperandsEncoding(op);
return verifySameOperandsEncoding(op, allowTensorPointerType);
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
if (failed(verifySameEncoding(opType, type)))
return op->emitOpError() << "requires the same encoding for all operands";
return success();
}
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) {
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
@@ -69,3 +86,55 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
}
return success();
}
static ArrayRef<int64_t> getTypeShape(Type type) {
auto rankedType = type.dyn_cast<RankedTensorType>();
if (auto ptrType = type.dyn_cast<triton::PointerType>())
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
return rankedType ? rankedType.getShape() : ArrayRef<int64_t>();
}
LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto firstOperandShape = getTypeShape(op->getOperand(0).getType());
for (auto type : llvm::drop_begin(op->getOperandTypes(), 1))
if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape)))
return op->emitOpError() << "requires the same shape for all operands";
return success();
}
LogicalResult
OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();
auto firstOperandShape = getTypeShape(op->getOperand(0).getType());
for (auto type : op->getResultTypes())
if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
return verifySameLoadStoreOperandsShape(op);
}
bool OpTrait::impl::verifyLoadStorePointerAndValueType(Type valueType,
Type ptrType) {
if (isTensorPointerType(ptrType)) {
return ptrType.cast<triton::PointerType>().getPointeeType() == valueType;
} else if (auto rankedType = ptrType.dyn_cast<RankedTensorType>()) {
if (auto elementPtrType =
dyn_cast<triton::PointerType>(rankedType.getElementType())) {
auto inferValueType = RankedTensorType::get(
rankedType.getShape(), elementPtrType.getPointeeType(),
rankedType.getEncoding());
return inferValueType == valueType;
}
} else if (auto scalarPtrType = ptrType.dyn_cast<triton::PointerType>()) {
return scalarPtrType.getPointeeType() == valueType;
}
return false;
}

View File

@@ -46,4 +46,10 @@ unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
return pointeeType.getIntOrFloatBitWidth();
}
bool isTensorPointerType(Type type) {
if (auto ptrType = type.dyn_cast<PointerType>())
return ptrType.getPointeeType().isa<RankedTensorType>();
return false;
}
} // namespace mlir

View File

@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)
add_mlir_dialect_library(TritonTransforms
Combine.cpp
RewriteTensorPointer.cpp
DEPENDS
TritonTransformsIncGen

View File

@@ -94,7 +94,8 @@ public:
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, loadOp.getPtr(), loadOp.getMask(), falseValue, loadOp.getCache(),
op, loadOp.getPtr(), loadOp.getMask(), falseValue,
loadOp.getBoundaryCheck(), loadOp.getPadding(), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
return mlir::success();
}
@@ -127,6 +128,7 @@ struct CanonicalizeMaskedLoadPattern
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::LoadOp>(
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
} else {
// mask = splat(0)

View File

@@ -0,0 +1,503 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include <memory>
#include <stack>
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
/// An additional struct to record the meta information of operations
/// with tensor pointers
struct RewritedInfo {
private:
Value base;
SmallVector<Value> shape;
SmallVector<Value> strides;
SmallVector<Value> offsets;
ArrayRef<int64_t> tensorShape;
// A cache to avoid generating the same offset with range
DenseMap<unsigned, Value> cachedOffsetWithRange;
public:
RewritedInfo() = default;
RewritedInfo(const RewritedInfo &other) = default;
RewritedInfo(Value base, const SmallVector<Value> &shape,
const SmallVector<Value> &strides,
const SmallVector<Value> &offsets,
const ArrayRef<int64_t> &tensorShape)
: base(base), shape(shape), strides(strides), offsets(offsets),
tensorShape(tensorShape) {
assert(shape.size() == strides.size() && shape.size() == offsets.size() &&
shape.size() == tensorShape.size());
}
unsigned int length() const { return shape.size(); }
Value getOffset(unsigned i) { return offsets[i]; }
SmallVector<Value> getOffsets() { return offsets; }
void setOffset(unsigned i, Value newOffset) {
offsets[i] = newOffset;
cachedOffsetWithRange.clear();
}
void setOffsets(const SmallVector<Value> &newOffsets) {
offsets = newOffsets;
cachedOffsetWithRange.clear();
}
Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
unsigned i) {
if (cachedOffsetWithRange.count(i))
return cachedOffsetWithRange[i];
// Add range
auto indexI32RowType =
RankedTensorType::get({tensorShape[i]}, builder.getI32Type());
auto indexRowType =
RankedTensorType::get({tensorShape[i]}, builder.getI64Type());
Value splatOffset =
builder.create<triton::SplatOp>(loc, indexRowType, offsets[i]);
Value range = builder.create<triton::MakeRangeOp>(loc, indexI32RowType, 0,
tensorShape[i]);
Value i64Range = builder.create<arith::ExtSIOp>(loc, indexRowType, range);
// Expand dimensions
Value expandedResult =
builder.create<arith::AddIOp>(loc, splatOffset, i64Range);
for (int j = 0; j < tensorShape.size(); ++j) {
if (j == i)
continue;
expandedResult =
builder.create<triton::ExpandDimsOp>(loc, expandedResult, j);
}
return cachedOffsetWithRange[i] = expandedResult;
}
Value generatePtr(OpBuilder &builder, const Location &loc) {
assert(tensorShape.size() == offsets.size() &&
tensorShape.size() == strides.size());
auto indexTensorType =
RankedTensorType::get(tensorShape, builder.getI64Type());
auto ptrType = base.getType().cast<triton::PointerType>();
auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType);
// Generate offsets per dimension
Value ptr = builder.create<triton::SplatOp>(loc, ptrTensorType, base);
for (unsigned i = 0; i < tensorShape.size(); ++i) {
auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i);
// We must splat strides into the expanded shape not a row for retaining
// the divisibility information given by strides
Value splatStride = builder.create<triton::SplatOp>(
loc, offsetWithRange.getType(), strides[i]);
Value offsetWithStride =
builder.create<arith::MulIOp>(loc, offsetWithRange, splatStride);
Value broadcasted = builder.create<triton::BroadcastOp>(
loc, indexTensorType, offsetWithStride);
// Add to the pointer
ptr = builder.create<triton::AddPtrOp>(loc, ptrTensorType, ptr,
broadcasted);
}
return ptr;
}
Value generateMask(OpBuilder &builder, const Location &loc,
const std::optional<ArrayRef<int32_t>> &boundaryCheck) {
if (!boundaryCheck.has_value())
return {};
// Generate mask per dimension
auto maskTensorType =
RankedTensorType::get(tensorShape, builder.getI1Type());
Value mask;
for (auto i : boundaryCheck.value()) {
auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i);
// Compare with lower bound
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
loc, 0, builder.getI64Type());
Value splatLowerBound = builder.create<triton::SplatOp>(
loc, offsetWithRange.getType(), lowerBound);
Value cmpLower = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound);
// Compare with upper bound
Value splatUpperBound = builder.create<triton::SplatOp>(
loc, offsetWithRange.getType(), shape[i]);
Value cmpUpper = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
// And and broadcast
Value andResult = builder.create<arith::AndIOp>(loc, cmpLower, cmpUpper);
Value broadcasted =
builder.create<triton::BroadcastOp>(loc, maskTensorType, andResult);
// And up all results
if (!mask) {
mask = broadcasted;
} else {
mask = builder.create<arith::AndIOp>(loc, mask, broadcasted);
}
}
return mask;
}
Value generateOther(OpBuilder &builder, const Location &loc,
const std::optional<triton::PaddingOption> &padding) {
if (!padding.has_value())
return Value();
// Create element attribute
auto elementType =
base.getType().cast<triton::PointerType>().getPointeeType();
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);
// Set zero padding value
Attribute attr =
elementType.isIntOrIndex()
? builder.getIntegerAttr(elementType, 0).cast<Attribute>()
: builder.getFloatAttr(elementType, 0).cast<Attribute>();
// Float NaN padding case
if (padding.value() == triton::PaddingOption::PAD_NAN) {
assert(!elementType.isIntOrIndex());
auto apNaN = llvm::APFloat::getNaN(
attr.cast<FloatAttr>().getValue().getSemantics());
attr = builder.getFloatAttr(elementType, apNaN);
}
// Create tensor
Value constant = builder.create<arith::ConstantOp>(loc, attr);
return builder.create<triton::SplatOp>(loc, otherTensorType, constant);
}
};
class RewriteTensorPointerPass
: public TritonRewriteTensorPointerBase<RewriteTensorPointerPass> {
private:
int computeCapability;
DenseMap<Value, RewritedInfo> rewritedInfo;
public:
explicit RewriteTensorPointerPass(int computeCapability)
: computeCapability(computeCapability) {}
static bool needRewrite(Operation *op) {
return std::any_of(
op->getOperands().begin(), op->getOperands().end(),
[](Value operand) { return isTensorPointerType(operand.getType()); });
}
static SmallVector<Value>
generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index,
const SmallVector<Value> &newValues) {
assert(index < oldOperands.size());
SmallVector<Value> newOperands;
for (int i = 0; i < index; ++i)
newOperands.push_back(oldOperands[i]);
for (auto value : newValues)
newOperands.push_back(value);
for (auto i = index + 1; i < oldOperands.size(); ++i)
newOperands.push_back(oldOperands[i]);
return newOperands;
}
Operation *rewriteMakeTensorPtrOp(OpBuilder &builder,
triton::MakeTensorPtrOp op,
std::stack<Operation *> &eraser) {
// Save info for later use
auto ptrType = op.getResult().getType().cast<triton::PointerType>();
auto tensorType = ptrType.getPointeeType().cast<RankedTensorType>();
// Cast I32 offsets into I64
SmallVector<Value> i64Offsets;
for (auto offset : op.getOffsets()) {
auto i64Offset = builder.create<arith::ExtSIOp>(
op.getLoc(), builder.getI64Type(), offset);
i64Offsets.push_back(i64Offset);
}
// Save information
rewritedInfo[op.getResult()] =
RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets,
tensorType.getShape());
// Erase the original operation
eraser.push(op);
return nullptr;
}
Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op,
std::stack<Operation *> &eraser) {
// Get info from previous results
assert(rewritedInfo.count(op.getPtr()));
auto info = rewritedInfo[op.getPtr()];
// Calculate new offsets
assert(info.length() == op.getOffsets().size());
SmallVector<Value> newOffsets;
for (int i = 0; i < info.length(); ++i) {
Value i64Offset = builder.create<arith::ExtSIOp>(
op.getLoc(), builder.getI64Type(), op.getOffsets()[i]);
Value newOffset = builder.create<arith::AddIOp>(
op.getLoc(), info.getOffset(i), i64Offset);
newOffsets.push_back(newOffset);
}
// Save info for later use
info.setOffsets(newOffsets);
rewritedInfo[op.getResult()] = info;
// Erase the original operation
eraser.push(op);
return nullptr;
}
Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op,
std::stack<Operation *> &eraser) {
assert(isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op));
// We only have to rewrite load/stores with tensor pointers
auto ptr = op->getOperand(0);
if (!isTensorPointerType(ptr.getType()))
return nullptr;
// Get info from previous results
assert(rewritedInfo.count(ptr));
auto info = rewritedInfo[ptr];
// Load/store with tensor pointers implicitly will check the bound while
// accessing memory, so we should set `mask` and `other` (according to the
// padding). Also note that load with tensor pointers do not have `mask` and
// `other` while building IR from Python AST
std::optional<ArrayRef<int>> boundaryCheck;
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
assert(!loadOp.getMask() && !loadOp.getOther());
boundaryCheck = loadOp.getBoundaryCheck();
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
assert(!storeOp.getMask());
boundaryCheck = storeOp.getBoundaryCheck();
}
// Generate new `ptr`, `mask` and `other`
auto newPtr = info.generatePtr(builder, op->getLoc());
auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck);
Value newOther;
if (auto loadOp = dyn_cast<triton::LoadOp>(op))
newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding());
// Create a new operation
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
auto newResult = builder.create<triton::LoadOp>(
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
op->getResult(0).replaceAllUsesWith(newResult);
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
builder.create<triton::StoreOp>(storeOp.getLoc(), newPtr,
storeOp.getValue(), newMask,
storeOp.getCache(), storeOp.getEvict());
}
// Erase the original operation
eraser.push(op);
return nullptr;
}
Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op,
std::stack<Operation *> &eraser) {
// Generate new iteration operands and set rewrited information
SmallVector<Value> oldIterOperands = op.getIterOperands();
SmallVector<Value> newIterOperands = op.getIterOperands();
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
++i, ++oldI) {
if (!isTensorPointerType(newIterOperands[i].getType()))
continue;
// Expand the tensor pointer into offsets
assert(rewritedInfo.count(newIterOperands[i]));
auto info = rewritedInfo[newIterOperands[i]];
newIterOperands =
generateNewOperands(newIterOperands, i, info.getOffsets());
i += info.length() - 1;
size += info.length() - 1;
}
// Rebuild the loop type
auto newForOp = builder.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
op.getUpperBound(), op.getStep(),
newIterOperands);
// Create value mapping. Note that for tensor pointers, we use identity
// mapping. It may refer to a value in the old loop, but we will rewrite it
// later
IRMapping mapping;
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (isTensorPointerType(oldRegionIterArg.getType())) {
// Pass rewrited info inside
assert(rewritedInfo.count(oldIterOperands[oldI]));
auto info = rewritedInfo[oldIterOperands[oldI]];
mapping.map(oldRegionIterArg, oldRegionIterArg);
for (unsigned j = 0; j < info.length(); ++j)
info.setOffset(j, newForOp.getRegionIterArg(i + j));
rewritedInfo[oldRegionIterArg] = info;
i += info.length() - 1;
} else {
mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i));
}
}
mapping.map(op.getInductionVar(), newForOp.getInductionVar());
// Clone body
builder.setInsertionPointToStart(newForOp.getBody());
for (auto &opInFor : *op.getBody()) {
auto *newOp = builder.clone(opInFor, mapping);
for (unsigned i = 0; i < opInFor.getNumResults(); ++i)
mapping.map(op->getResult(i), newOp->getResult(i));
}
// Replace later usages
assert(op.getNumResults() == op.getNumIterOperands());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (isTensorPointerType(oldResult.getType())) {
// Pack new offsets into rewrited info
assert(rewritedInfo.count(oldIterOperands[oldI]));
auto info = rewritedInfo[oldIterOperands[oldI]];
for (unsigned j = 0; j < info.length(); ++j)
info.setOffset(j, newForOp.getResult(i + j));
i += info.length() - 1;
rewritedInfo[oldResult] = info;
} else {
oldResult.replaceAllUsesWith(newForOp.getResult(i));
}
}
// Erase later
eraser.push(op);
return newForOp;
}
Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op,
std::stack<Operation *> &eraser) {
// Replace tensor pointers with offsets
SmallVector<Value> newOperands = op->getOperands();
for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) {
if (!isTensorPointerType(newOperands[i].getType()))
continue;
assert(rewritedInfo.count(newOperands[i]));
auto info = rewritedInfo[newOperands[i]];
newOperands = generateNewOperands(newOperands, i, info.getOffsets());
i += info.length() - 1;
size += info.length() - 1;
}
op->setOperands(newOperands);
// No need to erase
return nullptr;
}
Operation *rewriteOp(Operation *op, std::stack<Operation *> &eraser) {
OpBuilder builder(op);
// Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers
// Rewriting functions return the next operation to visit, if there is no
// next one, simply return `nullptr`
std::pair<Value, RewritedInfo> rewrited;
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
return rewriteAdvanceOp(builder, advanceOp, eraser);
} else if (isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op)) {
return rewriteLoadStoreOp(builder, op, eraser);
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
return rewriteLoadStoreOp(builder, op, eraser);
} else if (op->getDialect()->getNamespace() == "scf" ||
op->getDialect()->getNamespace() == "cf") {
if (!needRewrite(op))
return op;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
return rewriteForOp(builder, forOp, eraser);
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
return rewriteYieldOp(builder, yieldOp, eraser);
} else {
llvm_unreachable("Currently we only support tensor pointer usages "
"inside a `scf::ForOp`, others such as `scf::IfOp`,"
"`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` "
"are not supported yet");
}
}
// Otherwise return the original one
return op;
}
void visitOperation(Operation *op, std::stack<Operation *> &eraser) {
for (auto &region : op->getRegions()) {
for (auto &block : region) {
// We need an extra copy because erasing operations may break the
// iterator behavior
SmallVector<Operation *> blockCopy;
for (auto &nestedOp : block)
blockCopy.push_back(&nestedOp);
// Rewrite and recursively visit
for (auto &nestedOp : blockCopy) {
if (auto newOp = rewriteOp(nestedOp, eraser))
visitOperation(newOp, eraser);
}
}
}
}
void runOnOperation() override {
// Only rewrite if the hardware does not support
if (computeCapability >= 90)
return;
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
// MLIR does not support one-multiple value mapping. For example, if we use
// `ConversionPatternRewriter`, we can not make a type converter, which
// converts `ptr<tensor>` into multiple types `ptr<>, int64, int64, ...`
// (containing the base/offsets/strides...). What we can do is to convert
// `ptr<tensor>` into a single type `Tuple<ptr<>, int64, int64, ...>`. But
// in this way, we also have to define `PackTuple` and `UnpackTuple`
// operations and make a canonicalization pass to optimize, which is much
// So here we recursively build the IR, to be specific, we have to rewrite
// `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`,
// `scf.for` (tensor pointer usages may be in a loop fashion)
std::stack<Operation *> eraser;
visitOperation(getOperation(), eraser);
// The operation could not be erased during visit, because they may have
// later usages, so we erase after visit
rewritedInfo.clear();
while (!eraser.empty()) {
auto op = eraser.top();
eraser.pop();
op->erase();
}
}
};
std::unique_ptr<Pass>
triton::createRewriteTensorPointerPass(int computeCapability) {
return std::make_unique<RewriteTensorPointerPass>(computeCapability);
}

View File

@@ -407,8 +407,9 @@ void LoopPipeliner::emitPrologue() {
newOp = builder.create<triton::LoadOp>(
loadOp.getLoc(), loadOp.getResult().getType(),
lookupOrDefault(loadOp.getPtr(), stage), newMask,
lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
lookupOrDefault(loadOp.getOther(), stage),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
addNamedAttrs(newOp, op->getAttrDictionary());
} else {
newOp = builder.clone(*op);
@@ -630,8 +631,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextOp = builder.create<triton::LoadOp>(
loadOp.getLoc(), loadOp.getResult().getType(),
nextMapping.lookupOrDefault(loadOp.getPtr()), newMask,
nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
nextMapping.lookupOrDefault(loadOp.getOther()),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
addNamedAttrs(nextOp, op->getAttrDictionary());
nextMapping.map(loadOp.getResult(), nextOp->getResult(0));
} else {

View File

@@ -78,6 +78,11 @@ void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;
py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION")
.value("PAD_ZERO", mlir::triton::PaddingOption::PAD_ZERO)
.value("PAD_NAN", mlir::triton::PaddingOption::PAD_NAN)
.export_values();
py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER")
.value("NONE", mlir::triton::CacheModifier::NONE)
.value("CA", mlir::triton::CacheModifier::CA)
@@ -1124,6 +1129,27 @@ void init_triton_ir(py::module &&m) {
self.create<mlir::triton::StoreOp>(loc, ptrs, value, cacheModifier,
evictionPolicy);
})
.def("create_tensor_pointer_load",
[](mlir::OpBuilder &self, mlir::Value &ptr,
std::vector<int32_t> &boundaryCheck,
std::optional<mlir::triton::PaddingOption> paddingOption,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::LoadOp>(
loc, ptr, boundaryCheck, paddingOption, cacheModifier,
evictionPolicy, isVolatile);
})
.def("create_tensor_pointer_store",
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &val,
std::vector<int32_t> &boundaryCheck,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptr, val, boundaryCheck,
cacheModifier, evictionPolicy);
})
.def("create_masked_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
std::optional<mlir::Value> &other,
@@ -1390,10 +1416,31 @@ void init_triton_ir(py::module &&m) {
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
// Force GPU barrier
.def("create_barrier", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
});
.def("create_barrier",
[](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
})
// Make a block pointer (tensor pointer in Triton IR)
.def("create_make_block_ptr",
[](mlir::OpBuilder &self, mlir::Value &base,
std::vector<mlir::Value> &shape,
std::vector<mlir::Value> &strides,
std::vector<mlir::Value> &offsets,
std::vector<int32_t> &tensorShape,
std::vector<int32_t> &order) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::MakeTensorPtrOp>(
loc, base, shape, strides, offsets, tensorShape, order);
})
// Advance a block pointer
.def("create_advance",
[](mlir::OpBuilder &self, mlir::Value &ptr,
std::vector<mlir::Value> &offsets) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::AdvanceOp>(loc, ptr.getType(),
ptr, offsets);
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
@@ -1448,6 +1495,11 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
})
.def("add_convert_triton_to_tritongpu_pass",
[](mlir::PassManager &self, int numWarps) {
self.addPass(

View File

@@ -0,0 +1,102 @@
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
pid = tl.program_id(0)
# We only copy half of the data to see if the padding works
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
tl.store(b_block_ptr, a, boundary_check=(0, ))
@pytest.mark.parametrize("dtype_str, n, padding_option",
[(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16")
for n in (64, 128, 256, 512, 1024)
for padding in ("zero", "nan")])
def test_block_copy(dtype_str, n, padding_option):
capability = torch.cuda.get_device_capability()
if capability[0] >= 9:
pytest.skip("Hopper support is working in progress")
dtype = getattr(torch, dtype_str)
if dtype_str in ("bool", "int16"):
if padding_option == "nan":
pytest.skip("Padding with NaN is not supported for integer types")
a = torch.randint(0, 2, (n, ), device="cuda", dtype=dtype)
else:
a = torch.randn((n, ), device="cuda", dtype=dtype)
b = torch.zeros((n, ), device="cuda", dtype=dtype)
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
assert torch.all(a[0: n // 2] == b[0: n // 2])
if padding_option == "zero":
assert torch.all(b[n // 2: n] == 0)
else:
assert torch.all(torch.isnan(b[n // 2: n]))
@triton.jit
def matmul_no_scf_with_advance_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
# Below two lines are just for testing negative offsets for the `advance` API, which could be removed
a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))
a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))
a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero")
b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero")
c = tl.dot(a, b)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
@pytest.mark.parametrize("shape, num_warps", [
(shape, num_warps)
for shape in [
[64, 64, 16],
[64, 64, 32],
[64, 64, 64],
]
for num_warps in [4, 8]
])
def test_block_ptr_matmul_no_scf(shape, num_warps):
capability = torch.cuda.get_device_capability()
if capability[0] >= 9:
pytest.skip("Hopper support is working in progress")
m, n, k = shape
a = torch.randn((m, k), device="cuda", dtype=torch.float16)
b = torch.randn((k, n), device="cuda", dtype=torch.float16)
c = torch.empty((m, n), device="cuda", dtype=torch.float32)
grid = lambda META: (1, )
matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
M=m, N=n, K=k,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
num_warps=num_warps)
golden = torch.matmul(a, b)
torch.testing.assert_allclose(c, golden)

View File

@@ -1087,10 +1087,17 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
return ret, generator
def optimize_triton_ir(mod):
def inline_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.run(mod)
return mod
def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
@@ -1100,13 +1107,25 @@ def optimize_triton_ir(mod):
return mod
def ast_to_ttir(fn, signature, specialization, constants, debug=False):
def ttir_compute_capability_rewrite(mod, compute_capability):
# For hardware without support, we must rewrite all load/store with block (tensor) pointers into legacy load/store
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_rewrite_tensor_pointer_pass(compute_capability)
pm.run(mod)
return mod
def ast_to_ttir(fn, signature, specialization, constants, compute_capability, debug=False):
mod, _ = build_triton_ir(fn, signature, specialization, constants, debug)
mod = inline_triton_ir(mod)
mod = ttir_compute_capability_rewrite(mod, compute_capability)
return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
@@ -1902,7 +1921,7 @@ def compile(fn, **kwargs):
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
lambda src: ast_to_ttir(src, signature, configs[0], constants, capability)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
@@ -1916,7 +1935,7 @@ def compile(fn, **kwargs):
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
lambda src: ast_to_ttir(src, signature, configs[0], constants, capability, debug)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
@@ -1926,7 +1945,6 @@ def compile(fn, **kwargs):
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)

View File

@@ -8,6 +8,7 @@ from ..impl import (
from . import math
from .core import (
abs,
advance,
arange,
argmin,
argmax,
@@ -48,6 +49,7 @@ from .core import (
int8,
load,
log,
make_block_ptr,
max,
max_contiguous,
maximum,
@@ -101,6 +103,7 @@ from .random import (
__all__ = [
"abs",
"advance",
"arange",
"argmin",
"argmax",
@@ -144,6 +147,7 @@ __all__ = [
"math",
"load",
"log",
"make_block_ptr",
"max",
"max_contiguous",
"maximum",

View File

@@ -884,55 +884,87 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
@builtin
def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option="", cache_modifier="",
eviction_policy="", volatile=False, _builder=None):
"""
Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
(1) `pointer` could be a single element pointer, then a scalar will be loaded
- `mask` and `other` must be scalar too
- `other` is implicitly typecast to `pointer.dtype.element_ty`
- `boundary_check` and `padding_option` must be empty
(2) `pointer` could be element-wise tensor of pointers, in which case:
- `mask` and `other` are implicitly broadcast to `pointer.shape`
- `other` is implicitly typecast to `pointer.dtype.element_ty`
- `boundary_check` and `padding_option` must be empty
(3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
- `mask` and `other` must be None
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
:param pointer: Pointers to the data to be loaded.
:type pointer: Block of dtype=triton.PointerDType
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional
:param other: if mask[idx] is false, return other[idx]
:param pointer: Pointer to the data to be loaded
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
:param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
(must be `None` with block pointers)
:type mask: Block of `triton.int1`, optional
:param other: if `mask[idx]` is false, return `other[idx]`
:type other: Block, optional
:param cache_modifier: changes cache option in nvidia ptx
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
:type boundary_check: tuple of ints, optional
:param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional
:param volatile: changes volatile option in NVIDIA PTX
:type volatile: bool, optional
"""
# mask, other can be constexpr
# `mask` and `other` can be constexpr
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
if _constexpr_to_value(other) is not None:
other = _to_tensor(other, _builder)
padding_option = _constexpr_to_value(padding_option)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
volatile = _constexpr_to_value(volatile)
return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
volatile, _builder)
@builtin
def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None):
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
"""
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
Store a tensor of data into memory locations defined by `pointer`:
(1) `pointer` could be a single element pointer, then a scalar will be stored
- `mask` must be scalar too
- `boundary_check` and `padding_option` must be empty
(2) `pointer` could be element-wise tensor of pointers, in which case:
- `mask` is implicitly broadcast to `pointer.shape`
- `boundary_check` must be empty
(3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
- `mask` must be None
- `boundary_check` can be specified to control the behavior of out-of-bound access
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
:param pointer: The memory locations where the elements of :code:`value` are stored.
:type pointer: Block of dtype=triton.PointerDType
:param value: The tensor of elements to be stored.
:param pointer: The memory location where the elements of `value` are stored
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
:param value: The tensor of elements to be stored
:type value: Block
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
:param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
:type mask: Block of triton.int1, optional
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
:type boundary_check: tuple of ints, optional
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional
"""
# value can be constexpr
# `value` can be constexpr
value = _to_tensor(value, _builder)
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder)
return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
# -----------------------
@@ -1387,6 +1419,33 @@ def device_assert(cond, msg="", _builder=None):
lineno = frame.f_back.f_lineno
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
@builtin
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
"""
Returns a pointer to a block in a parent tensor
:param base: The base pointer to the parent tensor
:param shape: The shape of the parent tensor
:param strides: The strides of the parent tensor
:param offsets: The offsets to the block
:param block_shape: The shape of the block
:param order: The order of the original data format
"""
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
@builtin
def advance(base: tensor, offsets, _builder=None):
"""
Advance a block pointer
:param base: the block pointer to advance
:param offsets: the offsets to advance, a tuple by dimension
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Iterators
# -----------------------

View File

@@ -771,7 +771,7 @@ def cast(input: tl.tensor,
# ===----------------------------------------------------------------------===//
def str_to_cache_modifier(cache_modifier):
def _str_to_cache_modifier(cache_modifier):
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
@@ -783,7 +783,7 @@ def str_to_cache_modifier(cache_modifier):
return cache
def str_to_eviction_policy(eviction_policy):
def _str_to_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
@@ -795,97 +795,219 @@ def str_to_eviction_policy(eviction_policy):
return eviction
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
def _str_to_padding_option(padding_option):
padding = None # default
if padding_option:
if padding_option == "zero":
padding = ir.PADDING_OPTION.PAD_ZERO
elif padding_option == "nan":
padding = ir.PADDING_OPTION.PAD_NAN
else:
raise ValueError(f"Padding option {padding_option} not supported")
return padding
def _canonicalize_boundary_check(boundary_check, block_shape):
if boundary_check:
if not hasattr(boundary_check, "__iter__"):
boundary_check = [boundary_check]
boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
for dim in boundary_check:
assert isinstance(dim, int) and 0 <= dim < len(block_shape)
assert len(boundary_check) > 0
assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
return sorted(boundary_check)
return tuple()
def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
# Load by a block pointer: `pointer_type<block_type<>>`
# Block pointer can not have `mask` and `other` arguments
if mask or other:
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
elt_ty = ptr.type.element_ty.element_ty
assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
raise ValueError("Padding option `nan` is not supported for integer block pointers")
# `dst_ty` is de-referenced type of the pointer type
dst_ty = ptr.type.element_ty
# Check `boundary_check` argument
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
# Build IR
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
is_volatile), dst_ty)
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__())
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
# Check `mask`, `other`, `boundary_check`, and `padding` arguments
if not mask and other:
raise ValueError("`other` cannot be provided without `mask`")
if padding or boundary_check:
raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
"pointers or loading a scalar. Because the compiler does not know the boundary; please "
"use block pointers (defined by `make_block_ptr`) instead")
# For a pointer of scalar, check the type of `mask` and `other`
if not ptr.type.is_block():
if mask and mask.type.is_block():
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
if other and other.type.is_block():
raise ValueError("Other argument cannot be block type if pointer argument is not a block")
# Make `mask` and `other` into the same shape as `ptr`
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if other:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
# Get `pointer_type<elt_ty>` and `elt_ty`
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# Cast `other` into `ele_ty` type
if other:
other = cast(other, elt_ty, builder)
# cache modifier
# Create loaded result type `dst_ty`
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
dst_ty = tl.block_type(elt_ty, shape)
else:
# Load by de-referencing the pointer of scalar
dst_ty = elt_ty
cache = str_to_cache_modifier(cache_modifier)
eviction = str_to_eviction_policy(eviction_policy)
# Build IR
if not mask:
if other:
raise ValueError("`other` cannot be provided without `mask`")
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
dst_ty)
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle if other else None,
cache, eviction, is_volatile),
dst_ty)
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
eviction, is_volatile), dst_ty)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
cache_modifier: str,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
boundary_check,
padding_option: str,
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
padding = _str_to_padding_option(padding_option)
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# Load by a block pointer: `pointer_type<block_type<>>`
return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
else:
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
# Store by a block pointer: `pointer_type<block_type<>>`
# Block pointers can not have the `mask` argument
if mask:
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
# Check same shape
block_shape = ptr.type.element_ty.get_block_shapes()
if not val.type.is_block():
val = broadcast_impl_shape(val, block_shape, builder)
assert val.type.is_block(), "Value argument must be block type or a scalar"
assert block_shape == val.type.get_block_shapes(), "Block shape and value shape mismatch"
elt_ty = ptr.type.element_ty.element_ty
assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
# Check `boundary_check` argument
boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
# Build IR
return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
tl.void)
def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
# Check `boundary_check` argument
if boundary_check:
raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
"scalar. Because the compiler does not know the boundary; please use block pointers "
"(defined by `make_block_ptr`) instead")
# For a pointer of scalar, check the type of `val` and `mask`
if not ptr.type.is_block():
if val.type.is_block():
raise ValueError("Value argument cannot be block type if pointer argument is not a block")
if mask and mask.type.is_block():
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
# Make `mask` and `val` into the same shape as `ptr`
if ptr.type.is_block():
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
if mask and ptr.type.is_block():
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# attributes
cache = str_to_cache_modifier(cache_modifier)
eviction = str_to_eviction_policy(eviction_policy)
# cast to target data-type
# Cast to target data type
val = cast(val, elt_ty, builder)
# Build IR
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
boundary_check,
cache_modifier: str,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
# Cache and eviction options
cache = _str_to_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# Store by a block pointer: `pointer_type<block_type<>>`
return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
else:
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)
#########
# atomic
#########
@@ -1265,3 +1387,70 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
def _convert_elem_to_ir_value(builder, elem, require_i64):
if isinstance(elem, tl.constexpr):
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
elif isinstance(elem, tl.tensor):
assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
if elem.dtype != tl.int64 and require_i64:
return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed())
elif elem.dtype != tl.int32:
return builder.create_int_cast(elem.handle, builder.get_int32_ty(), elem.dtype.is_int_signed())
return elem.handle
assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
def _convert_to_ir_values(builder, list_like, require_i64=True):
if hasattr(list_like, "__iter__"):
return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like]
return [_convert_elem_to_ir_value(builder, list_like, require_i64)]
def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor:
# Convert dynamic arguments to IR values
# NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
shape = _convert_to_ir_values(builder, shape)
strides = _convert_to_ir_values(builder, strides)
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
# Check `base` type
if not base.type.is_ptr() or base.type.element_ty.is_block():
raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
if base.type.element_ty == tl.int1:
base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder)
# Check whether `block_shape` is static
if not hasattr(block_shape, "__iter__"):
block_shape = [block_shape]
block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
assert all([isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape]), \
"Expected a list of constant integers (`int32_t` range) in `block_shape`"
# Check `order`
if not hasattr(order, "__iter__"):
order = [order]
order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
# Must have same length
assert all([len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]]), \
"Expected shape/strides/offsets/block_shape to have the same length"
# Build value, the type is:
# `pointer_type<blocked<shape, element_type>>` in Python
# `tt.ptr<tensor<shape, element_type>>` in MLIR
handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
# Convert dynamic offsets to IR values
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
# Advanced block pointer type is the same as before
return tl.tensor(builder.create_advance(base.handle, offsets), base.type)

View File

@@ -0,0 +1,83 @@
// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s
func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%c31_i32 = arith.constant 31 : i32
%c127_i32 = arith.constant 127 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%2 = arith.addi %arg3, %c127_i32 : i32
%3 = arith.divsi %2, %c128_i32 : i32
%4 = arith.addi %arg4, %c31_i32 : i32
%5 = arith.divsi %4, %c32_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %0, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.cmpi slt, %9, %c8_i32 : i32
%11 = arith.select %10, %9, %c8_i32 : i32
%12 = arith.remsi %0, %11 : i32
%13 = arith.addi %8, %12 : i32
%14 = arith.remsi %0, %6 : i32
%15 = arith.divsi %14, %11 : i32
%16 = arith.muli %13, %c128_i32 : i32
%17 = arith.muli %1, %c32_i32 : i32
%18 = arith.extsi %arg3 : i32 to i64
%19 = arith.extsi %arg5 : i32 to i64
%20 = arith.extsi %arg6 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %17] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
%22 = arith.muli %15, %c32_i32 : i32
%23 = arith.extsi %arg4 : i32 to i64
%24 = arith.extsi %arg7 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%25 = tt.make_tensor_ptr %arg1, [%19, %23], [%24, %c1_i64], [%17, %22] {order = array<i32: 1, 0>} : !tt.ptr<tensor<32x32xf16>>
%26 = arith.addi %arg5, %c31_i32 : i32
%27 = arith.divsi %26, %c32_i32 : i32
%28 = arith.index_cast %27 : i32 to index
%29:3 = scf.for %arg9 = %c0 to %28 step %c1 iter_args(%arg10 = %cst, %arg11 = %21, %arg12 = %25) -> (tensor<128x32xf32>, !tt.ptr<tensor<128x32xf16>>, !tt.ptr<tensor<32x32xf16>>) {
// CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16>
%55 = tt.load %arg11 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr<tensor<128x32xf16>> -> tensor<128x32xf16>
// CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
%56 = tt.load %arg12 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr<tensor<32x32xf16>> -> tensor<32x32xf16>
%57 = tt.dot %55, %56, %arg10 {allowTF32 = true} : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32>
// CHECK-NOT: tt.advance
%58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr<tensor<128x32xf16>>
// CHECK-NOT: tt.advance
%59 = tt.advance %arg12, [%c32_i32, %c0_i32] : !tt.ptr<tensor<32x32xf16>>
scf.yield %57, %58, %59 : tensor<128x32xf32>, !tt.ptr<tensor<128x32xf16>>, !tt.ptr<tensor<32x32xf16>>
}
%30 = arith.truncf %29#0 : tensor<128x32xf32> to tensor<128x32xf16>
%31 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%32 = tt.splat %16 : (i32) -> tensor<128xi32>
%33 = arith.addi %32, %31 : tensor<128xi32>
%34 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%35 = tt.splat %22 : (i32) -> tensor<32xi32>
%36 = arith.addi %35, %34 : tensor<32xi32>
%37 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
%38 = tt.splat %arg8 : (i32) -> tensor<128x1xi32>
%39 = arith.muli %37, %38 : tensor<128x1xi32>
%40 = tt.expand_dims %36 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
%41 = tt.broadcast %39 : (tensor<128x1xi32>) -> tensor<128x32xi32>
%42 = tt.broadcast %40 : (tensor<1x32xi32>) -> tensor<128x32xi32>
%43 = arith.addi %41, %42 : tensor<128x32xi32>
%44 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>>
%45 = tt.addptr %44, %43 : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
%46 = tt.splat %arg3 : (i32) -> tensor<128xi32>
%47 = arith.cmpi slt, %33, %46 : tensor<128xi32>
%48 = tt.expand_dims %47 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1>
%49 = tt.splat %arg4 : (i32) -> tensor<32xi32>
%50 = arith.cmpi slt, %36, %49 : tensor<32xi32>
%51 = tt.expand_dims %50 {axis = 0 : i32} : (tensor<32xi1>) -> tensor<1x32xi1>
%52 = tt.broadcast %48 : (tensor<128x1xi1>) -> tensor<128x32xi1>
%53 = tt.broadcast %51 : (tensor<1x32xi1>) -> tensor<128x32xi1>
%54 = arith.andi %52, %53 : tensor<128x32xi1>
tt.store %45, %30, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<128x32xf16>
return
}

View File

@@ -1032,9 +1032,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
%c3136_i32 = arith.constant 3136 : i32
%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%c3136_i32 = arith.constant 3136 : index
%c256_i32 = arith.constant 256 : index
%c0_i32 = arith.constant 0 : index
%cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked>
%cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked>
%cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked>
@@ -1056,8 +1056,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%12 = tt.broadcast %11 : (tensor<1x1xi32, #blocked>) -> tensor<1x256xi32, #blocked>
%13 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1x256x!tt.ptr<f32>, #blocked>
%14 = tt.broadcast %7 : (tensor<1x1xi1, #blocked>) -> tensor<1x256xi1, #blocked>
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 {
%43 = tt.splat %arg5 : (i32) -> tensor<1x256xi32, #blocked>
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) {
%42 = arith.index_cast %arg5 : index to i32
%43 = tt.splat %42 : (i32) -> tensor<1x256xi32, #blocked>
%44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked>
%45 = "triton_gpu.cmpi"(%44, %cst_4) {predicate = 2 : i64} : (tensor<1x256xi32, #blocked>, tensor<1x256xi32, #blocked>) -> tensor<1x256xi1, #blocked>
%46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked>