mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
Resolves #191 and #113 --------- Co-authored-by: DmytroTym <dmytrotym1@gmail.com> Co-authored-by: ImmanuelSegol <3ditds@gmail.com>
347 lines
12 KiB
Python
347 lines
12 KiB
Python
import json
|
|
import math
|
|
import os
|
|
from string import Template
|
|
import sys
|
|
|
|
|
|
argv_list = ['thisfile', 'curve_json', 'command']
|
|
new_curve_args = dict(zip(argv_list, sys.argv[:len(argv_list)] + [""]*(len(argv_list) - len(sys.argv))))
|
|
|
|
def to_hex(val: int, length):
|
|
x = hex(val)[2:]
|
|
if len(x) % 8 != 0:
|
|
x = "0" * (8-len(x) % 8) + x
|
|
if len(x) != length:
|
|
x = "0" * (length-len(x)) + x
|
|
n = 8
|
|
chunks = [x[i:i+n] for i in range(0, len(x), n)][::-1]
|
|
s = ""
|
|
for c in chunks[:length // n]:
|
|
s += f'0x{c}, '
|
|
|
|
return s[:-2]
|
|
|
|
|
|
def compute_values(modulus, modulus_bit_count, limbs):
|
|
limb_size = 8*limbs
|
|
bit_size = 4*limb_size
|
|
modulus_ = to_hex(modulus,limb_size)
|
|
modulus_2 = to_hex(modulus*2,limb_size)
|
|
modulus_4 = to_hex(modulus*4,limb_size)
|
|
modulus_wide = to_hex(modulus,limb_size*2)
|
|
modulus_squared = to_hex(modulus*modulus,limb_size*2)
|
|
modulus_squared_2 = to_hex(modulus*modulus*2,limb_size*2)
|
|
modulus_squared_4 = to_hex(modulus*modulus*4,limb_size*2)
|
|
m_raw = int(math.floor(int(pow(2,2*modulus_bit_count) // modulus)))
|
|
m = to_hex(m_raw,limb_size)
|
|
one = to_hex(1,limb_size)
|
|
zero = to_hex(0,limb_size)
|
|
montgomery_r = to_hex(pow(2,bit_size,modulus),limb_size)
|
|
montgomery_r_inv = to_hex(pow(2,-bit_size,modulus),limb_size)
|
|
|
|
return (
|
|
modulus_,
|
|
modulus_2,
|
|
modulus_4,
|
|
modulus_wide,
|
|
modulus_squared,
|
|
modulus_squared_2,
|
|
modulus_squared_4,
|
|
m,
|
|
one,
|
|
zero,
|
|
montgomery_r,
|
|
montgomery_r_inv
|
|
)
|
|
|
|
|
|
def get_fq_params(modulus, modulus_bit_count, limbs, nonresidue):
|
|
(
|
|
modulus,
|
|
modulus_2,
|
|
modulus_4,
|
|
modulus_wide,
|
|
modulus_squared,
|
|
modulus_squared_2,
|
|
modulus_squared_4,
|
|
m,
|
|
one,
|
|
zero,
|
|
montgomery_r,
|
|
montgomery_r_inv
|
|
) = compute_values(modulus, modulus_bit_count, limbs)
|
|
|
|
limb_size = 8*limbs
|
|
nonresidue_is_negative = str(nonresidue < 0).lower()
|
|
nonresidue = abs(nonresidue)
|
|
return {
|
|
'fq_modulus': modulus,
|
|
'fq_modulus_2': modulus_2,
|
|
'fq_modulus_4': modulus_4,
|
|
'fq_modulus_wide': modulus_wide,
|
|
'fq_modulus_squared': modulus_squared,
|
|
'fq_modulus_squared_2': modulus_squared_2,
|
|
'fq_modulus_squared_4': modulus_squared_4,
|
|
'fq_m': m,
|
|
'fq_one': one,
|
|
'fq_zero': zero,
|
|
'fq_montgomery_r': montgomery_r,
|
|
'fq_montgomery_r_inv': montgomery_r_inv,
|
|
'nonresidue': nonresidue,
|
|
'nonresidue_is_negative': nonresidue_is_negative
|
|
}
|
|
|
|
|
|
def get_fp_params(modulus, modulus_bit_count, limbs, root_of_unity, size=0):
|
|
(
|
|
modulus_,
|
|
modulus_2,
|
|
modulus_4,
|
|
modulus_wide,
|
|
modulus_squared,
|
|
modulus_squared_2,
|
|
modulus_squared_4,
|
|
m,
|
|
one,
|
|
zero,
|
|
montgomery_r,
|
|
montgomery_r_inv
|
|
) = compute_values(modulus, modulus_bit_count, limbs)
|
|
limb_size = 8*limbs
|
|
if size > 0:
|
|
omega = ''
|
|
omega_inv = ''
|
|
inv = ''
|
|
omegas = []
|
|
omegas_inv = []
|
|
for k in range(size):
|
|
if k == 0:
|
|
om = root_of_unity
|
|
else:
|
|
om = pow(om, 2, modulus)
|
|
omegas.append(om)
|
|
omegas_inv.append(pow(om, -1, modulus))
|
|
omegas.reverse()
|
|
omegas_inv.reverse()
|
|
for k in range(size):
|
|
omega += "\n {"+ to_hex(omegas[k],limb_size)+"}," if k>0 else " {"+ to_hex(omegas[k],limb_size)+"},"
|
|
omega_inv += "\n {"+ to_hex(omegas_inv[k],limb_size)+"}," if k>0 else " {"+ to_hex(omegas_inv[k],limb_size)+"},"
|
|
inv += "\n {"+ to_hex(pow(int(pow(2,k+1)), -1, modulus),limb_size)+"}," if k>0 else " {"+ to_hex(pow(int(pow(2,k+1)), -1, modulus),limb_size)+"},"
|
|
|
|
|
|
return {
|
|
'fp_modulus': modulus_,
|
|
'fp_modulus_2': modulus_2,
|
|
'fp_modulus_4': modulus_4,
|
|
'fp_modulus_wide': modulus_wide,
|
|
'fp_modulus_squared': modulus_squared,
|
|
'fp_modulus_squared_2': modulus_squared_2,
|
|
'fp_modulus_squared_4': modulus_squared_4,
|
|
'fp_m': m,
|
|
'fp_one': one,
|
|
'fp_zero': zero,
|
|
'fp_montgomery_r': montgomery_r,
|
|
'fp_montgomery_r_inv': montgomery_r_inv,
|
|
'omega': omega[:-1],
|
|
'omega_inv': omega_inv[:-1],
|
|
'inv': inv[:-1],
|
|
}
|
|
|
|
|
|
def get_generators(g1_gen_x, g1_gen_y, g2_gen_x_re, g2_gen_x_im, g2_gen_y_re, g2_gen_y_im, size):
|
|
|
|
return {
|
|
'fq_gen_x': to_hex(g1_gen_x, size),
|
|
'fq_gen_y': to_hex(g1_gen_y, size),
|
|
'fq_gen_x_re': to_hex(g2_gen_x_re, size),
|
|
'fq_gen_x_im': to_hex(g2_gen_x_im, size),
|
|
'fq_gen_y_re': to_hex(g2_gen_y_re, size),
|
|
'fq_gen_y_im': to_hex(g2_gen_y_im, size)
|
|
}
|
|
|
|
|
|
def get_weier_params(weierstrass_b, weierstrass_b_g2_re, weierstrass_b_g2_im, size):
|
|
|
|
return {
|
|
'weier_b': to_hex(weierstrass_b, size),
|
|
'weier_b_g2_re': to_hex(weierstrass_b_g2_re, size),
|
|
'weier_b_g2_im': to_hex(weierstrass_b_g2_im, size),
|
|
}
|
|
|
|
|
|
def get_params(config):
|
|
global ntt_size
|
|
curve_name = config["curve_name"]
|
|
modulus_p = config["modulus_p"]
|
|
bit_count_p = config["bit_count_p"]
|
|
limb_p = config["limb_p"]
|
|
ntt_size = config["ntt_size"]
|
|
modulus_q = config["modulus_q"]
|
|
bit_count_q = config["bit_count_q"]
|
|
limb_q = config["limb_q"]
|
|
root_of_unity = config["root_of_unity"]
|
|
nonresidue = config["nonresidue"]
|
|
if root_of_unity == modulus_p:
|
|
sys.exit("Invalid root_of_unity value; please update in curve parameters")
|
|
|
|
weierstrass_b = config["weierstrass_b"]
|
|
weierstrass_b_g2_re = config["weierstrass_b_g2_re"]
|
|
weierstrass_b_g2_im = config["weierstrass_b_g2_im"]
|
|
g1_gen_x = config["g1_gen_x"]
|
|
g1_gen_y = config["g1_gen_y"]
|
|
g2_generator_x_re = config["g2_gen_x_re"]
|
|
g2_generator_x_im = config["g2_gen_x_im"]
|
|
g2_generator_y_re = config["g2_gen_y_re"]
|
|
g2_generator_y_im = config["g2_gen_y_im"]
|
|
|
|
params = {
|
|
'curve_name_U': curve_name.upper(),
|
|
'fp_num_limbs': limb_p,
|
|
'fq_num_limbs': limb_q,
|
|
'fp_modulus_bit_count': bit_count_p,
|
|
'fq_modulus_bit_count': bit_count_q,
|
|
'num_omegas': ntt_size
|
|
}
|
|
|
|
fp_params = get_fp_params(modulus_p, bit_count_p, limb_p, root_of_unity, ntt_size)
|
|
fq_params = get_fq_params(modulus_q, bit_count_q, limb_q, nonresidue)
|
|
generators = get_generators(g1_gen_x, g1_gen_y, g2_generator_x_re, g2_generator_x_im, g2_generator_y_re, g2_generator_y_im, 8*limb_q)
|
|
weier_params = get_weier_params(weierstrass_b, weierstrass_b_g2_re, weierstrass_b_g2_im, 8*limb_q)
|
|
|
|
return {
|
|
**params,
|
|
**fp_params,
|
|
**fq_params,
|
|
**generators,
|
|
**weier_params
|
|
}
|
|
|
|
|
|
config = None
|
|
with open(new_curve_args['curve_json']) as json_file:
|
|
config = json.load(json_file)
|
|
|
|
curve_name_lower = config["curve_name"].lower()
|
|
curve_name_upper = config["curve_name"].upper()
|
|
limb_q = config["limb_q"]
|
|
limb_p = config["limb_p"]
|
|
|
|
# Create Cuda interface
|
|
|
|
newpath = f'./icicle/curves/{curve_name_lower}'
|
|
if not os.path.exists(newpath):
|
|
os.makedirs(newpath)
|
|
|
|
with open("./icicle/curves/curve_template/params.cuh.tmpl", "r") as params_file:
|
|
params_file_template = Template(params_file.read())
|
|
params = get_params(config)
|
|
params_content = params_file_template.safe_substitute(params)
|
|
with open(f'./icicle/curves/{curve_name_lower}/params.cuh', 'w') as f:
|
|
f.write(params_content)
|
|
|
|
if new_curve_args['command'] != '-update':
|
|
with open("./icicle/curves/curve_template/lde.cu.tmpl", "r") as lde_file:
|
|
template_content = Template(lde_file.read())
|
|
lde_content = template_content.safe_substitute(
|
|
CURVE_NAME_U=curve_name_upper,
|
|
CURVE_NAME_L=curve_name_lower
|
|
)
|
|
with open(f'./icicle/curves/{curve_name_lower}/lde.cu', 'w') as f:
|
|
f.write(lde_content)
|
|
|
|
with open("./icicle/curves/curve_template/msm.cu.tmpl", "r") as msm_file:
|
|
template_content = Template(msm_file.read())
|
|
msm_content = template_content.safe_substitute(
|
|
CURVE_NAME_U=curve_name_upper,
|
|
CURVE_NAME_L=curve_name_lower
|
|
)
|
|
with open(f'./icicle/curves/{curve_name_lower}/msm.cu', 'w') as f:
|
|
f.write(msm_content)
|
|
|
|
with open("./icicle/curves/curve_template/ve_mod_mult.cu.tmpl", "r") as ve_mod_mult_file:
|
|
template_content = Template(ve_mod_mult_file.read())
|
|
ve_mod_mult_content = template_content.safe_substitute(
|
|
CURVE_NAME_U=curve_name_upper,
|
|
CURVE_NAME_L=curve_name_lower
|
|
)
|
|
with open(f'./icicle/curves/{curve_name_lower}/ve_mod_mult.cu', 'w') as f:
|
|
f.write(ve_mod_mult_content)
|
|
|
|
|
|
with open(f'./icicle/curves/curve_template/curve_config.cuh.tmpl', 'r') as cc:
|
|
template_content = Template(cc.read())
|
|
cc_content = template_content.safe_substitute(
|
|
CURVE_NAME_U=curve_name_upper,
|
|
)
|
|
with open(f'./icicle/curves/{curve_name_lower}/curve_config.cuh', 'w') as f:
|
|
f.write(cc_content)
|
|
|
|
|
|
with open(f'./icicle/curves/curve_template/projective.cu.tmpl', 'r') as proj:
|
|
template_content = Template(proj.read())
|
|
proj_content = template_content.safe_substitute(
|
|
CURVE_NAME_U=curve_name_upper,
|
|
CURVE_NAME_L=curve_name_lower
|
|
)
|
|
with open(f'./icicle/curves/{curve_name_lower}/projective.cu', 'w') as f:
|
|
f.write(proj_content)
|
|
|
|
|
|
with open(f'./icicle/curves/curve_template/supported_operations.cu.tmpl', 'r') as supp_ops:
|
|
template_content = Template(supp_ops.read())
|
|
supp_ops_content = template_content.safe_substitute()
|
|
with open(f'./icicle/curves/{curve_name_lower}/supported_operations.cu', 'w') as f:
|
|
f.write(supp_ops_content)
|
|
|
|
|
|
with open('./icicle/curves/index.cu', 'r+') as f:
|
|
index_text = f.read()
|
|
if index_text.find(curve_name_lower) == -1:
|
|
f.write(f'\n#include "{curve_name_lower}/supported_operations.cu"')
|
|
|
|
|
|
|
|
# Create Rust interface and tests
|
|
|
|
if limb_p == limb_q:
|
|
with open("./src/curve_templates/curve_same_limbs.rs", "r") as curve_file:
|
|
content = curve_file.read()
|
|
content = content.replace("CURVE_NAME_U",curve_name_upper)
|
|
content = content.replace("CURVE_NAME_L",curve_name_lower)
|
|
content = content.replace("_limbs_p",str(limb_p * 8 * 4))
|
|
content = content.replace("limbs_p",str(limb_p))
|
|
text_file = open("./src/curves/"+curve_name_lower+".rs", "w")
|
|
n = text_file.write(content)
|
|
text_file.close()
|
|
else:
|
|
with open("./src/curve_templates/curve_different_limbs.rs", "r") as curve_file:
|
|
content = curve_file.read()
|
|
content = content.replace("CURVE_NAME_U",curve_name_upper)
|
|
content = content.replace("CURVE_NAME_L",curve_name_lower)
|
|
content = content.replace("_limbs_p",str(limb_p * 8 * 4))
|
|
content = content.replace("limbs_p",str(limb_p))
|
|
content = content.replace("_limbs_q",str(limb_q * 8 * 4))
|
|
content = content.replace("limbs_q",str(limb_q))
|
|
text_file = open("./src/curves/"+curve_name_lower+".rs", "w")
|
|
n = text_file.write(content)
|
|
text_file.close()
|
|
|
|
with open("./src/curve_templates/test.rs", "r") as test_file:
|
|
content = test_file.read()
|
|
content = content.replace("CURVE_NAME_U",curve_name_upper)
|
|
content = content.replace("CURVE_NAME_L",curve_name_lower)
|
|
text_file = open("./src/test_"+curve_name_lower+".rs", "w")
|
|
n = text_file.write(content)
|
|
text_file.close()
|
|
|
|
with open('./src/curves/mod.rs', 'r+') as f:
|
|
mod_text = f.read()
|
|
if mod_text.find(curve_name_lower) == -1:
|
|
f.write('\npub mod ' + curve_name_lower + ';')
|
|
|
|
with open('./src/lib.rs', 'r+') as f:
|
|
lib_text = f.read()
|
|
if lib_text.find(curve_name_lower) == -1:
|
|
f.write('\npub mod ' + curve_name_lower + ';')
|