start getting multiprocessing to work

This commit is contained in:
Ben
2022-06-09 22:27:37 +02:00
parent 4d09db8b74
commit 15963ff51f

View File

@@ -4,9 +4,10 @@ from estimator_new import *
from sage.all import oo, save
from math import log2
import gc
from multiprocessing import *
def old_models(security_level, sd, logq = 32):
def old_models(security_level, sd, logq=32):
"""
Use the old model as a starting point for the data gathering step
:param security_level: the security level under consideration
@@ -14,12 +15,11 @@ def old_models(security_level, sd, logq = 32):
:param logq : the (base 2 log) value of the LWE modulus q
"""
def evaluate_model(sd, a, b):
return (sd - b)/a
def evaluate_model(a, b, stddev=sd):
return (stddev - b)/a
models = dict()
# TODO: figure out a way to import these from a datafile, for future version
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
@@ -34,32 +34,38 @@ def old_models(security_level, sd, logq = 32):
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
(a, b) = models["{}".format(security_level)]
n_est = evaluate_model(sd, a, b)
n_est = evaluate_model(a, b, sd)
return round(n_est)
def estimate(params, red_cost_model = RC.BDGL16):
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
"""
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
:param params: the input LWE parameters
:param red_cost_model: the lattice reduction cost model
:param skip: attacks to skip
"""
est = LWE.estimate(params, deny_list=("arora-gb", "bkw"), red_cost_model=red_cost_model)
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
def get_security_level(est, dp = 2):
def get_security_level(est, dp=2):
"""
Get the security level lambda from a Lattice Estimator output
:param est: the Lattice Estimator output
:param dp : the number of decimal places to consider
:param dp: the number of decimal places to consider
"""
attack_costs = []
for key in est.keys():
# note: key does not need to be specified est vs est.keys()
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
def inequality(x, y):
""" A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
:param x: the LHS of the inequality
@@ -71,6 +77,7 @@ def inequality(x, y):
if x > y:
return -1
def automated_param_select_n(params, target_security=128):
""" A function used to generate the smallest value of n which allows for
target_security bits of security, for the input values of (params.Xe.stddev,params.q)
@@ -101,15 +108,15 @@ def automated_param_select_n(params, target_security=128):
# n_log = log2(n_start)
# n_start = 2**round(n_log)
print("n_start = {}".format(n_start))
params = params.updated(n=n_start)
print(params)
#
costs2 = estimate(params)
security_level = get_security_level(costs2, 2)
costs2 = None
z = inequality(security_level, target_security)
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
while z * security_level < z * target_security:
# if params.n > 1024:
@@ -118,6 +125,9 @@ def automated_param_select_n(params, target_security=128):
params = params.updated(n = params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
# try none with delete, try none without delete
# test the list of objects that are in memory before end of program
costs = None
if -1 * params.Xe.stddev > 0:
print("target security level is unatainable")
@@ -127,14 +137,14 @@ def automated_param_select_n(params, target_security=128):
if security_level < target_security:
# we make n larger
print("we make n larger")
params = params.updated(n = params.n + 8)
params = params.updated(n=params.n + 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
log2(params.Xe.stddev),
log2(params.q),
security_level))
security_level))
# final sanity check so we don't return insecure (or inf) parameters
# TODO: figure out inf in new estimator
@@ -142,12 +152,9 @@ def automated_param_select_n(params, target_security=128):
if security_level < target_security:
params.updated(n=None)
del(costs)
del(costs2)
gc.collect()
return params
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="v0.sobj"):
"""
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
@@ -175,26 +182,107 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128],
gc.collect()
return results
def generate_zama_curves64(sd_range=[2, 56], target_security_levels=[256], name="v0256.sobj"):
D = ND.DiscreteGaussian
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
raw_data = generate_parameter_matrix(init_params, sd_range=sd_range, target_security_levels=target_security_levels, name=name)
def generate_parameter_matrix_para(params_in, sd_range, target_security_levels=[128], name="v0.sobj"):
"""
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param params: the standard deviation of the LWE error
:param target_security: the target number of bits of security, 128 is default
return raw_data
EXAMPLE:
sage: X = generate_parameter_matrix()
sage: X
"""
if __name__ == "__main__":
def plota_curve(raw_data, security_level):
results = dict()
data = raw_data["{}".format(security_level)]
def test_memory(x):
print("doing job...")
print(x)
y = LWE.estimate(x, deny_list=("arora-gb", "bkw"))
return y
# grab min and max value/s of n
(sd_min, sd_max) = sd_range
print(sd_range)
for lam in target_security_levels:
results["{}".format(lam)] = []
names = range(sd_min, sd_max + 1)
procs = []
proc = Process(target=automated_param_select_n)
procs.append(proc)
proc.start()
p = Pool(1)
for name in names:
proc = Process(target=test_memory, args=(name,))
procs.append(proc)
proc.start()
proc.join()
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
params_out = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
results["{}".format(lam)].append((params_out.n, params_out.q, params_out.Xe.stddev))
save(results, "{}.sobj".format(name))
params_out = None
del(params_out)
gc.collect()
return results
# what we run
def generate_zama_curves64(sd_range=[2, 56], target_security_levels=[256], name="default", pools = 1):
if __name__ == '__main__':
D = ND.DiscreteGaussian
vals = sd_range
p = Pool(pools)
procs = []
for val in vals:
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**55), m=oo, tag='TFHE_DEFAULT')
proc = Process(target=generate_parameter_matrix, args=(init_params, [val, val + 1], target_security_levels, name))
procs.append(proc)
proc.start()
return "done"
import sys
a = int(sys.argv[1])
print(a)
print("input arg is {}".format(a))
D = ND.DiscreteGaussian
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=ND.UniformMod(2), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
generate_zama_curves64(target_security_levels=[a], name="{}".format(a))
init_params = LWE.Parameters(n=1024, q=2 ** 32, Xs=ND.UniformMod(2), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
#automated_param_select_n(init_params, target_security=128)
#automated_param_select_n(init_params, target_security=192)
generate_zama_curves64(sd_range=[50, 53], target_security_levels=[a], name="{}".format("testing"))
#if __name__ == "__main__":
# D = ND.DiscreteGaussian
# params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT')
#
# names = [params, params.updated(n=761), params.updated(q=2**65), params.updated(n=762)]
# procs = []
# proc = Process(target=print_func)
# procs.append(proc)
# proc.start()
# p = Pool(1)
#
# for name in names:
# proc = Process(target=test_memory, args=(name,))
# procs.append(proc)
# proc.start()
# proc.join()
#
# for proc in procs:
# proc.join()