mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: implement maxpool2d operation
This commit is contained in:
@@ -323,6 +323,41 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
|
||||
let summary = "Get maximum of two encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'.
|
||||
Type of inputs and the output should be the same.
|
||||
|
||||
If `x - y`` inside the max overflows or underflows, the behavior is undefined.
|
||||
So to support the full range, you should increase the bit-width by 1 manually.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3>
|
||||
|
||||
// error
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$x, "Value":$y), [{
|
||||
build($_builder, $_state, x.getType(), x, y);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
let summary = "Cast an unsigned integer to a signed one";
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@ add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_subdirectory(BigInt)
|
||||
add_subdirectory(Boolean)
|
||||
add_subdirectory(Max)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Max.td)
|
||||
mlir_tablegen(Max.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEMaxPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEMaxPassIncGen)
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_MAX_PASS_H
|
||||
#define CONCRETELANG_FHE_MAX_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass();
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_FHE_MAX_PASS
|
||||
#define CONCRETELANG_FHE_MAX_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHEMaxTransform : Pass<"fhe-max-transform"> {
|
||||
let summary = "Transform max operation to basic operations";
|
||||
let constructor = "mlir::concretelang::createFHEMaxTransformPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -944,6 +944,18 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> {
|
||||
let summary = "Returns the 2D maxpool of a tensor in the form NCHW";
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input,
|
||||
I64ElementsAttr:$kernel_shape,
|
||||
OptionalAttr<I64ElementsAttr>:$strides,
|
||||
OptionalAttr<I64ElementsAttr>:$dilations
|
||||
);
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
|
||||
let summary = "Returns a tensor that contains the transposition of the input tensor.";
|
||||
|
||||
|
||||
@@ -1771,6 +1771,81 @@ struct FHELinalgConv2dToLinalgConv2d
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms all instances
|
||||
/// of `FHELinalg.maxpool2d` to `linalg.pooling_ncw_max`.
|
||||
struct FHELinalgMaxpool2dToLinalgMaxpool2d
|
||||
: public mlir::OpRewritePattern<FHELinalg::Maxpool2dOp> {
|
||||
|
||||
FHELinalgMaxpool2dToLinalgMaxpool2d(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<FHELinalg::Maxpool2dOp>(context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHELinalg::Maxpool2dOp maxpool2dOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
const mlir::Location loc = maxpool2dOp->getLoc();
|
||||
|
||||
const mlir::NamedAttribute maxOpAttr = rewriter.getNamedAttr(
|
||||
"max_signed",
|
||||
rewriter.getStringAttr(FHE::MaxEintOp::getOperationName()));
|
||||
|
||||
const auto outputTy =
|
||||
maxpool2dOp->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
const auto outputElementTy =
|
||||
outputTy.getElementType().cast<FHE::FheIntegerInterface>();
|
||||
|
||||
mlir::Value output =
|
||||
rewriter.create<FHE::ZeroTensorOp>(loc, outputTy).getResult();
|
||||
|
||||
if (outputElementTy.isSigned()) {
|
||||
const int64_t outputBitWidth = outputElementTy.getWidth();
|
||||
const int64_t offsetValue = 1 << (outputBitWidth - 2);
|
||||
|
||||
const mlir::Type offsetType =
|
||||
mlir::IntegerType::get(this->getContext(), outputBitWidth + 1);
|
||||
const mlir::Type offsetTensorType =
|
||||
mlir::RankedTensorType::get({1}, offsetType);
|
||||
|
||||
const llvm::SmallVector<mlir::Attribute> offsetTensorAttr = {
|
||||
mlir::IntegerAttr::get(offsetType, offsetValue)};
|
||||
const mlir::Attribute offsetAttr =
|
||||
mlir::DenseElementsAttr::get(offsetTensorType, offsetTensorAttr);
|
||||
|
||||
const mlir::Value offset =
|
||||
rewriter.create<mlir::arith::ConstantOp>(loc, offsetAttr);
|
||||
|
||||
output = rewriter.create<FHELinalg::SubEintIntOp>(loc, output, offset);
|
||||
}
|
||||
|
||||
const mlir::DenseElementsAttr kernelShapeAttr = maxpool2dOp.kernel_shape();
|
||||
const auto kernelShape =
|
||||
llvm::SmallVector<int64_t, 2>(kernelShapeAttr.value_begin<int64_t>(),
|
||||
kernelShapeAttr.value_end<int64_t>());
|
||||
|
||||
const mlir::Value kernel =
|
||||
rewriter
|
||||
.create<mlir::linalg::InitTensorOp>(
|
||||
loc, kernelShape,
|
||||
mlir::IntegerType::get(this->getContext(), 64))
|
||||
.getResult();
|
||||
|
||||
const mlir::DenseIntElementsAttr defaultAttr =
|
||||
rewriter.getI64VectorAttr({1, 1});
|
||||
|
||||
const mlir::DenseIntElementsAttr stridesAttr =
|
||||
maxpool2dOp.dilations().getValueOr(defaultAttr);
|
||||
const mlir::DenseIntElementsAttr dilationsAttr =
|
||||
maxpool2dOp.dilations().getValueOr(defaultAttr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::linalg::PoolingNchwMaxOp>(
|
||||
maxpool2dOp, outputTy, mlir::ValueRange{maxpool2dOp.input(), kernel},
|
||||
output, stridesAttr, dilationsAttr,
|
||||
llvm::ArrayRef<mlir::NamedAttribute>({maxOpAttr}));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an
|
||||
/// appropriate region using `FHE.to_signed` operation, an appropriate
|
||||
@@ -2019,6 +2094,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
patterns.insert<SumToLinalgGeneric>(&getContext());
|
||||
patterns.insert<ConcatRewritePattern>(&getContext());
|
||||
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
|
||||
patterns.insert<FHELinalgMaxpool2dToLinalgMaxpool2d>(&getContext());
|
||||
patterns.insert<TransposeToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FromElementToTensorFromElements>(&getContext());
|
||||
patterns.insert<FHELinalgToSignedToLinalgGeneric>(&getContext());
|
||||
|
||||
@@ -77,6 +77,10 @@ struct FunctionToDag {
|
||||
for (auto &bb : func.getBody().getBlocks()) {
|
||||
for (auto &op : bb.getOperations()) {
|
||||
addOperation(dag, op);
|
||||
}
|
||||
}
|
||||
for (auto &bb : func.getBody().getBlocks()) {
|
||||
for (auto &op : bb.getOperations()) {
|
||||
op.removeAttr("SMANP");
|
||||
}
|
||||
}
|
||||
@@ -145,6 +149,14 @@ struct FunctionToDag {
|
||||
// If can't find weights return default leveled op
|
||||
DEBUG("Replace Dot by LevelledOp on " << op);
|
||||
}
|
||||
if (auto max = asMax(op)) {
|
||||
addMax(dag, max, encrypted_inputs, precision);
|
||||
return;
|
||||
}
|
||||
if (auto maxpool2d = asMaxpool2d(op)) {
|
||||
addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision);
|
||||
return;
|
||||
}
|
||||
// default
|
||||
addLevelledOp(dag, op, encrypted_inputs);
|
||||
}
|
||||
@@ -207,6 +219,103 @@ struct FunctionToDag {
|
||||
manp, slice(out_shape), comment);
|
||||
}
|
||||
|
||||
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
|
||||
int precision) {
|
||||
mlir::Value result = maxOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
Operation *xOp = maxOp.x().getDefiningOp();
|
||||
Operation *yOp = maxOp.y().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
|
||||
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
|
||||
if (xOp != nullptr) {
|
||||
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
xSmanp = xSmanpAttr.getValue();
|
||||
}
|
||||
|
||||
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
|
||||
if (yOp != nullptr) {
|
||||
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
ySmanp = ySmanpAttr.getValue();
|
||||
}
|
||||
|
||||
const double subManp =
|
||||
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
|
||||
|
||||
auto loc = loc_to_string(maxOp.getLoc());
|
||||
auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc;
|
||||
|
||||
auto subNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
subManp, slice(resultShape), comment);
|
||||
|
||||
const double tluNodeManp = 1;
|
||||
const std::vector<std::uint64_t> unknownFunction;
|
||||
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
||||
|
||||
const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble());
|
||||
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
|
||||
tluNode, inputs[1]};
|
||||
index[result] =
|
||||
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
|
||||
addManp, slice(resultShape), comment);
|
||||
}
|
||||
|
||||
void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp,
|
||||
Inputs &inputs, int precision) {
|
||||
mlir::Value result = maxpool2dOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
// all TLUs are flattened into a dimension
|
||||
// to create a single TLU node in optimizer dag
|
||||
std::vector<uint64_t> fakeShape = resultShape;
|
||||
|
||||
uint64_t numberOfComparisons = 1;
|
||||
for (auto dimensionSize : maxpool2dOp.kernel_shape().getValues<int64_t>()) {
|
||||
numberOfComparisons *= dimensionSize;
|
||||
}
|
||||
fakeShape.push_back(numberOfComparisons);
|
||||
|
||||
Operation *inputOp = maxpool2dOp.input().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
|
||||
llvm::APInt inputSmanp = llvm::APInt{1, 1, false};
|
||||
if (inputOp != nullptr) {
|
||||
const auto inputSmanpAttr =
|
||||
inputOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(inputSmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
inputSmanp = inputSmanpAttr.getValue();
|
||||
}
|
||||
|
||||
const double subManp = sqrt(2 * inputSmanp.roundToDouble() + 1);
|
||||
|
||||
auto loc = loc_to_string(maxpool2dOp.getLoc());
|
||||
auto comment =
|
||||
std::string(maxpool2dOp->getName().getStringRef()) + " " + loc;
|
||||
|
||||
auto subNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
subManp, slice(fakeShape), comment);
|
||||
|
||||
const std::vector<std::uint64_t> unknownFunction;
|
||||
auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
||||
|
||||
const double addManp = sqrt(inputSmanp.roundToDouble() + 1);
|
||||
const std::vector<concrete_optimizer::dag::OperatorIndex> addInputs = {
|
||||
tluNode, inputs[1]};
|
||||
|
||||
index[result] =
|
||||
dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost,
|
||||
addManp, slice(resultShape), comment);
|
||||
}
|
||||
|
||||
Inputs encryptedInputs(mlir::Operation &op) {
|
||||
Inputs inputs;
|
||||
for (auto operand : op.getOperands()) {
|
||||
@@ -237,6 +346,14 @@ struct FunctionToDag {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHELinalg::Maxpool2dOp asMaxpool2d(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(op);
|
||||
}
|
||||
|
||||
bool isReturn(mlir::Operation &op) {
|
||||
return llvm::isa<mlir::func::ReturnOp>(op);
|
||||
}
|
||||
|
||||
@@ -510,6 +510,30 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.max_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MaxEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
// max(x, y) = max(x - y, 0) + y
|
||||
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
const llvm::APInt sub = APIntWidthExtendUAdd(x, y);
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt add = APIntWidthExtendUAdd(tlu, y);
|
||||
|
||||
// this is not optimal as it can increase the resulting noise unnecessarily
|
||||
return APIntUMax(add, sub);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -1153,6 +1177,28 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Maxpool2dOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
// maximum between two value is calculated using
|
||||
// - max(x - y, 0) + y
|
||||
|
||||
// max is calculated with a TLU so MANP is {1, 1, false}
|
||||
// y on the other hand comes from the input or from the previous result
|
||||
|
||||
// in the current implementation, it's the input
|
||||
// so the resulting MANP is `{1, 1, false} + MANP input`
|
||||
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt input = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input);
|
||||
const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input);
|
||||
|
||||
return APIntUMax(forIntermediate, forResult);
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -1202,6 +1248,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto roundOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::RoundEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(roundOp, operands);
|
||||
} else if (auto maxEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(maxEintOp, operands);
|
||||
} else if (llvm::isa<mlir::concretelang::FHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::ToBoolOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::FromBoolOp>(op) ||
|
||||
@@ -1264,6 +1313,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(conv2dOp, operands);
|
||||
} else if (auto maxpool2dOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(maxpool2dOp, operands);
|
||||
} else if (auto fromElementOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::FromElementOp>(
|
||||
op)) {
|
||||
|
||||
@@ -192,6 +192,24 @@ mlir::LogicalResult MulEintOp::verify() {
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult MaxEintOp::verify() {
|
||||
auto xTy = this->x().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto yTy = this->y().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto outTy = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(),
|
||||
xTy, outTy)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), xTy,
|
||||
yTy)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToSignedOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedIntegerType>();
|
||||
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();
|
||||
|
||||
@@ -2,6 +2,7 @@ add_mlir_library(
|
||||
FHEDialectTransforms
|
||||
BigInt.cpp
|
||||
Boolean.cpp
|
||||
Max.cpp
|
||||
EncryptedMulToDoubleTLU.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
111
compiler/lib/Dialect/FHE/Transforms/Max.cpp
Normal file
111
compiler/lib/Dialect/FHE/Transforms/Max.cpp
Normal file
@@ -0,0 +1,111 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/Transforms/Max/Max.h"
|
||||
|
||||
namespace arith = mlir::arith;
|
||||
namespace func = mlir::func;
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
|
||||
/// This rewrite pattern transforms all instances
|
||||
/// of `FHE.max_eint` to `max(x - y, 0) + y`.
|
||||
struct MaxEintPattern : public mlir::OpRewritePattern<FHE::MaxEintOp> {
|
||||
MaxEintPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<FHE::MaxEintOp>(context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MaxEintOp maxEintOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
const mlir::Location loc = maxEintOp->getLoc();
|
||||
|
||||
const FHE::FheIntegerInterface outputTy =
|
||||
maxEintOp->getResult(0).getType().cast<FHE::FheIntegerInterface>();
|
||||
const int64_t outputBitWidth = outputTy.getWidth();
|
||||
|
||||
mlir::Value x = maxEintOp.x();
|
||||
mlir::Value y = maxEintOp.y();
|
||||
|
||||
const auto xTy = x.getType().cast<FHE::FheIntegerInterface>();
|
||||
const auto yTy = y.getType().cast<FHE::FheIntegerInterface>();
|
||||
|
||||
const auto signedTy = FHE::EncryptedSignedIntegerType::get(
|
||||
this->getContext(), outputBitWidth);
|
||||
|
||||
if (xTy.isUnsigned()) {
|
||||
x = rewriter.create<FHE::ToSignedOp>(loc, signedTy, x).getResult();
|
||||
}
|
||||
if (yTy.isUnsigned()) {
|
||||
y = rewriter.create<FHE::ToSignedOp>(loc, signedTy, y).getResult();
|
||||
}
|
||||
|
||||
const mlir::Value sub =
|
||||
rewriter.create<FHE::SubEintOp>(loc, x, y).getResult();
|
||||
|
||||
const int64_t lutSize = 1 << outputBitWidth;
|
||||
|
||||
auto lutValues = std::vector<int64_t>();
|
||||
for (int64_t i = 0; i < lutSize / 2; i++) {
|
||||
lutValues.push_back(i);
|
||||
}
|
||||
for (int64_t i = 0; i < lutSize / 2; i++) {
|
||||
lutValues.push_back(0);
|
||||
}
|
||||
|
||||
const mlir::Attribute lutAttr = rewriter.getI64TensorAttr(lutValues);
|
||||
const mlir::Value lut =
|
||||
rewriter.create<arith::ConstantOp>(loc, lutAttr).getResult();
|
||||
|
||||
const mlir::Value max =
|
||||
rewriter.create<FHE::ApplyLookupTableEintOp>(loc, outputTy, sub, lut)
|
||||
.getResult();
|
||||
|
||||
const mlir::Value add =
|
||||
rewriter.create<FHE::AddEintOp>(loc, max, maxEintOp.y()).getResult();
|
||||
|
||||
rewriter.replaceOp(maxEintOp, {add});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
struct FHEMaxTransform : public FHEMaxTransformBase<FHEMaxTransform> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void FHEMaxTransform::runOnOperation() {
|
||||
auto target = mlir::ConversionTarget(this->getContext());
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<FHE::FHEDialect>();
|
||||
target.addIllegalOp<FHE::MaxEintOp>();
|
||||
|
||||
auto patterns = mlir::RewritePatternSet(&this->getContext());
|
||||
patterns.insert<MaxEintPattern>(&this->getContext());
|
||||
|
||||
mlir::Operation *op = this->getOperation();
|
||||
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEMaxTransformPass() {
|
||||
return std::make_unique<FHEMaxTransform>();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1048,6 +1048,154 @@ mlir::LogicalResult Conv2dOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult Maxpool2dOp::verify() {
|
||||
const mlir::RankedTensorType inputTy =
|
||||
this->input().getType().cast<mlir::RankedTensorType>();
|
||||
const mlir::RankedTensorType outputTy =
|
||||
this->getResult().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
const FHE::FheIntegerInterface inputElementTy = inputTy.getElementType();
|
||||
const FHE::FheIntegerInterface outputElementTy = outputTy.getElementType();
|
||||
|
||||
if (inputElementTy != outputElementTy) {
|
||||
this->emitOpError() << "expected output element type "
|
||||
<< "(" << outputElementTy << ") "
|
||||
<< "to be the same with input element type "
|
||||
<< "(" << inputElementTy << ") "
|
||||
<< "but it is not";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
const llvm::ArrayRef<int64_t> inputShape = inputTy.getShape();
|
||||
const llvm::ArrayRef<int64_t> outputShape = outputTy.getShape();
|
||||
|
||||
if (inputShape.size() != 4) {
|
||||
this->emitOpError() << "expected input to have 4 dimensions (N*C*H*W) "
|
||||
<< "but it has " << inputShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (outputShape.size() != 4) {
|
||||
this->emitOpError() << "expected output to have 4 dimensions (N*C*H*W) "
|
||||
<< "but it has " << outputShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
const int64_t inputN = inputShape[0];
|
||||
const int64_t inputC = inputShape[1];
|
||||
const int64_t inputH = inputShape[2];
|
||||
const int64_t inputW = inputShape[3];
|
||||
|
||||
const mlir::DenseIntElementsAttr kernelShapeAttr = this->kernel_shape();
|
||||
const mlir::RankedTensorType kernelShapeAttrTy =
|
||||
kernelShapeAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> kernelShapeAttrShape =
|
||||
kernelShapeAttrTy.getShape();
|
||||
|
||||
if (kernelShapeAttrShape.size() != 1 || kernelShapeAttrShape[0] != 2) {
|
||||
this->emitOpError() << "expected kernel shape to be of shape "
|
||||
<< "(2) "
|
||||
<< "but it is of shape "
|
||||
<< "(" << kernelShapeAttrShape << ")";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
mlir::SmallVector<int64_t, 2> kernelShape;
|
||||
kernelShape.append(kernelShapeAttr.value_begin<int64_t>(),
|
||||
kernelShapeAttr.value_end<int64_t>());
|
||||
|
||||
const int64_t kernelShapeH = kernelShape[0];
|
||||
const int64_t kernelShapeW = kernelShape[1];
|
||||
|
||||
mlir::SmallVector<int64_t, 2> strides;
|
||||
const llvm::Optional<mlir::DenseIntElementsAttr> maybeStridesAttr =
|
||||
this->strides();
|
||||
if (maybeStridesAttr.hasValue()) {
|
||||
const mlir::DenseIntElementsAttr stridesAttr = maybeStridesAttr.getValue();
|
||||
const mlir::RankedTensorType stridesAttrTy =
|
||||
stridesAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> stridesAttrShape = stridesAttrTy.getShape();
|
||||
|
||||
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
|
||||
this->emitOpError() << "expected strides to be of shape "
|
||||
<< "(2) "
|
||||
<< "but it is of shape "
|
||||
<< "(" << stridesAttrShape << ")";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
strides.append(stridesAttr.value_begin<int64_t>(),
|
||||
stridesAttr.value_end<int64_t>());
|
||||
} else {
|
||||
strides.append({1, 1});
|
||||
}
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
if (strides[i] < 1) {
|
||||
this->emitOpError() << "expected elements of strides to be positive "
|
||||
<< "but strides[" << i << "] is " << strides[i];
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t stridesH = strides[0];
|
||||
const int64_t stridesW = strides[1];
|
||||
|
||||
mlir::SmallVector<int64_t, 2> dilations;
|
||||
const llvm::Optional<mlir::DenseIntElementsAttr> maybeDilationsAttr =
|
||||
this->dilations();
|
||||
if (maybeDilationsAttr.hasValue()) {
|
||||
const mlir::DenseIntElementsAttr dilationsAttr =
|
||||
maybeDilationsAttr.getValue();
|
||||
const mlir::RankedTensorType dilationsAttrTy =
|
||||
dilationsAttr.getType().cast<mlir::RankedTensorType>();
|
||||
const llvm::ArrayRef<int64_t> dilationsAttrShape =
|
||||
dilationsAttrTy.getShape();
|
||||
|
||||
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
|
||||
this->emitOpError() << "expected dilations to be of shape "
|
||||
<< "(2) "
|
||||
<< "but it is of shape "
|
||||
<< "(" << dilationsAttrShape << ")";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
dilations.append(dilationsAttr.value_begin<int64_t>(),
|
||||
dilationsAttr.value_end<int64_t>());
|
||||
} else {
|
||||
dilations.append({1, 1});
|
||||
}
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
if (dilations[i] < 1) {
|
||||
this->emitOpError() << "expected elements of dilations to be positive "
|
||||
<< "but dilations[" << i << "] is " << dilations[i];
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t dilationsH = dilations[0];
|
||||
const int64_t dilationsW = dilations[1];
|
||||
|
||||
const int64_t expectedOutputH =
|
||||
floor((inputH - dilationsH * (kernelShapeH - 1) - 1) / stridesH) + 1;
|
||||
const int64_t expectedOutputW =
|
||||
floor((inputW - dilationsW * (kernelShapeW - 1) - 1) / stridesW) + 1;
|
||||
const mlir::SmallVector<int64_t, 4> expectedOutputShape = {
|
||||
inputN,
|
||||
inputC,
|
||||
expectedOutputH,
|
||||
expectedOutputW,
|
||||
};
|
||||
|
||||
if (outputShape != llvm::makeArrayRef(expectedOutputShape)) {
|
||||
this->emitOpError() << "expected output to be of shape "
|
||||
<< "(" << expectedOutputShape << ") "
|
||||
<< "but it is of shape "
|
||||
<< "(" << outputShape << ")";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult FromElementOp::verify() {
|
||||
mlir::Value in = this->getOperand();
|
||||
mlir::Value out = this->getResult();
|
||||
|
||||
@@ -293,13 +293,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag("Transforming FHE boolean ops failed");
|
||||
}
|
||||
|
||||
// Encrypted mul rewriting
|
||||
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
|
||||
module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Rewriting of encrypted mul failed");
|
||||
}
|
||||
|
||||
if (options.chunkIntegers) {
|
||||
if (mlir::concretelang::pipeline::transformFHEBigInt(
|
||||
mlirContext, module, enablePass, options.chunkSize,
|
||||
@@ -344,6 +337,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from FHELinalg to FHE failed");
|
||||
}
|
||||
|
||||
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
|
||||
module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Rewriting of high level fhe ops failed");
|
||||
}
|
||||
|
||||
if (target == Target::FHE_NO_LINALG)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
|
||||
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <concretelang/Support/Pipeline.h>
|
||||
@@ -181,7 +182,9 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("transformHighLevelFHEOps", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.init_tensor [3, 2] : tensor<3x2xi64>
|
||||
// CHECK-NEXT: %[[v2:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %1 : tensor<1x1x8x10x!FHE.eint<5>>, tensor<3x2xi64>) outs(%0 : tensor<1x1x6x9x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<1x1x6x9x!FHE.eint<5>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> } : (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
return %0 : tensor<1x1x6x9x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (0)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
|
||||
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
|
||||
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.esint<6>
|
||||
// CHECK-NEXT: } -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v4:.*]] = linalg.init_tensor [2, 3] : tensor<2x3xi64>
|
||||
// CHECK-NEXT: %[[v5:.*]] = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, max_signed = "FHE.max_eint", strides = dense<1> : vector<2xi64>} ins(%arg0, %[[v4]] : tensor<1x1x6x5x!FHE.esint<6>>, tensor<2x3xi64>) outs(%[[v3]] : tensor<1x1x5x3x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: return %[[v5]] : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> } : (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
return %0 : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-max-transform %s 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: !FHE.eint<5>, %[[a1:.*]]: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
|
||||
// CHECK-NEXT: %[[v1:.*]] = "FHE.to_signed"(%[[a0]]) : (!FHE.eint<5>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: %[[v2:.*]] = "FHE.to_signed"(%[[a1]]) : (!FHE.eint<5>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[v1]], %[[v2]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: %[[v4:.*]] = "FHE.apply_lookup_table"(%[[v3]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: %[[v5:.*]] = "FHE.add_eint"(%[[v4]], %[[a1]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: return %[[v5]] : !FHE.eint<5>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %0 : !FHE.eint<5>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: !FHE.esint<5>, %[[a1:.*]]: !FHE.esint<5>) -> !FHE.esint<5> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
|
||||
// CHECK-NEXT: %[[v1:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v1]], %[[v0]]) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: %[[v3:.*]] = "FHE.add_eint"(%[[v2]], %[[a1]]) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
|
||||
// CHECK-NEXT: return %[[v3:.*]] : !FHE.esint<5>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.esint<5>, %arg1: !FHE.esint<5>) -> !FHE.esint<5> {
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<5>, !FHE.esint<5>) -> !FHE.esint<5>
|
||||
return %0 : !FHE.esint<5>
|
||||
}
|
||||
@@ -1,15 +1,15 @@
|
||||
// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-tfhe --passes EncryptedMulToDoubleTLU --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3>
|
||||
// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: return %5 : !FHE.eint<3>
|
||||
// CHECK: func.func @simple_eint(%[[a0:.*]]: !FHE.eint<3>, %[[a1:.*]]: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.add_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %[[v2:.*]] = "FHE.apply_lookup_table"(%[[v0]], %[[v1]]) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %[[v3:.*]] = "FHE.sub_eint"(%[[a0]], %[[a1]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %[[v4:.*]] = "FHE.to_signed"(%[[v3]]) : (!FHE.eint<3>) -> !FHE.esint<3>
|
||||
// CHECK-NEXT: %[[v5:.*]] = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %[[v6:.*]] = "FHE.apply_lookup_table"(%[[v4]], %[[v5]]) : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %[[v7:.*]] = "FHE.sub_eint"(%[[v2]], %[[v6]]) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: return %[[v7]] : !FHE.eint<3>
|
||||
// CHECK-NEXT: }
|
||||
func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
|
||||
|
||||
33
compiler/tests/check_tests/Dialect/FHE/max_eint.invalid.mlir
Normal file
33
compiler/tests/check_tests/Dialect/FHE/max_eint.invalid.mlir
Normal file
@@ -0,0 +1,33 @@
|
||||
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs equal
|
||||
func.func @bad_inputs_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<3>) -> !FHE.eint<5> {
|
||||
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<3>) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs equal
|
||||
func.func @bad_inputs_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.esint<5>) -> !FHE.eint<5> {
|
||||
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.esint<5>) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.max_eint' op should have the width of encrypted inputs and result equal
|
||||
func.func @bad_result_width(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.max_eint' op should have the signedness of encrypted inputs and result equal
|
||||
func.func @bad_result_signedness(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.esint<5> {
|
||||
%1 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.esint<5>)
|
||||
return %1: !FHE.esint<5>
|
||||
}
|
||||
@@ -223,6 +223,24 @@ func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4>
|
||||
func.func @max_eint(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>)
|
||||
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<4>
|
||||
return %0: !FHE.eint<4>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4>
|
||||
func.func @max_esint(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4>
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1): (!FHE.esint<4>, !FHE.esint<4>) -> (!FHE.esint<4>)
|
||||
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.esint<4>
|
||||
return %0: !FHE.esint<4>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool
|
||||
func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.to_bool"(%arg0) : (!FHE.eint<1>) -> !FHE.ebool
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
func.func @different_input_and_output_bit_widths(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output element type ('!FHE.eint<5>') to be the same with input element type ('!FHE.eint<7>') but it is not}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<5>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_input_dimensions(%arg0: tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected input to have 4 dimensions (N*C*H*W) but it has 2}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_output_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to have 4 dimensions (N*C*H*W) but it has 3}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_kernel_shape_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected kernel shape to be of shape (2) but it is of shape (3)}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2, 3]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_strides_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected strides to be of shape (2) but it is of shape (3)}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_strides_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of strides to be positive but strides[0] is -1}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, strides = dense<[-1, 1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_dilations_dimensions(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected dilations to be of shape (2) but it is of shape (3)}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_dilations_values(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected elements of dilations to be positive but dilations[1] is -1}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64>, dilations = dense<[1, -1]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_output_shape(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.maxpool2d' op expected output to be of shape (1, 1, 13, 9) but it is of shape (1, 1, 10, 5)}}
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x10x5x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x10x5x!FHE.eint<7>>
|
||||
}
|
||||
12
compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.mlir
Normal file
12
compiler/tests/check_tests/Dialect/FHELinalg/maxpool2d.mlir
Normal file
@@ -0,0 +1,12 @@
|
||||
// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.maxpool2d"(%[[a0]]) {kernel_shape = dense<[4, 2]> : tensor<2xi64>} : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[4, 2]> : tensor<2xi64> } : (tensor<1x1x16x10x!FHE.eint<7>>) -> tensor<1x1x13x9x!FHE.eint<7>>
|
||||
return %0 : tensor<1x1x13x9x!FHE.eint<7>>
|
||||
}
|
||||
@@ -1,3 +1,977 @@
|
||||
description: max_eint_unsigned
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
return %0 : !FHE.eint<4>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 2
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 2
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 2
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 2
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 2
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 3
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 4
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 5
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 6
|
||||
|
||||
- inputs:
|
||||
- scalar: 6
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 3
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 7
|
||||
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 7
|
||||
---
|
||||
description: max_eint_signed
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.esint<4>, %arg1: !FHE.esint<4>) -> !FHE.esint<4> {
|
||||
%0 = "FHE.max_eint"(%arg0, %arg1) : (!FHE.esint<4>, !FHE.esint<4>) -> !FHE.esint<4>
|
||||
return %0 : !FHE.esint<4>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -4
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -3
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -2
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: -1
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: -4
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: -3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: -2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: -1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: 0
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: 1
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: 2
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
- scalar: 3
|
||||
signed: true
|
||||
outputs:
|
||||
- scalar: 3
|
||||
signed: true
|
||||
---
|
||||
# TODO: Rewrite/Remove
|
||||
# The FHE.neg_eint op doesn't come with a well defined semantics as FHE.eint
|
||||
# has an undefined behavior for under/overflow.
|
||||
|
||||
@@ -380,3 +380,64 @@ tests:
|
||||
20, 32, 44, 20, 32, 44, 20, 32, 44,
|
||||
]
|
||||
shape: [1, 3, 3, 3]
|
||||
---
|
||||
description: maxpool2d_unsigned_1x1x8x10_kernel_3x2
|
||||
program: |
|
||||
func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>> {
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[3, 2]> : tensor<2xi64> }
|
||||
: (tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.eint<5>>
|
||||
return %0 : tensor<1x1x6x9x!FHE.eint<5>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [
|
||||
3, 3, 11, 9, 0, 3, 14, 10, 5, 6,
|
||||
10, 6, 4, 1, 10, 9, 11, 4, 0, 9,
|
||||
8, 4, 10, 12, 11, 10, 9, 3, 10, 2,
|
||||
8, 0, 11, 7, 5, 10, 8, 13, 9, 9,
|
||||
9, 1, 15, 0, 6, 8, 6, 6, 6, 3,
|
||||
6, 9, 10, 6, 0, 9, 13, 12, 6, 9,
|
||||
12, 13, 7, 15, 7, 1, 9, 3, 13, 6,
|
||||
2, 11, 14, 8, 11, 1, 11, 0, 0, 15,
|
||||
]
|
||||
shape: [1, 1, 8, 10]
|
||||
outputs:
|
||||
- tensor: [
|
||||
10, 11, 12, 12, 11, 14, 14, 10, 10,
|
||||
10, 11, 12, 12, 11, 11, 13, 13, 10,
|
||||
9, 15, 15, 12, 11, 10, 13, 13, 10,
|
||||
9, 15, 15, 7, 10, 13, 13, 13, 9,
|
||||
13, 15, 15, 15, 9, 13, 13, 13, 13,
|
||||
13, 14, 15, 15, 11, 13, 13, 13, 15,
|
||||
]
|
||||
shape: [1, 1, 6, 9]
|
||||
---
|
||||
description: maxpool2d_signed_1x1x6x5_kernel_2x3
|
||||
program: |
|
||||
func.func @main(%arg0: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
|
||||
%0 = "FHELinalg.maxpool2d"(%arg0) { kernel_shape = dense<[2, 3]> : tensor<2xi64> }
|
||||
: (tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
return %0 : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [
|
||||
-8, -12, -8, -12, -10,
|
||||
1, -9, -15, -16, 14,
|
||||
9, 14, 2, 2, -15,
|
||||
9, -12, 0, -4, -5,
|
||||
-7, -11, -15, -4, 6,
|
||||
15, -3, 7, -13, -13,
|
||||
]
|
||||
shape: [1, 1, 6, 5]
|
||||
signed: true
|
||||
outputs:
|
||||
- tensor: [
|
||||
1, -8, 14,
|
||||
14, 14, 14,
|
||||
14, 14, 2,
|
||||
9, 0, 6,
|
||||
15, 7, 7,
|
||||
]
|
||||
shape: [1, 1, 5, 3]
|
||||
signed: true
|
||||
|
||||
Reference in New Issue
Block a user