// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h" #include "concretelang/Support/CompilerEngine.h" #include #include #include using namespace mlir; using namespace mlir::bufferization; using namespace mlir::tensor; namespace { namespace BConcrete = mlir::concretelang::BConcrete; template struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel< TensorToMemrefOp, TensorOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); auto resTensorType = castOp.result().getType().template cast(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); if (mlir::failed(outMemref)) { return mlir::failure(); } // The first operand is the result mlir::SmallVector operands{ *outMemref, }; for (auto &operand : op->getOpOperands()) { if (!operand.get().getType().isa()) { operands.push_back(operand.get()); } else { operands.push_back( bufferization::getBuffer(rewriter, operand.get(), options)); } } rewriter.create(loc, mlir::TypeRange{}, operands, op->getAttrs()); replaceOpWithBufferizedValues(rewriter, op, *outMemref); return success(); } }; } // namespace void mlir::concretelang::BConcrete:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, BConcrete::BConcreteDialect *dialect) { // add_lwe_tensor => add_lwe_buffer BConcrete::AddLweTensorOp::attachInterface< TensorToMemrefOp>( *ctx); // add_plaintext_lwe_tensor => add_plaintext_lwe_buffer BConcrete::AddPlaintextLweTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // mul_cleartext_lwe_tensor => mul_cleartext_lwe_buffer BConcrete::MulCleartextLweTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer BConcrete::NegateLweTensorOp::attachInterface>(*ctx); // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer BConcrete::NegateLweTensorOp::attachInterface>(*ctx); // keyswitch_lwe_tensor => keyswitch_lwe_buffer BConcrete::KeySwitchLweTensorOp::attachInterface>( *ctx); // bootstrap_lwe_tensor => bootstrap_lwe_buffer BConcrete::BootstrapLweTensorOp::attachInterface>( *ctx); // batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer BConcrete::BatchedKeySwitchLweTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // batched_bootstrap_lwe_tensor => batched_bootstrap_lwe_buffer BConcrete::BatchedBootstrapLweTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer BConcrete::WopPBSCRTLweTensorOp::attachInterface>( *ctx); // encode_plaintext_with_crt_tensor => encode_plaintext_with_crt_buffer BConcrete::EncodePlaintextWithCrtTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // encode_expand_lut_for_bootstrap_tensor => // encode_expand_lut_for_bootstrap_buffer BConcrete::EncodeExpandLutForBootstrapTensorOp::attachInterface< TensorToMemrefOp>(*ctx); // encode_expand_lut_for_woppbs_tensor => // encode_expand_lut_for_woppbs_buffer BConcrete::EncodeExpandLutForWopPBSTensorOp::attachInterface< TensorToMemrefOp>(*ctx); }); }