feat(concrete-compiler): add new ciphertext multiplication operator

This commit is contained in:
aPere3
2022-10-06 10:55:10 +02:00
committed by Alexandre Péré
parent 117e15cc05
commit fb680340f9
18 changed files with 637 additions and 1 deletions

View File

@@ -285,6 +285,44 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
let hasFolder = 1;
}
def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
let summary = "Multiplies two encrypted integers";
let description = [{
Multiplies two encrypted integers.
The encrypted integers and the result must have the same width and
signedness. Also, due to the current implementation, one supplementary
bit of width must be provided, in addition to the number of bits needed
to encode the largest output value.
Example:
```mlir
// ok
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
"FHE.mul_eint"(%a, %b): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
"FHE.mul_eint"(%a, %b): (!FHE.esint<3>, !FHE.esint<3>) -> (!FHE.esint<3>)
// error
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
"FHE.mul_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
```
}];
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
let hasVerifier = 1;
}
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
let summary = "Cast an unsigned integer to a signed one";

View File

@@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS Boolean.td)
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td)
mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)

View File

@@ -0,0 +1,24 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H
#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createEncryptedMulToDoubleTLUPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,11 @@
#ifndef CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS
#define CONCRETELANG_FHE_ENCRYPTED_MUL_TO_DOUBLE_TLU_PASS
include "mlir/Pass/PassBase.td"
def EncryptedMulToDoubleTLU : Pass<"EncryptedMulToDoubleTLU", "::mlir::func::FuncOp"> {
let summary = "Replaces encrypted multiplication with a double table lookup.";
let constructor = "mlir::concretelang::createEncryptedMulToDoubleTLUPass()";
}
#endif

View File

@@ -34,6 +34,10 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::ArrayRef<int64_t> tileSizes,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -458,6 +458,42 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
}
};
/// Rewriter for the `FHE::to_signed` operation.
struct ToSignedOpPattern : public CrtOpPattern<FHE::ToSignedOp> {
ToSignedOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::ToSignedOp>(context, params, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter{loweringParameters};
rewriter.replaceOp(op, {adaptor.input()});
return mlir::success();
}
};
/// Rewriter for the `FHE::to_unsigned` operation.
struct ToUnsignedOpPattern : public CrtOpPattern<FHE::ToUnsignedOp> {
ToUnsignedOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::ToUnsignedOp>(context, params, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter{loweringParameters};
rewriter.replaceOp(op, {adaptor.input()});
return mlir::success();
}
};
/// Rewriter for the `FHE::mul_eint_int` operation.
struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
@@ -937,6 +973,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
lowering::NegEintOpPattern,
// |_ `FHE::mul_eint_int`
lowering::MulEintIntOpPattern,
// |_ `FHE::to_unsigned`
lowering::ToUnsignedOpPattern,
// |_ `FHE::to_signed`
lowering::ToSignedOpPattern,
// |_ `FHE::apply_lookup_table`
lowering::ApplyLookupTableEintOpPattern>(&getContext(),
loweringParameters);

View File

@@ -273,6 +273,41 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
}
};
/// Rewriter for the `FHE::to_signed` operation.
struct ToSignedOpPattern : public ScalarOpPattern<FHE::ToSignedOp> {
ToSignedOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::ToSignedOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter;
rewriter.replaceOp(op, {adaptor.input()});
return mlir::success();
}
};
/// Rewriter for the `FHE::to_unsigned` operation.
struct ToUnsignedOpPattern : public ScalarOpPattern<FHE::ToUnsignedOp> {
ToUnsignedOpPattern(mlir::TypeConverter &converter,
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::ToUnsignedOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter;
rewriter.replaceOp(op, {adaptor.input()});
return mlir::success();
}
};
/// Rewriter for the `FHE::apply_lookup_table` operation.
struct ApplyLookupTableEintOpPattern
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
@@ -474,7 +509,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// |_ `FHE::sub_eint`
lowering::SubEintOpPattern,
// |_ `FHE::mul_eint_int`
lowering::MulEintIntOpPattern>(converter, &getContext());
lowering::MulEintIntOpPattern,
// |_ `FHE::to_signed`
lowering::ToSignedOpPattern,
// |_ `FHE::to_unsigned`
lowering::ToUnsignedOpPattern>(converter, &getContext());
// |_ `FHE::apply_lookup_table`
patterns.add<lowering::ApplyLookupTableEintOpPattern>(
converter, &getContext(), loweringParameters);

View File

@@ -424,6 +424,34 @@ static llvm::APInt getSqMANP(
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::ToSignedOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::ToUnsignedOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -1139,6 +1167,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto boolNotOp =
llvm::dyn_cast<mlir::concretelang::FHE::BoolNotOp>(op)) {
norm2SqEquiv = getSqMANP(boolNotOp, operands);
} else if (auto toSignedOp =
llvm::dyn_cast<mlir::concretelang::FHE::ToSignedOp>(op)) {
norm2SqEquiv = getSqMANP(toSignedOp, operands);
} else if (auto toUnsignedOp =
llvm::dyn_cast<mlir::concretelang::FHE::ToUnsignedOp>(op)) {
norm2SqEquiv = getSqMANP(toUnsignedOp, operands);
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);

View File

@@ -177,6 +177,22 @@ mlir::LogicalResult MulEintIntOp::verify() {
return mlir::success();
}
mlir::LogicalResult MulEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
}
return ::mlir::success();
}
mlir::LogicalResult ToSignedOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();

View File

@@ -1,6 +1,7 @@
add_mlir_library(
FHEDialectTransforms
Boolean.cpp
EncryptedMulToDoubleTLU.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS

View File

@@ -0,0 +1,178 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/Dialect/FHE/Analysis/utils.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
#include <concretelang/Support/Constants.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Transforms/DialectConversion.h>
#include <unordered_set>
using namespace mlir::concretelang::FHE;
namespace mlir {
namespace concretelang {
namespace {
class EncryptedMulOpPattern : public mlir::OpConversionPattern<FHE::MulEintOp> {
public:
EncryptedMulOpPattern(mlir::MLIRContext *context)
: mlir::OpConversionPattern<FHE::MulEintOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(FHE::MulEintOp op, FHE::MulEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto inputType = adaptor.a().getType();
auto bitWidth = inputType.cast<FHE::FheIntegerInterface>().getWidth();
auto isSigned = inputType.cast<FHE::FheIntegerInterface>().isSigned();
mlir::Type signedType =
FHE::EncryptedSignedIntegerType::get(op->getContext(), bitWidth);
// Note:
// -----
//
// The signedness of a value is only important:
// + when used as function input / output, because it changes the
// encoding/decoding used.
// + when used as tlu input, because it changes the encoding of the lut.
//
// Otherwise, for the leveled operations, the semantics are compatible. We
// just have to please the verifier that usually requires the same
// signedness for inputs and outputs.
// s = a + b
mlir::Value sum =
rewriter.create<FHE::AddEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
// se = (s)^2/4
// Depending on whether a,b,s are signed or not, we need a different lut to
// compute (.)^2/4.
mlir::SmallVector<uint64_t> rawSumLut;
if (isSigned) {
rawSumLut = generateSignedLut(bitWidth);
} else {
rawSumLut = generateUnsignedLut(bitWidth);
}
mlir::Value sumLut = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(
rawSumLut.size(), rewriter.getIntegerType(64)),
rawSumLut));
mlir::Value sumTluOutput = rewriter.create<FHE::ApplyLookupTableEintOp>(
op->getLoc(), inputType, sum, sumLut);
// d = a - b
mlir::Value diff =
rewriter.create<FHE::SubEintOp>(op->getLoc(), adaptor.a(), adaptor.b());
// de = (d)^2/4
// Here, the tlu must be performed with signed encoded lut, to properly
// bootstrap negative values that may arise in the computation of d. If the
// inputs are not signed, we cast the output to a signed encrypted integer.
mlir::Value diffO;
if (isSigned) {
diffO = diff;
} else {
diff = rewriter.create<FHE::ToSignedOp>(op->getLoc(), signedType, diff);
}
mlir::SmallVector<uint64_t> rawDiffLut = generateSignedLut(bitWidth);
mlir::Value diffLut = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(
rawDiffLut.size(), rewriter.getIntegerType(64)),
rawDiffLut));
mlir::Value diffTluOutput = rewriter.create<FHE::ApplyLookupTableEintOp>(
op->getLoc(), inputType, diff, diffLut);
// o = se - de
mlir::Value output = rewriter.create<FHE::SubEintOp>(
op->getLoc(), inputType, sumTluOutput, diffTluOutput);
rewriter.replaceOp(op, {output});
return mlir::success();
}
private:
static mlir::SmallVector<uint64_t> generateUnsignedLut(unsigned bitWidth) {
mlir::SmallVector<uint64_t> rawLut;
uint64_t lutLen = 1 << bitWidth;
for (uint64_t i = 0; i < lutLen; ++i) {
rawLut.push_back((i * i) / 4);
}
return rawLut;
}
static mlir::SmallVector<uint64_t> generateSignedLut(unsigned bitWidth) {
mlir::SmallVector<uint64_t> rawLut;
uint64_t lutLen = 1 << bitWidth;
for (uint64_t i = 0; i < lutLen / 2; ++i) {
rawLut.push_back((i * i) / 4);
}
for (uint64_t i = lutLen / 2; i > 0; --i) {
rawLut.push_back((i * i) / 4);
}
return rawLut;
}
};
} // namespace
/// This pass rewrites an `FHE::MulEintOp` into a set of ops of the `FHE`
/// dialects.
///
/// It relies on the observation that `x*y` can be turned into `((x+y)^2)/4 -
/// ((x-y)^2)/4`, which uses operations already available in the `FHE` dialect:
/// + `x+y` can be computed with the leveled operation `add_eint`
/// + `x-y` can be computed with the leveled operation `sub_eint`
/// + `(a^2)/4` can be computed with a table lookup `apply_table_lookup`
///
/// Gotchas:
/// --------
///
/// + Since we use the leveled addition and subtraction, we have to increment
/// the bitwidth of the inputs to properly
/// encode the carry of the computation. This change in bitwidth must then be
/// propagated to the whole graph, both upstream and downstream.
/// + This graph-wide update may reach existing `apply_lookup_table` operations,
/// which in turn will necessitate an
/// update of the size of the lookup table.
class EncryptedMulToDoubleTLU
: public EncryptedMulToDoubleTLUBase<EncryptedMulToDoubleTLU> {
public:
void runOnOperation() override {
mlir::func::FuncOp funcOp = getOperation();
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<FHE::FHEDialect>();
target.addIllegalOp<FHE::MulEintOp>();
mlir::RewritePatternSet patterns(funcOp->getContext());
patterns.add<EncryptedMulOpPattern>(funcOp->getContext());
if (mlir::applyPartialConversion(funcOp, target, std::move(patterns))
.failed()) {
funcOp->emitError("Failed to rewrite FHE mul_eint operation.");
this->signalPassFailure();
}
}
};
std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>>
createEncryptedMulToDoubleTLUPass() {
return std::make_unique<EncryptedMulToDoubleTLU>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -24,6 +24,7 @@ add_mlir_library(
ExtractSDFGOps
MLIRLowerableDialectsToLLVM
FHEDialectAnalysis
FHEDialectTransforms
RTDialectAnalysis
ConcretelangTransforms
ConcretelangBConcreteTransforms

View File

@@ -289,6 +289,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return errorDiag("Transforming FHE boolean ops failed");
}
// Encrypted mul rewriting
if (mlir::concretelang::pipeline::transformHighLevelFHEOps(mlirContext,
module, enablePass)
.failed()) {
return StreamStringError("Rewriting of encrypted mul failed");
}
// FHE High level pass to determine FHE parameters
if (auto err = this->determineFHEParameters(res))
return std::move(err);

View File

@@ -36,6 +36,7 @@
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Support/Pipeline.h>
@@ -174,6 +175,16 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("transformHighLevelFHEOps", pm, context);
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -0,0 +1,17 @@
// RUN: concretecompiler --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
// CHECK: func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 2, 4, 2, 1, 0]> : tensor<8xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 0, 1, 2, 4, 6, 9, 12]> : tensor<8xi64>
// CHECK-NEXT: %0 = "FHE.add_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst_0) {MANP = 1 : ui1} : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %2 = "FHE.sub_eint"(%arg0, %arg1) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: %3 = "FHE.to_signed"(%2) {MANP = 2 : ui3} : (!FHE.eint<3>) -> !FHE.esint<3>
// CHECK-NEXT: %4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.esint<3>, tensor<8xi64>) -> !FHE.eint<3>
// CHECK-NEXT: %5 = "FHE.sub_eint"(%1, %4) {MANP = 2 : ui3} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
// CHECK-NEXT: return %5 : !FHE.eint<3>
// CHECK-NEXT: }
func.func @simple_eint(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>) -> !FHE.eint<3> {
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
return %0: !FHE.eint<3>
}

View File

@@ -0,0 +1,15 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs equal
func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> {
%1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.mul_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -176,6 +176,15 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
func.func @mul_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V0:.*]] = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: return %[[V0]] : !FHE.eint<2>
%0 = "FHE.mul_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
return %0: !FHE.eint<2>
}
// CHECK-LABEL: func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3

View File

@@ -1,3 +1,5 @@
import random
MIN_PRECISON = 1
from end_to_end_linalg_leveled_gen import P_ERROR
@@ -302,6 +304,56 @@ def main():
print(" - scalar: {0}".format(max_value))
may_check_error_rate()
print("---")
# mul_eint
if p <= 15:
def gen_random_encodable():
while True:
a = random.randint(1, max_value)
b = random.randint(1, max_value)
if a*b <= max_value:
return a, b
print("description: mul_eint_{0}bits".format(p+1))
print("program: |")
print(
" func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p+1))
print(
" %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p+1))
print(" return %1: !FHE.eint<{0}>".format(p+1))
print(" }")
print("tests:")
print(" - inputs:")
print(" - scalar: 0")
print(" - scalar: 0")
print(" outputs:")
print(" - scalar: 0")
print(" - inputs:")
print(" - scalar: {0}".format(max_value))
print(" - scalar: 0")
print(" outputs:")
print(" - scalar: 0")
print(" - inputs:")
print(" - scalar: 0")
print(" - scalar: {0}".format(max_value))
print(" outputs:")
print(" - scalar: 0")
print(" - inputs:")
print(" - scalar: 1")
print(" - scalar: {0}".format(max_value))
print(" outputs:")
print(" - scalar: {0}".format(max_value))
print(" - inputs:")
print(" - scalar: {0}".format(max_value))
print(" - scalar: 1")
print(" outputs:")
print(" - scalar: {0}".format(max_value))
inp = gen_random_encodable()
print(" - inputs:")
print(" - scalar: {0}".format(inp[0]))
print(" - scalar: {0}".format(inp[1]))
print(" outputs:")
print(" - scalar: {0}".format(inp[0]*inp[1]))
print("---")
# signed
for p in range(MIN_PRECISON, MAX_PRECISION+1):
print("---")
@@ -877,6 +929,141 @@ def main():
may_check_error_rate()
print("---")
# mul_eint
if 2 <= p <= 15:
def gen_random_encodable(p):
while True:
a = random.randint(min_value, max_value)
b = random.randint(min_value, max_value)
if min_value <= a*b <= max_value:
if p == 3:
return a, b
if not (a in [-1, 1, 0] or b in [-1, 1, 0]):
return a, b
print("description: signed_mul_eint_{0}bits".format(p+1))
print("program: |")
print(
" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p+1))
print(
" %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p+1))
print(" return %1: !FHE.esint<{0}>".format(p+1))
print(" }")
print("tests:")
print(" - inputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - scalar: 0")
print(" signed: true")
print(" outputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" - scalar: 0")
print(" signed: true")
print(" outputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" - scalar: 0")
print(" signed: true")
print(" outputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - inputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" outputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - inputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" outputs:")
print(" - scalar: 0")
print(" signed: true")
print(" - inputs:")
print(" - scalar: 1")
print(" signed: true")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: 1")
print(" signed: true")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" - scalar: 1")
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" - scalar: 1")
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(min_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: -1")
print(" signed: true")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(-max_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: -1")
print(" signed: true")
print(" - scalar: {0}".format(min_value+1))
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(-(min_value+1)))
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(max_value))
print(" signed: true")
print(" - scalar: -1")
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(-max_value))
print(" signed: true")
print(" - inputs:")
print(" - scalar: {0}".format(min_value+1))
print(" signed: true")
print(" - scalar: -1")
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(-(min_value+1)))
print(" signed: true")
inp = gen_random_encodable(p+1)
print(" - inputs:")
print(" - scalar: {0}".format(inp[0]))
print(" signed: true")
print(" - scalar: {0}".format(inp[1]))
print(" signed: true")
print(" outputs:")
print(" - scalar: {0}".format(inp[0]*inp[1]))
print(" signed: true")
print("---")
if __name__ == "__main__":
main()