mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-10 07:57:56 -05:00
# This PR 1. Adds C++ API 2. Renames a lot of API functions 3. Adds inplace poseidon2 4. Makes input const at all poseidon functions 5. Adds benchmark for poseidon2
162 lines
4.9 KiB
Python
Executable File
162 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
This script generates extern declarators for every curve/field
|
|
"""
|
|
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from string import Template
|
|
|
|
API_PATH = Path(__file__).resolve().parent.parent.joinpath("icicle").joinpath("include").joinpath("api")
|
|
TEMPLATES_PATH = API_PATH.joinpath("templates")
|
|
|
|
"""
|
|
Defines a set of curves to generate API for.
|
|
A set corresponding to each curve contains headers that shouldn't be included.
|
|
"""
|
|
CURVES_CONFIG = {
|
|
"bn254": [
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_381": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_377": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bw6_761": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"grumpkin": {
|
|
"poseidon2.h",
|
|
"curve_g2.h",
|
|
"msm_g2.h",
|
|
"ecntt.h",
|
|
"ntt.h",
|
|
"vec_ops_ext.h",
|
|
"field_ext.h",
|
|
"ntt_ext.h",
|
|
},
|
|
}
|
|
|
|
"""
|
|
Defines a set of fields to generate API for.
|
|
A set corresponding to each field contains headers that shouldn't be included.
|
|
"""
|
|
FIELDS_CONFIG = {
|
|
"babybear": {
|
|
"poseidon.h",
|
|
},
|
|
"stark252": {
|
|
"poseidon.h",
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
}
|
|
}
|
|
|
|
# For cudaError_t and device_context
|
|
COMMON_INCLUDES = [
|
|
'#include <cuda_runtime.h>',
|
|
'#include "gpu-utils/device_context.cuh"',
|
|
]
|
|
|
|
WARN_TEXT = """\
|
|
// WARNING: This file is auto-generated by a script.
|
|
// Any changes made to this file may be overwritten.
|
|
// Please modify the code generation script instead.
|
|
// Path to the code generation script: scripts/gen_c_api.py
|
|
|
|
"""
|
|
|
|
INCLUDE_ONCE = """\
|
|
#pragma once
|
|
#ifndef {0}_API_H
|
|
#define {0}_API_H
|
|
|
|
"""
|
|
|
|
CURVE_HEADERS = list(TEMPLATES_PATH.joinpath("curves").iterdir())
|
|
FIELD_HEADERS = list(TEMPLATES_PATH.joinpath("fields").iterdir())
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Generate API for ingo_curve
|
|
for curve, skip in CURVES_CONFIG.items():
|
|
curve_api = API_PATH.joinpath(f"{curve}.h")
|
|
|
|
headers = [header for header in chain(CURVE_HEADERS, FIELD_HEADERS) if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "curves/params/{curve}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("msm") for header in headers):
|
|
includes.append('#include "msm/msm.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
includes.append('#include "poseidon/tree/merkle.cuh"')
|
|
if any(header.name.startswith("poseidon2") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(curve.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"CURVE": curve,
|
|
"FIELD": curve,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(curve_api, "w") as f:
|
|
f.write(contents)
|
|
|
|
|
|
# Generate API for ingo_field
|
|
for field, skip in FIELDS_CONFIG.items():
|
|
field_api = API_PATH.joinpath(f"{field}.h")
|
|
|
|
headers = [header for header in FIELD_HEADERS if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "fields/stark_fields/{field}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
includes.append('#include "poseidon/tree/merkle.cuh"')
|
|
if any(header.name.startswith("poseidon2") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(field.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"FIELD": field,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(field_api, "w") as f:
|
|
f.write(contents) |