diff --git a/generate_data.py b/generate_data.py index 28a6848b1..26a0e933e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -1,3 +1,4 @@ +import sys from estimator_new import * from sage.all import oo, save, load, ceil from math import log2 @@ -91,7 +92,8 @@ def automated_param_select_n(params, target_security=128): # get an estimate based on the prev. model print("n = {}".format(params.n)) - n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) + n_start = old_models(target_security, log2( + params.Xe.stddev), log2(params.q)) # n_start = max(n_start, 450) # TODO: think about throwing an error if the required n < 450 @@ -103,7 +105,7 @@ def automated_param_select_n(params, target_security=128): # we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing) while z * security_level < z * target_security: # TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization) - params = params.updated(n = params.n + z * 8) + params = params.updated(n=params.n + z * 8) costs = estimate(params) security_level = get_security_level(costs, 2) @@ -120,8 +122,10 @@ def automated_param_select_n(params, target_security=128): security_level = get_security_level(costs, 2) print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n, - log2(params.Xe.stddev), - log2(params.q), + log2( + params.Xe.stddev), + log2( + params.q), security_level)) if security_level < target_security: @@ -143,7 +147,8 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], for sd in range(sd_min, sd_max + 1): print("run for {}".format(lam, sd)) Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd) - (params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam) + (params_out, sec) = automated_param_select_n( + params_in.updated(Xe=Xe_new), target_security=lam) try: results = load("{}.sobj".format(name)) @@ -151,7 +156,8 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], results = dict() results["{}".format(lam)] = [] - results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) + results["{}".format(lam)].append( + (params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) save(results, "{}.sobj".format(name)) return results @@ -170,18 +176,20 @@ def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name= D = ND.DiscreteGaussian vals = range(sd_range[0], sd_range[1]) pool = multiprocessing.Pool(2) - init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') - inputs = [(init_params, (val, val), target_security_levels, name) for val in vals] + init_params = LWE.Parameters( + n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') + inputs = [(init_params, (val, val), target_security_levels, name) + for val in vals] res = pool.starmap(generate_parameter_matrix, inputs) return "done" # The script runs the following commands -import sys # grab values of the command-line input arguments a = int(sys.argv[1]) b = int(sys.argv[2]) c = int(sys.argv[3]) # run the code -generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a)) \ No newline at end of file +generate_zama_curves64(sd_range=(b, c), target_security_levels=[ + a], name="{}".format(a)) diff --git a/verify_curves.py b/verify_curves.py index bad73c93c..9dfe15ab1 100644 --- a/verify_curves.py +++ b/verify_curves.py @@ -9,7 +9,7 @@ def sort_data(security_level): X = load("{}.sobj".format(security_level)) # step 2. sort by SD - x = sorted(X["{}".format(security_level)], key = itemgetter(2)) + x = sorted(X["{}".format(security_level)], key=itemgetter(2)) # step3. replace the sorted value X["{}".format(security_level)] = x @@ -28,14 +28,14 @@ def generate_curve(security_level): for x in X["{}".format(security_level)]: N.append(x[0]) SD.append(x[2] + 0.5) - + # step 3. perform interpolation and return coefficients - (a,b) = np.polyfit(N, SD, 1) + (a, b) = np.polyfit(N, SD, 1) return a, b - -def verify_curve(security_level, a = None, b = None): + +def verify_curve(security_level, a=None, b=None): # step 1. get the table and max values of n, sd X = sort_data(security_level) @@ -53,9 +53,9 @@ def verify_curve(security_level, a = None, b = None): if n < n_val: pass else: - j = i + j = i break - + # now j is the correct index, we return the corresponding sd return table[j][2] @@ -67,7 +67,7 @@ def verify_curve(security_level, a = None, b = None): for n in range(n_max, n_min, - 1): model_sd = f_model(a, b, n) table_sd = f_table(X["{}".format(security_level)], n) - print(n , table_sd, model_sd, model_sd >= table_sd) + print(n, table_sd, model_sd, model_sd >= table_sd) if table_sd > model_sd: print("MODEL FAILS at n = {}".format(n)) @@ -76,7 +76,7 @@ def verify_curve(security_level, a = None, b = None): return "PASS", n_min -def generate_and_verify(security_levels, log_q, name = "verified_curves"): +def generate_and_verify(security_levels, log_q, name="verified_curves"): data = [] @@ -89,9 +89,10 @@ def generate_and_verify(security_levels, log_q, name = "verified_curves"): # append the information into a list data.append((a_sec, b_sec - log_q, sec, res[0], res[1])) save(data, "{}.sobj".format(name)) - + return data -data = generate_and_verify([80, 96, 112, 128, 144, 160, 176, 192, 256], log_q = 64) +data = generate_and_verify( + [80, 96, 112, 128, 144, 160, 176, 192, 256], log_q=64) print(data)