mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[OPTIMIZER] Improved layout simplifications heuristics (#1168)
This commit is contained in:
@@ -161,10 +161,13 @@ def TT_StoreOp : TT_Op<"store",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
|
||||
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
|
||||
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
|
||||
@@ -85,6 +85,23 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (Operation *op = value.getDefiningOp()) {
|
||||
DimVectorT knownContiguity(rank, 1);
|
||||
DimVectorT knownDivisibility(rank, 1);
|
||||
DimVectorT knownConstancy(rank, 1);
|
||||
if (Attribute attr = op->getAttr("tt.divisibility")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownDivisibility = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.contiguity")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownContiguity = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.constancy")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
knownConstancy = DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
|
||||
}
|
||||
|
||||
return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
|
||||
@@ -818,7 +835,24 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
|
||||
// override with hint
|
||||
auto newContiguity = curr.getContiguity();
|
||||
auto newDivisibility = curr.getDivisibility();
|
||||
auto newConstancy = curr.getConstancy();
|
||||
if (Attribute attr = op->getAttr("tt.contiguity")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.divisibility")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
if (Attribute attr = op->getAttr("tt.constancy")) {
|
||||
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
|
||||
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
|
||||
}
|
||||
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
|
||||
curr.getConstantValue());
|
||||
// join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
|
||||
@@ -345,7 +345,8 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask(), adaptor.cache(),
|
||||
adaptor.evict());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -149,8 +149,11 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
StoreOp::build(builder, state, ptr, value, mlir::Value());
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict) {
|
||||
return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache,
|
||||
evict);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
|
||||
@@ -169,8 +169,9 @@ struct CanonicalizeMaskedStorePattern
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
|
||||
storeOp.value());
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
storeOp, storeOp.ptr(), storeOp.value(), storeOp.cache(),
|
||||
storeOp.evict());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
|
||||
@@ -154,6 +154,12 @@ public:
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
@@ -278,6 +284,9 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
return failure();
|
||||
ret = sliceEncoding.getParent();
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -287,16 +296,23 @@ inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
auto ptr = op->getOperand(0);
|
||||
// Case 2: We assume that `evict_last` loads/stores have high hit rate
|
||||
if (auto load = dyn_cast<triton::LoadOp>(op))
|
||||
if (load.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
return false;
|
||||
if (auto store = dyn_cast<triton::StoreOp>(op))
|
||||
if (store.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
return false;
|
||||
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorTy.getEncoding();
|
||||
// Case 2: Different type conversion is expensive (e.g., mma <-> block)
|
||||
// Case 3: Different type conversion is expensive (e.g., mma <-> block)
|
||||
if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
return true;
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
|
||||
auto order = triton::gpu::getOrder(encoding);
|
||||
auto targetOrder = triton::gpu::getOrder(targetEncoding);
|
||||
// Case 3: The targeEncoding may expose more vectorization opportunities
|
||||
// Case 4: The targeEncoding may expose more vectorization opportunities
|
||||
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
|
||||
}
|
||||
return false;
|
||||
@@ -365,6 +381,9 @@ LogicalResult simulateBackwardRematerialization(
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
@@ -383,9 +402,9 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
BlockAndValueMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
origType.getShape(), origType.getElementType(), argType.getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||
if (typeInfer) {
|
||||
@@ -425,6 +444,11 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
|
||||
}
|
||||
}
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (op->getNumResults() == 0) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
rewriter.eraseOp(op);
|
||||
return;
|
||||
}
|
||||
auto *newOp = cloneWithInferType(rewriter, op, mapping);
|
||||
auto newType = newOp->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto newCvtType = RankedTensorType::get(
|
||||
@@ -564,17 +588,22 @@ public:
|
||||
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty())
|
||||
if (cvtSlices.empty()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
for (Operation *op : cvtSlices) {
|
||||
// don't rematerialize anything expensive
|
||||
if (expensiveToRemat(op, srcEncoding))
|
||||
if (expensiveToRemat(op, dstEncoding)) {
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize if it adds an extra conversion that can't
|
||||
// be removed
|
||||
for (Value arg : op->getOperands()) {
|
||||
|
||||
@@ -169,7 +169,10 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
auto ptr = loadOp.ptr();
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
auto ty = getElementTypeOrSelf(ptr.getType())
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
continue;
|
||||
auto ty = tensorTy.getElementType()
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType();
|
||||
unsigned width = vec * ty.getIntOrFloatBitWidth();
|
||||
|
||||
@@ -190,7 +190,14 @@ void init_triton_ir(py::module &&m) {
|
||||
if (mlir::Operation *definingOp = self.getDefiningOp())
|
||||
definingOp->setAttr(name, attr);
|
||||
else {
|
||||
/* issue a warning */
|
||||
auto arg = self.cast<mlir::BlockArgument>();
|
||||
int id = arg.getArgNumber();
|
||||
std::string attrName = name + "_arg" + std::to_string(id);
|
||||
mlir::Block *owner = arg.getOwner();
|
||||
if (owner->isEntryBlock() &&
|
||||
!mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
|
||||
owner->getParentOp()->setAttr(attrName, attr);
|
||||
}
|
||||
}
|
||||
})
|
||||
.def("get_context", &mlir::Value::getContext)
|
||||
@@ -1082,10 +1089,12 @@ void init_triton_ir(py::module &&m) {
|
||||
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||
mlir::Value &value) -> void {
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, value);
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, value, cacheModifier,
|
||||
evictionPolicy);
|
||||
})
|
||||
.def("create_masked_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
|
||||
@@ -1100,9 +1109,11 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_masked_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||
mlir::Value &mask) -> void {
|
||||
mlir::Value &mask, mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask,
|
||||
cacheModifier, evictionPolicy);
|
||||
})
|
||||
.def("create_view",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
|
||||
@@ -179,6 +179,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def contains_return_op(self, node):
|
||||
if isinstance(node, ast.Return):
|
||||
return True
|
||||
elif isinstance(node, ast.If):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
ret = any(pred(s) for s in node.body)
|
||||
if node.orelse:
|
||||
ret = ret or any(pred(s) for s in node.orelse)
|
||||
return ret
|
||||
else:
|
||||
return False
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
@@ -475,7 +487,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
if self.scf_stack:
|
||||
if self.scf_stack or not self.contains_return_op(node):
|
||||
self.visit_if_scf(cond, node)
|
||||
else:
|
||||
self.visit_if_top_level(cond, node)
|
||||
|
||||
@@ -873,7 +873,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, _builder=None):
|
||||
def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None):
|
||||
"""
|
||||
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
|
||||
@@ -890,7 +890,9 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
value = _to_tensor(value, _builder)
|
||||
if _constexpr_to_value(mask) is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
return semantic.store(pointer, value, mask, _builder)
|
||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||
return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
@@ -18,6 +18,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
"""
|
||||
for _ in tl.static_range(n_rounds):
|
||||
# for _ in range(n_rounds):
|
||||
# update random state
|
||||
A = PHILOX_ROUND_A
|
||||
B = PHILOX_ROUND_B
|
||||
|
||||
@@ -747,6 +747,30 @@ def cast(input: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def str_to_cache_modifier(cache_modifier):
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
if cache_modifier == ".ca":
|
||||
cache = ir.CACHE_MODIFIER.CA
|
||||
elif cache_modifier == ".cg":
|
||||
cache = ir.CACHE_MODIFIER.CG
|
||||
else:
|
||||
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
||||
return cache
|
||||
|
||||
|
||||
def str_to_eviction_policy(eviction_policy):
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
return eviction
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
@@ -775,24 +799,6 @@ def load(ptr: tl.tensor,
|
||||
other = cast(other, elt_ty, builder)
|
||||
|
||||
# cache modifier
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
if cache_modifier == ".ca":
|
||||
cache = ir.CACHE_MODIFIER.CA
|
||||
elif cache_modifier == ".cg":
|
||||
cache = ir.CACHE_MODIFIER.CG
|
||||
else:
|
||||
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
||||
|
||||
# eviction policy
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
|
||||
if ptr.type.is_block():
|
||||
shape = ptr.type.get_block_shapes()
|
||||
@@ -800,6 +806,9 @@ def load(ptr: tl.tensor,
|
||||
else:
|
||||
dst_ty = elt_ty
|
||||
|
||||
cache = str_to_cache_modifier(cache_modifier)
|
||||
eviction = str_to_eviction_policy(eviction_policy)
|
||||
|
||||
if not mask:
|
||||
if other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
@@ -816,6 +825,8 @@ def load(ptr: tl.tensor,
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -830,14 +841,16 @@ def store(ptr: tl.tensor,
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
# attributes
|
||||
cache = str_to_cache_modifier(cache_modifier)
|
||||
eviction = str_to_eviction_policy(eviction_policy)
|
||||
# cast to target data-type
|
||||
val = cast(val, elt_ty, builder)
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
|
||||
if not mask.type.scalar.is_bool():
|
||||
raise ValueError("Mask must have boolean scalar type")
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
|
||||
|
||||
#########
|
||||
# atomic
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout2 = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
#layout2 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
|
||||
|
||||
// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
|
||||
Reference in New Issue
Block a user