feat(compiler): Add support for full unrolling of loops with SDFG-convertible ops

This adds a new option `--unroll-loops-with-sdfg-convertible-ops`,
which causes loops containing SDFG-convertible operations to be fully
unrolled upon the extraction of SDFG-operations using the
`--emit-sdfg-ops` switch. This avoids constant roundtrips between an
SDFG-capable accelerator and the host during execution of a loop.

The option is limited to `scf.for` loops with static bounds and a
static step size. Since full unrolling of loops with large bounds
results in a large number of operations, the option is disabled by
default.
This commit is contained in:
Andi Drebes
2022-12-07 10:26:29 +01:00
committed by Andi Drebes
parent 2dce654406
commit e2e6df322e
9 changed files with 160 additions and 17 deletions

View File

@@ -11,7 +11,8 @@
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass();
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createExtractSDFGOpsPass(bool unroll);
} // namespace concretelang
} // namespace mlir

View File

@@ -60,6 +60,7 @@ struct CompilationOptions {
bool loopParallelize;
bool batchConcreteOps;
bool emitSDFGOps;
bool unrollLoopsWithSDFGConvertibleOps;
bool dataflowParallelize;
bool optimizeConcrete;
/// use GPU during execution by generating GPU operations if possible
@@ -73,8 +74,9 @@ struct CompilationOptions {
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true),
emitGPUOps(false), clientParametersFuncName(llvm::None),
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
clientParametersFuncName(llvm::None),
optimizerConfig(optimizer::DEFAULT_CONFIG){};
CompilationOptions(std::string funcname) : CompilationOptions() {

View File

@@ -58,9 +58,10 @@ mlir::LogicalResult
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool unrollLoops);
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,

View File

@@ -12,6 +12,7 @@ add_mlir_dialect_library(
ConcretelangSDFGInterfaces
PUBLIC
MLIRIR
MLIRSCFUtils
MLIRTransforms)
target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR)

View File

@@ -9,6 +9,7 @@
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
@@ -18,6 +19,8 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
namespace SDFG = mlir::concretelang::SDFG;
@@ -34,6 +37,51 @@ SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder,
return builder.create<SDFG::MakeStream>(streamType, dfg, name, kind);
}
/// Unrolls entirely all scf loops, which directly contain an
/// SDFG-convertible operation and whose bounds are static.
void unrollLoopsWithSDFGConvertibleOps(mlir::func::FuncOp func) {
mlir::DenseSet<mlir::scf::ForOp> unrollCandidates;
// Identify loops with SDFG-convertible ops
func.walk([&](SDFG::SDFGConvertibleOpInterface convertible) {
for (mlir::Operation *parent = convertible->getParentOp(); parent;
parent = parent->getParentOp()) {
if (mlir::scf::ForOp forOp = llvm::dyn_cast<mlir::scf::ForOp>(parent)) {
unrollCandidates.insert(forOp);
}
}
});
// Fully unroll all the loops if its bounds are static
for (mlir::scf::ForOp forOp : unrollCandidates) {
mlir::arith::ConstantIndexOp lb =
forOp.getLowerBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
mlir::arith::ConstantIndexOp ub =
forOp.getUpperBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
mlir::arith::ConstantIndexOp step =
forOp.getStep().getDefiningOp<mlir::arith::ConstantIndexOp>();
if (!lb || !ub || !step)
continue;
int64_t ilb = lb.value();
int64_t iub = ub.value();
int64_t istep = step.value();
// Unrolling requires positive bounds and step
if (ilb < 0 || iub < 0 || istep <= 0)
continue;
int64_t unrollFactor = ((iub - ilb) + (istep - 1)) / istep;
if (unrollFactor == 0)
continue;
if (mlir::loopUnrollByFactor(forOp, (uint64_t)unrollFactor).failed())
continue;
}
}
StreamMappingKind determineStreamMappingKind(mlir::Value v) {
// Determine stream type for operands:
//
@@ -90,11 +138,16 @@ void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder,
}
struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
bool unroll;
ExtractSDFGOpsPass() {}
ExtractSDFGOpsPass(bool unroll) : unroll(unroll) {}
void runOnOperation() override {
mlir::func::FuncOp func = getOperation();
if (unroll)
unrollLoopsWithSDFGConvertibleOps(func);
mlir::IRRewriter rewriter(func.getContext());
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
@@ -205,8 +258,9 @@ struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass() {
return std::make_unique<ExtractSDFGOpsPass>();
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createExtractSDFGOpsPass(bool unroll) {
return std::make_unique<ExtractSDFGOpsPass>(unroll);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -399,8 +399,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// Extract SDFG data flow graph from BConcrete representation
if (options.emitSDFGOps) {
if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module,
enablePass)
if (mlir::concretelang::pipeline::extractSDFGOps(
mlirContext, module, enablePass,
options.unrollLoopsWithSDFGConvertibleOps)
.failed()) {
return errorDiag(
"Extraction of SDFG operations from BConcrete representation failed");

View File

@@ -274,14 +274,17 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool unroll) {
mlir::PassManager pm(&context);
pipelinePrinting("extract SDFG ops from BConcrete", pm, context);
addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(),
enablePass);
return pm.run(module.getOperation());
addPotentiallyNestedPass(
pm, mlir::concretelang::createExtractSDFGOpsPass(unroll), enablePass);
LogicalResult res = pm.run(module.getOperation());
return res;
}
mlir::LogicalResult

View File

@@ -183,6 +183,12 @@ llvm::cl::opt<bool> emitSDFGOps(
" graphs and emit them."),
llvm::cl::init(false));
llvm::cl::opt<bool> unrollLoopsWithSDFGConvertibleOps(
"unroll-loops-with-sdfg-convertible-ops",
llvm::cl::desc("Causes loops containing SDFG-convertible operations to be "
"fully unrolled."),
llvm::cl::init(false));
llvm::cl::opt<bool> dataflowParallelize(
"parallelize-dataflow",
llvm::cl::desc(
@@ -316,6 +322,8 @@ cmdlineCompilationOptions() {
options.dataflowParallelize = cmdline::dataflowParallelize;
options.batchConcreteOps = cmdline::batchConcreteOps;
options.emitSDFGOps = cmdline::emitSDFGOps;
options.unrollLoopsWithSDFGConvertibleOps =
cmdline::unrollLoopsWithSDFGConvertibleOps;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;

View File

@@ -0,0 +1,72 @@
// RUN: concretecompiler --action=dump-sdfg --emit-sdfg-ops --unroll-loops-with-sdfg-convertible-ops --split-input-file %s 2>&1| FileCheck %s
// CHECK: func.func @main(%[[Varg0:.*]]: tensor<4x513xi64>, %[[Varg1:.*]]: tensor<4x513xi64>) -> tensor<4x513xi64> {
// CHECK-NEXT: %[[V0:.*]] = "SDFG.init"() : () -> !SDFG.dfg
// CHECK-NEXT: %[[V1:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream0", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V2:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream1", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V3:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream2", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V2]], %[[V3]], %[[V1]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
// CHECK-NEXT: %[[V4:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream3", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V5:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream4", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V6:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream5", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V5]], %[[V6]], %[[V4]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
// CHECK-NEXT: %[[V7:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream6", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V8:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream7", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V9:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream8", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V8]], %[[V9]], %[[V7]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
// CHECK-NEXT: %[[V10:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream9", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V11:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream10", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: %[[V12:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream11", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V11]], %[[V12]], %[[V10]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
// CHECK-NEXT: "SDFG.start"(%[[V0]]) : (!SDFG.dfg) -> ()
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[V13:.*]] = bufferization.alloc_tensor() : tensor<4x513xi64>
// CHECK-NEXT: %[[V14:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V15:.*]] = tensor.collapse_shape %[[V14]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V2]], %[[V15]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V16:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V17:.*]] = tensor.collapse_shape %[[V16]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V3]], %[[V17]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V18:.*]] = "SDFG.get"(%[[V1]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
// CHECK-NEXT: %[[V19:.*]] = tensor.insert_slice %[[V18]] into %[[V13]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
// CHECK-NEXT: %[[Vc1_0:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[V20:.*]] = arith.muli %[[Vc1]], %[[Vc1_0]] : index
// CHECK-NEXT: %[[V21:.*]] = arith.addi %[[Vc0]], %[[V20]] : index
// CHECK-NEXT: %[[V22:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V23:.*]] = tensor.collapse_shape %[[V22]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V5]], %[[V23]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V24:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V25:.*]] = tensor.collapse_shape %[[V24]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V6]], %[[V25]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V26:.*]] = "SDFG.get"(%[[V4]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
// CHECK-NEXT: %[[V27:.*]] = tensor.insert_slice %[[V26]] into %[[V19]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index
// CHECK-NEXT: %[[V28:.*]] = arith.muli %[[Vc1]], %[[Vc2]] : index
// CHECK-NEXT: %[[V29:.*]] = arith.addi %[[Vc0]], %[[V28]] : index
// CHECK-NEXT: %[[V30:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V31:.*]] = tensor.collapse_shape %[[V30]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V8]], %[[V31]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V32:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V33:.*]] = tensor.collapse_shape %[[V32]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V9]], %[[V33]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V34:.*]] = "SDFG.get"(%[[V7]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
// CHECK-NEXT: %[[V35:.*]] = tensor.insert_slice %[[V34]] into %[[V27]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
// CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
// CHECK-NEXT: %[[V36:.*]] = arith.muli %[[Vc1]], %[[Vc3]] : index
// CHECK-NEXT: %[[V37:.*]] = arith.addi %[[Vc0]], %[[V36]] : index
// CHECK-NEXT: %[[V38:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V39:.*]] = tensor.collapse_shape %[[V38]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V11]], %[[V39]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V40:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
// CHECK-NEXT: %[[V41:.*]] = tensor.collapse_shape %[[V40]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
// CHECK-NEXT: "SDFG.put"(%[[V12]], %[[V41]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
// CHECK-NEXT: %[[V42:.*]] = "SDFG.get"(%[[V10]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
// CHECK-NEXT: %[[V43:.*]] = tensor.insert_slice %[[V42]] into %[[V35]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
// CHECK-NEXT: "SDFG.shutdown"(%[[V0]]) : (!SDFG.dfg) -> ()
// CHECK-NEXT: return %[[V43]] : tensor<4x513xi64>
// CHECK-NEXT: }
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}