mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(concrete-compiler): add new ciphertext multiplication operator
This commit is contained in:
@@ -285,6 +285,44 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
let summary = "Multiplies two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
Multiplies two encrypted integers.
|
||||
|
||||
The encrypted integers and the result must have the same width and
|
||||
signedness. Also, due to the current implementation, one supplementary
|
||||
bit of width must be provided, in addition to the number of bits needed
|
||||
to encode the largest output value.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.esint<3>, !FHE.esint<3>) -> (!FHE.esint<3>)
|
||||
|
||||
// error
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
let summary = "Cast an unsigned integer to a signed one";
|
||||
|
||||
|
||||
@@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS Boolean.td)
|
||||
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
|
||||
set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td)
|
||||
mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H
|
||||
#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
|
||||
createEncryptedMulToDoubleTLUPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS
|
||||
#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def EncryptedMulToDoubleTLU : Pass<"EncryptedMulToDoubleTLU", "::mlir::func::FuncOp"> {
|
||||
let summary = "Replaces encrypted multiplication with a double table lookup.";
|
||||
let constructor = "mlir::concretelang::createEncryptedMulToDoubleTLUPass()";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -34,6 +34,10 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::ArrayRef<int64_t> tileSizes,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -458,6 +458,42 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_signed` operation.
|
||||
struct ToSignedOpPattern : public CrtOpPattern<FHE::ToSignedOp> {
|
||||
ToSignedOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::ToSignedOp>(context, params, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_unsigned` operation.
|
||||
struct ToUnsignedOpPattern : public CrtOpPattern<FHE::ToUnsignedOp> {
|
||||
ToUnsignedOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::ToUnsignedOp>(context, params, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::mul_eint_int` operation.
|
||||
struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
|
||||
|
||||
@@ -937,6 +973,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
lowering::NegEintOpPattern,
|
||||
// |_ `FHE::mul_eint_int`
|
||||
lowering::MulEintIntOpPattern,
|
||||
// |_ `FHE::to_unsigned`
|
||||
lowering::ToUnsignedOpPattern,
|
||||
// |_ `FHE::to_signed`
|
||||
lowering::ToSignedOpPattern,
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
lowering::ApplyLookupTableEintOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
|
||||
@@ -273,6 +273,41 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_signed` operation.
|
||||
struct ToSignedOpPattern : public ScalarOpPattern<FHE::ToSignedOp> {
|
||||
ToSignedOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::ToSignedOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter;
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_unsigned` operation.
|
||||
struct ToUnsignedOpPattern : public ScalarOpPattern<FHE::ToUnsignedOp> {
|
||||
ToUnsignedOpPattern(mlir::TypeConverter &converter,
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::ToUnsignedOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter;
|
||||
rewriter.replaceOp(op, {adaptor.input()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::apply_lookup_table` operation.
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
|
||||
@@ -474,7 +509,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// |_ `FHE::sub_eint`
|
||||
lowering::SubEintOpPattern,
|
||||
// |_ `FHE::mul_eint_int`
|
||||
lowering::MulEintIntOpPattern>(converter, &getContext());
|
||||
lowering::MulEintIntOpPattern,
|
||||
// |_ `FHE::to_signed`
|
||||
lowering::ToSignedOpPattern,
|
||||
// |_ `FHE::to_unsigned`
|
||||
lowering::ToUnsignedOpPattern>(converter, &getContext());
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(
|
||||
converter, &getContext(), loweringParameters);
|
||||
|
||||
@@ -424,6 +424,34 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::ToSignedOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::ToUnsignedOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -1139,6 +1167,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto boolNotOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::BoolNotOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(boolNotOp, operands);
|
||||
} else if (auto toSignedOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::ToSignedOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(toSignedOp, operands);
|
||||
} else if (auto toUnsignedOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::ToUnsignedOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(toUnsignedOp, operands);
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
|
||||
@@ -177,6 +177,22 @@ mlir::LogicalResult MulEintIntOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulEintOp::verify() {
|
||||
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
|
||||
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
|
||||
|
||||
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult ToSignedOp::verify() {
|
||||
auto input = this->input().getType().cast<EncryptedIntegerType>();
|
||||
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
add_mlir_library(
|
||||
FHEDialectTransforms
|
||||
Boolean.cpp
|
||||
EncryptedMulToDoubleTLU.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
|
||||
178
compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp
Normal file
178
compiler/lib/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.cpp
Normal file
@@ -0,0 +1,178 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Dialect/FHE/Analysis/utils.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace mlir::concretelang::FHE;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace {
|
||||
|
||||
class EncryptedMulOpPattern : public mlir::OpConversionPattern<FHE::MulEintOp> {
|
||||
public:
|
||||
EncryptedMulOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpConversionPattern<FHE::MulEintOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MulEintOp op, FHE::MulEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto inputType = adaptor.a().getType();
|
||||
auto bitWidth = inputType.cast<FHE::FheIntegerInterface>().getWidth();
|
||||
auto isSigned = inputType.cast<FHE::FheIntegerInterface>().isSigned();
|
||||
mlir::Type signedType =
|
||||
FHE::EncryptedSignedIntegerType::get(op->getContext(), bitWidth);
|
||||
|
||||
// Note:
|
||||
// -----
|
||||
//
|
||||
// The signedness of a value is only important:
|
||||
// + when used as function input / output, because it changes the
|
||||
// encoding/decoding used.
|
||||
// + when used as tlu input, because it changes the encoding of the lut.
|
||||
//
|
||||
// Otherwise, for the leveled operations, the semantics are compatible. We
|
||||
// just have to please the verifier that usually requires the same
|
||||
// signedness for inputs and outputs.
|
||||
|
||||
// s = a + b
|
||||
mlir::Value sum =
|
||||
rewriter.create<FHE::AddEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
|
||||
|
||||
// se = (s)^2/4
|
||||
// Depending on whether a,b,s are signed or not, we need a different lut to
|
||||
// compute (.)^2/4.
|
||||
mlir::SmallVector<uint64_t> rawSumLut;
|
||||
if (isSigned) {
|
||||
rawSumLut = generateSignedLut(bitWidth);
|
||||
} else {
|
||||
rawSumLut = generateUnsignedLut(bitWidth);
|
||||
}
|
||||
mlir::Value sumLut = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(
|
||||
rawSumLut.size(), rewriter.getIntegerType(64)),
|
||||
rawSumLut));
|
||||
mlir::Value sumTluOutput = rewriter.create<FHE::ApplyLookupTableEintOp>(
|
||||
op->getLoc(), inputType, sum, sumLut);
|
||||
|
||||
// d = a - b
|
||||
mlir::Value diff =
|
||||
rewriter.create<FHE::SubEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
|
||||
|
||||
// de = (d)^2/4
|
||||
// Here, the tlu must be performed with signed encoded lut, to properly
|
||||
// bootstrap negative values that may arise in the computation of d. If the
|
||||
// inputs are not signed, we cast the output to a signed encrypted integer.
|
||||
mlir::Value diffO;
|
||||
if (isSigned) {
|
||||
diffO = diff;
|
||||
} else {
|
||||
diff = rewriter.create<FHE::ToSignedOp>(op->getLoc(), signedType, diff);
|
||||
}
|
||||
mlir::SmallVector<uint64_t> rawDiffLut = generateSignedLut(bitWidth);
|
||||
mlir::Value diffLut = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(
|
||||
rawDiffLut.size(), rewriter.getIntegerType(64)),
|
||||
rawDiffLut));
|
||||
mlir::Value diffTluOutput = rewriter.create<FHE::ApplyLookupTableEintOp>(
|
||||
op->getLoc(), inputType, diff, diffLut);
|
||||
|
||||
// o = se - de
|
||||
mlir::Value output = rewriter.create<FHE::SubEintOp>(
|
||||
op->getLoc(), inputType, sumTluOutput, diffTluOutput);
|
||||
|
||||
rewriter.replaceOp(op, {output});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
static mlir::SmallVector<uint64_t> generateUnsignedLut(unsigned bitWidth) {
|
||||
mlir::SmallVector<uint64_t> rawLut;
|
||||
uint64_t lutLen = 1 << bitWidth;
|
||||
for (uint64_t i = 0; i < lutLen; ++i) {
|
||||
rawLut.push_back((i * i) / 4);
|
||||
}
|
||||
return rawLut;
|
||||
}
|
||||
|
||||
static mlir::SmallVector<uint64_t> generateSignedLut(unsigned bitWidth) {
|
||||
mlir::SmallVector<uint64_t> rawLut;
|
||||
uint64_t lutLen = 1 << bitWidth;
|
||||
for (uint64_t i = 0; i < lutLen / 2; ++i) {
|
||||
rawLut.push_back((i * i) / 4);
|
||||
}
|
||||
for (uint64_t i = lutLen / 2; i > 0; --i) {
|
||||
rawLut.push_back((i * i) / 4);
|
||||
}
|
||||
return rawLut;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// This pass rewrites an `FHE::MulEintOp` into a set of ops of the `FHE`
|
||||
/// dialects.
|
||||
///
|
||||
/// It relies on the observation that `x*y` can be turned into `((x+y)^2)/4 -
|
||||
/// ((x-y)^2)/4`, which uses operations already available in the `FHE` dialect:
|
||||
/// + `x+y` can be computed with the leveled operation `add_eint`
|
||||
/// + `x-y` can be computed with the leveled operation `sub_eint`
|
||||
/// + `(a^2)/4` can be computed with a table lookup `apply_table_lookup`
|
||||
///
|
||||
/// Gotchas:
|
||||
/// --------
|
||||
///
|
||||
/// + Since we use the leveled addition and subtraction, we have to increment
|
||||
/// the bitwidth of the inputs to properly
|
||||
/// encode the carry of the computation. This change in bitwidth must then be
|
||||
/// propagated to the whole graph, both upstream and downstream.
|
||||
/// + This graph-wide update may reach existing `apply_lookup_table` operations,
|
||||
/// which in turn will necessitate an
|
||||
/// update of the size of the lookup table.
|
||||
class EncryptedMulToDoubleTLU
|
||||
: public EncryptedMulToDoubleTLUBase<EncryptedMulToDoubleTLU> {
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp funcOp = getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<FHE::FHEDialect>();
|
||||
target.addIllegalOp<FHE::MulEintOp>();
|
||||
|
||||
mlir::RewritePatternSet patterns(funcOp->getContext());
|
||||
patterns.add<EncryptedMulOpPattern>(funcOp->getContext());
|
||||
if (mlir::applyPartialConversion(funcOp, target, std::move(patterns))
|
||||
.failed()) {
|
||||
funcOp->emitError("Failed to rewrite FHE mul_eint operation.");
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>>
|
||||
createEncryptedMulToDoubleTLUPass() {
|
||||
return std::make_unique<EncryptedMulToDoubleTLU>();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -24,6 +24,7 @@ add_mlir_library(
|
||||
ExtractSDFGOps
|
||||
MLIRLowerableDialectsToLLVM
|
||||
FHEDialectAnalysis
|
||||
FHEDialectTransforms
|
||||
RTDialectAnalysis
|
||||
ConcretelangTransforms
|
||||
ConcretelangBConcreteTransforms
|
||||
|
||||
@@ -289,6 +289,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag("Transforming FHE boolean ops failed");
|
||||
}
|
||||
|
||||
// Encrypted mul rewriting
|
||||
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
|
||||
module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Rewriting of encrypted mul failed");
|
||||
}
|
||||
|
||||
// FHE High level pass to determine FHE parameters
|
||||
if (auto err = this->determineFHEParameters(res))
|
||||
return std::move(err);
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
|
||||
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <concretelang/Support/Pipeline.h>
|
||||
@@ -174,6 +175,16 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("transformHighLevelFHEOps", pm, context);
|
||||
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3>
|
||||
// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: return %5 : !FHE.eint<3>
|
||||
// CHECK-NEXT: }
|
||||
func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
|
||||
return %0: !FHE.eint<3>
|
||||
}
|
||||
15
compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir
Normal file
15
compiler/tests/check_tests/Dialect/FHE/mul_eint.invalid.mlir
Normal file
@@ -0,0 +1,15 @@
|
||||
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs equal
|
||||
func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> {
|
||||
%1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs and result equal
|
||||
func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
@@ -176,6 +176,15 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V0]] : !FHE.eint<2>
|
||||
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
return %0: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
|
||||
func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
MIN_PRECISON = 1
|
||||
from end_to_end_linalg_leveled_gen import P_ERROR
|
||||
|
||||
@@ -302,6 +304,56 @@ def main():
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
may_check_error_rate()
|
||||
print("---")
|
||||
# mul_eint
|
||||
if p <= 15:
|
||||
def gen_random_encodable():
|
||||
while True:
|
||||
a = random.randint(1, max_value)
|
||||
b = random.randint(1, max_value)
|
||||
if a*b <= max_value:
|
||||
return a, b
|
||||
|
||||
print("description: mul_eint_{0}bits".format(p+1))
|
||||
print("program: |")
|
||||
print(
|
||||
" func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p+1))
|
||||
print(
|
||||
" %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p+1))
|
||||
print(" return %1: !FHE.eint<{0}>".format(p+1))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" - scalar: 0")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" - scalar: 0")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 1")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" - scalar: 1")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
inp = gen_random_encodable()
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(inp[0]))
|
||||
print(" - scalar: {0}".format(inp[1]))
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(inp[0]*inp[1]))
|
||||
print("---")
|
||||
# signed
|
||||
for p in range(MIN_PRECISON, MAX_PRECISION+1):
|
||||
print("---")
|
||||
@@ -877,6 +929,141 @@ def main():
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
# mul_eint
|
||||
if 2 <= p <= 15:
|
||||
def gen_random_encodable(p):
|
||||
while True:
|
||||
a = random.randint(min_value, max_value)
|
||||
b = random.randint(min_value, max_value)
|
||||
if min_value <= a*b <= max_value:
|
||||
if p == 3:
|
||||
return a, b
|
||||
if not (a in [-1, 1, 0] or b in [-1, 1, 0]):
|
||||
return a, b
|
||||
|
||||
print("description: signed_mul_eint_{0}bits".format(p+1))
|
||||
print("program: |")
|
||||
print(
|
||||
" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p+1))
|
||||
print(
|
||||
" %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p+1))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p+1))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 1")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 1")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: 1")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: 1")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: -1")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: -1")
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value+1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-(min_value+1)))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: -1")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value+1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: -1")
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-(min_value+1)))
|
||||
print(" signed: true")
|
||||
inp = gen_random_encodable(p+1)
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(inp[0]))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(inp[1]))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(inp[0]*inp[1]))
|
||||
print(" signed: true")
|
||||
print("---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user