feat: lower FHELinalg.transpose to linalg.generic

This commit is contained in:
youben11
2022-03-22 11:20:02 +01:00
committed by Ayoub Benaissa
parent 4e64b9e12a
commit 77356fa374

View File

@@ -1243,6 +1243,86 @@ struct SumToLinalgGeneric
};
};
// This rewrite pattern transforms any instance of operators
// `FHELinalg.transpose` to an instance of `linalg.generic`.
//
// Example:
//
// %result = "FHELinalg.transpose"(%input: tensor<d0xd1x...xdNx!FHE.eint<p>>)
// -> tensor<dNx...xd1xd0x!FHE.eint<p>
//
// becomes:
//
// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)>
// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
//
// %accumulator = "FHE.zero_tensor"() : () ->
// tensor<dNx...xd1xd0x!FHE.eint<6>> %result = linalg.generic
// {
// indexing_maps = [#map0, #map1],
// iterator_types = ["parallel", "parallel", ..., "parallel"]
// }
// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
// outs(%accumulator : tensor<dNx...xd1xd0x!FHE.eint<7>>)
// {
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
// linalg.yield %a : !FHE.eint<7>
// } -> tensor<dNx...xd1xd0x!FHE.eint<7>>
//
struct TransposeToLinalgGeneric
: public ::mlir::OpRewritePattern<
mlir::concretelang::FHELinalg::TransposeOp> {
TransposeToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::TransposeOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::FHELinalg::TransposeOp transposeOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::Value input = transposeOp.getOperand();
mlir::Value output = transposeOp.getResult();
auto inputType = input.getType().dyn_cast<mlir::RankedTensorType>();
auto outputType = output.getType().dyn_cast<mlir::RankedTensorType>();
mlir::Location location = transposeOp.getLoc();
// Initialize empty tensor to fill with transpose result
mlir::Value zeroTensor =
rewriter.create<FHE::ZeroTensorOp>(location, outputType).getResult();
// Inverted dimensions to create a transposition
std::vector<unsigned int> perms = {};
auto n_dim = inputType.getShape().size();
for (int i = n_dim - 1; i >= 0; i--)
perms.push_back(i);
llvm::SmallVector<mlir::Type, 1> resultTypes{zeroTensor.getType()};
auto ins = llvm::SmallVector<mlir::Value, 1>{input};
auto outs = llvm::SmallVector<mlir::Value, 1>{zeroTensor};
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getPermutationMap(perms, this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()),
};
auto iteratorTypes = parallelIteratorType(n_dim);
// The maps will be responsible for changing item positions, we just return
// items here
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::Value item = blockArgs[0];
nestedBuilder.create<linalg::YieldOp>(location, item);
};
mlir::Value result =
rewriter
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
iteratorTypes, regionBuilder)
.getResult(0);
rewriter.replaceOp(transposeOp, {result});
return mlir::success();
};
};
// This rewrite pattern transforms any instance of operators
// `FHELinalg.concat` to instances of `tensor.insert_slice`
//
@@ -1558,6 +1638,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
patterns.insert<SumToLinalgGeneric>(&getContext());
patterns.insert<ConcatRewritePattern>(&getContext());
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
patterns.insert<TransposeToLinalgGeneric>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())