Merge pull request #9 from zama-ai/update_estimator

Update estimator
This commit is contained in:
Ben
2022-06-24 13:57:07 +01:00
committed by GitHub
38 changed files with 804 additions and 134 deletions

View File

@@ -19,4 +19,5 @@ jobs:
- name: PEP8
run: |
pip install --upgrade pyproject-flake8
flake8 new_scripts.py
flake8 generate_data.py
flake8 verify_curves.py

BIN
112.sobj Normal file

Binary file not shown.

BIN
128.sobj Normal file

Binary file not shown.

BIN
144.sobj Normal file

Binary file not shown.

BIN
160.sobj Normal file

Binary file not shown.

BIN
176.sobj Normal file

Binary file not shown.

BIN
192.sobj Normal file

Binary file not shown.

BIN
256.sobj Normal file

Binary file not shown.

BIN
80.sobj Normal file

Binary file not shown.

BIN
96.sobj Normal file

Binary file not shown.

View File

@@ -3,12 +3,12 @@ Parameter curves for Concrete
This Github repository contains the code needed to generate the Parameter curves used inside Zama. The repository contains the following files:
- cpp/, Python scripts to generate a cpp file containing the parameter curves
- cpp/, Python scripts to generate a cpp file containing the parameter curves (needs updating)
- data/, a folder containing the data generated for previous curves.
- estimator/, Zama's internal version of the LWE Estimator
- figs/, a folder containing various figures related to the parameter curves
- scripts.py, a copy of all scripts required to generate the parameter curves
- a variety of other python files, used for estimating the security of previous Concrete parameter sets
- estimator_new/, the Lattice estimator (TODO: add as a submodule and use dependabot to alert for new commits)
- old_files/, legacy files used for previous versions
- generate_data.py, functions to gather raw data from the lattice estimator
- verifiy_curves.py, functions to generate and verify curves from raw data
.. image:: logo.svg
:align: center
@@ -21,52 +21,43 @@ This is an example of how to generate the parameter curves, and save them to fil
::
sage: load("scripts.py")
sage: results = get_zama_curves()
sage: save(results, "v0.sobj")
./job.sh
::
We can load results files, and find the interpolants.
This will generate several data files, {80, 96, 112, 128, 144, 160, 176, 192, 256}.sobj
To generate the parameter curves from the data files, we run
`sage verify_curves.py`
this will generate a list of the form:
::
sage: load("scripts.py")
sage: interps = []
sage: results = load("v0.sobj")
sage: for result in results:
sage: interps.append(interpolate_result(result, log_q = 64))
sage: interps
[(-0.040476778656126484, 1.143346508563902),
(-0.03417207792207793, 1.4805194805194737),
(-0.029681716023268107, 1.752723426758335),
(-0.0263748887657055, 2.0121439233304894),
(-0.023730136557783763, 2.1537066948924095),
(-0.021604493958972515, 2.2696862472846204),
(-0.019897520946588438, 2.4423829771964796),
(-0.018504919354426233, 2.6634073426215745),
(-0.017254242957361113, 2.7353702447139026),
(-0.016178309410530816, 2.8493969373734758),
(-0.01541034709414119, 3.1982749283836283),
(-0.014327640360322604, 2.899270827311096)]
[(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450),
(-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450),
(-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450),
(-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450),
(-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450),
(-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498),
(-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551),
(-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606),
(-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)]
::
Finding the value of n_{alpha} is done manually. We can also verify the interpolants which are generated at the same time:
each element is a tuple (a, b, security, P, n_min), where (a,b) are the model
parameters, security is the security level, P is a boolean value denoting PASS or
FAIL of the verification, and n_min is the smallest reccomended value of `n` to be used.
Each model outputs a value of sigma, and is of the form:
::
# verify the interpolant used for lambda = 256 (which is interps[-1])
sage: z = verify_interpolants(interps[-1], (128,2048), 64)
[... code runs, can take ~10 mins ...]
# find the index corresponding to n_alpha, which is where security drops below the target security level (256 here)
sage: n_alpha = find_nalpha(z, 256)
653
# so the model in this case is
(-0.014327640360322604, 2.899270827311096, 653)
# which corresponds to
# sd(n) = max(-0.014327640360322604 * n + 2.899270827311096, -logq + 2), n >= 653
f(a, b, n) = max(ceil(a * n + b), -log2(q) + 2)
::
where the -log2(q) + 2 term ensures that we are always using at least two bits of noise.
Version History
-------------------
@@ -76,14 +67,7 @@ Data for the curves are kept in /data. The following files are present:
v0: generated using the {usvp, dual, decoding} attacks
v0.1: generated using the {mitm, usvp, dual, decoding} attacks
v0.2: generated using the lattice estimator
::
TODO List
-------------------
There are several updates which are still required.
1. Consider Hybrid attacks (WIP, Michael + Ben are coding up hybrid-dual/hybrid-decoding estimates)
2. CI/CD stuff for new pushes to the external LWE Estimator.
3. Fully automate the process of finding n_{alpha} for each curve.
4. Functionality for q =! 64? This is covered by the curve, but we currently don't account for it in the models, and it needs to be done manually.
5. cpp file generation

BIN
data/v0.2/112.sobj Normal file

Binary file not shown.

BIN
data/v0.2/128.sobj Normal file

Binary file not shown.

BIN
data/v0.2/144.sobj Normal file

Binary file not shown.

BIN
data/v0.2/160.sobj Normal file

Binary file not shown.

BIN
data/v0.2/176.sobj Normal file

Binary file not shown.

BIN
data/v0.2/192.sobj Normal file

Binary file not shown.

BIN
data/v0.2/256.sobj Normal file

Binary file not shown.

BIN
data/v0.2/80.sobj Normal file

Binary file not shown.

BIN
data/v0.2/96.sobj Normal file

Binary file not shown.

View File

@@ -1,8 +1,11 @@
import sys
from estimator_new import *
from sage.all import oo, save
from sage.all import oo, save, load, ceil
from math import log2
import multiprocessing
def old_models(security_level, sd, logq = 32):
def old_models(security_level, sd, logq=32):
"""
Use the old model as a starting point for the data gathering step
:param security_level: the security level under consideration
@@ -10,12 +13,11 @@ def old_models(security_level, sd, logq = 32):
:param logq : the (base 2 log) value of the LWE modulus q
"""
def evaluate_model(sd, a, b):
return (sd - b)/a
def evaluate_model(a, b, stddev=sd):
return (stddev - b)/a
models = dict()
# TODO: figure out a way to import these from a datafile, for future version
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
@@ -30,32 +32,37 @@ def old_models(security_level, sd, logq = 32):
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
(a, b) = models["{}".format(security_level)]
n_est = evaluate_model(sd, a, b)
n_est = evaluate_model(a, b, sd)
return round(n_est)
def estimate(params, red_cost_model = RC.BDGL16):
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
"""
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
:param params: the input LWE parameters
:param red_cost_model: the lattice reduction cost model
:param skip: attacks to skip
"""
est = LWE.estimate(params, deny_list=("arora-gb", "bkw"), red_cost_model=red_cost_model)
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
def get_security_level(est, dp = 2):
def get_security_level(est, dp=2):
"""
Get the security level lambda from a Lattice Estimator output
:param est: the Lattice Estimator output
:param dp : the number of decimal places to consider
:param dp: the number of decimal places to consider
"""
attack_costs = []
for key in est.keys():
# note: key does not need to be specified est vs est.keys()
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
@@ -83,120 +90,106 @@ def automated_param_select_n(params, target_security=128):
456
"""
# get an initial estimate
# costs = estimate(params)
# security_level = get_security_level(costs, 2)
# determine if we are above or below the target security level
# z = inequality(security_level, target_security)
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# TODO -- is this how we want to deal with the small n issue? Shouldn't the model have this baked in?
# we want to start no lower than n = 450
n_start = max(n_start, 450)
print("n_start = {}".format(n_start))
n_start = old_models(target_security, log2(
params.Xe.stddev), log2(params.q))
# n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
params = params.updated(n=n_start)
print(params)
costs2 = estimate(params)
security_level = get_security_level(costs2, 2)
z = inequality(security_level, target_security)
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
while z * security_level < z * target_security:
# if params.n > 1024:
# we only need to consider powers-of-two in this case
# TODO: fill in this case! For n > 1024 we only need to consider every 256
params = params.updated(n = params.n + z * 8)
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
params = params.updated(n=params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
if -1 * params.Xe.stddev > 0:
print("target security level is unatainable")
print("target security level is unattainable")
break
# final estimate (we went too far in the above loop)
if security_level < target_security:
# TODO: we should somehow keep the previous estimate stored so that we don't need to compute it twice
# if we do this we need to make sure that it works for both sides (i.e. if (i-1) is above or below the
# security level
params = params.updated(n = params.n - z * 8)
# we make n larger
print("we make n larger")
params = params.updated(n=params.n + 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
log2(params.Xe.stddev),
log2(params.q),
security_level))
log2(
params.Xe.stddev),
log2(
params.q),
security_level))
# final sanity check so we don't return insecure (or inf) parameters
# TODO: figure out inf in new estimator
# or security_level == oo:
if security_level < target_security:
params.updated(n=None)
return params
return params, security_level
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128]):
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
"""
:param params_in: a initial set of LWE parameters
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param params: the standard deviation of the LWE error
:param target_security: the target number of bits of security, 128 is default
EXAMPLE:
sage: X = generate_parameter_matrix()
sage: X
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
results = dict()
# grab min and max value/s of n
(sd_min, sd_max) = sd_range
for lam in target_security_levels:
results["{}".format(lam)] = []
for sd in range(sd_min, sd_max + 1):
print("run for {}".format(lam, sd))
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
params_out = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
results["{}".format(lam)].append((params_out.n, params_out.q, params_out.Xe.stddev))
(params_out, sec) = automated_param_select_n(
params_in.updated(Xe=Xe_new), target_security=lam)
try:
results = load("{}.sobj".format(name))
except:
results = dict()
results["{}".format(lam)] = []
results["{}".format(lam)].append(
(params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
save(results, "{}.sobj".format(name))
return results
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
"""
The top level function which we use to run the experiment
def test_it():
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
if __name__ == '__main__':
D = nd.NoiseDistribution.DiscreteGaussian
DEFAULT_PARAMETERS = LWE.Parameters(n=1024, q=2**64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
D = ND.DiscreteGaussian
vals = range(sd_range[0], sd_range[1])
pool = multiprocessing.Pool(2)
init_params = LWE.Parameters(
n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
inputs = [(init_params, (val, val), target_security_levels, name)
for val in vals]
res = pool.starmap(generate_parameter_matrix, inputs)
return "done"
# x = estimate(params)
# y = get_security_level(x, 2)
# print(y)
#z1 = automated_param_select_n(schemes.TFHE630.updated(n=786), 128)
#print(z1)
sd_range = [1,4]
print("working...")
z3 = generate_parameter_matrix(DEFAULT_PARAMETERS, sd_range=[5, 6], target_security_levels=[128, 192, 256])
# TODO: in this function call the initial guess for n is way off (security is ~60-bits instead of close to 128).
print(z3)
save(z3, "123.sobj")
return z3
def generate_zama_curves64(sd_range=[2, 60], target_security_levels=[128, 192, 256]):
D = ND.DiscreteGaussian
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(131072.00), m=oo, tag='TFHE_DEFAULT')
raw_data = generate_parameter_matrix(init_params, sd_range=sd_range, target_security_levels=target_security_levels)
return raw_data
generate_zama_curves64()
# The script runs the following commands
# grab values of the command-line input arguments
a = int(sys.argv[1])
b = int(sys.argv[2])
c = int(sys.argv[3])
# run the code
generate_zama_curves64(sd_range=(b, c), target_security_levels=[
a], name="{}".format(a))

66
job.sh Executable file
View File

@@ -0,0 +1,66 @@
#!/bin/sh
# 80-bits
sage generate_data.py 80 2 12
sage generate_data.py 80 12 22
sage generate_data.py 80 22 32
sage generate_data.py 80 32 42
sage generate_data.py 80 42 52
sage generate_data.py 80 52 59
# 96-bits
sage generate_data.py 96 2 12
sage generate_data.py 96 12 22
sage generate_data.py 96 22 32
sage generate_data.py 96 32 42
sage generate_data.py 96 42 52
sage generate_data.py 96 52 59
# 112-bits
sage generate_data.py 112 2 12
sage generate_data.py 112 12 22
sage generate_data.py 112 22 32
sage generate_data.py 112 32 42
sage generate_data.py 112 42 52
sage generate_data.py 112 52 59
# 128-bits
sage generate_data.py 128 2 12
sage generate_data.py 128 12 22
sage generate_data.py 128 22 32
sage generate_data.py 128 32 42
sage generate_data.py 128 42 52
sage generate_data.py 128 52 59
# 144-bits
sage generate_data.py 144 2 12
sage generate_data.py 144 12 22
sage generate_data.py 144 22 32
sage generate_data.py 144 32 42
sage generate_data.py 144 42 52
sage generate_data.py 144 52 59
# 160-bits
sage generate_data.py 160 2 12
sage generate_data.py 160 12 22
sage generate_data.py 160 22 32
sage generate_data.py 160 32 42
sage generate_data.py 160 42 52
sage generate_data.py 160 52 59
# 176-bits
sage generate_data.py 176 2 12
sage generate_data.py 176 12 22
sage generate_data.py 176 22 32
sage generate_data.py 176 32 42
sage generate_data.py 176 42 52
sage generate_data.py 176 52 59
# 192-bits
sage generate_data.py 192 2 12
sage generate_data.py 192 12 22
sage generate_data.py 192 22 32
sage generate_data.py 192 32 42
sage generate_data.py 192 42 52
sage generate_data.py 192 52 59
# 256-bits
sage generate_data.py 256 2 12
sage generate_data.py 256 12 22
sage generate_data.py 256 22 32
sage generate_data.py 256 32 42
sage generate_data.py 256 42 52
sage generate_data.py 256 52 59

View File

Before

Width:  |  Height:  |  Size: 74 KiB

After

Width:  |  Height:  |  Size: 74 KiB

View File

Before

Width:  |  Height:  |  Size: 34 KiB

After

Width:  |  Height:  |  Size: 34 KiB

View File

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 32 KiB

View File

Before

Width:  |  Height:  |  Size: 26 KiB

After

Width:  |  Height:  |  Size: 26 KiB

View File

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 20 KiB

View File

@@ -0,0 +1,17 @@
from estimator_new import *
from sage.all import oo, save
def test():
# code
D = ND.DiscreteGaussian
params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT')
names = [params, params.updated(n=761), params.updated(q=2 ** 65), params.updated(n=762)]
for name in names:
x = LWE.estimate(name, deny_list=("arora-gb", "bkw"))
return 0
test()

View File

@@ -0,0 +1,31 @@
from multiprocessing import *
from estimator_new import *
from sage.all import oo, save
def test_memory(x):
print("doing job...")
print(x)
y = LWE.estimate(x, deny_list=("arora-gb", "bkw"))
return y
if __name__ == "__main__":
D = ND.DiscreteGaussian
params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2**57), m=oo, tag='TFHE_DEFAULT')
names = [params, params.updated(n=761), params.updated(q=2**65), params.updated(n=762)]
procs = []
proc = Process(target=print_func)
procs.append(proc)
proc.start()
p = Pool(1)
for name in names:
proc = Process(target=test_memory, args=(name,))
procs.append(proc)
proc.start()
proc.join()
for proc in procs:
proc.join()

471
old_files/new_scripts.py Normal file
View File

@@ -0,0 +1,471 @@
from estimator_new import *
from sage.all import oo, save, load, ceil
from math import log2
import multiprocessing
def old_models(security_level, sd, logq=32):
"""
Use the old model as a starting point for the data gathering step
:param security_level: the security level under consideration
:param sd : the standard deviation of the LWE error distribution Xe
:param logq : the (base 2 log) value of the LWE modulus q
"""
def evaluate_model(a, b, stddev=sd):
return (stddev - b)/a
models = dict()
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq)
models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq)
models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq)
models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq)
models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq)
models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq)
models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq)
models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq)
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
(a, b) = models["{}".format(security_level)]
n_est = evaluate_model(a, b, sd)
return round(n_est)
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
"""
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
:param params: the input LWE parameters
:param red_cost_model: the lattice reduction cost model
:param skip: attacks to skip
"""
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
def get_security_level(est, dp=2):
"""
Get the security level lambda from a Lattice Estimator output
:param est: the Lattice Estimator output
:param dp: the number of decimal places to consider
"""
attack_costs = []
# note: key does not need to be specified est vs est.keys()
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
def inequality(x, y):
""" A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
:param x: the LHS of the inequality
:param y: the RHS of the inequality
"""
if x <= y:
return 1
if x > y:
return -1
def automated_param_select_n(params, target_security=128):
""" A function used to generate the smallest value of n which allows for
target_security bits of security, for the input values of (params.Xe.stddev,params.q)
:param params: the standard deviation of the error
:param target_security: the target number of bits of security, 128 is default
EXAMPLE:
sage: X = automated_param_select_n(Kyber512, target_security = 128)
sage: X
456
"""
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
params = params.updated(n=n_start)
costs2 = estimate(params)
security_level = get_security_level(costs2, 2)
z = inequality(security_level, target_security)
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
while z * security_level < z * target_security:
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
params = params.updated(n = params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
if -1 * params.Xe.stddev > 0:
print("target security level is unattainable")
break
# final estimate (we went too far in the above loop)
if security_level < target_security:
# we make n larger
print("we make n larger")
params = params.updated(n=params.n + 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
log2(params.Xe.stddev),
log2(params.q),
security_level))
if security_level < target_security:
params.updated(n=None)
return params, security_level
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
"""
:param params_in: a initial set of LWE parameters
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
(sd_min, sd_max) = sd_range
for lam in target_security_levels:
for sd in range(sd_min, sd_max + 1):
print("run for {}".format(lam, sd))
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
try:
results = load("{}.sobj".format(name))
except:
results = dict()
results["{}".format(lam)] = []
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
save(results, "{}.sobj".format(name))
return results
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
"""
The top level function which we use to run the experiment
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
if __name__ == '__main__':
D = ND.DiscreteGaussian
vals = range(sd_range[0], sd_range[1])
pool = multiprocessing.Pool(2)
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
res = pool.starmap(generate_parameter_matrix, inputs)
return "done"
# The script runs the following commands
import sys
# grab values of the command-line input arguments
a = int(sys.argv[1])
b = int(sys.argv[2])
c = int(sys.argv[3])
# run the code
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))
from estimator_new import *
from sage.all import oo, save, load
from math import log2
import multiprocessing
def old_models(security_level, sd, logq=32):
"""
Use the old model as a starting point for the data gathering step
:param security_level: the security level under consideration
:param sd : the standard deviation of the LWE error distribution Xe
:param logq : the (base 2 log) value of the LWE modulus q
"""
def evaluate_model(a, b, stddev=sd):
return (stddev - b)/a
models = dict()
models["80"] = (-0.04049295502947623, 1.1288318226557081 + logq)
models["96"] = (-0.03416314056943681, 1.4704806061716345 + logq)
models["112"] = (-0.02970984362676178, 1.7848907787798667 + logq)
models["128"] = (-0.026361288425133814, 2.0014671315214696 + logq)
models["144"] = (-0.023744534465622812, 2.1710601038230712 + logq)
models["160"] = (-0.021667220727651954, 2.3565507936475476 + logq)
models["176"] = (-0.019947662046189942, 2.5109588704235803 + logq)
models["192"] = (-0.018552804646747204, 2.7168913723130816 + logq)
models["208"] = (-0.017291091126923574, 2.7956961446214326 + logq)
models["224"] = (-0.016257546811508806, 2.9582401000615226 + logq)
models["240"] = (-0.015329741032015766, 3.0744579055889782 + logq)
models["256"] = (-0.014530554319171845, 3.2094375376751745 + logq)
(a, b) = models["{}".format(security_level)]
n_est = evaluate_model(a, b, sd)
return round(n_est)
def estimate(params, red_cost_model=RC.BDGL16, skip=("arora-gb", "bkw")):
"""
Retrieve an estimate using the Lattice Estimator, for a given set of input parameters
:param params: the input LWE parameters
:param red_cost_model: the lattice reduction cost model
:param skip: attacks to skip
"""
est = LWE.estimate(params, red_cost_model=red_cost_model, deny_list=skip)
return est
def get_security_level(est, dp=2):
"""
Get the security level lambda from a Lattice Estimator output
:param est: the Lattice Estimator output
:param dp: the number of decimal places to consider
"""
attack_costs = []
# note: key does not need to be specified est vs est.keys()
for key in est:
attack_costs.append(est[key]["rop"])
# get the security level correct to 'dp' decimal places
security_level = round(log2(min(attack_costs)), dp)
return security_level
def inequality(x, y):
""" A utility function which compresses the conditions x < y and x > y into a single condition via a multiplier
:param x: the LHS of the inequality
:param y: the RHS of the inequality
"""
if x <= y:
return 1
if x > y:
return -1
def automated_param_select_n(params, target_security=128):
""" A function used to generate the smallest value of n which allows for
target_security bits of security, for the input values of (params.Xe.stddev,params.q)
:param params: the standard deviation of the error
:param target_security: the target number of bits of security, 128 is default
EXAMPLE:
sage: X = automated_param_select_n(Kyber512, target_security = 128)
sage: X
456
"""
# get an estimate based on the prev. model
print("n = {}".format(params.n))
n_start = old_models(target_security, log2(params.Xe.stddev), log2(params.q))
# n_start = max(n_start, 450)
# TODO: think about throwing an error if the required n < 450
params = params.updated(n=n_start)
costs2 = estimate(params)
security_level = get_security_level(costs2, 2)
z = inequality(security_level, target_security)
# we keep n > 2 * target_security as a rough baseline for mitm security (on binary key guessing)
while z * security_level < z * target_security:
# TODO: fill in this case! For n > 1024 we only need to consider every 256 (optimization)
params = params.updated(n = params.n + z * 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
if -1 * params.Xe.stddev > 0:
print("target security level is unattainable")
break
# final estimate (we went too far in the above loop)
if security_level < target_security:
# we make n larger
print("we make n larger")
params = params.updated(n=params.n + 8)
costs = estimate(params)
security_level = get_security_level(costs, 2)
print("the finalised parameters are n = {}, log2(sd) = {}, log2(q) = {}, with a security level of {}-bits".format(params.n,
log2(params.Xe.stddev),
log2(params.q),
security_level))
if security_level < target_security:
params.updated(n=None)
return params, security_level
def generate_parameter_matrix(params_in, sd_range, target_security_levels=[128], name="default_name"):
"""
:param params_in: a initial set of LWE parameters
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
(sd_min, sd_max) = sd_range
for lam in target_security_levels:
for sd in range(sd_min, sd_max + 1):
print("run for {}".format(lam, sd))
Xe_new = nd.NoiseDistribution.DiscreteGaussian(2**sd)
(params_out, sec) = automated_param_select_n(params_in.updated(Xe=Xe_new), target_security=lam)
try:
results = load("{}.sobj".format(name))
except:
results = dict()
results["{}".format(lam)] = []
results["{}".format(lam)].append((params_out.n, log2(params_out.q), log2(params_out.Xe.stddev), sec))
save(results, "{}.sobj".format(name))
return results
def generate_zama_curves64(sd_range=[2, 58], target_security_levels=[128], name="default_name"):
"""
The top level function which we use to run the experiment
:param sd_range: a tuple (sd_min, sd_max) giving the values of sd for which to generate parameters
:param target_security_levels: a list of the target number of bits of security, 128 is default
:param name: a name to save the file
"""
if __name__ == '__main__':
D = ND.DiscreteGaussian
vals = range(sd_range[0], sd_range[1])
pool = multiprocessing.Pool(2)
init_params = LWE.Parameters(n=1024, q=2 ** 64, Xs=D(0.50, -0.50), Xe=D(2 ** 55), m=oo, tag='params')
inputs = [(init_params, (val, val), target_security_levels, name) for val in vals]
res = pool.starmap(generate_parameter_matrix, inputs)
return "done"
# The script runs the following commands
import sys
# grab values of the command-line input arguments
a = int(sys.argv[1])
b = int(sys.argv[2])
c = int(sys.argv[3])
# run the code
generate_zama_curves64(sd_range= (b,c), target_security_levels=[a], name="{}".format(a))
import numpy as np
from sage.all import save, load
def sort_data(security_level):
from operator import itemgetter
# step 1. load the data
X = load("{}.sobj".format(security_level))
# step 2. sort by SD
x = sorted(X["{}".format(security_level)], key = itemgetter(2))
# step3. replace the sorted value
X["{}".format(security_level)] = x
return X
def generate_curve(security_level):
# step 1. get the data
X = sort_data(security_level)
# step 2. group the n and sigma data into lists
N = []
SD = []
for x in X["{}".format(security_level)]:
N.append(x[0])
SD.append(x[2] + 0.5)
# step 3. perform interpolation and return coefficients
(a,b) = np.polyfit(N, SD, 1)
return a, b
def verify_curve(security_level, a = None, b = None):
# step 1. get the table and max values of n, sd
X = sort_data(security_level)
n_max = X["{}".format(security_level)][0][0]
sd_max = X["{}".format(security_level)][-1][2]
# step 2. a function to get model values
def f_model(a, b, n):
return ceil(a * n + b)
# step 3. a function to get table values
def f_table(table, n):
for i in range(len(table)):
n_val = table[i][0]
if n < n_val:
pass
else:
j = i
break
# now j is the correct index, we return the corresponding sd
return table[j][2]
# step 3. for each n, check whether we satisfy the table
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
print(n_min)
print(n_max)
for n in range(n_max, n_min, - 1):
model_sd = f_model(a, b, n)
table_sd = f_table(X["{}".format(security_level)], n)
print(n , table_sd, model_sd, model_sd >= table_sd)
if table_sd > model_sd:
print("MODEL FAILS at n = {}".format(n))
return "FAIL"
return "PASS", n_min
def generate_and_verify(security_levels, log_q, name = "verified_curves"):
data = []
for sec in security_levels:
print("WE GO FOR {}".format(sec))
# generate the model for security level sec
(a_sec, b_sec) = generate_curve(sec)
# verify the model for security level sec
res = verify_curve(sec, a_sec, b_sec)
# append the information into a list
data.append((a_sec, b_sec - log_q, sec, res[0], res[1]))
save(data, "{}.sobj".format(name))
return data
# To verify the curves we use
generate_and_verify([80, 96, 112, 128, 144, 160, 176, 192, 256], log_q = 64)

BIN
verified_curves.sobj Normal file

Binary file not shown.

9
verified_curves.txt Normal file
View File

@@ -0,0 +1,9 @@
[(-0.04042633119364589, 1.6609788641436722, 80, 'PASS', 450),
(-0.03414780360867051, 2.017310258660345, 96, 'PASS', 450),
(-0.029670137081135885, 2.162463714083856, 112, 'PASS', 450),
(-0.02640502876522622, 2.4826422691043177, 128, 'PASS', 450),
(-0.023821437305989134, 2.7177789440636673, 144, 'PASS', 450),
(-0.02174358218716036, 2.938810548493322, 160, 'PASS', 498),
(-0.019904056582117684, 2.8161252801542247, 176, 'PASS', 551),
(-0.018610403247590085, 3.2996236848399008, 192, 'PASS', 606),
(-0.014606812351714953, 3.8493629234693003, 256, 'PASS', 826)]

98
verify_curves.py Normal file
View File

@@ -0,0 +1,98 @@
import numpy as np
from sage.all import save, load, ceil
def sort_data(security_level):
from operator import itemgetter
# step 1. load the data
X = load("{}.sobj".format(security_level))
# step 2. sort by SD
x = sorted(X["{}".format(security_level)], key=itemgetter(2))
# step3. replace the sorted value
X["{}".format(security_level)] = x
return X
def generate_curve(security_level):
# step 1. get the data
X = sort_data(security_level)
# step 2. group the n and sigma data into lists
N = []
SD = []
for x in X["{}".format(security_level)]:
N.append(x[0])
SD.append(x[2] + 0.5)
# step 3. perform interpolation and return coefficients
(a, b) = np.polyfit(N, SD, 1)
return a, b
def verify_curve(security_level, a=None, b=None):
# step 1. get the table and max values of n, sd
X = sort_data(security_level)
n_max = X["{}".format(security_level)][0][0]
sd_max = X["{}".format(security_level)][-1][2]
# step 2. a function to get model values
def f_model(a, b, n):
return ceil(a * n + b)
# step 3. a function to get table values
def f_table(table, n):
for i in range(len(table)):
n_val = table[i][0]
if n < n_val:
pass
else:
j = i
break
# now j is the correct index, we return the corresponding sd
return table[j][2]
# step 3. for each n, check whether we satisfy the table
n_min = max(2 * security_level, 450, X["{}".format(security_level)][-1][0])
print(n_min)
print(n_max)
for n in range(n_max, n_min, - 1):
model_sd = f_model(a, b, n)
table_sd = f_table(X["{}".format(security_level)], n)
print(n, table_sd, model_sd, model_sd >= table_sd)
if table_sd > model_sd:
print("MODEL FAILS at n = {}".format(n))
return "FAIL"
return "PASS", n_min
def generate_and_verify(security_levels, log_q, name="verified_curves"):
data = []
for sec in security_levels:
print("WE GO FOR {}".format(sec))
# generate the model for security level sec
(a_sec, b_sec) = generate_curve(sec)
# verify the model for security level sec
res = verify_curve(sec, a_sec, b_sec)
# append the information into a list
data.append((a_sec, b_sec - log_q, sec, res[0], res[1]))
save(data, "{}.sobj".format(name))
return data
data = generate_and_verify(
[80, 96, 112, 128, 144, 160, 176, 192, 256], log_q=64)
print(data)