[OPTIMIZER] Improved layout simplifications heuristics (#1168)

This commit is contained in:
Philippe Tillet
2023-02-09 20:17:25 -08:00
committed by GitHub
parent c61c8a123f
commit 2aba985daa
13 changed files with 160 additions and 47 deletions

View File

@@ -161,10 +161,13 @@ def TT_StoreOp : TT_Op<"store",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>,
];
// let assemblyFormat = "operands attr-dict `:` type($value)";

View File

@@ -85,6 +85,23 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
}
}
}
} else if (Operation *op = value.getDefiningOp()) {
DimVectorT knownContiguity(rank, 1);
DimVectorT knownDivisibility(rank, 1);
DimVectorT knownConstancy(rank, 1);
if (Attribute attr = op->getAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownDivisibility = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownContiguity = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownConstancy = DimVectorT(vals.begin(), vals.end());
}
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
}
return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
@@ -818,7 +835,24 @@ ChangeResult AxisInfoAnalysis::visitOperation(
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
// override with hint
auto newContiguity = curr.getContiguity();
auto newDivisibility = curr.getDivisibility();
auto newConstancy = curr.getConstancy();
if (Attribute attr = op->getAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
curr.getConstantValue());
// join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {

View File

@@ -345,7 +345,8 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
op, adaptor.ptr(), adaptor.value(), adaptor.mask(), adaptor.cache(),
adaptor.evict());
return success();
}
};

View File

@@ -149,8 +149,11 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
StoreOp::build(builder, state, ptr, value, mlir::Value());
::mlir::Value ptr, ::mlir::Value value,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache,
evict);
}
//-- LoadOp --

View File

@@ -169,8 +169,9 @@ struct CanonicalizeMaskedStorePattern
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
storeOp.value());
rewriter.replaceOpWithNewOp<triton::StoreOp>(
storeOp, storeOp.ptr(), storeOp.value(), storeOp.cache(),
storeOp.evict());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);

View File

@@ -154,6 +154,12 @@ public:
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
@@ -278,6 +284,9 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return failure();
ret = sliceEncoding.getParent();
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
return failure();
}
return success();
}
@@ -287,16 +296,23 @@ inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
if (isSingleValue(op->getOperand(0)))
return false;
auto ptr = op->getOperand(0);
// Case 2: We assume that `evict_last` loads/stores have high hit rate
if (auto load = dyn_cast<triton::LoadOp>(op))
if (load.evict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto store = dyn_cast<triton::StoreOp>(op))
if (store.evict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
auto encoding = tensorTy.getEncoding();
// Case 2: Different type conversion is expensive (e.g., mma <-> block)
// Case 3: Different type conversion is expensive (e.g., mma <-> block)
if (encoding.getTypeID() != targetEncoding.getTypeID())
return true;
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
auto order = triton::gpu::getOrder(encoding);
auto targetOrder = triton::gpu::getOrder(targetEncoding);
// Case 3: The targeEncoding may expose more vectorization opportunities
// Case 4: The targeEncoding may expose more vectorization opportunities
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
}
return false;
@@ -365,6 +381,9 @@ LogicalResult simulateBackwardRematerialization(
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
continue;
// We add one expensive conversion for the current operand
numCvts += 1;
queue.emplace_back(opArgI, newEncoding);
@@ -383,9 +402,9 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
@@ -425,6 +444,11 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
}
}
rewriter.setInsertionPoint(op);
if (op->getNumResults() == 0) {
Operation *newOp = rewriter.clone(*op, mapping);
rewriter.eraseOp(op);
return;
}
auto *newOp = cloneWithInferType(rewriter, op, mapping);
auto newType = newOp->getResult(0).getType().cast<RankedTensorType>();
auto newCvtType = RankedTensorType::get(
@@ -564,17 +588,22 @@ public:
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
if (cvtSlices.empty()) {
return failure();
}
llvm::MapVector<Value, Attribute> toConvert;
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, srcEncoding))
if (expensiveToRemat(op, dstEncoding)) {
return failure();
}
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op)) {
return failure();
}
// don't rematerialize if it adds an extra conversion that can't
// be removed
for (Value arg : op->getOperands()) {

View File

@@ -169,7 +169,10 @@ LogicalResult LoopPipeliner::initialize() {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.ptr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
auto ty = getElementTypeOrSelf(ptr.getType())
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
continue;
auto ty = tensorTy.getElementType()
.cast<triton::PointerType>()
.getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();

View File

@@ -190,7 +190,14 @@ void init_triton_ir(py::module &&m) {
if (mlir::Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
/* issue a warning */
auto arg = self.cast<mlir::BlockArgument>();
int id = arg.getArgNumber();
std::string attrName = name + "_arg" + std::to_string(id);
mlir::Block *owner = arg.getOwner();
if (owner->isEntryBlock() &&
!mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
owner->getParentOp()->setAttr(attrName, attr);
}
}
})
.def("get_context", &mlir::Value::getContext)
@@ -1082,10 +1089,12 @@ void init_triton_ir(py::module &&m) {
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
})
.def("create_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs,
mlir::Value &value) -> void {
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, value);
self.create<mlir::triton::StoreOp>(loc, ptrs, value, cacheModifier,
evictionPolicy);
})
.def("create_masked_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
@@ -1100,9 +1109,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
mlir::Value &mask) -> void {
mlir::Value &mask, mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask,
cacheModifier, evictionPolicy);
})
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,

View File

@@ -179,6 +179,18 @@ class CodeGenerator(ast.NodeVisitor):
break
return stmts and isinstance(stmt, ast.Return)
def contains_return_op(self, node):
if isinstance(node, ast.Return):
return True
elif isinstance(node, ast.If):
pred = lambda s: self.contains_return_op(s)
ret = any(pred(s) for s in node.body)
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
else:
return False
def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)
@@ -475,7 +487,7 @@ class CodeGenerator(ast.NodeVisitor):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder)
if self.scf_stack:
if self.scf_stack or not self.contains_return_op(node):
self.visit_if_scf(cond, node)
else:
self.visit_if_top_level(cond, node)

View File

@@ -873,7 +873,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
@builtin
def store(pointer, value, mask=None, _builder=None):
def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None):
"""
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
@@ -890,7 +890,9 @@ def store(pointer, value, mask=None, _builder=None):
value = _to_tensor(value, _builder)
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder)
# -----------------------

View File

@@ -18,6 +18,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in tl.static_range(n_rounds):
# for _ in range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B

View File

@@ -747,6 +747,30 @@ def cast(input: tl.tensor,
# ===----------------------------------------------------------------------===//
def str_to_cache_modifier(cache_modifier):
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
return cache
def str_to_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
return eviction
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
@@ -775,24 +799,6 @@ def load(ptr: tl.tensor,
other = cast(other, elt_ty, builder)
# cache modifier
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
# eviction policy
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
@@ -800,6 +806,9 @@ def load(ptr: tl.tensor,
else:
dst_ty = elt_ty
cache = str_to_cache_modifier(cache_modifier)
eviction = str_to_eviction_policy(eviction_policy)
if not mask:
if other:
raise ValueError("`other` cannot be provided without `mask`")
@@ -816,6 +825,8 @@ def load(ptr: tl.tensor,
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
cache_modifier: str,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
@@ -830,14 +841,16 @@ def store(ptr: tl.tensor,
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# attributes
cache = str_to_cache_modifier(cache_modifier)
eviction = str_to_eviction_policy(eviction_policy)
# cast to target data-type
val = cast(val, elt_ty, builder)
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
#########
# atomic

View File

@@ -2,7 +2,7 @@
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout2 = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#layout2 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>