From e24dbec24920653ad07ff31e005a39c29c207123 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 5 Jul 2022 15:59:39 +0200 Subject: [PATCH] feat: create FHELinalg.from_element operation This commit is introduced because python bindings for `tensor.from_elements` are not generated automatically. Previously, we overcame this with string manipulation, but with the latest version of the compiler, it became a problem. This commit should be reverted eventually. See https://discourse.llvm.org/t/cannot-create-tensor-from-elements-operation-from-python-bindings/4768 for the discussion in LLVM forums. --- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 16 +++++++++ .../TensorOpsToLinalg.cpp | 36 +++++++++++++++++++ compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 16 +++++++++ .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 17 +++++++++ .../FHELinalgToLinalg/from_element.mlir | 23 ++++++++++++ 5 files changed, 108 insertions(+) create mode 100644 compiler/tests/check_tests/Conversion/FHELinalgToLinalg/from_element.mlir diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 271b1313e..8c64c080c 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -1119,5 +1119,21 @@ def TransposeOp : FHELinalg_Op<"transpose", []> { let hasVerifier = 1; } +def FromElementOp : FHELinalg_Op<"from_element", []> { + let summary = "Creates a tensor with a single element."; + + let description = [{ + Creates a tensor with a single element. + + ```mlir + "FHELinalg.from_element"(%a) : (Type) -> tensor<1xType> + ``` + }]; + + let arguments = (ins AnyType); + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let hasVerifier = 1; +} #endif diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 2539a46f7..1e4126fec 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1324,6 +1324,41 @@ struct TransposeToLinalgGeneric }; }; +/// This rewrite pattern transforms any instance of operators +/// `FHELinalg.from_element` to an instance of `tensor.from_elements`. +/// +/// Example: +/// +/// %result = "FHELinalg.from_element"(%x) : (Type) -> tensor<1xType> +/// +/// becomes: +/// +/// %result = tensor.from_elements %x : (Type) -> tensor<1xType> +/// +struct FromElementToTensorFromElements + : public ::mlir::OpRewritePattern< + mlir::concretelang::FHELinalg::FromElementOp> { + + FromElementToTensorFromElements(::mlir::MLIRContext *context) + : ::mlir::OpRewritePattern< + ::mlir::concretelang::FHELinalg::FromElementOp>( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(::mlir::concretelang::FHELinalg::FromElementOp op, + ::mlir::PatternRewriter &rewriter) const override { + auto in = op.getOperand(); + auto out = op.getResult(); + + mlir::Value result = + rewriter.create(op.getLoc(), out.getType(), in) + .getResult(); + + rewriter.replaceOp(op, {result}); + return mlir::success(); + }; +}; + /// This rewrite pattern transforms any instance of operators /// `FHELinalg.concat` to instances of `tensor.insert_slice` /// @@ -1647,6 +1682,7 @@ void FHETensorOpsToLinalg::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 79fd4eb6a..39bce1179 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -1030,6 +1030,18 @@ static llvm::APInt getSqMANP( return eNorm; } +static llvm::APInt getSqMANP( + FHELinalg::FromElementOp op, + llvm::ArrayRef *> operandMANPs) { + + auto manp = operandMANPs[0]->getValue().getMANP(); + if (manp.hasValue()) { + return manp.getValue(); + } + + return llvm::APInt{1, 1, false}; +} + static llvm::APInt getSqMANP( mlir::tensor::FromElementsOp op, llvm::ArrayRef *> operandMANPs) { @@ -1367,6 +1379,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(conv2dOp, operands); + } else if (auto fromElementOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(fromElementOp, operands); } else if (auto transposeOp = llvm::dyn_cast( op)) { diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 9bcbcf615..e60cffb41 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -999,6 +999,23 @@ mlir::LogicalResult Conv2dOp::verify() { return mlir::success(); } +mlir::LogicalResult FromElementOp::verify() { + mlir::Value in = this->getOperand(); + mlir::Value out = this->getResult(); + + auto inType = in.getType(); + auto outType = out.getType().dyn_cast(); + + auto expectedOutType = outType.cloneWith({1}, inType); + if (outType != expectedOutType) { + this->emitOpError() << "has invalid output type (expected " + << expectedOutType << ", got " << outType << ")"; + return mlir::failure(); + } + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // Implementation of FhelinalgConv2DNchwFchwOp // This is a generated functions from `make generate_conv_op`, and some helpers diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/from_element.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/from_element.mlir new file mode 100644 index 000000000..78543be84 --- /dev/null +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/from_element.mlir @@ -0,0 +1,23 @@ +// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// ----- + +// CHECK: func @main(%[[a0:.*]]: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> { + %1 = "FHELinalg.from_element"(%arg0) : (!FHE.eint<7>) -> tensor<1x!FHE.eint<7>> + return %1 : tensor<1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: i8) -> tensor<1xi8> { +// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1xi8> +// CHECK-NEXT: return %[[v0]] : tensor<1xi8> +// CHECK-NEXT: } +func @main(%arg0: i8) -> tensor<1xi8> { + %1 = "FHELinalg.from_element"(%arg0) : (i8) -> tensor<1xi8> + return %1 : tensor<1xi8> +}