mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-19 00:44:22 -05:00
The Concrete Optimizer is invoked on a representation of the program in the high-level FHELinalg / FHE Dialects and yields a solution with a one-to-one mapping of operations to keys. However, the abstractions used by these dialects do not allow for references to keys and the application of the solution is delayed until the pipeline reaches a representation of the program in the lower-level TFHE dialect. Various transformations applied by the pipeline along the way may break the one-to-one mapping and add indirections into producer-consumer relationships, resulting in ambiguous or partial mappings of TFHE operations to the keys. In particular, explicit frontiers between optimizer partitions may not be recovered. This commit preserves explicit frontiers between optimizer partitions as `optimizer.partition_frontier` operations and lowers these to keyswitch operations before parametrization of TFHE operations.
144 lines
4.8 KiB
C++
144 lines
4.8 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
|
|
|
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
|
#include <concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h>
|
|
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h>
|
|
|
|
#include <llvm/Support/Debug.h>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
struct OptimizerPartitionFrontierMaterializationPass
|
|
: public OptimizerPartitionFrontierMaterializationPassBase<
|
|
OptimizerPartitionFrontierMaterializationPass> {
|
|
|
|
OptimizerPartitionFrontierMaterializationPass(
|
|
const optimizer::CircuitSolution &solverSolution)
|
|
: solverSolution(solverSolution) {}
|
|
|
|
enum class OperationKind { PRODUCER, CONSUMER };
|
|
|
|
std::optional<uint64_t> getOid(mlir::Operation *op, OperationKind kind) {
|
|
if (mlir::IntegerAttr oidAttr =
|
|
op->getAttrOfType<mlir::IntegerAttr>("TFHE.OId")) {
|
|
return oidAttr.getInt();
|
|
} else if (mlir::DenseI32ArrayAttr oidArrayAttr =
|
|
op->getAttrOfType<mlir::DenseI32ArrayAttr>("TFHE.OId")) {
|
|
assert(oidArrayAttr.size() > 0);
|
|
|
|
if (kind == OperationKind::CONSUMER) {
|
|
return oidArrayAttr[0];
|
|
} else {
|
|
// All operations with a `TFHE.OId` array attribute store the
|
|
// OId of the result at the last position, except
|
|
// multiplications, which use the 6th element (at index 5),
|
|
// see `mlir::concretelang::optimizer::FunctionToDag::addMul`.
|
|
if (llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op) ||
|
|
llvm::dyn_cast<mlir::concretelang::FHELinalg::MulEintOp>(op)) {
|
|
assert(oidArrayAttr.size() >= 6);
|
|
return oidArrayAttr[5];
|
|
} else {
|
|
return oidArrayAttr[oidArrayAttr.size() - 1];
|
|
}
|
|
}
|
|
} else {
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
void runOnOperation() final {
|
|
mlir::func::FuncOp func = this->getOperation();
|
|
|
|
func.walk([&](mlir::Operation *producer) {
|
|
std::optional<uint64_t> producerOid =
|
|
getOid(producer, OperationKind::PRODUCER);
|
|
|
|
if (!producerOid.has_value())
|
|
return;
|
|
|
|
assert(*producerOid < solverSolution.instructions_keys.size());
|
|
|
|
auto &eck =
|
|
solverSolution.instructions_keys[*producerOid].extra_conversion_keys;
|
|
|
|
if (eck.size() == 0)
|
|
return;
|
|
|
|
assert(eck.size() == 1);
|
|
assert(eck[0] <
|
|
solverSolution.circuit_keys.conversion_keyswitch_keys.size());
|
|
|
|
uint64_t producerOutKeyID =
|
|
solverSolution.instructions_keys[*producerOid].output_key;
|
|
|
|
uint64_t conversionOutKeyID =
|
|
solverSolution.circuit_keys.conversion_keyswitch_keys[eck[0]]
|
|
.output_key.identifier;
|
|
|
|
mlir::IRRewriter rewriter(producer->getContext());
|
|
rewriter.setInsertionPointAfter(producer);
|
|
|
|
for (mlir::Value res : producer->getResults()) {
|
|
mlir::Value resConverted;
|
|
|
|
for (mlir::OpOperand &operand :
|
|
llvm::make_early_inc_range(res.getUses())) {
|
|
mlir::Operation *consumer = operand.getOwner();
|
|
|
|
std::optional<uint64_t> consumerOid =
|
|
getOid(consumer, OperationKind::CONSUMER);
|
|
|
|
// By default, all consumers need the converted value,
|
|
// unless it is explicitly specified that the original value
|
|
// is needed
|
|
bool needsConvertedValue = true;
|
|
|
|
if (consumerOid.has_value()) {
|
|
assert(*consumerOid < solverSolution.instructions_keys.size());
|
|
|
|
uint64_t consumerInKeyID =
|
|
solverSolution.instructions_keys[*consumerOid].input_key;
|
|
|
|
if (consumerInKeyID == producerOutKeyID) {
|
|
needsConvertedValue = false;
|
|
} else {
|
|
assert(consumerInKeyID == conversionOutKeyID &&
|
|
"Consumer needs converted value, but with a key that is "
|
|
"not the extra conversion key of the producer");
|
|
}
|
|
}
|
|
|
|
if (needsConvertedValue) {
|
|
if (!resConverted) {
|
|
resConverted = rewriter.create<Optimizer::PartitionFrontierOp>(
|
|
producer->getLoc(), res.getType(), res, producerOutKeyID,
|
|
conversionOutKeyID);
|
|
}
|
|
|
|
operand.set(resConverted);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
protected:
|
|
const optimizer::CircuitSolution &solverSolution;
|
|
};
|
|
|
|
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
|
|
createOptimizerPartitionFrontierMaterializationPass(
|
|
const optimizer::CircuitSolution &solverSolution) {
|
|
return std::make_unique<OptimizerPartitionFrontierMaterializationPass>(
|
|
solverSolution);
|
|
}
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|