mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
465 lines
15 KiB
C++
465 lines
15 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include <chrono>
|
|
#include <cmath>
|
|
#include <initializer_list>
|
|
#include <vector>
|
|
|
|
#include "boost/outcome.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include "concrete-optimizer.hpp"
|
|
|
|
#include "concretelang/Common/Error.h"
|
|
#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h"
|
|
#include "concretelang/Dialect/FHE/Analysis/utils.h"
|
|
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
|
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
|
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
|
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
|
#include "concretelang/Support/V0Parameters.h"
|
|
#include "concretelang/Support/logging.h"
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h.inc"
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
namespace optimizer {
|
|
|
|
namespace {
|
|
|
|
template <typename T> rust::Slice<const T> slice(const std::vector<T> &vec) {
|
|
return rust::Slice<const T>(vec.data(), vec.size());
|
|
}
|
|
|
|
template <typename T> rust::Slice<const T> slice(const llvm::ArrayRef<T> &vec) {
|
|
return rust::Slice<const T>(vec.data(), vec.size());
|
|
}
|
|
|
|
struct FunctionToDag {
|
|
// Inputs of operators
|
|
using Inputs = std::vector<concrete_optimizer::dag::OperatorIndex>;
|
|
|
|
const double NEGLIGIBLE_COMPLEXITY = 0.0;
|
|
|
|
mlir::func::FuncOp func;
|
|
optimizer::Config config;
|
|
llvm::DenseMap<mlir::Value, concrete_optimizer::dag::OperatorIndex> index;
|
|
|
|
FunctionToDag(mlir::func::FuncOp func, optimizer::Config config)
|
|
: func(func), config(config) {}
|
|
|
|
#define DEBUG(MSG) \
|
|
if (mlir::concretelang::isVerbose()) { \
|
|
mlir::concretelang::log_verbose() << MSG << "\n"; \
|
|
}
|
|
|
|
outcome::checked<llvm::Optional<optimizer::Dag>,
|
|
::concretelang::error::StringError>
|
|
build() {
|
|
auto dag = concrete_optimizer::dag::empty();
|
|
// Converting arguments as Input
|
|
for (auto &arg : func.getArguments()) {
|
|
addArg(dag, arg);
|
|
}
|
|
// Converting ops
|
|
for (auto &bb : func.getBody().getBlocks()) {
|
|
for (auto &op : bb.getOperations()) {
|
|
addOperation(dag, op);
|
|
}
|
|
}
|
|
for (auto &bb : func.getBody().getBlocks()) {
|
|
for (auto &op : bb.getOperations()) {
|
|
op.removeAttr("SMANP");
|
|
}
|
|
}
|
|
if (index.empty()) {
|
|
// Dag is empty <=> classical function without encryption
|
|
DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName()
|
|
<< "\n");
|
|
return llvm::None;
|
|
};
|
|
DEBUG(std::string(dag->dump()));
|
|
return std::move(dag);
|
|
}
|
|
|
|
void addArg(optimizer::Dag &dag, mlir::Value &arg) {
|
|
DEBUG("Arg " << arg << " " << arg.getType());
|
|
if (!fhe::utils::isEncryptedValue(arg)) {
|
|
return;
|
|
}
|
|
auto precision = fhe::utils::getEintPrecision(arg);
|
|
auto shape = getShape(arg);
|
|
auto opI = dag->add_input(precision, slice(shape));
|
|
index[arg] = opI;
|
|
}
|
|
|
|
bool hasEncryptedResult(mlir::Operation &op) {
|
|
for (auto val : op.getResults()) {
|
|
if (fhe::utils::isEncryptedValue(val)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void addOperation(optimizer::Dag &dag, mlir::Operation &op) {
|
|
DEBUG("Instr " << op);
|
|
|
|
if (isReturn(op)) {
|
|
// This op has no result
|
|
return;
|
|
}
|
|
|
|
auto encrypted_inputs = encryptedInputs(op);
|
|
if (!hasEncryptedResult(op)) {
|
|
// This op is unrelated to FHE
|
|
assert(encrypted_inputs.empty() ||
|
|
mlir::isa<mlir::concretelang::Tracing::TraceCiphertextOp>(op));
|
|
return;
|
|
}
|
|
assert(op.getNumResults() == 1);
|
|
auto val = op.getResult(0);
|
|
auto precision = fhe::utils::getEintPrecision(val);
|
|
if (isLut(op)) {
|
|
addLut(dag, val, encrypted_inputs, precision);
|
|
return;
|
|
}
|
|
if (isRound(op)) {
|
|
addRound(dag, val, encrypted_inputs, precision);
|
|
return;
|
|
}
|
|
if (auto dot = asDot(op)) {
|
|
auto weightsOpt = dotWeights(dot);
|
|
if (weightsOpt) {
|
|
addDot(dag, val, encrypted_inputs, weightsOpt.getValue());
|
|
return;
|
|
}
|
|
// If can't find weights return default leveled op
|
|
DEBUG("Replace Dot by LevelledOp on " << op);
|
|
}
|
|
if (auto max = asMax(op)) {
|
|
addMax(dag, max, encrypted_inputs, precision);
|
|
return;
|
|
}
|
|
if (auto maxpool2d = asMaxpool2d(op)) {
|
|
addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision);
|
|
return;
|
|
}
|
|
// default
|
|
addLevelledOp(dag, op, encrypted_inputs);
|
|
}
|
|
|
|
void addLut(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
|
|
int precision) {
|
|
assert(encrypted_inputs.size() == 1);
|
|
// No need to distinguish different lut kind until we do approximate
|
|
// paradigm on outputs
|
|
auto encrypted_input = encrypted_inputs[0];
|
|
std::vector<std::uint64_t> unknowFunction;
|
|
index[val] =
|
|
dag->add_lut(encrypted_input, slice(unknowFunction), precision);
|
|
}
|
|
|
|
void addRound(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
|
|
int rounded_precision) {
|
|
assert(encrypted_inputs.size() == 1);
|
|
// No need to distinguish different lut kind until we do approximate
|
|
// paradigm on outputs
|
|
auto encrypted_input = encrypted_inputs[0];
|
|
index[val] = dag->add_round_op(encrypted_input, rounded_precision);
|
|
}
|
|
|
|
void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
|
|
std::vector<std::int64_t> &weights_vector) {
|
|
assert(encrypted_inputs.size() == 1);
|
|
auto weights = concrete_optimizer::weights::vector(slice(weights_vector));
|
|
index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights));
|
|
}
|
|
|
|
std::string loc_to_string(mlir::Location location) {
|
|
std::string loc;
|
|
llvm::raw_string_ostream loc_stream(loc);
|
|
location.print(loc_stream);
|
|
return loc;
|
|
}
|
|
|
|
void addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) {
|
|
auto val = op.getResult(0);
|
|
auto out_shape = getShape(val);
|
|
if (inputs.empty()) {
|
|
// Trivial encrypted constants encoding
|
|
// There are converted to input + levelledop
|
|
auto precision = fhe::utils::getEintPrecision(val);
|
|
auto opI = dag->add_input(precision, slice(out_shape));
|
|
inputs.push_back(opI);
|
|
}
|
|
// Default complexity is negligible
|
|
double fixed_cost = NEGLIGIBLE_COMPLEXITY;
|
|
double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY;
|
|
auto smanp_int = op.getAttrOfType<mlir::IntegerAttr>("SMANP");
|
|
auto loc = loc_to_string(op.getLoc());
|
|
assert(smanp_int && "Missing manp value on a crypto operation");
|
|
// TODO: use APIFloat.sqrt when it's available
|
|
double manp = sqrt(smanp_int.getValue().roundToDouble());
|
|
auto comment = std::string(op.getName().getStringRef()) + " " + loc;
|
|
index[val] =
|
|
dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost,
|
|
manp, slice(out_shape), comment);
|
|
}
|
|
|
|
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
|
|
int precision) {
|
|
mlir::Value result = maxOp.getResult();
|
|
const std::vector<uint64_t> resultShape = getShape(result);
|
|
|
|
Operation *xOp = maxOp.x().getDefiningOp();
|
|
Operation *yOp = maxOp.y().getDefiningOp();
|
|
|
|
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
|
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
|
|
|
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
|
|
if (xOp != nullptr) {
|
|
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
|
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
|
|
xSmanp = xSmanpAttr.getValue();
|
|
}
|
|
|
|
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
|
|
if (yOp != nullptr) {
|
|
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
|
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
|
|
ySmanp = ySmanpAttr.getValue();
|
|
}
|
|
|
|
const double subManp =
|
|
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
|
|
|
|
auto loc = loc_to_string(maxOp.getLoc());
|
|
auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc;
|
|
|
|
auto subNode =
|
|
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
|
subManp, slice(resultShape), comment);
|
|
|
|
const double tluNodeManp = 1;
|
|
const std::vector<std::uint64_t> unknownFunction;
|
|
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
|
|
|
const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble());
|
|
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
|
|
tluNode, inputs[1]};
|
|
index[result] =
|
|
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
|
|
addManp, slice(resultShape), comment);
|
|
}
|
|
|
|
void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp,
|
|
Inputs &inputs, int precision) {
|
|
mlir::Value result = maxpool2dOp.getResult();
|
|
const std::vector<uint64_t> resultShape = getShape(result);
|
|
|
|
// all TLUs are flattened into a dimension
|
|
// to create a single TLU node in optimizer dag
|
|
std::vector<uint64_t> fakeShape = resultShape;
|
|
|
|
uint64_t numberOfComparisons = 1;
|
|
for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues<int64_t>()) {
|
|
numberOfComparisons *= dimensionSize;
|
|
}
|
|
fakeShape.push_back(numberOfComparisons);
|
|
|
|
Operation *inputOp = maxpool2dOp.input().getDefiningOp();
|
|
|
|
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
|
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
|
|
|
llvm::APInt inputSmanp = llvm::APInt{1, 1, false};
|
|
if (inputOp != nullptr) {
|
|
const auto inputSmanpAttr =
|
|
inputOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
|
assert(inputSmanpAttr && "Missing SMANP value on a crypto operation");
|
|
inputSmanp = inputSmanpAttr.getValue();
|
|
}
|
|
|
|
const double subManp = sqrt(2 * inputSmanp.roundToDouble() + 1);
|
|
|
|
auto loc = loc_to_string(maxpool2dOp.getLoc());
|
|
auto comment =
|
|
std::string(maxpool2dOp->getName().getStringRef()) + " " + loc;
|
|
|
|
auto subNode =
|
|
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
|
subManp, slice(fakeShape), comment);
|
|
|
|
const std::vector<std::uint64_t> unknownFunction;
|
|
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
|
|
|
const double addManp = sqrt(inputSmanp.roundToDouble() + 1);
|
|
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
|
|
tluNode, inputs[1]};
|
|
|
|
index[result] =
|
|
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
|
|
addManp, slice(resultShape), comment);
|
|
}
|
|
|
|
Inputs encryptedInputs(mlir::Operation &op) {
|
|
Inputs inputs;
|
|
for (auto operand : op.getOperands()) {
|
|
auto entry = index.find(operand);
|
|
if (entry == index.end()) {
|
|
assert(!fhe::utils::isEncryptedValue(operand));
|
|
DEBUG("Ignoring as input " << operand);
|
|
continue;
|
|
}
|
|
inputs.push_back(entry->getSecond());
|
|
}
|
|
return inputs;
|
|
}
|
|
|
|
bool isLut(mlir::Operation &op) {
|
|
return llvm::isa<
|
|
mlir::concretelang::FHE::ApplyLookupTableEintOp,
|
|
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp,
|
|
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp,
|
|
mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op);
|
|
}
|
|
|
|
bool isRound(mlir::Operation &op) {
|
|
return llvm::isa<mlir::concretelang::FHE::RoundEintOp>(op);
|
|
}
|
|
|
|
mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) {
|
|
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
|
|
}
|
|
|
|
mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) {
|
|
return llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op);
|
|
}
|
|
|
|
mlir::concretelang::FHELinalg::Maxpool2dOp asMaxpool2d(mlir::Operation &op) {
|
|
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(op);
|
|
}
|
|
|
|
bool isReturn(mlir::Operation &op) {
|
|
return llvm::isa<mlir::func::ReturnOp>(op);
|
|
}
|
|
|
|
bool isConst(mlir::Operation &op) {
|
|
return llvm::isa<mlir::arith::ConstantOp>(op);
|
|
}
|
|
|
|
bool isArg(const mlir::Value &value) {
|
|
return value.isa<mlir::BlockArgument>();
|
|
}
|
|
|
|
llvm::Optional<std::vector<std::int64_t>>
|
|
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
|
|
std::vector<std::int64_t> values;
|
|
mlir::DenseIntElementsAttr denseVals =
|
|
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
|
|
|
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
|
|
if (val.getActiveBits() > 64) {
|
|
return llvm::None;
|
|
}
|
|
values.push_back(val.getSExtValue());
|
|
}
|
|
return values;
|
|
}
|
|
|
|
llvm::Optional<std::vector<std::int64_t>>
|
|
resolveConstantWeights(mlir::Value &value) {
|
|
if (auto cstOp = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
|
value.getDefiningOp())) {
|
|
auto shape = getShape(value);
|
|
switch (shape.size()) {
|
|
case 1:
|
|
return resolveConstantVectorWeights(cstOp);
|
|
default:
|
|
DEBUG("High-Rank tensor: rely on MANP and levelledOp");
|
|
return llvm::None;
|
|
}
|
|
} else {
|
|
DEBUG("Dynamic Weights: rely on MANP and levelledOp");
|
|
return llvm::None;
|
|
}
|
|
}
|
|
|
|
llvm::Optional<std::vector<std::int64_t>>
|
|
dotWeights(mlir::concretelang::FHELinalg::Dot &dot) {
|
|
if (dot.getOperands().size() != 2) {
|
|
return llvm::None;
|
|
}
|
|
auto weights = dot.getOperands()[1];
|
|
return resolveConstantWeights(weights);
|
|
}
|
|
|
|
std::vector<std::uint64_t> getShape(mlir::Value &value) {
|
|
return getShape(value.getType());
|
|
}
|
|
|
|
std::vector<std::uint64_t> getShape(mlir::Type type_) {
|
|
if (auto ranked_tensor = type_.dyn_cast_or_null<mlir::RankedTensorType>()) {
|
|
std::vector<std::uint64_t> shape;
|
|
for (auto v : ranked_tensor.getShape()) {
|
|
shape.push_back(v);
|
|
}
|
|
return shape;
|
|
} else {
|
|
return {};
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
struct DagPass : ConcreteOptimizerBase<DagPass> {
|
|
optimizer::Config config;
|
|
optimizer::FunctionsDag &dags;
|
|
|
|
void runOnOperation() override {
|
|
mlir::func::FuncOp func = getOperation();
|
|
auto name = std::string(func.getName());
|
|
DEBUG("ConcreteOptimizer Dag: " << name);
|
|
auto dag = FunctionToDag(func, config).build();
|
|
if (dag) {
|
|
dags.insert(
|
|
optimizer::FunctionsDag::value_type(name, std::move(dag.value())));
|
|
} else {
|
|
this->signalPassFailure();
|
|
}
|
|
}
|
|
|
|
DagPass() = delete;
|
|
DagPass(optimizer::Config config, optimizer::FunctionsDag &dags)
|
|
: config(config), dags(dags) {}
|
|
};
|
|
|
|
// Create an instance of the ConcreteOptimizerPass pass.
|
|
// A global pass result is communicated using `dags`.
|
|
// If `debug` is true, for each operation, the pass emits a
|
|
// remark containing the squared Minimal Arithmetic Noise Padding of
|
|
// the equivalent dot operation.
|
|
std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
|
|
optimizer::FunctionsDag &dags) {
|
|
return std::make_unique<optimizer::DagPass>(config, dags);
|
|
}
|
|
|
|
} // namespace optimizer
|
|
} // namespace concretelang
|
|
} // namespace mlir
|