feat: implement maxpool2d operation

This commit is contained in:
Umut
2023-02-17 16:46:46 +01:00
parent 56bdb05be3
commit bc69c87d62
24 changed files with 1873 additions and 18 deletions

View File

@@ -323,6 +323,41 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
let summary = "Get maximum of two encrypted integers.";
let description = [{
Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'.
Type of inputs and the output should be the same.
If `x - y`` inside the max overflows or underflows, the behavior is undefined.
So to support the full range, you should increase the bit-width by 1 manually.
Example:
```mlir
// ok
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
"FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3>
// error
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2>
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2>
"FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2>
```
}];
let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$x, "Value":$y), [{
build($_builder, $_state, x.getType(), x, y);
}]>
];
let hasVerifier = 1;
}
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
let summary = "Cast an unsigned integer to a signed one";

View File

@@ -4,3 +4,4 @@ add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
add_subdirectory(BigInt)
add_subdirectory(Boolean)
add_subdirectory(Max)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Max.td)
mlir_tablegen(Max.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEMaxPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEMaxPassIncGen)

View File

@@ -0,0 +1,23 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_MAX_PASS_H
#define CONCRETELANG_FHE_MAX_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_FHE_MAX_PASS
#define CONCRETELANG_FHE_MAX_PASS
include "mlir/Pass/PassBase.td"
def FHEMaxTransform : Pass<"fhe-max-transform"> {
let summary = "Transform max operation to basic operations";
let constructor = "mlir::concretelang::createFHEMaxTransformPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
}
#endif

View File

@@ -944,6 +944,18 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
let hasVerifier = 1;
}
def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> {
let summary = "Returns the 2D maxpool of a tensor in the form NCHW";
let arguments = (ins
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input,
I64ElementsAttr:$kernel_shape,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
let summary = "Returns a tensor that contains the transposition of the input tensor.";