diff --git a/concrete-security-curves-cpp/gen_header.py b/concrete-security-curves-cpp/gen_header.py index 38b0a5b73..bfd86e924 100644 --- a/concrete-security-curves-cpp/gen_header.py +++ b/concrete-security-curves-cpp/gen_header.py @@ -1,4 +1,6 @@ -import sys, json; +import json +import sys + def print_curve(data): print(f'\tSecurityCurve({data["security_level"]},{data["slope"]}, {data["bias"]}, {data["minimal_lwe_dimension"]}, KeyFormat::BINARY),') @@ -9,4 +11,4 @@ def print_cpp_curves_declaration(datas): print_curve(data) print("}\n") -print_cpp_curves_declaration(json.load(sys.stdin)) \ No newline at end of file +print_cpp_curves_declaration(json.load(sys.stdin)) diff --git a/concrete-security-curves-rust/gen_table.py b/concrete-security-curves-rust/gen_table.py index 4646cbbc5..07c9817d9 100644 --- a/concrete-security-curves-rust/gen_table.py +++ b/concrete-security-curves-rust/gen_table.py @@ -1,4 +1,6 @@ -import sys, json; +import json +import sys + def print_curve(data): print(f' ({data["security_level"]}, SecurityWeights {{ slope: {data["slope"]}, bias: {data["bias"]}, minimal_lwe_dimension: {data["minimal_lwe_dimension"]} }}),') @@ -10,4 +12,4 @@ def print_rust_curves_declaration(datas): print_curve(data) print("];") -print_rust_curves_declaration(json.load(sys.stdin)) \ No newline at end of file +print_rust_curves_declaration(json.load(sys.stdin)) diff --git a/lattice-scripts/compare_curves_and_estimator.py b/lattice-scripts/compare_curves_and_estimator.py index 598d74335..05eb2b3dc 100644 --- a/lattice-scripts/compare_curves_and_estimator.py +++ b/lattice-scripts/compare_curves_and_estimator.py @@ -1,10 +1,12 @@ -from estimator import LWE, ND -from sage.all import oo, load, floor -from generate_data import estimate, get_security_level import argparse import os import sys -sys.path.insert(1, 'lattice-estimator') + +from estimator import LWE, ND +from generate_data import estimate, get_security_level +from sage.all import floor, load, oo + +sys.path.insert(1, "lattice-estimator") LOG_N_MAX = 17 + 1 @@ -21,6 +23,7 @@ def get_index(sec, curves): for i in range(len(curves)): if curves[i][2] == sec: return i + return None def estimate_security_with_lattice_estimator(lwe_dimension, std_dev, log_q): @@ -45,9 +48,8 @@ def get_minimal_lwe_dimension(curve, security_level, log_q): :param security_level: :param log_q: :return: - """ - minimal_lwe_dim = curve[-1] - return minimal_lwe_dim + """ + return curve[-1] def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q): @@ -65,8 +67,7 @@ def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q): a = curve[0] b = curve[1] + log_q - stddev = minimal_stddev(a, b, lwe_dimension) - return stddev + return minimal_stddev(a, b, lwe_dimension) def compare_curve_and_estimator(security_level, log_q, curves_dir): @@ -137,5 +138,5 @@ if __name__ == "__main__": args = CLI.parse_args() for security_level in args.security_levels: if not(compare_curve_and_estimator(security_level, args.log_q, args.curves_dir)): - exit(1) - exit(0) \ No newline at end of file + sys.exit(1) + sys.exit(0) diff --git a/lattice-scripts/generate_data.py b/lattice-scripts/generate_data.py index 5d2693d2f..c675da015 100644 --- a/lattice-scripts/generate_data.py +++ b/lattice-scripts/generate_data.py @@ -1,11 +1,13 @@ -from estimator import RC, LWE, ND -from sage.all import oo, save, load -from math import log2 -import multiprocessing import argparse +import multiprocessing import os import sys -sys.path.insert(1, 'lattice-estimator') +from math import log2 + +from estimator import LWE, ND, RC +from sage.all import load, oo, save + +sys.path.insert(1, "lattice-estimator") old_models_sobj = "" @@ -25,6 +27,7 @@ def old_models(security_level, sd, logq=32): for i in range(len(curves)): if curves[i][2] == sec: return i + return None if old_models_sobj is None or not(os.path.exists(old_models_sobj)): return 450 @@ -48,10 +51,7 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): :param skip: attacks to skip """ - est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) - - return est - + return LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) def get_security_level(est, dp=2): """ @@ -64,10 +64,7 @@ def get_security_level(est, dp=2): for key in est: attack_costs.append(est[key]["rop"]) # get the security level correct to 'dp' decimal places - security_level = round(log2(min(attack_costs)), dp) - - return security_level - + return round(log2(min(attack_costs)), dp) def inequality(x, y): """A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier @@ -80,6 +77,8 @@ def inequality(x, y): if x > y: return -1 + return None + def automated_param_select_n(params, target_security=128): """A function used to generate the smallest value of n which allows for @@ -96,7 +95,6 @@ def automated_param_select_n(params, target_security=128): # get an estimate based on the prev. model print("n = {}".format(params.n)) n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) - # n_start = max(n_start, 450) # TODO: think about throwing an error if the required n < 450 params = params.updated(n=n_start) @@ -108,7 +106,6 @@ def automated_param_select_n(params, target_security=128): # (on binary key guessing) while z * security_level < z * target_security: # TODO: fill in this case! For n > 1024 we only need to consider every - # 256 (optimization) params = params.updated(n=params.n + z * 8) costs = estimate(params) security_level = get_security_level(costs, 2) @@ -138,7 +135,7 @@ def automated_param_select_n(params, target_security=128): def generate_parameter_matrix( - params_in, sd_range, target_security_levels=[128], name="default_name" + params_in, sd_range, target_security_levels=(128), name="default_name" ): """ :param params_in: a initial set of LWE parameters @@ -159,7 +156,7 @@ def generate_parameter_matrix( try: results = load("{}.sobj".format(name)) except BaseException: - results = dict() + results = {} results["{}".format(lam)] = [] results["{}".format(lam)].append( @@ -171,7 +168,7 @@ def generate_parameter_matrix( def generate_zama_curves64( - sd_range=[2, 58], target_security_levels=[128], name="default_name" + sd_range=(2, 58), target_security_levels=(128), name="default_name" ): """ The top level function which we use to run the experiment diff --git a/lattice-scripts/verify_curves.py b/lattice-scripts/verify_curves.py index d63824511..c130e4b11 100644 --- a/lattice-scripts/verify_curves.py +++ b/lattice-scripts/verify_curves.py @@ -1,23 +1,25 @@ -import numpy as np -from sage.all import save, load, ceil +import argparse import json import os -import argparse +import sys + +import numpy as np +from sage.all import ceil, load, save def sort_data(security_level, curves_dir): from operator import itemgetter # step 1. load the data - X = load(os.path.join(curves_dir, f"{security_level}.sobj")) + x = load(os.path.join(curves_dir, f"{security_level}.sobj")) # step 2. sort by SD - x = sorted(X["{}".format(security_level)], key=itemgetter(2)) + x = sorted(x["{}".format(security_level)], key=itemgetter(2)) # step3. replace the sorted value - X["{}".format(security_level)] = x + x["{}".format(security_level)] = x - return X + return x def generate_curve(security_level, curves_dir): @@ -63,16 +65,12 @@ def verify_curve(security_level, a, b, curves_dir): # step 3. for each n, check whether we satisfy the table n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0]) - #print(n_min) - #print(n_max) for n in range(n_max, n_min, -1): model_sd = f_model(a, b, n) table_sd = f_table(X["{}".format(security_level)], n) - #print(n, table_sd, model_sd, model_sd >= table_sd) if table_sd > model_sd: - #print("MODEL FAILS at n = {}".format(n)) return False return True, n_min @@ -85,7 +83,6 @@ def generate_and_verify(security_levels, log_q, curves_dir, name="verified_curve fail = [] for sec in security_levels: - #print("WE GO FOR {}".format(sec)) # generate the model for security level sec (a_sec, b_sec) = generate_curve(sec, curves_dir) # verify the model for security level sec @@ -127,6 +124,6 @@ if __name__ == "__main__": if (fail): print("FAILURE: Fail to verify the following curves") print(json.dumps(fail)) - exit(1) + sys.exit(1) print(json.dumps(success)) diff --git a/pyproject.toml b/pyproject.toml index 54825080d..ff1c27cf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,4 @@ [tool.ruff] line-length = 169 +select = ["F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "A", "C4", "T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH", "PLC", "PLE", "PLR", "PLW", "RUF"] +ignore = ["D", "T20", "ANN", "N806", "ARG001", "S101", "BLE001"]