tidy finalized script

This commit is contained in:
Ben
2022-06-21 17:35:12 +01:00
parent 7feb1f599e
commit 2a304c4f60

View File

@@ -1,11 +1,7 @@
import gc
import multiprocessing
from estimator_new import *
from sage.all import oo, save, load
from math import log2
import gc
from multiprocessing import *
import multiprocessing
def old_models(security_level, sd, logq=32):
@@ -49,6 +45,7 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
"""
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
@@ -64,6 +61,7 @@ def get_security_level(est, dp=2):
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
@@ -91,47 +89,26 @@ def automated_param_select_n(params, target_security=128):
456
"""
# get an initial estimate
# costs = estimate(params)
# security_level = get_security_level(costs, 2)
# determine if we are above or below the target security level
# z = inequality(security_level, target_security)
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# TODO -- is this how we want to deal with the small n issue? Shouldn't the model have this baked in?
# we want to start no lower than n = 450
n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
#if n_start > 1024:
# we only consider powers-of-two for now, in this range
# n_log = log2(n_start)
# n_start = 2**round(n_log)
print("n_start = {}".format(n_start))
params = params.updated(n=n_start)
print(params)
#
costs2 = estimate(params)
security_level = get_security_level(costs2, 2)
costs2 = None
z = inequality(security_level, target_security)
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
while z * security_level < z * target_security:
# if params.n > 1024:
# we only need to consider powers-of-two in this case
# TODO: fill in this case! For n > 1024 we only need to consider every 256
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
params = params.updated(n = params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
# try none with delete, try none without delete
# test the list of objects that are in memory before end of program
costs = None
if -1 * params.Xe.stddev > 0:
print("target security level is unatainable")
print("target security level is unattainable")
break
# final estimate (we went too far in the above loop)
@@ -147,34 +124,25 @@ def automated_param_select_n(params, target_security=128):
log2(params.q),
security_level))
# final sanity check so we don't return insecure (or inf) parameters
# TODO: figure out inf in new estimator
# or security_level == oo:
if security_level < target_security:
params.updated(n=None)
return (params, security_level)
return params, security_level
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="v0.sobj"):
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
"""
:param params_in: a initial set of LWE parameters
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param params: the standard deviation of the LWE error
:param target_security: the target number of bits of security, 128 is default
EXAMPLE:
sage: X = generate_parameter_matrix()
sage: X
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
# grab min and max value/s of n
(sd_min, sd_max) = sd_range
for lam in target_security_levels:
print("LAM = {}".format(lam))
for sd in range(sd_min, sd_max + 1):
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
print("PARAMS OUT = {}".format(params_out))
try:
results = load("{}.sobj".format(name))
@@ -185,36 +153,36 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128],
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
save(results, "{}.sobj".format(name))
del(params_out)
gc.collect()
return results
def generate_zama_curves64(sd_range=range(5,9), target_security_levels=[256], name="default"):
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
"""
The top level function which we use to run the experiment
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
if __name__ == '__main__':
D = ND.DiscreteGaussian
vals = range(sd_range[0], sd_range[1])
procs = []
pool = multiprocessing.Pool(2)
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='TFHE_DEFAULT')
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
print(inputs[0])
res = pool.starmap(generate_parameter_matrix, inputs)
return "done"
def wrap(*args):
return generate_parameter_matrix(*args)
# The script runs the following commands
import sys
# grab values of the command-line input arguments
a = int(sys.argv[1])
b = int(sys.argv[2])
c = int(sys.argv[3])
print(b)
D = ND.DiscreteGaussian
init_params = LWE.Parameters(n=1024, q=2 ** 32, Xs=ND.UniformMod(2), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
# run the code
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))