update scripts

This commit is contained in:
Ben
2021-06-17 13:33:52 +01:00
parent 8650a0969a
commit 72f0af5cf4

View File

@@ -1,8 +1,11 @@
import matplotlib.pyplot as plt
import numpy as np
from sage.stats.distributions.discrete_gaussian_lattice import DiscreteGaussianDistributionIntegerSampler
from concrete_params import concrete_LWE_params, concrete_RLWE_params
import numpy as np
from pytablewriter import MarkdownTableWriter
from hybrid_decoding import parameter_search
from random import uniform
from mpl_toolkits import mplot3d
# easier to just load the estimator
load("estimator.py")
@@ -34,7 +37,7 @@ def get_security_level(estimate, decimal_places = 2):
""" Function to get the security level from an LWE Estimator output,
i.e. returns only the bit-security level (without the attack params)
:param estimate: the input estimate
:param decimal_places: the number of decimal places"%.2f" %
:param decimal_places: the number of decimal places
EXAMPLE:
sage: x = estimate_lwe(n = 256, q = 2**32, alpha = RR(8/2**32))
@@ -97,7 +100,7 @@ def get_all_security_levels(params):
sd = 2 ** (x["sd"]) * q
alpha = sqrt(2 * pi) * sd / RR(q)
secret_distribution = (0, 1)
# assume access to an infinite number of papers
# assume access to an infinite number of samples
m = oo
for model in cost_models:
@@ -105,14 +108,51 @@ def get_all_security_levels(params):
model = model[0]
except:
model = model
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=model, m=oo, skip = {"bkw","dec","arora-gb","mitm"})
estimate = parameter_search(mitm = True, reduction_cost_model = est.BKZ.sieve, n = n, q = q, alpha = alpha, m = m, secret_distribution = secret_distribution)
results.append(get_security_level(estimate))
RESULTS.append(results)
return RESULTS
def get_hybrid_security_levels(params):
""" A function which gets the security levels of a collection of TFHE parameters,
using the four cost models: classical, quantum, classical_conservative, and
quantum_conservative
:param params: a dictionary of LWE parameter sets (see concrete_params)
EXAMPLE:
sage: X = get_all_security_levels(concrete_LWE_params)
sage: X
[['LWE128_256',
126.692189756144,
117.566189756144,
98.6960000000000,
89.5700000000000], ...]
"""
RESULTS = []
for param in params:
results = [param]
x = params["{}".format(param)]
n = x["n"] * x["k"]
q = 2 ** 32
sd = 2 ** (x["sd"]) * q
alpha = sqrt(2 * pi) * sd / RR(q)
secret_distribution = (0, 1)
# assume access to an infinite number of papers
m = oo
model = est.BKZ.sieve
estimate = parameter_search(mitm = True, reduction_cost_model = est.BKZ.sieve, n = n, q = q, alpha = alpha, m = m, secret_distribution = secret_distribution)
results.append(get_security_level(estimate))
RESULTS.append(results)
return RESULTS
def latexit(results):
"""
@@ -210,7 +250,7 @@ def automated_param_select_n(sd, n=None, q=2 ** 32, reduction_cost_model=BKZ.sie
return ZZ(n)
def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.sieve, secret_distribution=(0, 1),
def automated_param_select_sd(n, sd=None, q=2**32, reduction_cost_model=BKZ.sieve, secret_distribution=(0, 1),
target_security=128):
""" A function used to generate the smallest value of sd which allows for
target_security bits of security, for the input values of (n,q)
@@ -231,32 +271,55 @@ def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.si
# pick some random sd which gets us close (based on concrete_LWE_params)
sd = round(n * 80 / (target_security * (-25)))
sd_ = 2 ** sd * q
# make sure sd satisfies q * sd > 1
sd = max(sd, -(log(q,2) - 2))
sd_ = (2 ** sd) * q
alpha = sqrt(2 * pi) * sd_ / RR(q)
# initial estimate, to determine if we are above or below the target security level
print("estimating for n, q, sd = {}".format(log(sd_,2)))
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"})
try:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo,
skip={"bkw", "dec", "arora-gb", "mitm"})
except:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo,
skip={"bkw", "dec", "arora-gb", "mitm", "dual"})
security_level = get_security_level(estimate)
z = inequality(security_level, target_security)
while z * security_level < z * target_security:
while z * security_level < z * target_security and sd > -log(q,2):
sd += z * 1
sd_ = 2 ** sd * q
sd_ = (2 ** sd) * q
alpha = sqrt(2 * pi) * sd_ / RR(q)
print("estimating for n, q, sd = {}".format(log(sd_,2)))
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
try:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"})
except:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo,
skip={"bkw", "dec", "arora-gb", "mitm", "dual"})
security_level = get_security_level(estimate)
if (-1 * sd > log(q, 2)):
print("target security level is unatainable")
break
# final estimate (we went too far in the above loop)
if security_level < target_security:
sd -= z * 1
sd_ = 2 ** sd * q
sd_ = (2 ** sd) * q
alpha = sqrt(2 * pi) * sd_ / RR(q)
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo)
try:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo, skip = {"bkw","dec","arora-gb","mitm"})
except:
estimate = estimate_lwe(n, alpha, q, secret_distribution=secret_distribution,
reduction_cost_model=reduction_cost_model, m=oo,
skip={"bkw", "dec", "arora-gb", "mitm", "dual"})
security_level = get_security_level(estimate)
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(n,
@@ -268,7 +331,7 @@ def automated_param_select_sd(n, sd=None, q=2 ** 32, reduction_cost_model=BKZ.si
return sd
def generate_parameter_matrix(n_range, sd=None, q=2 ** 32, reduction_cost_model=BKZ.sieve,
def generate_parameter_matrix(n_range, sd=None, q=2**32, reduction_cost_model=BKZ.sieve,
secret_distribution=(0, 1), target_security=128):
"""
:param n_range: a tuple (n_min, n_max) giving the values of n for which to generate parameters
@@ -300,12 +363,10 @@ def generate_parameter_matrix(n_range, sd=None, q=2 ** 32, reduction_cost_model=
sd_ = sd
RESULTS.append((n, q, sd))
return RESULTS
def generate_parameter_step(results):
def generate_parameter_step(results, label = None, torus_sd = True):
"""
Plot results
:param results: an output of generate_parameter_matrix
@@ -323,14 +384,18 @@ def generate_parameter_step(results):
for (n, q, sd) in results:
N.append(n)
SD.append(sd)
if torus_sd:
SD.append(sd)
else:
SD.append(sd + log(q,2))
plt.scatter(N, SD)
plt.plot(N, SD, label = label)
plt.legend(loc = "upper right")
return plt
def test_rounded_gaussian(sigma, number_samples):
def test_rounded_gaussian(sigma, number_samples, q = None):
"""
TODO: actually use a _rounded_ gaussian to match Concrete
@@ -351,8 +416,10 @@ def test_rounded_gaussian(sigma, number_samples):
samples = []
for i in range(number_samples):
samples.append(D())
if q:
samples.append(D() % q)
else:
samples.append(D())
# now create a histogram
hist = []
for val in set(samples):
@@ -363,5 +430,109 @@ def test_rounded_gaussian(sigma, number_samples):
return hist
def test_uniform(number_samples, q):
"""
TODO: actually use a _rounded_ gaussian to match Concrete
A function which simulates sampling from a Discrete Gaussian distribution
:param sigma: the standard deviation
:param number_samples: the number of samples to draw
returns: a list of (value, count) pairs (essentially a histogram)
EXAMPLE:
sage: X = test_rounded_gaussian(2/3, 100000)
sage: X
[(-3, 2), (-2, 714), (-1, 19495), (0, 59658), (1, 19452), (2, 678), (3, 1)]
"""
samples = []
for i in range(number_samples):
samples.append(round(uniform(0, q)))
# now create a histogram
hist = []
for val in set(samples):
hist.append((val, samples.count(val)))
# sort (values)
hist.sort(key=lambda x: x[0])
return hist
# dual bug example
# n = 256; q = 2**32; sd = 2**(-4); reduction_cost_model = BKZ.sieve
# _ = estimate_lwe(n, alpha, q, reduction_cost_model)
def test_params(n, q, sd, secret_distribution):
sd = sd * q
alpha = RR(sqrt(2*pi) * sd / q)
est = estimate_lwe(n, alpha, q, secret_distribution = secret_distribution, reduction_cost_model = BKZ.sieve, skip = ("arora-gb", "bkw", "mitm", "dec"))
return est
def generate_iso_lines(N = [256, 2048], SD = [0, 32], q = 2**32):
RESULTS = []
for n in range(N[0], N[1] + 1, 1):
for sd in range(SD[0], SD[1] + 1, 1):
sd = 2**sd
alpha = sqrt(2*pi) * sd / q
try:
est = estimate_lwe(n, alpha, q, secret_distribution = (0,1), reduction_cost_model = BKZ.sieve, skip = ("bkw", "mitm", "arora-gb", "dec"))
est = get_security_level(est, 2)
except:
est = estimate_lwe(n, alpha, q, secret_distribution = (0,1), reduction_cost_model = BKZ.sieve, skip = ("bkw", "mitm", "arora-gb", "dual", "dec"))
est = get_security_level(est, 2)
RESULTS.append((n, sd, est))
return RESULTS
def plot_iso_lines(results):
x1 = []
x2 = []
x3 = []
for z in results:
x1.append(z[0])
# use log(q)
# use -ve values to match Pascal's diagram
x2.append(-1 * log(z[1],2))
x3.append(z[3])
plt.scatter(x1, x2, c = x3)
plt.colorbar()
return plt
def test_multiple_sd(n, q, secret_distribution, reduction_cost_model, split = 33, m = oo):
est = []
Y = []
for sd_ in np.linspace(0,32,split):
Y.append(sd_)
sd = (2** (-1 * sd_))* q
alpha = sqrt(2*pi) * sd / q
try:
es = estimate_lwe(n=512, alpha=alpha, q=q, secret_distribution=(0, 1), reduction_cost_model = reduction_cost_model,
skip=("bkw", "mitm", "dec", "arora-gb"), m = m)
except:
print("except")
es = estimate_lwe(n=512, alpha=alpha, q=q, secret_distribution=(0, 1), reduction_cost_model = reduction_cost_model,
skip=("bkw", "mitm", "dec", "arora-gb", "dual"), m = m)
est.append(get_security_level(es,2))
return est, Y
def estimate_lwe_sd(n, sd, q, secret_distribution, reduction_cost_model, skip = ("bkw","mitm","dec","arora-gb"), m = oo):
alpha = sqrt(2*pi) * sd/q
x = estimate_lwe(n = n, alpha = alpha , q = q, m = m, secret_distribution = secret_distribution, reduction_cost_model = reduction_cost_model, skip = skip)
return x