// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/ClientLib/CRT.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/BConcrete/Transforms/Passes.h" namespace arith = mlir::arith; namespace tensor = mlir::tensor; namespace bufferization = mlir::bufferization; namespace scf = mlir::scf; namespace BConcrete = mlir::concretelang::BConcrete; namespace crt = concretelang::clientlib::crt; namespace { char encode_crt[] = "encode_crt"; // This template rewrite pattern transforms any instance of // `BConcreteCRTOp` operators to `BConcreteOp` on // each block. // // Example: // // ```mlir // %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]} // : (tensor, tensor) -> // (tensor) // ``` // // becomes: // // ```mlir // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index // %init = linalg.tensor_init [B, lweSize] : tensor // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { // %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor // %tmp = "BConcreteOp"(%blockArg) // : (tensor) -> (tensor) // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] // : tensor into tensor // scf.yield %res : tensor // } // ``` template struct BConcreteCRTUnaryOpPattern : public mlir::OpRewritePattern { BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(BConcreteCRTOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); auto loc = op.getLoc(); assert(resultTy.getShape().size() == 2); auto shape = resultTy.getShape(); // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index auto c0 = rewriter.create(loc, 0); auto c1 = rewriter.create(loc, 1); auto cB = rewriter.create(loc, shape[0]); // %init = linalg.tensor_init [B, lweSize] : tensor mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { rewriter.replaceOpWithNewOp( op, c0, cB, c1, init, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, mlir::ValueRange iterArgs) { // [%i, 0] mlir::SmallVector offsets{ i, rewriter.getI64IntegerAttr(0)}; // [1, lweSize] mlir::SmallVector sizes{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(shape[1])}; // [1, 1] mlir::SmallVector strides{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; auto blockTy = mlir::RankedTensorType::get({shape[1]}, resultTy.getElementType()); // %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor auto blockArg = builder.create( loc, blockTy, op.ciphertext(), offsets, sizes, strides); // %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1) // : (tensor, tensor) -> // (tensor) auto tmp = builder.create(loc, blockTy, blockArg); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, // 1] : tensor into tensor auto res = builder.create( loc, tmp, iterArgs[0], offsets, sizes, strides); // scf.yield %res : tensor builder.create(loc, (mlir::Value)res); }); return mlir::success(); } }; // This template rewrite pattern transforms any instance of // `BConcreteCRTOp` operators to `BConcreteOp` on // each block. // // Example: // // ```mlir // %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]} // : (tensor, tensor) -> // (tensor) // ``` // // becomes: // // ```mlir // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index // %init = linalg.tensor_init [B, lweSize] : tensor // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor // %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1] // : tensor // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) // : (tensor, tensor) -> // (tensor) // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] // : tensor into tensor // scf.yield %res : tensor // } // ``` template struct BConcreteCRTBinaryOpPattern : public mlir::OpRewritePattern { BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(BConcreteCRTOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); auto loc = op.getLoc(); assert(resultTy.getShape().size() == 2); auto shape = resultTy.getShape(); // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index auto c0 = rewriter.create(loc, 0); auto c1 = rewriter.create(loc, 1); auto cB = rewriter.create(loc, shape[0]); // %init = linalg.tensor_init [B, lweSize] : tensor mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { rewriter.replaceOpWithNewOp( op, c0, cB, c1, init, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, mlir::ValueRange iterArgs) { // [%i, 0] mlir::SmallVector offsets{ i, rewriter.getI64IntegerAttr(0)}; // [1, lweSize] mlir::SmallVector sizes{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(shape[1])}; // [1, 1] mlir::SmallVector strides{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; auto blockTy = mlir::RankedTensorType::get({shape[1]}, resultTy.getElementType()); // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor auto blockArg0 = builder.create( loc, blockTy, op.lhs(), offsets, sizes, strides); // %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1] // : tensor auto blockArg1 = builder.create( loc, blockTy, op.rhs(), offsets, sizes, strides); // %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1) // : (tensor, tensor) -> // (tensor) auto tmp = builder.create(loc, blockTy, blockArg0, blockArg1); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, // 1] : tensor into tensor auto res = builder.create( loc, tmp, iterArgs[0], offsets, sizes, strides); // scf.yield %res : tensor builder.create(loc, (mlir::Value)res); }); return mlir::success(); } }; // This template rewrite pattern transforms any instance of // `BConcreteCRTOp` operators to `BConcreteOp` on // each block with the crt decomposition of the cleartext. // // Example: // // ```mlir // %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]} // : (tensor, i64) -> (tensor) // ``` // // becomes: // // ```mlir // // Build the decomposition of the plaintext // %x0_a = arith.constant 64/d0 : f64 // %x0_b = arith.mulf %x, %x0_a : i64 // %x0 = arith.fptoui %x0_b : f64 to i64 // ... // %xn_a = arith.constant 64/dn : f64 // %xn_b = arith.mulf %x, %xn_a : i64 // %xn = arith.fptoui %xn_b : f64 to i64 // %x_decomp = tensor.from_elements %x0, ..., %xn : tensor // // Loop on blocks // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index // %init = linalg.tensor_init [B, lweSize] : tensor // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor // %blockArg1 = tensor.extract %x_decomp[%i] : tensor // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) // : (tensor, i64) -> (tensor) // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] // : tensor into tensor // scf.yield %res : tensor // } // ``` struct AddPlaintextCRTLweBufferOpPattern : public mlir::OpRewritePattern { AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) { } mlir::LogicalResult matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); auto loc = op.getLoc(); assert(resultTy.getShape().size() == 2); auto shape = resultTy.getShape(); auto rhs = op.rhs(); mlir::SmallVector plaintextElements; uint64_t moduliProduct = 1; for (mlir::Attribute di : op.crtDecomposition()) { moduliProduct *= di.cast().getValue().getZExtValue(); } if (auto cst = mlir::dyn_cast_or_null(rhs.getDefiningOp())) { auto apCst = cst.getValue().cast().getValue(); auto value = apCst.getSExtValue(); // constant value, encode at compile time for (mlir::Attribute di : op.crtDecomposition()) { auto modulus = di.cast().getValue().getZExtValue(); auto encoded = crt::encode(value, modulus, moduliProduct); plaintextElements.push_back( rewriter.create(loc, encoded, 64)); } } else { // dynamic value, encode at runtime if (insertForwardDeclaration( op, rewriter, encode_crt, mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI64Type(), rewriter.getI64Type(), rewriter.getI64Type()}, {rewriter.getI64Type()})) .failed()) { return mlir::failure(); } auto extOp = rewriter.create(loc, rewriter.getI64Type(), rhs); auto moduliProductOp = rewriter.create(loc, moduliProduct, 64); for (mlir::Attribute di : op.crtDecomposition()) { auto modulus = di.cast().getValue().getZExtValue(); auto modulusOp = rewriter.create(loc, modulus, 64); plaintextElements.push_back( rewriter .create( loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()}, mlir::ValueRange{extOp, modulusOp, moduliProductOp}) .getResult(0)); } } // %x_decomp = tensor.from_elements %x0, ..., %xn : tensor auto x_decomp = rewriter.create(loc, plaintextElements); // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index auto c0 = rewriter.create(loc, 0); auto c1 = rewriter.create(loc, 1); auto cB = rewriter.create(loc, shape[0]); // %init = linalg.tensor_init [B, lweSize] : tensor mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { rewriter.replaceOpWithNewOp( op, c0, cB, c1, init, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, mlir::ValueRange iterArgs) { // [%i, 0] mlir::SmallVector offsets{ i, rewriter.getI64IntegerAttr(0)}; // [1, lweSize] mlir::SmallVector sizes{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(shape[1])}; // [1, 1] mlir::SmallVector strides{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; auto blockTy = mlir::RankedTensorType::get({shape[1]}, resultTy.getElementType()); // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor auto blockArg0 = builder.create( loc, blockTy, op.lhs(), offsets, sizes, strides); // %blockArg1 = tensor.extract %x_decomp[%i] : tensor auto blockArg1 = builder.create(loc, x_decomp, i); // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) // : (tensor, i64) -> (tensor) auto tmp = builder.create( loc, blockTy, blockArg0, blockArg1); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, // 1] : tensor into tensor auto res = builder.create( loc, tmp, iterArgs[0], offsets, sizes, strides); // scf.yield %res : tensor builder.create(loc, (mlir::Value)res); }); return mlir::success(); } }; // This template rewrite pattern transforms any instance of // `BConcreteCRTOp` operators to `BConcreteOp` on // each block with the crt decomposition of the cleartext. // // Example: // // ```mlir // %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]} // : (tensor, i64) -> (tensor) // ``` // // becomes: // // ```mlir // // Build the decomposition of the plaintext // %x0_a = arith.constant 64/d0 : f64 // %x0_b = arith.mulf %x, %x0_a : i64 // %x0 = arith.fptoui %x0_b : f64 to i64 // ... // %xn_a = arith.constant 64/dn : f64 // %xn_b = arith.mulf %x, %xn_a : i64 // %xn = arith.fptoui %xn_b : f64 to i64 // %x_decomp = tensor.from_elements %x0, ..., %xn : tensor // // Loop on blocks // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index // %init = linalg.tensor_init [B, lweSize] : tensor // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor // %blockArg1 = tensor.extract %x_decomp[%i] : tensor // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) // : (tensor, i64) -> (tensor) // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] // : tensor into tensor // scf.yield %res : tensor // } // ``` struct MulCleartextCRTLweBufferOpPattern : public mlir::OpRewritePattern { MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) { } mlir::LogicalResult matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); auto loc = op.getLoc(); assert(resultTy.getShape().size() == 2); auto shape = resultTy.getShape(); // %c0 = arith.constant 0 : index // %c1 = arith.constant 1 : index // %cB = arith.constant nbBlocks : index auto c0 = rewriter.create(loc, 0); auto c1 = rewriter.create(loc, 1); auto cB = rewriter.create(loc, shape[0]); // %init = linalg.tensor_init [B, lweSize] : tensor mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); auto rhs = rewriter.create(op.getLoc(), rewriter.getI64Type(), op.rhs()); // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> // (tensor) { rewriter.replaceOpWithNewOp( op, c0, cB, c1, init, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, mlir::ValueRange iterArgs) { // [%i, 0] mlir::SmallVector offsets{ i, rewriter.getI64IntegerAttr(0)}; // [1, lweSize] mlir::SmallVector sizes{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(shape[1])}; // [1, 1] mlir::SmallVector strides{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; auto blockTy = mlir::RankedTensorType::get({shape[1]}, resultTy.getElementType()); // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] // : tensor auto blockArg0 = builder.create( loc, blockTy, op.lhs(), offsets, sizes, strides); // %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x) // : (tensor, i64) -> (tensor) auto tmp = builder.create( loc, blockTy, blockArg0, rhs); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, // 1] : tensor into tensor auto res = builder.create( loc, tmp, iterArgs[0], offsets, sizes, strides); // scf.yield %res : tensor builder.create(loc, (mlir::Value)res); }); return mlir::success(); } }; struct EliminateCRTOpsPass : public EliminateCRTOpsBase { void runOnOperation() final; }; void EliminateCRTOpsPass::runOnOperation() { auto op = getOperation(); mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); // add_crt_lwe_buffers target.addIllegalOp(); patterns.add>( &getContext()); // add_plaintext_crt_lwe_buffers target.addIllegalOp(); patterns.add(&getContext()); // mul_cleartext_crt_lwe_buffer target.addIllegalOp(); patterns.add(&getContext()); target.addIllegalOp(); patterns.add>( &getContext()); // This dialect are used to transforms crt ops to bconcrete ops target .addLegalDialect(); // Apply the conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); return; } } } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createEliminateCRTOps() { return std::make_unique(); } } // namespace concretelang } // namespace mlir