diff --git a/concrete-security-curves-cpp/gen_header.py b/concrete-security-curves-cpp/gen_header.py index bfd86e924..38b0a5b73 100644 --- a/concrete-security-curves-cpp/gen_header.py +++ b/concrete-security-curves-cpp/gen_header.py @@ -1,6 +1,4 @@ -import json -import sys - +import sys, json; def print_curve(data): print(f'\tSecurityCurve({data["security_level"]},{data["slope"]}, {data["bias"]}, {data["minimal_lwe_dimension"]}, KeyFormat::BINARY),') @@ -11,4 +9,4 @@ def print_cpp_curves_declaration(datas): print_curve(data) print("}\n") -print_cpp_curves_declaration(json.load(sys.stdin)) +print_cpp_curves_declaration(json.load(sys.stdin)) \ No newline at end of file diff --git a/concrete-security-curves-rust/gen_table.py b/concrete-security-curves-rust/gen_table.py index 07c9817d9..4646cbbc5 100644 --- a/concrete-security-curves-rust/gen_table.py +++ b/concrete-security-curves-rust/gen_table.py @@ -1,6 +1,4 @@ -import json -import sys - +import sys, json; def print_curve(data): print(f' ({data["security_level"]}, SecurityWeights {{ slope: {data["slope"]}, bias: {data["bias"]}, minimal_lwe_dimension: {data["minimal_lwe_dimension"]} }}),') @@ -12,4 +10,4 @@ def print_rust_curves_declaration(datas): print_curve(data) print("];") -print_rust_curves_declaration(json.load(sys.stdin)) +print_rust_curves_declaration(json.load(sys.stdin)) \ No newline at end of file diff --git a/lattice-scripts/compare_curves_and_estimator.py b/lattice-scripts/compare_curves_and_estimator.py index 05eb2b3dc..598d74335 100644 --- a/lattice-scripts/compare_curves_and_estimator.py +++ b/lattice-scripts/compare_curves_and_estimator.py @@ -1,12 +1,10 @@ +from estimator import LWE, ND +from sage.all import oo, load, floor +from generate_data import estimate, get_security_level import argparse import os import sys - -from estimator import LWE, ND -from generate_data import estimate, get_security_level -from sage.all import floor, load, oo - -sys.path.insert(1, "lattice-estimator") +sys.path.insert(1, 'lattice-estimator') LOG_N_MAX = 17 + 1 @@ -23,7 +21,6 @@ def get_index(sec, curves): for i in range(len(curves)): if curves[i][2] == sec: return i - return None def estimate_security_with_lattice_estimator(lwe_dimension, std_dev, log_q): @@ -48,8 +45,9 @@ def get_minimal_lwe_dimension(curve, security_level, log_q): :param security_level: :param log_q: :return: - """ - return curve[-1] + """ + minimal_lwe_dim = curve[-1] + return minimal_lwe_dim def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q): @@ -67,7 +65,8 @@ def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q): a = curve[0] b = curve[1] + log_q - return minimal_stddev(a, b, lwe_dimension) + stddev = minimal_stddev(a, b, lwe_dimension) + return stddev def compare_curve_and_estimator(security_level, log_q, curves_dir): @@ -138,5 +137,5 @@ if __name__ == "__main__": args = CLI.parse_args() for security_level in args.security_levels: if not(compare_curve_and_estimator(security_level, args.log_q, args.curves_dir)): - sys.exit(1) - sys.exit(0) + exit(1) + exit(0) \ No newline at end of file diff --git a/lattice-scripts/generate_data.py b/lattice-scripts/generate_data.py index c675da015..5d2693d2f 100644 --- a/lattice-scripts/generate_data.py +++ b/lattice-scripts/generate_data.py @@ -1,13 +1,11 @@ -import argparse +from estimator import RC, LWE, ND +from sage.all import oo, save, load +from math import log2 import multiprocessing +import argparse import os import sys -from math import log2 - -from estimator import LWE, ND, RC -from sage.all import load, oo, save - -sys.path.insert(1, "lattice-estimator") +sys.path.insert(1, 'lattice-estimator') old_models_sobj = "" @@ -27,7 +25,6 @@ def old_models(security_level, sd, logq=32): for i in range(len(curves)): if curves[i][2] == sec: return i - return None if old_models_sobj is None or not(os.path.exists(old_models_sobj)): return 450 @@ -51,7 +48,10 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): :param skip: attacks to skip """ - return LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + + return est + def get_security_level(est, dp=2): """ @@ -64,7 +64,10 @@ def get_security_level(est, dp=2): for key in est: attack_costs.append(est[key]["rop"]) # get the security level correct to 'dp' decimal places - return round(log2(min(attack_costs)), dp) + security_level = round(log2(min(attack_costs)), dp) + + return security_level + def inequality(x, y): """A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier @@ -77,8 +80,6 @@ def inequality(x, y): if x > y: return -1 - return None - def automated_param_select_n(params, target_security=128): """A function used to generate the smallest value of n which allows for @@ -95,6 +96,7 @@ def automated_param_select_n(params, target_security=128): # get an estimate based on the prev. model print("n = {}".format(params.n)) n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) + # n_start = max(n_start, 450) # TODO: think about throwing an error if the required n < 450 params = params.updated(n=n_start) @@ -106,6 +108,7 @@ def automated_param_select_n(params, target_security=128): # (on binary key guessing) while z * security_level < z * target_security: # TODO: fill in this case! For n > 1024 we only need to consider every + # 256 (optimization) params = params.updated(n=params.n + z * 8) costs = estimate(params) security_level = get_security_level(costs, 2) @@ -135,7 +138,7 @@ def automated_param_select_n(params, target_security=128): def generate_parameter_matrix( - params_in, sd_range, target_security_levels=(128), name="default_name" + params_in, sd_range, target_security_levels=[128], name="default_name" ): """ :param params_in: a initial set of LWE parameters @@ -156,7 +159,7 @@ def generate_parameter_matrix( try: results = load("{}.sobj".format(name)) except BaseException: - results = {} + results = dict() results["{}".format(lam)] = [] results["{}".format(lam)].append( @@ -168,7 +171,7 @@ def generate_parameter_matrix( def generate_zama_curves64( - sd_range=(2, 58), target_security_levels=(128), name="default_name" + sd_range=[2, 58], target_security_levels=[128], name="default_name" ): """ The top level function which we use to run the experiment diff --git a/lattice-scripts/verify_curves.py b/lattice-scripts/verify_curves.py index c130e4b11..d63824511 100644 --- a/lattice-scripts/verify_curves.py +++ b/lattice-scripts/verify_curves.py @@ -1,25 +1,23 @@ -import argparse +import numpy as np +from sage.all import save, load, ceil import json import os -import sys - -import numpy as np -from sage.all import ceil, load, save +import argparse def sort_data(security_level, curves_dir): from operator import itemgetter # step 1. load the data - x = load(os.path.join(curves_dir, f"{security_level}.sobj")) + X = load(os.path.join(curves_dir, f"{security_level}.sobj")) # step 2. sort by SD - x = sorted(x["{}".format(security_level)], key=itemgetter(2)) + x = sorted(X["{}".format(security_level)], key=itemgetter(2)) # step3. replace the sorted value - x["{}".format(security_level)] = x + X["{}".format(security_level)] = x - return x + return X def generate_curve(security_level, curves_dir): @@ -65,12 +63,16 @@ def verify_curve(security_level, a, b, curves_dir): # step 3. for each n, check whether we satisfy the table n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0]) + #print(n_min) + #print(n_max) for n in range(n_max, n_min, -1): model_sd = f_model(a, b, n) table_sd = f_table(X["{}".format(security_level)], n) + #print(n, table_sd, model_sd, model_sd >= table_sd) if table_sd > model_sd: + #print("MODEL FAILS at n = {}".format(n)) return False return True, n_min @@ -83,6 +85,7 @@ def generate_and_verify(security_levels, log_q, curves_dir, name="verified_curve fail = [] for sec in security_levels: + #print("WE GO FOR {}".format(sec)) # generate the model for security level sec (a_sec, b_sec) = generate_curve(sec, curves_dir) # verify the model for security level sec @@ -124,6 +127,6 @@ if __name__ == "__main__": if (fail): print("FAILURE: Fail to verify the following curves") print(json.dumps(fail)) - sys.exit(1) + exit(1) print(json.dumps(success)) diff --git a/pyproject.toml b/pyproject.toml index ff1c27cf7..54825080d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,2 @@ [tool.ruff] line-length = 169 -select = ["F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "A", "C4", "T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH", "PLC", "PLE", "PLR", "PLW", "RUF"] -ignore = ["D", "T20", "ANN", "N806", "ARG001", "S101", "BLE001"]