mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(concrete-compiler): implement a new round operator in the fhe dialect
This commit is contained in:
@@ -269,7 +269,7 @@ $(FIXTURE_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py
|
||||
$(FIXTURE_CPU_DIR)/bug_report.yaml:
|
||||
unzip -o $(FIXTURE_CPU_DIR)/bug_report.zip -d $(FIXTURE_CPU_DIR)
|
||||
|
||||
generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml
|
||||
generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml $(FIXTURE_CPU_DIR)/end_to_end_round.yaml
|
||||
|
||||
SECURITY_TO_TEST=80 128
|
||||
run-end-to-end-tests: build-end-to-end-tests generate-cpu-tests
|
||||
@@ -293,6 +293,7 @@ $(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml: tests/end_to_end_fixture/
|
||||
$(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py
|
||||
$(Python3_EXECUTABLE) $< --bitwidth 1 2 3 4 5 6 7 > $@
|
||||
|
||||
|
||||
generate-gpu-tests: $(FIXTURE_GPU_DIR) $(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml
|
||||
|
||||
run-end-to-end-tests-gpu: build-end-to-end-test generate-gpu-tests
|
||||
|
||||
@@ -399,6 +399,35 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
|
||||
|
||||
let summary = "Rounds a ciphertext to a smaller precision.";
|
||||
|
||||
let description = [{
|
||||
Assuming a ciphertext whose message is implemented over `p` bits, this
|
||||
operation rounds it to fit to `q` bits with `p>q`.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<5>)
|
||||
"FHE.round"(%a): (!FHE.eint<5>) -> (!FHE.eint<3>)
|
||||
"FHE.round"(%a): (!FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.round"(%a): (!FHE.esint<3>) -> (!FHE.esint<2>)
|
||||
|
||||
// error
|
||||
"FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<6>)
|
||||
"FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.eint<5>)
|
||||
"FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.esint<5>)
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$input);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// FHE Boolean Operations
|
||||
|
||||
def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
|
||||
@@ -445,8 +474,6 @@ def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
|
||||
|
||||
def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies an AND gate to two encrypted boolean values";
|
||||
|
||||
@@ -382,6 +382,210 @@ private:
|
||||
concretelang::ScalarLoweringParameters loweringParameters;
|
||||
};
|
||||
|
||||
struct RoundEintOpPattern : public ScalarOpPattern<FHE::RoundEintOp> {
|
||||
RoundEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
concretelang::ScalarLoweringParameters loweringParams,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::RoundEintOp>(converter, context, benefit),
|
||||
loweringParameters(loweringParams) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::RoundEintOp op, FHE::RoundEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// The round operator allows to move from a given precision to a smaller one
|
||||
// by rounding the most significant bits of the message. For example a 5
|
||||
// bits message:
|
||||
// 101_11 (23)
|
||||
// would be rounded to a 3 bit message:
|
||||
// 110 (6)
|
||||
//
|
||||
// The following procedure can be homomorphically applied to implement this
|
||||
// semantic:
|
||||
// 1) Propagate the carry of the round around 2^(n_before-n_after)
|
||||
// performed with a homomorphic adddition.
|
||||
// 2) For each bits to be discarded we truncate it:
|
||||
// -> Extract a ciphertext of only the bit to be discarded by
|
||||
// performing a left shift and a pbs.
|
||||
// -> Subtract this one from the input by performing a
|
||||
// homomorphic subtraction.
|
||||
|
||||
mlir::Value input = adaptor.input();
|
||||
auto inputType = op.input().getType().cast<FHE::FheIntegerInterface>();
|
||||
mlir::Value output = op.getResult();
|
||||
uint64_t inputBitwidth = inputType.getWidth();
|
||||
uint64_t outputBitwidth =
|
||||
output.getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
uint64_t bitwidthDelta = inputBitwidth - outputBitwidth;
|
||||
|
||||
typing::TypeConverter converter;
|
||||
auto inputTy =
|
||||
converter.convertType(inputType).cast<TFHE::GLWECipherTextType>();
|
||||
|
||||
//-------------------------------------------------------- CARRY PROPAGATION
|
||||
// The first step we take is to propagate the carry of the round in the
|
||||
// msbs. This we perform with an addition of cleartext correctly encoded.
|
||||
// Say we have a 5 bits message that we want to round for 3 bits, we
|
||||
// perform the following addition:
|
||||
//
|
||||
// input = |0101|11| .... |
|
||||
// carryCst = |0000|10| .... |
|
||||
// input + carryCst = |0110|01| .... |
|
||||
|
||||
uint64_t rawCarryCst = ((uint64_t)1) << (bitwidthDelta - 1);
|
||||
mlir::Value carryCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(bitwidthDelta + 1),
|
||||
rawCarryCst));
|
||||
mlir::Value encodedCarryCst = writePlaintextShiftEncoding(
|
||||
op.getLoc(), carryCst, inputBitwidth, rewriter);
|
||||
mlir::Value carryPropagatedVal = rewriter.create<TFHE::AddGLWEIntOp>(
|
||||
op.getLoc(), inputTy, input, encodedCarryCst);
|
||||
|
||||
//--------------------------------------------------------------- TRUNCATION
|
||||
// The second step is to truncate every lsbs to be removed, from the least
|
||||
// significant one to the most significant one. For example:
|
||||
//
|
||||
// previousOutput = |0110|01| .... | (t_0)
|
||||
// previousOutput = |0110|00| .... | (t_1)
|
||||
// ^
|
||||
// previousOutput = |0110|00| .... | (t_2)
|
||||
// ^
|
||||
//
|
||||
// For this, we have to generate a ciphertext that contains only the bit to
|
||||
// be truncated:
|
||||
//
|
||||
// bitToRemove = |0000|01| .... | (t_1)
|
||||
// ^
|
||||
// bitToRemove = |0000|00| .... | (t_1)
|
||||
// ^
|
||||
|
||||
mlir::Value previousOutput = carryPropagatedVal;
|
||||
TFHE::GLWECipherTextType truncationInputTy = inputTy;
|
||||
for (uint64_t i = 0; i < bitwidthDelta; ++i) {
|
||||
//---------------------------------------------------------- BIT ISOLATION
|
||||
// To extract the bit to truncate, we use a PBS that look up on the
|
||||
// padding bit. We first begin by isolating the bit in question on the
|
||||
// padding bit. This is performed with a homomorphic multiplication (left
|
||||
// shift basically) of the proper amount. For example:
|
||||
//
|
||||
// previousOutput = |0110|01| .... |
|
||||
// ^
|
||||
// shiftCst = | 100000|
|
||||
// previousOutput * shiftCst = |1| .... |
|
||||
|
||||
uint64_t rawShiftCst = ((uint64_t)1) << (inputBitwidth - i);
|
||||
mlir::Value shiftCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(rawShiftCst));
|
||||
mlir::Value shiftedInput = rewriter.create<TFHE::MulGLWEIntOp>(
|
||||
op.getLoc(), truncationInputTy, previousOutput, shiftCst);
|
||||
|
||||
//-------------------------------------------------------- LUT PREPARATION
|
||||
// To perform the right shift (kind of), we use a PBS that acts on the
|
||||
// padding bit. We expect is the following function to be applied (for the
|
||||
// first round of our example):
|
||||
//
|
||||
// f(|0| .... |) = |0000|00| .... |
|
||||
// f(|1| .... |) = |0000|01| .... |
|
||||
//
|
||||
// That being said, a PBS on the padding bit can only encode a symmetric
|
||||
// function (that is f(1) = -f(0)), by encoding f(0) in the whole table.
|
||||
// To implement our semantic, we then rely on a trick. We encode the
|
||||
// following function in the bootstrap:
|
||||
//
|
||||
// f(|0| .... |) = |1111|11|1 .... |
|
||||
// f(|1| .... |) = |0000|00|1 .... |
|
||||
//
|
||||
// And add a correction constant:
|
||||
//
|
||||
// corrCst = |0000|00|1 .... |
|
||||
// f(|0| .... |) + corrCst = |0000|00| .... |
|
||||
// f(|1| .... |) + corrCst = |0000|01| .... |
|
||||
//
|
||||
// Hence the following constant lut.
|
||||
|
||||
llvm::SmallVector<int64_t> rawLut(loweringParameters.polynomialSize,
|
||||
((uint64_t)0 - 1)
|
||||
<< (64 - (inputBitwidth + 2 - i)));
|
||||
mlir::Value lut = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(
|
||||
rawLut.size(), rewriter.getIntegerType(64)),
|
||||
rawLut));
|
||||
|
||||
//-------------------------------------------------- CIPHERTEXT ALIGNEMENT
|
||||
// In practice, TFHE ciphertexts are normally distributed around a value.
|
||||
// That means that if the lookup is performed _as is_, we have almost .5
|
||||
// probability to return the wrong value. Imagine a ciphertext centered
|
||||
// around (|0| .... |):
|
||||
//
|
||||
// | 0000001... | 1111111... | Virtual lookup table
|
||||
// _
|
||||
// / \
|
||||
// _______________/ \_________________________ Ciphertext distribution
|
||||
//
|
||||
// |0| ... | Ciphertexts mean
|
||||
//
|
||||
// If the error of the ciphertext is negative, this means that the lookup
|
||||
// will wrap, and fall on the wrong mega-case...
|
||||
//
|
||||
// This is usually taken care of on the lookup table side, but we can also
|
||||
// slightly shift the ciphertext to center its distribution with the
|
||||
// center of the mega-case. That is, end up with a situation like this:
|
||||
|
||||
//
|
||||
// | 1111111... | 0000001... | Virtual lookup table
|
||||
// _
|
||||
// / \
|
||||
// ______/ \_________________________ Ciphertext distribution
|
||||
//
|
||||
// |0| ... | Ciphertexts mean
|
||||
//
|
||||
// This is performed by adding |0|1 .... | to the ciphertext.
|
||||
|
||||
uint64_t rawRotationCst = (((uint64_t)1) << 62);
|
||||
mlir::Value rotationCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(rawRotationCst));
|
||||
mlir::Value shiftedRotatedInput = rewriter.create<TFHE::AddGLWEIntOp>(
|
||||
op.getLoc(), truncationInputTy, shiftedInput, rotationCst);
|
||||
|
||||
//-------------------------------------------------------------------- PBS
|
||||
// The lookup is performed ...
|
||||
|
||||
mlir::Value keyswitched = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
op.getLoc(), truncationInputTy, shiftedRotatedInput, -1, -1);
|
||||
mlir::Value bootstrapped = rewriter.create<TFHE::BootstrapGLWEOp>(
|
||||
op.getLoc(), truncationInputTy, keyswitched, lut, -1, -1, -1, -1);
|
||||
|
||||
//------------------------------------------------------------- CORRECTION
|
||||
// The correction is performed to achieve our right shift semantic.
|
||||
|
||||
uint64_t rawCorrCst = ((uint64_t)1) << (64 - (inputBitwidth + 2 - i));
|
||||
mlir::Value corrCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(rawCorrCst));
|
||||
mlir::Value extractedBit = rewriter.create<TFHE::AddGLWEIntOp>(
|
||||
op.getLoc(), truncationInputTy, bootstrapped, corrCst);
|
||||
|
||||
//------------------------------------------------------------- TRUNCATION
|
||||
// Finally, the extracted bit is subtracted from the input.
|
||||
|
||||
mlir::Value minusIsolatedBit = rewriter.create<TFHE::NegGLWEOp>(
|
||||
op.getLoc(), truncationInputTy, extractedBit);
|
||||
truncationInputTy = TFHE::GLWECipherTextType::get(
|
||||
rewriter.getContext(), -1, -1, -1, truncationInputTy.getP() - 1);
|
||||
mlir::Value truncationOutput = rewriter.create<TFHE::AddGLWEOp>(
|
||||
op.getLoc(), truncationInputTy, previousOutput, minusIsolatedBit);
|
||||
previousOutput = truncationOutput;
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {previousOutput});
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
concretelang::ScalarLoweringParameters loweringParameters;
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::to_bool` operation.
|
||||
struct ToBoolOpPattern : public mlir::OpRewritePattern<FHE::ToBoolOp> {
|
||||
ToBoolOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
@@ -517,8 +721,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// |_ `FHE::to_unsigned`
|
||||
lowering::ToUnsignedOpPattern>(converter, &getContext());
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(
|
||||
converter, &getContext(), loweringParameters);
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern,
|
||||
// |_ `FHE::round`
|
||||
lowering::RoundEintOpPattern>(converter, &getContext(),
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for boolean conversion ops
|
||||
patterns.add<lowering::FromBoolOpPattern, lowering::ToBoolOpPattern>(
|
||||
|
||||
@@ -132,6 +132,10 @@ struct FunctionToDag {
|
||||
addLut(dag, val, encrypted_inputs, precision);
|
||||
return;
|
||||
}
|
||||
if (isRound(op)) {
|
||||
addRound(dag, val, encrypted_inputs, precision);
|
||||
return;
|
||||
}
|
||||
if (auto dot = asDot(op)) {
|
||||
auto weightsOpt = dotWeights(dot);
|
||||
if (weightsOpt) {
|
||||
@@ -156,6 +160,15 @@ struct FunctionToDag {
|
||||
dag->add_lut(encrypted_input, slice(unknowFunction), precision);
|
||||
}
|
||||
|
||||
void addRound(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
|
||||
int rounded_precision) {
|
||||
assert(encrypted_inputs.size() == 1);
|
||||
// No need to distinguish different lut kind until we do approximate
|
||||
// paradigm on outputs
|
||||
auto encrypted_input = encrypted_inputs[0];
|
||||
index[val] = dag->add_round_op(encrypted_input, rounded_precision);
|
||||
}
|
||||
|
||||
void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
|
||||
std::vector<std::int64_t> &weights_vector) {
|
||||
assert(encrypted_inputs.size() == 1);
|
||||
@@ -216,6 +229,10 @@ struct FunctionToDag {
|
||||
mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op);
|
||||
}
|
||||
|
||||
bool isRound(mlir::Operation &op) {
|
||||
return llvm::isa<mlir::concretelang::FHE::RoundEintOp>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
|
||||
}
|
||||
|
||||
@@ -487,6 +487,29 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.round` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::RoundEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
uint64_t inputWidth =
|
||||
op.getOperand().getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
uint64_t outputWidth =
|
||||
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
uint64_t clearedBits = inputWidth - outputWidth;
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
eNorm += clearedBits;
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -1176,6 +1199,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
} else if (auto roundOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::RoundEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(roundOp, operands);
|
||||
} else if (llvm::isa<mlir::concretelang::FHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::ToBoolOp>(op) ||
|
||||
llvm::isa<mlir::concretelang::FHE::FromBoolOp>(op) ||
|
||||
|
||||
@@ -262,6 +262,25 @@ mlir::LogicalResult GenGateOp::verify() {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult RoundEintOp::verify() {
|
||||
auto input = this->input().getType().cast<FheIntegerInterface>();
|
||||
auto output = this->getResult().getType().cast<FheIntegerInterface>();
|
||||
|
||||
if (input.getWidth() <= output.getWidth()) {
|
||||
this->emitOpError(
|
||||
"should have the input width larger than the output width.");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (input.isSigned() != output.isSigned()) {
|
||||
this->emitOpError(
|
||||
"should have the signedness of encrypted inputs and result equal");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Avoid addition with constant 0
|
||||
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
|
||||
@@ -105,10 +105,6 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "bits");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getP() != b.getP() || a.getP() != result.getP()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -303,3 +303,12 @@ func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<3>
|
||||
func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<3> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<3>
|
||||
// CHECK-NEXT: return %[[V1]] : !FHE.eint<3>
|
||||
|
||||
%1 = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<3>
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
|
||||
23
compiler/tests/check_tests/Dialect/FHE/round.invalid.mlir
Normal file
23
compiler/tests/check_tests/Dialect/FHE/round.invalid.mlir
Normal file
@@ -0,0 +1,23 @@
|
||||
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.round' op should have the input width larger than the output width.
|
||||
func.func @equal_width(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.round"(%arg0): (!FHE.eint<3>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.round' op should have the input width larger than the output width.
|
||||
func.func @larger_output_width(%arg0: !FHE.eint<3>) -> !FHE.eint<4> {
|
||||
%1 = "FHE.round"(%arg0): (!FHE.eint<3>) -> (!FHE.eint<4>)
|
||||
return %1: !FHE.eint<4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: error: 'FHE.round' op should have the signedness of encrypted inputs and result equal
|
||||
func.func @signed_input(%arg0: !FHE.esint<3>) -> !FHE.eint<2> {
|
||||
%1 = "FHE.round"(%arg0): (!FHE.esint<3>) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
@@ -1,23 +1,5 @@
|
||||
// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter result
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{6}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'p' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{6}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{6}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE p parameter inputs
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{6}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'p' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,12,64}{6}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE dimension parameter result
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{512,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'dimension' parameter}}
|
||||
|
||||
76
compiler/tests/end_to_end_fixture/end_to_end_round_gen.py
Normal file
76
compiler/tests/end_to_end_fixture/end_to_end_round_gen.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import argparse
|
||||
from platform import mac_ver
|
||||
|
||||
import numpy as np
|
||||
|
||||
from end_to_end_linalg_leveled_gen import P_ERROR
|
||||
|
||||
|
||||
def round(val, p_start, p_end, signed=False):
|
||||
p_delta = p_start - p_end
|
||||
carry_mask = 1 << (p_delta - 1)
|
||||
if val & carry_mask != 0:
|
||||
val += carry_mask << 1
|
||||
output = val >> p_delta
|
||||
if signed:
|
||||
if output >= (1 << (p_end - 1)):
|
||||
output = -output
|
||||
return output
|
||||
|
||||
|
||||
def generate(args):
|
||||
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
|
||||
print("# /!\ THIS FILE HAS BEEN GENERATED")
|
||||
np.random.seed(0)
|
||||
# unsigned_unsigned
|
||||
for from_p in args.bitwidth:
|
||||
for to_p in range(2, from_p):
|
||||
max_value = (2 ** from_p) - 1
|
||||
print(f"description: unsigned_round_{from_p}to{to_p}bits")
|
||||
print("program: |")
|
||||
print(f" func.func @main(%arg0: !FHE.eint<{from_p}>) -> !FHE.eint<{to_p}> {{")
|
||||
print(f" %1 = \"FHE.round\"(%arg0) : (!FHE.eint<{from_p}>) -> !FHE.eint<{to_p}>")
|
||||
print(f" return %1: !FHE.eint<{to_p}>")
|
||||
print(" }")
|
||||
print(f"p-error: {P_ERROR}")
|
||||
print("tests:")
|
||||
for i in range(8):
|
||||
val = np.random.randint(max_value)
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {val}")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {round(val, from_p, to_p)}")
|
||||
print("---")
|
||||
# signed_signed
|
||||
for from_p in args.bitwidth:
|
||||
for to_p in range(2, from_p):
|
||||
min_value = -(2 ** (from_p - 1))
|
||||
max_value = abs(min_value) - 1
|
||||
print(f"description: signed_round_from_{from_p}to{to_p}bits")
|
||||
print("program: |")
|
||||
print(f" func.func @main(%arg0: !FHE.esint<{from_p}>) -> !FHE.esint<{to_p}> {{")
|
||||
print(f" %1 = \"FHE.round\"(%arg0) : (!FHE.esint<{from_p}>) -> !FHE.esint<{to_p}>")
|
||||
print(f" return %1: !FHE.esint<{to_p}>")
|
||||
print(" }")
|
||||
print(f"p-error: {P_ERROR}")
|
||||
print("tests:")
|
||||
for i in range(8):
|
||||
val = np.random.randint(min_value, max_value)
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {val}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {round(val, from_p, to_p, True)}")
|
||||
print(f" signed: true")
|
||||
print("---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI = argparse.ArgumentParser()
|
||||
CLI.add_argument(
|
||||
"--bitwidth",
|
||||
help="Specify the list of bitwidth to generate",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=list(range(3,9)),
|
||||
)
|
||||
generate(CLI.parse_args())
|
||||
Reference in New Issue
Block a user