feat: lower and exec boolean ops

This commit is contained in:
youben11
2023-01-23 13:38:08 +01:00
committed by Ayoub Benaissa
parent ee00672996
commit 36f51ba0c2
23 changed files with 572 additions and 20 deletions

View File

@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -375,15 +375,17 @@ def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
Example:
```mlir
// ok
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> (!FHE.ebool)
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> (!FHE.ebool)
// error
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi4>) -> (!FHE.ebool)
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi1>) -> (!FHE.ebool)
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi64>) -> (!FHE.ebool)
```
}];
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[I1]>:$truth_table);
// The reason the truth table is of AnyInteger and not I1 is that in lowering passes, the truth_table is meant to be passed
// to an LUT operation which requires the table to be of type I64. Whenever lowering passes are no more restrictive, this
// can be set to I1 to reflect the boolean logic.
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[AnyInteger]>:$truth_table);
let results = (outs FHE_EncryptedBooleanType);
let hasVerifier = 1;
}
@@ -498,16 +500,17 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
let description = [{
Cast an unsigned integer to a boolean.
The input must necessarily be of width 1, in order to put that single
bit into a boolean.
The input must necessarily be of width 1 or 2. 2 being the current representation
of an encrypted boolean, leaving one bit for the carry.
Examples:
```mlir
// ok
"FHE.to_bool"(%x) : (!FHE.eint<1>) -> !FHE.ebool
"FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool
// error
"FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool
"FHE.to_bool"(%x) : (!FHE.eint<3>) -> !FHE.ebool
```
}];

View File

@@ -78,6 +78,11 @@ def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",
let description = [{
An encrypted boolean.
}];
let extraClassDeclaration = [{
/// Returns the required number of bits to represent an encrypted boolean
static size_t getWidth() { return 2; }
}];
}
#endif

View File

@@ -0,0 +1,23 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_BOOLEAN_PASS_H
#define CONCRETELANG_FHE_BOOLEAN_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/Boolean.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createFHEBooleanTransformPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_FHE_BOOLEAN_PASS
#define CONCRETELANG_FHE_BOOLEAN_PASS
include "mlir/Pass/PassBase.td"
def FHEBooleanTransform : Pass<"fhe-boolean-transform"> {
let summary = "Transform FHE boolean operations to integer operations";
let constructor = "mlir::concretelang::createFHEBooleanTransformPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
}
#endif

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Boolean.td)
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)

View File

@@ -40,6 +40,10 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelize, bool batch);
mlir::LogicalResult
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -78,6 +78,11 @@ public:
addConversion([](FHE::EncryptedIntegerType type) {
return convertEint(type.getContext(), type);
});
addConversion([](FHE::EncryptedBooleanType type) {
return TFHE::GLWECipherTextType::get(
type.getContext(), -1, -1, -1,
mlir::concretelang::FHE::EncryptedBooleanType::getWidth());
});
addConversion([](mlir::RankedTensorType type) {
return maybeConvertEintTensor(type.getContext(), type);
});
@@ -363,6 +368,51 @@ private:
concretelang::ScalarLoweringParameters loweringParameters;
};
/// Rewriter for the `FHE::to_bool` operation.
struct ToBoolOpPattern : public mlir::OpRewritePattern<FHE::ToBoolOp> {
ToBoolOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<FHE::ToBoolOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ToBoolOp op,
mlir::PatternRewriter &rewriter) const override {
auto width = op.input()
.getType()
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
rewriter.replaceOp(op, op.input());
return mlir::success();
}
// TODO
op->emitError("only support conversion with width 2 for the moment");
return mlir::failure();
}
};
/// Rewriter for the `FHE::from_bool` operation.
struct FromBoolOpPattern : public mlir::OpRewritePattern<FHE::FromBoolOp> {
FromBoolOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<FHE::FromBoolOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::FromBoolOp op,
mlir::PatternRewriter &rewriter) const override {
auto width = op.getResult()
.getType()
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
rewriter.replaceOp(op, op.input());
return mlir::success();
}
// TODO
op->emitError("only support conversion with width 2 for the moment");
return mlir::failure();
}
};
} // namespace lowering
struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
@@ -431,6 +481,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// |_ `FHE::neg_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
TFHE::NegGLWEOp>,
// |_ `FHE::not`
concretelang::GenericTypeAndOpConverterPattern<FHE::BoolNotOp,
TFHE::NegGLWEOp>,
// |_ `FHE::add_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
TFHE::AddGLWEOp>>(
@@ -449,6 +502,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
loweringParameters);
// Patterns for boolean conversion ops
patterns.add<lowering::FromBoolOpPattern, lowering::ToBoolOpPattern>(
&getContext());
// Patterns for the relics of the `FHELinalg` dialect operations.
// |_ `linalg::generic` turned to nested `scf::for`
patterns

View File

@@ -412,6 +412,22 @@ static llvm::APInt getSqMANP(
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.not` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::BoolNotOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -1124,10 +1140,15 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto negEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::NegEintOp>(op)) {
norm2SqEquiv = getSqMANP(negEintOp, operands);
} else if (auto boolNotOp =
llvm::dyn_cast<mlir::concretelang::FHE::BoolNotOp>(op)) {
norm2SqEquiv = getSqMANP(boolNotOp, operands);
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
} else if (llvm::isa<mlir::concretelang::FHE::ZeroEintOp>(op) ||
llvm::isa<mlir::concretelang::FHE::ToBoolOp>(op) ||
llvm::isa<mlir::concretelang::FHE::FromBoolOp>(op) ||
llvm::isa<mlir::concretelang::FHE::ZeroTensorOp>(op) ||
llvm::isa<mlir::concretelang::FHE::ApplyLookupTableEintOp>(op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};

View File

@@ -15,6 +15,7 @@ namespace utils {
bool isEncryptedValue(mlir::Value value) {
return (
value.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() ||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
(value.getType().isa<mlir::TensorType>() &&
value.getType()
.cast<mlir::TensorType>()
@@ -22,14 +23,20 @@ bool isEncryptedValue(mlir::Value value) {
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()));
}
/// Returns the bit width of `value` if `value` is an encrypted integer
/// or the bit width of the elements if `value` is a tensor of
/// Returns the bit width of `value` if `value` is an encrypted integer,
/// or the number of bits to represent a boolean if `value` is an encrypted
/// boolean, or the bit width of the elements if `value` is a tensor of
/// encrypted integers.
unsigned int getEintPrecision(mlir::Value value) {
if (auto ty = value.getType()
.dyn_cast_or_null<
mlir::concretelang::FHE::EncryptedIntegerType>()) {
return ty.getWidth();
}
if (auto ty = value.getType()
.dyn_cast_or_null<
mlir::concretelang::FHE::EncryptedBooleanType>()) {
return mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
} else if (auto tensorTy =
value.getType().dyn_cast_or_null<mlir::TensorType>()) {
if (auto ty = tensorTy.getElementType()

View File

@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -206,9 +206,9 @@ mlir::LogicalResult ToUnsignedOp::verify() {
mlir::LogicalResult ToBoolOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
if (input.getWidth() != 1) {
this->emitOpError(
"should have 1 as the width of encrypted input to cast to a boolean");
if (input.getWidth() != 1 && input.getWidth() != 2) {
this->emitOpError("should have 1 or 2 as the width of encrypted input to "
"cast to a boolean");
return mlir::failure();
}

View File

@@ -0,0 +1,126 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
#include <concretelang/Support/Constants.h>
namespace mlir {
namespace concretelang {
namespace {
/// Rewrite an `FHE.gen_gate` operation as an LUT operation by composing a
/// single index from the two boolean inputs.
class GenGatePattern
: public mlir::OpRewritePattern<mlir::concretelang::FHE::GenGateOp> {
public:
GenGatePattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<mlir::concretelang::FHE::GenGateOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::FHE::GenGateOp op,
mlir::PatternRewriter &rewriter) const override {
auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get(
rewriter.getContext(), 2);
auto left = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, op.left())
.getResult();
auto right = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, op.right())
.getResult();
auto cst_two =
rewriter.create<mlir::arith::ConstantIntOp>(op.getLoc(), 2, 3)
.getResult();
auto leftMulTwo = rewriter
.create<mlir::concretelang::FHE::MulEintIntOp>(
op.getLoc(), left, cst_two)
.getResult();
auto newIndex = rewriter
.create<mlir::concretelang::FHE::AddEintOp>(
op.getLoc(), leftMulTwo, right)
.getResult();
auto lut_result =
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
op.getLoc(), eint2, newIndex, op.truth_table());
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(
op,
mlir::concretelang::FHE::EncryptedBooleanType::get(
rewriter.getContext()),
lut_result);
return mlir::success();
}
};
/// Rewrite an FHE GateOp (e.g. And/Or) into a GenGate with the given truth
/// table.
template <typename GateOp>
class GeneralizeGatePattern : public mlir::OpRewritePattern<GateOp> {
public:
GeneralizeGatePattern(mlir::MLIRContext *context,
llvm::SmallVector<uint64_t, 4> truth_table_vector)
: mlir::OpRewritePattern<GateOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
truth_table_vector(truth_table_vector) {}
mlir::LogicalResult
matchAndRewrite(GateOp op, mlir::PatternRewriter &rewriter) const override {
auto truth_table_attr = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)),
{llvm::APInt(1, this->truth_table_vector[0], false),
llvm::APInt(1, this->truth_table_vector[1], false),
llvm::APInt(1, this->truth_table_vector[2], false),
llvm::APInt(1, this->truth_table_vector[3], false)});
auto truth_table =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::GenGateOp>(
op, op.getResult().getType(), op.left(), op.right(), truth_table);
return mlir::success();
}
private:
llvm::SmallVector<uint64_t, 4> truth_table_vector;
};
/// Perfoms the transformation of boolean operations
class FHEBooleanTransformPass
: public FHEBooleanTransformBase<FHEBooleanTransformPass> {
public:
void runOnOperation() override {
mlir::Operation *op = getOperation();
mlir::RewritePatternSet patterns(&getContext());
patterns.add<GenGatePattern>(&getContext());
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolAndOp>>(
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 0, 0, 1}));
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolNandOp>>(
&getContext(), llvm::SmallVector<uint64_t, 4>({1, 1, 1, 0}));
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolOrOp>>(
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 1, 1, 1}));
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolXorOp>>(
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 1, 1, 0}));
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<>> createFHEBooleanTransformPass() {
return std::make_unique<FHEBooleanTransformPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,12 @@
add_mlir_library(
FHEDialectTransforms
Boolean.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
FHEDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
FHEDialect)

View File

@@ -21,6 +21,7 @@ add_mlir_library(
FHELinalgDialect
FHELinalgDialectTransforms
FHETensorOpsToLinalg
FHEDialectTransforms
FHEToTFHECrt
FHEToTFHEScalar
ExtractSDFGOps

View File

@@ -283,6 +283,12 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::ROUND_TRIP)
return std::move(res);
if (mlir::concretelang::pipeline::transformFHEBoolean(mlirContext, module,
enablePass)
.failed()) {
return errorDiag("Transforming FHE boolean ops failed");
}
// FHE High level pass to determine FHE parameters
if (auto err = this->determineFHEParameters(res))
return std::move(err);

View File

@@ -35,6 +35,7 @@
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Support/Pipeline.h>
@@ -207,6 +208,15 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createFHEBooleanTransformPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -85,6 +85,28 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
},
};
}
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::EncryptedBooleanType>()) {
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ std::vector<int64_t>(),
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ false,
},
};
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,

View File

@@ -8,3 +8,11 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}>
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -0,0 +1,80 @@
// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %arg2) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
return %1: !FHE.ebool
}
// CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64>
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
%1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
// CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 1]> : tensor<4xi64>
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
%1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
// CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[1, 1, 1, 0]> : tensor<4xi64>
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
%1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
// CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 0]> : tensor<4xi64>
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -16,9 +16,9 @@ func.func @zero_plaintext() -> tensor<4x9xi32> {
// -----
func.func @to_bool(%arg0: !FHE.eint<2>) -> !FHE.ebool {
func.func @to_bool(%arg0: !FHE.eint<3>) -> !FHE.ebool {
// expected-error @+1 {{'FHE.to_bool' op}}
%1 = "FHE.to_bool"(%arg0): (!FHE.eint<2>) -> (!FHE.ebool)
%1 = "FHE.to_bool"(%arg0): (!FHE.eint<3>) -> (!FHE.ebool)
return %1: !FHE.ebool
}
@@ -32,8 +32,8 @@ func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<5xi1>) -
// -----
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi2>) -> !FHE.ebool {
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<6xi64>) -> !FHE.ebool {
// expected-error @+1 {{'FHE.gen_gate' op}}
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi2>) -> !FHE.ebool
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<6xi64>) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -232,12 +232,12 @@ func.func @from_bool(%arg0: !FHE.ebool) -> !FHE.eint<1> {
return %1: !FHE.eint<1>
}
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool {
// CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
// CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
// CHECK-NEXT: return %[[V1]] : !FHE.ebool
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -183,3 +183,151 @@ tests:
outputs:
- tensor: [9,4,7,7,10,9,9,4,7,7,10,9]
shape: [4,3]
---
description: boolean_and
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 0
- inputs:
- scalar: 1
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 1
- scalar: 1
outputs:
- scalar: 1
---
description: boolean_or
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
outputs:
- scalar: 1
---
description: boolean_nand
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
outputs:
- scalar: 0
---
description: boolean_xor
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
outputs:
- scalar: 0
---
description: boolean_gen_gate
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 1
- tensor: [1, 0, 1, 1]
shape: [4]
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 0
- tensor: [0, 1, 1, 1]
shape: [4]
outputs:
- scalar: 0
- inputs:
- scalar: 1
- scalar: 0
- tensor: [0, 0, 1, 0]
shape: [4]
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
- tensor: [0, 0, 0, 1]
shape: [4]
outputs:
- scalar: 1