mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
# Updates: ## Hashing - Added SpongeHasher class - Can be used to accept any hash function as an argument - Absorb and squeeze are now separated - Memory management is now mostly done by SpongeHasher class, each hash function only describes permutation kernels ## Tree builder - Tree builder is now hash-agnostic. - Tree builder now supports 2D input (matrices) - Tree builder can now use two different hash functions for layer 0 and compression layers ## Poseidon1 - Interface changed to classes - Now allows for any alpha - Now allows passing constants not in a single vector - Now allows for any domain tag - Constants are now released upon going out of scope - Rust wrappers changed to Poseidon struct ## Poseidon2 - Interface changed to classes - Constants are now released upon going out of scope - Rust wrappers changed to Poseidon2 struct ## Keccak - Added Keccak class which inherits SpongeHasher - Now doesn't use gpu registers for storing states To do: - [x] Update poseidon1 golang bindings - [x] Update poseidon1 examples - [x] Fix poseidon2 cuda test - [x] Fix poseidon2 merkle tree builder test - [x] Update keccak class with new design - [x] Update keccak test - [x] Check keccak correctness - [x] Update tree builder rust wrappers - [x] Leave doc comments Future work: - [ ] Add keccak merkle tree builder externs - [ ] Add keccak rust tree builder wrappers - [ ] Write docs - [ ] Add example - [ ] Fix device output for tree builder --------- Co-authored-by: Jeremy Felder <jeremy.felder1@gmail.com> Co-authored-by: nonam3e <71525212+nonam3e@users.noreply.github.com>
168 lines
5.0 KiB
Python
Executable File
168 lines
5.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
This script generates extern declarators for every curve/field
|
|
"""
|
|
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from string import Template
|
|
|
|
API_PATH = Path(__file__).resolve().parent.parent.joinpath("icicle").joinpath("include").joinpath("api")
|
|
TEMPLATES_PATH = API_PATH.joinpath("templates")
|
|
|
|
"""
|
|
Defines a set of curves to generate API for.
|
|
A set corresponding to each curve contains headers that shouldn't be included.
|
|
"""
|
|
CURVES_CONFIG = {
|
|
"bn254": [
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_381": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bls12_377": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"bw6_761": [
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
],
|
|
"grumpkin": {
|
|
"poseidon2.h",
|
|
"curve_g2.h",
|
|
"msm_g2.h",
|
|
"ecntt.h",
|
|
"ntt.h",
|
|
"vec_ops_ext.h",
|
|
"field_ext.h",
|
|
"ntt_ext.h",
|
|
},
|
|
}
|
|
|
|
"""
|
|
Defines a set of fields to generate API for.
|
|
A set corresponding to each field contains headers that shouldn't be included.
|
|
"""
|
|
FIELDS_CONFIG = {
|
|
"babybear": {
|
|
"poseidon.h",
|
|
},
|
|
"stark252": {
|
|
"poseidon.h",
|
|
"poseidon2.h",
|
|
"field_ext.h",
|
|
"vec_ops_ext.h",
|
|
"ntt_ext.h",
|
|
},
|
|
"m31": {
|
|
"ntt_ext.h",
|
|
"ntt.h",
|
|
"poseidon.h",
|
|
"poseidon2.h",
|
|
}
|
|
}
|
|
|
|
# For cudaError_t and device_context
|
|
COMMON_INCLUDES = [
|
|
'#include <cuda_runtime.h>',
|
|
'#include "gpu-utils/device_context.cuh"',
|
|
'#include "merkle-tree/merkle.cuh"',
|
|
'#include "matrix/matrix.cuh"'
|
|
]
|
|
|
|
WARN_TEXT = """\
|
|
// WARNING: This file is auto-generated by a script.
|
|
// Any changes made to this file may be overwritten.
|
|
// Please modify the code generation script instead.
|
|
// Path to the code generation script: scripts/gen_c_api.py
|
|
|
|
"""
|
|
|
|
INCLUDE_ONCE = """\
|
|
#pragma once
|
|
#ifndef {0}_API_H
|
|
#define {0}_API_H
|
|
|
|
"""
|
|
|
|
CURVE_HEADERS = list(TEMPLATES_PATH.joinpath("curves").iterdir())
|
|
FIELD_HEADERS = list(TEMPLATES_PATH.joinpath("fields").iterdir())
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Generate API for ingo_curve
|
|
for curve, skip in CURVES_CONFIG.items():
|
|
curve_api = API_PATH.joinpath(f"{curve}.h")
|
|
|
|
headers = [header for header in chain(CURVE_HEADERS, FIELD_HEADERS) if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "curves/params/{curve}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("msm") for header in headers):
|
|
includes.append('#include "msm/msm.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon.h") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
if any(header.name.startswith("poseidon2.h") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(curve.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"CURVE": curve,
|
|
"FIELD": curve,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(curve_api, "w") as f:
|
|
f.write(contents)
|
|
|
|
|
|
# Generate API for ingo_field
|
|
for field, skip in FIELDS_CONFIG.items():
|
|
field_api = API_PATH.joinpath(f"{field}.h")
|
|
|
|
headers = [header for header in FIELD_HEADERS if header.name not in skip]
|
|
|
|
# Collect includes
|
|
includes = COMMON_INCLUDES.copy()
|
|
includes.append(f'#include "fields/stark_fields/{field}.cuh"')
|
|
if any(header.name.startswith("ntt") for header in headers):
|
|
includes.append('#include "ntt/ntt.cuh"')
|
|
if any(header.name.startswith("vec_ops") for header in headers):
|
|
includes.append('#include "vec_ops/vec_ops.cuh"')
|
|
if any(header.name.startswith("poseidon.h") for header in headers):
|
|
includes.append('#include "poseidon/poseidon.cuh"')
|
|
if any(header.name.startswith("poseidon2.h") for header in headers):
|
|
includes.append('#include "poseidon2/poseidon2.cuh"')
|
|
|
|
contents = WARN_TEXT + INCLUDE_ONCE.format(field.upper()) + "\n".join(includes) + "\n\n"
|
|
for header in headers:
|
|
with open(header) as f:
|
|
template = Template(f.read())
|
|
contents += template.safe_substitute({
|
|
"FIELD": field,
|
|
})
|
|
contents += "\n\n"
|
|
contents += "#endif"
|
|
|
|
with open(field_api, "w") as f:
|
|
f.write(contents) |