mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(compiler): fix lowering of tasks generated from loop tiling pass.
This commit is contained in:
@@ -64,6 +64,10 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
Region &DFTOpBody = DFTOp.getBody();
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
SetVector<Value> operands;
|
||||
getUsedValuesDefinedAbove(DFTOpBody, operands);
|
||||
DFTOp->setOperands(operands.takeVector());
|
||||
|
||||
// Instead of outlining with the same operands/results, we pass all
|
||||
// results as operands as well. For now we preserve the results'
|
||||
// types, which will be changed to use an indirection when lowering.
|
||||
@@ -591,6 +595,23 @@ struct FinalizeTaskCreationPass
|
||||
op->setOperand(1, clone);
|
||||
}
|
||||
});
|
||||
|
||||
module.walk([&](RT::WorkFunctionReturnOp op) {
|
||||
OpBuilder builder(op);
|
||||
|
||||
Value val = op.getOperand(0);
|
||||
if (val.getType().isa<mlir::MemRefType>() &&
|
||||
isa<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
val.getDefiningOp())) {
|
||||
Value newval =
|
||||
builder
|
||||
.create<mlir::memref::AllocOp>(
|
||||
val.getLoc(), val.getType().dyn_cast<mlir::MemRefType>())
|
||||
.getResult();
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
op->setOperand(0, newval);
|
||||
}
|
||||
});
|
||||
}
|
||||
FinalizeTaskCreationPass(bool debug) : debug(debug){};
|
||||
|
||||
|
||||
@@ -561,6 +561,9 @@ mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
|
||||
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createStartStopPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createCanonicalizerPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createBufferizationToMemRefPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
|
||||
|
||||
@@ -589,9 +592,13 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(pm, mlir::arith::createArithExpandOpsPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createReconcileUnrealizedCastsPass(),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
@@ -68,6 +68,9 @@ public:
|
||||
ASSERT_OUTCOME_HAS_VALUE(maybeRes);
|
||||
auto result = maybeRes.value();
|
||||
|
||||
if (!mlir::concretelang::dfr::_dfr_is_root_node())
|
||||
return;
|
||||
|
||||
/* Check results */
|
||||
bool allgood = true;
|
||||
for (size_t i = 0; i < desc.outputs.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user