// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include "concretelang/Conversion/Passes.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Support/Constants.h" namespace arith = mlir::arith; namespace linalg = mlir::linalg; namespace tensor = mlir::tensor; namespace bufferization = mlir::bufferization; namespace FHE = mlir::concretelang::FHE; namespace FHELinalg = mlir::concretelang::FHELinalg; struct DotToLinalgGeneric : public ::mlir::OpRewritePattern { DotToLinalgGeneric(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} /// This rewrite pattern transforms any instance of /// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an /// appropriate region using `FHE.mul_eint_int` and /// `FHE.add_eint` operations, an appropriate specification for the /// iteration dimensions and appropriate operations managing the /// accumulator of `linalg.generic`. /// /// Example: /// /// %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) : /// (tensor<4x!FHE.eint<0>>, /// tensor<4xi32>) -> (!FHE.eint<0>) /// /// becomes: /// /// %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<0>> /// %1 = linalg.generic { /// indexing_maps = [#map0, #map0, #map1], /// iterator_types = ["reduction"] /// } /// ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>) /// outs(%0 : tensor<1x!FHE.eint<0>>) { /// ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>): /// %4 = "FHE.mul_eint_int"(%arg2, %arg3) : /// (!FHE.eint<0>, i32) -> !FHE.eint<0> /// /// %5 = "FHE.add_eint"(%4, %arg4) : /// (!FHE.eint<0>, !FHE.eint<0>) -> !FHE.eint<0> /// /// linalg.yield %5 : !FHE.eint<0> /// } -> tensor<1x!FHE.eint<0>> /// /// %c0 = constant 0 : index /// %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>> /// ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp, ::mlir::PatternRewriter &rewriter) const override { auto zeroTensorOp = rewriter.create( dotOp.getLoc(), mlir::RankedTensorType::get({1}, dotOp.getType())); // Create `linalg.generic` op llvm::SmallVector resTypes{zeroTensorOp.getType()}; llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; llvm::SmallVector outs{zeroTensorOp}; llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)}, this->getContext())}; llvm::SmallVector itTypes{"reduction"}; llvm::StringRef doc{""}; llvm::StringRef call{""}; auto regBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::concretelang::FHE::MulEintIntOp mul = nestedBuilder.create( dotOp.getLoc(), blockArgs[0], blockArgs[1]); mlir::concretelang::FHE::AddEintOp add = nestedBuilder.create( dotOp.getLoc(), mul, blockArgs[2]); nestedBuilder.create(dotOp.getLoc(), add.getResult()); }; mlir::linalg::GenericOp gop = rewriter.create( dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call, regBuilder); // Return value is still a 1-dimensional tensor; extract first // element and use it as a replacement for the result of the dot // operation mlir::Value idx0 = rewriter.create(dotOp.getLoc(), 0); llvm::SmallVector indexes{idx0}; mlir::Value res = rewriter.create( dotOp.getLoc(), gop.getResult(0), indexes); rewriter.replaceOp(dotOp, {res}); return ::mlir::success(); }; }; mlir::AffineMap getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, const mlir::RankedTensorType &operandType, ::mlir::PatternRewriter &rewriter) { mlir::SmallVector affineExprs; auto resultShape = resultType.getShape(); auto operandShape = operandType.getShape(); affineExprs.reserve(operandShape.size()); size_t deltaNumDim = resultShape.size() - operandShape.size(); for (size_t i = 0; i < operandShape.size(); i++) { if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) { affineExprs.push_back(rewriter.getAffineConstantExpr(0)); } else { affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); } } return mlir::AffineMap::get(resultShape.size(), 0, affineExprs, rewriter.getContext()); } /// This create an affine map following the broadcasting rules, but also takes /// out one specific element of the LUT from the LUT dimension, which should be /// the last. /// /// Example: /// /// resultType: 4x2x5, operandType: 4x2x8, lut_index: 3 /// return: affine_map<(d0, d1, d2) -> (d0, d1, 3) /// last dimension of the operand is the lut size, and we take the map takes out /// the element at index 3 mlir::AffineMap getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType, const mlir::RankedTensorType &operandType, const int64_t lut_index, ::mlir::PatternRewriter &rewriter) { mlir::SmallVector affineExprs; auto resultShape = resultType.getShape(); auto operandShape = operandType.getShape(); affineExprs.reserve(operandShape.size()); // Don't take the lut dimension into account size_t deltaNumDim = resultShape.size() - operandShape.size() + 1; for (size_t i = 0; i < operandShape.size() - 1; i++) { if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) { affineExprs.push_back(rewriter.getAffineConstantExpr(0)); } else { affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); } } // Index a specific element of the LUT affineExprs.push_back(rewriter.getAffineConstantExpr(lut_index)); return mlir::AffineMap::get(resultShape.size(), 0, affineExprs, rewriter.getContext()); } /// This template rewrite pattern transforms any instance of /// operators `FHELinalgOp` that implements the broadasting rules to an /// instance of `linalg.generic` with an appropriate region using `FHEOp` /// operation, an appropriate specification for the iteration dimensions and /// appropriate operations managing the accumulator of `linalg.generic`. /// /// Example: /// /// %res = FHELinalg.op(%lhs, %rhs): /// (tensor>, tensor) /// -> tensor> /// /// becomes: /// /// #maps_0 = [ /// affine_map<(a$R", ..., a$A, ..., a1) -> /// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>, /// affine_map<(a$R", ..., a1) -> /// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>, /// affine_map<(a$R", ..., a1) -> (a$R", ..., a1) /// ] /// #attributes_0 { /// indexing_maps = #maps_0, /// iterator_types = ["parallel", ..., "parallel"], // $R" parallel /// } /// %init = linalg.init_tensor [DR",...,D1"] /// : tensor> /// %res = linalg.generic { /// ins(%lhs, %rhs: tensor>,tensor) /// outs(%init : tensor>) /// { /// ^bb0(%arg0: !FHE.eint

, %arg1: T): /// %0 = FHE.op(%arg0, %arg1): !FHE.eint

, T -> /// !FHE.eint

/// linalg.yield %0 : !FHE.eint

/// } /// } /// template struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHELinalgOp linalgOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)linalgOp->getResult(0).getType()) .cast(); mlir::RankedTensorType lhsTy = ((mlir::Type)linalgOp.lhs().getType()).cast(); mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.rhs().getType()).cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( linalgOp.getLoc(), resultTy, mlir::ValueRange{}); // Create the affine #maps_0 llvm::SmallVector maps{ getBroadcastedAffineMap(resultTy, lhsTy, rewriter), getBroadcastedAffineMap(resultTy, rhsTy, rewriter), getBroadcastedAffineMap(resultTy, resultTy, rewriter), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { FHEOp fheOp = nestedBuilder.create(linalgOp.getLoc(), blockArgs[0], blockArgs[1]); nestedBuilder.create(linalgOp.getLoc(), fheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{linalgOp.lhs(), linalgOp.rhs()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(linalgOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(linalgOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; template inline mlir::RankedTensorType getRankedTensorType(T v) { return ((mlir::Type)v.getType()).cast(); } llvm::SmallVector parallelIteratorType(int n) { return llvm::SmallVector(n, "parallel"); } /// This class rewrite pattern transforms any instance of /// operators `FHELinalg.ApplyMappedLookupTableEintOp` that implements the /// broadasting rules to an instance of `linalg.generic` with an appropriate /// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate /// specification for the iteration dimensions and appropriate operations /// managing the accumulator of `linalg.generic`. /// /// Example: /// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) /// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) /// -> tensor<2x3x!FHE.eint<2>> /// /// becomes: /// /// #map = affine_map<(d0, d1) -> (d0, d1)> /// %init = linalg.init_tensor [2, 3] : tensor<2x3x!TFHE.glwe<{_,_,_}{2}>> /// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types /// = ["parallel", "parallel"]} ins(%arg0, %arg2 : /// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 : /// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) { /// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5: /// !TFHE.glwe<{_,_,_}{2}>): // no predecessors /// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1] /// : tensor<5x4xi64> to tensor<4xi64> /// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]]) /// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, /// glweDimension = -1 : i32, /// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = /// -1 : i32, polynomialSize = -1 : i32} /// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> /// !TFHE.glwe<{_,_,_}{2}> linalg.yield %res : /// !TFHE.glwe<{_,_,_}{2}> /// } -> tensor<2x3x!TFHE.glwe<{_,_,_}{2}>> namespace FHELinalg = mlir::concretelang::FHELinalg; struct FHELinalgApplyMappedLookupTableToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgApplyMappedLookupTableToLinalgGeneric( ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHELinalg::ApplyMappedLookupTableEintOp mappedLookup, ::mlir::PatternRewriter &rewriter) const override { namespace arith = mlir::arith; namespace linalg = mlir::linalg; namespace tensor = mlir::tensor; namespace FHE = mlir::concretelang::FHE; using Values = llvm::SmallVector; using Types = llvm::SmallVector; using AffineMaps = llvm::SmallVector; using sliceArg = llvm::SmallVector; auto input = mappedLookup.t(); auto luts = mappedLookup.luts(); auto map = mappedLookup.map(); auto loc = mappedLookup.getLoc(); auto tensorTy = getRankedTensorType(input); auto lutsTy = getRankedTensorType(luts); auto resultTy = getRankedTensorType(mappedLookup->getResult(0)); auto elementTy = resultTy.getElementType(); auto lutElmtTy = lutsTy.getElementType(); auto lutsShape = lutsTy.getShape(); auto lutSize = lutsShape[lutsShape.size() - 1]; auto resultShape = resultTy.getShape(); auto integer = [&](auto v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v); }; auto _0_ = integer(0); auto _1_ = integer(1); auto lutSizeValue = integer(lutSize); // Create the body of the `linalg.generic` op // %arg0 is an element of t (encrypted int) // %arg1 is the lut index (i64) // %arg2 is the output element auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { auto tElmt = blockArgs[0]; auto lutIdx = blockArgs[1]; // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] : // tensor to tensor sliceArg offsets{lutIdx, _0_}; sliceArg sizes{_1_, lutSizeValue}; sliceArg strides{_1_, _1_}; auto lutTy = mlir::RankedTensorType::get({static_cast(lutSize)}, lutElmtTy); mlir::Value lut = nestedBuilder.create( loc, lutTy, luts, offsets, sizes, strides); // %res1 = apply_lookup_table %arg0 %lut auto lookup = nestedBuilder.create( loc, elementTy, tElmt, lut); // linalg.yield %res1 : !FHE.eint<2> nestedBuilder.create(loc, lookup.getResult()); }; auto output = rewriter.create( loc, resultTy, mlir::ValueRange{}); // Create the `linalg.g eneric` op Types resTys{resultTy}; Values ins{input, map}; Values outs{output}; auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter); AffineMaps affineMaps{indexOfInput, indexOfInput, indexOfInput}; auto iteratorTypes = parallelIteratorType(resultShape.size()); auto genericOp = rewriter.create( loc, resTys, ins, outs, affineMaps, iteratorTypes, lambdaBlock); rewriter.replaceOp(mappedLookup, {genericOp.getResult(0)}); return ::mlir::success(); }; }; /// This class rewrite pattern transforms any instance of /// operators `FHELinalg.ApplyMultiLookupTableEintOp` that implements the /// broadasting rules to an instance of `linalg.generic` with an appropriate /// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate /// specification for the iteration dimensions and appropriate operaztions /// managing the accumulator of `linalg.generic`. /// /// Example: /// /// %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts): /// (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> /// /// becomes: /// /// #maps_0 = [ /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// #attributes_0 { /// indexing_maps = #maps_0, /// iterator_types = ["parallel", "parallel"], /// } /// %init = linalg.init_tensor [4, 3] /// : tensor<4x3x!FHE.eint<2>> /// %res = linalg.generic { /// ins(%t, %luts: tensor<4x3x!FHE.eint

>) /// outs(%init : tensor<4x3x!FHE.eint<2>>) /// { /// ^bb0(%arg0: !FHE.eint<2>): /// %i_lut = linalg.index 0 ; index /// %lut = tensor.extract_slice %arg21[%i_lut, 0] [1, lut_size] [1, /// 1] : ... tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, /// %lut) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension /// = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS /// = -1 : i32, polynomialSize = -1 : i32} : /// (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> /// !TFHE.glwe<{_,_,_}{2}> /// linalg.yield %0 : !FHE.eint<2> /// } /// } /// struct FHELinalgApplyMultiLookupTableToLinalgGeneric : public mlir::OpRewritePattern< mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp> { FHELinalgApplyMultiLookupTableToLinalgGeneric( ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern< mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp>( context, benefit) {} ::mlir::LogicalResult matchAndRewrite( mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp fheLinalgLutOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)fheLinalgLutOp->getResult(0).getType()) .cast(); mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType()) .cast(); auto luts = fheLinalgLutOp.luts(); mlir::RankedTensorType lutsTy = getRankedTensorType(luts); auto lutElmtTy = lutsTy.getElementType(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{}); auto lutsShape = lutsTy.getShape(); auto lut_size = lutsShape[lutsShape.size() - 1]; auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter); // Create the affine maps llvm::SmallVector maps{indexOfInput, indexOfInput}; // Create the iterator_types auto iteratorTypes = parallelIteratorType(resultTy.getShape().size()); auto integer = [&](auto v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v); }; // We need to know with linalg.generic index to use for lut // In broadcast case the lut index is inner dimensions of the tensor index auto tensorShape = tensorTy.getShape(); auto tensorRank = tensorTy.getShape().size(); auto lutsRank = lutsShape.size() - 1; // do not count inner dim of luts auto lutIndexDimAt = tensorRank - lutsRank; llvm::SmallVector indexLutsToLinalg(lutsRank); for (auto lutsIndex = 0u; lutsIndex < lutsRank; lutsIndex++) { auto tensorIndex = lutIndexDimAt + lutsIndex; if (tensorShape[tensorIndex] != lutsShape[lutsIndex]) { llvm::errs() << "ERROR: Broadcast only works by having more outer " "dims.\nConflict: " << tensorShape[tensorIndex] << " (tensor dim " << tensorIndex << ") is not compatible with " << lutsShape[lutsIndex] << " (luts dim " << lutsIndex << ")\n\n"; return ::mlir::LogicalResult::failure(); }; indexLutsToLinalg[lutsIndex] = tensorIndex; } auto _0_ = integer(0); auto _1_ = integer(1); auto lutSizeValue = integer(lut_size); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { auto loc = fheLinalgLutOp.getLoc(); auto tElmt = blockArgs[0]; // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] : // tensor to tensor auto sliceArgDim = lutsShape.size(); using sliceArg = llvm::SmallVector; sliceArg offsets(sliceArgDim, _0_); auto lutsIndex = 0; for (auto index : indexLutsToLinalg) { auto offset = nestedBuilder.create(loc, index); offsets[lutsIndex++] = (mlir::OpFoldResult)offset; } sliceArg sizes(sliceArgDim, _1_); sizes[sliceArgDim - 1] = lutSizeValue; sliceArg strides(sliceArgDim, _1_); auto lutTy = mlir::RankedTensorType::get({static_cast(lut_size)}, lutElmtTy); mlir::Value lut = nestedBuilder.create( loc, lutTy, luts, offsets, sizes, strides); auto lutOp = nestedBuilder.create( loc, resultTy.getElementType(), tElmt, lut); nestedBuilder.create(loc, lutOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{fheLinalgLutOp.t()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create( fheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(fheLinalgLutOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; /// This template rewrite pattern transforms any instance of /// operators `FHELinalg.apply_lookup_table` that implements the broadasting /// rules to an instance of `linalg.generic` with an appropriate region using /// `FHE.apply_lookup_table` operation, an appropriate specification for the /// iteration dimensions and appropriate operations managing the accumulator of /// `linalg.generic`. /// /// Example: /// /// FHELinalg.apply_lookup_table(%t, %lut): /// tensor>, tensor /// -> tensor> /// /// becomes: /// /// #maps_0 = [ /// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, /// affine_map<(aN, ..., a1) -> (aN, ..., a1)> /// ] /// #attributes_0 { /// indexing_maps = #maps_0, /// iterator_types = ["parallel",..],//N parallel /// } /// %init = linalg.init_tensor [DN,...,D1] /// : tensor> /// %res = linalg.generic { /// ins(%t: tensor>) /// outs(%init : tensor>) /// { /// ^bb0(%arg0: !FHE.eint

): /// %0 = FHE.apply_lookup_table(%arg0, %lut): !FHE.eint

, /// tensor<4xi64> -> !FHE.eint /// linalg.yield %0 : !FHE.eint /// } /// } /// struct FHELinalgApplyLookupTableToLinalgGeneric : public mlir::OpRewritePattern< mlir::concretelang::FHELinalg::ApplyLookupTableEintOp> { FHELinalgApplyLookupTableToLinalgGeneric( ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern< mlir::concretelang::FHELinalg::ApplyLookupTableEintOp>(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHELinalg::ApplyLookupTableEintOp lutOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)lutOp->getResult(0).getType()) .cast(); mlir::RankedTensorType tTy = ((mlir::Type)lutOp.t().getType()).cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( lutOp.getLoc(), resultTy, mlir::ValueRange{}); // Create the affine #maps_0 llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(tTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::concretelang::FHE::ApplyLookupTableEintOp fheOp = nestedBuilder.create( lutOp.getLoc(), resultTy.getElementType(), blockArgs[0], lutOp.lut()); nestedBuilder.create(lutOp.getLoc(), fheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{lutOp.t()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(lutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(lutOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; /// This template rewrite pattern transforms any instance of /// operators `FHELinalg.neg_eint` to an instance of `linalg.generic` with an /// appropriate region using `FHE.neg_eint` operation, an appropriate /// specification for the iteration dimensions and appropriate operations /// managing the accumulator of `linalg.generic`. /// /// Example: /// /// FHELinalg.neg_eint(%tensor): /// tensor> -> tensor> /// /// becomes: /// /// #maps_0 = [ /// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, /// affine_map<(aN, ..., a1) -> (aN, ..., a1)> /// ] /// #attributes_0 { /// indexing_maps = #maps_0, /// iterator_types = ["parallel",..],//N parallel /// } /// %init = linalg.init_tensor [DN,...,D1] /// : tensor> /// %res = linalg.generic { /// ins(%tensor: tensor>) /// outs(%init : tensor>) /// { /// ^bb0(%arg0: !FHE.eint

): /// %0 = FHE.neg_eint(%arg0): !FHE.eint

-> !FHE.eint /// linalg.yield %0 : !FHE.eint /// } /// } /// struct FHELinalgNegEintToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgNegEintToLinalgGeneric( ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHELinalg::NegEintOp negEintOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)negEintOp->getResult(0).getType()) .cast(); mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType()) .cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( negEintOp.getLoc(), resultTy, mlir::ValueRange{}); // Create the affine #maps_0 llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::concretelang::FHE::NegEintOp fheOp = nestedBuilder.create( negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]); nestedBuilder.create(negEintOp.getLoc(), fheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{negEintOp.tensor()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(negEintOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(negEintOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; /// This template rewrite pattern transforms any instance of /// operators `FHELinalgMatmulOp` to an instance of `linalg.generic` /// with an appropriate region using a builder that create the multiplication /// operators and `FHE.add_eint` operation, an appropriate specification for /// the iteration dimensions and appropriate operations managing the accumulator /// of `linalg.generic`. /// /// Example: /// /// "FHELinalg.matmul_eint_int(%a, %b) : /// (tensor>, tensor) -> /// tensor>" /// /// becomes: /// /// #maps_0 = [ /// (m, n, p) -> (m, p), /// (m, n, p) -> (p, n), /// (m, n, p) -> (m, n) /// ] /// #attributes_0 = { /// indexing_maps = #maps_0, /// iterator_types = ["parallel", "parallel", "reduction"] /// } /// %init = FHE.zero_tensor : tensor> /// linalg.generic #attributes_0 /// ins(%A, %B : tensor>, /// tensor) /// outs(%C : tensor>) /// { /// ^bb0(%a: !FHE.eint

, %b: ip', %c: !FHE.eint

) : /// %d = createMulOp(%a, %b): !FHE.eint

/// %e = "FHE.add_eint"(%c, %d): /// (!FHE.eint

, !FHE.eint

) -> !FHE.eint

/// linalg.yield %e : !FHE.eint

/// } /// template struct FHELinalgMatmulToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgMatmulToLinalgGeneric( mlir::MLIRContext *context, std::function createMulOp, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), createMulOp(createMulOp) {} mlir::LogicalResult matchAndRewrite(FHELinalgMatmulOp matmulOp, mlir::PatternRewriter &rewriter) const override { mlir::Location location = matmulOp.getLoc(); mlir::Value lhs = matmulOp.lhs(); mlir::Value rhs = matmulOp.rhs(); mlir::Value out = matmulOp.getResult(); auto lhsType = ((mlir::Type)lhs.getType()).cast(); auto rhsType = ((mlir::Type)rhs.getType()).cast(); auto outType = ((mlir::Type)out.getType()).cast(); llvm::ArrayRef lhsShape = lhsType.getShape(); llvm::ArrayRef rhsShape = rhsType.getShape(); llvm::ArrayRef outShape = outType.getShape(); int64_t lhsDims = (int64_t)lhsShape.size(); int64_t rhsDims = (int64_t)rhsShape.size(); int64_t outDims = (int64_t)outShape.size(); mlir::Value zeros = rewriter.create(location, outType).getResult(); auto ins = llvm::SmallVector{lhs, rhs}; auto outs = llvm::SmallVector{zeros}; auto iteratorTypes = llvm::SmallVector{}; auto lhsAffineExpressions = llvm::SmallVector{}; auto rhsAffineExpressions = llvm::SmallVector{}; auto outAffineExpressions = llvm::SmallVector{}; if (lhsDims >= 2 && rhsDims >= 2) { // here are some example shapes to help understand the logic below // notation: lhs.shape @ rhs.shape -> output.shape // MxN @ NxP -> MxP // KxLxMxN @ NxP -> KxLxMxP // KxLxMxN @ LxNxP -> KxLxMxP // Kx1xMxN @ LxNxP -> KxLxMxP // MxN @ KxLxNxP -> KxLxMxP // LxMxN @ KxLxNxP -> KxLxMxP // 1xMxN @ KxLxNxP -> KxLxMxP // make iterator types // ["parallel", "parallel", ..., "parallel", "reduction"] // --------------------------------------- // output.shape.size() times // // think of it as // // - 1st iterator is for the 1st dimension of the output (K in examples) // - 2nd iterator is for the 2nd dimension of the output (L in examples) // - ... // - Nth iterator is for the Nth dimension of the output // - Last iterator is for the reduced dimension (N in the examples) for (int64_t i = 0; i < outDims; i++) { iteratorTypes.push_back(mlir::getParallelIteratorTypeName()); } iteratorTypes.push_back(mlir::getReductionIteratorTypeName()); // we need to put appropriate affine dimension expressions // that match lhs.shape on iterator types array // in KxLxMxN @ NxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dK, dL, dM, dN) // == // (d0, d1, d2, d3, d4) -> (d0, d1, d2, d4) // in LxMxN @ KxLxNxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dL, dM, dN) // == // (d0, d1, d2, d3, d4) -> (d1, d2, d4) // in MxN @ KxLxNxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dM, dN) // == // (d0, d1, d2, d3, d4) -> (d2, d4) // so the first AffineDimExpr we need to create is // output.shape.size() - lhs.shape.size() == outDims - lhsDims // then we need to add all dims in the output except it's last dim // so, we iterate up to output.shape.size() - 1 == outDims - 1 // and finally, we add the AffineDimExpr corresponding to `N` // which is at the last index of `iteratorTypes` int64_t lhsDim = 0; for (int64_t outDim = outDims - lhsDims; outDim < outDims - 1; outDim++) { if (lhsDim < lhsDims - 2 && lhsShape[lhsDim] == 1) { // broadcasted so current `dim` will always be indexed with `0` lhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } else { assert(lhsShape[lhsDim] == outShape[outDim]); lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim)); } lhsDim++; } lhsAffineExpressions.push_back( rewriter.getAffineDimExpr(iteratorTypes.size() - 1)); // we need to put appropriate affine dimension expressions // that match rhs.shape on iterator types array // in KxLxMxN @ NxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dN, dP) // == // (d0, d1, d2, d3, d4) -> (d4, d3) // in KxLxMxN @ LxNxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dL, dN, dP) // == // (d0, d1, d2, d3, d4) -> (d1, d4, d3) // in LxMxN @ KxLxNxP -> KxLxMxP, we need to create the following map // // (dK, dL, dM, dP, dN) -> (dK, dL, dN, dP) // == // (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3) // so the first AffineDimExpr we need to create is // output.shape.size() - rhs.shape.size() == outDims - rhsDims // then we need to add all dims in the output except it's last 2 dims // so, we iterate up to output.shape.size() - 2 == outDims - 2 // and finally, we add the AffineDimExpr corresponding to `N` and `P` // which is at the last and one before last indices of `iteratorTypes` int64_t rhsDim = 0; for (int64_t outDim = outDims - rhsDims; outDim < outDims - 2; outDim++) { if (rhsShape[rhsDim] == 1) { // broadcasted so current `dim` will always be indexed with `0` rhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } else { assert(rhsShape[rhsDim] == outShape[outDim]); rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim)); } rhsDim++; } rhsAffineExpressions.push_back( rewriter.getAffineDimExpr(iteratorTypes.size() - 1)); rhsAffineExpressions.push_back( rewriter.getAffineDimExpr(iteratorTypes.size() - 2)); for (int64_t i = 0; i < outDims; i++) { outAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } } else if (lhsDims == 1 && rhsDims >= 2) { // here are some example shapes to help understand the logic below // notation: lhs.shape @ rhs.shape -> output.shape // N @ NxP -> P // N @ LxNxP -> LxP // N @ KxLxNxP -> KxLxP int64_t commonDim = rhsDims - 2; for (int64_t i = 0; i < rhsDims; i++) { if (i == commonDim) { iteratorTypes.push_back(mlir::getReductionIteratorTypeName()); } else { iteratorTypes.push_back(mlir::getParallelIteratorTypeName()); } } lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(commonDim)); for (int64_t i = 0; i < rhsDims; i++) { rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } for (int64_t i = 0; i < rhsDims; i++) { if (i != commonDim) { outAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } } } else if (lhsDims >= 2 && rhsDims == 1) { // here are some example shapes to help understand the logic below // notation: lhs.shape @ rhs.shape -> output.shape // MxN @ N -> M // LxMxN @ N -> LxM // KxLxMxN @ N -> KxLxM for (int64_t i = 0; i < lhsDims - 1; i++) { iteratorTypes.push_back(mlir::getParallelIteratorTypeName()); } iteratorTypes.push_back(mlir::getReductionIteratorTypeName()); for (int64_t i = 0; i < lhsDims; i++) { lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(lhsDims - 1)); for (int64_t i = 0; i < lhsDims - 1; i++) { outAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } } auto maps = llvm::SmallVector{ mlir::AffineMap::get(iteratorTypes.size(), 0, lhsAffineExpressions, rewriter.getContext()), mlir::AffineMap::get(iteratorTypes.size(), 0, rhsAffineExpressions, rewriter.getContext()), mlir::AffineMap::get(iteratorTypes.size(), 0, outAffineExpressions, rewriter.getContext()), }; mlir::Type outElementType = outType.getElementType(); auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::Value multiplication = createMulOp(nestedBuilder, location, outElementType, blockArgs[0], blockArgs[1]) .getResult(); mlir::Value addition = nestedBuilder .create(location, outElementType, blockArgs[2], multiplication) .getResult(); nestedBuilder.create(location, addition); }; auto resultTypes = llvm::SmallVector{outType}; mlir::Value result = rewriter .create(location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder) .getResult(0); rewriter.replaceOp(matmulOp, {result}); return mlir::success(); }; private: std::function createMulOp; }; /// This rewrite pattern transforms any instance of operators /// `FHELinalg.sum` to an instance of `linalg.generic`. /// /// Example: /// /// %result = "FHELinalg.sum"(%input) : /// tensor>() -> !FHE.eint

/// /// becomes: /// /// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> /// #map1 = affine_map<(i0, i1, ..., iN) -> (0)> /// /// %accumulator = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> /// %accumulation = linalg.generic /// { /// indexing_maps = [#map0, #map1], /// iterator_types = ["reduction", "reduction", ..., "reduction"] /// } /// ins(%input : tensor>) /// outs(%accumulator : tensor<1x!FHE.eint<7>>) /// { /// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>): /// %c = "FHE.add_eint"(%a, %b) : /// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> /// linalg.yield %c : !FHE.eint<7> /// } -> tensor<1x!FHE.eint<7>> /// /// %index = arith.constant 0 : index /// %result = tensor.extract %index : tensor<1x!FHE.eint<7>> /// struct SumToLinalgGeneric : public ::mlir::OpRewritePattern { SumToLinalgGeneric(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::SumOp>( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::SumOp sumOp, ::mlir::PatternRewriter &rewriter) const override { mlir::Location location = sumOp.getLoc(); mlir::Value input = sumOp.getOperand(); mlir::Value output = sumOp.getResult(); auto inputType = input.getType().dyn_cast(); mlir::Type outputType = output.getType(); llvm::ArrayRef inputShape = inputType.getShape(); int64_t inputDimensions = inputShape.size(); bool outputIsTensor = outputType.isa(); for (int64_t size : inputShape) { if (size == 0) { mlir::Value result; if (outputIsTensor) { result = rewriter.create(location, outputType) .getResult(); } else { result = rewriter.create(location, outputType) .getResult(); } rewriter.replaceOp(sumOp, {result}); return mlir::success(); } } auto axesToDestroy = std::unordered_set{}; for (mlir::Attribute axisAttribute : sumOp.axes()) { int64_t axis = axisAttribute.cast().getInt(); axesToDestroy.insert(axis); } if (axesToDestroy.empty()) { for (int64_t i = 0; i < inputDimensions; i++) { axesToDestroy.insert(i); } } mlir::Type accumulatorType = outputType; if (!outputIsTensor) { int64_t accumulatorShape[1] = {1}; accumulatorType = // tensor of shape (1,) mlir::RankedTensorType::get(accumulatorShape, outputType); } mlir::Value accumulator = rewriter.create(location, accumulatorType) .getResult(); auto ins = llvm::SmallVector{input}; auto outs = llvm::SmallVector{accumulator}; mlir::AffineMap inputMap = mlir::AffineMap::getMultiDimIdentityMap( inputDimensions, this->getContext()); auto outputAffineExpressions = llvm::SmallVector{}; if (outputIsTensor) { for (int64_t i = 0; i < inputDimensions; i++) { bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end(); if (!ithAxisIsDestroyed) { outputAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); } else if (sumOp.keep_dims()) { outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } } } else { outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } mlir::AffineMap outputMap = mlir::AffineMap::get( inputDimensions, 0, outputAffineExpressions, rewriter.getContext()); auto maps = llvm::SmallVector{inputMap, outputMap}; auto iteratorTypes = llvm::SmallVector( inputDimensions, mlir::getParallelIteratorTypeName()); for (int64_t axis : axesToDestroy) { iteratorTypes[axis] = mlir::getReductionIteratorTypeName(); } auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::Value lhs = blockArgs[0]; mlir::Value rhs = blockArgs[1]; mlir::Value addition = nestedBuilder.create(location, lhs, rhs).getResult(); nestedBuilder.create(location, addition); }; auto resultTypes = llvm::SmallVector{accumulatorType}; mlir::Value accumulation = rewriter .create(location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder) .getResult(0); mlir::Value result = accumulation; if (!outputIsTensor) { auto indices = llvm::SmallVector{ rewriter.create(location, 0).getResult(), }; result = rewriter.create(location, accumulation, indices) .getResult(); } rewriter.replaceOp(sumOp, {result}); return mlir::success(); }; }; /// This rewrite pattern transforms any instance of operators /// `FHELinalg.transpose` to an instance of `linalg.generic`. /// /// Example: /// /// %result = "FHELinalg.transpose"(%input: tensor>) /// -> tensor /// /// becomes: /// /// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)> /// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> /// /// %accumulator = "FHE.zero_tensor"() : () -> /// tensor> %result = linalg.generic /// { /// indexing_maps = [#map0, #map1], /// iterator_types = ["parallel", "parallel", ..., "parallel"] /// } /// ins(%input : tensor>) /// outs(%accumulator : tensor>) /// { /// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>): /// linalg.yield %a : !FHE.eint<7> /// } -> tensor> /// struct TransposeToLinalgGeneric : public ::mlir::OpRewritePattern< mlir::concretelang::FHELinalg::TransposeOp> { TransposeToLinalgGeneric(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::TransposeOp>( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::TransposeOp transposeOp, ::mlir::PatternRewriter &rewriter) const override { mlir::Value input = transposeOp.getOperand(); mlir::Value output = transposeOp.getResult(); auto inputType = input.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); auto n_dim = inputType.getShape().size(); mlir::Location location = transposeOp.getLoc(); // Initialize empty tensor to fill with transpose result mlir::Value zeroTensor = rewriter.create(location, outputType).getResult(); std::vector perms = {}; mlir::ArrayAttr axes = transposeOp.axes(); if (axes.empty()) { for (int i = n_dim - 1; i >= 0; i--) { perms.push_back(i); } } else { for (mlir::Attribute axisAttribute : axes) { int64_t axis = axisAttribute.cast().getInt(); perms.push_back(axis); } } llvm::SmallVector resultTypes{zeroTensor.getType()}; auto ins = llvm::SmallVector{input}; auto outs = llvm::SmallVector{zeroTensor}; llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()), mlir::AffineMap::getPermutationMap(perms, this->getContext()), }; auto iteratorTypes = parallelIteratorType(n_dim); // The maps will be responsible for changing item positions, we just return // items here auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::Value item = blockArgs[0]; nestedBuilder.create(location, item); }; mlir::Value result = rewriter .create(location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder) .getResult(0); rewriter.replaceOp(transposeOp, {result}); return mlir::success(); }; }; /// This rewrite pattern transforms any instance of operators /// `FHELinalg.from_element` to an instance of `tensor.from_elements`. /// /// Example: /// /// %result = "FHELinalg.from_element"(%x) : (Type) -> tensor<1xType> /// /// becomes: /// /// %result = tensor.from_elements %x : (Type) -> tensor<1xType> /// struct FromElementToTensorFromElements : public ::mlir::OpRewritePattern< mlir::concretelang::FHELinalg::FromElementOp> { FromElementToTensorFromElements(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern< ::mlir::concretelang::FHELinalg::FromElementOp>( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::FromElementOp op, ::mlir::PatternRewriter &rewriter) const override { auto in = op.getOperand(); auto out = op.getResult(); mlir::Value result = rewriter.create(op.getLoc(), out.getType(), in) .getResult(); rewriter.replaceOp(op, {result}); return mlir::success(); }; }; /// This rewrite pattern transforms any instance of operators /// `FHELinalg.concat` to instances of `tensor.insert_slice` /// /// Example: /// /// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } : /// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) /// -> tensor<2x7x!FHE.eint<4>> /// /// becomes: /// /// %empty = "FHE.zero_tensor"() : () -> tensor<2x7x!FHE.eint<4>> /// /// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1] /// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>> /// /// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1] /// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>> /// struct ConcatRewritePattern : public mlir::OpRewritePattern { ConcatRewritePattern(mlir::MLIRContext *context) : mlir::OpRewritePattern( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} mlir::LogicalResult matchAndRewrite(FHELinalg::ConcatOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); size_t axis = op.axis(); mlir::Value output = op.getResult(); auto outputType = output.getType().dyn_cast(); llvm::ArrayRef outputShape = outputType.getShape(); size_t outputDimensions = outputShape.size(); mlir::Value result = rewriter.create(location, outputType).getResult(); auto offsets = llvm::SmallVector{}; auto sizes = llvm::SmallVector{}; auto strides = llvm::SmallVector{}; // set up the initial values of offsets, sizes, and strides // each one has exactly `outputDimensions` number of elements // - offsets will be [0, 0, 0, ..., 0, 0, 0] // - strides will be [1, 1, 1, ..., 1, 1, 1] // - sizes will be the output shape except at the 'axis' which will be 0 for (size_t i = 0; i < outputDimensions; i++) { offsets.push_back(0); if (i == axis) { sizes.push_back(0); } else { sizes.push_back(outputShape[i]); } strides.push_back(1); } // these are not used, but they are required // for the creation of InsertSliceOp operation auto dynamicOffsets = llvm::ArrayRef{}; auto dynamicSizes = llvm::ArrayRef{}; auto dynamicStrides = llvm::ArrayRef{}; for (mlir::Value input : op.getOperands()) { auto inputType = input.getType().dyn_cast(); int64_t axisSize = inputType.getShape()[axis]; // offsets and sizes will be modified for each input tensor // if we have: // "FHELinalg.concat"(%x, %y, %z) : // ( // tensor<3x!FHE.eint<7>>, // tensor<4x!FHE.eint<7>>, // tensor<2x!FHE.eint<7>>, // ) // -> tensor<9x!FHE.eint<7>> // // for the first copy: // offsets = [0], sizes = [3], strides = [1] // // for the second copy: // offsets = [3], sizes = [4], strides = [1] // // for the third copy: // offsets = [7], sizes = [2], strides = [1] // // so in each iteration: // - the size is set to the axis size of the input // - the offset is increased by the size of the previous input sizes[axis] = axisSize; // these arrays are copied, so it's fine to modify and use them again mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets); mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes); mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides); offsets[axis] += axisSize; result = rewriter .create( location, outputType, input, result, dynamicOffsets, dynamicSizes, dynamicStrides, offsetsAttr, sizesAttr, stridesAttr) .getResult(); } rewriter.replaceOp(op, {result}); return mlir::success(); }; }; static mlir::SmallVector getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc, mlir::SmallVectorImpl &ints) { return llvm::to_vector<4>( llvm::map_range(ints, [&](int64_t val) -> mlir::OpFoldResult { return b.getIndexAttr(val); })); } /// Helper function to get the padding tensor given the padding int values, and /// the value to pad with static mlir::Value getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input, mlir::SmallVectorImpl &lowPaddingInts, mlir::SmallVectorImpl &highPaddingInts, mlir::Value pad) { assert(input.getType().isa() && "input must be RankedTensorType"); mlir::Location loc = op->getLoc(); mlir::Type rankedTensorType = mlir::tensor::PadOp::inferResultType( input.getType().cast(), lowPaddingInts, highPaddingInts); mlir::SmallVector lowPaddings = getAsOpFoldResult(b, loc, lowPaddingInts); mlir::SmallVector highPaddings = getAsOpFoldResult(b, loc, highPaddingInts); mlir::Value paddedInput = mlir::tensor::createPadScalarOp( rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings, /*packing=*/false, loc, b); return paddedInput; } mlir::Value extractContiguous4DSlice(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value input, mlir::RankedTensorType resultType, llvm::SmallVector sizes, llvm::SmallVector offsets) { return rewriter .create( loc, resultType, input, // offset llvm::SmallVector{ rewriter.getI64IntegerAttr(offsets[0]), rewriter.getI64IntegerAttr(offsets[1]), rewriter.getI64IntegerAttr(offsets[2]), rewriter.getI64IntegerAttr(offsets[3]), }, // sizes llvm::SmallVector{ rewriter.getI64IntegerAttr(sizes[0]), rewriter.getI64IntegerAttr(sizes[1]), rewriter.getI64IntegerAttr(sizes[2]), rewriter.getI64IntegerAttr(sizes[3]), }, // strides llvm::SmallVector{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), }) .getResult(); } /// Create operations for grouped convolution. This will slice the input, /// weight, and output tensors to apply separate conv2d operations. mlir::LogicalResult createGroupedConv2D(mlir::PatternRewriter &rewriter, mlir::concretelang::FHELinalg::Conv2dOp &conv2dOp, mlir::Value paddedInput, mlir::Value weight, mlir::Value outputTensor, mlir::DenseIntElementsAttr stridesAttr, mlir::DenseIntElementsAttr dilationsAttr, int64_t group) { mlir::RankedTensorType inputTy = paddedInput.getType().cast(); mlir::Type inputElemTy = inputTy.getElementType(); llvm::ArrayRef inputShape = inputTy.getShape(); llvm::SmallVector inputSliceSizes( {inputShape[0], inputShape[1] / group, inputShape[2], inputShape[3]}); mlir::RankedTensorType weightTy = weight.getType().cast(); mlir::Type weightElemTy = weightTy.getElementType(); llvm::ArrayRef weightShape = weightTy.getShape(); llvm::SmallVector weightSliceSizes( {weightShape[0] / group, weightShape[1], weightShape[2], weightShape[3]}); mlir::RankedTensorType resultTy = conv2dOp.getResult().getType().cast(); llvm::ArrayRef resultShape = resultTy.getShape(); llvm::SmallVector sliceResultSizes = { resultShape[0], weightSliceSizes[0], resultShape[2], resultShape[3]}; mlir::RankedTensorType sliceResultType = mlir::RankedTensorType::get(sliceResultSizes, inputElemTy); // slice the input, weight, and output to apply different convolutions and // store their outputs in a single result found in `finalResult` mlir::Value finalResult = outputTensor; for (int g = 0; g < group; g++) { // input[:][g * (input_C / group) : (g + 1) * (input_C / group)][:][:] mlir::Value inputSlice = extractContiguous4DSlice( rewriter, conv2dOp.getLoc(), paddedInput, mlir::RankedTensorType::get(inputSliceSizes, inputElemTy), inputSliceSizes, {0, g * inputSliceSizes[1], 0, 0}); // weight[g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:][:] mlir::Value weightSlice = extractContiguous4DSlice( rewriter, conv2dOp.getLoc(), weight, mlir::RankedTensorType::get(weightSliceSizes, weightElemTy), weightSliceSizes, {g * weightSliceSizes[0], 0, 0, 0}); // bias[:][g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:] mlir::Value biasSlice = extractContiguous4DSlice( rewriter, conv2dOp.getLoc(), outputTensor, sliceResultType, sliceResultSizes, {0, g * sliceResultSizes[1], 0, 0}); // attributes for custom linalg named op auto addOpAttr = rewriter.getNamedAttr( "add", rewriter.getStringAttr( mlir::concretelang::FHE::AddEintOp::getOperationName())); auto mulOpAttr = rewriter.getNamedAttr( "mul", rewriter.getStringAttr( mlir::concretelang::FHE::MulEintIntOp::getOperationName())); // slices are currently causing issues during scf bufferization, so we are // trying to avoid slices here by creating a new tensor and adding the bias // to it mlir::RankedTensorType biasSliceType = biasSlice.getType().cast(); mlir::Value biasUnsliced = rewriter .create( conv2dOp.getLoc(), mlir::RankedTensorType::get(biasSliceType.getShape(), biasSliceType.getElementType())) .getResult(); biasUnsliced = rewriter .create( conv2dOp.getLoc(), biasUnsliced, biasSlice) .getResult(); // apply conv mlir::Value convResult = rewriter .create( conv2dOp.getLoc(), sliceResultType, mlir::ValueRange{inputSlice, weightSlice}, biasUnsliced, stridesAttr, dilationsAttr, llvm::ArrayRef({addOpAttr, mulOpAttr})) .getResult(0); // insert result of a single conv in the final result finalResult = rewriter .create( conv2dOp.getLoc(), convResult, finalResult, llvm::SmallVector{ rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(g * sliceResultSizes[1]), rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(0), }, llvm::SmallVector{ rewriter.getI64IntegerAttr(sliceResultSizes[0]), rewriter.getI64IntegerAttr(sliceResultSizes[1]), rewriter.getI64IntegerAttr(sliceResultSizes[2]), rewriter.getI64IntegerAttr(sliceResultSizes[3]), }, llvm::SmallVector{ rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), }) .getResult(); } rewriter.replaceOp(conv2dOp, finalResult); return mlir::success(); } /// This rewrite pattern transforms any instance of operators /// `FHELinalg.conv2d` to one or multiple instances of /// `linalg.conv_2d_nchw_fchw`. The transformation consists of padding the input /// tensor, and initializing the output tensor with bias values if any. Multiple /// linalng conv operations can be generated, and their output concatenated in /// the case of grouped convolution struct FHELinalgConv2dToLinalgConv2d : public ::mlir::OpRewritePattern { FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Conv2dOp>( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::Conv2dOp conv2dOp, ::mlir::PatternRewriter &rewriter) const override { mlir::Location loc = conv2dOp->getLoc(); mlir::Value input = conv2dOp.input(); /* shape: Batch*Channels*Height*Width */ mlir::Value weight = conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */ mlir::Type inputElementType = input.getType().cast().getElementType(); // Attriutes are assumed to be correct after passing the verification mlir::SmallVector paddingInts = mlir::concretelang::FHELinalg::getPaddingFromConv2d(conv2dOp); mlir::SmallVector stridesInts = mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp); mlir::SmallVector dilationsInts = mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp); int64_t group = mlir::concretelang::FHELinalg::getGroupFromConv2d(conv2dOp); // Pad the input tensor according to padding. mlir::SmallVector lowPaddingIncludingNC = {0, 0}; lowPaddingIncludingNC.insert(lowPaddingIncludingNC.end(), paddingInts.begin() + 2, paddingInts.end()); mlir::SmallVector highPaddingIncludingNC = {0, 0}; highPaddingIncludingNC.insert(highPaddingIncludingNC.end(), paddingInts.begin(), paddingInts.begin() + 2); mlir::Value paddingValue = rewriter.create( loc, input.getType().cast().getElementType()); mlir::Value paddedInput = getPaddedTensor(conv2dOp, rewriter, input, lowPaddingIncludingNC, highPaddingIncludingNC, paddingValue); // TODO(Optimization): output tensor is being constructed in two different // ways, depending of whether there is a bias or not: // 1- There is no bias: we initialize the output tensor to encryptions of // zero // 2- There is a bias: we initialize the output tensor to encryptions of // zeros, then we add bias values. // For the second case, it can be done by initializing the output to // encryption of bias values directly mlir::Value initTensor = rewriter.create( loc, mlir::RankedTensorType::get(conv2dOp.getResult() .getType() .cast() .getShape(), inputElementType)); // Since linalg doesn't support a bias in the conv operation, we initialize // the output tensor to the bias values, so that conv results get // accumulated to it mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */ mlir::Value biasInitTensor; if (!bias) { // no bias was used biasInitTensor = initTensor; } else { // Fill the output tensor with bias values auto resultRank = initTensor.getType().cast().getRank(); mlir::SmallVector indexingMaps = { mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1), rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultRank)}; mlir::SmallVector iteratorTypes(resultRank, "parallel"); biasInitTensor = rewriter .create( loc, initTensor.getType(), bias, initTensor, indexingMaps, iteratorTypes, [](mlir::OpBuilder &b, mlir::Location loc, mlir::ValueRange args) { mlir::Value encryptedBias = b.create( loc, args[1], args[0]) .getResult(); b.create(loc, encryptedBias); }) .getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(stridesInts); auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts); // we can directly use linalg::Conv2DNchwFchwOp if group is equal to 1, but // since there is no support for groups in linalg conv operations, we need // to slice the different tensors and apply multiple convolution in case // group is greater than 1 if (group == 1) { auto addOpAttr = rewriter.getNamedAttr( "add", rewriter.getStringAttr( mlir::concretelang::FHE::AddEintOp::getOperationName())); auto mulOpAttr = rewriter.getNamedAttr( "mul", rewriter.getStringAttr( mlir::concretelang::FHE::MulEintIntOp::getOperationName())); rewriter.replaceOpWithNewOp( conv2dOp, biasInitTensor.getType(), mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, dilationsAttr, llvm::ArrayRef({addOpAttr, mulOpAttr})); return mlir::success(); } return createGroupedConv2D(rewriter, conv2dOp, paddedInput, weight, biasInitTensor, stridesAttr, dilationsAttr, group); }; }; /// This rewrite pattern transforms all instances /// of `FHELinalg.maxpool2d` to `linalg.pooling_ncw_max`. struct FHELinalgMaxpool2dToLinalgMaxpool2d : public mlir::OpRewritePattern { FHELinalgMaxpool2dToLinalgMaxpool2d(mlir::MLIRContext *context) : mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(FHELinalg::Maxpool2dOp maxpool2dOp, mlir::PatternRewriter &rewriter) const override { const mlir::Location loc = maxpool2dOp->getLoc(); const mlir::NamedAttribute maxOpAttr = rewriter.getNamedAttr( "max_signed", rewriter.getStringAttr(FHE::MaxEintOp::getOperationName())); const auto outputTy = maxpool2dOp->getResult(0).getType().cast(); const auto outputElementTy = outputTy.getElementType().cast(); mlir::Value output = rewriter.create(loc, outputTy).getResult(); if (outputElementTy.isSigned()) { const int64_t outputBitWidth = outputElementTy.getWidth(); const int64_t offsetValue = 1 << (outputBitWidth - 2); const mlir::Type offsetType = mlir::IntegerType::get(this->getContext(), outputBitWidth + 1); const mlir::Type offsetTensorType = mlir::RankedTensorType::get({1}, offsetType); const llvm::SmallVector offsetTensorAttr = { mlir::IntegerAttr::get(offsetType, offsetValue)}; const mlir::Attribute offsetAttr = mlir::DenseElementsAttr::get(offsetTensorType, offsetTensorAttr); const mlir::Value offset = rewriter.create(loc, offsetAttr); output = rewriter.create(loc, output, offset); } const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape(); const auto kernelShape = llvm::SmallVector(kernelShapeAttr.value_begin(), kernelShapeAttr.value_end()); const mlir::Value kernel = rewriter .create( loc, kernelShape, mlir::IntegerType::get(this->getContext(), 64)) .getResult(); const mlir::DenseIntElementsAttr defaultAttr = rewriter.getI64VectorAttr({1, 1}); const mlir::DenseIntElementsAttr stridesAttr = maxpool2dOp.dilations().getValueOr(defaultAttr); const mlir::DenseIntElementsAttr dilationsAttr = maxpool2dOp.dilations().getValueOr(defaultAttr); rewriter.replaceOpWithNewOp( maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel}, output, stridesAttr, dilationsAttr, llvm::ArrayRef({maxOpAttr})); return mlir::success(); }; }; /// This template rewrite pattern transforms any instance of /// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an /// appropriate region using `FHE.to_signed` operation, an appropriate /// specification for the iteration dimensions and appropriate operations /// managing the accumulator of `linalg.generic`. /// /// Example: /// /// FHELinalg.to_signed(%tensor): /// tensor> -> tensor> /// /// becomes: /// /// #maps = [ /// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, /// affine_map<(aN, ..., a1) -> (aN, ..., a1)> /// ] /// #attributes { /// indexing_maps = #maps, /// iterator_types = ["parallel", "parallel"], /// } /// /// %init = linalg.init_tensor [DN,...,D1] : tensor> /// %result = linalg.generic { /// ins(%tensor: tensor>) /// outs(%init: tensor>) /// { /// ^bb0(%arg0: !FHE.eint

): /// %0 = FHE.to_signed(%arg0): !FHE.eint

-> !FHE.esint

/// linalg.yield %0 : !FHE.esint

/// } /// } /// struct FHELinalgToSignedToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgToSignedToLinalgGeneric( mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(FHELinalg::ToSignedOp op, mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType inputTy = op.input().getType().cast(); mlir::RankedTensorType resultTy = op->getResult(0).getType().cast(); mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { auto fheOp = nestedBuilder.create( op.getLoc(), resultTy.getElementType(), blockArgs[0]); nestedBuilder.create(op.getLoc(), fheOp.getResult()); }; llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{op.input()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; auto genericOp = rewriter.create( op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(op, {genericOp.getResult(0)}); return mlir::success(); }; }; /// This template rewrite pattern transforms any instance of /// operators `FHELinalg.to_unsigned` to an instance of `linalg.generic` with an /// appropriate region using `FHE.to_unsigned` operation, an appropriate /// specification for the iteration dimensions and appropriate operations /// managing the accumulator of `linalg.generic`. /// /// Example: /// /// FHELinalg.to_unsigned(%tensor): /// tensor> -> tensor> /// /// becomes: /// /// #maps = [ /// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, /// affine_map<(aN, ..., a1) -> (aN, ..., a1)> /// ] /// #attributes { /// indexing_maps = #maps, /// iterator_types = ["parallel", "parallel"], /// } /// /// %init = linalg.init_tensor [DN,...,D1] : tensor> /// %result = linalg.generic { /// ins(%tensor: tensor>) /// outs(%init: tensor>) /// { /// ^bb0(%arg0: !FHE.esint

): /// %0 = FHE.to_unsigned(%arg0): !FHE.esint

-> !FHE.eint

/// linalg.yield %0 : !FHE.eint

/// } /// } /// struct FHELinalgToUnsignedToLinalgGeneric : public mlir::OpRewritePattern { FHELinalgToUnsignedToLinalgGeneric( mlir::MLIRContext *context, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(FHELinalg::ToUnsignedOp op, mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType inputTy = op.input().getType().cast(); mlir::RankedTensorType resultTy = op->getResult(0).getType().cast(); mlir::Value init = rewriter.create( op.getLoc(), resultTy, mlir::ValueRange{}); llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { auto fheOp = nestedBuilder.create( op.getLoc(), resultTy.getElementType(), blockArgs[0]); nestedBuilder.create(op.getLoc(), fheOp.getResult()); }; llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{op.input()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; auto genericOp = rewriter.create( op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(op, {genericOp.getResult(0)}); return mlir::success(); }; }; namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { void runOnOperation() final; }; void FHETensorOpsToLinalg::runOnOperation() { mlir::func::FuncOp function = this->getOperation(); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalDialect(); target.addLegalOp(); mlir::RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert>( &getContext(), [](mlir::OpBuilder &builder, mlir::Location loc, mlir::Type type, mlir::Value arg0, mlir::Value arg1) { return builder.create( loc, type, arg0, arg1); }); patterns.insert>( &getContext(), [](mlir::OpBuilder &builder, mlir::Location loc, mlir::Type type, mlir::Value arg0, mlir::Value arg1) { return builder.create( loc, type, arg1, arg0); }); patterns.insert(&getContext()); patterns.insert( &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) this->signalPassFailure(); } } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createConvertFHETensorOpsToLinalg() { return std::make_unique(); } } // namespace concretelang } // namespace mlir