diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index a0245f049..dc2e9d6dc 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -323,6 +323,41 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> { let hasVerifier = 1; } +def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> { + let summary = "Get maximum of two encrypted integers."; + + let description = [{ + Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'. + Type of inputs and the output should be the same. + + If `x - y`` inside the max overflows or underflows, the behavior is undefined. + So to support the full range, you should increase the bit-width by 1 manually. + + Example: + ```mlir + // ok + "FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + "FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3> + + // error + "FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2> + "FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2> + "FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2> + ``` + }]; + + let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y); + let results = (outs FHE_AnyEncryptedInteger); + + let builders = [ + OpBuilder<(ins "Value":$x, "Value":$y), [{ + build($_builder, $_state, x.getType(), x, y); + }]> + ]; + + let hasVerifier = 1; +} + def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> { let summary = "Cast an unsigned integer to a signed one"; diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt index b237690f5..64b6fcf2d 100644 --- a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -4,3 +4,4 @@ add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen) add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen) add_subdirectory(BigInt) add_subdirectory(Boolean) +add_subdirectory(Max) diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Max/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/CMakeLists.txt new file mode 100644 index 000000000..cea521777 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Max.td) +mlir_tablegen(Max.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangFHEMaxPassIncGen) +add_dependencies(mlir-headers ConcretelangFHEMaxPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.h b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.h new file mode 100644 index 000000000..ca6a4c2ac --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.h @@ -0,0 +1,23 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_FHE_MAX_PASS_H +#define CONCRETELANG_FHE_MAX_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { + +std::unique_ptr> createFHEMaxTransformPass(); + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.td b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.td new file mode 100644 index 000000000..7f102ee57 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Max/Max.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_FHE_MAX_PASS +#define CONCRETELANG_FHE_MAX_PASS + +include "mlir/Pass/PassBase.td" + +def FHEMaxTransform : Pass<"fhe-max-transform"> { + let summary = "Transform max operation to basic operations"; + let constructor = "mlir::concretelang::createFHEMaxTransformPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index fe50fb0ac..f509b16f4 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -944,6 +944,18 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { let hasVerifier = 1; } +def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> { + let summary = "Returns the 2D maxpool of a tensor in the form NCHW"; + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$input, + I64ElementsAttr:$kernel_shape, + OptionalAttr:$strides, + OptionalAttr:$dilations + ); + let results = (outs Type.predicate, HasStaticShapePred]>>); + let hasVerifier = 1; +} + def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { let summary = "Returns a tensor that contains the transposition of the input tensor."; diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index abbb9a80f..c5b4f9297 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1771,6 +1771,81 @@ struct FHELinalgConv2dToLinalgConv2d }; }; +/// This rewrite pattern transforms all instances +/// of `FHELinalg.maxpool2d` to `linalg.pooling_ncw_max`. +struct FHELinalgMaxpool2dToLinalgMaxpool2d + : public mlir::OpRewritePattern { + + FHELinalgMaxpool2dToLinalgMaxpool2d(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(FHELinalg::Maxpool2dOp maxpool2dOp, + mlir::PatternRewriter &rewriter) const override { + + const mlir::Location loc = maxpool2dOp->getLoc(); + + const mlir::NamedAttribute maxOpAttr = rewriter.getNamedAttr( + "max_signed", + rewriter.getStringAttr(FHE::MaxEintOp::getOperationName())); + + const auto outputTy = + maxpool2dOp->getResult(0).getType().cast(); + const auto outputElementTy = + outputTy.getElementType().cast(); + + mlir::Value output = + rewriter.create(loc, outputTy).getResult(); + + if (outputElementTy.isSigned()) { + const int64_t outputBitWidth = outputElementTy.getWidth(); + const int64_t offsetValue = 1 << (outputBitWidth - 2); + + const mlir::Type offsetType = + mlir::IntegerType::get(this->getContext(), outputBitWidth + 1); + const mlir::Type offsetTensorType = + mlir::RankedTensorType::get({1}, offsetType); + + const llvm::SmallVector offsetTensorAttr = { + mlir::IntegerAttr::get(offsetType, offsetValue)}; + const mlir::Attribute offsetAttr = + mlir::DenseElementsAttr::get(offsetTensorType, offsetTensorAttr); + + const mlir::Value offset = + rewriter.create(loc, offsetAttr); + + output = rewriter.create(loc, output, offset); + } + + const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape(); + const auto kernelShape = + llvm::SmallVector(kernelShapeAttr.value_begin(), + kernelShapeAttr.value_end()); + + const mlir::Value kernel = + rewriter + .create( + loc, kernelShape, + mlir::IntegerType::get(this->getContext(), 64)) + .getResult(); + + const mlir::DenseIntElementsAttr defaultAttr = + rewriter.getI64VectorAttr({1, 1}); + + const mlir::DenseIntElementsAttr stridesAttr = + maxpool2dOp.dilations().getValueOr(defaultAttr); + const mlir::DenseIntElementsAttr dilationsAttr = + maxpool2dOp.dilations().getValueOr(defaultAttr); + + rewriter.replaceOpWithNewOp( + maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel}, + output, stridesAttr, dilationsAttr, + llvm::ArrayRef({maxOpAttr})); + + return mlir::success(); + }; +}; + /// This template rewrite pattern transforms any instance of /// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an /// appropriate region using `FHE.to_signed` operation, an appropriate @@ -2019,6 +2094,7 @@ void FHETensorOpsToLinalg::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index c64c88c79..7a02bad8b 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -77,6 +77,10 @@ struct FunctionToDag { for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { addOperation(dag, op); + } + } + for (auto &bb : func.getBody().getBlocks()) { + for (auto &op : bb.getOperations()) { op.removeAttr("SMANP"); } } @@ -145,6 +149,14 @@ struct FunctionToDag { // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); } + if (auto max = asMax(op)) { + addMax(dag, max, encrypted_inputs, precision); + return; + } + if (auto maxpool2d = asMaxpool2d(op)) { + addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision); + return; + } // default addLevelledOp(dag, op, encrypted_inputs); } @@ -207,6 +219,103 @@ struct FunctionToDag { manp, slice(out_shape), comment); } + void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs, + int precision) { + mlir::Value result = maxOp.getResult(); + const std::vector resultShape = getShape(result); + + Operation *xOp = maxOp.x().getDefiningOp(); + Operation *yOp = maxOp.y().getDefiningOp(); + + const double fixedCost = NEGLIGIBLE_COMPLEXITY; + const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; + + llvm::APInt xSmanp = llvm::APInt{1, 1, false}; + if (xOp != nullptr) { + const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); + assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); + xSmanp = xSmanpAttr.getValue(); + } + + llvm::APInt ySmanp = llvm::APInt{1, 1, false}; + if (yOp != nullptr) { + const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); + assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); + ySmanp = ySmanpAttr.getValue(); + } + + const double subManp = + sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); + + auto loc = loc_to_string(maxOp.getLoc()); + auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc; + + auto subNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + subManp, slice(resultShape), comment); + + const double tluNodeManp = 1; + const std::vector unknownFunction; + auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + + const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble()); + const std::vector addInputs = { + tluNode, inputs[1]}; + index[result] = + dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, + addManp, slice(resultShape), comment); + } + + void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp, + Inputs &inputs, int precision) { + mlir::Value result = maxpool2dOp.getResult(); + const std::vector resultShape = getShape(result); + + // all TLUs are flattened into a dimension + // to create a single TLU node in optimizer dag + std::vector fakeShape = resultShape; + + uint64_t numberOfComparisons = 1; + for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues()) { + numberOfComparisons *= dimensionSize; + } + fakeShape.push_back(numberOfComparisons); + + Operation *inputOp = maxpool2dOp.input().getDefiningOp(); + + const double fixedCost = NEGLIGIBLE_COMPLEXITY; + const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; + + llvm::APInt inputSmanp = llvm::APInt{1, 1, false}; + if (inputOp != nullptr) { + const auto inputSmanpAttr = + inputOp->getAttrOfType("SMANP"); + assert(inputSmanpAttr && "Missing SMANP value on a crypto operation"); + inputSmanp = inputSmanpAttr.getValue(); + } + + const double subManp = sqrt(2 * inputSmanp.roundToDouble() + 1); + + auto loc = loc_to_string(maxpool2dOp.getLoc()); + auto comment = + std::string(maxpool2dOp->getName().getStringRef()) + " " + loc; + + auto subNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + subManp, slice(fakeShape), comment); + + const std::vector unknownFunction; + auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + + const double addManp = sqrt(inputSmanp.roundToDouble() + 1); + const std::vector addInputs = { + tluNode, inputs[1]}; + + index[result] = + dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, + addManp, slice(resultShape), comment); + } + Inputs encryptedInputs(mlir::Operation &op) { Inputs inputs; for (auto operand : op.getOperands()) { @@ -237,6 +346,14 @@ struct FunctionToDag { return llvm::dyn_cast(op); } + mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + + mlir::concretelang::FHELinalg::Maxpool2dOp asMaxpool2d(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + bool isReturn(mlir::Operation &op) { return llvm::isa(op); } diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 2bda48f8c..d69de6079 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -510,6 +510,30 @@ static llvm::APInt getSqMANP( return eNorm; } +/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +/// that is equivalent to an `FHE.max_eint` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::MaxEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + // max(x, y) = max(x - y, 0) + y + + const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue(); + const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue(); + + const llvm::APInt sub = APIntWidthExtendUAdd(x, y); + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt add = APIntWidthExtendUAdd(tlu, y); + + // this is not optimal as it can increase the resulting noise unnecessarily + return APIntUMax(add, sub); +} + /// Calculates the squared Minimal Arithmetic Noise Padding of an /// `FHELinalg.add_eint_int` operation. static llvm::APInt getSqMANP( @@ -1153,6 +1177,28 @@ static llvm::APInt getSqMANP( return accNorm; } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::Maxpool2dOp op, + llvm::ArrayRef *> operandMANPs) { + + // maximum between two value is calculated using + // - max(x - y, 0) + y + + // max is calculated with a TLU so MANP is {1, 1, false} + // y on the other hand comes from the input or from the previous result + + // in the current implementation, it's the input + // so the resulting MANP is `{1, 1, false} + MANP input` + + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt input = operandMANPs[0]->getValue().getMANP().getValue(); + + const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input); + const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input); + + return APIntUMax(forIntermediate, forResult); +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -1202,6 +1248,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto roundOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(roundOp, operands); + } else if (auto maxEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(maxEintOp, operands); } else if (llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || @@ -1264,6 +1313,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(conv2dOp, operands); + } else if (auto maxpool2dOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(maxpool2dOp, operands); } else if (auto fromElementOp = llvm::dyn_cast( op)) { diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 0e043c2bb..dfab4b8c9 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -192,6 +192,24 @@ mlir::LogicalResult MulEintOp::verify() { return ::mlir::success(); } +mlir::LogicalResult MaxEintOp::verify() { + auto xTy = this->x().getType().dyn_cast(); + auto yTy = this->y().getType().dyn_cast(); + auto outTy = this->getResult().getType().dyn_cast(); + + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), + xTy, outTy)) { + return mlir::failure(); + } + + if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), xTy, + yTy)) { + return mlir::failure(); + } + + return mlir::success(); +} + mlir::LogicalResult ToSignedOp::verify() { auto input = this->input().getType().cast(); auto output = this->getResult().getType().cast(); diff --git a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt index 51b71f1f2..dc2b5cd41 100644 --- a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library( FHEDialectTransforms BigInt.cpp Boolean.cpp + Max.cpp EncryptedMulToDoubleTLU.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE diff --git a/compiler/lib/Dialect/FHE/Transforms/Max.cpp b/compiler/lib/Dialect/FHE/Transforms/Max.cpp new file mode 100644 index 000000000..4c7bff444 --- /dev/null +++ b/compiler/lib/Dialect/FHE/Transforms/Max.cpp @@ -0,0 +1,111 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "concretelang/Dialect/FHE/IR/FHEOps.h" +#include "concretelang/Dialect/FHE/Transforms/Max/Max.h" + +namespace arith = mlir::arith; +namespace func = mlir::func; + +namespace FHE = mlir::concretelang::FHE; + +/// This rewrite pattern transforms all instances +/// of `FHE.max_eint` to `max(x - y, 0) + y`. +struct MaxEintPattern : public mlir::OpRewritePattern { + MaxEintPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(FHE::MaxEintOp maxEintOp, + mlir::PatternRewriter &rewriter) const override { + + const mlir::Location loc = maxEintOp->getLoc(); + + const FHE::FheIntegerInterface outputTy = + maxEintOp->getResult(0).getType().cast(); + const int64_t outputBitWidth = outputTy.getWidth(); + + mlir::Value x = maxEintOp.x(); + mlir::Value y = maxEintOp.y(); + + const auto xTy = x.getType().cast(); + const auto yTy = y.getType().cast(); + + const auto signedTy = FHE::EncryptedSignedIntegerType::get( + this->getContext(), outputBitWidth); + + if (xTy.isUnsigned()) { + x = rewriter.create(loc, signedTy, x).getResult(); + } + if (yTy.isUnsigned()) { + y = rewriter.create(loc, signedTy, y).getResult(); + } + + const mlir::Value sub = + rewriter.create(loc, x, y).getResult(); + + const int64_t lutSize = 1 << outputBitWidth; + + auto lutValues = std::vector(); + for (int64_t i = 0; i < lutSize / 2; i++) { + lutValues.push_back(i); + } + for (int64_t i = 0; i < lutSize / 2; i++) { + lutValues.push_back(0); + } + + const mlir::Attribute lutAttr = rewriter.getI64TensorAttr(lutValues); + const mlir::Value lut = + rewriter.create(loc, lutAttr).getResult(); + + const mlir::Value max = + rewriter.create(loc, outputTy, sub, lut) + .getResult(); + + const mlir::Value add = + rewriter.create(loc, max, maxEintOp.y()).getResult(); + + rewriter.replaceOp(maxEintOp, {add}); + return mlir::success(); + }; +}; + +namespace { + +struct FHEMaxTransform : public FHEMaxTransformBase { + void runOnOperation() final; +}; + +void FHEMaxTransform::runOnOperation() { + auto target = mlir::ConversionTarget(this->getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + auto patterns = mlir::RewritePatternSet(&this->getContext()); + patterns.insert(&this->getContext()); + + mlir::Operation *op = this->getOperation(); + if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +} // namespace + +namespace mlir { +namespace concretelang { + +std::unique_ptr> createFHEMaxTransformPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 39635a3da..ed4704ed7 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1048,6 +1048,154 @@ mlir::LogicalResult Conv2dOp::verify() { return mlir::success(); } +mlir::LogicalResult Maxpool2dOp::verify() { + const mlir::RankedTensorType inputTy = + this->input().getType().cast(); + const mlir::RankedTensorType outputTy = + this->getResult().getType().cast(); + + const FHE::FheIntegerInterface inputElementTy = inputTy.getElementType(); + const FHE::FheIntegerInterface outputElementTy = outputTy.getElementType(); + + if (inputElementTy != outputElementTy) { + this->emitOpError() << "expected output element type " + << "(" << outputElementTy << ") " + << "to be the same with input element type " + << "(" << inputElementTy << ") " + << "but it is not"; + return mlir::failure(); + } + + const llvm::ArrayRef inputShape = inputTy.getShape(); + const llvm::ArrayRef outputShape = outputTy.getShape(); + + if (inputShape.size() != 4) { + this->emitOpError() << "expected input to have 4 dimensions (N*C*H*W) " + << "but it has " << inputShape.size(); + return mlir::failure(); + } + if (outputShape.size() != 4) { + this->emitOpError() << "expected output to have 4 dimensions (N*C*H*W) " + << "but it has " << outputShape.size(); + return mlir::failure(); + } + + const int64_t inputN = inputShape[0]; + const int64_t inputC = inputShape[1]; + const int64_t inputH = inputShape[2]; + const int64_t inputW = inputShape[3]; + + const mlir::DenseIntElementsAttr kernelShapeAttr = this->kernel_shape(); + const mlir::RankedTensorType kernelShapeAttrTy = + kernelShapeAttr.getType().cast(); + const llvm::ArrayRef kernelShapeAttrShape = + kernelShapeAttrTy.getShape(); + + if (kernelShapeAttrShape.size() != 1 || kernelShapeAttrShape[0] != 2) { + this->emitOpError() << "expected kernel shape to be of shape " + << "(2) " + << "but it is of shape " + << "(" << kernelShapeAttrShape << ")"; + return mlir::failure(); + } + + mlir::SmallVector kernelShape; + kernelShape.append(kernelShapeAttr.value_begin(), + kernelShapeAttr.value_end()); + + const int64_t kernelShapeH = kernelShape[0]; + const int64_t kernelShapeW = kernelShape[1]; + + mlir::SmallVector strides; + const llvm::Optional maybeStridesAttr = + this->strides(); + if (maybeStridesAttr.hasValue()) { + const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.getValue(); + const mlir::RankedTensorType stridesAttrTy = + stridesAttr.getType().cast(); + const llvm::ArrayRef stridesAttrShape = stridesAttrTy.getShape(); + + if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) { + this->emitOpError() << "expected strides to be of shape " + << "(2) " + << "but it is of shape " + << "(" << stridesAttrShape << ")"; + return mlir::failure(); + } + + strides.append(stridesAttr.value_begin(), + stridesAttr.value_end()); + } else { + strides.append({1, 1}); + } + for (size_t i = 0; i < 2; i++) { + if (strides[i] < 1) { + this->emitOpError() << "expected elements of strides to be positive " + << "but strides[" << i << "] is " << strides[i]; + return mlir::failure(); + } + } + + const int64_t stridesH = strides[0]; + const int64_t stridesW = strides[1]; + + mlir::SmallVector dilations; + const llvm::Optional maybeDilationsAttr = + this->dilations(); + if (maybeDilationsAttr.hasValue()) { + const mlir::DenseIntElementsAttr dilationsAttr = + maybeDilationsAttr.getValue(); + const mlir::RankedTensorType dilationsAttrTy = + dilationsAttr.getType().cast(); + const llvm::ArrayRef dilationsAttrShape = + dilationsAttrTy.getShape(); + + if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) { + this->emitOpError() << "expected dilations to be of shape " + << "(2) " + << "but it is of shape " + << "(" << dilationsAttrShape << ")"; + return mlir::failure(); + } + + dilations.append(dilationsAttr.value_begin(), + dilationsAttr.value_end()); + } else { + dilations.append({1, 1}); + } + for (size_t i = 0; i < 2; i++) { + if (dilations[i] < 1) { + this->emitOpError() << "expected elements of dilations to be positive " + << "but dilations[" << i << "] is " << dilations[i]; + return mlir::failure(); + } + } + + const int64_t dilationsH = dilations[0]; + const int64_t dilationsW = dilations[1]; + + const int64_t expectedOutputH = + floor((inputH - dilationsH * (kernelShapeH - 1) - 1) / stridesH) + 1; + const int64_t expectedOutputW = + floor((inputW - dilationsW * (kernelShapeW - 1) - 1) / stridesW) + 1; + const mlir::SmallVector expectedOutputShape = { + inputN, + inputC, + expectedOutputH, + expectedOutputW, + }; + + if (outputShape != llvm::makeArrayRef(expectedOutputShape)) { + this->emitOpError() << "expected output to be of shape " + << "(" << expectedOutputShape << ") " + << "but it is of shape " + << "(" << outputShape << ")"; + return mlir::failure(); + } + + return mlir::success(); +} + mlir::LogicalResult FromElementOp::verify() { mlir::Value in = this->getOperand(); mlir::Value out = this->getResult(); diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 78d338503..b253aca70 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -293,13 +293,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Transforming FHE boolean ops failed"); } - // Encrypted mul rewriting - if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext, - module, enablePass) - .failed()) { - return StreamStringError("Rewriting of encrypted mul failed"); - } - if (options.chunkIntegers) { if (mlir::concretelang::pipeline::transformFHEBigInt( mlirContext, module, enablePass, options.chunkSize, @@ -344,6 +337,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { .failed()) { return errorDiag("Lowering from FHELinalg to FHE failed"); } + + if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext, + module, enablePass) + .failed()) { + return StreamStringError("Rewriting of high level fhe ops failed"); + } + if (target == Target::FHE_NO_LINALG) return std::move(res); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index a82802ab3..41b83ee8d 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -181,7 +182,9 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { mlir::PassManager pm(&context); pipelinePrinting("transformHighLevelFHEOps", pm, context); + addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass); + addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass); return pm.run(module.getOperation()); } diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/maxpool2d.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/maxpool2d.mlir new file mode 100644 index 000000000..ff31daf41 --- /dev/null +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/maxpool2d.mlir @@ -0,0 +1,38 @@ +// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x6x9x!FHE.eint<5>> +// CHECK-NEXT: %[[v1:.*]] = linalg.init_tensor [3, 2] : tensor<3x2xi64> +// CHECK-NEXT: %[[v2:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %1 : tensor<1x1x8x10x!FHE.eint<5>>, tensor<3x2xi64>) outs(%0 : tensor<1x1x6x9x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> +// CHECK-NEXT: return %[[v2]] : tensor<1x1x6x9x!FHE.eint<5>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> { + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> } : (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> + return %0 : tensor<1x1x6x9x!FHE.eint<5>> +} + +// ----- + + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (0)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>> +// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7> +// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>> +// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.esint<6> +// CHECK-NEXT: } -> tensor<1x1x5x3x!FHE.esint<6>> +// CHECK-NEXT: %[[v4:.*]] = linalg.init_tensor [2, 3] : tensor<2x3xi64> +// CHECK-NEXT: %[[v5:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %[[v4]] : tensor<1x1x6x5x!FHE.esint<6>>, tensor<2x3xi64>) outs(%[[v3]] : tensor<1x1x5x3x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> +// CHECK-NEXT: return %[[v5]] : tensor<1x1x5x3x!FHE.esint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> { + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> } : (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> + return %0 : tensor<1x1x5x3x!FHE.esint<6>> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/Transforms/max_eint.mlir b/compiler/tests/check_tests/Dialect/FHE/Transforms/max_eint.mlir new file mode 100644 index 000000000..671d4f368 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Transforms/max_eint.mlir @@ -0,0 +1,31 @@ +// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-max-transform %s 2>&1 | FileCheck %s + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: !FHE.eint<5>, %[[a1:.*]]: !FHE.eint<5>) -> !FHE.eint<5> { +// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64> +// CHECK-NEXT: %[[v1:.*]] = "FHE.to_signed"(%[[a0]]) : (!FHE.eint<5>) -> !FHE.esint<5> +// CHECK-NEXT: %[[v2:.*]] = "FHE.to_signed"(%[[a1]]) : (!FHE.eint<5>) -> !FHE.esint<5> +// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[v1]], %[[v2]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5> +// CHECK-NEXT: %[[v4:.*]] = "FHE.apply_lookup_table"(%[[v3]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<5> +// CHECK-NEXT: %[[v5:.*]] = "FHE.add_eint"(%[[v4]], %[[a1]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> +// CHECK-NEXT: return %[[v5]] : !FHE.eint<5> +// CHECK-NEXT: } +func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> +} + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: !FHE.esint<5>, %[[a1:.*]]: !FHE.esint<5>) -> !FHE.esint<5> { +// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64> +// CHECK-NEXT: %[[v1:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5> +// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v1]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<5> +// CHECK-NEXT: %[[v3:.*]] = "FHE.add_eint"(%[[v2]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5> +// CHECK-NEXT: return %[[v3:.*]] : !FHE.esint<5> +// CHECK-NEXT: } +func.func @main(%arg0: !FHE.esint<5>, %arg1: !FHE.esint<5>) -> !FHE.esint<5> { + %0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5> + return %0 : !FHE.esint<5> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir b/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir index 0b11156ec..4b75cf074 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Transforms/mul_eint.mlir @@ -1,15 +1,15 @@ -// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --action=dump-tfhe --passes EncryptedMulToDoubleTLU --split-input-file %s 2>&1 | FileCheck %s -// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> { -// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64> -// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64> -// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> -// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3> -// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> -// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3> -// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3> -// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> -// CHECK-NEXT: return %5 : !FHE.eint<3> +// CHECK: func.func @simple_eint(%[[a0:.*]]: !FHE.eint<3>, %[[a1:.*]]: !FHE.eint<3>) -> !FHE.eint<3> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.add_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64> +// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v0]], %[[v1]]) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3> +// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: %[[v4:.*]] = "FHE.to_signed"(%[[v3]]) : (!FHE.eint<3>) -> !FHE.esint<3> +// CHECK-NEXT: %[[v5:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64> +// CHECK-NEXT: %[[v6:.*]] = "FHE.apply_lookup_table"(%[[v4]], %[[v5]]) : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3> +// CHECK-NEXT: %[[v7:.*]] = "FHE.sub_eint"(%[[v2]], %[[v6]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> +// CHECK-NEXT: return %[[v7]] : !FHE.eint<3> // CHECK-NEXT: } func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> { %0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>) diff --git a/compiler/tests/check_tests/Dialect/FHE/max_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/max_eint.invalid.mlir new file mode 100644 index 000000000..99934c0af --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/max_eint.invalid.mlir @@ -0,0 +1,33 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// ----- + +// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs equal +func.func @bad_inputs_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<3>) -> !FHE.eint<5> { + %1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<3>) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs equal +func.func @bad_inputs_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.esint<5>) -> !FHE.eint<5> { + %1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.esint<5>) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<3> { + %1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.esint<5> { + %1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.esint<5>) + return %1: !FHE.esint<5> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 96cb3e0f3..9a00b33ea 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -223,6 +223,24 @@ func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> +func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4> + %0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>) + + // CHECK-NEXT: return %[[v0]] : !FHE.eint<4> + return %0: !FHE.eint<4> +} + +// CHECK-LABEL: func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> +func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4> + %0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.esint<4>, !FHE.esint<4>) -> (!FHE.esint<4>) + + // CHECK-NEXT: return %[[v0]] : !FHE.esint<4> + return %0: !FHE.esint<4> +} + // CHECK-LABEL: func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool { // CHECK-NEXT: %[[V1:.*]] = "FHE.to_bool"(%arg0) : (!FHE.eint<1>) -> !FHE.ebool diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.invalid.mlir new file mode 100644 index 000000000..7f94e3e10 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.invalid.mlir @@ -0,0 +1,73 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s + +// ----- + +func.func @different_input_and_output_bit_widths(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected output element type ('!FHE.eint<5>') to be the same with input element type ('!FHE.eint<7>') but it is not}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>> + return %0 : tensor<1x1x13x9x!FHE.eint<5>> +} + +// ----- + +func.func @bad_input_dimensions(%arg0: tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected input to have 4 dimensions (N*C*H*W) but it has 2}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_output_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to have 4 dimensions (N*C*H*W) but it has 3}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>> + return %0 : tensor<1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_kernel_shape_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected kernel shape to be of shape (2) but it is of shape (3)}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2, 3]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_strides_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected strides to be of shape (2) but it is of shape (3)}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_strides_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of strides to be positive but strides[0] is -1}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[-1, 1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_dilations_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected dilations to be of shape (2) but it is of shape (3)}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_dilations_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of dilations to be positive but dilations[1] is -1}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, -1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} + +// ----- + +func.func @bad_output_shape(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to be of shape (1, 1, 13, 9) but it is of shape (1, 1, 10, 5)}} + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>> + return %0 : tensor<1x1x10x5x!FHE.eint<7>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.mlir new file mode 100644 index 000000000..4170b2742 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.mlir @@ -0,0 +1,12 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.maxpool2d"(%[[a0]]) {kernel_shape = dense<[4, 2]> : tensor<2xi64>} : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x1x13x9x!FHE.eint<7>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> { + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> + return %0 : tensor<1x1x13x9x!FHE.eint<7>> +} diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml index 99fb81bc1..dc71b7feb 100644 --- a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml @@ -1,3 +1,977 @@ +description: max_eint_unsigned +program: | + func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> { + %0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4> + return %0 : !FHE.eint<4> + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + + - inputs: + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 1 + + - inputs: + - scalar: 0 + - scalar: 2 + outputs: + - scalar: 2 + + - inputs: + - scalar: 0 + - scalar: 3 + outputs: + - scalar: 3 + + - inputs: + - scalar: 0 + - scalar: 4 + outputs: + - scalar: 4 + + - inputs: + - scalar: 0 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 0 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 0 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 1 + + - inputs: + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 1 + + - inputs: + - scalar: 1 + - scalar: 2 + outputs: + - scalar: 2 + + - inputs: + - scalar: 1 + - scalar: 3 + outputs: + - scalar: 3 + + - inputs: + - scalar: 1 + - scalar: 4 + outputs: + - scalar: 4 + + - inputs: + - scalar: 1 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 1 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 1 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 2 + - scalar: 0 + outputs: + - scalar: 2 + + - inputs: + - scalar: 2 + - scalar: 1 + outputs: + - scalar: 2 + + - inputs: + - scalar: 2 + - scalar: 2 + outputs: + - scalar: 2 + + - inputs: + - scalar: 2 + - scalar: 3 + outputs: + - scalar: 3 + + - inputs: + - scalar: 2 + - scalar: 4 + outputs: + - scalar: 4 + + - inputs: + - scalar: 2 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 2 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 2 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 3 + - scalar: 0 + outputs: + - scalar: 3 + + - inputs: + - scalar: 3 + - scalar: 1 + outputs: + - scalar: 3 + + - inputs: + - scalar: 3 + - scalar: 2 + outputs: + - scalar: 3 + + - inputs: + - scalar: 3 + - scalar: 3 + outputs: + - scalar: 3 + + - inputs: + - scalar: 3 + - scalar: 4 + outputs: + - scalar: 4 + + - inputs: + - scalar: 3 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 3 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 3 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 4 + - scalar: 0 + outputs: + - scalar: 4 + + - inputs: + - scalar: 4 + - scalar: 1 + outputs: + - scalar: 4 + + - inputs: + - scalar: 4 + - scalar: 2 + outputs: + - scalar: 4 + + - inputs: + - scalar: 4 + - scalar: 3 + outputs: + - scalar: 4 + + - inputs: + - scalar: 4 + - scalar: 4 + outputs: + - scalar: 4 + + - inputs: + - scalar: 4 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 4 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 4 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 5 + - scalar: 0 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 1 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 2 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 3 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 4 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 5 + outputs: + - scalar: 5 + + - inputs: + - scalar: 5 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 5 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 6 + - scalar: 0 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 1 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 2 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 3 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 4 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 5 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 6 + outputs: + - scalar: 6 + + - inputs: + - scalar: 6 + - scalar: 7 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 0 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 1 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 2 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 3 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 4 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 5 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 6 + outputs: + - scalar: 7 + + - inputs: + - scalar: 7 + - scalar: 7 + outputs: + - scalar: 7 +--- +description: max_eint_signed +program: | + func.func @main(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> { + %0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4> + return %0 : !FHE.esint<4> + } +tests: + - inputs: + - scalar: -4 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: -4 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: -3 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: -2 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: -4 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: -3 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: -3 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: -2 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: -3 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: -2 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: -2 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: -2 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: -2 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: -1 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: -1 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 0 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 0 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 1 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 1 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 2 + signed: true + + - inputs: + - scalar: 2 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: -4 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: -3 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: -2 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: -1 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: 0 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: 1 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: 2 + signed: true + outputs: + - scalar: 3 + signed: true + + - inputs: + - scalar: 3 + signed: true + - scalar: 3 + signed: true + outputs: + - scalar: 3 + signed: true +--- # TODO: Rewrite/Remove # The FHE.neg_eint op doesn't come with a well defined semantics as FHE.eint # has an undefined behavior for under/overflow. diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml index ed08c2e1b..634b74c8b 100644 --- a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml @@ -380,3 +380,64 @@ tests: 20, 32, 44, 20, 32, 44, 20, 32, 44, ] shape: [1, 3, 3, 3] +--- +description: maxpool2d_unsigned_1x1x8x10_kernel_3x2 +program: | + func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> { + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> } + : (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> + return %0 : tensor<1x1x6x9x!FHE.eint<5>> + } +tests: + - inputs: + - tensor: [ + 3, 3, 11, 9, 0, 3, 14, 10, 5, 6, + 10, 6, 4, 1, 10, 9, 11, 4, 0, 9, + 8, 4, 10, 12, 11, 10, 9, 3, 10, 2, + 8, 0, 11, 7, 5, 10, 8, 13, 9, 9, + 9, 1, 15, 0, 6, 8, 6, 6, 6, 3, + 6, 9, 10, 6, 0, 9, 13, 12, 6, 9, + 12, 13, 7, 15, 7, 1, 9, 3, 13, 6, + 2, 11, 14, 8, 11, 1, 11, 0, 0, 15, + ] + shape: [1, 1, 8, 10] + outputs: + - tensor: [ + 10, 11, 12, 12, 11, 14, 14, 10, 10, + 10, 11, 12, 12, 11, 11, 13, 13, 10, + 9, 15, 15, 12, 11, 10, 13, 13, 10, + 9, 15, 15, 7, 10, 13, 13, 13, 9, + 13, 15, 15, 15, 9, 13, 13, 13, 13, + 13, 14, 15, 15, 11, 13, 13, 13, 15, + ] + shape: [1, 1, 6, 9] +--- +description: maxpool2d_signed_1x1x6x5_kernel_2x3 +program: | + func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> { + %0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> } + : (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> + return %0 : tensor<1x1x5x3x!FHE.esint<6>> + } +tests: + - inputs: + - tensor: [ + -8, -12, -8, -12, -10, + 1, -9, -15, -16, 14, + 9, 14, 2, 2, -15, + 9, -12, 0, -4, -5, + -7, -11, -15, -4, 6, + 15, -3, 7, -13, -13, + ] + shape: [1, 1, 6, 5] + signed: true + outputs: + - tensor: [ + 1, -8, 14, + 14, 14, 14, + 14, 14, 2, + 9, 0, 6, + 15, 7, 7, + ] + shape: [1, 1, 5, 3] + signed: true