From 26084a68aa6b6045d6d064ae633a90f660c48620 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 28 Jun 2022 14:26:06 +0100 Subject: [PATCH] fix(compiler): delay all memref deallocation calls introduced by the bufferizer and that are made into futures until after the synchronization point. --- .../RT/Analysis/LowerDataflowTasksToRT.cpp | 20 +++++++++++++++++++ .../end_to_end_jit_auto_parallelization.cc | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index d5e44b07f..68762e069 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -325,6 +325,26 @@ struct LowerDataflowTasksPass ArrayRef()); } }); + + // Delay memref deallocations when memrefs are made into futures + module.walk([&](Operation *op) { + if (isa(*op) && + op->getOperand(0).getType().isa()) { + for (auto &use : + llvm::make_early_inc_range(op->getOperand(0).getUses())) { + if (isa(use.getOwner())) { + OpBuilder builder(use.getOwner() + ->getParentOfType() + .getBody() + .back() + .getTerminator()); + builder.clone(*use.getOwner()); + use.getOwner()->erase(); + } + } + } + return WalkResult::advance(); + }); } LowerDataflowTasksPass(bool debug) : debug(debug){}; diff --git a/compiler/tests/unittest/end_to_end_jit_auto_parallelization.cc b/compiler/tests/unittest/end_to_end_jit_auto_parallelization.cc index 40c064ce1..dc83c371c 100644 --- a/compiler/tests/unittest/end_to_end_jit_auto_parallelization.cc +++ b/compiler/tests/unittest/end_to_end_jit_auto_parallelization.cc @@ -9,7 +9,7 @@ // Auto-parallelize independent FHE ops ///////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(ParallelizeAndRunFHE, DISABLED_add_eint_tree) { +TEST(ParallelizeAndRunFHE, add_eint_tree) { checkedJit(lambda, R"XXX( func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3: !FHE.eint<7>) -> !FHE.eint<7> { %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)