mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: lower and exec boolean ops
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -375,15 +375,17 @@ def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> (!FHE.ebool)
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> (!FHE.ebool)
|
||||
|
||||
// error
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi4>) -> (!FHE.ebool)
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi1>) -> (!FHE.ebool)
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi64>) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[I1]>:$truth_table);
|
||||
// The reason the truth table is of AnyInteger and not I1 is that in lowering passes, the truth_table is meant to be passed
|
||||
// to an LUT operation which requires the table to be of type I64. Whenever lowering passes are no more restrictive, this
|
||||
// can be set to I1 to reflect the boolean logic.
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[AnyInteger]>:$truth_table);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
@@ -498,16 +500,17 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
|
||||
let description = [{
|
||||
Cast an unsigned integer to a boolean.
|
||||
|
||||
The input must necessarily be of width 1, in order to put that single
|
||||
bit into a boolean.
|
||||
The input must necessarily be of width 1 or 2. 2 being the current representation
|
||||
of an encrypted boolean, leaving one bit for the carry.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<1>) -> !FHE.ebool
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
|
||||
// error
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<3>) -> !FHE.ebool
|
||||
```
|
||||
}];
|
||||
|
||||
|
||||
@@ -78,6 +78,11 @@ def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",
|
||||
let description = [{
|
||||
An encrypted boolean.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the required number of bits to represent an encrypted boolean
|
||||
static size_t getWidth() { return 2; }
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_BOOLEAN_PASS_H
|
||||
#define CONCRETELANG_FHE_BOOLEAN_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEBooleanTransformPass();
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_FHE_BOOLEAN_PASS
|
||||
#define CONCRETELANG_FHE_BOOLEAN_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHEBooleanTransform : Pass<"fhe-boolean-transform"> {
|
||||
let summary = "Transform FHE boolean operations to integer operations";
|
||||
let constructor = "mlir::concretelang::createFHEBooleanTransformPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Boolean.td)
|
||||
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
|
||||
@@ -40,6 +40,10 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelize, bool batch);
|
||||
|
||||
mlir::LogicalResult
|
||||
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -78,6 +78,11 @@ public:
|
||||
addConversion([](FHE::EncryptedIntegerType type) {
|
||||
return convertEint(type.getContext(), type);
|
||||
});
|
||||
addConversion([](FHE::EncryptedBooleanType type) {
|
||||
return TFHE::GLWECipherTextType::get(
|
||||
type.getContext(), -1, -1, -1,
|
||||
mlir::concretelang::FHE::EncryptedBooleanType::getWidth());
|
||||
});
|
||||
addConversion([](mlir::RankedTensorType type) {
|
||||
return maybeConvertEintTensor(type.getContext(), type);
|
||||
});
|
||||
@@ -363,6 +368,51 @@ private:
|
||||
concretelang::ScalarLoweringParameters loweringParameters;
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_bool` operation.
|
||||
struct ToBoolOpPattern : public mlir::OpRewritePattern<FHE::ToBoolOp> {
|
||||
ToBoolOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<FHE::ToBoolOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToBoolOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto width = op.input()
|
||||
.getType()
|
||||
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
|
||||
rewriter.replaceOp(op, op.input());
|
||||
return mlir::success();
|
||||
}
|
||||
// TODO
|
||||
op->emitError("only support conversion with width 2 for the moment");
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::from_bool` operation.
|
||||
struct FromBoolOpPattern : public mlir::OpRewritePattern<FHE::FromBoolOp> {
|
||||
FromBoolOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<FHE::FromBoolOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::FromBoolOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto width = op.getResult()
|
||||
.getType()
|
||||
.dyn_cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) {
|
||||
rewriter.replaceOp(op, op.input());
|
||||
return mlir::success();
|
||||
}
|
||||
// TODO
|
||||
op->emitError("only support conversion with width 2 for the moment");
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
@@ -431,6 +481,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// |_ `FHE::neg_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
// |_ `FHE::not`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::BoolNotOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
// |_ `FHE::add_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
|
||||
TFHE::AddGLWEOp>>(
|
||||
@@ -449,6 +502,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for boolean conversion ops
|
||||
patterns.add<lowering::FromBoolOpPattern, lowering::ToBoolOpPattern>(
|
||||
&getContext());
|
||||
|
||||
// Patterns for the relics of the `FHELinalg` dialect operations.
|
||||
// |_ `linalg::generic` turned to nested `scf::for`
|
||||
patterns
|
||||
|
||||
@@ -412,6 +412,22 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.not` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::BoolNotOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -1124,10 +1140,15 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto negEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::NegEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(negEintOp, operands);
|
||||
} else if (auto boolNotOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::BoolNotOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(boolNotOp, operands);
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
} else if (llvm::isa<mlir::concretelang::FHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::ToBoolOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::FromBoolOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::ZeroTensorOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::ApplyLookupTableEintOp>(op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
|
||||
@@ -15,6 +15,7 @@ namespace utils {
|
||||
bool isEncryptedValue(mlir::Value value) {
|
||||
return (
|
||||
value.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() ||
|
||||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
|
||||
(value.getType().isa<mlir::TensorType>() &&
|
||||
value.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
@@ -22,14 +23,20 @@ bool isEncryptedValue(mlir::Value value) {
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()));
|
||||
}
|
||||
|
||||
/// Returns the bit width of `value` if `value` is an encrypted integer
|
||||
/// or the bit width of the elements if `value` is a tensor of
|
||||
/// Returns the bit width of `value` if `value` is an encrypted integer,
|
||||
/// or the number of bits to represent a boolean if `value` is an encrypted
|
||||
/// boolean, or the bit width of the elements if `value` is a tensor of
|
||||
/// encrypted integers.
|
||||
unsigned int getEintPrecision(mlir::Value value) {
|
||||
if (auto ty = value.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
return ty.getWidth();
|
||||
}
|
||||
if (auto ty = value.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
||||
return mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
} else if (auto tensorTy =
|
||||
value.getType().dyn_cast_or_null<mlir::TensorType>()) {
|
||||
if (auto ty = tensorTy.getElementType()
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -206,9 +206,9 @@ mlir::LogicalResult ToUnsignedOp::verify() {
|
||||
mlir::LogicalResult ToBoolOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedIntegerType>();
|
||||
|
||||
if (input.getWidth() != 1) {
|
||||
this->emitOpError(
|
||||
"should have 1 as the width of encrypted input to cast to a boolean");
|
||||
if (input.getWidth() != 1 && input.getWidth() != 2) {
|
||||
this->emitOpError("should have 1 or 2 as the width of encrypted input to "
|
||||
"cast to a boolean");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
|
||||
126
compiler/lib/Dialect/FHE/Transforms/Boolean.cpp
Normal file
126
compiler/lib/Dialect/FHE/Transforms/Boolean.cpp
Normal file
@@ -0,0 +1,126 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Rewrite an `FHE.gen_gate` operation as an LUT operation by composing a
|
||||
/// single index from the two boolean inputs.
|
||||
class GenGatePattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::FHE::GenGateOp> {
|
||||
public:
|
||||
GenGatePattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::FHE::GenGateOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::FHE::GenGateOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get(
|
||||
rewriter.getContext(), 2);
|
||||
auto left = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, op.left())
|
||||
.getResult();
|
||||
auto right = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, op.right())
|
||||
.getResult();
|
||||
auto cst_two =
|
||||
rewriter.create<mlir::arith::ConstantIntOp>(op.getLoc(), 2, 3)
|
||||
.getResult();
|
||||
auto leftMulTwo = rewriter
|
||||
.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
op.getLoc(), left, cst_two)
|
||||
.getResult();
|
||||
auto newIndex = rewriter
|
||||
.create<mlir::concretelang::FHE::AddEintOp>(
|
||||
op.getLoc(), leftMulTwo, right)
|
||||
.getResult();
|
||||
auto lut_result =
|
||||
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
op.getLoc(), eint2, newIndex, op.truth_table());
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(
|
||||
op,
|
||||
mlir::concretelang::FHE::EncryptedBooleanType::get(
|
||||
rewriter.getContext()),
|
||||
lut_result);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewrite an FHE GateOp (e.g. And/Or) into a GenGate with the given truth
|
||||
/// table.
|
||||
template <typename GateOp>
|
||||
class GeneralizeGatePattern : public mlir::OpRewritePattern<GateOp> {
|
||||
public:
|
||||
GeneralizeGatePattern(mlir::MLIRContext *context,
|
||||
llvm::SmallVector<uint64_t, 4> truth_table_vector)
|
||||
: mlir::OpRewritePattern<GateOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
|
||||
truth_table_vector(truth_table_vector) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(GateOp op, mlir::PatternRewriter &rewriter) const override {
|
||||
auto truth_table_attr = mlir::DenseElementsAttr::get(
|
||||
mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)),
|
||||
{llvm::APInt(1, this->truth_table_vector[0], false),
|
||||
llvm::APInt(1, this->truth_table_vector[1], false),
|
||||
llvm::APInt(1, this->truth_table_vector[2], false),
|
||||
llvm::APInt(1, this->truth_table_vector[3], false)});
|
||||
auto truth_table =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::GenGateOp>(
|
||||
op, op.getResult().getType(), op.left(), op.right(), truth_table);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallVector<uint64_t, 4> truth_table_vector;
|
||||
};
|
||||
|
||||
/// Perfoms the transformation of boolean operations
|
||||
class FHEBooleanTransformPass
|
||||
: public FHEBooleanTransformBase<FHEBooleanTransformPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
mlir::Operation *op = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<GenGatePattern>(&getContext());
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolAndOp>>(
|
||||
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 0, 0, 1}));
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolNandOp>>(
|
||||
&getContext(), llvm::SmallVector<uint64_t, 4>({1, 1, 1, 0}));
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolOrOp>>(
|
||||
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 1, 1, 1}));
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolXorOp>>(
|
||||
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 1, 1, 0}));
|
||||
|
||||
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEBooleanTransformPass() {
|
||||
return std::make_unique<FHEBooleanTransformPass>();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
12
compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt
Normal file
12
compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_mlir_library(
|
||||
FHEDialectTransforms
|
||||
Boolean.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
FHEDialect
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
FHEDialect)
|
||||
@@ -21,6 +21,7 @@ add_mlir_library(
|
||||
FHELinalgDialect
|
||||
FHELinalgDialectTransforms
|
||||
FHETensorOpsToLinalg
|
||||
FHEDialectTransforms
|
||||
FHEToTFHECrt
|
||||
FHEToTFHEScalar
|
||||
ExtractSDFGOps
|
||||
|
||||
@@ -283,6 +283,12 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::ROUND_TRIP)
|
||||
return std::move(res);
|
||||
|
||||
if (mlir::concretelang::pipeline::transformFHEBoolean(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Transforming FHE boolean ops failed");
|
||||
}
|
||||
|
||||
// FHE High level pass to determine FHE parameters
|
||||
if (auto err = this->determineFHEParameters(res))
|
||||
return std::move(err);
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
|
||||
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <concretelang/Support/Pipeline.h>
|
||||
@@ -207,6 +208,15 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFHEBooleanTransformPass(), enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -85,6 +85,28 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
||||
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ std::vector<int64_t>(),
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ false,
|
||||
},
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,
|
||||
|
||||
@@ -8,3 +8,11 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}>
|
||||
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool
|
||||
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %arg2) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
|
||||
func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
|
||||
func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 1]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
|
||||
func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[1, 1, 1, 0]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool
|
||||
func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 0]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V5]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
@@ -16,9 +16,9 @@ func.func @zero_plaintext() -> tensor<4x9xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @to_bool(%arg0: !FHE.eint<2>) -> !FHE.ebool {
|
||||
func.func @to_bool(%arg0: !FHE.eint<3>) -> !FHE.ebool {
|
||||
// expected-error @+1 {{'FHE.to_bool' op}}
|
||||
%1 = "FHE.to_bool"(%arg0): (!FHE.eint<2>) -> (!FHE.ebool)
|
||||
%1 = "FHE.to_bool"(%arg0): (!FHE.eint<3>) -> (!FHE.ebool)
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<5xi1>) -
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi2>) -> !FHE.ebool {
|
||||
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<6xi64>) -> !FHE.ebool {
|
||||
// expected-error @+1 {{'FHE.gen_gate' op}}
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi2>) -> !FHE.ebool
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<6xi64>) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
@@ -232,12 +232,12 @@ func.func @from_bool(%arg0: !FHE.ebool) -> !FHE.eint<1> {
|
||||
return %1: !FHE.eint<1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool
|
||||
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool
|
||||
// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool
|
||||
func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V1]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
|
||||
@@ -183,3 +183,151 @@ tests:
|
||||
outputs:
|
||||
- tensor: [9,4,7,7,10,9,9,4,7,7,10,9]
|
||||
shape: [4,3]
|
||||
---
|
||||
description: boolean_and
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: boolean_or
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: boolean_nand
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 0
|
||||
---
|
||||
description: boolean_xor
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 0
|
||||
---
|
||||
description: boolean_gen_gate
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool {
|
||||
%1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- tensor: [1, 0, 1, 1]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
- tensor: [0, 1, 1, 1]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
- tensor: [0, 0, 1, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
- tensor: [0, 0, 0, 1]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- scalar: 1
|
||||
|
||||
Reference in New Issue
Block a user