// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include namespace mlir { namespace concretelang { /// Construct a table lookup to extract the carry bit mlir::Value getTruthTableCarryExtract(mlir::PatternRewriter &rewriter, mlir::Location loc, unsigned int chunkSize, unsigned int chunkWidth) { auto tableSize = 1 << chunkSize; std::vector values; values.reserve(tableSize); for (auto i = 0; i < tableSize; i++) { if (i < 1 << chunkWidth) values.push_back(llvm::APInt(1, 0, false)); else values.push_back(llvm::APInt(1, 1, false)); } auto truthTableAttr = mlir::DenseElementsAttr::get( mlir::RankedTensorType::get({tableSize}, rewriter.getIntegerType(64)), values); auto truthTable = rewriter.create(loc, truthTableAttr); return truthTable.getResult(); } namespace { namespace typing { /// Converts `FHE::ChunkedEncryptedInteger` into a tensor of /// `FHE::EncryptedInteger`. mlir::RankedTensorType convertChunkedEint(mlir::MLIRContext *context, FHE::ChunkedEncryptedIntegerType chunkedEint, unsigned int chunkSize, unsigned int chunkWidth) { auto eint = FHE::EncryptedIntegerType::get(context, chunkSize); auto bigIntWidth = chunkedEint.getWidth(); assert(bigIntWidth % chunkWidth == 0 && "chunkWidth must divide width of the big integer"); auto numberOfChunks = bigIntWidth / chunkWidth; std::vector shape({numberOfChunks}); return mlir::RankedTensorType::get(shape, eint); } /// The type converter used to transform `FHE` ops on chunked integers class TypeConverter : public mlir::TypeConverter { public: TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) { addConversion([](mlir::Type type) { return type; }); addConversion([chunkSize, chunkWidth](FHE::ChunkedEncryptedIntegerType type) { return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth); }); } }; } // namespace typing class AddEintPattern : public mlir::OpConversionPattern { public: AddEintPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, unsigned int chunkSize, unsigned int chunkWidth) : mlir::OpConversionPattern( converter, context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT), chunkSize(chunkSize), chunkWidth(chunkWidth) {} mlir::LogicalResult matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto tensorType = adaptor.a().getType().dyn_cast(); auto shape = tensorType.getShape(); assert(shape.size() == 1 && "chunked integer should be converted to flat tensors, but tensor " "have more than one dimension"); auto eintChunkWidth = tensorType.getElementType() .dyn_cast() .getWidth(); assert(eintChunkWidth == chunkSize && "wrong tensor elements width"); auto numberOfChunks = shape[0]; mlir::Value carry = rewriter .create(op.getLoc(), FHE::EncryptedIntegerType::get( rewriter.getContext(), chunkSize)) .getResult(); mlir::Value resultTensor = rewriter.create(op.getLoc(), adaptor.a().getType()) .getResult(); // used to shift the carry bit to the left mlir::Value twoPowerChunkSizeCst = rewriter .create(op.getLoc(), 1 << chunkWidth, chunkSize + 1) .getResult(); // Create the loop int64_t lb = 0, step = 1; auto forOp = rewriter.create( op.getLoc(), lb, numberOfChunks, step, resultTensor, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { // add inputs with the previous carry (init to 0) mlir::Value leftEint = builder.create(loc, adaptor.a(), iter); mlir::Value rightEint = builder.create(loc, adaptor.b(), iter); mlir::Value result = builder.create(loc, leftEint, rightEint) .getResult(); mlir::Value resultWithCarry = builder.create(loc, result, carry).getResult(); // compute the new carry: either 1 or 0 carry = rewriter.create( op.getLoc(), FHE::EncryptedIntegerType::get(rewriter.getContext(), chunkSize), resultWithCarry, getTruthTableCarryExtract(rewriter, op.getLoc(), chunkSize, chunkWidth)); // remove the carry bit from the result mlir::Value shiftedCarry = builder .create(loc, carry, twoPowerChunkSizeCst) .getResult(); mlir::Value finalResult = builder.create(loc, resultWithCarry, shiftedCarry) .getResult(); // insert the result in the result tensor mlir::Value tensorResult = builder.create( loc, finalResult, args[0], iter); builder.create(loc, tensorResult); }); rewriter.replaceOp(op, forOp.getResult(0)); return mlir::success(); } private: unsigned int chunkSize, chunkWidth; }; /// Perfoms the transformation of big integer operations class FHEBigIntTransformPass : public FHEBigIntTransformBase { public: FHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) : chunkSize(chunkSize), chunkWidth(chunkWidth){}; void runOnOperation() override { mlir::Operation *op = getOperation(); mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); typing::TypeConverter converter(chunkSize, chunkWidth); // Legal ops created during pattern application target.addLegalOp(); concretelang::addDynamicallyLegalTypeOp(target, converter); // Func ops are only legal with converted types target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); patterns.add>(patterns.getContext(), converter); concretelang::addDynamicallyLegalTypeOp(target, converter); patterns.add(converter, &getContext(), chunkSize, chunkWidth); if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } private: unsigned int chunkSize, chunkWidth; }; } // end anonymous namespace std::unique_ptr> createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) { assert(chunkSize >= chunkWidth + 1 && "chunkSize must be greater than chunkWidth"); return std::make_unique(chunkSize, chunkWidth); } } // namespace concretelang } // namespace mlir