diff --git a/compiler/Makefile b/compiler/Makefile index e2cd82d88..dc925a142 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -269,7 +269,7 @@ $(FIXTURE_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py $(FIXTURE_CPU_DIR)/bug_report.yaml: unzip -o $(FIXTURE_CPU_DIR)/bug_report.zip -d $(FIXTURE_CPU_DIR) -generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml +generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml $(FIXTURE_CPU_DIR)/end_to_end_round.yaml SECURITY_TO_TEST=80 128 run-end-to-end-tests: build-end-to-end-tests generate-cpu-tests @@ -293,6 +293,7 @@ $(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml: tests/end_to_end_fixture/ $(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py $(Python3_EXECUTABLE) $< --bitwidth 1 2 3 4 5 6 7 > $@ + generate-gpu-tests: $(FIXTURE_GPU_DIR) $(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml run-end-to-end-tests-gpu: build-end-to-end-test generate-gpu-tests diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index cf6d195e4..a0245f049 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -399,6 +399,35 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> { let hasVerifier = 1; } +def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> { + + let summary = "Rounds a ciphertext to a smaller precision."; + + let description = [{ + Assuming a ciphertext whose message is implemented over `p` bits, this + operation rounds it to fit to `q` bits with `p>q`. + + Example: + ```mlir + // ok + "FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<5>) + "FHE.round"(%a): (!FHE.eint<5>) -> (!FHE.eint<3>) + "FHE.round"(%a): (!FHE.eint<3>) -> (!FHE.eint<2>) + "FHE.round"(%a): (!FHE.esint<3>) -> (!FHE.esint<2>) + + // error + "FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<6>) + "FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.eint<5>) + "FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.esint<5>) + + ``` + }]; + + let arguments = (ins FHE_AnyEncryptedInteger:$input); + let results = (outs FHE_AnyEncryptedInteger); + let hasVerifier = 1; +} + // FHE Boolean Operations def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> { @@ -445,8 +474,6 @@ def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> { let results = (outs FHE_EncryptedBooleanType); } - - def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> { let summary = "Applies an AND gate to two encrypted boolean values"; diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 801a92b33..aa59b36d4 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -382,6 +382,210 @@ private: concretelang::ScalarLoweringParameters loweringParameters; }; +struct RoundEintOpPattern : public ScalarOpPattern { + RoundEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + concretelang::ScalarLoweringParameters loweringParams, + mlir::PatternBenefit benefit = 1) + : ScalarOpPattern(converter, context, benefit), + loweringParameters(loweringParams) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::RoundEintOp op, FHE::RoundEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // The round operator allows to move from a given precision to a smaller one + // by rounding the most significant bits of the message. For example a 5 + // bits message: + // 101_11 (23) + // would be rounded to a 3 bit message: + // 110 (6) + // + // The following procedure can be homomorphically applied to implement this + // semantic: + // 1) Propagate the carry of the round around 2^(n_before-n_after) + // performed with a homomorphic adddition. + // 2) For each bits to be discarded we truncate it: + // -> Extract a ciphertext of only the bit to be discarded by + // performing a left shift and a pbs. + // -> Subtract this one from the input by performing a + // homomorphic subtraction. + + mlir::Value input = adaptor.input(); + auto inputType = op.input().getType().cast(); + mlir::Value output = op.getResult(); + uint64_t inputBitwidth = inputType.getWidth(); + uint64_t outputBitwidth = + output.getType().cast().getWidth(); + uint64_t bitwidthDelta = inputBitwidth - outputBitwidth; + + typing::TypeConverter converter; + auto inputTy = + converter.convertType(inputType).cast(); + + //-------------------------------------------------------- CARRY PROPAGATION + // The first step we take is to propagate the carry of the round in the + // msbs. This we perform with an addition of cleartext correctly encoded. + // Say we have a 5 bits message that we want to round for 3 bits, we + // perform the following addition: + // + // input = |0101|11| .... | + // carryCst = |0000|10| .... | + // input + carryCst = |0110|01| .... | + + uint64_t rawCarryCst = ((uint64_t)1) << (bitwidthDelta - 1); + mlir::Value carryCst = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIntegerType(bitwidthDelta + 1), + rawCarryCst)); + mlir::Value encodedCarryCst = writePlaintextShiftEncoding( + op.getLoc(), carryCst, inputBitwidth, rewriter); + mlir::Value carryPropagatedVal = rewriter.create( + op.getLoc(), inputTy, input, encodedCarryCst); + + //--------------------------------------------------------------- TRUNCATION + // The second step is to truncate every lsbs to be removed, from the least + // significant one to the most significant one. For example: + // + // previousOutput = |0110|01| .... | (t_0) + // previousOutput = |0110|00| .... | (t_1) + // ^ + // previousOutput = |0110|00| .... | (t_2) + // ^ + // + // For this, we have to generate a ciphertext that contains only the bit to + // be truncated: + // + // bitToRemove = |0000|01| .... | (t_1) + // ^ + // bitToRemove = |0000|00| .... | (t_1) + // ^ + + mlir::Value previousOutput = carryPropagatedVal; + TFHE::GLWECipherTextType truncationInputTy = inputTy; + for (uint64_t i = 0; i < bitwidthDelta; ++i) { + //---------------------------------------------------------- BIT ISOLATION + // To extract the bit to truncate, we use a PBS that look up on the + // padding bit. We first begin by isolating the bit in question on the + // padding bit. This is performed with a homomorphic multiplication (left + // shift basically) of the proper amount. For example: + // + // previousOutput = |0110|01| .... | + // ^ + // shiftCst = | 100000| + // previousOutput * shiftCst = |1| .... | + + uint64_t rawShiftCst = ((uint64_t)1) << (inputBitwidth - i); + mlir::Value shiftCst = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(rawShiftCst)); + mlir::Value shiftedInput = rewriter.create( + op.getLoc(), truncationInputTy, previousOutput, shiftCst); + + //-------------------------------------------------------- LUT PREPARATION + // To perform the right shift (kind of), we use a PBS that acts on the + // padding bit. We expect is the following function to be applied (for the + // first round of our example): + // + // f(|0| .... |) = |0000|00| .... | + // f(|1| .... |) = |0000|01| .... | + // + // That being said, a PBS on the padding bit can only encode a symmetric + // function (that is f(1) = -f(0)), by encoding f(0) in the whole table. + // To implement our semantic, we then rely on a trick. We encode the + // following function in the bootstrap: + // + // f(|0| .... |) = |1111|11|1 .... | + // f(|1| .... |) = |0000|00|1 .... | + // + // And add a correction constant: + // + // corrCst = |0000|00|1 .... | + // f(|0| .... |) + corrCst = |0000|00| .... | + // f(|1| .... |) + corrCst = |0000|01| .... | + // + // Hence the following constant lut. + + llvm::SmallVector rawLut(loweringParameters.polynomialSize, + ((uint64_t)0 - 1) + << (64 - (inputBitwidth + 2 - i))); + mlir::Value lut = rewriter.create( + op.getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + rawLut.size(), rewriter.getIntegerType(64)), + rawLut)); + + //-------------------------------------------------- CIPHERTEXT ALIGNEMENT + // In practice, TFHE ciphertexts are normally distributed around a value. + // That means that if the lookup is performed _as is_, we have almost .5 + // probability to return the wrong value. Imagine a ciphertext centered + // around (|0| .... |): + // + // | 0000001... | 1111111... | Virtual lookup table + // _ + // / \ + // _______________/ \_________________________ Ciphertext distribution + // + // |0| ... | Ciphertexts mean + // + // If the error of the ciphertext is negative, this means that the lookup + // will wrap, and fall on the wrong mega-case... + // + // This is usually taken care of on the lookup table side, but we can also + // slightly shift the ciphertext to center its distribution with the + // center of the mega-case. That is, end up with a situation like this: + + // + // | 1111111... | 0000001... | Virtual lookup table + // _ + // / \ + // ______/ \_________________________ Ciphertext distribution + // + // |0| ... | Ciphertexts mean + // + // This is performed by adding |0|1 .... | to the ciphertext. + + uint64_t rawRotationCst = (((uint64_t)1) << 62); + mlir::Value rotationCst = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(rawRotationCst)); + mlir::Value shiftedRotatedInput = rewriter.create( + op.getLoc(), truncationInputTy, shiftedInput, rotationCst); + + //-------------------------------------------------------------------- PBS + // The lookup is performed ... + + mlir::Value keyswitched = rewriter.create( + op.getLoc(), truncationInputTy, shiftedRotatedInput, -1, -1); + mlir::Value bootstrapped = rewriter.create( + op.getLoc(), truncationInputTy, keyswitched, lut, -1, -1, -1, -1); + + //------------------------------------------------------------- CORRECTION + // The correction is performed to achieve our right shift semantic. + + uint64_t rawCorrCst = ((uint64_t)1) << (64 - (inputBitwidth + 2 - i)); + mlir::Value corrCst = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(rawCorrCst)); + mlir::Value extractedBit = rewriter.create( + op.getLoc(), truncationInputTy, bootstrapped, corrCst); + + //------------------------------------------------------------- TRUNCATION + // Finally, the extracted bit is subtracted from the input. + + mlir::Value minusIsolatedBit = rewriter.create( + op.getLoc(), truncationInputTy, extractedBit); + truncationInputTy = TFHE::GLWECipherTextType::get( + rewriter.getContext(), -1, -1, -1, truncationInputTy.getP() - 1); + mlir::Value truncationOutput = rewriter.create( + op.getLoc(), truncationInputTy, previousOutput, minusIsolatedBit); + previousOutput = truncationOutput; + } + + rewriter.replaceOp(op, {previousOutput}); + + return mlir::success(); + }; + +private: + concretelang::ScalarLoweringParameters loweringParameters; +}; + /// Rewriter for the `FHE::to_bool` operation. struct ToBoolOpPattern : public mlir::OpRewritePattern { ToBoolOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) @@ -517,8 +721,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // |_ `FHE::to_unsigned` lowering::ToUnsignedOpPattern>(converter, &getContext()); // |_ `FHE::apply_lookup_table` - patterns.add( - converter, &getContext(), loweringParameters); + patterns.add(converter, &getContext(), + loweringParameters); // Patterns for boolean conversion ops patterns.add( diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 69110193c..c64c88c79 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -132,6 +132,10 @@ struct FunctionToDag { addLut(dag, val, encrypted_inputs, precision); return; } + if (isRound(op)) { + addRound(dag, val, encrypted_inputs, precision); + return; + } if (auto dot = asDot(op)) { auto weightsOpt = dotWeights(dot); if (weightsOpt) { @@ -156,6 +160,15 @@ struct FunctionToDag { dag->add_lut(encrypted_input, slice(unknowFunction), precision); } + void addRound(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, + int rounded_precision) { + assert(encrypted_inputs.size() == 1); + // No need to distinguish different lut kind until we do approximate + // paradigm on outputs + auto encrypted_input = encrypted_inputs[0]; + index[val] = dag->add_round_op(encrypted_input, rounded_precision); + } + void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, std::vector &weights_vector) { assert(encrypted_inputs.size() == 1); @@ -216,6 +229,10 @@ struct FunctionToDag { mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op); } + bool isRound(mlir::Operation &op) { + return llvm::isa(op); + } + mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) { return llvm::dyn_cast(op); } diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 0c489367a..2bda48f8c 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -487,6 +487,29 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(sqNorm, eNorm); } +/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +/// that is equivalent to an `FHE.round` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::RoundEintOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + uint64_t inputWidth = + op.getOperand().getType().cast().getWidth(); + uint64_t outputWidth = + op.getResult().getType().cast().getWidth(); + uint64_t clearedBits = inputWidth - outputWidth; + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + eNorm += clearedBits; + + return eNorm; +} + /// Calculates the squared Minimal Arithmetic Noise Padding of an /// `FHELinalg.add_eint_int` operation. static llvm::APInt getSqMANP( @@ -1176,6 +1199,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto mulEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (auto roundOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(roundOp, operands); } else if (llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index bee590615..0e043c2bb 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -262,6 +262,25 @@ mlir::LogicalResult GenGateOp::verify() { return mlir::success(); } +mlir::LogicalResult RoundEintOp::verify() { + auto input = this->input().getType().cast(); + auto output = this->getResult().getType().cast(); + + if (input.getWidth() <= output.getWidth()) { + this->emitOpError( + "should have the input width larger than the output width."); + return mlir::failure(); + } + + if (input.isSigned() != output.isSigned()) { + this->emitOpError( + "should have the signedness of encrypted inputs and result equal"); + return mlir::failure(); + } + + return mlir::success(); +} + /// Avoid addition with constant 0 OpFoldResult AddEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp b/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp index 29b858630..b4775cff0 100644 --- a/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp +++ b/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp @@ -105,10 +105,6 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) { emitOpErrorForIncompatibleGLWEParameter(op, "bits"); return mlir::failure(); } - if (a.getP() != b.getP() || a.getP() != result.getP()) { - emitOpErrorForIncompatibleGLWEParameter(op, "p"); - return mlir::failure(); - } return mlir::success(); } diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 418f40e51..96cb3e0f3 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -303,3 +303,12 @@ func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool { %1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool return %1: !FHE.ebool } + +// CHECK-LABEL: func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<3> +func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<3> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<3> + // CHECK-NEXT: return %[[V1]] : !FHE.eint<3> + + %1 = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<3> + return %1: !FHE.eint<3> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/round.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/round.invalid.mlir new file mode 100644 index 000000000..bc6a8ba0c --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/round.invalid.mlir @@ -0,0 +1,23 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.round' op should have the input width larger than the output width. +func.func @equal_width(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %1 = "FHE.round"(%arg0): (!FHE.eint<3>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.round' op should have the input width larger than the output width. +func.func @larger_output_width(%arg0: !FHE.eint<3>) -> !FHE.eint<4> { + %1 = "FHE.round"(%arg0): (!FHE.eint<3>) -> (!FHE.eint<4>) + return %1: !FHE.eint<4> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.round' op should have the signedness of encrypted inputs and result equal +func.func @signed_input(%arg0: !FHE.esint<3>) -> !FHE.eint<2> { + %1 = "FHE.round"(%arg0): (!FHE.esint<3>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir index d55bba524..ed5445627 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir @@ -1,23 +1,5 @@ // RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s -// GLWE p parameter result -func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{6}> { - // expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'p' parameter}} - %1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{6}>) - return %1: !TFHE.glwe<{1024,12,64}{6}> -} - -// ----- - -// GLWE p parameter inputs -func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{6}>) -> !TFHE.glwe<{1024,12,64}{7}> { - // expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'p' parameter}} - %1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,12,64}{6}>) -> (!TFHE.glwe<{1024,12,64}{7}>) - return %1: !TFHE.glwe<{1024,12,64}{7}> -} - -// ----- - // GLWE dimension parameter result func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{512,12,64}{7}> { // expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'dimension' parameter}} diff --git a/compiler/tests/end_to_end_fixture/end_to_end_round_gen.py b/compiler/tests/end_to_end_fixture/end_to_end_round_gen.py new file mode 100644 index 000000000..74984ec7d --- /dev/null +++ b/compiler/tests/end_to_end_fixture/end_to_end_round_gen.py @@ -0,0 +1,76 @@ +import argparse +from platform import mac_ver + +import numpy as np + +from end_to_end_linalg_leveled_gen import P_ERROR + + +def round(val, p_start, p_end, signed=False): + p_delta = p_start - p_end + carry_mask = 1 << (p_delta - 1) + if val & carry_mask != 0: + val += carry_mask << 1 + output = val >> p_delta + if signed: + if output >= (1 << (p_end - 1)): + output = -output + return output + + +def generate(args): + print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY") + print("# /!\ THIS FILE HAS BEEN GENERATED") + np.random.seed(0) + # unsigned_unsigned + for from_p in args.bitwidth: + for to_p in range(2, from_p): + max_value = (2 ** from_p) - 1 + print(f"description: unsigned_round_{from_p}to{to_p}bits") + print("program: |") + print(f" func.func @main(%arg0: !FHE.eint<{from_p}>) -> !FHE.eint<{to_p}> {{") + print(f" %1 = \"FHE.round\"(%arg0) : (!FHE.eint<{from_p}>) -> !FHE.eint<{to_p}>") + print(f" return %1: !FHE.eint<{to_p}>") + print(" }") + print(f"p-error: {P_ERROR}") + print("tests:") + for i in range(8): + val = np.random.randint(max_value) + print(" - inputs:") + print(f" - scalar: {val}") + print(" outputs:") + print(f" - scalar: {round(val, from_p, to_p)}") + print("---") + # signed_signed + for from_p in args.bitwidth: + for to_p in range(2, from_p): + min_value = -(2 ** (from_p - 1)) + max_value = abs(min_value) - 1 + print(f"description: signed_round_from_{from_p}to{to_p}bits") + print("program: |") + print(f" func.func @main(%arg0: !FHE.esint<{from_p}>) -> !FHE.esint<{to_p}> {{") + print(f" %1 = \"FHE.round\"(%arg0) : (!FHE.esint<{from_p}>) -> !FHE.esint<{to_p}>") + print(f" return %1: !FHE.esint<{to_p}>") + print(" }") + print(f"p-error: {P_ERROR}") + print("tests:") + for i in range(8): + val = np.random.randint(min_value, max_value) + print(" - inputs:") + print(f" - scalar: {val}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {round(val, from_p, to_p, True)}") + print(f" signed: true") + print("---") + +if __name__ == "__main__": + CLI = argparse.ArgumentParser() + CLI.add_argument( + "--bitwidth", + help="Specify the list of bitwidth to generate", + nargs="+", + type=int, + default=list(range(3,9)), + ) + generate(CLI.parse_args())