Files
concrete/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Optimizer.cpp
Andi Drebes 9b6878316f fix(compiler): Preserve explicit optimizer partition boundaries through the pipeline
The Concrete Optimizer is invoked on a representation of the program
in the high-level FHELinalg / FHE Dialects and yields a solution with
a one-to-one mapping of operations to keys. However, the abstractions
used by these dialects do not allow for references to keys and the
application of the solution is delayed until the pipeline reaches a
representation of the program in the lower-level TFHE dialect. Various
transformations applied by the pipeline along the way may break the
one-to-one mapping and add indirections into producer-consumer
relationships, resulting in ambiguous or partial mappings of TFHE
operations to the keys. In particular, explicit frontiers between
optimizer partitions may not be recovered.

This commit preserves explicit frontiers between optimizer partitions
as `optimizer.partition_frontier` operations and lowers these to
keyswitch operations before parametrization of TFHE operations.
2024-03-07 15:42:26 +01:00

144 lines
4.8 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h>
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h>
#include <llvm/Support/Debug.h>
namespace mlir {
namespace concretelang {
struct OptimizerPartitionFrontierMaterializationPass
: public OptimizerPartitionFrontierMaterializationPassBase<
OptimizerPartitionFrontierMaterializationPass> {
OptimizerPartitionFrontierMaterializationPass(
const optimizer::CircuitSolution &solverSolution)
: solverSolution(solverSolution) {}
enum class OperationKind { PRODUCER, CONSUMER };
std::optional<uint64_t> getOid(mlir::Operation *op, OperationKind kind) {
if (mlir::IntegerAttr oidAttr =
op->getAttrOfType<mlir::IntegerAttr>("TFHE.OId")) {
return oidAttr.getInt();
} else if (mlir::DenseI32ArrayAttr oidArrayAttr =
op->getAttrOfType<mlir::DenseI32ArrayAttr>("TFHE.OId")) {
assert(oidArrayAttr.size() > 0);
if (kind == OperationKind::CONSUMER) {
return oidArrayAttr[0];
} else {
// All operations with a `TFHE.OId` array attribute store the
// OId of the result at the last position, except
// multiplications, which use the 6th element (at index 5),
// see `mlir::concretelang::optimizer::FunctionToDag::addMul`.
if (llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op) ||
llvm::dyn_cast<mlir::concretelang::FHELinalg::MulEintOp>(op)) {
assert(oidArrayAttr.size() >= 6);
return oidArrayAttr[5];
} else {
return oidArrayAttr[oidArrayAttr.size() - 1];
}
}
} else {
return std::nullopt;
}
}
void runOnOperation() final {
mlir::func::FuncOp func = this->getOperation();
func.walk([&](mlir::Operation *producer) {
std::optional<uint64_t> producerOid =
getOid(producer, OperationKind::PRODUCER);
if (!producerOid.has_value())
return;
assert(*producerOid < solverSolution.instructions_keys.size());
auto &eck =
solverSolution.instructions_keys[*producerOid].extra_conversion_keys;
if (eck.size() == 0)
return;
assert(eck.size() == 1);
assert(eck[0] <
solverSolution.circuit_keys.conversion_keyswitch_keys.size());
uint64_t producerOutKeyID =
solverSolution.instructions_keys[*producerOid].output_key;
uint64_t conversionOutKeyID =
solverSolution.circuit_keys.conversion_keyswitch_keys[eck[0]]
.output_key.identifier;
mlir::IRRewriter rewriter(producer->getContext());
rewriter.setInsertionPointAfter(producer);
for (mlir::Value res : producer->getResults()) {
mlir::Value resConverted;
for (mlir::OpOperand &operand :
llvm::make_early_inc_range(res.getUses())) {
mlir::Operation *consumer = operand.getOwner();
std::optional<uint64_t> consumerOid =
getOid(consumer, OperationKind::CONSUMER);
// By default, all consumers need the converted value,
// unless it is explicitly specified that the original value
// is needed
bool needsConvertedValue = true;
if (consumerOid.has_value()) {
assert(*consumerOid < solverSolution.instructions_keys.size());
uint64_t consumerInKeyID =
solverSolution.instructions_keys[*consumerOid].input_key;
if (consumerInKeyID == producerOutKeyID) {
needsConvertedValue = false;
} else {
assert(consumerInKeyID == conversionOutKeyID &&
"Consumer needs converted value, but with a key that is "
"not the extra conversion key of the producer");
}
}
if (needsConvertedValue) {
if (!resConverted) {
resConverted = rewriter.create<Optimizer::PartitionFrontierOp>(
producer->getLoc(), res.getType(), res, producerOutKeyID,
conversionOutKeyID);
}
operand.set(resConverted);
}
}
}
});
}
protected:
const optimizer::CircuitSolution &solverSolution;
};
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createOptimizerPartitionFrontierMaterializationPass(
const optimizer::CircuitSolution &solverSolution) {
return std::make_unique<OptimizerPartitionFrontierMaterializationPass>(
solverSolution);
}
} // namespace concretelang
} // namespace mlir