// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/SDFG/IR/SDFGOps.h" #include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" #include "concretelang/Runtime/stream_emulator_api.h" namespace SDFG = mlir::concretelang::SDFG; namespace { struct SDFGToStreamEmulatorPass : public SDFGToStreamEmulatorBase { void runOnOperation() final; }; char stream_emulator_init[] = "stream_emulator_init"; char stream_emulator_run[] = "stream_emulator_run"; char stream_emulator_delete[] = "stream_emulator_delete"; char stream_emulator_make_memref_add_lwe_ciphertexts_u64_process[] = "stream_emulator_make_memref_add_lwe_ciphertexts_u64_process"; char stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process[] = "stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process"; char stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process[] = "stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process"; char stream_emulator_make_memref_negate_lwe_ciphertext_u64_process[] = "stream_emulator_make_memref_negate_lwe_ciphertext_u64_process"; char stream_emulator_make_memref_keyswitch_lwe_u64_process[] = "stream_emulator_make_memref_keyswitch_lwe_u64_process"; char stream_emulator_make_memref_bootstrap_lwe_u64_process[] = "stream_emulator_make_memref_bootstrap_lwe_u64_process"; char stream_emulator_make_memref_stream[] = "stream_emulator_make_memref_stream"; char stream_emulator_put_memref[] = "stream_emulator_put_memref"; char stream_emulator_make_uint64_stream[] = "stream_emulator_make_uint64_stream"; char stream_emulator_put_uint64[] = "stream_emulator_put_uint64"; char stream_emulator_get_uint64[] = "stream_emulator_get_uint64"; mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) { std::vector shape(rank, -1); return mlir::RankedTensorType::get(shape, rewriter.getI64Type()); } mlir::Type makeDynamicTensorTypes(mlir::OpBuilder &rewriter, mlir::Type oldTy) { if (auto ttype = oldTy.dyn_cast_or_null()) return getDynamicTensor(rewriter, ttype.getRank()); if (auto stTy = oldTy.dyn_cast_or_null()) return SDFG::StreamType::get( rewriter.getContext(), makeDynamicTensorTypes(rewriter, stTy.getElementType())); return oldTy; } mlir::LogicalResult insertGenericForwardDeclaration(mlir::Operation *op, mlir::OpBuilder &rewriter, llvm::StringRef funcName, mlir::TypeRange opTys, mlir::TypeRange resTys) { mlir::SmallVector operands; for (mlir::Type opTy : opTys) operands.push_back(makeDynamicTensorTypes(rewriter, opTy)); mlir::SmallVector results; for (mlir::Type resTy : resTys) results.push_back(makeDynamicTensorTypes(rewriter, resTy)); mlir::FunctionType funcType = mlir::FunctionType::get(rewriter.getContext(), operands, results); return insertForwardDeclaration(op, rewriter, funcName, funcType); } void castDynamicTensorOps(mlir::Operation *op, mlir::OpBuilder &rewriter, mlir::ValueRange operands, mlir::SmallVector &newOps) { for (auto val : operands) { auto oldTy = val.getType(); if (auto ttype = oldTy.dyn_cast_or_null()) newOps.push_back(rewriter.create( op->getLoc(), getDynamicTensor(rewriter, ttype.getRank()), val)); else newOps.push_back(val); } } struct LowerSDFGInit : public mlir::OpRewritePattern { LowerSDFGInit(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::Init initOp, ::mlir::PatternRewriter &rewriter) const override { mlir::FunctionType funcType = mlir::FunctionType::get( rewriter.getContext(), {}, {SDFG::DFGType::get(rewriter.getContext())}); if (insertForwardDeclaration(initOp, rewriter, stream_emulator_init, funcType) .failed()) return ::mlir::failure(); rewriter.replaceOpWithNewOp( initOp, stream_emulator_init, mlir::TypeRange{SDFG::DFGType::get(rewriter.getContext())}); return ::mlir::success(); }; }; struct LowerSDFGStart : public mlir::OpRewritePattern { LowerSDFGStart(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::Start startOp, ::mlir::PatternRewriter &rewriter) const override { mlir::FunctionType funcType = mlir::FunctionType::get( rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {}); if (insertForwardDeclaration(startOp, rewriter, stream_emulator_run, funcType) .failed()) return ::mlir::failure(); rewriter.replaceOpWithNewOp( startOp, stream_emulator_run, mlir::TypeRange{}, startOp.getOperation()->getOperands()); return ::mlir::success(); }; }; struct LowerSDFGShutdown : public mlir::OpRewritePattern { LowerSDFGShutdown(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::Shutdown desOp, ::mlir::PatternRewriter &rewriter) const override { mlir::FunctionType funcType = mlir::FunctionType::get( rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {}); if (insertForwardDeclaration(desOp, rewriter, stream_emulator_delete, funcType) .failed()) return ::mlir::failure(); rewriter.replaceOpWithNewOp( desOp, stream_emulator_delete, mlir::TypeRange{}, desOp.getOperation()->getOperands()); return ::mlir::success(); }; }; struct LowerSDFGMakeProcess : public mlir::OpRewritePattern { LowerSDFGMakeProcess(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp, ::mlir::PatternRewriter &rewriter) const override { const char *funcName; mlir::SmallVector operands(mpOp->getOperands()); switch (mpOp.type()) { case SDFG::ProcessKind::add_eint: funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process; break; case SDFG::ProcessKind::add_eint_int: funcName = stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process; break; case SDFG::ProcessKind::mul_eint_int: funcName = stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process; break; case SDFG::ProcessKind::neg_eint: funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process; break; case SDFG::ProcessKind::keyswitch: funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process; // level operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("level"))); // base_log operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("baseLog"))); // lwe_dim_in operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("lwe_dim_in"))); // lwe_dim_out operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("lwe_dim_out"))); // context operands.push_back(getContextArgument(mpOp)); break; case SDFG::ProcessKind::bootstrap: funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process; // input_lwe_dim operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("inputLweDim"))); // poly_size operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("polySize"))); // level operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("level"))); // base_log operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("baseLog"))); // glwe_dim operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("glweDimension"))); // out_precision operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("outPrecision"))); // context operands.push_back(getContextArgument(mpOp)); break; } if (insertGenericForwardDeclaration(mpOp, rewriter, funcName, mlir::ValueRange{operands}.getTypes(), mpOp->getResultTypes()) .failed()) return ::mlir::failure(); rewriter.replaceOpWithNewOp( mpOp, funcName, mpOp->getResultTypes(), operands); return ::mlir::success(); }; }; struct LowerSDFGMakeStream : public mlir::OpRewritePattern { LowerSDFGMakeStream(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::MakeStream msOp, ::mlir::PatternRewriter &rewriter) const override { const char *funcName; stream_type t; switch (msOp.type()) { case SDFG::StreamKind::host_to_device: t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP; break; case SDFG::StreamKind::on_device: t = TS_STREAM_TYPE_TOPO_TO_TOPO_LSAP; break; case SDFG::StreamKind::device_to_host: t = TS_STREAM_TYPE_TOPO_TO_X86_LSAP; break; } auto sType = msOp->getResultTypes()[0].dyn_cast_or_null(); assert(sType && "SDFG MakeStream operation should return a stream type"); if (sType.getElementType().isa()) { funcName = stream_emulator_make_memref_stream; } else { assert(sType.getElementType().isa() && "SDFG streams only support memrefs and integers."); funcName = stream_emulator_make_uint64_stream; } if (insertGenericForwardDeclaration( msOp, rewriter, funcName, {rewriter.getI64Type(), rewriter.getI64Type()}, msOp->getResultTypes()) .failed()) return ::mlir::failure(); mlir::Value nullStringPtr = rewriter.create( msOp.getLoc(), rewriter.getI64IntegerAttr(0)); mlir::Value streamTypeCst = rewriter.create( msOp.getLoc(), rewriter.getI64IntegerAttr((int)t)); auto callop = rewriter.replaceOpWithNewOp( msOp, funcName, makeDynamicTensorTypes(rewriter, msOp->getResultTypes()[0]), mlir::ValueRange{nullStringPtr, streamTypeCst}); for (auto &use : llvm::make_early_inc_range(msOp->getUses())) use.set(callop.getResult(0)); return ::mlir::success(); }; }; struct LowerSDFGPut : public mlir::OpRewritePattern { LowerSDFGPut(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::Put putOp, ::mlir::PatternRewriter &rewriter) const override { const char *funcName; auto sType = putOp->getOperandTypes()[0].dyn_cast_or_null(); assert(sType && "SDFG Put operation must take a stream type as first parameter."); if (sType.getElementType().isa()) { funcName = stream_emulator_put_memref; } else { assert(sType.getElementType().isa() && "SDFG streams only support memrefs and integers."); funcName = stream_emulator_put_uint64; } if (insertGenericForwardDeclaration(putOp, rewriter, funcName, putOp->getOperandTypes(), putOp->getResultTypes()) .failed()) return ::mlir::failure(); mlir::SmallVector newOps; castDynamicTensorOps(putOp, rewriter, putOp->getOperands(), newOps); rewriter.replaceOpWithNewOp( putOp, funcName, putOp->getResultTypes(), newOps); return ::mlir::success(); }; }; struct LowerSDFGGet : public mlir::OpRewritePattern { LowerSDFGGet(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::Get getOp, ::mlir::PatternRewriter &rewriter) const override { const char *funcName; auto sType = getOp->getOperandTypes()[0].dyn_cast_or_null(); assert(sType && "SDFG Get operation must take a stream type as first parameter."); if (sType.getElementType().isa()) { // TODO: SDFG.Get for memref streams is lowered during bufferization // as returning a memref requires allocation for now return ::mlir::success(); } else { assert(sType.getElementType().isa() && "SDFG streams only support memrefs and integers."); funcName = stream_emulator_get_uint64; } if (insertGenericForwardDeclaration(getOp, rewriter, funcName, getOp->getOperandTypes(), getOp->getResultTypes()) .failed()) return ::mlir::failure(); rewriter.replaceOpWithNewOp( getOp, funcName, getOp->getResultTypes(), getOp->getOperands()); return ::mlir::success(); }; }; } // namespace void SDFGToStreamEmulatorPass::runOnOperation() { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); target.addIllegalOp(); // All BConcrete ops are legal after the conversion target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } namespace mlir { namespace concretelang { std::unique_ptr> createConvertSDFGToStreamEmulatorPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir