// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Function.h" namespace Concrete = ::mlir::concretelang::Concrete; namespace BConcrete = ::mlir::concretelang::BConcrete; namespace { struct ConcreteToBConcretePass : public ConcreteToBConcreteBase { void runOnOperation() final; }; } // namespace /// ConcreteToBConcreteTypeConverter is a TypeConverter that transform /// `Concrete.lwe_ciphertext` to `tensor>` /// `tensor<...xConcrete.lwe_ciphertext>` to /// `tensor<...xdimension+1, i64>>` class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter { public: ConcreteToBConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); addConversion([&](mlir::concretelang::Concrete::CleartextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { assert(type.getDimension() != -1); llvm::SmallVector shape; shape.push_back(type.getDimension() + 1); return mlir::RankedTensorType::get( shape, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) { assert(type.getGlweDimension() != -1); assert(type.getPolynomialSize() != -1); return mlir::RankedTensorType::get( {type.getPolynomialSize() * (type.getGlweDimension() + 1)}, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::RankedTensorType type) { auto lwe = type.getElementType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>(); if (lwe == nullptr) { return (mlir::Type)(type); } assert(lwe.getDimension() != -1); mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); newShape.push_back(lwe.getDimension() + 1); mlir::Type r = mlir::RankedTensorType::get( newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.dyn_cast() .getElementType())); }); } }; template struct ZeroOpPattern : public mlir::OpRewritePattern { ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ZeroOp zeroOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = zeroOp.getType(); auto newResultTy = converter.convertType(resultTy); auto generateBody = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { // %c0 = 0 : i64 auto cstOp = nestedBuilder.create( nestedLoc, nestedBuilder.getI64IntegerAttr(0)); // tensor.yield %z : !FHE.eint

nestedBuilder.create(nestedLoc, cstOp.getResult()); }; // tensor.generate rewriter.replaceOpWithNewOp( zeroOp, newResultTy, mlir::ValueRange{}, generateBody); return ::mlir::success(); }; }; template struct LowToBConcrete : public mlir::OpRewritePattern { LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ConcreteOp concreteOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::TypeRange resultTyRange = concreteOp->getResultTypes(); llvm::ArrayRef<::mlir::NamedAttribute> attributes = concreteOp.getOperation()->getAttrs(); mlir::Operation *bConcreteOp; bConcreteOp = rewriter.replaceOpWithNewOp( concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(), attributes); mlir::concretelang::convertOperandAndResultTypes( rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct LowerKeySwitch : public mlir::OpRewritePattern< mlir::concretelang::Concrete::KeySwitchLweOp> { LowerKeySwitch(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp ksOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; // construct attributes for in/out dimensions mlir::concretelang::Concrete::LweCiphertextType outType = ksOp.getType(); auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension()); auto inputType = converter.convertType(ksOp.ciphertext().getType()) .cast(); auto inputDimension = inputType.getShape().back() - 1; mlir::IntegerAttr inputDimAttr = rewriter.getI32IntegerAttr(inputDimension); mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp< mlir::concretelang::BConcrete::KeySwitchLweTensorOp>( ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(), inputDimAttr, outDimAttr); mlir::concretelang::convertOperandAndResultTypes( rewriter, bKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct LowerBatchedKeySwitch : public mlir::OpRewritePattern< mlir::concretelang::Concrete::BatchedKeySwitchLweOp> { LowerBatchedKeySwitch(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern< mlir::concretelang::Concrete::BatchedKeySwitchLweOp>(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::BatchedKeySwitchLweOp bksOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType outType = bksOp.getType() .cast() .getElementType() .cast(); auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension()); auto inputType = bksOp.ciphertexts() .getType() .cast() .getElementType() .cast(); mlir::IntegerAttr inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension()); mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp< mlir::concretelang::BConcrete::BatchedKeySwitchLweTensorOp>( bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(), bksOp.baseLogAttr(), inputDimAttr, outDimAttr); mlir::concretelang::convertOperandAndResultTypes( rewriter, bBatchedKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct LowerBootstrap : public mlir::OpRewritePattern< mlir::concretelang::Concrete::BootstrapLweOp> { LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp bsOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType outType = bsOp.getType(); auto inputType = converter.convertType(bsOp.input_ciphertext().getType()) .cast(); auto inputDimension = inputType.getShape().back() - 1; mlir::IntegerAttr inputDimAttr = rewriter.getI32IntegerAttr(inputDimension); auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp< mlir::concretelang::BConcrete::BootstrapLweTensorOp>( bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(), inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(), bsOp.glweDimensionAttr(), outputPrecisionAttr); mlir::concretelang::convertOperandAndResultTypes( rewriter, bBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct LowerBatchedBootstrap : public mlir::OpRewritePattern< mlir::concretelang::Concrete::BatchedBootstrapLweOp> { LowerBatchedBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern< mlir::concretelang::Concrete::BatchedBootstrapLweOp>(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::BatchedBootstrapLweOp bbsOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType outType = bbsOp.getType() .cast() .getElementType() .cast(); auto inputType = bbsOp.input_ciphertexts() .getType() .cast() .getElementType() .cast(); auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension()); auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp< mlir::concretelang::BConcrete::BatchedBootstrapLweTensorOp>( bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(), inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(), bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr); mlir::concretelang::convertOperandAndResultTypes( rewriter, bBatchedBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct AddPlaintextLweCiphertextOpPattern : public mlir::OpRewritePattern { AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType resultTy = ((mlir::Type)concreteOp->getResult(0).getType()) .cast(); auto newResultTy = converter.convertType(resultTy).cast(); llvm::ArrayRef<::mlir::NamedAttribute> attributes = concreteOp.getOperation()->getAttrs(); mlir::Operation *bConcreteOp; bConcreteOp = rewriter.replaceOpWithNewOp( concreteOp, newResultTy, mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes); mlir::concretelang::convertOperandAndResultTypes( rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct MulCleartextLweCiphertextOpPattern : public mlir::OpRewritePattern { MulCleartextLweCiphertextOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType resultTy = ((mlir::Type)concreteOp->getResult(0).getType()) .cast(); auto newResultTy = converter.convertType(resultTy).cast(); llvm::ArrayRef<::mlir::NamedAttribute> attributes = concreteOp.getOperation()->getAttrs(); mlir::Operation *bConcreteOp; bConcreteOp = rewriter.replaceOpWithNewOp( concreteOp, newResultTy, mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes); mlir::concretelang::convertOperandAndResultTypes( rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct ExtractSliceOpPattern : public mlir::OpRewritePattern { ExtractSliceOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = extractSliceOp.result().getType(); auto newResultTy = converter.convertType(resultTy).cast(); // add 0 to the static_offsets mlir::SmallVector staticOffsets; staticOffsets.append(extractSliceOp.static_offsets().begin(), extractSliceOp.static_offsets().end()); staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add the lweSize to the sizes mlir::SmallVector staticSizes; staticSizes.append(extractSliceOp.static_sizes().begin(), extractSliceOp.static_sizes().end()); staticSizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); // add 1 to the strides mlir::SmallVector staticStrides; staticStrides.append(extractSliceOp.static_strides().begin(), extractSliceOp.static_strides().end()); staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one mlir::tensor::ExtractSliceOp extractOp = rewriter.replaceOpWithNewOp( extractSliceOp, newResultTy, extractSliceOp.source(), extractSliceOp.offsets(), extractSliceOp.sizes(), extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); mlir::concretelang::convertOperandAndResultTypes( rewriter, extractOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; // TODO: since they are a bug on lowering extract_slice with rank reduction we // add a linalg.tensor_collapse_shape after the extract_slice without rank // reduction. See // https://github.com/zama-ai/concrete-compiler-internal/issues/396. struct ExtractOpPattern : public mlir::OpRewritePattern { ExtractOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractOp extractOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto lweResultTy = extractOp.result() .getType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>(); if (lweResultTy == nullptr) { return mlir::failure(); } auto newResultTy = converter.convertType(lweResultTy).cast(); auto rankOfResult = extractOp.indices().size() + 1; // [min..., 0] for static_offsets () mlir::SmallVector staticOffsets( rankOfResult, rewriter.getI64IntegerAttr(std::numeric_limits::min())); staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); // [1..., lweDimension+1] for static_sizes or // [1..., nbBlock, lweDimension+1] mlir::SmallVector staticSizes( rankOfResult, rewriter.getI64IntegerAttr(1)); staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1)); // [1...] for static_strides mlir::SmallVector staticStrides( rankOfResult, rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one mlir::SmallVector extractedSliceShape(rankOfResult, 1); extractedSliceShape[extractedSliceShape.size() - 1] = newResultTy.getDimSize(0); auto extractedSliceType = mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type()); auto extractedSlice = rewriter.create( extractOp.getLoc(), extractedSliceType, extractOp.tensor(), extractOp.indices(), mlir::SmallVector{}, mlir::SmallVector{}, rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); mlir::concretelang::convertOperandAndResultTypes( rewriter, extractedSlice, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); mlir::ReassociationIndices reassociation; for (int64_t i = 0; i < extractedSliceType.getRank(); i++) { reassociation.push_back(i); } mlir::SmallVector reassocs{reassociation}; mlir::tensor::CollapseShapeOp collapseOp = rewriter.replaceOpWithNewOp( extractOp, newResultTy, extractedSlice, reassocs); mlir::concretelang::convertOperandAndResultTypes( rewriter, collapseOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct InsertSliceOpPattern : public mlir::OpRewritePattern { InsertSliceOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = insertSliceOp.result().getType(); auto lweResultTy = resultTy.cast() .getElementType() .cast(); if (lweResultTy == nullptr) { return mlir::failure(); } auto newResultTy = converter.convertType(resultTy).cast(); // add 0 to static_offsets mlir::SmallVector staticOffsets; staticOffsets.append(insertSliceOp.static_offsets().begin(), insertSliceOp.static_offsets().end()); staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add lweDimension+1 to static_sizes mlir::SmallVector staticSizes; staticSizes.append(insertSliceOp.static_sizes().begin(), insertSliceOp.static_sizes().end()); staticSizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); // add 1 to the strides mlir::SmallVector staticStrides; staticStrides.append(insertSliceOp.static_strides().begin(), insertSliceOp.static_strides().end()); staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one auto newOp = rewriter.replaceOpWithNewOp( insertSliceOp, newResultTy, insertSliceOp.source(), insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(), insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); mlir::concretelang::convertOperandAndResultTypes( rewriter, newOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; struct InsertOpPattern : public mlir::OpRewritePattern { InsertOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertOp insertOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = insertOp.result().getType().dyn_cast_or_null(); auto lweResultTy = resultTy.getElementType() .dyn_cast_or_null(); if (lweResultTy == nullptr) { return mlir::failure(); }; mlir::RankedTensorType newResultTy = converter.convertType(resultTy).cast(); // add zeros to static_offsets mlir::SmallVector offsets; offsets.append(insertOp.indices().begin(), insertOp.indices().end()); offsets.push_back(rewriter.getIndexAttr(0)); // Inserting a smaller tensor into a (potentially) bigger one. Set // dimensions for all leading dimensions of the target tensor not // present in the source to 1. mlir::SmallVector sizes(insertOp.indices().size(), rewriter.getI64IntegerAttr(1)); // Add size for the bufferized source element sizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); // Set stride of all dimensions to 1 mlir::SmallVector strides( newResultTy.getRank(), rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one mlir::tensor::InsertSliceOp insertSliceOp = rewriter.replaceOpWithNewOp( insertOp, insertOp.getOperand(0), insertOp.dest(), offsets, sizes, strides); mlir::concretelang::convertOperandAndResultTypes( rewriter, insertSliceOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; /// FromElementsOpPatterns transform each tensor.from_elements that operates on /// Concrete.lwe_ciphertext /// /// refs: check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir struct FromElementsOpPattern : public mlir::OpRewritePattern { FromElementsOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = fromElementsOp.result().getType(); if (converter.isLegal(resultTy)) { return mlir::failure(); } auto oldTensorResultTy = resultTy.cast(); auto oldRank = oldTensorResultTy.getRank(); auto newTensorResultTy = converter.convertType(resultTy).cast(); auto newRank = newTensorResultTy.getRank(); auto newShape = newTensorResultTy.getShape(); mlir::Value tensor = rewriter.create( fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{}); // sizes are [1, ..., 1, diffShape...] llvm::SmallVector sizes(oldRank, rewriter.getI64IntegerAttr(1)); for (auto i = newRank - oldRank; i > 0; i--) { sizes.push_back(rewriter.getI64IntegerAttr(*(newShape.end() - i))); } // strides are [1, ..., 1] llvm::SmallVector oneStrides( newShape.size(), rewriter.getI64IntegerAttr(1)); // start with offets [0, ..., 0] llvm::SmallVector currentOffsets(newRank, 0); // for each elements insert_slice with right offet for (auto elt : llvm::enumerate(fromElementsOp.elements())) { // Just create offsets as attributes llvm::SmallVector offsets; offsets.reserve(currentOffsets.size()); std::transform(currentOffsets.begin(), currentOffsets.end(), std::back_inserter(offsets), [&](auto v) { return rewriter.getI64IntegerAttr(v); }); mlir::tensor::InsertSliceOp insOp = rewriter.create( fromElementsOp.getLoc(), /* src: */ elt.value(), /* dst: */ tensor, /* offs: */ offsets, /* sizes: */ sizes, /* strides: */ oneStrides); mlir::concretelang::convertOperandAndResultTypes( rewriter, insOp, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); tensor = insOp.getResult(); // Increment the offsets for (auto i = newRank - 2; i >= 0; i--) { if (currentOffsets[i] == newShape[i] - 1) { currentOffsets[i] = 0; continue; } currentOffsets[i]++; break; } } rewriter.replaceOp(fromElementsOp, tensor); return ::mlir::success(); }; }; // This template rewrite pattern transforms any instance of // `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the // lwe size as a size of the tensor result and by adding a trivial // reassociation at the end of the reassociations map. // // Example: // // ```mlir // %0 = "ShapeOp" %arg0 [reassocations...] // : tensor<...x!Concrete.lwe_ciphertext> into // tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] // : tensor<...xlweDimesion+1xi64> into // tensor<...xlweDimesion+1xi64> // ``` template struct TensorShapeOpPattern : public mlir::OpRewritePattern { TensorShapeOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ShapeOp shapeOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast(); auto newResultTy = ((mlir::Type)converter.convertType(resultTy)).cast(); auto reassocTy = ((mlir::Type)converter.convertType( (inRank ? shapeOp.src() : shapeOp.result()).getType())) .cast(); auto oldReassocs = shapeOp.getReassociationIndices(); mlir::SmallVector newReassocs; newReassocs.append(oldReassocs.begin(), oldReassocs.end()); // add [rank] to reassociations { mlir::ReassociationIndices lweAssoc; lweAssoc.push_back(reassocTy.getRank() - 1); newReassocs.push_back(lweAssoc); } ShapeOp op = rewriter.replaceOpWithNewOp( shapeOp, newResultTy, shapeOp.src(), newReassocs); // fix operand types mlir::concretelang::convertOperandAndResultTypes( rewriter, op, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); return ::mlir::success(); }; }; /// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp` /// to the patterns set and populate the conversion target. template void insertTensorShapeOpPattern(mlir::MLIRContext &context, mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { patterns.insert>(&context); target.addDynamicallyLegalOp([&](mlir::Operation *op) { ConcreteToBConcreteTypeConverter converter; return converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getOperandTypes()); }); } struct AllocTensorOpPattern : public mlir::OpRewritePattern { AllocTensorOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::bufferization::AllocTensorOp allocTensorOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::RankedTensorType resultTy = allocTensorOp.getType().dyn_cast(); if (!resultTy || !resultTy.hasStaticShape()) return mlir::failure(); mlir::RankedTensorType newResultTy = converter.convertType(resultTy).dyn_cast(); if (resultTy.getShape().size() != newResultTy.getShape().size()) { rewriter.replaceOpWithNewOp( allocTensorOp, newResultTy, mlir::ValueRange{}); } return ::mlir::success(); }; }; struct ForOpPattern : public mlir::OpRewritePattern { ForOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::scf::ForOp forOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; // TODO: Check if there is a cleaner way to modify the types in // place through appropriate interfaces or by reconstructing the // ForOp with the right types. rewriter.updateRootInPlace(forOp, [&] { for (mlir::Value initArg : forOp.getInitArgs()) { mlir::Type convertedType = converter.convertType(initArg.getType()); initArg.setType(convertedType); } for (mlir::Value &blockArg : forOp.getBody()->getArguments()) { mlir::Type convertedType = converter.convertType(blockArg.getType()); blockArg.setType(convertedType); } for (mlir::OpResult result : forOp.getResults()) { mlir::Type convertedType = converter.convertType(result.getType()); result.setType(convertedType); } }); return ::mlir::success(); }; }; void ConcreteToBConcretePass::runOnOperation() { auto op = this->getOperation(); // Then convert ciphertext to tensor or add a dimension to tensor of // ciphertext and memref of ciphertext { mlir::ConversionTarget target(getContext()); ConcreteToBConcreteTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); // All BConcrete ops are legal after the conversion target.addLegalDialect(); // Add Concrete ops are illegal after the conversion target.addIllegalDialect(); target.addLegalDialect(); // Add patterns to convert the zero ops to tensor.generate patterns .insert, ZeroOpPattern>( &getContext()); target.addLegalOp(); // Add patterns to trivialy convert Concrete op to the equivalent // BConcrete op patterns.insert< LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch, LowerBatchedKeySwitch, LowToBConcrete, AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern, LowToBConcrete< mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp, mlir::concretelang::BConcrete::EncodeExpandLutForBootstrapTensorOp>, LowToBConcrete< mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp, mlir::concretelang::BConcrete::EncodeExpandLutForWopPBSTensorOp>, LowToBConcrete< mlir::concretelang::Concrete::EncodePlaintextWithCrtOp, mlir::concretelang::BConcrete::EncodePlaintextWithCrtTensorOp>, LowToBConcrete, LowToBConcrete>( &getContext()); // Add patterns to rewrite tensor operators that works on encrypted // tensors patterns .insert(&getContext()); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getOperandTypes()); }); patterns.insert(&getContext()); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return converter.isLegal(op->getResult(0).getType()); }); target.addLegalOp(); patterns.insert(&getContext()); // Add patterns to rewrite some of memref ops that was introduced by the // linalg bufferization of encrypted tensor (first conversion of this // pass) insertTensorShapeOpPattern(getContext(), patterns, target); insertTensorShapeOpPattern(getContext(), patterns, target); insertTensorShapeOpPattern(getContext(), patterns, target); insertTensorShapeOpPattern(getContext(), patterns, target); target.addDynamicallyLegalOp< mlir::arith::ConstantOp, mlir::scf::ForOp, mlir::scf::ParallelOp, mlir::scf::YieldOp, mlir::AffineApplyOp, mlir::memref::SubViewOp, mlir::memref::LoadOp, mlir::memref::TensorStoreOp>( [&](mlir::Operation *op) { return converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getOperandTypes()); }); // Add patterns to do the conversion of func mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion< ConcreteToBConcreteTypeConverter>::isLegal(op, converter); }); patterns .insert>( &getContext(), converter); target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { return converter.isLegal(forOp.getInitArgs().getTypes()) && converter.isLegal(forOp.getResults().getTypes()); }); // Add pattern for return op target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getOperandTypes()); }); // Conversion of RT Dialect Ops patterns.add< mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::MakeReadyFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::AwaitFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::CreateAsyncTaskOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::AwaitFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } } namespace mlir { namespace concretelang { std::unique_ptr> createConvertConcreteToBConcretePass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir