#include "concretelang/Support/CompilationFeedback.h" #include #include #include #include #include #include #include #include using namespace mlir::concretelang; using namespace mlir; namespace mlir { namespace concretelang { namespace TFHE { #define DISPATCH_ENTER(type) \ if (auto typedOp = llvm::dyn_cast(op)) { \ std::optional error = on_enter(typedOp, *this); \ if (error.has_value()) { \ return error; \ } \ } #define DISPATCH_EXIT(type) \ if (auto typedOp = llvm::dyn_cast(op)) { \ std::optional error = on_exit(typedOp, *this); \ if (error.has_value()) { \ return error; \ } \ } struct ExtractTFHEStatisticsPass : public PassWrapper>, public TripCountTracker { ProgramCompilationFeedback &feedback; CircuitCompilationFeedback *circuitFeedback; ExtractTFHEStatisticsPass(ProgramCompilationFeedback &feedback) : feedback{feedback}, circuitFeedback{nullptr} {}; void runOnOperation() override { auto module = getOperation(); auto funcs = module.getOps(); for (CircuitCompilationFeedback &circuitFeedback : feedback.circuitFeedbacks) { auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) { return op.getName() == circuitFeedback.name; }); assert(funcOp != funcs.end()); this->circuitFeedback = &circuitFeedback; WalkResult walk = (*funcOp)->walk([&](Operation *op, const WalkStage &stage) { if (stage.isBeforeAllRegions()) { std::optional error = this->enter(op); if (error.has_value()) { op->emitError() << error->mesg; return WalkResult::interrupt(); } } if (stage.isAfterAllRegions()) { std::optional error = this->exit(op); if (error.has_value()) { op->emitError() << error->mesg; return WalkResult::interrupt(); } } return WalkResult::advance(); }); if (walk.wasInterrupted()) { signalPassFailure(); return; } } } std::optional enter(mlir::Operation *op) { DISPATCH_ENTER(scf::ForOp) DISPATCH_ENTER(TFHE::AddGLWEOp) DISPATCH_ENTER(TFHE::AddGLWEIntOp) DISPATCH_ENTER(TFHE::BootstrapGLWEOp) DISPATCH_ENTER(TFHE::KeySwitchGLWEOp) DISPATCH_ENTER(TFHE::MulGLWEIntOp) DISPATCH_ENTER(TFHE::NegGLWEOp) DISPATCH_ENTER(TFHE::SubGLWEIntOp) DISPATCH_ENTER(TFHE::WopPBSGLWEOp) return std::nullopt; } std::optional exit(mlir::Operation *op) { DISPATCH_EXIT(scf::ForOp) return std::nullopt; } static std::optional on_enter(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { std::optional tripCount = tryGetStaticTripCount(op); if (!tripCount.has_value()) { emitWarning(op.getLoc(), "Cannot determine static trip count"); } pass.pushTripCount(op, tripCount); return std::nullopt; } static std::optional on_exit(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { std::optional tripCount = tryGetStaticTripCount(op); pass.popTripCount(op, tripCount); return std::nullopt; } // ############# // TFHE.add_glwe // ############# static std::optional on_enter(TFHE::AddGLWEOp &op, ExtractTFHEStatisticsPass &pass) { auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################# // TFHE.add_glwe_int // ################# static std::optional on_enter(TFHE::AddGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_ADDITION; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################### // TFHE.bootstrap_glwe // ################### static std::optional on_enter(TFHE::BootstrapGLWEOp &op, ExtractTFHEStatisticsPass &pass) { auto bsk = op.getKey(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::PBS; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################### // TFHE.keyswitch_glwe // ################### static std::optional on_enter(TFHE::KeySwitchGLWEOp &op, ExtractTFHEStatisticsPass &pass) { auto ksk = op.getKey(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::KEY_SWITCH; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################# // TFHE.mul_glwe_int // ################# static std::optional on_enter(TFHE::MulGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ############# // TFHE.neg_glwe // ############# static std::optional on_enter(TFHE::NegGLWEOp &op, ExtractTFHEStatisticsPass &pass) { auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################# // TFHE.sub_int_glwe // ################# static std::optional on_enter(TFHE::SubGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { auto resultingKey = op.getType().getKey().getNormalized(); auto location = locationString(op.getLoc()); auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); // clear - encrypted = clear + neg(encrypted) auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); operation = PrimitiveOperation::CLEAR_ADDITION; pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } // ################# // TFHE.wop_pbs_glwe // ################# static std::optional on_enter(TFHE::WopPBSGLWEOp &op, ExtractTFHEStatisticsPass &pass) { auto bsk = op.getBsk(); auto ksk = op.getKsk(); auto pksk = op.getPksk(); auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::WOP_PBS; auto keys = std::vector>(); auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); keys.push_back(key); key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); keys.push_back(key); key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (int64_t)pksk.getIndex()); keys.push_back(key); pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, }); return std::nullopt; } }; } // namespace TFHE std::unique_ptr> createStatisticExtractionPass(ProgramCompilationFeedback &feedback) { return std::make_unique(feedback); } } // namespace concretelang } // namespace mlir