diff --git a/.github/workflows/pep8.yml b/.github/workflows/pep8.yml index f4656a9a2..d5d0aea9e 100644 --- a/.github/workflows/pep8.yml +++ b/.github/workflows/pep8.yml @@ -19,4 +19,5 @@ jobs: - name: PEP8 run: | pip install --upgrade pyproject-flake8 - flake8 new_scripts.py + flake8 generate_data.py + flake8 verify_curves.py diff --git a/112.sobj b/112.sobj new file mode 100644 index 000000000..a50b2f74c Binary files /dev/null and b/112.sobj differ diff --git a/128.sobj b/128.sobj new file mode 100644 index 000000000..79a6242f3 Binary files /dev/null and b/128.sobj differ diff --git a/144.sobj b/144.sobj new file mode 100644 index 000000000..fb6e74386 Binary files /dev/null and b/144.sobj differ diff --git a/160.sobj b/160.sobj new file mode 100644 index 000000000..005181dd1 Binary files /dev/null and b/160.sobj differ diff --git a/176.sobj b/176.sobj new file mode 100644 index 000000000..cfc0d5971 Binary files /dev/null and b/176.sobj differ diff --git a/192.sobj b/192.sobj new file mode 100644 index 000000000..213644dde Binary files /dev/null and b/192.sobj differ diff --git a/256.sobj b/256.sobj new file mode 100644 index 000000000..168914b0a Binary files /dev/null and b/256.sobj differ diff --git a/80.sobj b/80.sobj new file mode 100644 index 000000000..14697f236 Binary files /dev/null and b/80.sobj differ diff --git a/96.sobj b/96.sobj new file mode 100644 index 000000000..aa776fb42 Binary files /dev/null and b/96.sobj differ diff --git a/README.rst b/README.rst index b05d44c55..4f94197b2 100644 --- a/README.rst +++ b/README.rst @@ -3,12 +3,12 @@ Parameter curves for Concrete This Github repository contains the code needed to generate the Parameter curves used inside Zama. The repository contains the following files: -- cpp/, Python scripts to generate a cpp file containing the parameter curves +- cpp/, Python scripts to generate a cpp file containing the parameter curves (needs updating) - data/, a folder containing the data generated for previous curves. -- estimator/, Zama's internal version of the LWE Estimator -- figs/, a folder containing various figures related to the parameter curves -- scripts.py, a copy of all scripts required to generate the parameter curves -- a variety of other python files, used for estimating the security of previous Concrete parameter sets +- estimator_new/, the Lattice estimator (TODO: add as a submodule and use dependabot to alert for new commits) +- old_files/, legacy files used for previous versions +- generate_data.py, functions to gather raw data from the lattice estimator +- verifiy_curves.py, functions to generate and verify curves from raw data .. image:: logo.svg :align: center @@ -21,52 +21,43 @@ This is an example of how to generate the parameter curves, and save them to fil :: - sage: load("scripts.py") - sage: results = get_zama_curves() - sage: save(results, "v0.sobj") + ./job.sh :: -We can load results files, and find the interpolants. +This will generate several data files, {80, 96, 112, 128, 144, 160, 176, 192, 256}.sobj + +To generate the parameter curves from the data files, we run + +`sage verify_curves.py` + +this will generate a list of the form: :: - sage: load("scripts.py") - sage: interps = [] - sage: results = load("v0.sobj") - sage: for result in results: - sage: interps.append(interpolate_result(result, log_q = 64)) - sage: interps - [(-0.040476778656126484, 1.143346508563902), - (-0.03417207792207793, 1.4805194805194737), - (-0.029681716023268107, 1.752723426758335), - (-0.0263748887657055, 2.0121439233304894), - (-0.023730136557783763, 2.1537066948924095), - (-0.021604493958972515, 2.2696862472846204), - (-0.019897520946588438, 2.4423829771964796), - (-0.018504919354426233, 2.6634073426215745), - (-0.017254242957361113, 2.7353702447139026), - (-0.016178309410530816, 2.8493969373734758), - (-0.01541034709414119, 3.1982749283836283), - (-0.014327640360322604, 2.899270827311096)] + [(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450), + (-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450), + (-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450), + (-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450), + (-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450), + (-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498), + (-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551), + (-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606), + (-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)] :: -Finding the value of n_{alpha} is done manually. We can also verify the interpolants which are generated at the same time: +each element is a tuple (a, b, security, P, n_min), where (a,b) are the model +parameters, security is the security level, P is a boolean value denoting PASS or +FAIL of the verification, and n_min is the smallest reccomended value of `n` to be used. + +Each model outputs a value of sigma, and is of the form: :: - # verify the interpolant used for lambda = 256 (which is interps[-1]) - sage: z = verify_interpolants(interps[-1], (128,2048), 64) - [... code runs, can take ~10 mins ...] - # find the index corresponding to n_alpha, which is where security drops below the target security level (256 here) - sage: n_alpha = find_nalpha(z, 256) - 653 - - # so the model in this case is - (-0.014327640360322604, 2.899270827311096, 653) - # which corresponds to - # sd(n) = max(-0.014327640360322604 * n + 2.899270827311096, -logq + 2), n >= 653 + f(a, b, n) = max(ceil(a * n + b), -log2(q) + 2) :: +where the -log2(q) + 2 term ensures that we are always using at least two bits of noise. + Version History ------------------- @@ -76,14 +67,7 @@ Data for the curves are kept in /data. The following files are present: v0: generated using the {usvp, dual, decoding} attacks v0.1: generated using the {mitm, usvp, dual, decoding} attacks + v0.2: generated using the lattice estimator :: -TODO List -------------------- -There are several updates which are still required. - 1. Consider Hybrid attacks (WIP, Michael + Ben are coding up hybrid-dual/hybrid-decoding estimates) - 2. CI/CD stuff for new pushes to the external LWE Estimator. - 3. Fully automate the process of finding n_{alpha} for each curve. - 4. Functionality for q =! 64? This is covered by the curve, but we currently don't account for it in the models, and it needs to be done manually. - 5. cpp file generation diff --git a/data/v0.2/112.sobj b/data/v0.2/112.sobj new file mode 100644 index 000000000..a50b2f74c Binary files /dev/null and b/data/v0.2/112.sobj differ diff --git a/data/v0.2/128.sobj b/data/v0.2/128.sobj new file mode 100644 index 000000000..79a6242f3 Binary files /dev/null and b/data/v0.2/128.sobj differ diff --git a/data/v0.2/144.sobj b/data/v0.2/144.sobj new file mode 100644 index 000000000..fb6e74386 Binary files /dev/null and b/data/v0.2/144.sobj differ diff --git a/data/v0.2/160.sobj b/data/v0.2/160.sobj new file mode 100644 index 000000000..005181dd1 Binary files /dev/null and b/data/v0.2/160.sobj differ diff --git a/data/v0.2/176.sobj b/data/v0.2/176.sobj new file mode 100644 index 000000000..cfc0d5971 Binary files /dev/null and b/data/v0.2/176.sobj differ diff --git a/data/v0.2/192.sobj b/data/v0.2/192.sobj new file mode 100644 index 000000000..213644dde Binary files /dev/null and b/data/v0.2/192.sobj differ diff --git a/data/v0.2/256.sobj b/data/v0.2/256.sobj new file mode 100644 index 000000000..168914b0a Binary files /dev/null and b/data/v0.2/256.sobj differ diff --git a/data/v0.2/80.sobj b/data/v0.2/80.sobj new file mode 100644 index 000000000..14697f236 Binary files /dev/null and b/data/v0.2/80.sobj differ diff --git a/data/v0.2/96.sobj b/data/v0.2/96.sobj new file mode 100644 index 000000000..aa776fb42 Binary files /dev/null and b/data/v0.2/96.sobj differ diff --git a/estimator_new/__pycache__/lwe_primal.cpython-38.pyc b/estimator_new/__pycache__/lwe_primal.cpython-38.pyc index f00bc1cb7..0326ade84 100644 Binary files a/estimator_new/__pycache__/lwe_primal.cpython-38.pyc and b/estimator_new/__pycache__/lwe_primal.cpython-38.pyc differ diff --git a/new_scripts.py b/generate_data.py similarity index 53% rename from new_scripts.py rename to generate_data.py index 84ffd46d8..26a0e933e 100644 --- a/new_scripts.py +++ b/generate_data.py @@ -1,8 +1,11 @@ +import sys from estimator_new import * -from sage.all import oo, save +from sage.all import oo, save, load, ceil from math import log2 +import multiprocessing -def old_models(security_level, sd, logq = 32): + +def old_models(security_level, sd, logq=32): """ Use the old model as a starting point for the data gathering step :param security_level: the security level under consideration @@ -10,12 +13,11 @@ def old_models(security_level, sd, logq = 32): :param logq : the (base 2 log) value of the LWE modulus q """ - def evaluate_model(sd, a, b): - return (sd - b)/a + def evaluate_model(a, b, stddev=sd): + return (stddev - b)/a models = dict() - # TODO: figure out a way to import these from a datafile, for future version models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq) models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq) models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq) @@ -30,32 +32,37 @@ def old_models(security_level, sd, logq = 32): models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq) (a, b) = models["{}".format(security_level)] - n_est = evaluate_model(sd, a, b) + n_est = evaluate_model(a, b, sd) return round(n_est) -def estimate(params, red_cost_model = RC.BDGL16): +def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): """ Retrieve an estimate using the Lattice Estimator, for a given set of input parameters :param params: the input LWE parameters + :param red_cost_model: the lattice reduction cost model + :param skip: attacks to skip """ - est = LWE.estimate(params, deny_list=("arora-gb", "bkw"), red_cost_model=red_cost_model) + est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + return est -def get_security_level(est, dp = 2): +def get_security_level(est, dp=2): """ Get the security level lambda from a Lattice Estimator output :param est: the Lattice Estimator output - :param dp : the number of decimal places to consider + :param dp: the number of decimal places to consider """ attack_costs = [] - for key in est.keys(): + # note: key does not need to be specified est vs est.keys() + for key in est: attack_costs.append(est[key]["rop"]) # get the security level correct to 'dp' decimal places security_level = round(log2(min(attack_costs)), dp) + return security_level @@ -83,120 +90,106 @@ def automated_param_select_n(params, target_security=128): 456 """ - # get an initial estimate - # costs = estimate(params) - # security_level = get_security_level(costs, 2) - # determine if we are above or below the target security level - # z = inequality(security_level, target_security) - # get an estimate based on the prev. model print("n = {}".format(params.n)) - n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) - # TODO -- is this how we want to deal with the small n issue? Shouldn't the model have this baked in? - # we want to start no lower than n = 450 - n_start = max(n_start, 450) - print("n_start = {}".format(n_start)) + n_start = old_models(target_security, log2( + params.Xe.stddev), log2(params.q)) + # n_start = max(n_start, 450) + # TODO: think about throwing an error if the required n < 450 + params = params.updated(n=n_start) - print(params) costs2 = estimate(params) security_level = get_security_level(costs2, 2) z = inequality(security_level, target_security) - # we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing) while z * security_level < z * target_security: - # if params.n > 1024: - # we only need to consider powers-of-two in this case - # TODO: fill in this case! For n > 1024 we only need to consider every 256 - - - - params = params.updated(n = params.n + z * 8) + # TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization) + params = params.updated(n=params.n + z * 8) costs = estimate(params) security_level = get_security_level(costs, 2) if -1 * params.Xe.stddev > 0: - print("target security level is unatainable") + print("target security level is unattainable") break # final estimate (we went too far in the above loop) if security_level < target_security: - # TODO: we should somehow keep the previous estimate stored so that we don't need to compute it twice - # if we do this we need to make sure that it works for both sides (i.e. if (i-1) is above or below the - # security level - - params = params.updated(n = params.n - z * 8) + # we make n larger + print("we make n larger") + params = params.updated(n=params.n + 8) costs = estimate(params) security_level = get_security_level(costs, 2) print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n, - log2(params.Xe.stddev), - log2(params.q), - security_level)) + log2( + params.Xe.stddev), + log2( + params.q), + security_level)) - # final sanity check so we don't return insecure (or inf) parameters - # TODO: figure out inf in new estimator - # or security_level == oo: if security_level < target_security: params.updated(n=None) - return params + return params, security_level -def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128]): + +def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"): """ + :param params_in: a initial set of LWE parameters :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters - :param params: the standard deviation of the LWE error - :param target_security: the target number of bits of security, 128 is default - - EXAMPLE: - sage: X = generate_parameter_matrix() - sage: X + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file """ - results = dict() - - # grab min and max value/s of n (sd_min, sd_max) = sd_range - for lam in target_security_levels: - results["{}".format(lam)] = [] for sd in range(sd_min, sd_max + 1): + print("run for {}".format(lam, sd)) Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd) - params_out = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam) - results["{}".format(lam)].append((params_out.n, params_out.q, params_out.Xe.stddev)) + (params_out, sec) = automated_param_select_n( + params_in.updated(Xe=Xe_new), target_security=lam) + + try: + results = load("{}.sobj".format(name)) + except: + results = dict() + results["{}".format(lam)] = [] + + results["{}".format(lam)].append( + (params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) + save(results, "{}.sobj".format(name)) return results +def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"): + """ + The top level function which we use to run the experiment -def test_it(): + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ + if __name__ == '__main__': - D = nd.NoiseDistribution.DiscreteGaussian - DEFAULT_PARAMETERS = LWE.Parameters(n=1024, q=2**64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT') + D = ND.DiscreteGaussian + vals = range(sd_range[0], sd_range[1]) + pool = multiprocessing.Pool(2) + init_params = LWE.Parameters( + n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') + inputs = [(init_params, (val, val), target_security_levels, name) + for val in vals] + res = pool.starmap(generate_parameter_matrix, inputs) + + return "done" - # x = estimate(params) - # y = get_security_level(x, 2) - # print(y) - #z1 = automated_param_select_n(schemes.TFHE630.updated(n=786), 128) - #print(z1) - sd_range = [1,4] - print("working...") - z3 = generate_parameter_matrix(DEFAULT_PARAMETERS, sd_range=[5, 6], target_security_levels=[128, 192, 256]) - # TODO: in this function call the initial guess for n is way off (security is ~60-bits instead of close to 128). - print(z3) - save(z3, "123.sobj") - - return z3 - - -def generate_zama_curves64(sd_range=[2, 60], target_security_levels=[128, 192, 256]): - - D = ND.DiscreteGaussian - init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT') - raw_data = generate_parameter_matrix(init_params, sd_range=sd_range, target_security_levels=target_security_levels) - - return raw_data - - -generate_zama_curves64() \ No newline at end of file +# The script runs the following commands +# grab values of the command-line input arguments +a = int(sys.argv[1]) +b = int(sys.argv[2]) +c = int(sys.argv[3]) +# run the code +generate_zama_curves64(sd_range=(b, c), target_security_levels=[ + a], name="{}".format(a)) diff --git a/job.sh b/job.sh new file mode 100755 index 000000000..def3e07a2 --- /dev/null +++ b/job.sh @@ -0,0 +1,66 @@ +#!/bin/sh +# 80-bits +sage generate_data.py 80 2 12 +sage generate_data.py 80 12 22 +sage generate_data.py 80 22 32 +sage generate_data.py 80 32 42 +sage generate_data.py 80 42 52 +sage generate_data.py 80 52 59 +# 96-bits +sage generate_data.py 96 2 12 +sage generate_data.py 96 12 22 +sage generate_data.py 96 22 32 +sage generate_data.py 96 32 42 +sage generate_data.py 96 42 52 +sage generate_data.py 96 52 59 +# 112-bits +sage generate_data.py 112 2 12 +sage generate_data.py 112 12 22 +sage generate_data.py 112 22 32 +sage generate_data.py 112 32 42 +sage generate_data.py 112 42 52 +sage generate_data.py 112 52 59 +# 128-bits +sage generate_data.py 128 2 12 +sage generate_data.py 128 12 22 +sage generate_data.py 128 22 32 +sage generate_data.py 128 32 42 +sage generate_data.py 128 42 52 +sage generate_data.py 128 52 59 +# 144-bits +sage generate_data.py 144 2 12 +sage generate_data.py 144 12 22 +sage generate_data.py 144 22 32 +sage generate_data.py 144 32 42 +sage generate_data.py 144 42 52 +sage generate_data.py 144 52 59 +# 160-bits +sage generate_data.py 160 2 12 +sage generate_data.py 160 12 22 +sage generate_data.py 160 22 32 +sage generate_data.py 160 32 42 +sage generate_data.py 160 42 52 +sage generate_data.py 160 52 59 +# 176-bits +sage generate_data.py 176 2 12 +sage generate_data.py 176 12 22 +sage generate_data.py 176 22 32 +sage generate_data.py 176 32 42 +sage generate_data.py 176 42 52 +sage generate_data.py 176 52 59 +# 192-bits +sage generate_data.py 192 2 12 +sage generate_data.py 192 12 22 +sage generate_data.py 192 22 32 +sage generate_data.py 192 32 42 +sage generate_data.py 192 42 52 +sage generate_data.py 192 52 59 +# 256-bits +sage generate_data.py 256 2 12 +sage generate_data.py 256 12 22 +sage generate_data.py 256 22 32 +sage generate_data.py 256 32 42 +sage generate_data.py 256 42 52 +sage generate_data.py 256 52 59 + + diff --git a/concrete_params.py b/old_files/concrete_params.py similarity index 100% rename from concrete_params.py rename to old_files/concrete_params.py diff --git a/estimate_oldparams.py b/old_files/estimate_oldparams.py similarity index 100% rename from estimate_oldparams.py rename to old_files/estimate_oldparams.py diff --git a/figs/iso.png b/old_files/figs/iso.png similarity index 100% rename from figs/iso.png rename to old_files/figs/iso.png diff --git a/figs/plot.png b/old_files/figs/plot.png similarity index 100% rename from figs/plot.png rename to old_files/figs/plot.png diff --git a/figs/plot2.png b/old_files/figs/plot2.png similarity index 100% rename from figs/plot2.png rename to old_files/figs/plot2.png diff --git a/figs/sieve.png b/old_files/figs/sieve.png similarity index 100% rename from figs/sieve.png rename to old_files/figs/sieve.png diff --git a/figs/uSVP.png b/old_files/figs/uSVP.png similarity index 100% rename from figs/uSVP.png rename to old_files/figs/uSVP.png diff --git a/hybrid_decoding.py b/old_files/hybrid_decoding.py similarity index 100% rename from hybrid_decoding.py rename to old_files/hybrid_decoding.py diff --git a/old_files/memory_tests/test.py b/old_files/memory_tests/test.py new file mode 100644 index 000000000..22a58d63e --- /dev/null +++ b/old_files/memory_tests/test.py @@ -0,0 +1,17 @@ +from estimator_new import * +from sage.all import oo, save + +def test(): + + # code + D = ND.DiscreteGaussian + params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT') + + names = [params, params.updated(n=761), params.updated(q=2 ** 65), params.updated(n=762)] + + for name in names: + x = LWE.estimate(name, deny_list=("arora-gb", "bkw")) + + return 0 + +test() \ No newline at end of file diff --git a/old_files/memory_tests/test2.py b/old_files/memory_tests/test2.py new file mode 100644 index 000000000..4ad476806 --- /dev/null +++ b/old_files/memory_tests/test2.py @@ -0,0 +1,31 @@ + +from multiprocessing import * +from estimator_new import * +from sage.all import oo, save + + +def test_memory(x): + print("doing job...") + print(x) + y = LWE.estimate(x, deny_list=("arora-gb", "bkw")) + return y + +if __name__ == "__main__": + D = ND.DiscreteGaussian + params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT') + + names = [params, params.updated(n=761), params.updated(q=2**65), params.updated(n=762)] + procs = [] + proc = Process(target=print_func) + procs.append(proc) + proc.start() + p = Pool(1) + + for name in names: + proc = Process(target=test_memory, args=(name,)) + procs.append(proc) + proc.start() + proc.join() + + for proc in procs: + proc.join() \ No newline at end of file diff --git a/old_files/new_scripts.py b/old_files/new_scripts.py new file mode 100644 index 000000000..db1543550 --- /dev/null +++ b/old_files/new_scripts.py @@ -0,0 +1,471 @@ +from estimator_new import * +from sage.all import oo, save, load, ceil +from math import log2 +import multiprocessing + + +def old_models(security_level, sd, logq=32): + """ + Use the old model as a starting point for the data gathering step + :param security_level: the security level under consideration + :param sd : the standard deviation of the LWE error distribution Xe + :param logq : the (base 2 log) value of the LWE modulus q + """ + + def evaluate_model(a, b, stddev=sd): + return (stddev - b)/a + + models = dict() + + models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq) + models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq) + models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq) + models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq) + models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq) + models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq) + models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq) + models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq) + models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq) + models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq) + models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq) + models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq) + + (a, b) = models["{}".format(security_level)] + n_est = evaluate_model(a, b, sd) + + return round(n_est) + + +def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): + """ + Retrieve an estimate using the Lattice Estimator, for a given set of input parameters + :param params: the input LWE parameters + :param red_cost_model: the lattice reduction cost model + :param skip: attacks to skip + """ + + est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + + return est + + +def get_security_level(est, dp=2): + """ + Get the security level lambda from a Lattice Estimator output + :param est: the Lattice Estimator output + :param dp: the number of decimal places to consider + """ + attack_costs = [] + # note: key does not need to be specified est vs est.keys() + for key in est: + attack_costs.append(est[key]["rop"]) + # get the security level correct to 'dp' decimal places + security_level = round(log2(min(attack_costs)), dp) + + return security_level + + +def inequality(x, y): + """ A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier + :param x: the LHS of the inequality + :param y: the RHS of the inequality + """ + if x <= y: + return 1 + + if x > y: + return -1 + + +def automated_param_select_n(params, target_security=128): + """ A function used to generate the smallest value of n which allows for + target_security bits of security, for the input values of (params.Xe.stddev,params.q) + :param params: the standard deviation of the error + :param target_security: the target number of bits of security, 128 is default + + EXAMPLE: + sage: X = automated_param_select_n(Kyber512, target_security = 128) + sage: X + 456 + """ + + # get an estimate based on the prev. model + print("n = {}".format(params.n)) + n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) + # n_start = max(n_start, 450) + # TODO: think about throwing an error if the required n < 450 + + params = params.updated(n=n_start) + costs2 = estimate(params) + security_level = get_security_level(costs2, 2) + z = inequality(security_level, target_security) + + # we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing) + while z * security_level < z * target_security: + # TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization) + params = params.updated(n = params.n + z * 8) + costs = estimate(params) + security_level = get_security_level(costs, 2) + + if -1 * params.Xe.stddev > 0: + print("target security level is unattainable") + break + + # final estimate (we went too far in the above loop) + if security_level < target_security: + # we make n larger + print("we make n larger") + params = params.updated(n=params.n + 8) + costs = estimate(params) + security_level = get_security_level(costs, 2) + + print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n, + log2(params.Xe.stddev), + log2(params.q), + security_level)) + + if security_level < target_security: + params.updated(n=None) + + return params, security_level + + +def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"): + """ + :param params_in: a initial set of LWE parameters + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ + + (sd_min, sd_max) = sd_range + for lam in target_security_levels: + for sd in range(sd_min, sd_max + 1): + print("run for {}".format(lam, sd)) + Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd) + (params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam) + + try: + results = load("{}.sobj".format(name)) + except: + results = dict() + results["{}".format(lam)] = [] + + results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) + save(results, "{}.sobj".format(name)) + + return results + + +def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"): + """ + The top level function which we use to run the experiment + + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ + if __name__ == '__main__': + + D = ND.DiscreteGaussian + vals = range(sd_range[0], sd_range[1]) + pool = multiprocessing.Pool(2) + init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') + inputs = [(init_params, (val, val), target_security_levels, name) for val in vals] + res = pool.starmap(generate_parameter_matrix, inputs) + + return "done" + + +# The script runs the following commands +import sys +# grab values of the command-line input arguments +a = int(sys.argv[1]) +b = int(sys.argv[2]) +c = int(sys.argv[3]) +# run the code +generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a)) + +from estimator_new import * +from sage.all import oo, save, load +from math import log2 +import multiprocessing + + +def old_models(security_level, sd, logq=32): + """ + Use the old model as a starting point for the data gathering step + :param security_level: the security level under consideration + :param sd : the standard deviation of the LWE error distribution Xe + :param logq : the (base 2 log) value of the LWE modulus q + """ + + def evaluate_model(a, b, stddev=sd): + return (stddev - b)/a + + models = dict() + + models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq) + models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq) + models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq) + models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq) + models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq) + models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq) + models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq) + models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq) + models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq) + models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq) + models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq) + models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq) + + (a, b) = models["{}".format(security_level)] + n_est = evaluate_model(a, b, sd) + + return round(n_est) + + +def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")): + """ + Retrieve an estimate using the Lattice Estimator, for a given set of input parameters + :param params: the input LWE parameters + :param red_cost_model: the lattice reduction cost model + :param skip: attacks to skip + """ + + est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip) + + return est + + +def get_security_level(est, dp=2): + """ + Get the security level lambda from a Lattice Estimator output + :param est: the Lattice Estimator output + :param dp: the number of decimal places to consider + """ + attack_costs = [] + # note: key does not need to be specified est vs est.keys() + for key in est: + attack_costs.append(est[key]["rop"]) + # get the security level correct to 'dp' decimal places + security_level = round(log2(min(attack_costs)), dp) + + return security_level + + +def inequality(x, y): + """ A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier + :param x: the LHS of the inequality + :param y: the RHS of the inequality + """ + if x <= y: + return 1 + + if x > y: + return -1 + + +def automated_param_select_n(params, target_security=128): + """ A function used to generate the smallest value of n which allows for + target_security bits of security, for the input values of (params.Xe.stddev,params.q) + :param params: the standard deviation of the error + :param target_security: the target number of bits of security, 128 is default + + EXAMPLE: + sage: X = automated_param_select_n(Kyber512, target_security = 128) + sage: X + 456 + """ + + # get an estimate based on the prev. model + print("n = {}".format(params.n)) + n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q)) + # n_start = max(n_start, 450) + # TODO: think about throwing an error if the required n < 450 + + params = params.updated(n=n_start) + costs2 = estimate(params) + security_level = get_security_level(costs2, 2) + z = inequality(security_level, target_security) + + # we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing) + while z * security_level < z * target_security: + # TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization) + params = params.updated(n = params.n + z * 8) + costs = estimate(params) + security_level = get_security_level(costs, 2) + + if -1 * params.Xe.stddev > 0: + print("target security level is unattainable") + break + + # final estimate (we went too far in the above loop) + if security_level < target_security: + # we make n larger + print("we make n larger") + params = params.updated(n=params.n + 8) + costs = estimate(params) + security_level = get_security_level(costs, 2) + + print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n, + log2(params.Xe.stddev), + log2(params.q), + security_level)) + + if security_level < target_security: + params.updated(n=None) + + return params, security_level + + +def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"): + """ + :param params_in: a initial set of LWE parameters + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ + + (sd_min, sd_max) = sd_range + for lam in target_security_levels: + for sd in range(sd_min, sd_max + 1): + print("run for {}".format(lam, sd)) + Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd) + (params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam) + + try: + results = load("{}.sobj".format(name)) + except: + results = dict() + results["{}".format(lam)] = [] + + results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec)) + save(results, "{}.sobj".format(name)) + + return results + + +def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"): + """ + The top level function which we use to run the experiment + + :param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters + :param target_security_levels: a list of the target number of bits of security, 128 is default + :param name: a name to save the file + """ + if __name__ == '__main__': + + D = ND.DiscreteGaussian + vals = range(sd_range[0], sd_range[1]) + pool = multiprocessing.Pool(2) + init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params') + inputs = [(init_params, (val, val), target_security_levels, name) for val in vals] + res = pool.starmap(generate_parameter_matrix, inputs) + + return "done" + + +# The script runs the following commands +import sys +# grab values of the command-line input arguments +a = int(sys.argv[1]) +b = int(sys.argv[2]) +c = int(sys.argv[3]) +# run the code +generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a)) + +import numpy as np +from sage.all import save, load + +def sort_data(security_level): + from operator import itemgetter + + # step 1. load the data + X = load("{}.sobj".format(security_level)) + + # step 2. sort by SD + x = sorted(X["{}".format(security_level)], key = itemgetter(2)) + + # step3. replace the sorted value + X["{}".format(security_level)] = x + + return X + +def generate_curve(security_level): + + # step 1. get the data + X = sort_data(security_level) + + # step 2. group the n and sigma data into lists + N = [] + SD = [] + for x in X["{}".format(security_level)]: + N.append(x[0]) + SD.append(x[2] + 0.5) + + # step 3. perform interpolation and return coefficients + (a,b) = np.polyfit(N, SD, 1) + + return a, b + + +def verify_curve(security_level, a = None, b = None): + + # step 1. get the table and max values of n, sd + X = sort_data(security_level) + n_max = X["{}".format(security_level)][0][0] + sd_max = X["{}".format(security_level)][-1][2] + + # step 2. a function to get model values + def f_model(a, b, n): + return ceil(a * n + b) + + # step 3. a function to get table values + def f_table(table, n): + for i in range(len(table)): + n_val = table[i][0] + if n < n_val: + pass + else: + j = i + break + + # now j is the correct index, we return the corresponding sd + return table[j][2] + + # step 3. for each n, check whether we satisfy the table + n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0]) + print(n_min) + print(n_max) + + for n in range(n_max, n_min, - 1): + model_sd = f_model(a, b, n) + table_sd = f_table(X["{}".format(security_level)], n) + print(n , table_sd, model_sd, model_sd >= table_sd) + + if table_sd > model_sd: + print("MODEL FAILS at n = {}".format(n)) + return "FAIL" + + return "PASS", n_min + + +def generate_and_verify(security_levels, log_q, name = "verified_curves"): + + data = [] + + for sec in security_levels: + print("WE GO FOR {}".format(sec)) + # generate the model for security level sec + (a_sec, b_sec) = generate_curve(sec) + # verify the model for security level sec + res = verify_curve(sec, a_sec, b_sec) + # append the information into a list + data.append((a_sec, b_sec - log_q, sec, res[0], res[1])) + save(data, "{}.sobj".format(name)) + + return data + +# To verify the curves we use +generate_and_verify([80, 96, 112, 128, 144, 160, 176, 192, 256], log_q = 64) + diff --git a/scripts.py b/old_files/scripts.py similarity index 100% rename from scripts.py rename to old_files/scripts.py diff --git a/verified_curves.sobj b/verified_curves.sobj new file mode 100644 index 000000000..546a32e51 Binary files /dev/null and b/verified_curves.sobj differ diff --git a/verified_curves.txt b/verified_curves.txt new file mode 100644 index 000000000..87b387304 --- /dev/null +++ b/verified_curves.txt @@ -0,0 +1,9 @@ +[(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450), + (-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450), + (-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450), + (-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450), + (-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450), + (-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498), + (-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551), + (-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606), + (-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)] diff --git a/verify_curves.py b/verify_curves.py new file mode 100644 index 000000000..9dfe15ab1 --- /dev/null +++ b/verify_curves.py @@ -0,0 +1,98 @@ +import numpy as np +from sage.all import save, load, ceil + + +def sort_data(security_level): + from operator import itemgetter + + # step 1. load the data + X = load("{}.sobj".format(security_level)) + + # step 2. sort by SD + x = sorted(X["{}".format(security_level)], key=itemgetter(2)) + + # step3. replace the sorted value + X["{}".format(security_level)] = x + + return X + + +def generate_curve(security_level): + + # step 1. get the data + X = sort_data(security_level) + + # step 2. group the n and sigma data into lists + N = [] + SD = [] + for x in X["{}".format(security_level)]: + N.append(x[0]) + SD.append(x[2] + 0.5) + + # step 3. perform interpolation and return coefficients + (a, b) = np.polyfit(N, SD, 1) + + return a, b + + +def verify_curve(security_level, a=None, b=None): + + # step 1. get the table and max values of n, sd + X = sort_data(security_level) + n_max = X["{}".format(security_level)][0][0] + sd_max = X["{}".format(security_level)][-1][2] + + # step 2. a function to get model values + def f_model(a, b, n): + return ceil(a * n + b) + + # step 3. a function to get table values + def f_table(table, n): + for i in range(len(table)): + n_val = table[i][0] + if n < n_val: + pass + else: + j = i + break + + # now j is the correct index, we return the corresponding sd + return table[j][2] + + # step 3. for each n, check whether we satisfy the table + n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0]) + print(n_min) + print(n_max) + + for n in range(n_max, n_min, - 1): + model_sd = f_model(a, b, n) + table_sd = f_table(X["{}".format(security_level)], n) + print(n, table_sd, model_sd, model_sd >= table_sd) + + if table_sd > model_sd: + print("MODEL FAILS at n = {}".format(n)) + return "FAIL" + + return "PASS", n_min + + +def generate_and_verify(security_levels, log_q, name="verified_curves"): + + data = [] + + for sec in security_levels: + print("WE GO FOR {}".format(sec)) + # generate the model for security level sec + (a_sec, b_sec) = generate_curve(sec) + # verify the model for security level sec + res = verify_curve(sec, a_sec, b_sec) + # append the information into a list + data.append((a_sec, b_sec - log_q, sec, res[0], res[1])) + save(data, "{}.sobj".format(name)) + + return data + + +data = generate_and_verify( + [80, 96, 112, 128, 144, 160, 176, 192, 256], log_q=64) +print(data)