diff --git a/scripts.py b/scripts.py index 515e83461..d208a393a 100644 --- a/scripts.py +++ b/scripts.py @@ -1,8 +1,11 @@ import matplotlib.pyplot as plt -import numpy as np from sage.stats.distributions.discrete_gaussian_lattice import DiscreteGaussianDistributionIntegerSampler from concrete_params import concrete_LWE_params, concrete_RLWE_params +import numpy as np from pytablewriter import MarkdownTableWriter +from hybrid_decoding import parameter_search +from random import uniform +from mpl_toolkits import mplot3d # easier to just load the estimator load("estimator.py") @@ -34,7 +37,7 @@ def get_security_level(estimate, decimal_places = 2): """ Function to get the security level from an LWE Estimator output, i.e. returns only the bit-security level (without the attack params) :param estimate: the input estimate - :param decimal_places: the number of decimal places"%.2f" % + :param decimal_places: the number of decimal places EXAMPLE: sage: x = estimate_lwe(n = 256, q = 2**32, alpha = RR(8/2**32)) @@ -97,7 +100,7 @@ def get_all_security_levels(params): sd = 2 ** (x["sd"]) * q alpha = sqrt(2 * pi) * sd / RR(q) secret_distribution = (0, 1) - # assume access to an infinite number of papers + # assume access to an infinite number of samples m = oo for model in cost_models: @@ -105,14 +108,51 @@ def get_all_security_levels(params): model = model[0] except: model = model - estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, - reduction_cost_model=model, m=oo, skip = {"bkw","dec","arora-gb","mitm"}) + estimate = parameter_search(mitm = True, reduction_cost_model = est.BKZ.sieve, n = n, q = q, alpha = alpha, m = m, secret_distribution = secret_distribution) results.append(get_security_level(estimate)) RESULTS.append(results) return RESULTS +def get_hybrid_security_levels(params): + """ A function which gets the security levels of a collection of TFHE parameters, + using the four cost models: classical, quantum, classical_conservative, and + quantum_conservative + :param params: a dictionary of LWE parameter sets (see concrete_params) + + EXAMPLE: + sage: X = get_all_security_levels(concrete_LWE_params) + sage: X + [['LWE128_256', + 126.692189756144, + 117.566189756144, + 98.6960000000000, + 89.5700000000000], ...] + """ + + RESULTS = [] + + for param in params: + + results = [param] + x = params["{}".format(param)] + n = x["n"] * x["k"] + q = 2 ** 32 + sd = 2 ** (x["sd"]) * q + alpha = sqrt(2 * pi) * sd / RR(q) + secret_distribution = (0, 1) + # assume access to an infinite number of papers + m = oo + + model = est.BKZ.sieve + estimate = parameter_search(mitm = True, reduction_cost_model = est.BKZ.sieve, n = n, q = q, alpha = alpha, m = m, secret_distribution = secret_distribution) + results.append(get_security_level(estimate)) + + RESULTS.append(results) + + return RESULTS + def latexit(results): """ @@ -210,7 +250,7 @@ def automated_param_select_n(sd, n=None, q=2 ** 32, reduction_cost_model=BKZ.sie return ZZ(n) -def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.sieve, secret_distribution=(0, 1), +def automated_param_select_sd(n, sd=None, q=2**32, reduction_cost_model=BKZ.sieve, secret_distribution=(0, 1), target_security=128): """ A function used to generate the smallest value of sd which allows for target_security bits of security, for the input values of (n,q) @@ -231,32 +271,55 @@ def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.si # pick some random sd which gets us close (based on concrete_LWE_params) sd = round(n * 80 / (target_security * (-25))) - sd_ = 2 ** sd * q + # make sure sd satisfies q * sd > 1 + sd = max(sd, -(log(q,2) - 2)) + + sd_ = (2 ** sd) * q alpha = sqrt(2 * pi) * sd_ / RR(q) # initial estimate, to determine if we are above or below the target security level print("estimating for n, q, sd = {}".format(log(sd_,2))) - estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, - reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"}) + try: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + reduction_cost_model=reduction_cost_model, m=oo, + skip={"bkw", "dec", "arora-gb", "mitm"}) + except: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + reduction_cost_model=reduction_cost_model, m=oo, + skip={"bkw", "dec", "arora-gb", "mitm", "dual"}) security_level = get_security_level(estimate) z = inequality(security_level, target_security) - while z * security_level < z * target_security: + while z * security_level < z * target_security and sd > -log(q,2): sd += z * 1 - sd_ = 2 ** sd * q + sd_ = (2 ** sd) * q alpha = sqrt(2 * pi) * sd_ / RR(q) print("estimating for n, q, sd = {}".format(log(sd_,2))) - estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + try: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"}) + except: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + reduction_cost_model=reduction_cost_model, m=oo, + skip={"bkw", "dec", "arora-gb", "mitm", "dual"}) security_level = get_security_level(estimate) + if (-1 * sd > log(q, 2)): + print("target security level is unatainable") + break + # final estimate (we went too far in the above loop) if security_level < target_security: sd -= z * 1 - sd_ = 2 ** sd * q + sd_ = (2 ** sd) * q alpha = sqrt(2 * pi) * sd_ / RR(q) - estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, - reduction_cost_model=reduction_cost_model, m=oo) + try: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"}) + except: + estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution, + reduction_cost_model=reduction_cost_model, m=oo, + skip={"bkw", "dec", "arora-gb", "mitm", "dual"}) security_level = get_security_level(estimate) print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(n, @@ -268,7 +331,7 @@ def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.si return sd -def generate_parameter_matrix(n_range, sd=None, q=2 ** 32, reduction_cost_model=BKZ.sieve, +def generate_parameter_matrix(n_range, sd=None, q=2**32, reduction_cost_model=BKZ.sieve, secret_distribution=(0, 1), target_security=128): """ :param n_range: a tuple (n_min, n_max) giving the values of n for which to generate parameters @@ -300,12 +363,10 @@ def generate_parameter_matrix(n_range, sd=None, q=2 ** 32, reduction_cost_model= sd_ = sd RESULTS.append((n, q, sd)) - - return RESULTS -def generate_parameter_step(results): +def generate_parameter_step(results, label = None, torus_sd = True): """ Plot results :param results: an output of generate_parameter_matrix @@ -323,14 +384,18 @@ def generate_parameter_step(results): for (n, q, sd) in results: N.append(n) - SD.append(sd) + if torus_sd: + SD.append(sd) + else: + SD.append(sd + log(q,2)) - plt.scatter(N, SD) + plt.plot(N, SD, label = label) + plt.legend(loc = "upper right") return plt -def test_rounded_gaussian(sigma, number_samples): +def test_rounded_gaussian(sigma, number_samples, q = None): """ TODO: actually use a _rounded_ gaussian to match Concrete @@ -351,8 +416,10 @@ def test_rounded_gaussian(sigma, number_samples): samples = [] for i in range(number_samples): - samples.append(D()) - + if q: + samples.append(D() % q) + else: + samples.append(D()) # now create a histogram hist = [] for val in set(samples): @@ -363,5 +430,109 @@ def test_rounded_gaussian(sigma, number_samples): return hist +def test_uniform(number_samples, q): + """ + TODO: actually use a _rounded_ gaussian to match Concrete + + A function which simulates sampling from a Discrete Gaussian distribution + :param sigma: the standard deviation + :param number_samples: the number of samples to draw + + returns: a list of (value, count) pairs (essentially a histogram) + + EXAMPLE: + + sage: X = test_rounded_gaussian(2/3, 100000) + sage: X + [(-3, 2), (-2, 714), (-1, 19495), (0, 59658), (1, 19452), (2, 678), (3, 1)] + """ + + samples = [] + + for i in range(number_samples): + samples.append(round(uniform(0, q))) + # now create a histogram + hist = [] + for val in set(samples): + hist.append((val, samples.count(val))) + + # sort (values) + hist.sort(key=lambda x: x[0]) + return hist + +# dual bug example +# n = 256; q = 2**32; sd = 2**(-4); reduction_cost_model = BKZ.sieve +# _ = estimate_lwe(n, alpha, q, reduction_cost_model) + +def test_params(n, q, sd, secret_distribution): + + sd = sd * q + alpha = RR(sqrt(2*pi) * sd / q) + + est = estimate_lwe(n, alpha, q, secret_distribution = secret_distribution, reduction_cost_model = BKZ.sieve, skip = ("arora-gb", "bkw", "mitm", "dec")) + + return est + +def generate_iso_lines(N = [256, 2048], SD = [0, 32], q = 2**32): + + RESULTS = [] + + for n in range(N[0], N[1] + 1, 1): + for sd in range(SD[0], SD[1] + 1, 1): + sd = 2**sd + alpha = sqrt(2*pi) * sd / q + try: + est = estimate_lwe(n, alpha, q, secret_distribution = (0,1), reduction_cost_model = BKZ.sieve, skip = ("bkw", "mitm", "arora-gb", "dec")) + est = get_security_level(est, 2) + except: + est = estimate_lwe(n, alpha, q, secret_distribution = (0,1), reduction_cost_model = BKZ.sieve, skip = ("bkw", "mitm", "arora-gb", "dual", "dec")) + est = get_security_level(est, 2) + RESULTS.append((n, sd, est)) + + return RESULTS + +def plot_iso_lines(results): + + x1 = [] + x2 = [] + x3 = [] + + for z in results: + x1.append(z[0]) + # use log(q) + # use -ve values to match Pascal's diagram + x2.append(-1 * log(z[1],2)) + x3.append(z[3]) + + plt.scatter(x1, x2, c = x3) + plt.colorbar() + + return plt +def test_multiple_sd(n, q, secret_distribution, reduction_cost_model, split = 33, m = oo): + est = [] + Y = [] + for sd_ in np.linspace(0,32,split): + Y.append(sd_) + sd = (2** (-1 * sd_))* q + alpha = sqrt(2*pi) * sd / q + try: + es = estimate_lwe(n=512, alpha=alpha, q=q, secret_distribution=(0, 1), reduction_cost_model = reduction_cost_model, + skip=("bkw", "mitm", "dec", "arora-gb"), m = m) + except: + print("except") + es = estimate_lwe(n=512, alpha=alpha, q=q, secret_distribution=(0, 1), reduction_cost_model = reduction_cost_model, + skip=("bkw", "mitm", "dec", "arora-gb", "dual"), m = m) + est.append(get_security_level(es,2)) + + return est, Y + + +def estimate_lwe_sd(n, sd, q, secret_distribution, reduction_cost_model, skip = ("bkw","mitm","dec","arora-gb"), m = oo): + + alpha = sqrt(2*pi) * sd/q + x = estimate_lwe(n = n, alpha = alpha , q = q, m = m, secret_distribution = secret_distribution, reduction_cost_model = reduction_cost_model, skip = skip) + + return x +