From 2a304c4f60f2bb171d5110675a79ecf42d85d4a4 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Jun 2022 17:35:12 +0100 Subject: [PATCH] tidy finalized script --- new_scripts.py | 78 +++++++++++++++----------------------------------- 1 file changed, 23 insertions(+), 55 deletions(-) diff --git a/new_scripts.py b/new_scripts.py index baaf2ddf6..251bc75ce 100644 --- a/new_scripts.py +++ b/new_scripts.py @@ -1,11 +1,7 @@ -import gc -import multiprocessing - from estimator_new import * from sage.all import oo, save, load from math import log2 -import gc -from multiprocessing import * +import multiprocessing def old_models(security_level, sd, logq=32): @@ -49,6 +45,7 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): """ est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + return est @@ -64,6 +61,7 @@ def get_security_level(est, dp=2): attack_costs.append(est[key]["rop"]) # get the security level correct to 'dp' decimal places security_level = round(log2(min(attack_costs)), dp) + return security_level @@ -91,47 +89,26 @@ def automated_param_select_n(params, target_security=128): 456 """ - # get an initial estimate - # costs = estimate(params) - # security_level = get_security_level(costs, 2) - # determine if we are above or below the target security level - # z = inequality(security_level, target_security) - # get an estimate based on the prev. model print("n = {}".format(params.n)) n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) - # TODO -- is this how we want to deal with the small n issue? Shouldn't the model have this baked in? - # we want to start no lower than n = 450 n_start = max(n_start, 450) + # TODO: think about throwing an error if the required n < 450 - #if n_start > 1024: - # we only consider powers-of-two for now, in this range - # n_log = log2(n_start) - # n_start = 2**round(n_log) - - print("n_start = {}".format(n_start)) params = params.updated(n=n_start) - print(params) - # costs2 = estimate(params) security_level = get_security_level(costs2, 2) - costs2 = None z = inequality(security_level, target_security) # we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing) while z * security_level < z * target_security: - # if params.n > 1024: - # we only need to consider powers-of-two in this case - # TODO: fill in this case! For n > 1024 we only need to consider every 256 + # TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization) params = params.updated(n = params.n + z * 8) costs = estimate(params) security_level = get_security_level(costs, 2) - # try none with delete, try none without delete - # test the list of objects that are in memory before end of program - costs = None if -1 * params.Xe.stddev > 0: - print("target security level is unatainable") + print("target security level is unattainable") break # final estimate (we went too far in the above loop) @@ -147,34 +124,25 @@ def automated_param_select_n(params, target_security=128): log2(params.q), security_level)) - # final sanity check so we don't return insecure (or inf) parameters - # TODO: figure out inf in new estimator - # or security_level == oo: if security_level < target_security: params.updated(n=None) - return (params, security_level) + return params, security_level -def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="v0.sobj"): +def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"): """ + :param params_in: a initial set of LWE parameters :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters - :param params: the standard deviation of the LWE error - :param target_security: the target number of bits of security, 128 is default - - EXAMPLE: - sage: X = generate_parameter_matrix() - sage: X + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file """ - # grab min and max value/s of n (sd_min, sd_max) = sd_range for lam in target_security_levels: - print("LAM = {}".format(lam)) for sd in range(sd_min, sd_max + 1): Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd) (params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam) - print("PARAMS OUT = {}".format(params_out)) try: results = load("{}.sobj".format(name)) @@ -185,36 +153,36 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) save(results, "{}.sobj".format(name)) - del(params_out) - gc.collect() return results -def generate_zama_curves64(sd_range=range(5,9), target_security_levels=[256], name="default"): +def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"): + """ + The top level function which we use to run the experiment + + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ if __name__ == '__main__': D = ND.DiscreteGaussian vals = range(sd_range[0], sd_range[1]) - procs = [] pool = multiprocessing.Pool(2) - init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='TFHE_DEFAULT') + init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') inputs = [(init_params, (val, val), target_security_levels, name) for val in vals] - print(inputs[0]) res = pool.starmap(generate_parameter_matrix, inputs) return "done" -def wrap(*args): - return generate_parameter_matrix(*args) - +# The script runs the following commands import sys +# grab values of the command-line input arguments a = int(sys.argv[1]) b = int(sys.argv[2]) c = int(sys.argv[3]) -print(b) -D = ND.DiscreteGaussian -init_params = LWE.Parameters(n=1024, q=2 ** 32, Xs=ND.UniformMod(2), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT') +# run the code generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))