mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Support block pointer semantics (#1392)
This PR introduces a new semantics: **block pointer**, which makes users easier & faster to load a block from a parent tensor. Below is a detailed API change by an example: ``` # Make a block pointer, which points to a block in the parent shape # `base`: the parent tensor # `shape`: the shape of the parent tensor # `strides`: the strides of the parent tensor # `offsets`: the offsets of the block in the parent tensor # `order`: the order of the data arrangement in memory # Below is an example loading a 2D column-major matrix block_ptr = tl.make_block_ptr(base=ptr, shape=(M, N), strides=(stride_m, stride_n), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) # Advance the offsets; note that the striding information is already saved in `block_ptr` # `base`: the block pointer to be advanced # `offsets`: the offsets for each dimension block_ptr = tl.advance(base=block_ptr, offsets=(BLOCK_M, -BLOCK_N)) block_ptr = tl.advance(base=block_ptr, offsets=(-BLOCK_M, BLOCK_N)) # Load from a block pointer, the output type is the dereferenced type of `block_ptr`, e.g. ptr<tensor<32x32xf32>> -> tensor<32x32xf32> # `ptr`: the block pointer to be loaded # `boundary_check`: a tuple of dimensions to check the boundary # `padding`: padding strategy for elements out of bound val = tl.load(ptr=block_ptr, boundary_check=(0, 1), padding="zero") # Store by a block pointer, in which the pointer and the value tensor should have the same shape # `ptr`: the block pointer to be stored # `boundary_check`: a tuple of dimensions to check the boundary (no-write if out of bound) tl.store(ptr=block_ptr, value=val, boundary_check=(0, 1)) ``` --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
#ifndef TRITON_IR_TRAITS_H_
|
||||
#define TRITON_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <iostream>
|
||||
@@ -12,11 +11,9 @@ namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// The rationale for this trait is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
@@ -25,7 +22,22 @@ LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxTensorNumElements = 1048576;
|
||||
|
||||
LogicalResult verifyTensorSize(Operation *op);
|
||||
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op,
|
||||
bool allowTensorPointerType = false);
|
||||
|
||||
LogicalResult
|
||||
verifySameOperandsAndResultEncoding(Operation *op,
|
||||
bool allowTensorPointerType = false);
|
||||
|
||||
LogicalResult verifySameLoadStoreOperandsShape(Operation *op);
|
||||
|
||||
LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op);
|
||||
|
||||
bool verifyLoadStorePointerAndValueType(Type valueType, Type ptrType);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <class ConcreteType>
|
||||
@@ -54,6 +66,44 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameLoadStoreOperandsShape
|
||||
: public TraitBase<ConcreteType, SameLoadStoreOperandsShape> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameLoadStoreOperandsShape(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameLoadStoreOperandsAndResultShape
|
||||
: public TraitBase<ConcreteType, SameLoadStoreOperandsAndResultShape> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameLoadStoreOperandsAndResultShape(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameLoadStoreOperandsEncoding
|
||||
: public TraitBase<ConcreteType, SameLoadStoreOperandsEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsEncoding(op,
|
||||
/*allowTensorPointerType=*/true);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameLoadStoreOperandsAndResultEncoding
|
||||
: public TraitBase<ConcreteType, SameLoadStoreOperandsAndResultEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultEncoding(
|
||||
op, /*allowTensorPointerType=*/true);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
// Attrs for LoadOp
|
||||
// Attributes for LoadOp
|
||||
def TT_CacheModifierAttr : I32EnumAttr<
|
||||
"CacheModifier", "",
|
||||
[
|
||||
@@ -13,6 +13,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
"EvictionPolicy", "",
|
||||
[
|
||||
@@ -23,6 +24,16 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_PaddingOptionAttr : I32EnumAttr<
|
||||
"PaddingOption", "",
|
||||
[
|
||||
I32EnumAttrCase<"PAD_ZERO", 1, "zero">,
|
||||
// We can not set the string value to "NAN" because it is a keyword in C++
|
||||
I32EnumAttrCase<"PAD_NAN", 2, "nan">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
|
||||
@@ -4,8 +4,11 @@
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
|
||||
def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">;
|
||||
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
|
||||
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
|
||||
|
||||
#endif // TRITON_INTERFACES
|
||||
|
||||
@@ -107,66 +107,104 @@ def TT_AddPtrOp : TT_Op<"addptr",
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
|
||||
}
|
||||
|
||||
def TT_AdvanceOp : TT_Op<"advance",
|
||||
[Pure,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">]> {
|
||||
let summary = "Advance a tensor pointer by offsets";
|
||||
|
||||
let arguments = (ins TT_TensorPtr:$ptr, Variadic<I32>:$offsets);
|
||||
|
||||
let results = (outs TT_TensorPtr:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
[SameLoadStoreOperandsAndResultShape,
|
||||
SameLoadStoreOperandsAndResultEncoding,
|
||||
AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr", "getPointerTypeSameShape($_self)">,
|
||||
"result", "ptr", "$_self",
|
||||
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
|
||||
TypesMatchWith<"infer mask type from result type or none",
|
||||
"result", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from result type or none",
|
||||
"result", "other", "$_self",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "load";
|
||||
let summary = "Load from a tensor of pointers or from a tensor pointer";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, Optional<TT_BoolLike>:$mask,
|
||||
Optional<TT_Type>:$other, OptionalAttr<DenseI32ArrayAttr>:$boundaryCheck,
|
||||
OptionalAttr<TT_PaddingOptionAttr>:$padding, TT_CacheModifierAttr:$cache,
|
||||
TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
// A tensor of pointers or a pointer to a scalar
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A tensor pointer with boundary check and padding
|
||||
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
|
||||
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A tensor of pointers or a pointer to a scalar with mask
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A tensor of pointers or a pointer to a scalar with mask and other
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A utility function to build the operation with all attributes
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional<ArrayRef<int32_t>>":$boundaryCheck,
|
||||
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
];
|
||||
|
||||
// Format: `tt.load operands attrs : optional(type(ptr)) -> type(result)`
|
||||
// We need an extra `optional(type(ptr))` for inferring the tensor pointer type with back compatibility
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
SameOperandsEncoding,
|
||||
[SameLoadStoreOperandsShape,
|
||||
SameLoadStoreOperandsEncoding,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
"getPointerTypeSameShape($_self)">,
|
||||
"value", "ptr", "$_self",
|
||||
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
let summary = "Store by a tensor of pointers or by a tensor pointer";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
|
||||
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
|
||||
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
|
||||
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
|
||||
OptionalAttr<DenseI32ArrayAttr>:$boundaryCheck,
|
||||
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
|
||||
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache,
|
||||
// A tensor of pointers or a pointer to a scalar
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>,
|
||||
// A tensor of pointers or a pointer to a scalar with mask
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict)>,
|
||||
// A tensor pointer with boundary check
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef<int32_t>":$boundaryCheck, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict)>
|
||||
];
|
||||
|
||||
// Format: `tt.store operands attrs : optional(type(ptr)), type(val)
|
||||
// We need an extra `optional(type(ptr))` for inferring the tensor pointer type with back compatibility
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
@@ -213,7 +251,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -438,4 +476,46 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make a Tensor Pointer
|
||||
//
|
||||
def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
|
||||
[Pure,
|
||||
SameVariadicOperandSize,
|
||||
TypesMatchWith<"infer pointer type from the result type",
|
||||
"result", "base",
|
||||
"getPointerType(getElementTypeOfTensorPointerType($_self))">]> {
|
||||
let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified";
|
||||
|
||||
let description = [{
|
||||
`tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a
|
||||
pointer to the block tensor, e.g. returns a type of `tt.ptr<tensor<8x8xf16>>`.
|
||||
}];
|
||||
|
||||
// TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints.
|
||||
let arguments = (ins
|
||||
TT_Ptr:$base,
|
||||
Variadic<I64>:$shape,
|
||||
Variadic<I64>:$strides,
|
||||
Variadic<I32>:$offsets,
|
||||
DenseI32ArrayAttr:$order
|
||||
);
|
||||
|
||||
let results = (outs TT_TensorPtr:$result);
|
||||
|
||||
// Add additional `[]` to increase readability and split variadic lists
|
||||
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins
|
||||
"Value":$base,
|
||||
"ValueRange":$shape,
|
||||
"ValueRange":$strides,
|
||||
"ValueRange":$offsets,
|
||||
"ArrayRef<int32_t>":$tensorShape,
|
||||
"ArrayRef<int32_t>":$order
|
||||
)>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
||||
@@ -31,19 +31,28 @@ def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
|
||||
// I32 Type
|
||||
// TT_I32 -> I32
|
||||
// TT_I32Tensor -> I32Tensor
|
||||
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
|
||||
def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>;
|
||||
|
||||
// I64 Type
|
||||
// TT_I64 -> I64
|
||||
// TT_I64Tensor -> I64Tensor
|
||||
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
|
||||
def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>;
|
||||
|
||||
// Pointer Type
|
||||
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
// Pointer Type in TableGen
|
||||
class TT_PtrOf<list<Type> pointeeTypes> :
|
||||
DialectType<Triton_Dialect,
|
||||
And<[CPred<"$_self.isa<::mlir::triton::PointerType>()">,
|
||||
Concat<"[](::mlir::Type pointeeType) { return ",
|
||||
SubstLeaves<"$_self", "pointeeType", AnyTypeOf<pointeeTypes>.predicate>,
|
||||
"; }($_self.cast<::mlir::triton::PointerType>().getPointeeType())">]>,
|
||||
"ptr", "::mlir::triton::PointerType">;
|
||||
|
||||
// Pointer Type in C++ (corresponding to `TT_PtrOf`)
|
||||
def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system";
|
||||
|
||||
let description = [{
|
||||
Triton PointerType
|
||||
Pointer type in Triton IR type system, which could be pointing to scalars or tensors.
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||
@@ -58,14 +67,27 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
// Scalar Pointer Type: `ptr<>`
|
||||
def TT_Ptr : TT_PtrOf<[AnyType]>;
|
||||
|
||||
// Tensor of Pointer Type
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
|
||||
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`
|
||||
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
|
||||
|
||||
// Tensor Type
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
|
||||
// Pointer Type to Tensor Type: `ptr<tensor<>>`
|
||||
def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
|
||||
|
||||
// Any Type in Triton IR
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,6 +12,8 @@ namespace mlir {
|
||||
|
||||
unsigned getPointeeBitWidth(RankedTensorType tensorTy);
|
||||
|
||||
}
|
||||
bool isTensorPointerType(Type type);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
||||
|
||||
@@ -8,6 +8,9 @@ namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createRewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
|
||||
@@ -19,4 +19,23 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
let dependentDialects = ["mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
|
||||
let description = [{
|
||||
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
|
||||
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
|
||||
the pointer/mask/other for each load/store.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::createRewriteTensorPointerPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -412,6 +412,8 @@ private:
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
|
||||
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
|
||||
// TODO(Chenggang): confirm `boundaryCheck` and `padding`
|
||||
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
|
||||
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
|
||||
insertSliceAsyncOp.getIsVolatile());
|
||||
|
||||
|
||||
@@ -355,6 +355,7 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(),
|
||||
adaptor.getBoundaryCheckAttr(), adaptor.getPaddingAttr(),
|
||||
adaptor.getCache(), adaptor.getEvict(),
|
||||
adaptor.getIsVolatile()),
|
||||
adaptor.getAttributes());
|
||||
|
||||
@@ -36,66 +36,116 @@ static Type getPointerTypeSameShape(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
static Type getPointerType(Type type) { return PointerType::get(type, 1); }
|
||||
|
||||
static Type getElementTypeOfTensorPointerType(Type type) {
|
||||
if (auto ptrType = type.dyn_cast<PointerType>())
|
||||
if (auto tensorType = ptrType.getPointeeType().dyn_cast<RankedTensorType>())
|
||||
return tensorType.getElementType();
|
||||
return {};
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Parse operands
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
||||
Type resultTypes[1];
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
|
||||
return failure();
|
||||
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
// Operand types
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
|
||||
|
||||
// Parse `optional(type(ptr)) -> type(result)`
|
||||
Type ptrType, resultType;
|
||||
if (parser.parseType(resultType))
|
||||
return failure();
|
||||
if (parser.parseOptionalArrow().succeeded()) {
|
||||
ptrType = resultType;
|
||||
if (parser.parseType(resultType))
|
||||
return failure();
|
||||
operandTypes.push_back(ptrType);
|
||||
result.addTypes(resultType);
|
||||
} else {
|
||||
operandTypes.push_back(getPointerTypeSameShape(resultType));
|
||||
result.addTypes(resultType);
|
||||
}
|
||||
|
||||
// Determine `mask` and `other`
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 2) {
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
operandTypes.push_back(getI1SameShape(resultType));
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 3) {
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
operandTypes.push_back(resultType);
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
|
||||
// Deduce `operandSegmentSizes` from the number of the operands
|
||||
auto operandSegmentSizesAttrName =
|
||||
LoadOp::getOperandSegmentSizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
operandSegmentSizesAttrName,
|
||||
parser.getBuilder().getDenseI32ArrayAttr({1, hasMask, hasOther}));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void LoadOp::print(OpAsmPrinter &printer) {
|
||||
printer << " ";
|
||||
printer << getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
|
||||
// `operandSegmentSizes` can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName()});
|
||||
|
||||
// `type(ptr) -> type(result)`
|
||||
printer << " : ";
|
||||
// `type(ptr)` is optional during parsing, we only print for tensor pointers
|
||||
if (isTensorPointerType(getPtr().getType())) {
|
||||
printer.printStrippedAttrOrType(getPtr().getType());
|
||||
printer << " -> ";
|
||||
}
|
||||
printer.printStrippedAttrOrType(getResult().getType());
|
||||
}
|
||||
|
||||
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Parse operands
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
||||
Type valueType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(valueType))
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
|
||||
return failure();
|
||||
|
||||
// Operand types
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
|
||||
// Parse `optional(type(ptr)), type(val)`
|
||||
// Pointer type
|
||||
Type ptrType, valType;
|
||||
if (parser.parseType(valType))
|
||||
return failure();
|
||||
if (parser.parseOptionalComma().succeeded()) {
|
||||
ptrType = valType;
|
||||
if (parser.parseType(valType))
|
||||
return failure();
|
||||
operandTypes.push_back(ptrType);
|
||||
} else {
|
||||
operandTypes.push_back(getPointerTypeSameShape(valType));
|
||||
}
|
||||
|
||||
// Value type
|
||||
operandTypes.push_back(valType);
|
||||
|
||||
// Determine `mask`
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||
operandTypes.push_back(getI1SameShape(valType));
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
@@ -107,7 +157,14 @@ void StoreOp::print(OpAsmPrinter &printer) {
|
||||
printer << " ";
|
||||
printer << getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{});
|
||||
|
||||
// `type(ptr), type(value)`
|
||||
printer << " : ";
|
||||
// `type(ptr)` is optional during parsing, we only print for tensor pointers
|
||||
if (isTensorPointerType(getPtr().getType())) {
|
||||
printer.printStrippedAttrOrType(getPtr().getType());
|
||||
printer << ", ";
|
||||
}
|
||||
printer.printStrippedAttrOrType(getValue().getType());
|
||||
}
|
||||
|
||||
@@ -123,15 +180,6 @@ void StoreOp::print(OpAsmPrinter &printer) {
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict) {
|
||||
return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache,
|
||||
evict);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
|
||||
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
|
||||
@@ -146,24 +194,42 @@ static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{},
|
||||
/*boundaryCheck=*/{}, /*padding=*/{}, cache, evict, isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ArrayRef<int32_t> boundaryCheck,
|
||||
std::optional<::mlir::triton::PaddingOption> padding,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck,
|
||||
padding, cache, evict, isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
LoadOp::build(builder, state, ptr, mask, /*other=*/{}, /*boundaryCheck=*/{},
|
||||
/*padding=*/{}, cache, evict, isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
Type resultType = getLoadOpResultType(builder, ptr.getType());
|
||||
LoadOp::build(builder, state, ptr, mask, other, /*boundaryCheck=*/{},
|
||||
/*padding=*/{}, cache, evict, isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
std::optional<ArrayRef<int32_t>> boundaryCheck,
|
||||
std::optional<::mlir::triton::PaddingOption> padding,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
// Operands
|
||||
state.addOperands(ptr);
|
||||
if (mask) {
|
||||
state.addOperands(mask);
|
||||
@@ -171,9 +237,20 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
state.addAttribute(
|
||||
getOperandSegmentSizesAttrName(state.name),
|
||||
builder.getDenseI32ArrayAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
|
||||
if (boundaryCheck.has_value()) {
|
||||
state.addAttribute(getBoundaryCheckAttrName(state.name),
|
||||
builder.getDenseI32ArrayAttr(boundaryCheck.value()));
|
||||
}
|
||||
if (padding.has_value()) {
|
||||
state.addAttribute(getPaddingAttrName(state.name),
|
||||
::mlir::triton::PaddingOptionAttr::get(
|
||||
builder.getContext(), padding.value()));
|
||||
}
|
||||
state.addAttribute(
|
||||
getCacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
@@ -182,9 +259,39 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(getIsVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
|
||||
// Result type
|
||||
Type resultType = getLoadOpResultType(builder, ptr.getType());
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict) {
|
||||
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
|
||||
/*boundaryCheck=*/{}, cache, evict);
|
||||
}
|
||||
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value, ::mlir::Value mask,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict) {
|
||||
return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{},
|
||||
cache, evict);
|
||||
}
|
||||
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
ArrayRef<int32_t> boundaryCheck,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict) {
|
||||
return StoreOp::build(builder, state, ptr, value, /*mask=*/{},
|
||||
builder.getDenseI32ArrayAttr(boundaryCheck), cache,
|
||||
evict);
|
||||
}
|
||||
|
||||
//-- TransOp --
|
||||
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
@@ -424,5 +531,27 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
}
|
||||
|
||||
//-- MakeTensorPtrOp --
|
||||
void MakeTensorPtrOp::build(::mlir::OpBuilder &builder,
|
||||
::mlir::OperationState &state, ::mlir::Value base,
|
||||
::mlir::ValueRange shape,
|
||||
::mlir::ValueRange strides,
|
||||
::mlir::ValueRange offsets,
|
||||
ArrayRef<int32_t> tensorShape,
|
||||
ArrayRef<int32_t> order) {
|
||||
// Get pointer type from `base`
|
||||
auto pointerType = base.getType().cast<PointerType>();
|
||||
assert(pointerType != nullptr);
|
||||
|
||||
// Build type `tt.ptr<tensor<tensorShape, base.pointeeType>>`
|
||||
auto tensorType = RankedTensorType::get(
|
||||
SmallVector<int64_t>(tensorShape.begin(), tensorShape.end()),
|
||||
pointerType.getPointeeType());
|
||||
auto result = PointerType::get(tensorType, 1);
|
||||
|
||||
return build(builder, state, result, base, shape, strides, offsets,
|
||||
builder.getDenseI32ArrayAttr(order));
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,42 +1,59 @@
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
||||
using namespace mlir;
|
||||
auto encA = tyA.dyn_cast<RankedTensorType>();
|
||||
auto encB = tyB.dyn_cast<RankedTensorType>();
|
||||
if (!encA || !encB)
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static LogicalResult verifySameEncoding(Type typeA, Type typeB,
|
||||
bool allowTensorPointerType) {
|
||||
auto getEncoding = [=](Type type) -> Attribute {
|
||||
auto rankedType = type.dyn_cast<RankedTensorType>();
|
||||
if (allowTensorPointerType) {
|
||||
if (auto ptrType = type.dyn_cast<triton::PointerType>())
|
||||
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
|
||||
} else {
|
||||
assert(!isTensorPointerType(type));
|
||||
}
|
||||
return rankedType ? rankedType.getEncoding() : Attribute();
|
||||
};
|
||||
auto encodingA = getEncoding(typeA);
|
||||
auto encodingB = getEncoding(typeB);
|
||||
if (!encodingA || !encodingB)
|
||||
return success();
|
||||
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
||||
return encodingA == encodingB ? success() : failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||
LogicalResult
|
||||
OpTrait::impl::verifySameOperandsEncoding(Operation *op,
|
||||
bool allowTensorPointerType) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifySameEncoding(opType, type, allowTensorPointerType)))
|
||||
return op->emitOpError() << "requires the same encoding for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding(
|
||||
Operation *op, bool allowTensorPointerType) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (failed(verifySameEncoding(resultType, type)))
|
||||
if (failed(verifySameEncoding(resultType, type, allowTensorPointerType)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same encoding for all operands and results";
|
||||
return verifySameOperandsEncoding(op);
|
||||
|
||||
return verifySameOperandsEncoding(op, allowTensorPointerType);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifySameEncoding(opType, type)))
|
||||
return op->emitOpError() << "requires the same encoding for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
@@ -69,3 +86,55 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static ArrayRef<int64_t> getTypeShape(Type type) {
|
||||
auto rankedType = type.dyn_cast<RankedTensorType>();
|
||||
if (auto ptrType = type.dyn_cast<triton::PointerType>())
|
||||
rankedType = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
|
||||
return rankedType ? rankedType.getShape() : ArrayRef<int64_t>();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto firstOperandShape = getTypeShape(op->getOperand(0).getType());
|
||||
for (auto type : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape)))
|
||||
return op->emitOpError() << "requires the same shape for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto firstOperandShape = getTypeShape(op->getOperand(0).getType());
|
||||
for (auto type : op->getResultTypes())
|
||||
if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
|
||||
return verifySameLoadStoreOperandsShape(op);
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifyLoadStorePointerAndValueType(Type valueType,
|
||||
Type ptrType) {
|
||||
if (isTensorPointerType(ptrType)) {
|
||||
return ptrType.cast<triton::PointerType>().getPointeeType() == valueType;
|
||||
} else if (auto rankedType = ptrType.dyn_cast<RankedTensorType>()) {
|
||||
if (auto elementPtrType =
|
||||
dyn_cast<triton::PointerType>(rankedType.getElementType())) {
|
||||
auto inferValueType = RankedTensorType::get(
|
||||
rankedType.getShape(), elementPtrType.getPointeeType(),
|
||||
rankedType.getEncoding());
|
||||
return inferValueType == valueType;
|
||||
}
|
||||
} else if (auto scalarPtrType = ptrType.dyn_cast<triton::PointerType>()) {
|
||||
return scalarPtrType.getPointeeType() == valueType;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -46,4 +46,10 @@ unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
|
||||
return pointeeType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
bool isTensorPointerType(Type type) {
|
||||
if (auto ptrType = type.dyn_cast<PointerType>())
|
||||
return ptrType.getPointeeType().isa<RankedTensorType>();
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonTransforms
|
||||
Combine.cpp
|
||||
RewriteTensorPointer.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTransformsIncGen
|
||||
|
||||
@@ -94,7 +94,8 @@ public:
|
||||
return mlir::failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, loadOp.getPtr(), loadOp.getMask(), falseValue, loadOp.getCache(),
|
||||
op, loadOp.getPtr(), loadOp.getMask(), falseValue,
|
||||
loadOp.getBoundaryCheck(), loadOp.getPadding(), loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -127,6 +128,7 @@ struct CanonicalizeMaskedLoadPattern
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
503
lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Normal file
503
lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Normal file
@@ -0,0 +1,503 @@
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
/// An additional struct to record the meta information of operations
|
||||
/// with tensor pointers
|
||||
struct RewritedInfo {
|
||||
private:
|
||||
Value base;
|
||||
SmallVector<Value> shape;
|
||||
SmallVector<Value> strides;
|
||||
SmallVector<Value> offsets;
|
||||
ArrayRef<int64_t> tensorShape;
|
||||
|
||||
// A cache to avoid generating the same offset with range
|
||||
DenseMap<unsigned, Value> cachedOffsetWithRange;
|
||||
|
||||
public:
|
||||
RewritedInfo() = default;
|
||||
|
||||
RewritedInfo(const RewritedInfo &other) = default;
|
||||
|
||||
RewritedInfo(Value base, const SmallVector<Value> &shape,
|
||||
const SmallVector<Value> &strides,
|
||||
const SmallVector<Value> &offsets,
|
||||
const ArrayRef<int64_t> &tensorShape)
|
||||
: base(base), shape(shape), strides(strides), offsets(offsets),
|
||||
tensorShape(tensorShape) {
|
||||
assert(shape.size() == strides.size() && shape.size() == offsets.size() &&
|
||||
shape.size() == tensorShape.size());
|
||||
}
|
||||
|
||||
unsigned int length() const { return shape.size(); }
|
||||
|
||||
Value getOffset(unsigned i) { return offsets[i]; }
|
||||
|
||||
SmallVector<Value> getOffsets() { return offsets; }
|
||||
|
||||
void setOffset(unsigned i, Value newOffset) {
|
||||
offsets[i] = newOffset;
|
||||
cachedOffsetWithRange.clear();
|
||||
}
|
||||
|
||||
void setOffsets(const SmallVector<Value> &newOffsets) {
|
||||
offsets = newOffsets;
|
||||
cachedOffsetWithRange.clear();
|
||||
}
|
||||
|
||||
Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
|
||||
unsigned i) {
|
||||
if (cachedOffsetWithRange.count(i))
|
||||
return cachedOffsetWithRange[i];
|
||||
|
||||
// Add range
|
||||
auto indexI32RowType =
|
||||
RankedTensorType::get({tensorShape[i]}, builder.getI32Type());
|
||||
auto indexRowType =
|
||||
RankedTensorType::get({tensorShape[i]}, builder.getI64Type());
|
||||
Value splatOffset =
|
||||
builder.create<triton::SplatOp>(loc, indexRowType, offsets[i]);
|
||||
Value range = builder.create<triton::MakeRangeOp>(loc, indexI32RowType, 0,
|
||||
tensorShape[i]);
|
||||
Value i64Range = builder.create<arith::ExtSIOp>(loc, indexRowType, range);
|
||||
|
||||
// Expand dimensions
|
||||
Value expandedResult =
|
||||
builder.create<arith::AddIOp>(loc, splatOffset, i64Range);
|
||||
for (int j = 0; j < tensorShape.size(); ++j) {
|
||||
if (j == i)
|
||||
continue;
|
||||
expandedResult =
|
||||
builder.create<triton::ExpandDimsOp>(loc, expandedResult, j);
|
||||
}
|
||||
|
||||
return cachedOffsetWithRange[i] = expandedResult;
|
||||
}
|
||||
|
||||
Value generatePtr(OpBuilder &builder, const Location &loc) {
|
||||
assert(tensorShape.size() == offsets.size() &&
|
||||
tensorShape.size() == strides.size());
|
||||
auto indexTensorType =
|
||||
RankedTensorType::get(tensorShape, builder.getI64Type());
|
||||
auto ptrType = base.getType().cast<triton::PointerType>();
|
||||
auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType);
|
||||
|
||||
// Generate offsets per dimension
|
||||
Value ptr = builder.create<triton::SplatOp>(loc, ptrTensorType, base);
|
||||
for (unsigned i = 0; i < tensorShape.size(); ++i) {
|
||||
auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i);
|
||||
|
||||
// We must splat strides into the expanded shape not a row for retaining
|
||||
// the divisibility information given by strides
|
||||
Value splatStride = builder.create<triton::SplatOp>(
|
||||
loc, offsetWithRange.getType(), strides[i]);
|
||||
Value offsetWithStride =
|
||||
builder.create<arith::MulIOp>(loc, offsetWithRange, splatStride);
|
||||
Value broadcasted = builder.create<triton::BroadcastOp>(
|
||||
loc, indexTensorType, offsetWithStride);
|
||||
|
||||
// Add to the pointer
|
||||
ptr = builder.create<triton::AddPtrOp>(loc, ptrTensorType, ptr,
|
||||
broadcasted);
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
Value generateMask(OpBuilder &builder, const Location &loc,
|
||||
const std::optional<ArrayRef<int32_t>> &boundaryCheck) {
|
||||
if (!boundaryCheck.has_value())
|
||||
return {};
|
||||
|
||||
// Generate mask per dimension
|
||||
auto maskTensorType =
|
||||
RankedTensorType::get(tensorShape, builder.getI1Type());
|
||||
Value mask;
|
||||
for (auto i : boundaryCheck.value()) {
|
||||
auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i);
|
||||
|
||||
// Compare with lower bound
|
||||
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
|
||||
loc, 0, builder.getI64Type());
|
||||
Value splatLowerBound = builder.create<triton::SplatOp>(
|
||||
loc, offsetWithRange.getType(), lowerBound);
|
||||
Value cmpLower = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound);
|
||||
|
||||
// Compare with upper bound
|
||||
Value splatUpperBound = builder.create<triton::SplatOp>(
|
||||
loc, offsetWithRange.getType(), shape[i]);
|
||||
Value cmpUpper = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
|
||||
|
||||
// And and broadcast
|
||||
Value andResult = builder.create<arith::AndIOp>(loc, cmpLower, cmpUpper);
|
||||
Value broadcasted =
|
||||
builder.create<triton::BroadcastOp>(loc, maskTensorType, andResult);
|
||||
|
||||
// And up all results
|
||||
if (!mask) {
|
||||
mask = broadcasted;
|
||||
} else {
|
||||
mask = builder.create<arith::AndIOp>(loc, mask, broadcasted);
|
||||
}
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
Value generateOther(OpBuilder &builder, const Location &loc,
|
||||
const std::optional<triton::PaddingOption> &padding) {
|
||||
if (!padding.has_value())
|
||||
return Value();
|
||||
|
||||
// Create element attribute
|
||||
auto elementType =
|
||||
base.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);
|
||||
|
||||
// Set zero padding value
|
||||
Attribute attr =
|
||||
elementType.isIntOrIndex()
|
||||
? builder.getIntegerAttr(elementType, 0).cast<Attribute>()
|
||||
: builder.getFloatAttr(elementType, 0).cast<Attribute>();
|
||||
|
||||
// Float NaN padding case
|
||||
if (padding.value() == triton::PaddingOption::PAD_NAN) {
|
||||
assert(!elementType.isIntOrIndex());
|
||||
auto apNaN = llvm::APFloat::getNaN(
|
||||
attr.cast<FloatAttr>().getValue().getSemantics());
|
||||
attr = builder.getFloatAttr(elementType, apNaN);
|
||||
}
|
||||
|
||||
// Create tensor
|
||||
Value constant = builder.create<arith::ConstantOp>(loc, attr);
|
||||
return builder.create<triton::SplatOp>(loc, otherTensorType, constant);
|
||||
}
|
||||
};
|
||||
|
||||
class RewriteTensorPointerPass
|
||||
: public TritonRewriteTensorPointerBase<RewriteTensorPointerPass> {
|
||||
private:
|
||||
int computeCapability;
|
||||
DenseMap<Value, RewritedInfo> rewritedInfo;
|
||||
|
||||
public:
|
||||
explicit RewriteTensorPointerPass(int computeCapability)
|
||||
: computeCapability(computeCapability) {}
|
||||
|
||||
static bool needRewrite(Operation *op) {
|
||||
return std::any_of(
|
||||
op->getOperands().begin(), op->getOperands().end(),
|
||||
[](Value operand) { return isTensorPointerType(operand.getType()); });
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index,
|
||||
const SmallVector<Value> &newValues) {
|
||||
assert(index < oldOperands.size());
|
||||
SmallVector<Value> newOperands;
|
||||
for (int i = 0; i < index; ++i)
|
||||
newOperands.push_back(oldOperands[i]);
|
||||
for (auto value : newValues)
|
||||
newOperands.push_back(value);
|
||||
for (auto i = index + 1; i < oldOperands.size(); ++i)
|
||||
newOperands.push_back(oldOperands[i]);
|
||||
return newOperands;
|
||||
}
|
||||
|
||||
Operation *rewriteMakeTensorPtrOp(OpBuilder &builder,
|
||||
triton::MakeTensorPtrOp op,
|
||||
std::stack<Operation *> &eraser) {
|
||||
// Save info for later use
|
||||
auto ptrType = op.getResult().getType().cast<triton::PointerType>();
|
||||
auto tensorType = ptrType.getPointeeType().cast<RankedTensorType>();
|
||||
|
||||
// Cast I32 offsets into I64
|
||||
SmallVector<Value> i64Offsets;
|
||||
for (auto offset : op.getOffsets()) {
|
||||
auto i64Offset = builder.create<arith::ExtSIOp>(
|
||||
op.getLoc(), builder.getI64Type(), offset);
|
||||
i64Offsets.push_back(i64Offset);
|
||||
}
|
||||
|
||||
// Save information
|
||||
rewritedInfo[op.getResult()] =
|
||||
RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets,
|
||||
tensorType.getShape());
|
||||
|
||||
// Erase the original operation
|
||||
eraser.push(op);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op,
|
||||
std::stack<Operation *> &eraser) {
|
||||
// Get info from previous results
|
||||
assert(rewritedInfo.count(op.getPtr()));
|
||||
auto info = rewritedInfo[op.getPtr()];
|
||||
|
||||
// Calculate new offsets
|
||||
assert(info.length() == op.getOffsets().size());
|
||||
SmallVector<Value> newOffsets;
|
||||
for (int i = 0; i < info.length(); ++i) {
|
||||
Value i64Offset = builder.create<arith::ExtSIOp>(
|
||||
op.getLoc(), builder.getI64Type(), op.getOffsets()[i]);
|
||||
Value newOffset = builder.create<arith::AddIOp>(
|
||||
op.getLoc(), info.getOffset(i), i64Offset);
|
||||
newOffsets.push_back(newOffset);
|
||||
}
|
||||
|
||||
// Save info for later use
|
||||
info.setOffsets(newOffsets);
|
||||
rewritedInfo[op.getResult()] = info;
|
||||
|
||||
// Erase the original operation
|
||||
eraser.push(op);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op,
|
||||
std::stack<Operation *> &eraser) {
|
||||
assert(isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op));
|
||||
|
||||
// We only have to rewrite load/stores with tensor pointers
|
||||
auto ptr = op->getOperand(0);
|
||||
if (!isTensorPointerType(ptr.getType()))
|
||||
return nullptr;
|
||||
|
||||
// Get info from previous results
|
||||
assert(rewritedInfo.count(ptr));
|
||||
auto info = rewritedInfo[ptr];
|
||||
|
||||
// Load/store with tensor pointers implicitly will check the bound while
|
||||
// accessing memory, so we should set `mask` and `other` (according to the
|
||||
// padding). Also note that load with tensor pointers do not have `mask` and
|
||||
// `other` while building IR from Python AST
|
||||
std::optional<ArrayRef<int>> boundaryCheck;
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
|
||||
assert(!loadOp.getMask() && !loadOp.getOther());
|
||||
boundaryCheck = loadOp.getBoundaryCheck();
|
||||
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
|
||||
assert(!storeOp.getMask());
|
||||
boundaryCheck = storeOp.getBoundaryCheck();
|
||||
}
|
||||
|
||||
// Generate new `ptr`, `mask` and `other`
|
||||
auto newPtr = info.generatePtr(builder, op->getLoc());
|
||||
auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck);
|
||||
Value newOther;
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(op))
|
||||
newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding());
|
||||
|
||||
// Create a new operation
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
|
||||
auto newResult = builder.create<triton::LoadOp>(
|
||||
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
op->getResult(0).replaceAllUsesWith(newResult);
|
||||
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
|
||||
builder.create<triton::StoreOp>(storeOp.getLoc(), newPtr,
|
||||
storeOp.getValue(), newMask,
|
||||
storeOp.getCache(), storeOp.getEvict());
|
||||
}
|
||||
|
||||
// Erase the original operation
|
||||
eraser.push(op);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op,
|
||||
std::stack<Operation *> &eraser) {
|
||||
// Generate new iteration operands and set rewrited information
|
||||
SmallVector<Value> oldIterOperands = op.getIterOperands();
|
||||
SmallVector<Value> newIterOperands = op.getIterOperands();
|
||||
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
|
||||
++i, ++oldI) {
|
||||
if (!isTensorPointerType(newIterOperands[i].getType()))
|
||||
continue;
|
||||
|
||||
// Expand the tensor pointer into offsets
|
||||
assert(rewritedInfo.count(newIterOperands[i]));
|
||||
auto info = rewritedInfo[newIterOperands[i]];
|
||||
newIterOperands =
|
||||
generateNewOperands(newIterOperands, i, info.getOffsets());
|
||||
i += info.length() - 1;
|
||||
size += info.length() - 1;
|
||||
}
|
||||
|
||||
// Rebuild the loop type
|
||||
auto newForOp = builder.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
|
||||
op.getUpperBound(), op.getStep(),
|
||||
newIterOperands);
|
||||
|
||||
// Create value mapping. Note that for tensor pointers, we use identity
|
||||
// mapping. It may refer to a value in the old loop, but we will rewrite it
|
||||
// later
|
||||
IRMapping mapping;
|
||||
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
|
||||
++i, ++oldI) {
|
||||
auto oldRegionIterArg = op.getRegionIterArg(oldI);
|
||||
if (isTensorPointerType(oldRegionIterArg.getType())) {
|
||||
// Pass rewrited info inside
|
||||
assert(rewritedInfo.count(oldIterOperands[oldI]));
|
||||
auto info = rewritedInfo[oldIterOperands[oldI]];
|
||||
mapping.map(oldRegionIterArg, oldRegionIterArg);
|
||||
for (unsigned j = 0; j < info.length(); ++j)
|
||||
info.setOffset(j, newForOp.getRegionIterArg(i + j));
|
||||
rewritedInfo[oldRegionIterArg] = info;
|
||||
i += info.length() - 1;
|
||||
} else {
|
||||
mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i));
|
||||
}
|
||||
}
|
||||
mapping.map(op.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
// Clone body
|
||||
builder.setInsertionPointToStart(newForOp.getBody());
|
||||
for (auto &opInFor : *op.getBody()) {
|
||||
auto *newOp = builder.clone(opInFor, mapping);
|
||||
for (unsigned i = 0; i < opInFor.getNumResults(); ++i)
|
||||
mapping.map(op->getResult(i), newOp->getResult(i));
|
||||
}
|
||||
|
||||
// Replace later usages
|
||||
assert(op.getNumResults() == op.getNumIterOperands());
|
||||
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
|
||||
auto oldResult = op.getResult(oldI);
|
||||
if (isTensorPointerType(oldResult.getType())) {
|
||||
// Pack new offsets into rewrited info
|
||||
assert(rewritedInfo.count(oldIterOperands[oldI]));
|
||||
auto info = rewritedInfo[oldIterOperands[oldI]];
|
||||
for (unsigned j = 0; j < info.length(); ++j)
|
||||
info.setOffset(j, newForOp.getResult(i + j));
|
||||
i += info.length() - 1;
|
||||
rewritedInfo[oldResult] = info;
|
||||
} else {
|
||||
oldResult.replaceAllUsesWith(newForOp.getResult(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Erase later
|
||||
eraser.push(op);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op,
|
||||
std::stack<Operation *> &eraser) {
|
||||
// Replace tensor pointers with offsets
|
||||
SmallVector<Value> newOperands = op->getOperands();
|
||||
for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) {
|
||||
if (!isTensorPointerType(newOperands[i].getType()))
|
||||
continue;
|
||||
|
||||
assert(rewritedInfo.count(newOperands[i]));
|
||||
auto info = rewritedInfo[newOperands[i]];
|
||||
newOperands = generateNewOperands(newOperands, i, info.getOffsets());
|
||||
i += info.length() - 1;
|
||||
size += info.length() - 1;
|
||||
}
|
||||
op->setOperands(newOperands);
|
||||
|
||||
// No need to erase
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation *rewriteOp(Operation *op, std::stack<Operation *> &eraser) {
|
||||
OpBuilder builder(op);
|
||||
|
||||
// Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers
|
||||
// Rewriting functions return the next operation to visit, if there is no
|
||||
// next one, simply return `nullptr`
|
||||
std::pair<Value, RewritedInfo> rewrited;
|
||||
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
|
||||
return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser);
|
||||
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
|
||||
return rewriteAdvanceOp(builder, advanceOp, eraser);
|
||||
} else if (isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op)) {
|
||||
return rewriteLoadStoreOp(builder, op, eraser);
|
||||
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
|
||||
return rewriteLoadStoreOp(builder, op, eraser);
|
||||
} else if (op->getDialect()->getNamespace() == "scf" ||
|
||||
op->getDialect()->getNamespace() == "cf") {
|
||||
if (!needRewrite(op))
|
||||
return op;
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
return rewriteForOp(builder, forOp, eraser);
|
||||
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
|
||||
return rewriteYieldOp(builder, yieldOp, eraser);
|
||||
} else {
|
||||
llvm_unreachable("Currently we only support tensor pointer usages "
|
||||
"inside a `scf::ForOp`, others such as `scf::IfOp`,"
|
||||
"`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` "
|
||||
"are not supported yet");
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise return the original one
|
||||
return op;
|
||||
}
|
||||
|
||||
void visitOperation(Operation *op, std::stack<Operation *> &eraser) {
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
for (auto &block : region) {
|
||||
// We need an extra copy because erasing operations may break the
|
||||
// iterator behavior
|
||||
SmallVector<Operation *> blockCopy;
|
||||
for (auto &nestedOp : block)
|
||||
blockCopy.push_back(&nestedOp);
|
||||
|
||||
// Rewrite and recursively visit
|
||||
for (auto &nestedOp : blockCopy) {
|
||||
if (auto newOp = rewriteOp(nestedOp, eraser))
|
||||
visitOperation(newOp, eraser);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
// Only rewrite if the hardware does not support
|
||||
if (computeCapability >= 90)
|
||||
return;
|
||||
|
||||
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
|
||||
// MLIR does not support one-multiple value mapping. For example, if we use
|
||||
// `ConversionPatternRewriter`, we can not make a type converter, which
|
||||
// converts `ptr<tensor>` into multiple types `ptr<>, int64, int64, ...`
|
||||
// (containing the base/offsets/strides...). What we can do is to convert
|
||||
// `ptr<tensor>` into a single type `Tuple<ptr<>, int64, int64, ...>`. But
|
||||
// in this way, we also have to define `PackTuple` and `UnpackTuple`
|
||||
// operations and make a canonicalization pass to optimize, which is much
|
||||
// So here we recursively build the IR, to be specific, we have to rewrite
|
||||
// `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`,
|
||||
// `scf.for` (tensor pointer usages may be in a loop fashion)
|
||||
std::stack<Operation *> eraser;
|
||||
visitOperation(getOperation(), eraser);
|
||||
|
||||
// The operation could not be erased during visit, because they may have
|
||||
// later usages, so we erase after visit
|
||||
rewritedInfo.clear();
|
||||
while (!eraser.empty()) {
|
||||
auto op = eraser.top();
|
||||
eraser.pop();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
triton::createRewriteTensorPointerPass(int computeCapability) {
|
||||
return std::make_unique<RewriteTensorPointerPass>(computeCapability);
|
||||
}
|
||||
@@ -407,8 +407,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
newOp = builder.create<triton::LoadOp>(
|
||||
loadOp.getLoc(), loadOp.getResult().getType(),
|
||||
lookupOrDefault(loadOp.getPtr(), stage), newMask,
|
||||
lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
lookupOrDefault(loadOp.getOther(), stage),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
addNamedAttrs(newOp, op->getAttrDictionary());
|
||||
} else {
|
||||
newOp = builder.clone(*op);
|
||||
@@ -630,8 +631,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextOp = builder.create<triton::LoadOp>(
|
||||
loadOp.getLoc(), loadOp.getResult().getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.getPtr()), newMask,
|
||||
nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(),
|
||||
loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
nextMapping.lookupOrDefault(loadOp.getOther()),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
addNamedAttrs(nextOp, op->getAttrDictionary());
|
||||
nextMapping.map(loadOp.getResult(), nextOp->getResult(0));
|
||||
} else {
|
||||
|
||||
@@ -78,6 +78,11 @@ void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION")
|
||||
.value("PAD_ZERO", mlir::triton::PaddingOption::PAD_ZERO)
|
||||
.value("PAD_NAN", mlir::triton::PaddingOption::PAD_NAN)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER")
|
||||
.value("NONE", mlir::triton::CacheModifier::NONE)
|
||||
.value("CA", mlir::triton::CacheModifier::CA)
|
||||
@@ -1124,6 +1129,27 @@ void init_triton_ir(py::module &&m) {
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, value, cacheModifier,
|
||||
evictionPolicy);
|
||||
})
|
||||
.def("create_tensor_pointer_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
std::vector<int32_t> &boundaryCheck,
|
||||
std::optional<mlir::triton::PaddingOption> paddingOption,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptr, boundaryCheck, paddingOption, cacheModifier,
|
||||
evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_tensor_pointer_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &val,
|
||||
std::vector<int32_t> &boundaryCheck,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, ptr, val, boundaryCheck,
|
||||
cacheModifier, evictionPolicy);
|
||||
})
|
||||
.def("create_masked_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
|
||||
std::optional<mlir::Value> &other,
|
||||
@@ -1390,10 +1416,31 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
||||
})
|
||||
// Force GPU barrier
|
||||
.def("create_barrier", [](mlir::OpBuilder &self) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::gpu::BarrierOp>(loc);
|
||||
});
|
||||
.def("create_barrier",
|
||||
[](mlir::OpBuilder &self) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::gpu::BarrierOp>(loc);
|
||||
})
|
||||
// Make a block pointer (tensor pointer in Triton IR)
|
||||
.def("create_make_block_ptr",
|
||||
[](mlir::OpBuilder &self, mlir::Value &base,
|
||||
std::vector<mlir::Value> &shape,
|
||||
std::vector<mlir::Value> &strides,
|
||||
std::vector<mlir::Value> &offsets,
|
||||
std::vector<int32_t> &tensorShape,
|
||||
std::vector<int32_t> &order) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::MakeTensorPtrOp>(
|
||||
loc, base, shape, strides, offsets, tensorShape, order);
|
||||
})
|
||||
// Advance a block pointer
|
||||
.def("create_advance",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
std::vector<mlir::Value> &offsets) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::AdvanceOp>(loc, ptr.getType(),
|
||||
ptr, offsets);
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
@@ -1448,6 +1495,11 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
})
|
||||
.def("add_rewrite_tensor_pointer_pass",
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(mlir::triton::createRewriteTensorPointerPass(
|
||||
computeCapability));
|
||||
})
|
||||
.def("add_convert_triton_to_tritongpu_pass",
|
||||
[](mlir::PassManager &self, int numWarps) {
|
||||
self.addPass(
|
||||
|
||||
102
python/test/unit/language/test_block_pointer.py
Normal file
102
python/test/unit/language/test_block_pointer.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
# We only copy half of the data to see if the padding works
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, ))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, ))
|
||||
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
|
||||
tl.store(b_block_ptr, a, boundary_check=(0, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, n, padding_option",
|
||||
[(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16")
|
||||
for n in (64, 128, 256, 512, 1024)
|
||||
for padding in ("zero", "nan")])
|
||||
def test_block_copy(dtype_str, n, padding_option):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] >= 9:
|
||||
pytest.skip("Hopper support is working in progress")
|
||||
|
||||
dtype = getattr(torch, dtype_str)
|
||||
if dtype_str in ("bool", "int16"):
|
||||
if padding_option == "nan":
|
||||
pytest.skip("Padding with NaN is not supported for integer types")
|
||||
a = torch.randint(0, 2, (n, ), device="cuda", dtype=dtype)
|
||||
else:
|
||||
a = torch.randn((n, ), device="cuda", dtype=dtype)
|
||||
b = torch.zeros((n, ), device="cuda", dtype=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
|
||||
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
|
||||
|
||||
assert torch.all(a[0: n // 2] == b[0: n // 2])
|
||||
if padding_option == "zero":
|
||||
assert torch.all(b[n // 2: n] == 0)
|
||||
else:
|
||||
assert torch.all(torch.isnan(b[n // 2: n]))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_with_advance_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
|
||||
):
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
|
||||
# Below two lines are just for testing negative offsets for the `advance` API, which could be removed
|
||||
a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))
|
||||
a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))
|
||||
a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero")
|
||||
b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero")
|
||||
|
||||
c = tl.dot(a, b)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape, num_warps", [
|
||||
(shape, num_warps)
|
||||
for shape in [
|
||||
[64, 64, 16],
|
||||
[64, 64, 32],
|
||||
[64, 64, 64],
|
||||
]
|
||||
for num_warps in [4, 8]
|
||||
])
|
||||
def test_block_ptr_matmul_no_scf(shape, num_warps):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] >= 9:
|
||||
pytest.skip("Hopper support is working in progress")
|
||||
|
||||
m, n, k = shape
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.float16)
|
||||
b = torch.randn((k, n), device="cuda", dtype=torch.float16)
|
||||
c = torch.empty((m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=m, N=n, K=k,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
|
||||
num_warps=num_warps)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(c, golden)
|
||||
@@ -1087,10 +1087,17 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
|
||||
return ret, generator
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
def inline_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
@@ -1100,13 +1107,25 @@ def optimize_triton_ir(mod):
|
||||
return mod
|
||||
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants, debug=False):
|
||||
def ttir_compute_capability_rewrite(mod, compute_capability):
|
||||
# For hardware without support, we must rewrite all load/store with block (tensor) pointers into legacy load/store
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_rewrite_tensor_pointer_pass(compute_capability)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants, compute_capability, debug=False):
|
||||
mod, _ = build_triton_ir(fn, signature, specialization, constants, debug)
|
||||
mod = inline_triton_ir(mod)
|
||||
mod = ttir_compute_capability_rewrite(mod, compute_capability)
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -1902,7 +1921,7 @@ def compile(fn, **kwargs):
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, capability)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
@@ -1916,7 +1935,7 @@ def compile(fn, **kwargs):
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, capability, debug)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
@@ -1926,7 +1945,6 @@ def compile(fn, **kwargs):
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
|
||||
@@ -8,6 +8,7 @@ from ..impl import (
|
||||
from . import math
|
||||
from .core import (
|
||||
abs,
|
||||
advance,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
@@ -48,6 +49,7 @@ from .core import (
|
||||
int8,
|
||||
load,
|
||||
log,
|
||||
make_block_ptr,
|
||||
max,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
@@ -101,6 +103,7 @@ from .random import (
|
||||
|
||||
__all__ = [
|
||||
"abs",
|
||||
"advance",
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
@@ -144,6 +147,7 @@ __all__ = [
|
||||
"math",
|
||||
"load",
|
||||
"log",
|
||||
"make_block_ptr",
|
||||
"max",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
|
||||
@@ -884,55 +884,87 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
|
||||
def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option="", cache_modifier="",
|
||||
eviction_policy="", volatile=False, _builder=None):
|
||||
"""
|
||||
Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
|
||||
(1) `pointer` could be a single element pointer, then a scalar will be loaded
|
||||
- `mask` and `other` must be scalar too
|
||||
- `other` is implicitly typecast to `pointer.dtype.element_ty`
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
(2) `pointer` could be element-wise tensor of pointers, in which case:
|
||||
- `mask` and `other` are implicitly broadcast to `pointer.shape`
|
||||
- `other` is implicitly typecast to `pointer.dtype.element_ty`
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
(3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
|
||||
- `mask` and `other` must be None
|
||||
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access
|
||||
|
||||
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
|
||||
|
||||
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
|
||||
|
||||
:param pointer: Pointers to the data to be loaded.
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
|
||||
:type mask: Block of triton.int1, optional
|
||||
:param other: if mask[idx] is false, return other[idx]
|
||||
:param pointer: Pointer to the data to be loaded
|
||||
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
|
||||
:param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
|
||||
(must be `None` with block pointers)
|
||||
:type mask: Block of `triton.int1`, optional
|
||||
:param other: if `mask[idx]` is false, return `other[idx]`
|
||||
:type other: Block, optional
|
||||
:param cache_modifier: changes cache option in nvidia ptx
|
||||
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
|
||||
:type boundary_check: tuple of ints, optional
|
||||
:param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound
|
||||
:param cache_modifier: changes cache option in NVIDIA PTX
|
||||
:type cache_modifier: str, optional
|
||||
:param eviction_policy: changes eviction policy in NVIDIA PTX
|
||||
:type eviction_policy: str, optional
|
||||
:param volatile: changes volatile option in NVIDIA PTX
|
||||
:type volatile: bool, optional
|
||||
"""
|
||||
# mask, other can be constexpr
|
||||
# `mask` and `other` can be constexpr
|
||||
if _constexpr_to_value(mask) is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
if _constexpr_to_value(other) is not None:
|
||||
other = _to_tensor(other, _builder)
|
||||
padding_option = _constexpr_to_value(padding_option)
|
||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||
volatile = _constexpr_to_value(volatile)
|
||||
return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
|
||||
return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
|
||||
volatile, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None):
|
||||
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
|
||||
"""
|
||||
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
Store a tensor of data into memory locations defined by `pointer`:
|
||||
(1) `pointer` could be a single element pointer, then a scalar will be stored
|
||||
- `mask` must be scalar too
|
||||
- `boundary_check` and `padding_option` must be empty
|
||||
(2) `pointer` could be element-wise tensor of pointers, in which case:
|
||||
- `mask` is implicitly broadcast to `pointer.shape`
|
||||
- `boundary_check` must be empty
|
||||
(3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case:
|
||||
- `mask` must be None
|
||||
- `boundary_check` can be specified to control the behavior of out-of-bound access
|
||||
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
|
||||
|
||||
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
|
||||
|
||||
:param pointer: The memory locations where the elements of :code:`value` are stored.
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param value: The tensor of elements to be stored.
|
||||
:param pointer: The memory location where the elements of `value` are stored
|
||||
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
|
||||
:param value: The tensor of elements to be stored
|
||||
:type value: Block
|
||||
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
|
||||
:param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
|
||||
:type mask: Block of triton.int1, optional
|
||||
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
|
||||
:type boundary_check: tuple of ints, optional
|
||||
:param cache_modifier: changes cache option in NVIDIA PTX
|
||||
:type cache_modifier: str, optional
|
||||
:param eviction_policy: changes eviction policy in NVIDIA PTX
|
||||
:type eviction_policy: str, optional
|
||||
"""
|
||||
# value can be constexpr
|
||||
# `value` can be constexpr
|
||||
value = _to_tensor(value, _builder)
|
||||
if _constexpr_to_value(mask) is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||
return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder)
|
||||
return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -1387,6 +1419,33 @@ def device_assert(cond, msg="", _builder=None):
|
||||
lineno = frame.f_back.f_lineno
|
||||
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
|
||||
"""
|
||||
Returns a pointer to a block in a parent tensor
|
||||
|
||||
:param base: The base pointer to the parent tensor
|
||||
:param shape: The shape of the parent tensor
|
||||
:param strides: The strides of the parent tensor
|
||||
:param offsets: The offsets to the block
|
||||
:param block_shape: The shape of the block
|
||||
:param order: The order of the original data format
|
||||
"""
|
||||
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
Advance a block pointer
|
||||
|
||||
:param base: the block pointer to advance
|
||||
:param offsets: the offsets to advance, a tuple by dimension
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Iterators
|
||||
# -----------------------
|
||||
|
||||
@@ -771,7 +771,7 @@ def cast(input: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def str_to_cache_modifier(cache_modifier):
|
||||
def _str_to_cache_modifier(cache_modifier):
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
if cache_modifier == ".ca":
|
||||
@@ -783,7 +783,7 @@ def str_to_cache_modifier(cache_modifier):
|
||||
return cache
|
||||
|
||||
|
||||
def str_to_eviction_policy(eviction_policy):
|
||||
def _str_to_eviction_policy(eviction_policy):
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
@@ -795,97 +795,219 @@ def str_to_eviction_policy(eviction_policy):
|
||||
return eviction
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
is_volatile: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def _str_to_padding_option(padding_option):
|
||||
padding = None # default
|
||||
if padding_option:
|
||||
if padding_option == "zero":
|
||||
padding = ir.PADDING_OPTION.PAD_ZERO
|
||||
elif padding_option == "nan":
|
||||
padding = ir.PADDING_OPTION.PAD_NAN
|
||||
else:
|
||||
raise ValueError(f"Padding option {padding_option} not supported")
|
||||
return padding
|
||||
|
||||
|
||||
def _canonicalize_boundary_check(boundary_check, block_shape):
|
||||
if boundary_check:
|
||||
if not hasattr(boundary_check, "__iter__"):
|
||||
boundary_check = [boundary_check]
|
||||
boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
|
||||
for dim in boundary_check:
|
||||
assert isinstance(dim, int) and 0 <= dim < len(block_shape)
|
||||
assert len(boundary_check) > 0
|
||||
assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
|
||||
return sorted(boundary_check)
|
||||
return tuple()
|
||||
|
||||
|
||||
def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
|
||||
# Load by a block pointer: `pointer_type<block_type<>>`
|
||||
# Block pointer can not have `mask` and `other` arguments
|
||||
if mask or other:
|
||||
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
||||
|
||||
elt_ty = ptr.type.element_ty.element_ty
|
||||
assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
|
||||
if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
|
||||
raise ValueError("Padding option `nan` is not supported for integer block pointers")
|
||||
|
||||
# `dst_ty` is de-referenced type of the pointer type
|
||||
dst_ty = ptr.type.element_ty
|
||||
|
||||
# Check `boundary_check` argument
|
||||
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
|
||||
|
||||
# Build IR
|
||||
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
|
||||
is_volatile), dst_ty)
|
||||
|
||||
|
||||
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
|
||||
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__())
|
||||
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
|
||||
|
||||
# Check `mask`, `other`, `boundary_check`, and `padding` arguments
|
||||
if not mask and other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
if padding or boundary_check:
|
||||
raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
|
||||
"pointers or loading a scalar. Because the compiler does not know the boundary; please "
|
||||
"use block pointers (defined by `make_block_ptr`) instead")
|
||||
|
||||
# For a pointer of scalar, check the type of `mask` and `other`
|
||||
if not ptr.type.is_block():
|
||||
if mask and mask.type.is_block():
|
||||
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
|
||||
if other and other.type.is_block():
|
||||
raise ValueError("Other argument cannot be block type if pointer argument is not a block")
|
||||
|
||||
# Make `mask` and `other` into the same shape as `ptr`
|
||||
if ptr.type.is_block():
|
||||
if mask:
|
||||
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
||||
if other:
|
||||
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
|
||||
|
||||
# Get `pointer_type<elt_ty>` and `elt_ty`
|
||||
ptr_ty = ptr.type.scalar
|
||||
elt_ty = ptr_ty.element_ty
|
||||
|
||||
# treat bool* as tl.int8*
|
||||
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
# Cast `other` into `ele_ty` type
|
||||
if other:
|
||||
other = cast(other, elt_ty, builder)
|
||||
|
||||
# cache modifier
|
||||
|
||||
# Create loaded result type `dst_ty`
|
||||
if ptr.type.is_block():
|
||||
shape = ptr.type.get_block_shapes()
|
||||
dst_ty = tl.block_type(elt_ty, shape)
|
||||
else:
|
||||
# Load by de-referencing the pointer of scalar
|
||||
dst_ty = elt_ty
|
||||
|
||||
cache = str_to_cache_modifier(cache_modifier)
|
||||
eviction = str_to_eviction_policy(eviction_policy)
|
||||
|
||||
# Build IR
|
||||
if not mask:
|
||||
if other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle if other else None,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
|
||||
eviction, is_volatile), dst_ty)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
padding_option: str,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
is_volatile: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# Cache, eviction and padding options
|
||||
cache = _str_to_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
padding = _str_to_padding_option(padding_option)
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# Load by a block pointer: `pointer_type<block_type<>>`
|
||||
return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
|
||||
else:
|
||||
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
||||
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
|
||||
|
||||
|
||||
def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
|
||||
# Store by a block pointer: `pointer_type<block_type<>>`
|
||||
# Block pointers can not have the `mask` argument
|
||||
if mask:
|
||||
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
||||
|
||||
# Check same shape
|
||||
block_shape = ptr.type.element_ty.get_block_shapes()
|
||||
if not val.type.is_block():
|
||||
val = broadcast_impl_shape(val, block_shape, builder)
|
||||
assert val.type.is_block(), "Value argument must be block type or a scalar"
|
||||
assert block_shape == val.type.get_block_shapes(), "Block shape and value shape mismatch"
|
||||
|
||||
elt_ty = ptr.type.element_ty.element_ty
|
||||
assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
|
||||
|
||||
# Check `boundary_check` argument
|
||||
boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
|
||||
|
||||
# Build IR
|
||||
return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
|
||||
tl.void)
|
||||
|
||||
|
||||
def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
|
||||
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
|
||||
|
||||
# Check `boundary_check` argument
|
||||
if boundary_check:
|
||||
raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
|
||||
"scalar. Because the compiler does not know the boundary; please use block pointers "
|
||||
"(defined by `make_block_ptr`) instead")
|
||||
|
||||
# For a pointer of scalar, check the type of `val` and `mask`
|
||||
if not ptr.type.is_block():
|
||||
if val.type.is_block():
|
||||
raise ValueError("Value argument cannot be block type if pointer argument is not a block")
|
||||
if mask and mask.type.is_block():
|
||||
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
|
||||
|
||||
# Make `mask` and `val` into the same shape as `ptr`
|
||||
if ptr.type.is_block():
|
||||
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
||||
if mask and ptr.type.is_block():
|
||||
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
||||
if mask:
|
||||
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
||||
|
||||
ptr_ty = ptr.type.scalar
|
||||
elt_ty = ptr_ty.element_ty
|
||||
# treat bool* as tl.int8*
|
||||
|
||||
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
# attributes
|
||||
cache = str_to_cache_modifier(cache_modifier)
|
||||
eviction = str_to_eviction_policy(eviction_policy)
|
||||
# cast to target data-type
|
||||
|
||||
# Cast to target data type
|
||||
val = cast(val, elt_ty, builder)
|
||||
|
||||
# Build IR
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
|
||||
if not mask.type.scalar.is_bool():
|
||||
raise ValueError("Mask must have boolean scalar type")
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# Cache and eviction options
|
||||
cache = _str_to_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# Store by a block pointer: `pointer_type<block_type<>>`
|
||||
return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
|
||||
else:
|
||||
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
||||
return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)
|
||||
|
||||
|
||||
#########
|
||||
# atomic
|
||||
#########
|
||||
@@ -1265,3 +1387,70 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
|
||||
|
||||
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
|
||||
|
||||
|
||||
def _convert_elem_to_ir_value(builder, elem, require_i64):
|
||||
if isinstance(elem, tl.constexpr):
|
||||
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
|
||||
elif isinstance(elem, tl.tensor):
|
||||
assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
|
||||
assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
|
||||
if elem.dtype != tl.int64 and require_i64:
|
||||
return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed())
|
||||
elif elem.dtype != tl.int32:
|
||||
return builder.create_int_cast(elem.handle, builder.get_int32_ty(), elem.dtype.is_int_signed())
|
||||
return elem.handle
|
||||
assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
|
||||
|
||||
|
||||
def _convert_to_ir_values(builder, list_like, require_i64=True):
|
||||
if hasattr(list_like, "__iter__"):
|
||||
return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like]
|
||||
return [_convert_elem_to_ir_value(builder, list_like, require_i64)]
|
||||
|
||||
|
||||
def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor:
|
||||
# Convert dynamic arguments to IR values
|
||||
# NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
|
||||
shape = _convert_to_ir_values(builder, shape)
|
||||
strides = _convert_to_ir_values(builder, strides)
|
||||
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
||||
|
||||
# Check `base` type
|
||||
if not base.type.is_ptr() or base.type.element_ty.is_block():
|
||||
raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
|
||||
|
||||
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
||||
if base.type.element_ty == tl.int1:
|
||||
base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder)
|
||||
|
||||
# Check whether `block_shape` is static
|
||||
if not hasattr(block_shape, "__iter__"):
|
||||
block_shape = [block_shape]
|
||||
block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
|
||||
assert all([isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape]), \
|
||||
"Expected a list of constant integers (`int32_t` range) in `block_shape`"
|
||||
|
||||
# Check `order`
|
||||
if not hasattr(order, "__iter__"):
|
||||
order = [order]
|
||||
order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
|
||||
assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
|
||||
|
||||
# Must have same length
|
||||
assert all([len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]]), \
|
||||
"Expected shape/strides/offsets/block_shape to have the same length"
|
||||
|
||||
# Build value, the type is:
|
||||
# `pointer_type<blocked<shape, element_type>>` in Python
|
||||
# `tt.ptr<tensor<shape, element_type>>` in MLIR
|
||||
handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
|
||||
return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
|
||||
|
||||
|
||||
def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
|
||||
# Convert dynamic offsets to IR values
|
||||
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
||||
|
||||
# Advanced block pointer type is the same as before
|
||||
return tl.tensor(builder.create_advance(base.handle, offsets), base.type)
|
||||
|
||||
83
test/Triton/rewrite-tensor-pointer.mlir
Normal file
83
test/Triton/rewrite-tensor-pointer.mlir
Normal file
@@ -0,0 +1,83 @@
|
||||
// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s
|
||||
func.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c31_i32 = arith.constant 31 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c1_i64 = arith.constant 1 : i64
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
||||
%2 = arith.addi %arg3, %c127_i32 : i32
|
||||
%3 = arith.divsi %2, %c128_i32 : i32
|
||||
%4 = arith.addi %arg4, %c31_i32 : i32
|
||||
%5 = arith.divsi %4, %c32_i32 : i32
|
||||
%6 = arith.muli %5, %c8_i32 : i32
|
||||
%7 = arith.divsi %0, %6 : i32
|
||||
%8 = arith.muli %7, %c8_i32 : i32
|
||||
%9 = arith.subi %3, %8 : i32
|
||||
%10 = arith.cmpi slt, %9, %c8_i32 : i32
|
||||
%11 = arith.select %10, %9, %c8_i32 : i32
|
||||
%12 = arith.remsi %0, %11 : i32
|
||||
%13 = arith.addi %8, %12 : i32
|
||||
%14 = arith.remsi %0, %6 : i32
|
||||
%15 = arith.divsi %14, %11 : i32
|
||||
%16 = arith.muli %13, %c128_i32 : i32
|
||||
%17 = arith.muli %1, %c32_i32 : i32
|
||||
%18 = arith.extsi %arg3 : i32 to i64
|
||||
%19 = arith.extsi %arg5 : i32 to i64
|
||||
%20 = arith.extsi %arg6 : i32 to i64
|
||||
// CHECK-NOT: tt.make_tensor_ptr
|
||||
%21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %17] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
|
||||
%22 = arith.muli %15, %c32_i32 : i32
|
||||
%23 = arith.extsi %arg4 : i32 to i64
|
||||
%24 = arith.extsi %arg7 : i32 to i64
|
||||
// CHECK-NOT: tt.make_tensor_ptr
|
||||
%25 = tt.make_tensor_ptr %arg1, [%19, %23], [%24, %c1_i64], [%17, %22] {order = array<i32: 1, 0>} : !tt.ptr<tensor<32x32xf16>>
|
||||
%26 = arith.addi %arg5, %c31_i32 : i32
|
||||
%27 = arith.divsi %26, %c32_i32 : i32
|
||||
%28 = arith.index_cast %27 : i32 to index
|
||||
%29:3 = scf.for %arg9 = %c0 to %28 step %c1 iter_args(%arg10 = %cst, %arg11 = %21, %arg12 = %25) -> (tensor<128x32xf32>, !tt.ptr<tensor<128x32xf16>>, !tt.ptr<tensor<32x32xf16>>) {
|
||||
// CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16>
|
||||
%55 = tt.load %arg11 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr<tensor<128x32xf16>> -> tensor<128x32xf16>
|
||||
// CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
|
||||
%56 = tt.load %arg12 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr<tensor<32x32xf16>> -> tensor<32x32xf16>
|
||||
%57 = tt.dot %55, %56, %arg10 {allowTF32 = true} : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32>
|
||||
// CHECK-NOT: tt.advance
|
||||
%58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr<tensor<128x32xf16>>
|
||||
// CHECK-NOT: tt.advance
|
||||
%59 = tt.advance %arg12, [%c32_i32, %c0_i32] : !tt.ptr<tensor<32x32xf16>>
|
||||
scf.yield %57, %58, %59 : tensor<128x32xf32>, !tt.ptr<tensor<128x32xf16>>, !tt.ptr<tensor<32x32xf16>>
|
||||
}
|
||||
%30 = arith.truncf %29#0 : tensor<128x32xf32> to tensor<128x32xf16>
|
||||
%31 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%32 = tt.splat %16 : (i32) -> tensor<128xi32>
|
||||
%33 = arith.addi %32, %31 : tensor<128xi32>
|
||||
%34 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||
%35 = tt.splat %22 : (i32) -> tensor<32xi32>
|
||||
%36 = arith.addi %35, %34 : tensor<32xi32>
|
||||
%37 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%38 = tt.splat %arg8 : (i32) -> tensor<128x1xi32>
|
||||
%39 = arith.muli %37, %38 : tensor<128x1xi32>
|
||||
%40 = tt.expand_dims %36 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%41 = tt.broadcast %39 : (tensor<128x1xi32>) -> tensor<128x32xi32>
|
||||
%42 = tt.broadcast %40 : (tensor<1x32xi32>) -> tensor<128x32xi32>
|
||||
%43 = arith.addi %41, %42 : tensor<128x32xi32>
|
||||
%44 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>>
|
||||
%45 = tt.addptr %44, %43 : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
|
||||
%46 = tt.splat %arg3 : (i32) -> tensor<128xi32>
|
||||
%47 = arith.cmpi slt, %33, %46 : tensor<128xi32>
|
||||
%48 = tt.expand_dims %47 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1>
|
||||
%49 = tt.splat %arg4 : (i32) -> tensor<32xi32>
|
||||
%50 = arith.cmpi slt, %36, %49 : tensor<32xi32>
|
||||
%51 = tt.expand_dims %50 {axis = 0 : i32} : (tensor<32xi1>) -> tensor<1x32xi1>
|
||||
%52 = tt.broadcast %48 : (tensor<128x1xi1>) -> tensor<128x32xi1>
|
||||
%53 = tt.broadcast %51 : (tensor<1x32xi1>) -> tensor<128x32xi1>
|
||||
%54 = arith.andi %52, %53 : tensor<128x32xi1>
|
||||
tt.store %45, %30, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<128x32xf16>
|
||||
return
|
||||
}
|
||||
@@ -1032,9 +1032,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
|
||||
%c3136_i32 = arith.constant 3136 : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c3136_i32 = arith.constant 3136 : index
|
||||
%c256_i32 = arith.constant 256 : index
|
||||
%c0_i32 = arith.constant 0 : index
|
||||
%cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked>
|
||||
%cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked>
|
||||
%cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked>
|
||||
@@ -1056,8 +1056,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%12 = tt.broadcast %11 : (tensor<1x1xi32, #blocked>) -> tensor<1x256xi32, #blocked>
|
||||
%13 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1x256x!tt.ptr<f32>, #blocked>
|
||||
%14 = tt.broadcast %7 : (tensor<1x1xi1, #blocked>) -> tensor<1x256xi1, #blocked>
|
||||
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 {
|
||||
%43 = tt.splat %arg5 : (i32) -> tensor<1x256xi32, #blocked>
|
||||
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) {
|
||||
%42 = arith.index_cast %arg5 : index to i32
|
||||
%43 = tt.splat %42 : (i32) -> tensor<1x256xi32, #blocked>
|
||||
%44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked>
|
||||
%45 = "triton_gpu.cmpi"(%44, %cst_4) {predicate = 2 : i64} : (tensor<1x256xi32, #blocked>, tensor<1x256xi32, #blocked>) -> tensor<1x256xi1, #blocked>
|
||||
%46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked>
|
||||
|
||||
Reference in New Issue
Block a user