Revert "fix: new python lints"

This reverts commit 2f877e80de.
This commit is contained in:
Quentin Bourgerie
2023-01-11 17:00:11 +01:00
parent e0f0274f35
commit 6ca8a55b66
6 changed files with 46 additions and 47 deletions

View File

@@ -1,6 +1,4 @@
import json
import sys
import sys, json;
def print_curve(data):
print(f'\tSecurityCurve({data["security_level"]},{data["slope"]}, {data["bias"]}, {data["minimal_lwe_dimension"]}, KeyFormat::BINARY),')
@@ -11,4 +9,4 @@ def print_cpp_curves_declaration(datas):
print_curve(data)
print("}\n")
print_cpp_curves_declaration(json.load(sys.stdin))
print_cpp_curves_declaration(json.load(sys.stdin))

View File

@@ -1,6 +1,4 @@
import json
import sys
import sys, json;
def print_curve(data):
print(f' ({data["security_level"]}, SecurityWeights {{ slope: {data["slope"]}, bias: {data["bias"]}, minimal_lwe_dimension: {data["minimal_lwe_dimension"]} }}),')
@@ -12,4 +10,4 @@ def print_rust_curves_declaration(datas):
print_curve(data)
print("];")
print_rust_curves_declaration(json.load(sys.stdin))
print_rust_curves_declaration(json.load(sys.stdin))

View File

@@ -1,12 +1,10 @@
from estimator import LWE, ND
from sage.all import oo, load, floor
from generate_data import estimate, get_security_level
import argparse
import os
import sys
from estimator import LWE, ND
from generate_data import estimate, get_security_level
from sage.all import floor, load, oo
sys.path.insert(1, "lattice-estimator")
sys.path.insert(1, 'lattice-estimator')
LOG_N_MAX = 17 + 1
@@ -23,7 +21,6 @@ def get_index(sec, curves):
for i in range(len(curves)):
if curves[i][2] == sec:
return i
return None
def estimate_security_with_lattice_estimator(lwe_dimension, std_dev, log_q):
@@ -48,8 +45,9 @@ def get_minimal_lwe_dimension(curve, security_level, log_q):
:param security_level:
:param log_q:
:return:
"""
return curve[-1]
"""
minimal_lwe_dim = curve[-1]
return minimal_lwe_dim
def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
@@ -67,7 +65,8 @@ def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
a = curve[0]
b = curve[1] + log_q
return minimal_stddev(a, b, lwe_dimension)
stddev = minimal_stddev(a, b, lwe_dimension)
return stddev
def compare_curve_and_estimator(security_level, log_q, curves_dir):
@@ -138,5 +137,5 @@ if __name__ == "__main__":
args = CLI.parse_args()
for security_level in args.security_levels:
if not(compare_curve_and_estimator(security_level, args.log_q, args.curves_dir)):
sys.exit(1)
sys.exit(0)
exit(1)
exit(0)

View File

@@ -1,13 +1,11 @@
import argparse
from estimator import RC, LWE, ND
from sage.all import oo, save, load
from math import log2
import multiprocessing
import argparse
import os
import sys
from math import log2
from estimator import LWE, ND, RC
from sage.all import load, oo, save
sys.path.insert(1, "lattice-estimator")
sys.path.insert(1, 'lattice-estimator')
old_models_sobj = ""
@@ -27,7 +25,6 @@ def old_models(security_level, sd, logq=32):
for i in range(len(curves)):
if curves[i][2] == sec:
return i
return None
if old_models_sobj is None or not(os.path.exists(old_models_sobj)):
return 450
@@ -51,7 +48,10 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
:param skip: attacks to skip
"""
return LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
def get_security_level(est, dp=2):
"""
@@ -64,7 +64,10 @@ def get_security_level(est, dp=2):
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
return round(log2(min(attack_costs)), dp)
security_level = round(log2(min(attack_costs)), dp)
return security_level
def inequality(x, y):
"""A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
@@ -77,8 +80,6 @@ def inequality(x, y):
if x > y:
return -1
return None
def automated_param_select_n(params, target_security=128):
"""A function used to generate the smallest value of n which allows for
@@ -95,6 +96,7 @@ def automated_param_select_n(params, target_security=128):
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
params = params.updated(n=n_start)
@@ -106,6 +108,7 @@ def automated_param_select_n(params, target_security=128):
# (on binary key guessing)
while z * security_level < z * target_security:
# TODO: fill in this case! For n > 1024 we only need to consider every
# 256 (optimization)
params = params.updated(n=params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
@@ -135,7 +138,7 @@ def automated_param_select_n(params, target_security=128):
def generate_parameter_matrix(
params_in, sd_range, target_security_levels=(128), name="default_name"
params_in, sd_range, target_security_levels=[128], name="default_name"
):
"""
:param params_in: a initial set of LWE parameters
@@ -156,7 +159,7 @@ def generate_parameter_matrix(
try:
results = load("{}.sobj".format(name))
except BaseException:
results = {}
results = dict()
results["{}".format(lam)] = []
results["{}".format(lam)].append(
@@ -168,7 +171,7 @@ def generate_parameter_matrix(
def generate_zama_curves64(
sd_range=(2, 58), target_security_levels=(128), name="default_name"
sd_range=[2, 58], target_security_levels=[128], name="default_name"
):
"""
The top level function which we use to run the experiment

View File

@@ -1,25 +1,23 @@
import argparse
import numpy as np
from sage.all import save, load, ceil
import json
import os
import sys
import numpy as np
from sage.all import ceil, load, save
import argparse
def sort_data(security_level, curves_dir):
from operator import itemgetter
# step 1. load the data
x = load(os.path.join(curves_dir, f"{security_level}.sobj"))
X = load(os.path.join(curves_dir, f"{security_level}.sobj"))
# step 2. sort by SD
x = sorted(x["{}".format(security_level)], key=itemgetter(2))
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
# step3. replace the sorted value
x["{}".format(security_level)] = x
X["{}".format(security_level)] = x
return x
return X
def generate_curve(security_level, curves_dir):
@@ -65,12 +63,16 @@ def verify_curve(security_level, a, b, curves_dir):
# step 3. for each n, check whether we satisfy the table
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
#print(n_min)
#print(n_max)
for n in range(n_max, n_min, -1):
model_sd = f_model(a, b, n)
table_sd = f_table(X["{}".format(security_level)], n)
#print(n, table_sd, model_sd, model_sd >= table_sd)
if table_sd > model_sd:
#print("MODEL FAILS at n = {}".format(n))
return False
return True, n_min
@@ -83,6 +85,7 @@ def generate_and_verify(security_levels, log_q, curves_dir, name="verified_curve
fail = []
for sec in security_levels:
#print("WE GO FOR {}".format(sec))
# generate the model for security level sec
(a_sec, b_sec) = generate_curve(sec, curves_dir)
# verify the model for security level sec
@@ -124,6 +127,6 @@ if __name__ == "__main__":
if (fail):
print("FAILURE: Fail to verify the following curves")
print(json.dumps(fail))
sys.exit(1)
exit(1)
print(json.dumps(success))

View File

@@ -1,4 +1,2 @@
[tool.ruff]
line-length = 169
select = ["F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "A", "C4", "T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH", "PLC", "PLE", "PLR", "PLW", "RUF"]
ignore = ["D", "T20", "ANN", "N806", "ARG001", "S101", "BLE001"]