mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat: implement maxpool2d operation
This commit is contained in:
@@ -323,6 +323,41 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
|
||||
let summary = "Get maximum of two encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'.
|
||||
Type of inputs and the output should be the same.
|
||||
|
||||
If `x - y`` inside the max overflows or underflows, the behavior is undefined.
|
||||
So to support the full range, you should increase the bit-width by 1 manually.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3>
|
||||
|
||||
// error
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$x, "Value":$y), [{
|
||||
build($_builder, $_state, x.getType(), x, y);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
let summary = "Cast an unsigned integer to a signed one";
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@ add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_subdirectory(BigInt)
|
||||
add_subdirectory(Boolean)
|
||||
add_subdirectory(Max)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Max.td)
|
||||
mlir_tablegen(Max.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEMaxPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEMaxPassIncGen)
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_MAX_PASS_H
|
||||
#define CONCRETELANG_FHE_MAX_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass();
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_FHE_MAX_PASS
|
||||
#define CONCRETELANG_FHE_MAX_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHEMaxTransform : Pass<"fhe-max-transform"> {
|
||||
let summary = "Transform max operation to basic operations";
|
||||
let constructor = "mlir::concretelang::createFHEMaxTransformPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -944,6 +944,18 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> {
|
||||
let summary = "Returns the 2D maxpool of a tensor in the form NCHW";
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input,
|
||||
I64ElementsAttr:$kernel_shape,
|
||||
OptionalAttr<I64ElementsAttr>:$strides,
|
||||
OptionalAttr<I64ElementsAttr>:$dilations
|
||||
);
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
|
||||
let summary = "Returns a tensor that contains the transposition of the input tensor.";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user