mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: create FHELinalg.from_element operation
This commit is introduced because python bindings for `tensor.from_elements` are not generated automatically. Previously, we overcame this with string manipulation, but with the latest version of the compiler, it became a problem. This commit should be reverted eventually. See https://discourse.llvm.org/t/cannot-create-tensor-from-elements-operation-from-python-bindings/4768 for the discussion in LLVM forums.
This commit is contained in:
@@ -1119,5 +1119,21 @@ def TransposeOp : FHELinalg_Op<"transpose", []> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FromElementOp : FHELinalg_Op<"from_element", []> {
|
||||
let summary = "Creates a tensor with a single element.";
|
||||
|
||||
let description = [{
|
||||
Creates a tensor with a single element.
|
||||
|
||||
```mlir
|
||||
"FHELinalg.from_element"(%a) : (Type) -> tensor<1xType>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType);
|
||||
let results = (outs Type<And<[TensorOf<[AnyType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1324,6 +1324,41 @@ struct TransposeToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.from_element` to an instance of `tensor.from_elements`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %result = "FHELinalg.from_element"(%x) : (Type) -> tensor<1xType>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %result = tensor.from_elements %x : (Type) -> tensor<1xType>
|
||||
///
|
||||
struct FromElementToTensorFromElements
|
||||
: public ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::FHELinalg::FromElementOp> {
|
||||
|
||||
FromElementToTensorFromElements(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<
|
||||
::mlir::concretelang::FHELinalg::FromElementOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::FromElementOp op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
auto in = op.getOperand();
|
||||
auto out = op.getResult();
|
||||
|
||||
mlir::Value result =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), out.getType(), in)
|
||||
.getResult();
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of operators
|
||||
/// `FHELinalg.concat` to instances of `tensor.insert_slice`
|
||||
///
|
||||
@@ -1647,6 +1682,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
patterns.insert<ConcatRewritePattern>(&getContext());
|
||||
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
|
||||
patterns.insert<TransposeToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FromElementToTensorFromElements>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -1030,6 +1030,18 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
FHELinalg::FromElementOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
auto manp = operandMANPs[0]->getValue().getMANP();
|
||||
if (manp.hasValue()) {
|
||||
return manp.getValue();
|
||||
}
|
||||
|
||||
return llvm::APInt{1, 1, false};
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::FromElementsOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -1367,6 +1379,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(conv2dOp, operands);
|
||||
} else if (auto fromElementOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::FromElementOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(fromElementOp, operands);
|
||||
} else if (auto transposeOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
|
||||
op)) {
|
||||
|
||||
@@ -999,6 +999,23 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult FromElementOp::verify() {
|
||||
mlir::Value in = this->getOperand();
|
||||
mlir::Value out = this->getResult();
|
||||
|
||||
auto inType = in.getType();
|
||||
auto outType = out.getType().dyn_cast<mlir::TensorType>();
|
||||
|
||||
auto expectedOutType = outType.cloneWith({1}, inType);
|
||||
if (outType != expectedOutType) {
|
||||
this->emitOpError() << "has invalid output type (expected "
|
||||
<< expectedOutType << ", got " << outType << ")";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementation of FhelinalgConv2DNchwFchwOp
|
||||
// This is a generated functions from `make generate_conv_op`, and some helpers
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
%1 = "FHELinalg.from_element"(%arg0) : (!FHE.eint<7>) -> tensor<1x!FHE.eint<7>>
|
||||
return %1 : tensor<1x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: i8) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.from_elements %[[a0]] : tensor<1xi8>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<1xi8>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%arg0: i8) -> tensor<1xi8> {
|
||||
%1 = "FHELinalg.from_element"(%arg0) : (i8) -> tensor<1xi8>
|
||||
return %1 : tensor<1xi8>
|
||||
}
|
||||
Reference in New Issue
Block a user