fix(compiler): fix lowering of tasks generated from loop tiling pass.

This commit is contained in:
Antoniu Pop
2024-01-18 13:15:19 +00:00
committed by Andi Drebes
parent a5afb1f0a6
commit 2da8644e57
3 changed files with 31 additions and 0 deletions

View File

@@ -64,6 +64,10 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
Region &DFTOpBody = DFTOp.getBody();
OpBuilder::InsertionGuard guard(builder);
SetVector<Value> operands;
getUsedValuesDefinedAbove(DFTOpBody, operands);
DFTOp->setOperands(operands.takeVector());
// Instead of outlining with the same operands/results, we pass all
// results as operands as well. For now we preserve the results'
// types, which will be changed to use an indirection when lowering.
@@ -591,6 +595,23 @@ struct FinalizeTaskCreationPass
op->setOperand(1, clone);
}
});
module.walk([&](RT::WorkFunctionReturnOp op) {
OpBuilder builder(op);
Value val = op.getOperand(0);
if (val.getType().isa<mlir::MemRefType>() &&
isa<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
val.getDefiningOp())) {
Value newval =
builder
.create<mlir::memref::AllocOp>(
val.getLoc(), val.getType().dyn_cast<mlir::MemRefType>())
.getResult();
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
op->setOperand(0, newval);
}
});
}
FinalizeTaskCreationPass(bool debug) : debug(debug){};

View File

@@ -561,6 +561,9 @@ mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createStartStopPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createCanonicalizerPass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createBufferizationToMemRefPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
@@ -589,9 +592,13 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
pipelinePrinting("StdToLLVM", pm, context);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(pm, mlir::arith::createArithExpandOpsPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createReconcileUnrealizedCastsPass(),
enablePass);
return pm.run(module);
}

View File

@@ -68,6 +68,9 @@ public:
ASSERT_OUTCOME_HAS_VALUE(maybeRes);
auto result = maybeRes.value();
if (!mlir::concretelang::dfr::_dfr_is_root_node())
return;
/* Check results */
bool allgood = true;
for (size_t i = 0; i < desc.outputs.size(); i++) {