mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: lower FHELinalg.transpose to linalg.generic
This commit is contained in:
@@ -1243,6 +1243,86 @@ struct SumToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.transpose` to an instance of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.transpose"(%input: tensor<d0xd1x...xdNx!FHE.eint<p>>)
|
||||
// -> tensor<dNx...xd1xd0x!FHE.eint<p>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map0 = affine_map<(i0, i1, ..., iN) -> (iN, ..., i1, i0)>
|
||||
// #map1 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
//
|
||||
// %accumulator = "FHE.zero_tensor"() : () ->
|
||||
// tensor<dNx...xd1xd0x!FHE.eint<6>> %result = linalg.generic
|
||||
// {
|
||||
// indexing_maps = [#map0, #map1],
|
||||
// iterator_types = ["parallel", "parallel", ..., "parallel"]
|
||||
// }
|
||||
// ins(%input : tensor<d0xd1x...xdNx!FHE.eint<7>>)
|
||||
// outs(%accumulator : tensor<dNx...xd1xd0x!FHE.eint<7>>)
|
||||
// {
|
||||
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
// linalg.yield %a : !FHE.eint<7>
|
||||
// } -> tensor<dNx...xd1xd0x!FHE.eint<7>>
|
||||
//
|
||||
struct TransposeToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::TransposeOp> {
|
||||
TransposeToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::TransposeOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::TransposeOp transposeOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Value input = transposeOp.getOperand();
|
||||
mlir::Value output = transposeOp.getResult();
|
||||
auto inputType = input.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto outputType = output.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Location location = transposeOp.getLoc();
|
||||
// Initialize empty tensor to fill with transpose result
|
||||
mlir::Value zeroTensor =
|
||||
rewriter.create<FHE::ZeroTensorOp>(location, outputType).getResult();
|
||||
|
||||
// Inverted dimensions to create a transposition
|
||||
std::vector<unsigned int> perms = {};
|
||||
auto n_dim = inputType.getShape().size();
|
||||
for (int i = n_dim - 1; i >= 0; i--)
|
||||
perms.push_back(i);
|
||||
|
||||
llvm::SmallVector<mlir::Type, 1> resultTypes{zeroTensor.getType()};
|
||||
auto ins = llvm::SmallVector<mlir::Value, 1>{input};
|
||||
auto outs = llvm::SmallVector<mlir::Value, 1>{zeroTensor};
|
||||
llvm::SmallVector<mlir::AffineMap, 2> maps{
|
||||
mlir::AffineMap::getPermutationMap(perms, this->getContext()),
|
||||
mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()),
|
||||
};
|
||||
auto iteratorTypes = parallelIteratorType(n_dim);
|
||||
// The maps will be responsible for changing item positions, we just return
|
||||
// items here
|
||||
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::Value item = blockArgs[0];
|
||||
nestedBuilder.create<linalg::YieldOp>(location, item);
|
||||
};
|
||||
mlir::Value result =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
|
||||
iteratorTypes, regionBuilder)
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOp(transposeOp, {result});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.concat` to instances of `tensor.insert_slice`
|
||||
//
|
||||
@@ -1558,6 +1638,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
|
||||
patterns.insert<SumToLinalgGeneric>(&getContext());
|
||||
patterns.insert<ConcatRewritePattern>(&getContext());
|
||||
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
|
||||
patterns.insert<TransposeToLinalgGeneric>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
Reference in New Issue
Block a user