diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 85b2153a5..97df64aa7 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1243,6 +1243,86 @@ struct SumToLinalgGeneric }; }; +// This rewrite pattern transforms any instance of operators +// `FHELinalg.transpose` to an instance of `linalg.generic`. +// +// Example: +// +// %result = "FHELinalg.transpose"(%input: tensor>) +// -> tensor +// +// becomes: +// +// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)> +// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> +// +// %accumulator = "FHE.zero_tensor"() : () -> +// tensor> %result = linalg.generic +// { +// indexing_maps = [#map0, #map1], +// iterator_types = ["parallel", "parallel", ..., "parallel"] +// } +// ins(%input : tensor>) +// outs(%accumulator : tensor>) +// { +// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>): +// linalg.yield %a : !FHE.eint<7> +// } -> tensor> +// +struct TransposeToLinalgGeneric + : public ::mlir::OpRewritePattern< + mlir::concretelang::FHELinalg::TransposeOp> { + TransposeToLinalgGeneric(::mlir::MLIRContext *context) + : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::TransposeOp>( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(::mlir::concretelang::FHELinalg::TransposeOp transposeOp, + ::mlir::PatternRewriter &rewriter) const override { + + mlir::Value input = transposeOp.getOperand(); + mlir::Value output = transposeOp.getResult(); + auto inputType = input.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + mlir::Location location = transposeOp.getLoc(); + // Initialize empty tensor to fill with transpose result + mlir::Value zeroTensor = + rewriter.create(location, outputType).getResult(); + + // Inverted dimensions to create a transposition + std::vector perms = {}; + auto n_dim = inputType.getShape().size(); + for (int i = n_dim - 1; i >= 0; i--) + perms.push_back(i); + + llvm::SmallVector resultTypes{zeroTensor.getType()}; + auto ins = llvm::SmallVector{input}; + auto outs = llvm::SmallVector{zeroTensor}; + llvm::SmallVector maps{ + mlir::AffineMap::getPermutationMap(perms, this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()), + }; + auto iteratorTypes = parallelIteratorType(n_dim); + // The maps will be responsible for changing item positions, we just return + // items here + auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::Value item = blockArgs[0]; + nestedBuilder.create(location, item); + }; + mlir::Value result = + rewriter + .create(location, resultTypes, ins, outs, maps, + iteratorTypes, regionBuilder) + .getResult(0); + + rewriter.replaceOp(transposeOp, {result}); + return mlir::success(); + }; +}; + // This rewrite pattern transforms any instance of operators // `FHELinalg.concat` to instances of `tensor.insert_slice` // @@ -1558,6 +1638,7 @@ void FHETensorOpsToLinalg::runOnFunction() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed())