// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h" #include #include #include using namespace mlir; using namespace mlir::bufferization; using namespace mlir::tensor; namespace BConcrete = mlir::concretelang::BConcrete; namespace mlir { namespace concretelang { namespace BConcrete { namespace {} // namespace } // namespace BConcrete } // namespace concretelang } // namespace mlir namespace { mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, size_t rank) { std::vector shape(rank, -1); mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); for (size_t i = 0; i < rank; i++) { expr = expr + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); } return mlir::MemRefType::get( shape, rewriter.getI64Type(), mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); } // Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, mlir::Value value) { mlir::Type valueType = value.getType(); if (auto memrefTy = valueType.dyn_cast_or_null()) { return rewriter.create( loc, getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), value); } else { return value; } } char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64"; char memref_add_plaintext_lwe_ciphertext_u64[] = "memref_add_plaintext_lwe_ciphertext_u64"; char memref_mul_cleartext_lwe_ciphertext_u64[] = "memref_mul_cleartext_lwe_ciphertext_u64"; char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; char memref_await_future[] = "memref_await_future"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; mlir::LogicalResult insertForwardDeclarationOfTheCAPI( mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); auto futureType = mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); mlir::FunctionType funcType; if (funcName == memref_add_lwe_ciphertexts_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {}); } else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_negate_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType}, {}); } else if (funcName == memref_keyswitch_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {}); } else if (funcName == memref_bootstrap_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType, contextType}, {}); } else if (funcName == memref_keyswitch_async_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {futureType}); } else if (funcName == memref_bootstrap_async_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType, contextType}, {futureType}); } else if (funcName == memref_await_future) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, futureType, memref1DType, memref1DType}, {}); } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), memref1DType, }, {}); } else if (funcName == memref_wop_pbs_crt_buffer) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref2DType, memref2DType, memref1DType, memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), contextType, }, {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); } return insertForwardDeclaration(op, rewriter, funcName, funcType); } /// Returns the value of the context argument from the enclosing func mlir::Value getContextArgument(mlir::Operation *op) { mlir::Block *block = op->getBlock(); while (block != nullptr) { if (llvm::isa(block->getParentOp())) { block = &mlir::cast(block->getParentOp()) .getBody() .front(); auto context = std::find_if(block->getArguments().rbegin(), block->getArguments().rend(), [](BlockArgument &arg) { return arg.getType() .isa(); }); assert(context != block->getArguments().rend() && "Cannot find the Concrete.context"); return *context; } block = block->getParentOp()->getBlock(); } assert("can't find a function that enclose the op"); return nullptr; }; template void pushAdditionalArgs(Op op, mlir::SmallVector &operands, RewriterBase &rewriter); template struct BufferizableWithCallOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableWithCallOpInterface, Op> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); // For now we always alloc for the result, we didn't have the in place // operators yet. auto resTensorType = castOp.result().getType().template cast(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); if (mlir::failed(outMemref)) { return mlir::failure(); } // The first operand is the result mlir::SmallVector operands{ getCastedMemRef(rewriter, loc, *outMemref), }; // For all tensor operand get the corresponding casted buffer for (auto &operand : op->getOpOperands()) { if (!operand.get().getType().isa()) { operands.push_back(operand.get()); } else { auto memrefOperand = bufferization::getBuffer(rewriter, operand.get(), options); operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); } } // Append additional argument pushAdditionalArgs(castOp, operands, rewriter); // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { return mlir::failure(); } rewriter.create(loc, funcName, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, *outMemref); return success(); } }; struct BufferizableGlweFromTableOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableGlweFromTableOpInterface, BConcrete::FillGlweFromTable> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } /// Bufferize GlweFromTable /// ``` /// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1, /// polynomialSize=2048, outPrecision=3} : /// (tensor<4096xi64>, tensor<32xi64>) -> () /// ``` /// /// to /// /// ``` /// %glweDim = arith.constant 1 : i32 /// %polySize = arith.constant 2048 : i32 /// %outPrecision = arith.constant 3 : i32 /// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref /// %lut_ = memref.cast %lut : memref<32xi64> to memref /// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, /// %outPrecision, %lut_) : /// (tensor, i32, i32, tensor) -> () /// ``` LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); auto glweOp = getCastedMemRef(rewriter, loc, bufferization::getBuffer( rewriter, castOp->getOpOperand(0).get(), options)); auto lutOp = getCastedMemRef(rewriter, loc, bufferization::getBuffer( rewriter, castOp->getOpOperand(1).get(), options)); auto polySizeOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize())); auto glweDimensionOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.glweDimension())); auto outPrecisionOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.outPrecision())); mlir::SmallVector operands{glweOp, polySizeOp, glweDimensionOp, outPrecisionOp, lutOp}; // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI( op, rewriter, memref_expand_lut_in_trivial_glwe_ct_u64) .failed()) { return mlir::failure(); } rewriter.create( loc, memref_expand_lut_in_trivial_glwe_ct_u64, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, {}); return success(); } }; template struct BufferizableWithAsyncCallOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableWithAsyncCallOpInterface, Op> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); // For now we always alloc for the result, we didn't have the in place // operators yet. auto resTensorType = castOp.result() .getType() .template cast() .getElementType() .template cast(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); if (mlir::failed(outMemref)) { return mlir::failure(); } // The first operand is the result mlir::SmallVector operands{ getCastedMemRef(rewriter, loc, *outMemref), }; // For all tensor operand get the corresponding casted buffer for (auto &operand : op->getOpOperands()) { if (!operand.get().getType().isa()) { operands.push_back(operand.get()); } else { auto memrefOperand = bufferization::getBuffer(rewriter, operand.get(), options); operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); } } // Append the context argument if (withContext) { operands.push_back(getContextArgument(op)); } // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { return mlir::failure(); } auto result = rewriter.create( loc, funcName, mlir::TypeRange{ mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())}, operands); replaceOpWithBufferizedValues(rewriter, op, result.getResult(0)); return success(); } }; template struct BufferizableWithSyncCallOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableWithSyncCallOpInterface, Op> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); auto resTensorType = castOp.result().getType().template cast(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); if (mlir::failed(outMemref)) { return mlir::failure(); } // The first operand is the result mlir::SmallVector operands{ getCastedMemRef(rewriter, loc, *outMemref), }; // Then add the future operand operands.push_back(op->getOpOperand(0).get()); // Finally add a dependence on the memref covered by the future to // prevent early deallocation auto def = op->getOpOperand(0).get().getDefiningOp(); operands.push_back(def->getOpOperand(0).get()); operands.push_back(def->getOpOperand(1).get()); // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { return mlir::failure(); } rewriter.create(loc, funcName, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, *outMemref); return success(); } }; template <> void pushAdditionalArgs(BConcrete::AddPlaintextLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter){}; template <> void pushAdditionalArgs(BConcrete::AddLweBuffersOp op, mlir::SmallVector &operands, RewriterBase &rewriter){}; template <> void pushAdditionalArgs(BConcrete::MulCleartextLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter){}; template <> void pushAdditionalArgs(BConcrete::NegateLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter){}; template <> void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter) { // context operands.push_back(getContextArgument(op)); }; template <> void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter) { // context operands.push_back(getContextArgument(op)); }; template <> void pushAdditionalArgs(BConcrete::WopPBSCRTLweBufferOp op, mlir::SmallVector &operands, RewriterBase &rewriter) { mlir::Type crtType = mlir::RankedTensorType::get( {(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type()); std::vector values; for (auto a : op.crtDecomposition()) { values.push_back(a.cast().getValue().getZExtValue()); } auto attr = rewriter.getI64TensorAttr(values); auto x = rewriter.create(op.getLoc(), attr, crtType); auto globalMemref = bufferization::getGlobalFor(x, 0); assert(!failed(globalMemref)); auto globalRef = rewriter.create( op.getLoc(), (*globalMemref).type(), (*globalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, op.getLoc(), globalRef)); // lwe_small_size operands.push_back(rewriter.create( op.getLoc(), op.packingKeySwitchInputLweDimensionAttr())); // cbs_level_count operands.push_back(rewriter.create( op.getLoc(), op.circuitBootstrapLevelAttr())); // cbs_base_log operands.push_back(rewriter.create( op.getLoc(), op.circuitBootstrapBaseLogAttr())); // polynomial_size operands.push_back(rewriter.create( op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr())); // context operands.push_back(getContextArgument(op)); }; } // namespace void mlir::concretelang::BConcrete:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, BConcrete::BConcreteDialect *dialect) { BConcrete::AddLweBuffersOp::attachInterface>(*ctx); BConcrete::AddPlaintextLweBufferOp::attachInterface< BufferizableWithCallOpInterface< BConcrete::AddPlaintextLweBufferOp, memref_add_plaintext_lwe_ciphertext_u64>>(*ctx); BConcrete::MulCleartextLweBufferOp::attachInterface< BufferizableWithCallOpInterface< BConcrete::MulCleartextLweBufferOp, memref_mul_cleartext_lwe_ciphertext_u64>>(*ctx); BConcrete::NegateLweBufferOp::attachInterface< BufferizableWithCallOpInterface>( *ctx); BConcrete::KeySwitchLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); BConcrete::BootstrapLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); BConcrete::WopPBSCRTLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface< BufferizableWithAsyncCallOpInterface< BConcrete::KeySwitchLweBufferAsyncOffloadOp, memref_keyswitch_async_lwe_u64, true>>(*ctx); BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface< BufferizableWithAsyncCallOpInterface< BConcrete::BootstrapLweBufferAsyncOffloadOp, memref_bootstrap_async_lwe_u64, true>>(*ctx); BConcrete::AwaitFutureOp::attachInterface< BufferizableWithSyncCallOpInterface>(*ctx); BConcrete::FillGlweFromTable::attachInterface< BufferizableGlweFromTableOpInterface>(*ctx); }); }