Files
icicle/scripts/gen_c_api.py
PatStiles bdc3da98d6 FEAT(stark252 field): Adds Stark252 curve (#494)
## Describe the changes

Adds support for the stark252 base field.
2024-05-01 14:08:05 +03:00

153 lines
4.6 KiB
Python
Executable File

#!/usr/bin/env python3
"""
This script generates extern declarators for every curve/field
"""
from itertools import chain
from pathlib import Path
from string import Template
API_PATH = Path(__file__).resolve().parent.parent.joinpath("icicle").joinpath("include").joinpath("api")
TEMPLATES_PATH = API_PATH.joinpath("templates")
"""
Defines a set of curves to generate API for.
A set corresponding to each curve contains headers that shouldn't be included.
"""
CURVES_CONFIG = {
"bn254": [
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bls12_381": [
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bls12_377": [
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bw6_761": [
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"grumpkin": {
"curve_g2.h",
"msm_g2.h",
"ecntt.h",
"ntt.h",
"vec_ops_ext.h",
"field_ext.h",
"ntt_ext.h",
},
}
"""
Defines a set of fields to generate API for.
A set corresponding to each field contains headers that shouldn't be included.
"""
FIELDS_CONFIG = {
"babybear": {
"poseidon.h",
},
"stark252": {
"poseidon.h",
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
}
}
# For cudaError_t and device_context
COMMON_INCLUDES = [
'#include <cuda_runtime.h>',
'#include "gpu-utils/device_context.cuh"',
]
WARN_TEXT = """\
// WARNING: This file is auto-generated by a script.
// Any changes made to this file may be overwritten.
// Please modify the code generation script instead.
// Path to the code generation script: scripts/gen_c_api.py
"""
INCLUDE_ONCE = """\
#pragma once
#ifndef {0}_API_H
#define {0}_API_H
"""
CURVE_HEADERS = list(TEMPLATES_PATH.joinpath("curves").iterdir())
FIELD_HEADERS = list(TEMPLATES_PATH.joinpath("fields").iterdir())
if __name__ == "__main__":
# Generate API for ingo_curve
for curve, skip in CURVES_CONFIG.items():
curve_api = API_PATH.joinpath(f"{curve}.h")
headers = [header for header in chain(CURVE_HEADERS, FIELD_HEADERS) if header.name not in skip]
# Collect includes
includes = COMMON_INCLUDES.copy()
includes.append(f'#include "curves/params/{curve}.cuh"')
if any(header.name.startswith("ntt") for header in headers):
includes.append('#include "ntt/ntt.cuh"')
if any(header.name.startswith("msm") for header in headers):
includes.append('#include "msm/msm.cuh"')
if any(header.name.startswith("vec_ops") for header in headers):
includes.append('#include "vec_ops/vec_ops.cuh"')
if any(header.name.startswith("poseidon") for header in headers):
includes.append('#include "poseidon/poseidon.cuh"')
includes.append('#include "poseidon/tree/merkle.cuh"')
contents = WARN_TEXT + INCLUDE_ONCE.format(curve.upper()) + "\n".join(includes) + "\n\n"
for header in headers:
with open(header) as f:
template = Template(f.read())
contents += template.safe_substitute({
"CURVE": curve,
"FIELD": curve,
})
contents += "\n\n"
contents += "#endif"
with open(curve_api, "w") as f:
f.write(contents)
# Generate API for ingo_field
for field, skip in FIELDS_CONFIG.items():
field_api = API_PATH.joinpath(f"{field}.h")
headers = [header for header in FIELD_HEADERS if header.name not in skip]
# Collect includes
includes = COMMON_INCLUDES.copy()
includes.append(f'#include "fields/stark_fields/{field}.cuh"')
if any(header.name.startswith("ntt") for header in headers):
includes.append('#include "ntt/ntt.cuh"')
if any(header.name.startswith("vec_ops") for header in headers):
includes.append('#include "vec_ops/vec_ops.cuh"')
if any(header.name.startswith("poseidon") for header in headers):
includes.append('#include "poseidon/poseidon.cuh"')
includes.append('#include "poseidon/tree/merkle.cuh"')
contents = WARN_TEXT + INCLUDE_ONCE.format(field.upper()) + "\n".join(includes) + "\n\n"
for header in headers:
with open(header) as f:
template = Template(f.read())
contents += template.safe_substitute({
"FIELD": field,
})
contents += "\n\n"
contents += "#endif"
with open(field_api, "w") as f:
f.write(contents)