[BACKEND] Add error reporting to report non-kernel-argument (#2552)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Dongdong Li
2023-11-02 08:22:10 +08:00
committed by GitHub
parent 702cde0d6f
commit d0098da7b1
2 changed files with 13 additions and 4 deletions

View File

@@ -877,6 +877,11 @@ private:
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (!isa<BlockArgument>(v) &&
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
v.getDefiningOp()))
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
@@ -1771,6 +1776,11 @@ private:
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (!isa<BlockArgument>(v) &&
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
v.getDefiningOp()))
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));

View File

@@ -1701,8 +1701,6 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
if (resEnc && resEnc.isHopper()) {
// Don't pipeline valid dots that depend on ops other than scf.yield
// and scf.for
auto dot = dotOp.getResult();
bool valid = true;
@@ -1713,7 +1711,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
valid = false;
Operation *firstUse = nullptr;
selfDepend(dotOp, forOp, &firstUse);
auto depend = selfDepend(dotOp, forOp, &firstUse);
bool selfDirectDepend = (dotOp == firstUse);
for (auto tempInAll : allDots) {
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
@@ -1726,7 +1724,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
hasSyncDot = true;
}
auto CArg = dotOp.getOperand(2);
if (!(selfDirectDepend || (!selfDirectDepend && hasSyncDot)) ||
if (!(selfDirectDepend ||
(depend && !selfDirectDepend && hasSyncDot)) ||
!CArg.hasOneUse())
valid = false;