mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Add error reporting to report non-kernel-argument (#2552)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -877,6 +877,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
@@ -1771,6 +1776,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
|
||||
@@ -1701,8 +1701,6 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
if (resEnc && resEnc.isHopper()) {
|
||||
// Don't pipeline valid dots that depend on ops other than scf.yield
|
||||
// and scf.for
|
||||
auto dot = dotOp.getResult();
|
||||
bool valid = true;
|
||||
|
||||
@@ -1713,7 +1711,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
valid = false;
|
||||
|
||||
Operation *firstUse = nullptr;
|
||||
selfDepend(dotOp, forOp, &firstUse);
|
||||
auto depend = selfDepend(dotOp, forOp, &firstUse);
|
||||
bool selfDirectDepend = (dotOp == firstUse);
|
||||
for (auto tempInAll : allDots) {
|
||||
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
|
||||
@@ -1726,7 +1724,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
hasSyncDot = true;
|
||||
}
|
||||
auto CArg = dotOp.getOperand(2);
|
||||
if (!(selfDirectDepend || (!selfDirectDepend && hasSyncDot)) ||
|
||||
if (!(selfDirectDepend ||
|
||||
(depend && !selfDirectDepend && hasSyncDot)) ||
|
||||
!CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user