// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #include #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" namespace { struct ConcreteToBConcretePass : public ConcreteToBConcreteBase { void runOnOperation() final; }; } // namespace /// ConcreteToBConcreteTypeConverter is a TypeConverter that transform /// `Concrete.lwe_ciphertext` to `tensor>` /// `tensor<...xConcrete.lwe_ciphertext>` to /// `tensor<...xdimension+1, i64>>` class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter { public: ConcreteToBConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { assert(type.getDimension() != -1); return mlir::RankedTensorType::get( {type.getDimension() + 1}, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) { assert(type.getGlweDimension() != -1); assert(type.getPolynomialSize() != -1); return mlir::RankedTensorType::get( {type.getPolynomialSize() * (type.getGlweDimension() + 1)}, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::RankedTensorType type) { auto lwe = type.getElementType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>(); if (lwe == nullptr) { return (mlir::Type)(type); } assert(lwe.getDimension() != -1); mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); newShape.push_back(lwe.getDimension() + 1); mlir::Type r = mlir::RankedTensorType::get( newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); addConversion([&](mlir::MemRefType type) { auto lwe = type.getElementType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>(); if (lwe == nullptr) { return (mlir::Type)(type); } assert(lwe.getDimension() != -1); mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); newShape.push_back(lwe.getDimension() + 1); mlir::Type r = mlir::MemRefType::get( newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); } }; // This rewrite pattern transforms any instance of `Concrete.zero_tensor` // operators. // // Example: // // ```mlir // %0 = "Concrete.zero_tensor" () : // tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = tensor.generate { // ^bb0(... : index): // %c0 = arith.constant 0 : i64 // tensor.yield %z // }: tensor<...xlweDim+1xi64> // i64> // ``` template struct ZeroOpPattern : public mlir::OpRewritePattern { ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ZeroOp zeroOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = zeroOp.getType(); auto newResultTy = converter.convertType(resultTy); auto generateBody = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { // %c0 = 0 : i64 auto cstOp = nestedBuilder.create( nestedLoc, nestedBuilder.getI64IntegerAttr(1)); // tensor.yield %z : !FHE.eint

nestedBuilder.create(nestedLoc, cstOp.getResult()); }; // tensor.generate rewriter.replaceOpWithNewOp( zeroOp, newResultTy, mlir::ValueRange{}, generateBody); return ::mlir::success(); }; }; // This template rewrite pattern transforms any instance of // `ConcreteOp` to an instance of `BConcreteOp`. // // Example: // // %0 = "ConcreteOp"(%arg0, ...) : // (!Concrete.lwe_ciphertext, ...) -> // (!Concrete.lwe_ciphertext) // // becomes: // // %0 = linalg.init_tensor [dimension+1] : tensor // "BConcreteOp"(%0, %arg0, ...) : (tensor>, // tensor>, ..., ) -> () // // A reference to the preallocated output is always passed as the first // argument. template struct LowToBConcrete : public mlir::OpRewritePattern { LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ConcreteOp concreteOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::concretelang::Concrete::LweCiphertextType resultTy = ((mlir::Type)concreteOp->getResult(0).getType()) .cast(); auto newResultTy = converter.convertType(resultTy).cast(); // %0 = linalg.init_tensor [dimension+1] : tensor mlir::Value init = rewriter.replaceOpWithNewOp( concreteOp, newResultTy.getShape(), newResultTy.getElementType()); // "BConcreteOp"(%0, %arg0, ...) : (tensor>, // tensor>, ..., ) -> () mlir::SmallVector newOperands{init}; newOperands.append(concreteOp.getOperation()->getOperands().begin(), concreteOp.getOperation()->getOperands().end()); llvm::ArrayRef<::mlir::NamedAttribute> attributes = concreteOp.getOperation()->getAttrs(); rewriter.create(concreteOp.getLoc(), mlir::SmallVector{}, newOperands, attributes); return ::mlir::success(); }; }; // This rewrite pattern transforms any instance of // `Concrete.glwe_from_table` operators. // // Example: // // ```mlir // %0 = "Concrete.glwe_from_table"(%tlu) // : (tensor<$Dxi64>) -> // !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p> // ``` // // with $D = 2^$p // // becomes: // // ```mlir // %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] // : tensor // "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu) // : tensor, i64, i64, tensor<$Dxi64> // ``` struct GlweFromTablePattern : public mlir::OpRewritePattern< mlir::concretelang::Concrete::GlweFromTable> { GlweFromTablePattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = op.result() .getType() .cast(); auto newResultTy = converter.convertType(resultTy).cast(); // %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] // : tensor mlir::Value init = rewriter.replaceOpWithNewOp( op, newResultTy.getShape(), newResultTy.getElementType()); // "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, // %tlu) // polynomialSize*(glweDimension+1) auto polySize = resultTy.getPolynomialSize(); auto glweDimension = resultTy.getGlweDimension(); auto outPrecision = resultTy.getP(); rewriter.create( op.getLoc(), init, polySize, glweDimension, outPrecision, op.table()); return ::mlir::success(); }; }; // This rewrite pattern transforms any instance of // `tensor.extract_slice` operators that operates on tensor of lwe ciphertext. // // Example: // // ```mlir // %0 = tensor.extract_slice %arg0 // [offsets...] [sizes...] [strides...] // : tensor<...x!Concrete.lwe_ciphertext> to // tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = tensor.extract_slice %arg0 // [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] // : tensor<...xlweDimension+1,i64> to // tensor<...xlweDimension+1,i64> // ``` struct ExtractSliceOpPattern : public mlir::OpRewritePattern { ExtractSliceOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = extractSliceOp.result().getType(); auto resultEltTy = resultTy.cast() .getElementType() .cast(); auto newResultTy = converter.convertType(resultTy); // add 0 to the static_offsets mlir::SmallVector staticOffsets; staticOffsets.append(extractSliceOp.static_offsets().begin(), extractSliceOp.static_offsets().end()); staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add the lweSize to the sizes mlir::SmallVector staticSizes; staticSizes.append(extractSliceOp.static_sizes().begin(), extractSliceOp.static_sizes().end()); staticSizes.push_back( rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1)); // add 1 to the strides mlir::SmallVector staticStrides; staticStrides.append(extractSliceOp.static_strides().begin(), extractSliceOp.static_strides().end()); staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one rewriter.replaceOpWithNewOp( extractSliceOp, newResultTy, extractSliceOp.source(), extractSliceOp.offsets(), extractSliceOp.sizes(), extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); return ::mlir::success(); }; }; // This rewrite pattern transforms any instance of // `tensor.extract` operators that operates on tensor of lwe ciphertext. // // Example: // // ```mlir // %0 = tensor.extract %t[offsets...] // : tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %1 = tensor.extract_slice %arg0 // [offsets...] [1..., lweDimension+1] [1...] // : tensor<...xlweDimension+1,i64> to // tensor<1...xlweDimension+1,i64> // %0 = linalg.tensor_collapse_shape %0 [[...]] : // tensor<1x1xlweDimension+1xi64> into tensor // ``` // // TODO: since they are a bug on lowering extract_slice with rank reduction we // add a linalg.tensor_collapse_shape after the extract_slice without rank // reduction. See // https://github.com/zama-ai/concrete-compiler-internal/issues/396. struct ExtractOpPattern : public mlir::OpRewritePattern { ExtractOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractOp extractOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto lweResultTy = extractOp.result() .getType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>(); if (lweResultTy == nullptr) { return mlir::failure(); } auto newResultTy = converter.convertType(lweResultTy).cast(); auto rankOfResult = extractOp.indices().size() + 1; // [min..., 0] for static_offsets () mlir::SmallVector staticOffsets( rankOfResult, rewriter.getI64IntegerAttr(std::numeric_limits::min())); staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); // [1..., lweDimension+1] for static_sizes mlir::SmallVector staticSizes( rankOfResult, rewriter.getI64IntegerAttr(1)); staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1)); // [1...] for static_strides mlir::SmallVector staticStrides( rankOfResult, rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one mlir::SmallVector extractedSliceShape( extractOp.indices().size() + 1, 0); extractedSliceShape.reserve(extractOp.indices().size() + 1); for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) { extractedSliceShape[i] = 1; } extractedSliceShape[extractedSliceShape.size() - 1] = newResultTy.getDimSize(0); auto extractedSliceType = mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type()); auto extractedSlice = rewriter.create( extractOp.getLoc(), extractedSliceType, extractOp.tensor(), extractOp.indices(), mlir::SmallVector{}, mlir::SmallVector{}, rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); mlir::ReassociationIndices reassociation; for (int64_t i = 0; i < extractedSliceType.getRank(); i++) { reassociation.push_back(i); } rewriter.replaceOpWithNewOp( extractOp, newResultTy, extractedSlice, mlir::SmallVector{reassociation}); return ::mlir::success(); }; }; // This rewrite pattern transforms any instance of // `tensor.insert_slice` operators that operates on tensor of lwe ciphertext. // // Example: // // ```mlir // %0 = tensor.insert_slice %arg1 // into %arg0[offsets...] [sizes...] [strides...] // : tensor<...x!Concrete.lwe_ciphertext> into // tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = tensor.insert_slice %arg1 // into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] // : tensor<...xlweDimension+1xi64> into // tensor<...xlweDimension+1xi64> // ``` struct InsertSliceOpPattern : public mlir::OpRewritePattern { InsertSliceOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = insertSliceOp.result().getType(); auto newResultTy = converter.convertType(resultTy).cast(); // add 0 to static_offsets mlir::SmallVector staticOffsets; staticOffsets.append(insertSliceOp.static_offsets().begin(), insertSliceOp.static_offsets().end()); staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add lweDimension+1 to static_sizes mlir::SmallVector staticSizes; staticSizes.append(insertSliceOp.static_sizes().begin(), insertSliceOp.static_sizes().end()); staticSizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); // add 1 to the strides mlir::SmallVector staticStrides; staticStrides.append(insertSliceOp.static_strides().begin(), insertSliceOp.static_strides().end()); staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one rewriter.replaceOpWithNewOp( insertSliceOp, newResultTy, insertSliceOp.source(), insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(), insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); return ::mlir::success(); }; }; // This rewrite pattern transforms any instance of // `tensor.from_elements` operators that operates on tensor of lwe ciphertext. // // Example: // // ```mlir // %0 = tensor.from_elements %e0, ..., %e(n-1) // : tensor> // ``` // // becomes: // // ```mlir // %m = memref.alloc() : memref // %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref // %m0 = memref.buffer_cast %e0 : memref // memref.copy %m0, s0 : memref to memref // ... // %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1] // : memref // %m(n-1) = memref.buffer_cast %e(n-1) : memref // memref.copy %e(n-1), s(n-1) // : memref to memref // %0 = memref.tensor_load %m : memref // ``` struct FromElementsOpPattern : public mlir::OpRewritePattern { FromElementsOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = fromElementsOp.result().getType(); if (converter.isLegal(resultTy)) { return mlir::failure(); } auto eltResultTy = resultTy.cast() .getElementType() .cast(); auto newTensorResultTy = converter.convertType(resultTy).cast(); auto newMemrefResultTy = mlir::MemRefType::get( newTensorResultTy.getShape(), newTensorResultTy.getElementType()); // %m = memref.alloc() : memref auto mOp = rewriter.create(fromElementsOp.getLoc(), newMemrefResultTy); // for i = 0 to n-1 // %si = memref.subview %m[i, 0][1, lweDim+1][1, 1] : memref // %mi = memref.buffer_cast %ei : memref // memref.copy %mi, si : memref to memref auto subviewResultTy = mlir::MemRefType::get( {eltResultTy.getDimension() + 1}, newMemrefResultTy.getElementType()); auto offset = 0; for (auto eiOp : fromElementsOp.elements()) { mlir::SmallVector staticOffsets{ rewriter.getI64IntegerAttr(offset), rewriter.getI64IntegerAttr(0)}; mlir::SmallVector staticSizes{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(eltResultTy.getDimension() + 1)}; mlir::SmallVector staticStrides{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; auto siOp = rewriter.create( fromElementsOp.getLoc(), subviewResultTy, mOp, mlir::ValueRange{}, mlir::ValueRange{}, mlir::ValueRange{}, rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); auto miOp = rewriter.create( fromElementsOp.getLoc(), subviewResultTy, eiOp); rewriter.create(fromElementsOp.getLoc(), miOp, siOp); offset++; } // Go back to tensor world // %0 = memref.tensor_load %m : memref rewriter.replaceOpWithNewOp(fromElementsOp, mOp); return ::mlir::success(); }; }; // This template rewrite pattern transforms any instance of // `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the // lwe size as a size of the tensor result and by adding a trivial reassociation // at the end of the reassociations map. // // Example: // // ```mlir // %0 = "ShapeOp" %arg0 [reassocations...] // : tensor<...x!Concrete.lwe_ciphertext> into // tensor<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] // : tensor<...xlweDimesion+1xi64> into // tensor<...xlweDimesion+1xi64> // ``` template struct TensorShapeOpPattern : public mlir::OpRewritePattern { TensorShapeOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(ShapeOp shapeOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = shapeOp.result().getType(); auto newResultTy = ((mlir::Type)converter.convertType(resultTy)).cast(); // add [rank] to reassociations auto oldReassocs = shapeOp.getReassociationIndices(); mlir::SmallVector newReassocs; newReassocs.append(oldReassocs.begin(), oldReassocs.end()); mlir::ReassociationIndices lweAssoc; auto reassocTy = ((mlir::Type)converter.convertType( (inRank ? shapeOp.src() : shapeOp.result()).getType())) .cast(); lweAssoc.push_back(reassocTy.getRank() - 1); newReassocs.push_back(lweAssoc); rewriter.replaceOpWithNewOp(shapeOp, newResultTy, shapeOp.src(), newReassocs); return ::mlir::success(); }; }; // Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp` // to the patterns set and populate the conversion target. template void insertTensorShapeOpPattern(mlir::MLIRContext &context, mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { patterns.insert>(&context); target.addDynamicallyLegalOp([&](ShapeOp op) { ConcreteToBConcreteTypeConverter converter; return converter.isLegal(op.result().getType()); }); } // This template rewrite pattern transforms any instance of // `MemrefOp` operators that returns a memref of lwe ciphertext to the same // operator but which returns the bufferized lwe ciphertext. // // Example: // // ```mlir // %0 = "MemrefOp"(...) : ... -> memref<...x!Concrete.lwe_ciphertext> // ``` // // becomes: // // ```mlir // %0 = "MemrefOp"(...) : ... -> memref<...xlweDim+1xi64> // ``` template struct MemrefOpPattern : public mlir::OpRewritePattern { MemrefOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(MemrefOp memrefOp, mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; mlir::SmallVector convertedTypes; if (converter.convertTypes(memrefOp->getResultTypes(), convertedTypes) .failed()) { return mlir::failure(); } rewriter.replaceOpWithNewOp(memrefOp, convertedTypes, memrefOp->getOperands(), memrefOp->getAttrs()); return ::mlir::success(); }; }; template void insertMemrefOpPatternImpl(mlir::MLIRContext &context, mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { patterns.insert>(&context); target.addDynamicallyLegalOp([&](MemrefOp op) { ConcreteToBConcreteTypeConverter converter; return converter.isLegal(op->getResultTypes()); }); } // Add the instantiated MemrefOpPattern rewrite pattern with the `MemrefOp` // to the patterns set and populate the conversion target. template void insertMemrefOpPattern(mlir::MLIRContext &context, mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { (void)std::initializer_list{ 0, (insertMemrefOpPatternImpl(context, patterns, target), 0)...}; } // cc from Loops.cpp static mlir::SmallVector makeCanonicalAffineApplies(mlir::OpBuilder &b, mlir::Location loc, mlir::AffineMap map, mlir::ArrayRef vals) { if (map.isEmpty()) return {}; assert(map.getNumInputs() == vals.size()); mlir::SmallVector res; res.reserve(map.getNumResults()); auto dims = map.getNumDims(); for (auto e : map.getResults()) { auto exprMap = mlir::AffineMap::get(dims, map.getNumSymbols(), e); mlir::SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); res.push_back(b.create(loc, exprMap, operands)); } return res; } static std::pair makeOperandLoadOrSubview(mlir::OpBuilder &builder, mlir::Location loc, mlir::ArrayRef allIvs, mlir::linalg::LinalgOp linalgOp, mlir::OpOperand *operand) { ConcreteToBConcreteTypeConverter converter; mlir::Value opVal = operand->get(); mlir::MemRefType opTy = opVal.getType().cast(); if (auto lweType = opTy.getElementType() .dyn_cast_or_null< mlir::concretelang::Concrete::LweCiphertextType>()) { // For memref of ciphertexts operands create the inner memref // subview to the ciphertext, and go back to the tensor type as BConcrete // operators works with tensor. // %op : memref> // %opInner = memref.subview %opInner[offsets...][1...][1,...] // : memref<...xConcrete.lwe_ciphertext> to // memref> auto tensorizedLweTy = converter.convertType(lweType).cast(); auto subviewResultTy = mlir::MemRefType::get( tensorizedLweTy.getShape(), tensorizedLweTy.getElementType()); auto offsets = makeCanonicalAffineApplies( builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs); mlir::SmallVector staticOffsets( opTy.getRank(), builder.getI64IntegerAttr(std::numeric_limits::min())); mlir::SmallVector staticSizes( opTy.getRank(), builder.getI64IntegerAttr(1)); mlir::SmallVector staticStrides( opTy.getRank(), builder.getI64IntegerAttr(1)); auto subViewOp = builder.create( loc, subviewResultTy, opVal, offsets, mlir::ValueRange{}, mlir::ValueRange{}, builder.getArrayAttr(staticOffsets), builder.getArrayAttr(staticSizes), builder.getArrayAttr(staticStrides)); return std::pair( subViewOp, builder.create(loc, subViewOp)); } else { // For memref of non ciphertexts load the value from the memref. // with %op : memref // %opInner = memref.load %op[offsets...] : memref auto offsets = makeCanonicalAffineApplies( builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs); return std::pair( nullptr, builder.create(loc, operand->get(), offsets)); } } static void inlineRegionAndEmitTensorStore(mlir::OpBuilder &builder, mlir::Location loc, mlir::linalg::LinalgOp linalgOp, llvm::ArrayRef indexedValues, mlir::ValueRange outputBuffers) { // Clone the block with the new operands auto &block = linalgOp->getRegion(0).front(); mlir::BlockAndValueMapping map; map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { auto *newOp = builder.clone(op, map); map.map(op.getResults(), newOp->getResults()); } // Create memref.tensor_store operation for each terminator operands auto *terminator = block.getTerminator(); for (mlir::OpOperand &operand : terminator->getOpOperands()) { mlir::Value toStore = map.lookupOrDefault(operand.get()); builder.create( loc, toStore, outputBuffers[operand.getOperandNumber()]); } } template class LinalgRewritePattern : public mlir::OpInterfaceConversionPattern { public: using mlir::OpInterfaceConversionPattern< mlir::linalg::LinalgOp>::OpInterfaceConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::linalg::LinalgOp linalgOp, mlir::ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const override { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); mlir::SmallVector allIvs; mlir::linalg::GenerateLoopNest::doit( rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange ivs, mlir::ValueRange operandValuesToUse) -> mlir::scf::ValueVector { // Keep indexed values to replace the linalg.generic block arguments // by them mlir::SmallVector indexedValues; indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); assert( operandValuesToUse == linalgOp->getOperands() && "expect operands are captured and not passed by loop argument"); allIvs.append(ivs.begin(), ivs.end()); // For all input operands create the inner operand for (mlir::OpOperand *inputOperand : linalgOp.getInputOperands()) { auto innerOperand = makeOperandLoadOrSubview( builder, loc, allIvs, linalgOp, inputOperand); indexedValues.push_back(innerOperand.second); } // For all output operands create the inner operand assert(linalgOp.getOutputOperands() == linalgOp.getOutputBufferOperands() && "expect only memref as output operands"); mlir::SmallVector outputBuffers; for (mlir::OpOperand *outputOperand : linalgOp.getOutputOperands()) { auto innerOperand = makeOperandLoadOrSubview( builder, loc, allIvs, linalgOp, outputOperand); indexedValues.push_back(innerOperand.second); assert(innerOperand.first != nullptr && "Expected a memref subview as output buffer"); outputBuffers.push_back(innerOperand.first); } // Finally inline the linalgOp region inlineRegionAndEmitTensorStore(builder, loc, linalgOp, indexedValues, outputBuffers); return mlir::scf::ValueVector{}; }); rewriter.eraseOp(linalgOp); return mlir::success(); }; }; void ConcreteToBConcretePass::runOnOperation() { auto op = this->getOperation(); // First of all we transform LinalgOp that work on tensor of ciphertext to // work on memref. { mlir::ConversionTarget target(getContext()); mlir::BufferizeTypeConverter converter; // Mark all Standard operations legal. target .addLegalDialect(); // Mark all Linalg operations illegal as long as they work on encrypted // tensors. target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return converter.isLegal(op); }); mlir::RewritePatternSet patterns(&getContext()); mlir::linalg::populateLinalgBufferizePatterns(converter, patterns); if (failed(applyPartialConversion(op, target, std::move(patterns)))) { signalPassFailure(); return; } } // Then convert ciphertext to tensor or add a dimension to tensor of // ciphertext and memref of ciphertext { mlir::ConversionTarget target(getContext()); ConcreteToBConcreteTypeConverter converter; mlir::OwningRewritePatternList patterns(&getContext()); // All BConcrete ops are legal after the conversion target.addLegalDialect(); // Add Concrete ops are illegal after the conversion unless those which are // explicitly marked as legal (more or less operators that didn't work on // ciphertexts) target.addIllegalDialect(); target.addLegalOp(); target.addLegalOp(); // Add patterns to convert the zero ops to tensor.generate patterns .insert, ZeroOpPattern>( &getContext()); target.addLegalOp(); // Add patterns to trivialy convert Concrete op to the equivalent // BConcrete op target.addLegalOp(); patterns.insert< LowToBConcrete, LowToBConcrete< mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp, mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>, LowToBConcrete< mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, LowToBConcrete< mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, LowToBConcrete, LowToBConcrete, LowToBConcrete>( &getContext()); patterns.insert(&getContext()); // Add patterns to rewrite tensor operators that works on encrypted // tensors patterns.insert(&getContext()); target.addDynamicallyLegalOp< mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertSliceOp, mlir::tensor::FromElementsOp>( [&](mlir::Operation *op) { return converter.isLegal(op->getResult(0).getType()); }); target.addLegalOp(); // Add patterns to rewrite some of memref ops that was introduced by the // linalg bufferization of encrypted tensor (first conversion of this pass) insertTensorShapeOpPattern( getContext(), patterns, target); insertTensorShapeOpPattern( getContext(), patterns, target); // Add patterns to rewrite linalg op to nested loops with views on // ciphertexts patterns.insert>(converter, &getContext()); target.addLegalOp(); // Add patterns to do the conversion of func mlir::populateFuncOpTypeConversionPattern(patterns, converter); target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && converter.isLegal(&funcOp.getBody()); }); // Add patterns to convert some memref operators that is generated by // previous step insertMemrefOpPattern(getContext(), patterns, target); // Conversion of RT Dialect Ops patterns.add>(patterns.getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DataflowTaskOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } } namespace mlir { namespace concretelang { std::unique_ptr> createConvertConcreteToBConcretePass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir