feat: create FHELinalg.from_element operation

This commit is introduced because python bindings for `tensor.from_elements` are not generated automatically. Previously, we overcame this with string manipulation, but with the latest version of the compiler, it became a problem. This commit should be reverted eventually. See https://discourse.llvm.org/t/cannot-create-tensor-from-elements-operation-from-python-bindings/4768 for the discussion in LLVM forums.
This commit is contained in:
Umut
2022-07-05 15:59:39 +02:00
parent f4166a4973
commit e24dbec249
5 changed files with 108 additions and 0 deletions

View File

@@ -1119,5 +1119,21 @@ def TransposeOp : FHELinalg_Op<"transpose", []> {
let hasVerifier = 1;
}
def FromElementOp : FHELinalg_Op<"from_element", []> {
let summary = "Creates a tensor with a single element.";
let description = [{
Creates a tensor with a single element.
```mlir
"FHELinalg.from_element"(%a) : (Type) -> tensor<1xType>
```
}];
let arguments = (ins AnyType);
let results = (outs Type<And<[TensorOf<[AnyType]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
#endif

View File

@@ -1324,6 +1324,41 @@ struct TransposeToLinalgGeneric
};
};
/// This rewrite pattern transforms any instance of operators
/// `FHELinalg.from_element` to an instance of `tensor.from_elements`.
///
/// Example:
///
/// %result = "FHELinalg.from_element"(%x) : (Type) -> tensor<1xType>
///
/// becomes:
///
/// %result = tensor.from_elements %x : (Type) -> tensor<1xType>
///
struct FromElementToTensorFromElements
: public ::mlir::OpRewritePattern<
mlir::concretelang::FHELinalg::FromElementOp> {
FromElementToTensorFromElements(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<
::mlir::concretelang::FHELinalg::FromElementOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::FHELinalg::FromElementOp op,
::mlir::PatternRewriter &rewriter) const override {
auto in = op.getOperand();
auto out = op.getResult();
mlir::Value result =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), out.getType(), in)
.getResult();
rewriter.replaceOp(op, {result});
return mlir::success();
};
};
/// This rewrite pattern transforms any instance of operators
/// `FHELinalg.concat` to instances of `tensor.insert_slice`
///
@@ -1647,6 +1682,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<ConcatRewritePattern>(&getContext());
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
patterns.insert<TransposeToLinalgGeneric>(&getContext());
patterns.insert<FromElementToTensorFromElements>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -1030,6 +1030,18 @@ static llvm::APInt getSqMANP(
return eNorm;
}
static llvm::APInt getSqMANP(
FHELinalg::FromElementOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
auto manp = operandMANPs[0]->getValue().getMANP();
if (manp.hasValue()) {
return manp.getValue();
}
return llvm::APInt{1, 1, false};
}
static llvm::APInt getSqMANP(
mlir::tensor::FromElementsOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
@@ -1367,6 +1379,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
op)) {
norm2SqEquiv = getSqMANP(conv2dOp, operands);
} else if (auto fromElementOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::FromElementOp>(
op)) {
norm2SqEquiv = getSqMANP(fromElementOp, operands);
} else if (auto transposeOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
op)) {

View File

@@ -999,6 +999,23 @@ mlir::LogicalResult Conv2dOp::verify() {
return mlir::success();
}
mlir::LogicalResult FromElementOp::verify() {
mlir::Value in = this->getOperand();
mlir::Value out = this->getResult();
auto inType = in.getType();
auto outType = out.getType().dyn_cast<mlir::TensorType>();
auto expectedOutType = outType.cloneWith({1}, inType);
if (outType != expectedOutType) {
this->emitOpError() << "has invalid output type (expected "
<< expectedOutType << ", got " << outType << ")";
return mlir::failure();
}
return mlir::success();
}
//===----------------------------------------------------------------------===//
// Implementation of FhelinalgConv2DNchwFchwOp
// This is a generated functions from `make generate_conv_op`, and some helpers

View File

@@ -0,0 +1,23 @@
// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// -----
// CHECK: func @main(%[[a0:.*]]: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
%1 = "FHELinalg.from_element"(%arg0) : (!FHE.eint<7>) -> tensor<1x!FHE.eint<7>>
return %1 : tensor<1x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: i8) -> tensor<1xi8> {
// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1xi8>
// CHECK-NEXT: return %[[v0]] : tensor<1xi8>
// CHECK-NEXT: }
func @main(%arg0: i8) -> tensor<1xi8> {
%1 = "FHELinalg.from_element"(%arg0) : (i8) -> tensor<1xi8>
return %1 : tensor<1xi8>
}