mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
This adds a new option `--unroll-loops-with-sdfg-convertible-ops`, which causes loops containing SDFG-convertible operations to be fully unrolled upon the extraction of SDFG-operations using the `--emit-sdfg-ops` switch. This avoids constant roundtrips between an SDFG-capable accelerator and the host during execution of a loop. The option is limited to `scf.for` loops with static bounds and a static step size. Since full unrolling of loops with large bounds results in a large number of operations, the option is disabled by default.
267 lines
9.4 KiB
C++
267 lines
9.4 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Conversion/Passes.h"
|
|
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
|
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
|
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
|
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
|
#include <mlir/Dialect/SCF/IR/SCF.h>
|
|
|
|
namespace SDFG = mlir::concretelang::SDFG;
|
|
|
|
namespace {
|
|
enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE };
|
|
|
|
SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder,
|
|
SDFG::StreamKind kind, mlir::Type type,
|
|
mlir::Value dfg, unsigned &streamNumber) {
|
|
SDFG::StreamType streamType = builder.getType<SDFG::StreamType>(type);
|
|
mlir::StringAttr name = builder.getStringAttr(llvm::Twine("stream") +
|
|
llvm::Twine(streamNumber++));
|
|
|
|
return builder.create<SDFG::MakeStream>(streamType, dfg, name, kind);
|
|
}
|
|
|
|
/// Unrolls entirely all scf loops, which directly contain an
|
|
/// SDFG-convertible operation and whose bounds are static.
|
|
void unrollLoopsWithSDFGConvertibleOps(mlir::func::FuncOp func) {
|
|
mlir::DenseSet<mlir::scf::ForOp> unrollCandidates;
|
|
|
|
// Identify loops with SDFG-convertible ops
|
|
func.walk([&](SDFG::SDFGConvertibleOpInterface convertible) {
|
|
for (mlir::Operation *parent = convertible->getParentOp(); parent;
|
|
parent = parent->getParentOp()) {
|
|
if (mlir::scf::ForOp forOp = llvm::dyn_cast<mlir::scf::ForOp>(parent)) {
|
|
unrollCandidates.insert(forOp);
|
|
}
|
|
}
|
|
});
|
|
|
|
// Fully unroll all the loops if its bounds are static
|
|
for (mlir::scf::ForOp forOp : unrollCandidates) {
|
|
mlir::arith::ConstantIndexOp lb =
|
|
forOp.getLowerBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
|
mlir::arith::ConstantIndexOp ub =
|
|
forOp.getUpperBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
|
mlir::arith::ConstantIndexOp step =
|
|
forOp.getStep().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
|
|
|
if (!lb || !ub || !step)
|
|
continue;
|
|
|
|
int64_t ilb = lb.value();
|
|
int64_t iub = ub.value();
|
|
int64_t istep = step.value();
|
|
|
|
// Unrolling requires positive bounds and step
|
|
if (ilb < 0 || iub < 0 || istep <= 0)
|
|
continue;
|
|
|
|
int64_t unrollFactor = ((iub - ilb) + (istep - 1)) / istep;
|
|
|
|
if (unrollFactor == 0)
|
|
continue;
|
|
|
|
if (mlir::loopUnrollByFactor(forOp, (uint64_t)unrollFactor).failed())
|
|
continue;
|
|
}
|
|
}
|
|
|
|
StreamMappingKind determineStreamMappingKind(mlir::Value v) {
|
|
// Determine stream type for operands:
|
|
//
|
|
// - If an operand is produced by a non-convertible op, there
|
|
// needs to be just a host-to-device stream
|
|
//
|
|
// - If an operand is produced by a convertible op and there
|
|
// are no other consumers, a device-to-device stream may be
|
|
// used
|
|
//
|
|
// - If an operand is produced by a convertible op and there is
|
|
// at least one other non-convertible consumer, there needs
|
|
// to be a device-to-host stream, and a host-to-device stream
|
|
//
|
|
if (llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
|
v.getDefiningOp())) {
|
|
// All convertible consumers?
|
|
if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
|
|
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
|
o.getOwner());
|
|
})) {
|
|
return StreamMappingKind::ON_DEVICE;
|
|
}
|
|
// All non-convertible consumers?
|
|
else if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
|
|
return !llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
|
o.getOwner());
|
|
})) {
|
|
return StreamMappingKind::TO_HOST;
|
|
}
|
|
// Mix of convertible and non-convertible users
|
|
else {
|
|
return StreamMappingKind::SPLICE;
|
|
}
|
|
} else {
|
|
if (llvm::any_of(v.getUses(), [](mlir::OpOperand &o) {
|
|
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
|
o.getOwner());
|
|
})) {
|
|
return StreamMappingKind::TO_DEVICE;
|
|
} else {
|
|
return StreamMappingKind::NONE;
|
|
}
|
|
}
|
|
}
|
|
|
|
void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder,
|
|
mlir::Value v,
|
|
mlir::OpBuilder::InsertPoint &pos) {
|
|
if (v.isa<mlir::BlockArgument>())
|
|
builder.restoreInsertionPoint(pos);
|
|
else
|
|
builder.setInsertionPointAfterValue(v);
|
|
}
|
|
|
|
struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
|
|
bool unroll;
|
|
|
|
ExtractSDFGOpsPass(bool unroll) : unroll(unroll) {}
|
|
|
|
void runOnOperation() override {
|
|
mlir::func::FuncOp func = getOperation();
|
|
|
|
if (unroll)
|
|
unrollLoopsWithSDFGConvertibleOps(func);
|
|
|
|
mlir::IRRewriter rewriter(func.getContext());
|
|
|
|
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
|
|
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processInMapping;
|
|
mlir::DenseMap<mlir::Value, mlir::Value> replacementMapping;
|
|
|
|
llvm::SmallVector<SDFG::SDFGConvertibleOpInterface> convertibleOps;
|
|
|
|
unsigned streamNumber = 0;
|
|
|
|
func.walk([&](SDFG::SDFGConvertibleOpInterface op) {
|
|
convertibleOps.push_back(op);
|
|
});
|
|
|
|
if (convertibleOps.size() == 0)
|
|
return;
|
|
|
|
// Insert Prelude
|
|
rewriter.setInsertionPointToStart(&func.getBlocks().front());
|
|
mlir::Value dfg = rewriter.create<SDFG::Init>(
|
|
func.getLoc(), rewriter.getType<SDFG::DFGType>());
|
|
SDFG::Start start = rewriter.create<SDFG::Start>(func.getLoc(), dfg);
|
|
|
|
rewriter.setInsertionPoint(func.getBlocks().front().getTerminator());
|
|
rewriter.create<SDFG::Shutdown>(func.getLoc(), dfg);
|
|
|
|
mlir::ImplicitLocOpBuilder ilb(func.getLoc(), rewriter);
|
|
|
|
auto mapValueToStreams = [&](mlir::Value v) {
|
|
if (processInMapping.find(v) != processInMapping.end() ||
|
|
processOutMapping.find(v) != processOutMapping.end())
|
|
return;
|
|
|
|
StreamMappingKind smk = determineStreamMappingKind(v);
|
|
|
|
SDFG::MakeStream prodOutStream;
|
|
SDFG::MakeStream consInStream;
|
|
|
|
if (smk == StreamMappingKind::SPLICE ||
|
|
smk == StreamMappingKind::TO_HOST) {
|
|
ilb.setInsertionPoint(start);
|
|
prodOutStream = makeStream(ilb, SDFG::StreamKind::device_to_host,
|
|
v.getType(), dfg, streamNumber);
|
|
processOutMapping.insert({v, prodOutStream});
|
|
|
|
ilb.setInsertionPointAfter(start);
|
|
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
|
|
setInsertionPointAfterValueOrRestore(ilb, v, pos);
|
|
mlir::Value newOutVal =
|
|
ilb.create<SDFG::Get>(v.getType(), prodOutStream.getResult());
|
|
replacementMapping.insert({v, newOutVal});
|
|
} else if (smk == StreamMappingKind::ON_DEVICE) {
|
|
ilb.setInsertionPoint(start);
|
|
prodOutStream = makeStream(ilb, SDFG::StreamKind::on_device,
|
|
v.getType(), dfg, streamNumber);
|
|
processOutMapping.insert({v, prodOutStream});
|
|
}
|
|
|
|
if (smk == StreamMappingKind::ON_DEVICE) {
|
|
processInMapping.insert({v, prodOutStream});
|
|
} else if (smk == StreamMappingKind::SPLICE ||
|
|
smk == StreamMappingKind::TO_DEVICE ||
|
|
smk == StreamMappingKind::ON_DEVICE) {
|
|
ilb.setInsertionPoint(start);
|
|
consInStream = makeStream(ilb, SDFG::StreamKind::host_to_device,
|
|
v.getType(), dfg, streamNumber);
|
|
processInMapping.insert({v, consInStream});
|
|
|
|
if (smk == StreamMappingKind::TO_DEVICE) {
|
|
ilb.setInsertionPointAfter(start);
|
|
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
|
|
setInsertionPointAfterValueOrRestore(ilb, v, pos);
|
|
ilb.create<SDFG::Put>(consInStream.getResult(), v);
|
|
}
|
|
}
|
|
};
|
|
|
|
for (SDFG::SDFGConvertibleOpInterface convertibleOp : convertibleOps) {
|
|
llvm::SmallVector<mlir::Value> ins;
|
|
llvm::SmallVector<mlir::Value> outs;
|
|
ilb.setLoc(convertibleOp.getLoc());
|
|
|
|
for (mlir::Value res : convertibleOp->getResults()) {
|
|
mapValueToStreams(res);
|
|
outs.push_back(processOutMapping.find(res)->second.getResult());
|
|
}
|
|
|
|
for (mlir::Value operand : convertibleOp->getOperands()) {
|
|
mapValueToStreams(operand);
|
|
ins.push_back(processInMapping.find(operand)->second.getResult());
|
|
}
|
|
|
|
ilb.setInsertionPoint(start);
|
|
SDFG::MakeProcess process = convertibleOp.convert(ilb, dfg, ins, outs);
|
|
|
|
assert(process && "Conversion to SDFG operation failed");
|
|
}
|
|
|
|
for (auto it : replacementMapping) {
|
|
it.first.replaceAllUsesWith(it.second);
|
|
}
|
|
|
|
(void)mlir::simplifyRegions(rewriter, func->getRegions());
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
|
|
createExtractSDFGOpsPass(bool unroll) {
|
|
return std::make_unique<ExtractSDFGOpsPass>(unroll);
|
|
}
|
|
} // namespace concretelang
|
|
} // namespace mlir
|