mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 04:58:08 -05:00
This PR adds support of the m31 Field --------- Co-authored-by: Jeremy Felder <jeremy.felder1@gmail.com>
168 lines
5.0 KiB
Python
Executable File
168 lines
5.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
This script generates extern declarators for every curve/field
|
|
"""
|
|
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from string import Template
|
|
|
|
API_PATH = Path(__file__).resolve().parent.parent.joinpath("icicle").joinpath("include").joinpath("api")
|
|
TEMPLATES_PATH = API_PATH.joinpath("templates")
|
|
|
|
"""
|
|
Defines a set of curves to generate API for.
|
|
A set corresponding to each curve contains headers that shouldn't be included.
|
|
"""
|
|
CURVES_CONFIG = {
|
|
"bn254": [
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_381": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_377": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bw6_761": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"grumpkin": {
|
|
"poseidon2.h",
|
|
"curve_g2.h",
|
|
"msm_g2.h",
|
|
"ecntt.h",
|
|
"ntt.h",
|
|
"vec_ops_ext.h",
|
|
"field_ext.h",
|
|
"ntt_ext.h",
|
|
},
|
|
}
|
|
|
|
"""
|
|
Defines a set of fields to generate API for.
|
|
A set corresponding to each field contains headers that shouldn't be included.
|
|
"""
|
|
FIELDS_CONFIG = {
|
|
"babybear": {
|
|
"poseidon.h",
|
|
},
|
|
"stark252": {
|
|
"poseidon.h",
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
},
|
|
"m31": {
|
|
"ntt_ext.h",
|
|
"ntt.h",
|
|
"poseidon.h",
|
|
"poseidon2.h",
|
|
}
|
|
}
|
|
|
|
# For cudaError_t and device_context
|
|
COMMON_INCLUDES = [
|
|
'#include <cuda_runtime.h>',
|
|
'#include "gpu-utils/device_context.cuh"',
|
|
]
|
|
|
|
WARN_TEXT = """\
|
|
// WARNING: This file is auto-generated by a script.
|
|
// Any changes made to this file may be overwritten.
|
|
// Please modify the code generation script instead.
|
|
// Path to the code generation script: scripts/gen_c_api.py
|
|
|
|
"""
|
|
|
|
INCLUDE_ONCE = """\
|
|
#pragma once
|
|
#ifndef {0}_API_H
|
|
#define {0}_API_H
|
|
|
|
"""
|
|
|
|
CURVE_HEADERS = list(TEMPLATES_PATH.joinpath("curves").iterdir())
|
|
FIELD_HEADERS = list(TEMPLATES_PATH.joinpath("fields").iterdir())
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Generate API for ingo_curve
|
|
for curve, skip in CURVES_CONFIG.items():
|
|
curve_api = API_PATH.joinpath(f"{curve}.h")
|
|
|
|
headers = [header for header in chain(CURVE_HEADERS, FIELD_HEADERS) if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "curves/params/{curve}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("msm") for header in headers):
|
|
includes.append('#include "msm/msm.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
includes.append('#include "poseidon/tree/merkle.cuh"')
|
|
if any(header.name.startswith("poseidon2") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(curve.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"CURVE": curve,
|
|
"FIELD": curve,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(curve_api, "w") as f:
|
|
f.write(contents)
|
|
|
|
|
|
# Generate API for ingo_field
|
|
for field, skip in FIELDS_CONFIG.items():
|
|
field_api = API_PATH.joinpath(f"{field}.h")
|
|
|
|
headers = [header for header in FIELD_HEADERS if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "fields/stark_fields/{field}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
includes.append('#include "poseidon/tree/merkle.cuh"')
|
|
if any(header.name.startswith("poseidon2") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(field.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"FIELD": field,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(field_api, "w") as f:
|
|
f.write(contents) |