// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" namespace { namespace Concrete = mlir::concretelang::Concrete; namespace arith = mlir::arith; namespace func = mlir::func; namespace memref = mlir::memref; char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64"; char memref_add_plaintext_lwe_ciphertext_u64[] = "memref_add_plaintext_lwe_ciphertext_u64"; char memref_mul_cleartext_lwe_ciphertext_u64[] = "memref_mul_cleartext_lwe_ciphertext_u64"; char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; char memref_await_future[] = "memref_await_future"; char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64"; char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64"; char memref_batched_keyswitch_lwe_cuda_u64[] = "memref_batched_keyswitch_lwe_cuda_u64"; char memref_batched_bootstrap_lwe_cuda_u64[] = "memref_batched_bootstrap_lwe_cuda_u64"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt"; char memref_encode_expand_lut_for_bootstrap[] = "memref_encode_expand_lut_for_bootstrap"; char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs"; char memref_trace[] = "memref_trace"; mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, size_t rank) { std::vector shape(rank, mlir::ShapedType::kDynamic); mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); for (size_t i = 0; i < rank; i++) { expr = expr + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); } return mlir::MemRefType::get( shape, rewriter.getI64Type(), mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); } // Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) { mlir::Type valueType = value.getType(); if (auto memrefTy = valueType.dyn_cast_or_null()) { return rewriter.create( value.getLoc(), getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), value); } else { return value; } } mlir::LogicalResult insertForwardDeclarationOfTheCAPI( mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); auto futureType = mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); auto i32Type = rewriter.getI32Type(); mlir::FunctionType funcType; if (funcName == memref_add_lwe_ciphertexts_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {}); } else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_negate_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType}, {}); } else if (funcName == memref_keyswitch_lwe_u64 || funcName == memref_keyswitch_lwe_cuda_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType, i32Type, i32Type, i32Type, i32Type, contextType}, {}); } else if (funcName == memref_bootstrap_lwe_u64 || funcName == memref_bootstrap_lwe_cuda_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType, memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {}); } else if (funcName == memref_keyswitch_async_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {futureType}); } else if (funcName == memref_bootstrap_async_lwe_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType, memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {futureType}); } else if (funcName == memref_batched_keyswitch_lwe_u64 || funcName == memref_batched_keyswitch_lwe_cuda_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref2DType, memref2DType, i32Type, i32Type, i32Type, i32Type, contextType}, {}); } else if (funcName == memref_batched_bootstrap_lwe_u64 || funcName == memref_batched_bootstrap_lwe_cuda_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref2DType, memref2DType, memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {}); } else if (funcName == memref_await_future) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, futureType, memref1DType, memref1DType}, {}); } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), memref1DType, }, {}); } else if (funcName == memref_wop_pbs_crt_buffer) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref2DType, memref2DType, memref2DType, memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), contextType, }, {}); } else if (funcName == memref_encode_plaintext_with_crt) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, rewriter.getI64Type(), memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_encode_expand_lut_for_bootstrap) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()}, {}); } else if (funcName == memref_encode_lut_for_crt_woppbs) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref2DType, memref1DType, memref1DType, memref1DType, rewriter.getI32Type(), rewriter.getI1Type()}, {}); } else if (funcName == memref_trace) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), rewriter.getI32Type(), rewriter.getI32Type()}, {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); } return insertForwardDeclaration(op, rewriter, funcName, funcType); } template void addNoOperands(ConcreteOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) {} template struct ConcreteToCAPICallPattern : public mlir::OpRewritePattern { ConcreteToCAPICallPattern( ::mlir::MLIRContext *context, std::function &, mlir::RewriterBase &)> addOperands = addNoOperands, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit), addOperands(addOperands) {} ::mlir::LogicalResult matchAndRewrite(ConcreteOp bOp, ::mlir::PatternRewriter &rewriter) const override { // Create the operands mlir::SmallVector operands; // For all tensor operand get the corresponding casted buffer for (auto &operand : bOp->getOpOperands()) { mlir::Type type = operand.get().getType(); if (!type.isa()) { operands.push_back(operand.get()); } else { operands.push_back(getCastedMemRef(rewriter, operand.get())); } } // append additional argument addOperands(bOp, operands, rewriter); // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI(bOp, rewriter, callee).failed()) { return mlir::failure(); } rewriter.replaceOpWithNewOp(bOp, callee, mlir::TypeRange{}, operands); return ::mlir::success(); }; private: std::function &, mlir::RewriterBase &)> addOperands; }; template void keyswitchAddOperands(KeySwitchOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // level operands.push_back( rewriter.create(op.getLoc(), op.getLevelAttr())); // base_log operands.push_back( rewriter.create(op.getLoc(), op.getBaseLogAttr())); // lwe_dim_in operands.push_back( rewriter.create(op.getLoc(), op.getLweDimInAttr())); // lwe_dim_out operands.push_back( rewriter.create(op.getLoc(), op.getLweDimOutAttr())); // context operands.push_back(getContextArgument(op)); } template void bootstrapAddOperands(BootstrapOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // input_lwe_dim operands.push_back(rewriter.create( op.getLoc(), op.getInputLweDimAttr())); // poly_size operands.push_back(rewriter.create( op.getLoc(), op.getPolySizeAttr())); // level operands.push_back( rewriter.create(op.getLoc(), op.getLevelAttr())); // base_log operands.push_back(rewriter.create( op.getLoc(), op.getBaseLogAttr())); // glwe_dim operands.push_back(rewriter.create( op.getLoc(), op.getGlweDimensionAttr())); // out_precision operands.push_back(rewriter.create( op.getLoc(), op.getOutPrecisionAttr())); // context operands.push_back(getContextArgument(op)); } void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { mlir::Type crtType = mlir::RankedTensorType::get( {(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type()); std::vector values; for (auto a : op.getCrtDecomposition()) { values.push_back(a.cast().getValue().getZExtValue()); } auto attr = rewriter.getI64TensorAttr(values); auto x = rewriter.create(op.getLoc(), attr, crtType); auto globalMemref = mlir::bufferization::getGlobalFor(x, 0); rewriter.eraseOp(x); assert(!failed(globalMemref)); auto globalRef = rewriter.create( op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, globalRef)); // lwe_small_size operands.push_back(rewriter.create( op.getLoc(), op.getPackingKeySwitchInputLweDimensionAttr())); // cbs_level_count operands.push_back(rewriter.create( op.getLoc(), op.getCircuitBootstrapLevelAttr())); // cbs_base_log operands.push_back(rewriter.create( op.getLoc(), op.getCircuitBootstrapBaseLogAttr())); // ksk_level_count operands.push_back(rewriter.create( op.getLoc(), op.getKeyswitchLevelAttr())); // ksk_base_log operands.push_back(rewriter.create( op.getLoc(), op.getKeyswitchBaseLogAttr())); // bsk_level_count operands.push_back(rewriter.create( op.getLoc(), op.getBootstrapLevelAttr())); // bsk_base_log operands.push_back(rewriter.create( op.getLoc(), op.getBootstrapBaseLogAttr())); // fpksk_level_count operands.push_back(rewriter.create( op.getLoc(), op.getPackingKeySwitchLevelAttr())); // fpksk_base_log operands.push_back(rewriter.create( op.getLoc(), op.getPackingKeySwitchBaseLogAttr())); // polynomial_size operands.push_back(rewriter.create( op.getLoc(), op.getPackingKeySwitchoutputPolynomialSizeAttr())); // context operands.push_back(getContextArgument(op)); } void encodePlaintextWithCrtAddOperands( Concrete::EncodePlaintextWithCrtBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // mods mlir::Type modsType = mlir::RankedTensorType::get( {(int)op.getModsAttr().size()}, rewriter.getI64Type()); std::vector modsValues; for (auto a : op.getMods()) { modsValues.push_back(a.cast().getValue().getZExtValue()); } auto modsAttr = rewriter.getI64TensorAttr(modsValues); auto modsOp = rewriter.create(op.getLoc(), modsAttr, modsType); auto modsGlobalMemref = mlir::bufferization::getGlobalFor(modsOp, 0); rewriter.eraseOp(modsOp); assert(!failed(modsGlobalMemref)); auto modsGlobalRef = rewriter.create( op.getLoc(), (*modsGlobalMemref).getType(), (*modsGlobalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, modsGlobalRef)); // mods_prod operands.push_back(rewriter.create( op.getLoc(), op.getModsProdAttr())); } void encodeExpandLutForBootstrapAddOperands( Concrete::EncodeExpandLutForBootstrapBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // poly_size operands.push_back(rewriter.create( op.getLoc(), op.getPolySizeAttr())); // output bits operands.push_back(rewriter.create( op.getLoc(), op.getOutputBitsAttr())); // is_signed operands.push_back(rewriter.create( op.getLoc(), op.getIsSignedAttr())); } void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op, mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { // crt_decomposition mlir::Type crtDecompositionType = mlir::RankedTensorType::get( {(int)op.getCrtDecompositionAttr().size()}, rewriter.getI64Type()); std::vector crtDecompositionValues; for (auto a : op.getCrtDecomposition()) { crtDecompositionValues.push_back( a.cast().getValue().getZExtValue()); } auto crtDecompositionAttr = rewriter.getI64TensorAttr(crtDecompositionValues); auto crtDecompositionOp = rewriter.create( op.getLoc(), crtDecompositionAttr, crtDecompositionType); auto crtDecompositionGlobalMemref = mlir::bufferization::getGlobalFor(crtDecompositionOp, 0); rewriter.eraseOp(crtDecompositionOp); assert(!failed(crtDecompositionGlobalMemref)); auto crtDecompositionGlobalRef = rewriter.create( op.getLoc(), (*crtDecompositionGlobalMemref).getType(), (*crtDecompositionGlobalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef)); // crt_bits mlir::Type crtBitsType = mlir::RankedTensorType::get( {(int)op.getCrtBitsAttr().size()}, rewriter.getI64Type()); std::vector crtBitsValues; for (auto a : op.getCrtBits()) { crtBitsValues.push_back( a.cast().getValue().getZExtValue()); } auto crtBitsAttr = rewriter.getI64TensorAttr(crtBitsValues); auto crtBitsOp = rewriter.create( op.getLoc(), crtBitsAttr, crtBitsType); auto crtBitsGlobalMemref = mlir::bufferization::getGlobalFor(crtBitsOp, 0); rewriter.eraseOp(crtBitsOp); assert(!failed(crtBitsGlobalMemref)); auto crtBitsGlobalRef = rewriter.create( op.getLoc(), (*crtBitsGlobalMemref).getType(), (*crtBitsGlobalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef)); // modulus_product operands.push_back(rewriter.create( op.getLoc(), op.getModulusProductAttr())); // is_signed operands.push_back(rewriter.create( op.getLoc(), op.getIsSignedAttr())); } struct ConcreteToCAPIPass : public ConcreteToCAPIBase { ConcreteToCAPIPass(bool gpu) : gpu(gpu) {} void runOnOperation() override { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); // Mark ops from the target dialect as legal operations target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); // Make sure that no ops from `FHE` remain after the lowering target.addIllegalDialect(); // Add patterns to transform Concrete operators to CAPI call patterns.add>( &getContext()); patterns.add< ConcreteToCAPICallPattern>( &getContext()); patterns.add< ConcreteToCAPICallPattern>( &getContext()); patterns.add>( &getContext()); patterns .add>( &getContext(), encodePlaintextWithCrtAddOperands); patterns.add< ConcreteToCAPICallPattern>( &getContext(), encodeExpandLutForBootstrapAddOperands); patterns .add>( &getContext(), encodeLutForWopPBSAddOperands); if (gpu) { patterns.add>( &getContext(), keyswitchAddOperands); patterns.add>( &getContext(), bootstrapAddOperands); patterns.add< ConcreteToCAPICallPattern>( &getContext(), keyswitchAddOperands); patterns.add< ConcreteToCAPICallPattern>( &getContext(), bootstrapAddOperands); } else { patterns.add>( &getContext(), keyswitchAddOperands); patterns.add>( &getContext(), bootstrapAddOperands); patterns .add>( &getContext(), keyswitchAddOperands); patterns .add>( &getContext(), bootstrapAddOperands); } patterns.add>( &getContext(), wopPBSAddOperands); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } private: bool gpu; }; } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createConvertConcreteToCAPIPass(bool gpu) { return std::make_unique(gpu); } } // namespace concretelang } // namespace mlir