diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h index 0428ef19d..829222b0b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h @@ -65,6 +65,12 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, mlir::Value iv, mlir::OpFoldResult lb, mlir::OpFoldResult step); +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps); + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt index 4f7494893..306b43968 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt new file mode 100644 index 000000000..0661dc36d --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name RT) +add_public_tablegen_target(RTTransformsIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h new file mode 100644 index 000000000..371a7544c --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.h @@ -0,0 +1,25 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H +#define CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" + +#include "concretelang/Dialect/RT/IR/RTDialect.h" + +#define GEN_PASS_CLASSES +#include "concretelang/Dialect/RT/Transforms/Passes.h.inc" + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass(); +} // namespace concretelang +} // namespace mlir + +#endif // CONCRETELANG_DIALECT_RT_TRANSFORMS_PASSES_H diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td new file mode 100644 index 000000000..f638400d5 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/RT/Transforms/Passes.td @@ -0,0 +1,82 @@ +#ifndef MLIR_DIALECT_RT_TRANSFORMS_PASSES +#define MLIR_DIALECT_RT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def HoistAwaitFuturePass : Pass<"hoist-await-future", "mlir::func::FuncOp"> { + let summary = "Hoists `RT.await_future` operations whose results are yielded " + "by `scf.forall` operations out of the loops"; + let description = [{ + Hoists `RT.await_future` operations whose results are yielded by + scf.forall operations out of the loops in order to avoid + over-synchronization of data-flow tasks. + + E.g., the following IR: + + ``` + scf.forall (%arg) in (16) + shared_outs(%o1 = %sometensor, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] : + tensor<...> into tensor<...> + ... + } + } + ``` + + is transformed into: + + ``` + %tensoroffutures = tensor.empty() : tensor<16x!RT.future>> + + scf.forall (%arg) in (16) + shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor) + -> (tensor<...>, tensor<...>) + { + ... + %rph = "RT.build_return_ptr_placeholder"() : + () -> !RT.rtptr>> + "RT.create_async_task"(..., %rph, ...) { ... } : ... + %future = "RT.deref_return_ptr_placeholder"(%rph) : + (!RT.rtptr>) -> !RT.future> + %wrappedfuture = tensor.from_elements %future : + tensor<1x!RT.future>> + ... + scf.forall.in_parallel { + ... + tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] : + tensor<1xRT.future>> into tensor<16x!RT.future>> + ... + } + } + + scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) { + %future = tensor.extract %tensoroffutures[%arg] : + tensor<4x!RT.future>> + %res = "RT.await_future"(%future) : (!RT.future>) -> tensor<...> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] : + tensor<...> into tensor<...> + } + } + ``` + }]; + let constructor = "mlir::concretelang::createHoistAwaitFuturePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "mlir::concretelang::RT::RTDialect" + ]; +} + +#endif // MLIR_DIALECT_RT_TRANSFORMS_PASSES diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp index 8383d9bb5..2515dbf08 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp @@ -472,5 +472,19 @@ mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, return normalizedIV; } +llvm::SmallVector +normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, + mlir::ValueRange ivs, + llvm::ArrayRef lbs, + llvm::ArrayRef steps) { + llvm::SmallVector normalizedIVs; + + for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) { + normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step)); + } + + return normalizedIVs; +} + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 4f9942f3f..74d24dd03 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -397,15 +397,13 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::Tracing::TraceCiphertextOp>, mlir::concretelang::GenericTypeConverterPattern, - mlir::concretelang::GenericTypeConverterPattern>( - &getContext(), converter); + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::tensor::ParallelInsertSliceOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); - mlir::concretelang::GenericTypeConverterPattern< - mlir::tensor::ParallelInsertSliceOp>(&getContext(), converter); - // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index f8288d7c5..f6c59b3c1 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -377,6 +377,11 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); + patterns.add>(&getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::tensor::ParallelInsertSliceOp>(target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt index 993e05688..5168d47ea 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library( RTDialectTransforms BufferizableOpInterfaceImpl.cpp + HoistAwaitFuturePass.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/RT DEPENDS + RTTransformsIncGen mlir-headers LINK_LIBS PUBLIC diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp new file mode 100644 index 000000000..f5e978405 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp @@ -0,0 +1,259 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include + +#include + +#include +#include + +namespace { +struct HoistAwaitFuturePass + : public HoistAwaitFuturePassBase { + // Checks if all values of `a` are sizes of a non-dynamic dimensions + bool allStatic(llvm::ArrayRef a) { + return llvm::all_of( + a, [](int64_t r) { return !mlir::ShapedType::isDynamic(r); }); + } + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + llvm::SmallVector opsToErase; + + func.walk([&](mlir::concretelang::RT::AwaitFutureOp awaitFutureOp) { + // Make sure there are no other consumers that rely on the + // synchronization + if (!awaitFutureOp.getResult().hasOneUse()) + return; + + mlir::scf::ForallOp forallOp = + llvm::dyn_cast(awaitFutureOp->getParentOp()); + + if (!forallOp) + return; + + mlir::tensor::ParallelInsertSliceOp parallelInsertSliceOp = + llvm::dyn_cast( + awaitFutureOp.getResult().getUses().begin()->getOwner()); + + if (!parallelInsertSliceOp) + return; + + // Make sure that the original tensor into which the + // synchronized values are inserted is a region out argument of + // the forall op and thus being written to concurrently + mlir::Value dst = parallelInsertSliceOp.getDest(); + + if (!llvm::any_of(forallOp.getRegionOutArgs(), + [=](mlir::Value output) { return output == dst; })) + return; + + // Currently, the tensor storing the futures must have a static + // shape, so only loops with static trip counts are supported + if (!(allStatic(forallOp.getStaticLowerBound()) && + allStatic(forallOp.getStaticUpperBound()) && + allStatic(forallOp.getStaticStep()))) + return; + + llvm::SmallVector tripCounts; + + for (auto [lb, ub, step] : llvm::zip_equal(forallOp.getStaticLowerBound(), + forallOp.getStaticUpperBound(), + forallOp.getStaticStep())) { + tripCounts.push_back( + mlir::concretelang::getStaticTripCount(lb, ub, step)); + } + + mlir::IRRewriter rewriter(&getContext()); + rewriter.setInsertionPoint(forallOp); + + mlir::Value tensorOfFutures = rewriter.create( + forallOp.getLoc(), tripCounts, awaitFutureOp.getInput().getType()); + + // Assemble the list of shared outputs that are to be preserved + // after the output storing the results of the `RT.await_future` + // has been removed + llvm::SmallVector newOutputs; + mlir::Value tensorOfValues; + size_t i = 0; + size_t oldResultIdx; + for (auto [output, regionOutArg] : llvm::zip_equal( + forallOp.getOutputs(), forallOp.getRegionOutArgs())) { + if (regionOutArg != dst) { + newOutputs.push_back(output); + } else { + tensorOfValues = output; + oldResultIdx = i; + } + + i++; + } + + newOutputs.push_back(tensorOfFutures); + + // Create a new forall loop with the same shared outputs except + // for the one previously storing the contents of the + // `RT.await_future` ops is replaced with a tensor of futures + rewriter.setInsertionPointAfter(forallOp); + mlir::scf::ForallOp newForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOutputs, + std::nullopt); + + // Move all operations from the old forall op to the new one + auto &newOperations = newForallOp.getBody()->getOperations(); + mlir::Block *oldBody = forallOp.getBody(); + + newOperations.splice(newOperations.begin(), oldBody->getOperations(), + oldBody->begin(), std::prev(oldBody->end())); + + // Wrap future in a tensor of one element, so that it can be + // stored in the new shared output tensor of futures using + // `tensor.parallel_insert_slice` + rewriter.setInsertionPointAfter(awaitFutureOp); + mlir::Value futureAsTensor = + rewriter.create( + awaitFutureOp.getLoc(), + mlir::ValueRange{awaitFutureOp.getInput()}); + + // Move all operations from the old `scf.forall.in_parallel` + // terminator to the new one + mlir::scf::InParallelOp oldTerminator = forallOp.getTerminator(); + mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator(); + + mlir::Block::OpListType &oldTerminatorOps = + oldTerminator.getRegion().getBlocks().begin()->getOperations(); + mlir::Block::OpListType &newTerminatorOps = + newTerminator.getRegion().getBlocks().begin()->getOperations(); + + newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps, + oldTerminatorOps.begin(), oldTerminatorOps.end()); + + // Remap IVs and out args + for (auto [oldIV, newIV] : llvm::zip(forallOp.getInductionVars(), + newForallOp.getInductionVars())) { + oldIV.replaceAllUsesWith(newIV); + } + + { + size_t offs = 0; + for (auto it : llvm::enumerate(forallOp.getRegionOutArgs())) { + mlir::Value oldRegionOutArg = it.value(); + + if (oldRegionOutArg != dst) { + oldRegionOutArg.replaceAllUsesWith( + newForallOp.getRegionOutArgs()[it.index() - offs]); + } else { + offs++; + } + } + } + + // Create new `tensor.parallel_inset_slice` operation inserting + // the future into the tensor of futures + llvm::SmallVector ones(tripCounts.size(), + rewriter.getI64IntegerAttr(1)); + + mlir::Value tensorOfFuturesRegionOutArg = + newForallOp.getRegionOutArgs().back(); + + mlir::ImplicitLocOpBuilder ilob(parallelInsertSliceOp.getLoc(), rewriter); + + rewriter.setInsertionPointAfter(parallelInsertSliceOp); + rewriter.create( + parallelInsertSliceOp.getLoc(), futureAsTensor, + tensorOfFuturesRegionOutArg, + mlir::getAsOpFoldResult(mlir::concretelang::normalizeInductionVars( + ilob, newForallOp.getInductionVars(), + newForallOp.getMixedLowerBound(), newForallOp.getMixedStep())), + ones, ones); + + // Create a new forall loop, that invokes `RT.await_future` on + // all futures stored in the tensor of futures and writes the + // contents into the otiginal tensor with the results + rewriter.setInsertionPointAfter(newForallOp); + mlir::scf::ForallOp syncForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), + mlir::ValueRange{tensorOfValues}, std::nullopt); + + mlir::Value resultTensorOfFutures = newForallOp.getResults().back(); + + rewriter.setInsertionPointToStart(syncForallOp.getBody()); + mlir::Value extractedFuture = rewriter.create( + awaitFutureOp.getLoc(), resultTensorOfFutures, + syncForallOp.getInductionVars()); + mlir::concretelang::RT::AwaitFutureOp newAwaitFutureOp = + rewriter.create( + awaitFutureOp.getLoc(), awaitFutureOp.getResult().getType(), + extractedFuture); + + mlir::IRMapping syncMapping; + + for (auto [oldIV, newIV] : + llvm::zip_equal(newForallOp.getInductionVars(), + syncForallOp.getInductionVars())) { + syncMapping.map(oldIV, newIV); + } + + syncMapping.map(dst, syncForallOp.getOutputBlockArguments().back()); + syncMapping.map(parallelInsertSliceOp.getSource(), + newAwaitFutureOp.getResult()); + + mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator(); + rewriter.setInsertionPointToStart(syncTerminator.getBody()); + rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping); + + // Replace uses of the results of the original forall loop with: + // either the corresponding result from the new forall loop if + // this is a result unrelated to the futures or with the result + // of the forall loop synchronizing the futures + { + size_t offs = 0; + for (size_t i = 0; i < forallOp.getNumResults(); i++) { + if (i == oldResultIdx) { + forallOp.getResult(i).replaceAllUsesWith(syncForallOp.getResult(0)); + offs = 1; + } else { + forallOp.getResult(i).replaceAllUsesWith( + newForallOp.getResult(i - offs)); + } + } + } + + // Replace the use of the shared output with the results of the + // original forall loop with the tensor outside of the loop so + // that there are no more references to values that were local + // to the original forall loop, enabling safe erasing of the old + // operations within the original forall loop + dst.replaceAllUsesWith( + forallOp.getOutputs().drop_front(oldResultIdx).front()); + parallelInsertSliceOp->erase(); + awaitFutureOp.erase(); + + // Defer erasing the original parallel loop that contained the + // `RT.await_future` operation until later in order to not + // confuse the walk relying on the parent operation + opsToErase.push_back(forallOp); + }); + + for (mlir::Operation *op : opsToErase) + op->erase(); + } +}; +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createHoistAwaitFuturePass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index c1f6187b1..210d2ea03 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -44,6 +44,7 @@ #include "concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h" #include "concretelang/Dialect/FHELinalg/Transforms/Tiling.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" +#include "concretelang/Dialect/RT/Transforms/Passes.h" #include "concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h" #include "concretelang/Dialect/TFHE/Transforms/Transforms.h" #include "concretelang/Support/CompilerEngine.h" @@ -183,6 +184,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); + addPotentiallyNestedPass(pm, mlir::concretelang::createHoistAwaitFuturePass(), + enablePass); return pm.run(module.getOperation()); }