Files
icicle/scripts/gen_c_api.py
ChickenLover 7fd9ed1b49 Feat/roman/tree builder (#525)
# Updates:

## Hashing

 - Added SpongeHasher class
 - Can be used to accept any hash function as an argument
 - Absorb and squeeze are now separated
- Memory management is now mostly done by SpongeHasher class, each hash
function only describes permutation kernels

## Tree builder

 - Tree builder is now hash-agnostic. 
 - Tree builder now supports 2D input (matrices)
- Tree builder can now use two different hash functions for layer 0 and
compression layers

## Poseidon1

 - Interface changed to classes
 - Now allows for any alpha
 - Now allows passing constants not in a single vector
 - Now allows for any domain tag
 - Constants are now released upon going out of scope
 - Rust wrappers changed to Poseidon struct
 
 ## Poseidon2
 
 - Interface changed to classes
 - Constants are now released upon going out of scope
 - Rust wrappers changed to Poseidon2 struct
 
## Keccak

 - Added Keccak class which inherits SpongeHasher
 - Now doesn't use gpu registers for storing states
 
 To do:
- [x] Update poseidon1 golang bindings
- [x] Update poseidon1 examples
- [x] Fix poseidon2 cuda test
- [x] Fix poseidon2 merkle tree builder test
- [x] Update keccak class with new design
- [x] Update keccak test
- [x] Check keccak correctness
- [x] Update tree builder rust wrappers
- [x] Leave doc comments

Future work:  
- [ ] Add keccak merkle tree builder externs
- [ ] Add keccak rust tree builder wrappers
- [ ] Write docs
- [ ] Add example
- [ ] Fix device output for tree builder

---------

Co-authored-by: Jeremy Felder <jeremy.felder1@gmail.com>
Co-authored-by: nonam3e <71525212+nonam3e@users.noreply.github.com>
2024-07-11 13:46:25 +07:00

168 lines
5.0 KiB
Python
Executable File

#!/usr/bin/env python3
"""
This script generates extern declarators for every curve/field
"""
from itertools import chain
from pathlib import Path
from string import Template
API_PATH = Path(__file__).resolve().parent.parent.joinpath("icicle").joinpath("include").joinpath("api")
TEMPLATES_PATH = API_PATH.joinpath("templates")
"""
Defines a set of curves to generate API for.
A set corresponding to each curve contains headers that shouldn't be included.
"""
CURVES_CONFIG = {
"bn254": [
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bls12_381": [
"poseidon2.h",
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bls12_377": [
"poseidon2.h",
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"bw6_761": [
"poseidon2.h",
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
],
"grumpkin": {
"poseidon2.h",
"curve_g2.h",
"msm_g2.h",
"ecntt.h",
"ntt.h",
"vec_ops_ext.h",
"field_ext.h",
"ntt_ext.h",
},
}
"""
Defines a set of fields to generate API for.
A set corresponding to each field contains headers that shouldn't be included.
"""
FIELDS_CONFIG = {
"babybear": {
"poseidon.h",
},
"stark252": {
"poseidon.h",
"poseidon2.h",
"field_ext.h",
"vec_ops_ext.h",
"ntt_ext.h",
},
"m31": {
"ntt_ext.h",
"ntt.h",
"poseidon.h",
"poseidon2.h",
}
}
# For cudaError_t and device_context
COMMON_INCLUDES = [
'#include <cuda_runtime.h>',
'#include "gpu-utils/device_context.cuh"',
'#include "merkle-tree/merkle.cuh"',
'#include "matrix/matrix.cuh"'
]
WARN_TEXT = """\
// WARNING: This file is auto-generated by a script.
// Any changes made to this file may be overwritten.
// Please modify the code generation script instead.
// Path to the code generation script: scripts/gen_c_api.py
"""
INCLUDE_ONCE = """\
#pragma once
#ifndef {0}_API_H
#define {0}_API_H
"""
CURVE_HEADERS = list(TEMPLATES_PATH.joinpath("curves").iterdir())
FIELD_HEADERS = list(TEMPLATES_PATH.joinpath("fields").iterdir())
if __name__ == "__main__":
# Generate API for ingo_curve
for curve, skip in CURVES_CONFIG.items():
curve_api = API_PATH.joinpath(f"{curve}.h")
headers = [header for header in chain(CURVE_HEADERS, FIELD_HEADERS) if header.name not in skip]
# Collect includes
includes = COMMON_INCLUDES.copy()
includes.append(f'#include "curves/params/{curve}.cuh"')
if any(header.name.startswith("ntt") for header in headers):
includes.append('#include "ntt/ntt.cuh"')
if any(header.name.startswith("msm") for header in headers):
includes.append('#include "msm/msm.cuh"')
if any(header.name.startswith("vec_ops") for header in headers):
includes.append('#include "vec_ops/vec_ops.cuh"')
if any(header.name.startswith("poseidon.h") for header in headers):
includes.append('#include "poseidon/poseidon.cuh"')
if any(header.name.startswith("poseidon2.h") for header in headers):
includes.append('#include "poseidon2/poseidon2.cuh"')
contents = WARN_TEXT + INCLUDE_ONCE.format(curve.upper()) + "\n".join(includes) + "\n\n"
for header in headers:
with open(header) as f:
template = Template(f.read())
contents += template.safe_substitute({
"CURVE": curve,
"FIELD": curve,
})
contents += "\n\n"
contents += "#endif"
with open(curve_api, "w") as f:
f.write(contents)
# Generate API for ingo_field
for field, skip in FIELDS_CONFIG.items():
field_api = API_PATH.joinpath(f"{field}.h")
headers = [header for header in FIELD_HEADERS if header.name not in skip]
# Collect includes
includes = COMMON_INCLUDES.copy()
includes.append(f'#include "fields/stark_fields/{field}.cuh"')
if any(header.name.startswith("ntt") for header in headers):
includes.append('#include "ntt/ntt.cuh"')
if any(header.name.startswith("vec_ops") for header in headers):
includes.append('#include "vec_ops/vec_ops.cuh"')
if any(header.name.startswith("poseidon.h") for header in headers):
includes.append('#include "poseidon/poseidon.cuh"')
if any(header.name.startswith("poseidon2.h") for header in headers):
includes.append('#include "poseidon2/poseidon2.cuh"')
contents = WARN_TEXT + INCLUDE_ONCE.format(field.upper()) + "\n".join(includes) + "\n\n"
for header in headers:
with open(header) as f:
template = Template(f.read())
contents += template.safe_substitute({
"FIELD": field,
})
contents += "\n\n"
contents += "#endif"
with open(field_api, "w") as f:
f.write(contents)