fix: new python lints

This commit is contained in:
Mayeul@Zama
2022-12-21 17:30:41 +01:00
committed by Quentin Bourgerie
parent 6ceed0f1f3
commit 2f877e80de
6 changed files with 47 additions and 46 deletions

View File

@@ -1,4 +1,6 @@
import sys, json;
import json
import sys
def print_curve(data):
print(f'\tSecurityCurve({data["security_level"]},{data["slope"]}, {data["bias"]}, {data["minimal_lwe_dimension"]}, KeyFormat::BINARY),')
@@ -9,4 +11,4 @@ def print_cpp_curves_declaration(datas):
print_curve(data)
print("}\n")
print_cpp_curves_declaration(json.load(sys.stdin))
print_cpp_curves_declaration(json.load(sys.stdin))

View File

@@ -1,4 +1,6 @@
import sys, json;
import json
import sys
def print_curve(data):
print(f' ({data["security_level"]}, SecurityWeights {{ slope: {data["slope"]}, bias: {data["bias"]}, minimal_lwe_dimension: {data["minimal_lwe_dimension"]} }}),')
@@ -10,4 +12,4 @@ def print_rust_curves_declaration(datas):
print_curve(data)
print("];")
print_rust_curves_declaration(json.load(sys.stdin))
print_rust_curves_declaration(json.load(sys.stdin))

View File

@@ -1,10 +1,12 @@
from estimator import LWE, ND
from sage.all import oo, load, floor
from generate_data import estimate, get_security_level
import argparse
import os
import sys
sys.path.insert(1, 'lattice-estimator')
from estimator import LWE, ND
from generate_data import estimate, get_security_level
from sage.all import floor, load, oo
sys.path.insert(1, "lattice-estimator")
LOG_N_MAX = 17 + 1
@@ -21,6 +23,7 @@ def get_index(sec, curves):
for i in range(len(curves)):
if curves[i][2] == sec:
return i
return None
def estimate_security_with_lattice_estimator(lwe_dimension, std_dev, log_q):
@@ -45,9 +48,8 @@ def get_minimal_lwe_dimension(curve, security_level, log_q):
:param security_level:
:param log_q:
:return:
"""
minimal_lwe_dim = curve[-1]
return minimal_lwe_dim
"""
return curve[-1]
def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
@@ -65,8 +67,7 @@ def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
a = curve[0]
b = curve[1] + log_q
stddev = minimal_stddev(a, b, lwe_dimension)
return stddev
return minimal_stddev(a, b, lwe_dimension)
def compare_curve_and_estimator(security_level, log_q, curves_dir):
@@ -137,5 +138,5 @@ if __name__ == "__main__":
args = CLI.parse_args()
for security_level in args.security_levels:
if not(compare_curve_and_estimator(security_level, args.log_q, args.curves_dir)):
exit(1)
exit(0)
sys.exit(1)
sys.exit(0)

View File

@@ -1,11 +1,13 @@
from estimator import RC, LWE, ND
from sage.all import oo, save, load
from math import log2
import multiprocessing
import argparse
import multiprocessing
import os
import sys
sys.path.insert(1, 'lattice-estimator')
from math import log2
from estimator import LWE, ND, RC
from sage.all import load, oo, save
sys.path.insert(1, "lattice-estimator")
old_models_sobj = ""
@@ -25,6 +27,7 @@ def old_models(security_level, sd, logq=32):
for i in range(len(curves)):
if curves[i][2] == sec:
return i
return None
if old_models_sobj is None or not(os.path.exists(old_models_sobj)):
return 450
@@ -48,10 +51,7 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
:param skip: attacks to skip
"""
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
return LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
def get_security_level(est, dp=2):
"""
@@ -64,10 +64,7 @@ def get_security_level(est, dp=2):
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
return round(log2(min(attack_costs)), dp)
def inequality(x, y):
"""A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
@@ -80,6 +77,8 @@ def inequality(x, y):
if x > y:
return -1
return None
def automated_param_select_n(params, target_security=128):
"""A function used to generate the smallest value of n which allows for
@@ -96,7 +95,6 @@ def automated_param_select_n(params, target_security=128):
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
params = params.updated(n=n_start)
@@ -108,7 +106,6 @@ def automated_param_select_n(params, target_security=128):
# (on binary key guessing)
while z * security_level < z * target_security:
# TODO: fill in this case! For n > 1024 we only need to consider every
# 256 (optimization)
params = params.updated(n=params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
@@ -138,7 +135,7 @@ def automated_param_select_n(params, target_security=128):
def generate_parameter_matrix(
params_in, sd_range, target_security_levels=[128], name="default_name"
params_in, sd_range, target_security_levels=(128), name="default_name"
):
"""
:param params_in: a initial set of LWE parameters
@@ -159,7 +156,7 @@ def generate_parameter_matrix(
try:
results = load("{}.sobj".format(name))
except BaseException:
results = dict()
results = {}
results["{}".format(lam)] = []
results["{}".format(lam)].append(
@@ -171,7 +168,7 @@ def generate_parameter_matrix(
def generate_zama_curves64(
sd_range=[2, 58], target_security_levels=[128], name="default_name"
sd_range=(2, 58), target_security_levels=(128), name="default_name"
):
"""
The top level function which we use to run the experiment

View File

@@ -1,23 +1,25 @@
import numpy as np
from sage.all import save, load, ceil
import argparse
import json
import os
import argparse
import sys
import numpy as np
from sage.all import ceil, load, save
def sort_data(security_level, curves_dir):
from operator import itemgetter
# step 1. load the data
X = load(os.path.join(curves_dir, f"{security_level}.sobj"))
x = load(os.path.join(curves_dir, f"{security_level}.sobj"))
# step 2. sort by SD
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
x = sorted(x["{}".format(security_level)], key=itemgetter(2))
# step3. replace the sorted value
X["{}".format(security_level)] = x
x["{}".format(security_level)] = x
return X
return x
def generate_curve(security_level, curves_dir):
@@ -63,16 +65,12 @@ def verify_curve(security_level, a, b, curves_dir):
# step 3. for each n, check whether we satisfy the table
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
#print(n_min)
#print(n_max)
for n in range(n_max, n_min, -1):
model_sd = f_model(a, b, n)
table_sd = f_table(X["{}".format(security_level)], n)
#print(n, table_sd, model_sd, model_sd >= table_sd)
if table_sd > model_sd:
#print("MODEL FAILS at n = {}".format(n))
return False
return True, n_min
@@ -85,7 +83,6 @@ def generate_and_verify(security_levels, log_q, curves_dir, name="verified_curve
fail = []
for sec in security_levels:
#print("WE GO FOR {}".format(sec))
# generate the model for security level sec
(a_sec, b_sec) = generate_curve(sec, curves_dir)
# verify the model for security level sec
@@ -127,6 +124,6 @@ if __name__ == "__main__":
if (fail):
print("FAILURE: Fail to verify the following curves")
print(json.dumps(fail))
exit(1)
sys.exit(1)
print(json.dumps(success))

View File

@@ -1,2 +1,4 @@
[tool.ruff]
line-length = 169
select = ["F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "A", "C4", "T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH", "PLC", "PLE", "PLR", "PLW", "RUF"]
ignore = ["D", "T20", "ANN", "N806", "ARG001", "S101", "BLE001"]