feat: implement maxpool2d operation

This commit is contained in:
Umut
2023-02-17 16:46:46 +01:00
parent 56bdb05be3
commit bc69c87d62
24 changed files with 1873 additions and 18 deletions

View File

@@ -323,6 +323,41 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
let hasVerifier = 1;
}
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
let summary = "Get maximum of two encrypted integers.";
let description = [{
Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'.
Type of inputs and the output should be the same.
If `x - y`` inside the max overflows or underflows, the behavior is undefined.
So to support the full range, you should increase the bit-width by 1 manually.
Example:
```mlir
// ok
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
"FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3>
// error
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2>
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2>
"FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2>
```
}];
let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$x, "Value":$y), [{
build($_builder, $_state, x.getType(), x, y);
}]>
];
let hasVerifier = 1;
}
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
let summary = "Cast an unsigned integer to a signed one";

View File

@@ -4,3 +4,4 @@ add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
add_subdirectory(BigInt)
add_subdirectory(Boolean)
add_subdirectory(Max)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Max.td)
mlir_tablegen(Max.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEMaxPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEMaxPassIncGen)

View File

@@ -0,0 +1,23 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_MAX_PASS_H
#define CONCRETELANG_FHE_MAX_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_FHE_MAX_PASS
#define CONCRETELANG_FHE_MAX_PASS
include "mlir/Pass/PassBase.td"
def FHEMaxTransform : Pass<"fhe-max-transform"> {
let summary = "Transform max operation to basic operations";
let constructor = "mlir::concretelang::createFHEMaxTransformPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
}
#endif

View File

@@ -944,6 +944,18 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
let hasVerifier = 1;
}
def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> {
let summary = "Returns the 2D maxpool of a tensor in the form NCHW";
let arguments = (ins
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input,
I64ElementsAttr:$kernel_shape,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
let summary = "Returns a tensor that contains the transposition of the input tensor.";

View File

@@ -1771,6 +1771,81 @@ struct FHELinalgConv2dToLinalgConv2d
};
};
/// This rewrite pattern transforms all instances
/// of `FHELinalg.maxpool2d` to `linalg.pooling_ncw_max`.
struct FHELinalgMaxpool2dToLinalgMaxpool2d
: public mlir::OpRewritePattern<FHELinalg::Maxpool2dOp> {
FHELinalgMaxpool2dToLinalgMaxpool2d(mlir::MLIRContext *context)
: mlir::OpRewritePattern<FHELinalg::Maxpool2dOp>(context) {}
mlir::LogicalResult
matchAndRewrite(FHELinalg::Maxpool2dOp maxpool2dOp,
mlir::PatternRewriter &rewriter) const override {
const mlir::Location loc = maxpool2dOp->getLoc();
const mlir::NamedAttribute maxOpAttr = rewriter.getNamedAttr(
"max_signed",
rewriter.getStringAttr(FHE::MaxEintOp::getOperationName()));
const auto outputTy =
maxpool2dOp->getResult(0).getType().cast<mlir::RankedTensorType>();
const auto outputElementTy =
outputTy.getElementType().cast<FHE::FheIntegerInterface>();
mlir::Value output =
rewriter.create<FHE::ZeroTensorOp>(loc, outputTy).getResult();
if (outputElementTy.isSigned()) {
const int64_t outputBitWidth = outputElementTy.getWidth();
const int64_t offsetValue = 1 << (outputBitWidth - 2);
const mlir::Type offsetType =
mlir::IntegerType::get(this->getContext(), outputBitWidth + 1);
const mlir::Type offsetTensorType =
mlir::RankedTensorType::get({1}, offsetType);
const llvm::SmallVector<mlir::Attribute> offsetTensorAttr = {
mlir::IntegerAttr::get(offsetType, offsetValue)};
const mlir::Attribute offsetAttr =
mlir::DenseElementsAttr::get(offsetTensorType, offsetTensorAttr);
const mlir::Value offset =
rewriter.create<mlir::arith::ConstantOp>(loc, offsetAttr);
output = rewriter.create<FHELinalg::SubEintIntOp>(loc, output, offset);
}
const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape();
const auto kernelShape =
llvm::SmallVector<int64_t, 2>(kernelShapeAttr.value_begin<int64_t>(),
kernelShapeAttr.value_end<int64_t>());
const mlir::Value kernel =
rewriter
.create<mlir::linalg::InitTensorOp>(
loc, kernelShape,
mlir::IntegerType::get(this->getContext(), 64))
.getResult();
const mlir::DenseIntElementsAttr defaultAttr =
rewriter.getI64VectorAttr({1, 1});
const mlir::DenseIntElementsAttr stridesAttr =
maxpool2dOp.dilations().getValueOr(defaultAttr);
const mlir::DenseIntElementsAttr dilationsAttr =
maxpool2dOp.dilations().getValueOr(defaultAttr);
rewriter.replaceOpWithNewOp<mlir::linalg::PoolingNchwMaxOp>(
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel},
output, stridesAttr, dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({maxOpAttr}));
return mlir::success();
};
};
/// This template rewrite pattern transforms any instance of
/// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an
/// appropriate region using `FHE.to_signed` operation, an appropriate
@@ -2019,6 +2094,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<SumToLinalgGeneric>(&getContext());
patterns.insert<ConcatRewritePattern>(&getContext());
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
patterns.insert<FHELinalgMaxpool2dToLinalgMaxpool2d>(&getContext());
patterns.insert<TransposeToLinalgGeneric>(&getContext());
patterns.insert<FromElementToTensorFromElements>(&getContext());
patterns.insert<FHELinalgToSignedToLinalgGeneric>(&getContext());

View File

@@ -77,6 +77,10 @@ struct FunctionToDag {
for (auto &bb : func.getBody().getBlocks()) {
for (auto &op : bb.getOperations()) {
addOperation(dag, op);
}
}
for (auto &bb : func.getBody().getBlocks()) {
for (auto &op : bb.getOperations()) {
op.removeAttr("SMANP");
}
}
@@ -145,6 +149,14 @@ struct FunctionToDag {
// If can't find weights return default leveled op
DEBUG("Replace Dot by LevelledOp on " << op);
}
if (auto max = asMax(op)) {
addMax(dag, max, encrypted_inputs, precision);
return;
}
if (auto maxpool2d = asMaxpool2d(op)) {
addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision);
return;
}
// default
addLevelledOp(dag, op, encrypted_inputs);
}
@@ -207,6 +219,103 @@ struct FunctionToDag {
manp, slice(out_shape), comment);
}
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
int precision) {
mlir::Value result = maxOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
Operation *xOp = maxOp.x().getDefiningOp();
Operation *yOp = maxOp.y().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
if (xOp != nullptr) {
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
xSmanp = xSmanpAttr.getValue();
}
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
if (yOp != nullptr) {
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
ySmanp = ySmanpAttr.getValue();
}
const double subManp =
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
auto loc = loc_to_string(maxOp.getLoc());
auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc;
auto subNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
subManp, slice(resultShape), comment);
const double tluNodeManp = 1;
const std::vector<std::uint64_t> unknownFunction;
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble());
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
tluNode, inputs[1]};
index[result] =
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
addManp, slice(resultShape), comment);
}
void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp,
Inputs &inputs, int precision) {
mlir::Value result = maxpool2dOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
// all TLUs are flattened into a dimension
// to create a single TLU node in optimizer dag
std::vector<uint64_t> fakeShape = resultShape;
uint64_t numberOfComparisons = 1;
for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues<int64_t>()) {
numberOfComparisons *= dimensionSize;
}
fakeShape.push_back(numberOfComparisons);
Operation *inputOp = maxpool2dOp.input().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
llvm::APInt inputSmanp = llvm::APInt{1, 1, false};
if (inputOp != nullptr) {
const auto inputSmanpAttr =
inputOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(inputSmanpAttr && "Missing SMANP value on a crypto operation");
inputSmanp = inputSmanpAttr.getValue();
}
const double subManp = sqrt(2 * inputSmanp.roundToDouble() + 1);
auto loc = loc_to_string(maxpool2dOp.getLoc());
auto comment =
std::string(maxpool2dOp->getName().getStringRef()) + " " + loc;
auto subNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
subManp, slice(fakeShape), comment);
const std::vector<std::uint64_t> unknownFunction;
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
const double addManp = sqrt(inputSmanp.roundToDouble() + 1);
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
tluNode, inputs[1]};
index[result] =
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
addManp, slice(resultShape), comment);
}
Inputs encryptedInputs(mlir::Operation &op) {
Inputs inputs;
for (auto operand : op.getOperands()) {
@@ -237,6 +346,14 @@ struct FunctionToDag {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
}
mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op);
}
mlir::concretelang::FHELinalg::Maxpool2dOp asMaxpool2d(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(op);
}
bool isReturn(mlir::Operation &op) {
return llvm::isa<mlir::func::ReturnOp>(op);
}

View File

@@ -510,6 +510,30 @@ static llvm::APInt getSqMANP(
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.max_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MaxEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
// max(x, y) = max(x - y, 0) + y
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
const llvm::APInt sub = APIntWidthExtendUAdd(x, y);
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt add = APIntWidthExtendUAdd(tlu, y);
// this is not optimal as it can increase the resulting noise unnecessarily
return APIntUMax(add, sub);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHELinalg.add_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -1153,6 +1177,28 @@ static llvm::APInt getSqMANP(
return accNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::Maxpool2dOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
// maximum between two value is calculated using
// - max(x - y, 0) + y
// max is calculated with a TLU so MANP is {1, 1, false}
// y on the other hand comes from the input or from the previous result
// in the current implementation, it's the input
// so the resulting MANP is `{1, 1, false} + MANP input`
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input);
const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input);
return APIntUMax(forIntermediate, forResult);
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -1202,6 +1248,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto roundOp =
llvm::dyn_cast<mlir::concretelang::FHE::RoundEintOp>(op)) {
norm2SqEquiv = getSqMANP(roundOp, operands);
} else if (auto maxEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op)) {
norm2SqEquiv = getSqMANP(maxEintOp, operands);
} else if (llvm::isa<mlir::concretelang::FHE::ZeroEintOp>(op) ||
llvm::isa<mlir::concretelang::FHE::ToBoolOp>(op) ||
llvm::isa<mlir::concretelang::FHE::FromBoolOp>(op) ||
@@ -1264,6 +1313,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
op)) {
norm2SqEquiv = getSqMANP(conv2dOp, operands);
} else if (auto maxpool2dOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(
op)) {
norm2SqEquiv = getSqMANP(maxpool2dOp, operands);
} else if (auto fromElementOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::FromElementOp>(
op)) {

View File

@@ -192,6 +192,24 @@ mlir::LogicalResult MulEintOp::verify() {
return ::mlir::success();
}
mlir::LogicalResult MaxEintOp::verify() {
auto xTy = this->x().getType().dyn_cast<FheIntegerInterface>();
auto yTy = this->y().getType().dyn_cast<FheIntegerInterface>();
auto outTy = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(),
xTy, outTy)) {
return mlir::failure();
}
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), xTy,
yTy)) {
return mlir::failure();
}
return mlir::success();
}
mlir::LogicalResult ToSignedOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();

View File

@@ -2,6 +2,7 @@ add_mlir_library(
FHEDialectTransforms
BigInt.cpp
Boolean.cpp
Max.cpp
EncryptedMulToDoubleTLU.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE

View File

@@ -0,0 +1,111 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/Transforms/Max/Max.h"
namespace arith = mlir::arith;
namespace func = mlir::func;
namespace FHE = mlir::concretelang::FHE;
/// This rewrite pattern transforms all instances
/// of `FHE.max_eint` to `max(x - y, 0) + y`.
struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
MaxEintPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<FHE::MaxEintOp>(context) {}
mlir::LogicalResult
matchAndRewrite(FHE::MaxEintOp maxEintOp,
mlir::PatternRewriter &rewriter) const override {
const mlir::Location loc = maxEintOp->getLoc();
const FHE::FheIntegerInterface outputTy =
maxEintOp->getResult(0).getType().cast<FHE::FheIntegerInterface>();
const int64_t outputBitWidth = outputTy.getWidth();
mlir::Value x = maxEintOp.x();
mlir::Value y = maxEintOp.y();
const auto xTy = x.getType().cast<FHE::FheIntegerInterface>();
const auto yTy = y.getType().cast<FHE::FheIntegerInterface>();
const auto signedTy = FHE::EncryptedSignedIntegerType::get(
this->getContext(), outputBitWidth);
if (xTy.isUnsigned()) {
x = rewriter.create<FHE::ToSignedOp>(loc, signedTy, x).getResult();
}
if (yTy.isUnsigned()) {
y = rewriter.create<FHE::ToSignedOp>(loc, signedTy, y).getResult();
}
const mlir::Value sub =
rewriter.create<FHE::SubEintOp>(loc, x, y).getResult();
const int64_t lutSize = 1 << outputBitWidth;
auto lutValues = std::vector<int64_t>();
for (int64_t i = 0; i < lutSize / 2; i++) {
lutValues.push_back(i);
}
for (int64_t i = 0; i < lutSize / 2; i++) {
lutValues.push_back(0);
}
const mlir::Attribute lutAttr = rewriter.getI64TensorAttr(lutValues);
const mlir::Value lut =
rewriter.create<arith::ConstantOp>(loc, lutAttr).getResult();
const mlir::Value max =
rewriter.create<FHE::ApplyLookupTableEintOp>(loc, outputTy, sub, lut)
.getResult();
const mlir::Value add =
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.y()).getResult();
rewriter.replaceOp(maxEintOp, {add});
return mlir::success();
};
};
namespace {
struct FHEMaxTransform : public FHEMaxTransformBase<FHEMaxTransform> {
void runOnOperation() final;
};
void FHEMaxTransform::runOnOperation() {
auto target = mlir::ConversionTarget(this->getContext());
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<FHE::FHEDialect>();
target.addIllegalOp<FHE::MaxEintOp>();
auto patterns = mlir::RewritePatternSet(&this->getContext());
patterns.insert<MaxEintPattern>(&this->getContext());
mlir::Operation *op = this->getOperation();
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass() {
return std::make_unique<FHEMaxTransform>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1048,6 +1048,154 @@ mlir::LogicalResult Conv2dOp::verify() {
return mlir::success();
}
mlir::LogicalResult Maxpool2dOp::verify() {
const mlir::RankedTensorType inputTy =
this->input().getType().cast<mlir::RankedTensorType>();
const mlir::RankedTensorType outputTy =
this->getResult().getType().cast<mlir::RankedTensorType>();
const FHE::FheIntegerInterface inputElementTy = inputTy.getElementType();
const FHE::FheIntegerInterface outputElementTy = outputTy.getElementType();
if (inputElementTy != outputElementTy) {
this->emitOpError() << "expected output element type "
<< "(" << outputElementTy << ") "
<< "to be the same with input element type "
<< "(" << inputElementTy << ") "
<< "but it is not";
return mlir::failure();
}
const llvm::ArrayRef<int64_t> inputShape = inputTy.getShape();
const llvm::ArrayRef<int64_t> outputShape = outputTy.getShape();
if (inputShape.size() != 4) {
this->emitOpError() << "expected input to have 4 dimensions (N*C*H*W) "
<< "but it has " << inputShape.size();
return mlir::failure();
}
if (outputShape.size() != 4) {
this->emitOpError() << "expected output to have 4 dimensions (N*C*H*W) "
<< "but it has " << outputShape.size();
return mlir::failure();
}
const int64_t inputN = inputShape[0];
const int64_t inputC = inputShape[1];
const int64_t inputH = inputShape[2];
const int64_t inputW = inputShape[3];
const mlir::DenseIntElementsAttr kernelShapeAttr = this->kernel_shape();
const mlir::RankedTensorType kernelShapeAttrTy =
kernelShapeAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> kernelShapeAttrShape =
kernelShapeAttrTy.getShape();
if (kernelShapeAttrShape.size() != 1 || kernelShapeAttrShape[0] != 2) {
this->emitOpError() << "expected kernel shape to be of shape "
<< "(2) "
<< "but it is of shape "
<< "(" << kernelShapeAttrShape << ")";
return mlir::failure();
}
mlir::SmallVector<int64_t, 2> kernelShape;
kernelShape.append(kernelShapeAttr.value_begin<int64_t>(),
kernelShapeAttr.value_end<int64_t>());
const int64_t kernelShapeH = kernelShape[0];
const int64_t kernelShapeW = kernelShape[1];
mlir::SmallVector<int64_t, 2> strides;
const llvm::Optional<mlir::DenseIntElementsAttr> maybeStridesAttr =
this->strides();
if (maybeStridesAttr.hasValue()) {
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.getValue();
const mlir::RankedTensorType stridesAttrTy =
stridesAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> stridesAttrShape = stridesAttrTy.getShape();
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
this->emitOpError() << "expected strides to be of shape "
<< "(2) "
<< "but it is of shape "
<< "(" << stridesAttrShape << ")";
return mlir::failure();
}
strides.append(stridesAttr.value_begin<int64_t>(),
stridesAttr.value_end<int64_t>());
} else {
strides.append({1, 1});
}
for (size_t i = 0; i < 2; i++) {
if (strides[i] < 1) {
this->emitOpError() << "expected elements of strides to be positive "
<< "but strides[" << i << "] is " << strides[i];
return mlir::failure();
}
}
const int64_t stridesH = strides[0];
const int64_t stridesW = strides[1];
mlir::SmallVector<int64_t, 2> dilations;
const llvm::Optional<mlir::DenseIntElementsAttr> maybeDilationsAttr =
this->dilations();
if (maybeDilationsAttr.hasValue()) {
const mlir::DenseIntElementsAttr dilationsAttr =
maybeDilationsAttr.getValue();
const mlir::RankedTensorType dilationsAttrTy =
dilationsAttr.getType().cast<mlir::RankedTensorType>();
const llvm::ArrayRef<int64_t> dilationsAttrShape =
dilationsAttrTy.getShape();
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
this->emitOpError() << "expected dilations to be of shape "
<< "(2) "
<< "but it is of shape "
<< "(" << dilationsAttrShape << ")";
return mlir::failure();
}
dilations.append(dilationsAttr.value_begin<int64_t>(),
dilationsAttr.value_end<int64_t>());
} else {
dilations.append({1, 1});
}
for (size_t i = 0; i < 2; i++) {
if (dilations[i] < 1) {
this->emitOpError() << "expected elements of dilations to be positive "
<< "but dilations[" << i << "] is " << dilations[i];
return mlir::failure();
}
}
const int64_t dilationsH = dilations[0];
const int64_t dilationsW = dilations[1];
const int64_t expectedOutputH =
floor((inputH - dilationsH * (kernelShapeH - 1) - 1) / stridesH) + 1;
const int64_t expectedOutputW =
floor((inputW - dilationsW * (kernelShapeW - 1) - 1) / stridesW) + 1;
const mlir::SmallVector<int64_t, 4> expectedOutputShape = {
inputN,
inputC,
expectedOutputH,
expectedOutputW,
};
if (outputShape != llvm::makeArrayRef(expectedOutputShape)) {
this->emitOpError() << "expected output to be of shape "
<< "(" << expectedOutputShape << ") "
<< "but it is of shape "
<< "(" << outputShape << ")";
return mlir::failure();
}
return mlir::success();
}
mlir::LogicalResult FromElementOp::verify() {
mlir::Value in = this->getOperand();
mlir::Value out = this->getResult();

View File

@@ -293,13 +293,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return errorDiag("Transforming FHE boolean ops failed");
}
// Encrypted mul rewriting
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
module, enablePass)
.failed()) {
return StreamStringError("Rewriting of encrypted mul failed");
}
if (options.chunkIntegers) {
if (mlir::concretelang::pipeline::transformFHEBigInt(
mlirContext, module, enablePass, options.chunkSize,
@@ -344,6 +337,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
.failed()) {
return errorDiag("Lowering from FHELinalg to FHE failed");
}
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
module, enablePass)
.failed()) {
return StreamStringError("Rewriting of high level fhe ops failed");
}
if (target == Target::FHE_NO_LINALG)
return std::move(res);

View File

@@ -38,6 +38,7 @@
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Support/Pipeline.h>
@@ -181,7 +182,9 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("transformHighLevelFHEOps", pm, context);
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass);
return pm.run(module.getOperation());
}

View File

@@ -0,0 +1,38 @@
// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: %[[v1:.*]] = linalg.init_tensor [3, 2] : tensor<3x2xi64>
// CHECK-NEXT: %[[v2:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %1 : tensor<1x1x8x10x!FHE.eint<5>>, tensor<3x2xi64>) outs(%0 : tensor<1x1x6x9x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: return %[[v2]] : tensor<1x1x6x9x!FHE.eint<5>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> } : (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
return %0 : tensor<1x1x6x9x!FHE.eint<5>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (0)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.esint<6>
// CHECK-NEXT: } -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v4:.*]] = linalg.init_tensor [2, 3] : tensor<2x3xi64>
// CHECK-NEXT: %[[v5:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %[[v4]] : tensor<1x1x6x5x!FHE.esint<6>>, tensor<2x3xi64>) outs(%[[v3]] : tensor<1x1x5x3x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: return %[[v5]] : tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> } : (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
return %0 : tensor<1x1x5x3x!FHE.esint<6>>
}

View File

@@ -0,0 +1,31 @@
// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-max-transform %s 2>&1 | FileCheck %s
// -----
// CHECK: func.func @main(%[[a0:.*]]: !FHE.eint<5>, %[[a1:.*]]: !FHE.eint<5>) -> !FHE.eint<5> {
// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
// CHECK-NEXT: %[[v1:.*]] = "FHE.to_signed"(%[[a0]]) : (!FHE.eint<5>) -> !FHE.esint<5>
// CHECK-NEXT: %[[v2:.*]] = "FHE.to_signed"(%[[a1]]) : (!FHE.eint<5>) -> !FHE.esint<5>
// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[v1]], %[[v2]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
// CHECK-NEXT: %[[v4:.*]] = "FHE.apply_lookup_table"(%[[v3]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<5>
// CHECK-NEXT: %[[v5:.*]] = "FHE.add_eint"(%[[v4]], %[[a1]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
// CHECK-NEXT: return %[[v5]] : !FHE.eint<5>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}
// -----
// CHECK: func.func @main(%[[a0:.*]]: !FHE.esint<5>, %[[a1:.*]]: !FHE.esint<5>) -> !FHE.esint<5> {
// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
// CHECK-NEXT: %[[v1:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v1]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<5>
// CHECK-NEXT: %[[v3:.*]] = "FHE.add_eint"(%[[v2]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
// CHECK-NEXT: return %[[v3:.*]] : !FHE.esint<5>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.esint<5>, %arg1: !FHE.esint<5>) -> !FHE.esint<5> {
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
return %0 : !FHE.esint<5>
}

View File

@@ -1,15 +1,15 @@
// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
// RUN: concretecompiler --action=dump-tfhe --passes EncryptedMulToDoubleTLU --split-input-file %s 2>&1 | FileCheck %s
// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3>
// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: return %5 : !FHE.eint<3>
// CHECK: func.func @simple_eint(%[[a0:.*]]: !FHE.eint<3>, %[[a1:.*]]: !FHE.eint<3>) -> !FHE.eint<3> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.add_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v0]], %[[v1]]) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %[[v4:.*]] = "FHE.to_signed"(%[[v3]]) : (!FHE.eint<3>) -> !FHE.esint<3>
// CHECK-NEXT: %[[v5:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
// CHECK-NEXT: %[[v6:.*]] = "FHE.apply_lookup_table"(%[[v4]], %[[v5]]) : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %[[v7:.*]] = "FHE.sub_eint"(%[[v2]], %[[v6]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: return %[[v7]] : !FHE.eint<3>
// CHECK-NEXT: }
func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)

View File

@@ -0,0 +1,33 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// -----
// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs equal
func.func @bad_inputs_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<3>) -> !FHE.eint<5> {
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<3>) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
// -----
// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs equal
func.func @bad_inputs_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.esint<5>) -> !FHE.eint<5> {
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.esint<5>) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
// -----
// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<3> {
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.esint<5> {
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.esint<5>)
return %1: !FHE.esint<5>
}

View File

@@ -223,6 +223,24 @@ func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4>
func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>)
// CHECK-NEXT: return %[[v0]] : !FHE.eint<4>
return %0: !FHE.eint<4>
}
// CHECK-LABEL: func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4>
func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4>
%0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.esint<4>, !FHE.esint<4>) -> (!FHE.esint<4>)
// CHECK-NEXT: return %[[v0]] : !FHE.esint<4>
return %0: !FHE.esint<4>
}
// CHECK-LABEL: func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool
func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool {
// CHECK-NEXT: %[[V1:.*]] = "FHE.to_bool"(%arg0) : (!FHE.eint<1>) -> !FHE.ebool

View File

@@ -0,0 +1,73 @@
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
// -----
func.func @different_input_and_output_bit_widths(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output element type ('!FHE.eint<5>') to be the same with input element type ('!FHE.eint<7>') but it is not}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>>
return %0 : tensor<1x1x13x9x!FHE.eint<5>>
}
// -----
func.func @bad_input_dimensions(%arg0: tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected input to have 4 dimensions (N*C*H*W) but it has 2}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_output_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to have 4 dimensions (N*C*H*W) but it has 3}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>>
return %0 : tensor<1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_kernel_shape_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected kernel shape to be of shape (2) but it is of shape (3)}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2, 3]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_strides_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected strides to be of shape (2) but it is of shape (3)}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_strides_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of strides to be positive but strides[0] is -1}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[-1, 1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_dilations_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected dilations to be of shape (2) but it is of shape (3)}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_dilations_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of dilations to be positive but dilations[1] is -1}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, -1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}
// -----
func.func @bad_output_shape(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to be of shape (1, 1, 13, 9) but it is of shape (1, 1, 10, 5)}}
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>>
return %0 : tensor<1x1x10x5x!FHE.eint<7>>
}

View File

@@ -0,0 +1,12 @@
// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.maxpool2d"(%[[a0]]) {kernel_shape = dense<[4, 2]> : tensor<2xi64>} : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<1x1x13x9x!FHE.eint<7>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
}

View File

@@ -1,3 +1,977 @@
description: max_eint_unsigned
program: |
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %0 : !FHE.eint<4>
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 0
- scalar: 2
outputs:
- scalar: 2
- inputs:
- scalar: 0
- scalar: 3
outputs:
- scalar: 3
- inputs:
- scalar: 0
- scalar: 4
outputs:
- scalar: 4
- inputs:
- scalar: 0
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 0
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 0
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 1
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 2
outputs:
- scalar: 2
- inputs:
- scalar: 1
- scalar: 3
outputs:
- scalar: 3
- inputs:
- scalar: 1
- scalar: 4
outputs:
- scalar: 4
- inputs:
- scalar: 1
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 1
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 1
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 2
- scalar: 0
outputs:
- scalar: 2
- inputs:
- scalar: 2
- scalar: 1
outputs:
- scalar: 2
- inputs:
- scalar: 2
- scalar: 2
outputs:
- scalar: 2
- inputs:
- scalar: 2
- scalar: 3
outputs:
- scalar: 3
- inputs:
- scalar: 2
- scalar: 4
outputs:
- scalar: 4
- inputs:
- scalar: 2
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 2
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 2
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 3
- scalar: 0
outputs:
- scalar: 3
- inputs:
- scalar: 3
- scalar: 1
outputs:
- scalar: 3
- inputs:
- scalar: 3
- scalar: 2
outputs:
- scalar: 3
- inputs:
- scalar: 3
- scalar: 3
outputs:
- scalar: 3
- inputs:
- scalar: 3
- scalar: 4
outputs:
- scalar: 4
- inputs:
- scalar: 3
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 3
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 3
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 4
- scalar: 0
outputs:
- scalar: 4
- inputs:
- scalar: 4
- scalar: 1
outputs:
- scalar: 4
- inputs:
- scalar: 4
- scalar: 2
outputs:
- scalar: 4
- inputs:
- scalar: 4
- scalar: 3
outputs:
- scalar: 4
- inputs:
- scalar: 4
- scalar: 4
outputs:
- scalar: 4
- inputs:
- scalar: 4
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 4
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 4
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 5
- scalar: 0
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 1
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 2
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 3
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 4
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 5
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 5
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 6
- scalar: 0
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 1
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 2
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 3
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 4
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 5
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 6
outputs:
- scalar: 6
- inputs:
- scalar: 6
- scalar: 7
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 0
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 1
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 2
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 3
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 4
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 5
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 6
outputs:
- scalar: 7
- inputs:
- scalar: 7
- scalar: 7
outputs:
- scalar: 7
---
description: max_eint_signed
program: |
func.func @main(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> {
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4>
return %0 : !FHE.esint<4>
}
tests:
- inputs:
- scalar: -4
signed: true
- scalar: -4
signed: true
outputs:
- scalar: -4
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: -3
signed: true
outputs:
- scalar: -3
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: -2
signed: true
outputs:
- scalar: -2
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: -1
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: -4
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: -4
signed: true
outputs:
- scalar: -3
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: -3
signed: true
outputs:
- scalar: -3
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: -2
signed: true
outputs:
- scalar: -2
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: -1
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: -3
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: -4
signed: true
outputs:
- scalar: -2
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: -3
signed: true
outputs:
- scalar: -2
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: -2
signed: true
outputs:
- scalar: -2
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: -1
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: -2
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: -4
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: -3
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: -2
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: -1
signed: true
outputs:
- scalar: -1
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: -1
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: -4
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: -3
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: -2
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: -1
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 0
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 0
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: -4
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: -3
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: -2
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: -1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 1
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 1
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: -4
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: -3
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: -2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: -1
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 2
signed: true
- inputs:
- scalar: 2
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: -4
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: -3
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: -2
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: -1
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: 0
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: 1
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: 2
signed: true
outputs:
- scalar: 3
signed: true
- inputs:
- scalar: 3
signed: true
- scalar: 3
signed: true
outputs:
- scalar: 3
signed: true
---
# TODO: Rewrite/Remove
# The FHE.neg_eint op doesn't come with a well defined semantics as FHE.eint
# has an undefined behavior for under/overflow.

View File

@@ -380,3 +380,64 @@ tests:
20, 32, 44, 20, 32, 44, 20, 32, 44,
]
shape: [1, 3, 3, 3]
---
description: maxpool2d_unsigned_1x1x8x10_kernel_3x2
program: |
func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> }
: (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
return %0 : tensor<1x1x6x9x!FHE.eint<5>>
}
tests:
- inputs:
- tensor: [
3, 3, 11, 9, 0, 3, 14, 10, 5, 6,
10, 6, 4, 1, 10, 9, 11, 4, 0, 9,
8, 4, 10, 12, 11, 10, 9, 3, 10, 2,
8, 0, 11, 7, 5, 10, 8, 13, 9, 9,
9, 1, 15, 0, 6, 8, 6, 6, 6, 3,
6, 9, 10, 6, 0, 9, 13, 12, 6, 9,
12, 13, 7, 15, 7, 1, 9, 3, 13, 6,
2, 11, 14, 8, 11, 1, 11, 0, 0, 15,
]
shape: [1, 1, 8, 10]
outputs:
- tensor: [
10, 11, 12, 12, 11, 14, 14, 10, 10,
10, 11, 12, 12, 11, 11, 13, 13, 10,
9, 15, 15, 12, 11, 10, 13, 13, 10,
9, 15, 15, 7, 10, 13, 13, 13, 9,
13, 15, 15, 15, 9, 13, 13, 13, 13,
13, 14, 15, 15, 11, 13, 13, 13, 15,
]
shape: [1, 1, 6, 9]
---
description: maxpool2d_signed_1x1x6x5_kernel_2x3
program: |
func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> }
: (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
return %0 : tensor<1x1x5x3x!FHE.esint<6>>
}
tests:
- inputs:
- tensor: [
-8, -12, -8, -12, -10,
1, -9, -15, -16, 14,
9, 14, 2, 2, -15,
9, -12, 0, -4, -5,
-7, -11, -15, -4, 6,
15, -3, 7, -13, -13,
]
shape: [1, 1, 6, 5]
signed: true
outputs:
- tensor: [
1, -8, 14,
14, 14, 14,
14, 14, 2,
9, 0, 6,
15, 7, 7,
]
shape: [1, 1, 5, 3]
signed: true