// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h" #include #include #include using namespace mlir; using namespace mlir::bufferization; using namespace mlir::tensor; namespace BConcrete = mlir::concretelang::BConcrete; namespace mlir { namespace concretelang { namespace BConcrete { namespace {} // namespace } // namespace BConcrete } // namespace concretelang } // namespace mlir namespace { // Returns a map with a symbolic offset for each dimension, i.e., for N // dimensions, it returns // // [d1, d2, ..., dN](s1, s2, ..., sN) -> (d1 + s1, d2 + s2, ..., dN + sN) // AffineMap getMultiDimSymbolicOffsetMap(mlir::RewriterBase &rewriter, unsigned rank) { SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(rewriter.getAffineDimExpr(i) + rewriter.getAffineSymbolExpr(i)); return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, dimExprs, rewriter.getContext()); } mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, size_t rank) { mlir::MLIRContext *ctx = rewriter.getContext(); std::vector shape(rank, -1); return mlir::MemRefType::get(shape, rewriter.getI64Type(), getMultiDimSymbolicOffsetMap(rewriter, rank)); } // Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, mlir::Value value) { mlir::Type valueType = value.getType(); if (auto memrefTy = valueType.dyn_cast_or_null()) { return rewriter.create( loc, getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), value); } else { return value; } } char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64"; char memref_add_plaintext_lwe_ciphertext_u64[] = "memref_add_plaintext_lwe_ciphertext_u64"; char memref_mul_cleartext_lwe_ciphertext_u64[] = "memref_mul_cleartext_lwe_ciphertext_u64"; char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; mlir::LogicalResult insertForwardDeclarationOfTheCAPI( mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); mlir::FunctionType funcType; if (funcName == memref_add_lwe_ciphertexts_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {}); } else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_negate_lwe_ciphertext_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType}, {}); } else if (funcName == memref_keyswitch_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {}); } else if (funcName == memref_bootstrap_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType, contextType}, {}); } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), memref1DType, }, {}); } else if (funcName == memref_wop_pbs_crt_buffer) { funcType = mlir::FunctionType::get(rewriter.getContext(), { memref2DType, memref2DType, memref1DType, contextType, }, {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); } return insertForwardDeclaration(op, rewriter, funcName, funcType); } /// Returns the value of the context argument from the enclosing func mlir::Value getContextArgument(mlir::Operation *op) { mlir::Block *block = op->getBlock(); while (block != nullptr) { if (llvm::isa(block->getParentOp())) { auto context = std::find_if(block->getArguments().rbegin(), block->getArguments().rend(), [](BlockArgument &arg) { return arg.getType() .isa(); }); assert(context != block->getArguments().rend() && "Cannot find the Concrete.context"); return *context; } block = block->getParentOp()->getBlock(); } assert("can't find a function that enclose the op"); return nullptr; } template struct BufferizableWithCallOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableWithCallOpInterface, Op> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); // For now we always alloc for the result, we didn't have the in place // operators yet. auto resTensorType = castOp.result().getType().template cast(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); if (mlir::failed(outMemref)) { return mlir::failure(); } // The first operand is the result mlir::SmallVector operands{ getCastedMemRef(rewriter, loc, *outMemref), }; // For all tensor operand get the corresponding casted buffer for (auto &operand : op->getOpOperands()) { if (!operand.get().getType().isa()) { operands.push_back(operand.get()); } else { auto memrefOperand = bufferization::getBuffer(rewriter, operand.get(), options); operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); } } // Append the context argument if (withContext) { operands.push_back(getContextArgument(op)); } // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { return mlir::failure(); } rewriter.create(loc, funcName, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, *outMemref); return success(); } }; struct BufferizableGlweFromTableOpInterface : public BufferizableOpInterface::ExternalModel< BufferizableGlweFromTableOpInterface, BConcrete::FillGlweFromTable> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } /// Bufferize GlweFromTable /// ``` /// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1, /// polynomialSize=2048, outPrecision=3} : /// (tensor<4096xi64>, tensor<32xi64>) -> () /// ``` /// /// to /// /// ``` /// %glweDim = arith.constant 1 : i32 /// %polySize = arith.constant 2048 : i32 /// %outPrecision = arith.constant 3 : i32 /// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref /// %lut_ = memref.cast %lut : memref<32xi64> to memref /// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, /// %outPrecision, %lut_) : /// (tensor, i32, i32, tensor) -> () /// ``` LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto loc = op->getLoc(); auto castOp = cast(op); auto glweOp = getCastedMemRef(rewriter, loc, bufferization::getBuffer( rewriter, castOp->getOpOperand(0).get(), options)); auto lutOp = getCastedMemRef(rewriter, loc, bufferization::getBuffer( rewriter, castOp->getOpOperand(1).get(), options)); auto polySizeOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize())); auto glweDimensionOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.glweDimension())); auto outPrecisionOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.outPrecision())); mlir::SmallVector operands{glweOp, polySizeOp, glweDimensionOp, outPrecisionOp, lutOp}; // Insert forward declaration of the function if (insertForwardDeclarationOfTheCAPI( op, rewriter, memref_expand_lut_in_trivial_glwe_ct_u64) .failed()) { return mlir::failure(); } rewriter.create( loc, memref_expand_lut_in_trivial_glwe_ct_u64, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, {}); return success(); } }; } // namespace void mlir::concretelang::BConcrete:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, BConcrete::BConcreteDialect *dialect) { BConcrete::AddLweBuffersOp::attachInterface>(*ctx); BConcrete::AddPlaintextLweBufferOp::attachInterface< BufferizableWithCallOpInterface< BConcrete::AddPlaintextLweBufferOp, memref_add_plaintext_lwe_ciphertext_u64>>(*ctx); BConcrete::MulCleartextLweBufferOp::attachInterface< BufferizableWithCallOpInterface< BConcrete::MulCleartextLweBufferOp, memref_mul_cleartext_lwe_ciphertext_u64>>(*ctx); BConcrete::NegateLweBufferOp::attachInterface< BufferizableWithCallOpInterface>( *ctx); BConcrete::KeySwitchLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); BConcrete::BootstrapLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); // TODO(16bits): hack BConcrete::WopPBSCRTLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); BConcrete::FillGlweFromTable::attachInterface< BufferizableGlweFromTableOpInterface>(*ctx); }); }