3
.github/workflows/pep8.yml
vendored
@@ -19,4 +19,5 @@ jobs:
|
||||
- name: PEP8
|
||||
run: |
|
||||
pip install --upgrade pyproject-flake8
|
||||
flake8 new_scripts.py
|
||||
flake8 generate_data.py
|
||||
flake8 verify_curves.py
|
||||
|
||||
78
README.rst
@@ -3,12 +3,12 @@ Parameter curves for Concrete
|
||||
|
||||
This Github repository contains the code needed to generate the Parameter curves used inside Zama. The repository contains the following files:
|
||||
|
||||
- cpp/, Python scripts to generate a cpp file containing the parameter curves
|
||||
- cpp/, Python scripts to generate a cpp file containing the parameter curves (needs updating)
|
||||
- data/, a folder containing the data generated for previous curves.
|
||||
- estimator/, Zama's internal version of the LWE Estimator
|
||||
- figs/, a folder containing various figures related to the parameter curves
|
||||
- scripts.py, a copy of all scripts required to generate the parameter curves
|
||||
- a variety of other python files, used for estimating the security of previous Concrete parameter sets
|
||||
- estimator_new/, the Lattice estimator (TODO: add as a submodule and use dependabot to alert for new commits)
|
||||
- old_files/, legacy files used for previous versions
|
||||
- generate_data.py, functions to gather raw data from the lattice estimator
|
||||
- verifiy_curves.py, functions to generate and verify curves from raw data
|
||||
|
||||
.. image:: logo.svg
|
||||
:align: center
|
||||
@@ -21,52 +21,43 @@ This is an example of how to generate the parameter curves, and save them to fil
|
||||
|
||||
::
|
||||
|
||||
sage: load("scripts.py")
|
||||
sage: results = get_zama_curves()
|
||||
sage: save(results, "v0.sobj")
|
||||
./job.sh
|
||||
::
|
||||
|
||||
We can load results files, and find the interpolants.
|
||||
This will generate several data files, {80, 96, 112, 128, 144, 160, 176, 192, 256}.sobj
|
||||
|
||||
To generate the parameter curves from the data files, we run
|
||||
|
||||
`sage verify_curves.py`
|
||||
|
||||
this will generate a list of the form:
|
||||
|
||||
::
|
||||
|
||||
sage: load("scripts.py")
|
||||
sage: interps = []
|
||||
sage: results = load("v0.sobj")
|
||||
sage: for result in results:
|
||||
sage: interps.append(interpolate_result(result, log_q = 64))
|
||||
sage: interps
|
||||
[(-0.040476778656126484, 1.143346508563902),
|
||||
(-0.03417207792207793, 1.4805194805194737),
|
||||
(-0.029681716023268107, 1.752723426758335),
|
||||
(-0.0263748887657055, 2.0121439233304894),
|
||||
(-0.023730136557783763, 2.1537066948924095),
|
||||
(-0.021604493958972515, 2.2696862472846204),
|
||||
(-0.019897520946588438, 2.4423829771964796),
|
||||
(-0.018504919354426233, 2.6634073426215745),
|
||||
(-0.017254242957361113, 2.7353702447139026),
|
||||
(-0.016178309410530816, 2.8493969373734758),
|
||||
(-0.01541034709414119, 3.1982749283836283),
|
||||
(-0.014327640360322604, 2.899270827311096)]
|
||||
[(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450),
|
||||
(-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450),
|
||||
(-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450),
|
||||
(-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450),
|
||||
(-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450),
|
||||
(-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498),
|
||||
(-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551),
|
||||
(-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606),
|
||||
(-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)]
|
||||
::
|
||||
|
||||
Finding the value of n_{alpha} is done manually. We can also verify the interpolants which are generated at the same time:
|
||||
each element is a tuple (a, b, security, P, n_min), where (a,b) are the model
|
||||
parameters, security is the security level, P is a boolean value denoting PASS or
|
||||
FAIL of the verification, and n_min is the smallest reccomended value of `n` to be used.
|
||||
|
||||
Each model outputs a value of sigma, and is of the form:
|
||||
|
||||
::
|
||||
|
||||
# verify the interpolant used for lambda = 256 (which is interps[-1])
|
||||
sage: z = verify_interpolants(interps[-1], (128,2048), 64)
|
||||
[... code runs, can take ~10 mins ...]
|
||||
# find the index corresponding to n_alpha, which is where security drops below the target security level (256 here)
|
||||
sage: n_alpha = find_nalpha(z, 256)
|
||||
653
|
||||
|
||||
# so the model in this case is
|
||||
(-0.014327640360322604, 2.899270827311096, 653)
|
||||
# which corresponds to
|
||||
# sd(n) = max(-0.014327640360322604 * n + 2.899270827311096, -logq + 2), n >= 653
|
||||
f(a, b, n) = max(ceil(a * n + b), -log2(q) + 2)
|
||||
::
|
||||
|
||||
where the -log2(q) + 2 term ensures that we are always using at least two bits of noise.
|
||||
|
||||
Version History
|
||||
-------------------
|
||||
|
||||
@@ -76,14 +67,7 @@ Data for the curves are kept in /data. The following files are present:
|
||||
|
||||
v0: generated using the {usvp, dual, decoding} attacks
|
||||
v0.1: generated using the {mitm, usvp, dual, decoding} attacks
|
||||
v0.2: generated using the lattice estimator
|
||||
::
|
||||
|
||||
TODO List
|
||||
-------------------
|
||||
|
||||
There are several updates which are still required.
|
||||
1. Consider Hybrid attacks (WIP, Michael + Ben are coding up hybrid-dual/hybrid-decoding estimates)
|
||||
2. CI/CD stuff for new pushes to the external LWE Estimator.
|
||||
3. Fully automate the process of finding n_{alpha} for each curve.
|
||||
4. Functionality for q =! 64? This is covered by the curve, but we currently don't account for it in the models, and it needs to be done manually.
|
||||
5. cpp file generation
|
||||
|
||||
BIN
data/v0.2/112.sobj
Normal file
BIN
data/v0.2/128.sobj
Normal file
BIN
data/v0.2/144.sobj
Normal file
BIN
data/v0.2/160.sobj
Normal file
BIN
data/v0.2/176.sobj
Normal file
BIN
data/v0.2/192.sobj
Normal file
BIN
data/v0.2/256.sobj
Normal file
BIN
data/v0.2/80.sobj
Normal file
BIN
data/v0.2/96.sobj
Normal file
@@ -1,8 +1,11 @@
|
||||
import sys
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save
|
||||
from sage.all import oo, save, load, ceil
|
||||
from math import log2
|
||||
import multiprocessing
|
||||
|
||||
def old_models(security_level, sd, logq = 32):
|
||||
|
||||
def old_models(security_level, sd, logq=32):
|
||||
"""
|
||||
Use the old model as a starting point for the data gathering step
|
||||
:param security_level: the security level under consideration
|
||||
@@ -10,12 +13,11 @@ def old_models(security_level, sd, logq = 32):
|
||||
:param logq : the (base 2 log) value of the LWE modulus q
|
||||
"""
|
||||
|
||||
def evaluate_model(sd, a, b):
|
||||
return (sd - b)/a
|
||||
def evaluate_model(a, b, stddev=sd):
|
||||
return (stddev - b)/a
|
||||
|
||||
models = dict()
|
||||
|
||||
# TODO: figure out a way to import these from a datafile, for future version
|
||||
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
|
||||
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
|
||||
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
|
||||
@@ -30,32 +32,37 @@ def old_models(security_level, sd, logq = 32):
|
||||
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
|
||||
|
||||
(a, b) = models["{}".format(security_level)]
|
||||
n_est = evaluate_model(sd, a, b)
|
||||
n_est = evaluate_model(a, b, sd)
|
||||
|
||||
return round(n_est)
|
||||
|
||||
|
||||
def estimate(params, red_cost_model = RC.BDGL16):
|
||||
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
|
||||
"""
|
||||
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
|
||||
:param params: the input LWE parameters
|
||||
:param red_cost_model: the lattice reduction cost model
|
||||
:param skip: attacks to skip
|
||||
"""
|
||||
|
||||
est = LWE.estimate(params, deny_list=("arora-gb", "bkw"), red_cost_model=red_cost_model)
|
||||
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
|
||||
|
||||
return est
|
||||
|
||||
|
||||
def get_security_level(est, dp = 2):
|
||||
def get_security_level(est, dp=2):
|
||||
"""
|
||||
Get the security level lambda from a Lattice Estimator output
|
||||
:param est: the Lattice Estimator output
|
||||
:param dp : the number of decimal places to consider
|
||||
:param dp: the number of decimal places to consider
|
||||
"""
|
||||
attack_costs = []
|
||||
for key in est.keys():
|
||||
# note: key does not need to be specified est vs est.keys()
|
||||
for key in est:
|
||||
attack_costs.append(est[key]["rop"])
|
||||
# get the security level correct to 'dp' decimal places
|
||||
security_level = round(log2(min(attack_costs)), dp)
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
@@ -83,120 +90,106 @@ def automated_param_select_n(params, target_security=128):
|
||||
456
|
||||
"""
|
||||
|
||||
# get an initial estimate
|
||||
# costs = estimate(params)
|
||||
# security_level = get_security_level(costs, 2)
|
||||
# determine if we are above or below the target security level
|
||||
# z = inequality(security_level, target_security)
|
||||
|
||||
# get an estimate based on the prev. model
|
||||
print("n = {}".format(params.n))
|
||||
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
|
||||
# TODO -- is this how we want to deal with the small n issue? Shouldn't the model have this baked in?
|
||||
# we want to start no lower than n = 450
|
||||
n_start = max(n_start, 450)
|
||||
print("n_start = {}".format(n_start))
|
||||
n_start = old_models(target_security, log2(
|
||||
params.Xe.stddev), log2(params.q))
|
||||
# n_start = max(n_start, 450)
|
||||
# TODO: think about throwing an error if the required n < 450
|
||||
|
||||
params = params.updated(n=n_start)
|
||||
print(params)
|
||||
costs2 = estimate(params)
|
||||
security_level = get_security_level(costs2, 2)
|
||||
z = inequality(security_level, target_security)
|
||||
|
||||
|
||||
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
|
||||
while z * security_level < z * target_security:
|
||||
# if params.n > 1024:
|
||||
# we only need to consider powers-of-two in this case
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every 256
|
||||
|
||||
|
||||
|
||||
params = params.updated(n = params.n + z * 8)
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
|
||||
params = params.updated(n=params.n + z * 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
if -1 * params.Xe.stddev > 0:
|
||||
print("target security level is unatainable")
|
||||
print("target security level is unattainable")
|
||||
break
|
||||
|
||||
# final estimate (we went too far in the above loop)
|
||||
if security_level < target_security:
|
||||
# TODO: we should somehow keep the previous estimate stored so that we don't need to compute it twice
|
||||
# if we do this we need to make sure that it works for both sides (i.e. if (i-1) is above or below the
|
||||
# security level
|
||||
|
||||
params = params.updated(n = params.n - z * 8)
|
||||
# we make n larger
|
||||
print("we make n larger")
|
||||
params = params.updated(n=params.n + 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
|
||||
log2(params.Xe.stddev),
|
||||
log2(params.q),
|
||||
security_level))
|
||||
log2(
|
||||
params.Xe.stddev),
|
||||
log2(
|
||||
params.q),
|
||||
security_level))
|
||||
|
||||
# final sanity check so we don't return insecure (or inf) parameters
|
||||
# TODO: figure out inf in new estimator
|
||||
# or security_level == oo:
|
||||
if security_level < target_security:
|
||||
params.updated(n=None)
|
||||
|
||||
return params
|
||||
return params, security_level
|
||||
|
||||
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128]):
|
||||
|
||||
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
:param params_in: a initial set of LWE parameters
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param params: the standard deviation of the LWE error
|
||||
:param target_security: the target number of bits of security, 128 is default
|
||||
|
||||
EXAMPLE:
|
||||
sage: X = generate_parameter_matrix()
|
||||
sage: X
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
|
||||
results = dict()
|
||||
|
||||
# grab min and max value/s of n
|
||||
(sd_min, sd_max) = sd_range
|
||||
|
||||
for lam in target_security_levels:
|
||||
results["{}".format(lam)] = []
|
||||
for sd in range(sd_min, sd_max + 1):
|
||||
print("run for {}".format(lam, sd))
|
||||
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
|
||||
params_out = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
results["{}".format(lam)].append((params_out.n, params_out.q, params_out.Xe.stddev))
|
||||
(params_out, sec) = automated_param_select_n(
|
||||
params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
|
||||
try:
|
||||
results = load("{}.sobj".format(name))
|
||||
except:
|
||||
results = dict()
|
||||
results["{}".format(lam)] = []
|
||||
|
||||
results["{}".format(lam)].append(
|
||||
(params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
|
||||
save(results, "{}.sobj".format(name))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
The top level function which we use to run the experiment
|
||||
|
||||
def test_it():
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
|
||||
D = nd.NoiseDistribution.DiscreteGaussian
|
||||
DEFAULT_PARAMETERS = LWE.Parameters(n=1024, q=2**64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
|
||||
D = ND.DiscreteGaussian
|
||||
vals = range(sd_range[0], sd_range[1])
|
||||
pool = multiprocessing.Pool(2)
|
||||
init_params = LWE.Parameters(
|
||||
n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
|
||||
inputs = [(init_params, (val, val), target_security_levels, name)
|
||||
for val in vals]
|
||||
res = pool.starmap(generate_parameter_matrix, inputs)
|
||||
|
||||
return "done"
|
||||
|
||||
|
||||
# x = estimate(params)
|
||||
# y = get_security_level(x, 2)
|
||||
# print(y)
|
||||
#z1 = automated_param_select_n(schemes.TFHE630.updated(n=786), 128)
|
||||
#print(z1)
|
||||
sd_range = [1,4]
|
||||
print("working...")
|
||||
z3 = generate_parameter_matrix(DEFAULT_PARAMETERS, sd_range=[5, 6], target_security_levels=[128, 192, 256])
|
||||
# TODO: in this function call the initial guess for n is way off (security is ~60-bits instead of close to 128).
|
||||
print(z3)
|
||||
save(z3, "123.sobj")
|
||||
|
||||
return z3
|
||||
|
||||
|
||||
def generate_zama_curves64(sd_range=[2, 60], target_security_levels=[128, 192, 256]):
|
||||
|
||||
D = ND.DiscreteGaussian
|
||||
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
|
||||
raw_data = generate_parameter_matrix(init_params, sd_range=sd_range, target_security_levels=target_security_levels)
|
||||
|
||||
return raw_data
|
||||
|
||||
|
||||
generate_zama_curves64()
|
||||
# The script runs the following commands
|
||||
# grab values of the command-line input arguments
|
||||
a = int(sys.argv[1])
|
||||
b = int(sys.argv[2])
|
||||
c = int(sys.argv[3])
|
||||
# run the code
|
||||
generate_zama_curves64(sd_range=(b, c), target_security_levels=[
|
||||
a], name="{}".format(a))
|
||||
66
job.sh
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/bin/sh
|
||||
# 80-bits
|
||||
sage generate_data.py 80 2 12
|
||||
sage generate_data.py 80 12 22
|
||||
sage generate_data.py 80 22 32
|
||||
sage generate_data.py 80 32 42
|
||||
sage generate_data.py 80 42 52
|
||||
sage generate_data.py 80 52 59
|
||||
# 96-bits
|
||||
sage generate_data.py 96 2 12
|
||||
sage generate_data.py 96 12 22
|
||||
sage generate_data.py 96 22 32
|
||||
sage generate_data.py 96 32 42
|
||||
sage generate_data.py 96 42 52
|
||||
sage generate_data.py 96 52 59
|
||||
# 112-bits
|
||||
sage generate_data.py 112 2 12
|
||||
sage generate_data.py 112 12 22
|
||||
sage generate_data.py 112 22 32
|
||||
sage generate_data.py 112 32 42
|
||||
sage generate_data.py 112 42 52
|
||||
sage generate_data.py 112 52 59
|
||||
# 128-bits
|
||||
sage generate_data.py 128 2 12
|
||||
sage generate_data.py 128 12 22
|
||||
sage generate_data.py 128 22 32
|
||||
sage generate_data.py 128 32 42
|
||||
sage generate_data.py 128 42 52
|
||||
sage generate_data.py 128 52 59
|
||||
# 144-bits
|
||||
sage generate_data.py 144 2 12
|
||||
sage generate_data.py 144 12 22
|
||||
sage generate_data.py 144 22 32
|
||||
sage generate_data.py 144 32 42
|
||||
sage generate_data.py 144 42 52
|
||||
sage generate_data.py 144 52 59
|
||||
# 160-bits
|
||||
sage generate_data.py 160 2 12
|
||||
sage generate_data.py 160 12 22
|
||||
sage generate_data.py 160 22 32
|
||||
sage generate_data.py 160 32 42
|
||||
sage generate_data.py 160 42 52
|
||||
sage generate_data.py 160 52 59
|
||||
# 176-bits
|
||||
sage generate_data.py 176 2 12
|
||||
sage generate_data.py 176 12 22
|
||||
sage generate_data.py 176 22 32
|
||||
sage generate_data.py 176 32 42
|
||||
sage generate_data.py 176 42 52
|
||||
sage generate_data.py 176 52 59
|
||||
# 192-bits
|
||||
sage generate_data.py 192 2 12
|
||||
sage generate_data.py 192 12 22
|
||||
sage generate_data.py 192 22 32
|
||||
sage generate_data.py 192 32 42
|
||||
sage generate_data.py 192 42 52
|
||||
sage generate_data.py 192 52 59
|
||||
# 256-bits
|
||||
sage generate_data.py 256 2 12
|
||||
sage generate_data.py 256 12 22
|
||||
sage generate_data.py 256 22 32
|
||||
sage generate_data.py 256 32 42
|
||||
sage generate_data.py 256 42 52
|
||||
sage generate_data.py 256 52 59
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 74 KiB After Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 20 KiB |
17
old_files/memory_tests/test.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save
|
||||
|
||||
def test():
|
||||
|
||||
# code
|
||||
D = ND.DiscreteGaussian
|
||||
params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT')
|
||||
|
||||
names = [params, params.updated(n=761), params.updated(q=2 ** 65), params.updated(n=762)]
|
||||
|
||||
for name in names:
|
||||
x = LWE.estimate(name, deny_list=("arora-gb", "bkw"))
|
||||
|
||||
return 0
|
||||
|
||||
test()
|
||||
31
old_files/memory_tests/test2.py
Normal file
@@ -0,0 +1,31 @@
|
||||
|
||||
from multiprocessing import *
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save
|
||||
|
||||
|
||||
def test_memory(x):
|
||||
print("doing job...")
|
||||
print(x)
|
||||
y = LWE.estimate(x, deny_list=("arora-gb", "bkw"))
|
||||
return y
|
||||
|
||||
if __name__ == "__main__":
|
||||
D = ND.DiscreteGaussian
|
||||
params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT')
|
||||
|
||||
names = [params, params.updated(n=761), params.updated(q=2**65), params.updated(n=762)]
|
||||
procs = []
|
||||
proc = Process(target=print_func)
|
||||
procs.append(proc)
|
||||
proc.start()
|
||||
p = Pool(1)
|
||||
|
||||
for name in names:
|
||||
proc = Process(target=test_memory, args=(name,))
|
||||
procs.append(proc)
|
||||
proc.start()
|
||||
proc.join()
|
||||
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
471
old_files/new_scripts.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save, load, ceil
|
||||
from math import log2
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def old_models(security_level, sd, logq=32):
|
||||
"""
|
||||
Use the old model as a starting point for the data gathering step
|
||||
:param security_level: the security level under consideration
|
||||
:param sd : the standard deviation of the LWE error distribution Xe
|
||||
:param logq : the (base 2 log) value of the LWE modulus q
|
||||
"""
|
||||
|
||||
def evaluate_model(a, b, stddev=sd):
|
||||
return (stddev - b)/a
|
||||
|
||||
models = dict()
|
||||
|
||||
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
|
||||
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
|
||||
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
|
||||
models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq)
|
||||
models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq)
|
||||
models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq)
|
||||
models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq)
|
||||
models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq)
|
||||
models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq)
|
||||
models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq)
|
||||
models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq)
|
||||
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
|
||||
|
||||
(a, b) = models["{}".format(security_level)]
|
||||
n_est = evaluate_model(a, b, sd)
|
||||
|
||||
return round(n_est)
|
||||
|
||||
|
||||
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
|
||||
"""
|
||||
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
|
||||
:param params: the input LWE parameters
|
||||
:param red_cost_model: the lattice reduction cost model
|
||||
:param skip: attacks to skip
|
||||
"""
|
||||
|
||||
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
|
||||
|
||||
return est
|
||||
|
||||
|
||||
def get_security_level(est, dp=2):
|
||||
"""
|
||||
Get the security level lambda from a Lattice Estimator output
|
||||
:param est: the Lattice Estimator output
|
||||
:param dp: the number of decimal places to consider
|
||||
"""
|
||||
attack_costs = []
|
||||
# note: key does not need to be specified est vs est.keys()
|
||||
for key in est:
|
||||
attack_costs.append(est[key]["rop"])
|
||||
# get the security level correct to 'dp' decimal places
|
||||
security_level = round(log2(min(attack_costs)), dp)
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
def inequality(x, y):
|
||||
""" A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
|
||||
:param x: the LHS of the inequality
|
||||
:param y: the RHS of the inequality
|
||||
"""
|
||||
if x <= y:
|
||||
return 1
|
||||
|
||||
if x > y:
|
||||
return -1
|
||||
|
||||
|
||||
def automated_param_select_n(params, target_security=128):
|
||||
""" A function used to generate the smallest value of n which allows for
|
||||
target_security bits of security, for the input values of (params.Xe.stddev,params.q)
|
||||
:param params: the standard deviation of the error
|
||||
:param target_security: the target number of bits of security, 128 is default
|
||||
|
||||
EXAMPLE:
|
||||
sage: X = automated_param_select_n(Kyber512, target_security = 128)
|
||||
sage: X
|
||||
456
|
||||
"""
|
||||
|
||||
# get an estimate based on the prev. model
|
||||
print("n = {}".format(params.n))
|
||||
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
|
||||
# n_start = max(n_start, 450)
|
||||
# TODO: think about throwing an error if the required n < 450
|
||||
|
||||
params = params.updated(n=n_start)
|
||||
costs2 = estimate(params)
|
||||
security_level = get_security_level(costs2, 2)
|
||||
z = inequality(security_level, target_security)
|
||||
|
||||
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
|
||||
while z * security_level < z * target_security:
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
|
||||
params = params.updated(n = params.n + z * 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
if -1 * params.Xe.stddev > 0:
|
||||
print("target security level is unattainable")
|
||||
break
|
||||
|
||||
# final estimate (we went too far in the above loop)
|
||||
if security_level < target_security:
|
||||
# we make n larger
|
||||
print("we make n larger")
|
||||
params = params.updated(n=params.n + 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
|
||||
log2(params.Xe.stddev),
|
||||
log2(params.q),
|
||||
security_level))
|
||||
|
||||
if security_level < target_security:
|
||||
params.updated(n=None)
|
||||
|
||||
return params, security_level
|
||||
|
||||
|
||||
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
:param params_in: a initial set of LWE parameters
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
|
||||
(sd_min, sd_max) = sd_range
|
||||
for lam in target_security_levels:
|
||||
for sd in range(sd_min, sd_max + 1):
|
||||
print("run for {}".format(lam, sd))
|
||||
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
|
||||
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
|
||||
try:
|
||||
results = load("{}.sobj".format(name))
|
||||
except:
|
||||
results = dict()
|
||||
results["{}".format(lam)] = []
|
||||
|
||||
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
|
||||
save(results, "{}.sobj".format(name))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
The top level function which we use to run the experiment
|
||||
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
|
||||
D = ND.DiscreteGaussian
|
||||
vals = range(sd_range[0], sd_range[1])
|
||||
pool = multiprocessing.Pool(2)
|
||||
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
|
||||
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
|
||||
res = pool.starmap(generate_parameter_matrix, inputs)
|
||||
|
||||
return "done"
|
||||
|
||||
|
||||
# The script runs the following commands
|
||||
import sys
|
||||
# grab values of the command-line input arguments
|
||||
a = int(sys.argv[1])
|
||||
b = int(sys.argv[2])
|
||||
c = int(sys.argv[3])
|
||||
# run the code
|
||||
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))
|
||||
|
||||
from estimator_new import *
|
||||
from sage.all import oo, save, load
|
||||
from math import log2
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def old_models(security_level, sd, logq=32):
|
||||
"""
|
||||
Use the old model as a starting point for the data gathering step
|
||||
:param security_level: the security level under consideration
|
||||
:param sd : the standard deviation of the LWE error distribution Xe
|
||||
:param logq : the (base 2 log) value of the LWE modulus q
|
||||
"""
|
||||
|
||||
def evaluate_model(a, b, stddev=sd):
|
||||
return (stddev - b)/a
|
||||
|
||||
models = dict()
|
||||
|
||||
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
|
||||
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
|
||||
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
|
||||
models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq)
|
||||
models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq)
|
||||
models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq)
|
||||
models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq)
|
||||
models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq)
|
||||
models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq)
|
||||
models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq)
|
||||
models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq)
|
||||
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
|
||||
|
||||
(a, b) = models["{}".format(security_level)]
|
||||
n_est = evaluate_model(a, b, sd)
|
||||
|
||||
return round(n_est)
|
||||
|
||||
|
||||
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
|
||||
"""
|
||||
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
|
||||
:param params: the input LWE parameters
|
||||
:param red_cost_model: the lattice reduction cost model
|
||||
:param skip: attacks to skip
|
||||
"""
|
||||
|
||||
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
|
||||
|
||||
return est
|
||||
|
||||
|
||||
def get_security_level(est, dp=2):
|
||||
"""
|
||||
Get the security level lambda from a Lattice Estimator output
|
||||
:param est: the Lattice Estimator output
|
||||
:param dp: the number of decimal places to consider
|
||||
"""
|
||||
attack_costs = []
|
||||
# note: key does not need to be specified est vs est.keys()
|
||||
for key in est:
|
||||
attack_costs.append(est[key]["rop"])
|
||||
# get the security level correct to 'dp' decimal places
|
||||
security_level = round(log2(min(attack_costs)), dp)
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
def inequality(x, y):
|
||||
""" A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
|
||||
:param x: the LHS of the inequality
|
||||
:param y: the RHS of the inequality
|
||||
"""
|
||||
if x <= y:
|
||||
return 1
|
||||
|
||||
if x > y:
|
||||
return -1
|
||||
|
||||
|
||||
def automated_param_select_n(params, target_security=128):
|
||||
""" A function used to generate the smallest value of n which allows for
|
||||
target_security bits of security, for the input values of (params.Xe.stddev,params.q)
|
||||
:param params: the standard deviation of the error
|
||||
:param target_security: the target number of bits of security, 128 is default
|
||||
|
||||
EXAMPLE:
|
||||
sage: X = automated_param_select_n(Kyber512, target_security = 128)
|
||||
sage: X
|
||||
456
|
||||
"""
|
||||
|
||||
# get an estimate based on the prev. model
|
||||
print("n = {}".format(params.n))
|
||||
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
|
||||
# n_start = max(n_start, 450)
|
||||
# TODO: think about throwing an error if the required n < 450
|
||||
|
||||
params = params.updated(n=n_start)
|
||||
costs2 = estimate(params)
|
||||
security_level = get_security_level(costs2, 2)
|
||||
z = inequality(security_level, target_security)
|
||||
|
||||
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
|
||||
while z * security_level < z * target_security:
|
||||
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
|
||||
params = params.updated(n = params.n + z * 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
if -1 * params.Xe.stddev > 0:
|
||||
print("target security level is unattainable")
|
||||
break
|
||||
|
||||
# final estimate (we went too far in the above loop)
|
||||
if security_level < target_security:
|
||||
# we make n larger
|
||||
print("we make n larger")
|
||||
params = params.updated(n=params.n + 8)
|
||||
costs = estimate(params)
|
||||
security_level = get_security_level(costs, 2)
|
||||
|
||||
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
|
||||
log2(params.Xe.stddev),
|
||||
log2(params.q),
|
||||
security_level))
|
||||
|
||||
if security_level < target_security:
|
||||
params.updated(n=None)
|
||||
|
||||
return params, security_level
|
||||
|
||||
|
||||
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
:param params_in: a initial set of LWE parameters
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
|
||||
(sd_min, sd_max) = sd_range
|
||||
for lam in target_security_levels:
|
||||
for sd in range(sd_min, sd_max + 1):
|
||||
print("run for {}".format(lam, sd))
|
||||
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
|
||||
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
|
||||
|
||||
try:
|
||||
results = load("{}.sobj".format(name))
|
||||
except:
|
||||
results = dict()
|
||||
results["{}".format(lam)] = []
|
||||
|
||||
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
|
||||
save(results, "{}.sobj".format(name))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
|
||||
"""
|
||||
The top level function which we use to run the experiment
|
||||
|
||||
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
|
||||
:param target_security_levels: a list of the target number of bits of security, 128 is default
|
||||
:param name: a name to save the file
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
|
||||
D = ND.DiscreteGaussian
|
||||
vals = range(sd_range[0], sd_range[1])
|
||||
pool = multiprocessing.Pool(2)
|
||||
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
|
||||
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
|
||||
res = pool.starmap(generate_parameter_matrix, inputs)
|
||||
|
||||
return "done"
|
||||
|
||||
|
||||
# The script runs the following commands
|
||||
import sys
|
||||
# grab values of the command-line input arguments
|
||||
a = int(sys.argv[1])
|
||||
b = int(sys.argv[2])
|
||||
c = int(sys.argv[3])
|
||||
# run the code
|
||||
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))
|
||||
|
||||
import numpy as np
|
||||
from sage.all import save, load
|
||||
|
||||
def sort_data(security_level):
|
||||
from operator import itemgetter
|
||||
|
||||
# step 1. load the data
|
||||
X = load("{}.sobj".format(security_level))
|
||||
|
||||
# step 2. sort by SD
|
||||
x = sorted(X["{}".format(security_level)], key = itemgetter(2))
|
||||
|
||||
# step3. replace the sorted value
|
||||
X["{}".format(security_level)] = x
|
||||
|
||||
return X
|
||||
|
||||
def generate_curve(security_level):
|
||||
|
||||
# step 1. get the data
|
||||
X = sort_data(security_level)
|
||||
|
||||
# step 2. group the n and sigma data into lists
|
||||
N = []
|
||||
SD = []
|
||||
for x in X["{}".format(security_level)]:
|
||||
N.append(x[0])
|
||||
SD.append(x[2] + 0.5)
|
||||
|
||||
# step 3. perform interpolation and return coefficients
|
||||
(a,b) = np.polyfit(N, SD, 1)
|
||||
|
||||
return a, b
|
||||
|
||||
|
||||
def verify_curve(security_level, a = None, b = None):
|
||||
|
||||
# step 1. get the table and max values of n, sd
|
||||
X = sort_data(security_level)
|
||||
n_max = X["{}".format(security_level)][0][0]
|
||||
sd_max = X["{}".format(security_level)][-1][2]
|
||||
|
||||
# step 2. a function to get model values
|
||||
def f_model(a, b, n):
|
||||
return ceil(a * n + b)
|
||||
|
||||
# step 3. a function to get table values
|
||||
def f_table(table, n):
|
||||
for i in range(len(table)):
|
||||
n_val = table[i][0]
|
||||
if n < n_val:
|
||||
pass
|
||||
else:
|
||||
j = i
|
||||
break
|
||||
|
||||
# now j is the correct index, we return the corresponding sd
|
||||
return table[j][2]
|
||||
|
||||
# step 3. for each n, check whether we satisfy the table
|
||||
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
|
||||
print(n_min)
|
||||
print(n_max)
|
||||
|
||||
for n in range(n_max, n_min, - 1):
|
||||
model_sd = f_model(a, b, n)
|
||||
table_sd = f_table(X["{}".format(security_level)], n)
|
||||
print(n , table_sd, model_sd, model_sd >= table_sd)
|
||||
|
||||
if table_sd > model_sd:
|
||||
print("MODEL FAILS at n = {}".format(n))
|
||||
return "FAIL"
|
||||
|
||||
return "PASS", n_min
|
||||
|
||||
|
||||
def generate_and_verify(security_levels, log_q, name = "verified_curves"):
|
||||
|
||||
data = []
|
||||
|
||||
for sec in security_levels:
|
||||
print("WE GO FOR {}".format(sec))
|
||||
# generate the model for security level sec
|
||||
(a_sec, b_sec) = generate_curve(sec)
|
||||
# verify the model for security level sec
|
||||
res = verify_curve(sec, a_sec, b_sec)
|
||||
# append the information into a list
|
||||
data.append((a_sec, b_sec - log_q, sec, res[0], res[1]))
|
||||
save(data, "{}.sobj".format(name))
|
||||
|
||||
return data
|
||||
|
||||
# To verify the curves we use
|
||||
generate_and_verify([80, 96, 112, 128, 144, 160, 176, 192, 256], log_q = 64)
|
||||
|
||||
BIN
verified_curves.sobj
Normal file
9
verified_curves.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
[(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450),
|
||||
(-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450),
|
||||
(-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450),
|
||||
(-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450),
|
||||
(-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450),
|
||||
(-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498),
|
||||
(-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551),
|
||||
(-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606),
|
||||
(-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)]
|
||||
98
verify_curves.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
from sage.all import save, load, ceil
|
||||
|
||||
|
||||
def sort_data(security_level):
|
||||
from operator import itemgetter
|
||||
|
||||
# step 1. load the data
|
||||
X = load("{}.sobj".format(security_level))
|
||||
|
||||
# step 2. sort by SD
|
||||
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
|
||||
|
||||
# step3. replace the sorted value
|
||||
X["{}".format(security_level)] = x
|
||||
|
||||
return X
|
||||
|
||||
|
||||
def generate_curve(security_level):
|
||||
|
||||
# step 1. get the data
|
||||
X = sort_data(security_level)
|
||||
|
||||
# step 2. group the n and sigma data into lists
|
||||
N = []
|
||||
SD = []
|
||||
for x in X["{}".format(security_level)]:
|
||||
N.append(x[0])
|
||||
SD.append(x[2] + 0.5)
|
||||
|
||||
# step 3. perform interpolation and return coefficients
|
||||
(a, b) = np.polyfit(N, SD, 1)
|
||||
|
||||
return a, b
|
||||
|
||||
|
||||
def verify_curve(security_level, a=None, b=None):
|
||||
|
||||
# step 1. get the table and max values of n, sd
|
||||
X = sort_data(security_level)
|
||||
n_max = X["{}".format(security_level)][0][0]
|
||||
sd_max = X["{}".format(security_level)][-1][2]
|
||||
|
||||
# step 2. a function to get model values
|
||||
def f_model(a, b, n):
|
||||
return ceil(a * n + b)
|
||||
|
||||
# step 3. a function to get table values
|
||||
def f_table(table, n):
|
||||
for i in range(len(table)):
|
||||
n_val = table[i][0]
|
||||
if n < n_val:
|
||||
pass
|
||||
else:
|
||||
j = i
|
||||
break
|
||||
|
||||
# now j is the correct index, we return the corresponding sd
|
||||
return table[j][2]
|
||||
|
||||
# step 3. for each n, check whether we satisfy the table
|
||||
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
|
||||
print(n_min)
|
||||
print(n_max)
|
||||
|
||||
for n in range(n_max, n_min, - 1):
|
||||
model_sd = f_model(a, b, n)
|
||||
table_sd = f_table(X["{}".format(security_level)], n)
|
||||
print(n, table_sd, model_sd, model_sd >= table_sd)
|
||||
|
||||
if table_sd > model_sd:
|
||||
print("MODEL FAILS at n = {}".format(n))
|
||||
return "FAIL"
|
||||
|
||||
return "PASS", n_min
|
||||
|
||||
|
||||
def generate_and_verify(security_levels, log_q, name="verified_curves"):
|
||||
|
||||
data = []
|
||||
|
||||
for sec in security_levels:
|
||||
print("WE GO FOR {}".format(sec))
|
||||
# generate the model for security level sec
|
||||
(a_sec, b_sec) = generate_curve(sec)
|
||||
# verify the model for security level sec
|
||||
res = verify_curve(sec, a_sec, b_sec)
|
||||
# append the information into a list
|
||||
data.append((a_sec, b_sec - log_q, sec, res[0], res[1]))
|
||||
save(data, "{}.sobj".format(name))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
data = generate_and_verify(
|
||||
[80, 96, 112, 128, 144, 160, 176, 192, 256], log_q=64)
|
||||
print(data)
|
||||