mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-15 07:05:09 -05:00
@@ -1,6 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import sys, json;
|
||||
|
||||
def print_curve(data):
|
||||
print(f'\tSecurityCurve({data["security_level"]},{data["slope"]}, {data["bias"]}, {data["minimal_lwe_dimension"]}, KeyFormat::BINARY),')
|
||||
@@ -11,4 +9,4 @@ def print_cpp_curves_declaration(datas):
|
||||
print_curve(data)
|
||||
print("}\n")
|
||||
|
||||
print_cpp_curves_declaration(json.load(sys.stdin))
|
||||
print_cpp_curves_declaration(json.load(sys.stdin))
|
||||
@@ -1,6 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import sys, json;
|
||||
|
||||
def print_curve(data):
|
||||
print(f' ({data["security_level"]}, SecurityWeights {{ slope: {data["slope"]}, bias: {data["bias"]}, minimal_lwe_dimension: {data["minimal_lwe_dimension"]} }}),')
|
||||
@@ -12,4 +10,4 @@ def print_rust_curves_declaration(datas):
|
||||
print_curve(data)
|
||||
print("];")
|
||||
|
||||
print_rust_curves_declaration(json.load(sys.stdin))
|
||||
print_rust_curves_declaration(json.load(sys.stdin))
|
||||
@@ -1,12 +1,10 @@
|
||||
from estimator import LWE, ND
|
||||
from sage.all import oo, load, floor
|
||||
from generate_data import estimate, get_security_level
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from estimator import LWE, ND
|
||||
from generate_data import estimate, get_security_level
|
||||
from sage.all import floor, load, oo
|
||||
|
||||
sys.path.insert(1, "lattice-estimator")
|
||||
sys.path.insert(1, 'lattice-estimator')
|
||||
|
||||
|
||||
LOG_N_MAX = 17 + 1
|
||||
@@ -23,7 +21,6 @@ def get_index(sec, curves):
|
||||
for i in range(len(curves)):
|
||||
if curves[i][2] == sec:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def estimate_security_with_lattice_estimator(lwe_dimension, std_dev, log_q):
|
||||
@@ -48,8 +45,9 @@ def get_minimal_lwe_dimension(curve, security_level, log_q):
|
||||
:param security_level:
|
||||
:param log_q:
|
||||
:return:
|
||||
"""
|
||||
return curve[-1]
|
||||
"""
|
||||
minimal_lwe_dim = curve[-1]
|
||||
return minimal_lwe_dim
|
||||
|
||||
|
||||
def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
|
||||
@@ -67,7 +65,8 @@ def estimate_stddev_with_current_curve(curve, lwe_dimension, log_q):
|
||||
a = curve[0]
|
||||
b = curve[1] + log_q
|
||||
|
||||
return minimal_stddev(a, b, lwe_dimension)
|
||||
stddev = minimal_stddev(a, b, lwe_dimension)
|
||||
return stddev
|
||||
|
||||
|
||||
def compare_curve_and_estimator(security_level, log_q, curves_dir):
|
||||
@@ -138,5 +137,5 @@ if __name__ == "__main__":
|
||||
args = CLI.parse_args()
|
||||
for security_level in args.security_levels:
|
||||
if not(compare_curve_and_estimator(security_level, args.log_q, args.curves_dir)):
|
||||
sys.exit(1)
|
||||
sys.exit(0)
|
||||
exit(1)
|
||||
exit(0)
|
||||
@@ -1,13 +1,11 @@
|
||||
import argparse
|
||||
from estimator import RC, LWE, ND
|
||||
from sage.all import oo, save, load
|
||||
from math import log2
|
||||
import multiprocessing
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from math import log2
|
||||
|
||||
from estimator import LWE, ND, RC
|
||||
from sage.all import load, oo, save
|
||||
|
||||
sys.path.insert(1, "lattice-estimator")
|
||||
sys.path.insert(1, 'lattice-estimator')
|
||||
|
||||
|
||||
old_models_sobj = ""
|
||||
@@ -27,7 +25,6 @@ def old_models(security_level, sd, logq=32):
|
||||
for i in range(len(curves)):
|
||||
if curves[i][2] == sec:
|
||||
return i
|
||||
return None
|
||||
|
||||
if old_models_sobj is None or not(os.path.exists(old_models_sobj)):
|
||||
return 450
|
||||
@@ -51,7 +48,10 @@ def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
|
||||
:param skip: attacks to skip
|
||||
"""
|
||||
|
||||
return LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
|
||||
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
|
||||
|
||||
return est
|
||||
|
||||
|
||||
def get_security_level(est, dp=2):
|
||||
"""
|
||||
@@ -64,7 +64,10 @@ def get_security_level(est, dp=2):
|
||||
for key in est:
|
||||
attack_costs.append(est[key]["rop"])
|
||||
# get the security level correct to 'dp' decimal places
|
||||
return round(log2(min(attack_costs)), dp)
|
||||
security_level = round(log2(min(attack_costs)), dp)
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
def inequality(x, y):
|
||||
"""A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
|
||||
@@ -77,8 +80,6 @@ def inequality(x, y):
|
||||
if x > y:
|
||||
return -1
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def automated_param_select_n(params, target_security=128):
|
||||
"""A function used to generate the smallest value of n which allows for
|
||||
@@ -95,6 +96,7 @@ def automated_param_select_n(params, target_security=128):
|
||||
# get an estimate based on the prev. model
|
||||
print("n = {}".format(params.n))
|
||||
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
|
||||
# n_start = max(n_start, 450)
|
||||
# TODO: think about throwing an error if the required n < 450
|
||||
|
||||
params = params.updated(n=n_start)
|
||||
@@ -106,6 +108,7 @@ def automated_param_select_n(params, target_security=128):
|
||||
# (on binary key guessing)
|
||||
while z * security_level < z * target_security:
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every
|
||||
# 256 (optimization)
|
||||
params = params.updated(n=params.n + z * 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
@@ -135,7 +138,7 @@ def automated_param_select_n(params, target_security=128):
|
||||
|
||||
|
||||
def generate_parameter_matrix(
|
||||
params_in, sd_range, target_security_levels=(128), name="default_name"
|
||||
params_in, sd_range, target_security_levels=[128], name="default_name"
|
||||
):
|
||||
"""
|
||||
:param params_in: a initial set of LWE parameters
|
||||
@@ -156,7 +159,7 @@ def generate_parameter_matrix(
|
||||
try:
|
||||
results = load("{}.sobj".format(name))
|
||||
except BaseException:
|
||||
results = {}
|
||||
results = dict()
|
||||
results["{}".format(lam)] = []
|
||||
|
||||
results["{}".format(lam)].append(
|
||||
@@ -168,7 +171,7 @@ def generate_parameter_matrix(
|
||||
|
||||
|
||||
def generate_zama_curves64(
|
||||
sd_range=(2, 58), target_security_levels=(128), name="default_name"
|
||||
sd_range=[2, 58], target_security_levels=[128], name="default_name"
|
||||
):
|
||||
"""
|
||||
The top level function which we use to run the experiment
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
from sage.all import save, load, ceil
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from sage.all import ceil, load, save
|
||||
import argparse
|
||||
|
||||
|
||||
def sort_data(security_level, curves_dir):
|
||||
from operator import itemgetter
|
||||
|
||||
# step 1. load the data
|
||||
x = load(os.path.join(curves_dir, f"{security_level}.sobj"))
|
||||
X = load(os.path.join(curves_dir, f"{security_level}.sobj"))
|
||||
|
||||
# step 2. sort by SD
|
||||
x = sorted(x["{}".format(security_level)], key=itemgetter(2))
|
||||
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
|
||||
|
||||
# step3. replace the sorted value
|
||||
x["{}".format(security_level)] = x
|
||||
X["{}".format(security_level)] = x
|
||||
|
||||
return x
|
||||
return X
|
||||
|
||||
|
||||
def generate_curve(security_level, curves_dir):
|
||||
@@ -65,12 +63,16 @@ def verify_curve(security_level, a, b, curves_dir):
|
||||
|
||||
# step 3. for each n, check whether we satisfy the table
|
||||
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
|
||||
#print(n_min)
|
||||
#print(n_max)
|
||||
|
||||
for n in range(n_max, n_min, -1):
|
||||
model_sd = f_model(a, b, n)
|
||||
table_sd = f_table(X["{}".format(security_level)], n)
|
||||
#print(n, table_sd, model_sd, model_sd >= table_sd)
|
||||
|
||||
if table_sd > model_sd:
|
||||
#print("MODEL FAILS at n = {}".format(n))
|
||||
return False
|
||||
|
||||
return True, n_min
|
||||
@@ -83,6 +85,7 @@ def generate_and_verify(security_levels, log_q, curves_dir, name="verified_curve
|
||||
fail = []
|
||||
|
||||
for sec in security_levels:
|
||||
#print("WE GO FOR {}".format(sec))
|
||||
# generate the model for security level sec
|
||||
(a_sec, b_sec) = generate_curve(sec, curves_dir)
|
||||
# verify the model for security level sec
|
||||
@@ -124,6 +127,6 @@ if __name__ == "__main__":
|
||||
if (fail):
|
||||
print("FAILURE: Fail to verify the following curves")
|
||||
print(json.dumps(fail))
|
||||
sys.exit(1)
|
||||
exit(1)
|
||||
|
||||
print(json.dumps(success))
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
[tool.ruff]
|
||||
line-length = 169
|
||||
select = ["F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "A", "C4", "T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH", "PLC", "PLE", "PLR", "PLW", "RUF"]
|
||||
ignore = ["D", "T20", "ANN", "N806", "ARG001", "S101", "BLE001"]
|
||||
|
||||
Reference in New Issue
Block a user