mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-10 07:57:56 -05:00
draft Poseidon example for review
This commit is contained in:
@@ -6,41 +6,35 @@ We recommend to run our examples in [ZK-containers](../../ZK-containers.md) to s
|
||||
|
||||
## Key-Takeaway
|
||||
|
||||
`Icicle` provides CUDA C++ template classes to accelerate Zero Knowledge (ZK) applications, for example, a popular [Poseidon hash function](https://www.poseidon-hash.info/).
|
||||
Use class `Poseidon` to instantiate and use the hash function
|
||||
`Icicle` provides CUDA C++ template `poseidon_hash` to accelerate the popular [Poseidon hash function](https://www.poseidon-hash.info/).
|
||||
|
||||
|
||||
### Instantiate hash function
|
||||
## Concise Usage Explanation
|
||||
|
||||
```c++
|
||||
Poseidon<BLS12_381::scalar_t> poseidon(arity, stream);
|
||||
#include "appUtils/poseidon/poseidon.cu"
|
||||
...
|
||||
poseidon_hash<scalar_t, arity+1>(input, output, n, constants, config);
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- **data class:** Here the hash operates on `BLS12_381::scalar_t`, a scalar field of the curve `BLS12-381`.
|
||||
You can think of field's elements as 32-bytes integers modulo `p`, where `p` is a prime number, specific to this field.
|
||||
- **`scalar_t`:** a scalar field of the selected curve. Currently only `BLS12-381`.
|
||||
You can think of field's elements as 32-byte integers modulo `p`, where `p` is a prime number, specific to this field.
|
||||
|
||||
- **arity:** The number of elements in a hashed block.
|
||||
- **arity:** number of elements in a hashed block.
|
||||
|
||||
- **stream:** CUDA streams allow multiple hashes and higher throughput.
|
||||
- **n:** number of blocks we hash in parallel.
|
||||
|
||||
### Hash multiple blocks in parallel
|
||||
- **input, output:** `scalar_t` arrays of size $arity*n$ and $n$ respectively.
|
||||
|
||||
- **constants:** are defined as below
|
||||
|
||||
```c++
|
||||
poseidon.hash_blocks(inBlocks, nBlocks, outHashes, hashType, stream);
|
||||
device_context::DeviceContext ctx= device_context::get_default_device_context();
|
||||
PoseidonConstants<scalar_t> constants;
|
||||
init_optimized_poseidon_constants<scalar_t>(ctx, &constants);
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- **nBlocks:** number of blocks we hash in parallel.
|
||||
|
||||
- **inBlocks:** input array of size `arity*nBlocks`. The blocks are arranged sequentially in the array.
|
||||
|
||||
- **outHashes:** output array of size `nBlocks`.
|
||||
|
||||
- **HashType:** In this example we use `Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree`.
|
||||
|
||||
## What's in the example
|
||||
|
||||
1. Define the size of the example: the height of the full binary Merkle tree.
|
||||
@@ -59,8 +53,7 @@ Our Merkle tree is a **full binary tree** stored in a 1D array.
|
||||
The tree nodes are stored following a level-first traversal of the binary tree.
|
||||
For a given level, we use offset to number elements from left to right. The node numbers on the figure below correspond to their locations in the array.
|
||||
|
||||
|
||||
```
|
||||
```text
|
||||
Tree Level
|
||||
0 0
|
||||
/ \
|
||||
@@ -70,16 +63,10 @@ For a given level, we use offset to number elements from left to right. The node
|
||||
|
||||
1D array representation: {0, 1, 2, 3, 4, 5, 6}
|
||||
```
|
||||
|
||||
### Membership proof structure
|
||||
|
||||
We use two arrays:
|
||||
|
||||
- position (left/right) of the node along the path toward the root
|
||||
- hash of a second node with the same parent
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
- hash of a second node with the same parent
|
||||
|
||||
@@ -4,98 +4,92 @@
|
||||
|
||||
// select the curve (only 2 available so far)
|
||||
#define CURVE_ID 2
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "utils/device_context.cuh"
|
||||
// include Poseidon template
|
||||
#include "appUtils/poseidon/poseidon.cu"
|
||||
using namespace poseidon;
|
||||
using namespace curve_config;
|
||||
|
||||
#define A 2
|
||||
#define T A + 1
|
||||
|
||||
device_context::DeviceContext ctx= device_context::get_default_device_context();
|
||||
|
||||
// location of a tree node in the array for a given level and offset
|
||||
inline uint32_t tree_index(uint32_t level, uint32_t offset) { return (1 << level) - 1 + offset; }
|
||||
|
||||
// We assume the tree has leaves already set, compute all other levels
|
||||
// void build_tree(
|
||||
// const uint32_t tree_height, scalar_t* tree, Poseidon<BLS12_381::scalar_t>& poseidon, cudaStream_t stream)
|
||||
// {
|
||||
// for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
// const uint32_t next_level = level - 1;
|
||||
// const uint32_t next_level_width = 1 << next_level;
|
||||
// poseidon.hash_blocks(
|
||||
// &tree[tree_index(level, 0)], next_level_width, &tree[tree_index(next_level, 0)],
|
||||
// Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
|
||||
// }
|
||||
// }
|
||||
void build_tree(
|
||||
const uint32_t tree_height, scalar_t* tree, PoseidonConstants<scalar_t> * constants, PoseidonConfig config)
|
||||
{
|
||||
for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
const uint32_t next_level = level - 1;
|
||||
const uint32_t next_level_width = 1 << next_level;
|
||||
poseidon_hash<scalar_t, 2+1>(
|
||||
&tree[tree_index(level, 0)], &tree[tree_index(next_level, 0)], next_level_width, *constants, config);
|
||||
}
|
||||
}
|
||||
|
||||
// search leaves for a given hash, return offset
|
||||
// uint32_t query_membership(BLS12_381::scalar_t query, BLS12_381::scalar_t* tree, const uint32_t tree_height)
|
||||
// {
|
||||
// const uint32_t tree_width = (1 << (tree_height - 1));
|
||||
// for (uint32_t i = 0; i < tree_width; i++) {
|
||||
// const BLS12_381::scalar_t leaf = tree[tree_index(tree_height - 1, i)];
|
||||
// if (leaf == query) {
|
||||
// return i; // found the hash
|
||||
// }
|
||||
// }
|
||||
// return tree_height; // hash not found
|
||||
// }
|
||||
// linear search leaves for a given hash, return offset
|
||||
uint32_t query_membership(scalar_t query, scalar_t* tree, const uint32_t tree_height)
|
||||
{
|
||||
const uint32_t tree_width = (1 << (tree_height - 1));
|
||||
for (uint32_t i = 0; i < tree_width; i++) {
|
||||
const scalar_t leaf = tree[tree_index(tree_height - 1, i)];
|
||||
if (leaf == query) {
|
||||
return i; // found the hash
|
||||
}
|
||||
}
|
||||
return tree_height; // hash not found
|
||||
}
|
||||
|
||||
// void generate_proof(
|
||||
// uint32_t position,
|
||||
// BLS12_381::scalar_t* tree,
|
||||
// const uint32_t tree_height,
|
||||
// uint32_t* proof_lr,
|
||||
// BLS12_381::scalar_t* proof_hash)
|
||||
// {
|
||||
// uint32_t level_index = position;
|
||||
// for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
// uint32_t lr;
|
||||
// uint32_t neighbour_index;
|
||||
// lr = level_index % 2;
|
||||
// if (lr == 0) {
|
||||
// // left
|
||||
// neighbour_index = level_index + 1;
|
||||
// } else {
|
||||
// // right
|
||||
// neighbour_index = level_index - 1;
|
||||
// }
|
||||
// proof_lr[level] = lr;
|
||||
// proof_hash[level] = tree[tree_index(level, neighbour_index)];
|
||||
// level_index /= 2;
|
||||
// }
|
||||
// // the proof must match this:
|
||||
// proof_hash[0] = tree[tree_index(0, 0)];
|
||||
// }
|
||||
void generate_proof(
|
||||
uint32_t position,
|
||||
scalar_t* tree,
|
||||
const uint32_t tree_height,
|
||||
uint32_t* proof_lr,
|
||||
scalar_t* proof_hash)
|
||||
{
|
||||
uint32_t level_index = position;
|
||||
for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
uint32_t lr;
|
||||
uint32_t neighbour_index;
|
||||
lr = level_index % 2;
|
||||
if (lr == 0) {
|
||||
// left
|
||||
neighbour_index = level_index + 1;
|
||||
} else {
|
||||
// right
|
||||
neighbour_index = level_index - 1;
|
||||
}
|
||||
proof_lr[level] = lr;
|
||||
proof_hash[level] = tree[tree_index(level, neighbour_index)];
|
||||
level_index /= 2;
|
||||
}
|
||||
// the proof must match this:
|
||||
proof_hash[0] = tree[tree_index(0, 0)];
|
||||
}
|
||||
|
||||
// uint32_t validate_proof(
|
||||
// const BLS12_381::scalar_t hash,
|
||||
// const uint32_t tree_height,
|
||||
// const uint32_t* proof_lr,
|
||||
// const BLS12_381::scalar_t* proof_hash,
|
||||
// Poseidon<BLS12_381::scalar_t>& poseidon,
|
||||
// cudaStream_t stream)
|
||||
// {
|
||||
// BLS12_381::scalar_t hashes_in[2], hash_out[1], level_hash;
|
||||
// level_hash = hash;
|
||||
// for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
// if (proof_lr[level] == 0) {
|
||||
// hashes_in[0] = level_hash;
|
||||
// hashes_in[1] = proof_hash[level];
|
||||
// } else {
|
||||
// hashes_in[0] = proof_hash[level];
|
||||
// hashes_in[1] = level_hash;
|
||||
// }
|
||||
// // next level hash
|
||||
// poseidon.hash_blocks(hashes_in, 1, hash_out, Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
|
||||
// level_hash = hash_out[0];
|
||||
// }
|
||||
// return proof_hash[0] == level_hash;
|
||||
// }
|
||||
uint32_t validate_proof(
|
||||
const scalar_t hash,
|
||||
const uint32_t tree_height,
|
||||
const uint32_t* proof_lr,
|
||||
const scalar_t* proof_hash,
|
||||
PoseidonConstants<scalar_t> * constants,
|
||||
PoseidonConfig config)
|
||||
{
|
||||
scalar_t hashes_in[2], hash_out[1], level_hash;
|
||||
level_hash = hash;
|
||||
for (uint32_t level = tree_height - 1; level > 0; level--) {
|
||||
if (proof_lr[level] == 0) {
|
||||
hashes_in[0] = level_hash;
|
||||
hashes_in[1] = proof_hash[level];
|
||||
} else {
|
||||
hashes_in[0] = proof_hash[level];
|
||||
hashes_in[1] = level_hash;
|
||||
}
|
||||
// next level hash
|
||||
poseidon_hash<scalar_t, 2+1>(hashes_in, hash_out, 1, *constants, config);
|
||||
level_hash = hash_out[0];
|
||||
}
|
||||
return proof_hash[0] == level_hash;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -111,51 +105,49 @@ int main(int argc, char* argv[])
|
||||
scalar_t* tree = static_cast<scalar_t*>(malloc(tree_size * sizeof(scalar_t)));
|
||||
|
||||
std::cout << "2. Hashing blocks in parallel" << std::endl;
|
||||
|
||||
// ctx = device_context::get_default_device_context();
|
||||
init_optimized_poseidon_constants<scalar_t, T>(ctx);
|
||||
|
||||
const uint32_t data_arity = 2;
|
||||
const uint32_t data_arity = 4;
|
||||
std::cout << "Block size (arity): " << data_arity << std::endl;
|
||||
std::cout << "Initializing blocks..." << std::endl;
|
||||
scalar_t d = scalar_t::zero();
|
||||
scalar_t* data = static_cast<scalar_t*>(malloc(tree_width * data_arity * sizeof(scalar_t)));
|
||||
|
||||
for (uint32_t i = 0; i < tree_width * data_arity; i++) {
|
||||
data[i] = d;
|
||||
d = d + scalar_t::one();
|
||||
}
|
||||
|
||||
std::cout << "Hashing blocks into tree leaves..." << std::endl;
|
||||
PoseidonConfig config = default_poseidon_config(T);
|
||||
poseidon_hash<curve_config::scalar_t, T>(data, &tree[tree_index(leaf_level, 0)], tree_width, preloaded_constants<curve_config::scalar_t, T>, config);
|
||||
// data_poseidon.hash_blocks(data, tree_width, &tree[tree_index(leaf_level, 0)], Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
|
||||
PoseidonConstants<scalar_t> constants;
|
||||
init_optimized_poseidon_constants<scalar_t>(data_arity, ctx, &constants);
|
||||
PoseidonConfig config = default_poseidon_config(data_arity+1);
|
||||
poseidon_hash<curve_config::scalar_t, data_arity+1>(data, &tree[tree_index(leaf_level, 0)], tree_width, constants, config);
|
||||
|
||||
// std::cout << "3. Building Merkle tree" << std::endl;
|
||||
std::cout << "3. Building Merkle tree" << std::endl;
|
||||
// Poseidon<BLS12_381::scalar_t> tree_poseidon(tree_arity, stream);
|
||||
// build_tree(tree_height, tree, tree_poseidon, stream);
|
||||
PoseidonConstants<scalar_t> tree_constants;
|
||||
init_optimized_poseidon_constants<scalar_t>(tree_arity, ctx, &tree_constants);
|
||||
PoseidonConfig tree_config = default_poseidon_config(tree_arity+1);
|
||||
build_tree(tree_height, tree, &tree_constants, tree_config);
|
||||
|
||||
// std::cout << "4. Generate membership proof" << std::endl;
|
||||
// uint32_t position = tree_width - 1;
|
||||
// std::cout << "Using the hash for block: " << position << std::endl;
|
||||
// BLS12_381::scalar_t query = tree[tree_index(leaf_level, position)];
|
||||
// uint32_t query_position = query_membership(query, tree, tree_height);
|
||||
// // allocate arrays for the proof
|
||||
// uint32_t* proof_lr = static_cast<uint32_t*>(malloc(tree_height * sizeof(uint32_t)));
|
||||
// BLS12_381::scalar_t* proof_hash =
|
||||
// static_cast<BLS12_381::scalar_t*>(malloc(tree_height * sizeof(BLS12_381::scalar_t)));
|
||||
// generate_proof(query_position, tree, tree_height, proof_lr, proof_hash);
|
||||
std::cout << "4. Generate membership proof" << std::endl;
|
||||
uint32_t position = tree_width - 1;
|
||||
std::cout << "Using the hash for block: " << position << std::endl;
|
||||
scalar_t query = tree[tree_index(leaf_level, position)];
|
||||
uint32_t query_position = query_membership(query, tree, tree_height);
|
||||
// allocate arrays for the proof
|
||||
uint32_t* proof_lr = static_cast<uint32_t*>(malloc(tree_height * sizeof(uint32_t)));
|
||||
scalar_t* proof_hash = static_cast<scalar_t*>(malloc(tree_height * sizeof(scalar_t)));
|
||||
generate_proof(query_position, tree, tree_height, proof_lr, proof_hash);
|
||||
|
||||
// std::cout << "5. Validate the hash membership" << std::endl;
|
||||
// uint32_t validated;
|
||||
// const BLS12_381::scalar_t hash = tree[tree_index(leaf_level, query_position)];
|
||||
// validated = validate_proof(hash, tree_height, proof_lr, proof_hash, tree_poseidon, stream);
|
||||
// std::cout << "Validated: " << validated << std::endl;
|
||||
std::cout << "5. Validate the hash membership" << std::endl;
|
||||
uint32_t validated;
|
||||
const scalar_t hash = tree[tree_index(leaf_level, query_position)];
|
||||
validated = validate_proof(hash, tree_height, proof_lr, proof_hash, &tree_constants, tree_config);
|
||||
std::cout << "Validated: " << validated << std::endl;
|
||||
|
||||
// std::cout << "6. Tamper the hash" << std::endl;
|
||||
// const BLS12_381::scalar_t tampered_hash = hash + BLS12_381::scalar_t::one();
|
||||
// validated = validate_proof(tampered_hash, tree_height, proof_lr, proof_hash, tree_poseidon, stream);
|
||||
// std::cout << "7. Invalidate tamper hash membership" << std::endl;
|
||||
// std::cout << "Validated: " << validated << std::endl;
|
||||
std::cout << "6. Tamper the hash" << std::endl;
|
||||
const scalar_t tampered_hash = hash + scalar_t::one();
|
||||
validated = validate_proof(tampered_hash, tree_height, proof_lr, proof_hash, &tree_constants, tree_config);
|
||||
|
||||
std::cout << "7. Invalidate tamper hash membership" << std::endl;
|
||||
std::cout << "Validated: " << validated << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user