mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
pep8
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save, load, ceil
|
||||
from math import log2
|
||||
@@ -91,7 +92,8 @@ def automated_param_select_n(params, target_security=128):
|
||||
|
||||
# get an estimate based on the prev. model
|
||||
print("n = {}".format(params.n))
|
||||
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
|
||||
n_start = old_models(target_security, log2(
|
||||
params.Xe.stddev), log2(params.q))
|
||||
# n_start = max(n_start, 450)
|
||||
# TODO: think about throwing an error if the required n < 450
|
||||
|
||||
@@ -103,7 +105,7 @@ def automated_param_select_n(params, target_security=128):
|
||||
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
|
||||
while z * security_level < z * target_security:
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
|
||||
params = params.updated(n = params.n + z * 8)
|
||||
params = params.updated(n=params.n + z * 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
@@ -120,8 +122,10 @@ def automated_param_select_n(params, target_security=128):
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
|
||||
log2(params.Xe.stddev),
|
||||
log2(params.q),
|
||||
log2(
|
||||
params.Xe.stddev),
|
||||
log2(
|
||||
params.q),
|
||||
security_level))
|
||||
|
||||
if security_level < target_security:
|
||||
@@ -143,7 +147,8 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128],
|
||||
for sd in range(sd_min, sd_max + 1):
|
||||
print("run for {}".format(lam, sd))
|
||||
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
|
||||
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
(params_out, sec) = automated_param_select_n(
|
||||
params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
|
||||
try:
|
||||
results = load("{}.sobj".format(name))
|
||||
@@ -151,7 +156,8 @@ def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128],
|
||||
results = dict()
|
||||
results["{}".format(lam)] = []
|
||||
|
||||
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
|
||||
results["{}".format(lam)].append(
|
||||
(params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
|
||||
save(results, "{}.sobj".format(name))
|
||||
|
||||
return results
|
||||
@@ -170,18 +176,20 @@ def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name=
|
||||
D = ND.DiscreteGaussian
|
||||
vals = range(sd_range[0], sd_range[1])
|
||||
pool = multiprocessing.Pool(2)
|
||||
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
|
||||
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
|
||||
init_params = LWE.Parameters(
|
||||
n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
|
||||
inputs = [(init_params, (val, val), target_security_levels, name)
|
||||
for val in vals]
|
||||
res = pool.starmap(generate_parameter_matrix, inputs)
|
||||
|
||||
return "done"
|
||||
|
||||
|
||||
# The script runs the following commands
|
||||
import sys
|
||||
# grab values of the command-line input arguments
|
||||
a = int(sys.argv[1])
|
||||
b = int(sys.argv[2])
|
||||
c = int(sys.argv[3])
|
||||
# run the code
|
||||
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))
|
||||
generate_zama_curves64(sd_range=(b, c), target_security_levels=[
|
||||
a], name="{}".format(a))
|
||||
|
||||
@@ -9,7 +9,7 @@ def sort_data(security_level):
|
||||
X = load("{}.sobj".format(security_level))
|
||||
|
||||
# step 2. sort by SD
|
||||
x = sorted(X["{}".format(security_level)], key = itemgetter(2))
|
||||
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
|
||||
|
||||
# step3. replace the sorted value
|
||||
X["{}".format(security_level)] = x
|
||||
@@ -28,14 +28,14 @@ def generate_curve(security_level):
|
||||
for x in X["{}".format(security_level)]:
|
||||
N.append(x[0])
|
||||
SD.append(x[2] + 0.5)
|
||||
|
||||
|
||||
# step 3. perform interpolation and return coefficients
|
||||
(a,b) = np.polyfit(N, SD, 1)
|
||||
(a, b) = np.polyfit(N, SD, 1)
|
||||
|
||||
return a, b
|
||||
|
||||
|
||||
def verify_curve(security_level, a = None, b = None):
|
||||
|
||||
def verify_curve(security_level, a=None, b=None):
|
||||
|
||||
# step 1. get the table and max values of n, sd
|
||||
X = sort_data(security_level)
|
||||
@@ -53,9 +53,9 @@ def verify_curve(security_level, a = None, b = None):
|
||||
if n < n_val:
|
||||
pass
|
||||
else:
|
||||
j = i
|
||||
j = i
|
||||
break
|
||||
|
||||
|
||||
# now j is the correct index, we return the corresponding sd
|
||||
return table[j][2]
|
||||
|
||||
@@ -67,7 +67,7 @@ def verify_curve(security_level, a = None, b = None):
|
||||
for n in range(n_max, n_min, - 1):
|
||||
model_sd = f_model(a, b, n)
|
||||
table_sd = f_table(X["{}".format(security_level)], n)
|
||||
print(n , table_sd, model_sd, model_sd >= table_sd)
|
||||
print(n, table_sd, model_sd, model_sd >= table_sd)
|
||||
|
||||
if table_sd > model_sd:
|
||||
print("MODEL FAILS at n = {}".format(n))
|
||||
@@ -76,7 +76,7 @@ def verify_curve(security_level, a = None, b = None):
|
||||
return "PASS", n_min
|
||||
|
||||
|
||||
def generate_and_verify(security_levels, log_q, name = "verified_curves"):
|
||||
def generate_and_verify(security_levels, log_q, name="verified_curves"):
|
||||
|
||||
data = []
|
||||
|
||||
@@ -89,9 +89,10 @@ def generate_and_verify(security_levels, log_q, name = "verified_curves"):
|
||||
# append the information into a list
|
||||
data.append((a_sec, b_sec - log_q, sec, res[0], res[1]))
|
||||
save(data, "{}.sobj".format(name))
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
data = generate_and_verify([80, 96, 112, 128, 144, 160, 176, 192, 256], log_q = 64)
|
||||
data = generate_and_verify(
|
||||
[80, 96, 112, 128, 144, 160, 176, 192, 256], log_q=64)
|
||||
print(data)
|
||||
|
||||
Reference in New Issue
Block a user