draft Poseidon example for review

This commit is contained in:
stas
2024-01-31 12:50:42 -05:00
parent bcfb240ad5
commit 4a4bb5a2aa
2 changed files with 121 additions and 142 deletions

View File

@@ -6,41 +6,35 @@ We recommend to run our examples in [ZK-containers](../../ZK-containers.md) to s
## Key-Takeaway
`Icicle` provides CUDA C++ template classes to accelerate Zero Knowledge (ZK) applications, for example, a popular [Poseidon hash function](https://www.poseidon-hash.info/).
Use class `Poseidon` to instantiate and use the hash function
`Icicle` provides CUDA C++ template `poseidon_hash` to accelerate the popular [Poseidon hash function](https://www.poseidon-hash.info/).
### Instantiate hash function
## Concise Usage Explanation
```c++
Poseidon<BLS12_381::scalar_t> poseidon(arity, stream);
#include "appUtils/poseidon/poseidon.cu"
...
poseidon_hash<scalar_t, arity+1>(input, output, n, constants, config);
```
**Parameters:**
- **data class:** Here the hash operates on `BLS12_381::scalar_t`, a scalar field of the curve `BLS12-381`.
You can think of field's elements as 32-bytes integers modulo `p`, where `p` is a prime number, specific to this field.
- **`scalar_t`:** a scalar field of the selected curve. Currently only `BLS12-381`.
You can think of field's elements as 32-byte integers modulo `p`, where `p` is a prime number, specific to this field.
- **arity:** The number of elements in a hashed block.
- **arity:** number of elements in a hashed block.
- **stream:** CUDA streams allow multiple hashes and higher throughput.
- **n:** number of blocks we hash in parallel.
### Hash multiple blocks in parallel
- **input, output:** `scalar_t` arrays of size $arity*n$ and $n$ respectively.
- **constants:** are defined as below
```c++
poseidon.hash_blocks(inBlocks, nBlocks, outHashes, hashType, stream);
device_context::DeviceContext ctx= device_context::get_default_device_context();
PoseidonConstants<scalar_t> constants;
init_optimized_poseidon_constants<scalar_t>(ctx, &constants);
```
**Parameters:**
- **nBlocks:** number of blocks we hash in parallel.
- **inBlocks:** input array of size `arity*nBlocks`. The blocks are arranged sequentially in the array.
- **outHashes:** output array of size `nBlocks`.
- **HashType:** In this example we use `Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree`.
## What's in the example
1. Define the size of the example: the height of the full binary Merkle tree.
@@ -59,8 +53,7 @@ Our Merkle tree is a **full binary tree** stored in a 1D array.
The tree nodes are stored following a level-first traversal of the binary tree.
For a given level, we use offset to number elements from left to right. The node numbers on the figure below correspond to their locations in the array.
```
```text
Tree Level
0 0
/ \
@@ -70,16 +63,10 @@ For a given level, we use offset to number elements from left to right. The node
1D array representation: {0, 1, 2, 3, 4, 5, 6}
```
### Membership proof structure
We use two arrays:
- position (left/right) of the node along the path toward the root
- hash of a second node with the same parent
- hash of a second node with the same parent

View File

@@ -4,98 +4,92 @@
// select the curve (only 2 available so far)
#define CURVE_ID 2
#include "curves/curve_config.cuh"
#include "utils/device_context.cuh"
// include Poseidon template
#include "appUtils/poseidon/poseidon.cu"
using namespace poseidon;
using namespace curve_config;
#define A 2
#define T A + 1
device_context::DeviceContext ctx= device_context::get_default_device_context();
// location of a tree node in the array for a given level and offset
inline uint32_t tree_index(uint32_t level, uint32_t offset) { return (1 << level) - 1 + offset; }
// We assume the tree has leaves already set, compute all other levels
// void build_tree(
// const uint32_t tree_height, scalar_t* tree, Poseidon<BLS12_381::scalar_t>& poseidon, cudaStream_t stream)
// {
// for (uint32_t level = tree_height - 1; level > 0; level--) {
// const uint32_t next_level = level - 1;
// const uint32_t next_level_width = 1 << next_level;
// poseidon.hash_blocks(
// &tree[tree_index(level, 0)], next_level_width, &tree[tree_index(next_level, 0)],
// Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
// }
// }
void build_tree(
const uint32_t tree_height, scalar_t* tree, PoseidonConstants<scalar_t> * constants, PoseidonConfig config)
{
for (uint32_t level = tree_height - 1; level > 0; level--) {
const uint32_t next_level = level - 1;
const uint32_t next_level_width = 1 << next_level;
poseidon_hash<scalar_t, 2+1>(
&tree[tree_index(level, 0)], &tree[tree_index(next_level, 0)], next_level_width, *constants, config);
}
}
// search leaves for a given hash, return offset
// uint32_t query_membership(BLS12_381::scalar_t query, BLS12_381::scalar_t* tree, const uint32_t tree_height)
// {
// const uint32_t tree_width = (1 << (tree_height - 1));
// for (uint32_t i = 0; i < tree_width; i++) {
// const BLS12_381::scalar_t leaf = tree[tree_index(tree_height - 1, i)];
// if (leaf == query) {
// return i; // found the hash
// }
// }
// return tree_height; // hash not found
// }
// linear search leaves for a given hash, return offset
uint32_t query_membership(scalar_t query, scalar_t* tree, const uint32_t tree_height)
{
const uint32_t tree_width = (1 << (tree_height - 1));
for (uint32_t i = 0; i < tree_width; i++) {
const scalar_t leaf = tree[tree_index(tree_height - 1, i)];
if (leaf == query) {
return i; // found the hash
}
}
return tree_height; // hash not found
}
// void generate_proof(
// uint32_t position,
// BLS12_381::scalar_t* tree,
// const uint32_t tree_height,
// uint32_t* proof_lr,
// BLS12_381::scalar_t* proof_hash)
// {
// uint32_t level_index = position;
// for (uint32_t level = tree_height - 1; level > 0; level--) {
// uint32_t lr;
// uint32_t neighbour_index;
// lr = level_index % 2;
// if (lr == 0) {
// // left
// neighbour_index = level_index + 1;
// } else {
// // right
// neighbour_index = level_index - 1;
// }
// proof_lr[level] = lr;
// proof_hash[level] = tree[tree_index(level, neighbour_index)];
// level_index /= 2;
// }
// // the proof must match this:
// proof_hash[0] = tree[tree_index(0, 0)];
// }
void generate_proof(
uint32_t position,
scalar_t* tree,
const uint32_t tree_height,
uint32_t* proof_lr,
scalar_t* proof_hash)
{
uint32_t level_index = position;
for (uint32_t level = tree_height - 1; level > 0; level--) {
uint32_t lr;
uint32_t neighbour_index;
lr = level_index % 2;
if (lr == 0) {
// left
neighbour_index = level_index + 1;
} else {
// right
neighbour_index = level_index - 1;
}
proof_lr[level] = lr;
proof_hash[level] = tree[tree_index(level, neighbour_index)];
level_index /= 2;
}
// the proof must match this:
proof_hash[0] = tree[tree_index(0, 0)];
}
// uint32_t validate_proof(
// const BLS12_381::scalar_t hash,
// const uint32_t tree_height,
// const uint32_t* proof_lr,
// const BLS12_381::scalar_t* proof_hash,
// Poseidon<BLS12_381::scalar_t>& poseidon,
// cudaStream_t stream)
// {
// BLS12_381::scalar_t hashes_in[2], hash_out[1], level_hash;
// level_hash = hash;
// for (uint32_t level = tree_height - 1; level > 0; level--) {
// if (proof_lr[level] == 0) {
// hashes_in[0] = level_hash;
// hashes_in[1] = proof_hash[level];
// } else {
// hashes_in[0] = proof_hash[level];
// hashes_in[1] = level_hash;
// }
// // next level hash
// poseidon.hash_blocks(hashes_in, 1, hash_out, Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
// level_hash = hash_out[0];
// }
// return proof_hash[0] == level_hash;
// }
uint32_t validate_proof(
const scalar_t hash,
const uint32_t tree_height,
const uint32_t* proof_lr,
const scalar_t* proof_hash,
PoseidonConstants<scalar_t> * constants,
PoseidonConfig config)
{
scalar_t hashes_in[2], hash_out[1], level_hash;
level_hash = hash;
for (uint32_t level = tree_height - 1; level > 0; level--) {
if (proof_lr[level] == 0) {
hashes_in[0] = level_hash;
hashes_in[1] = proof_hash[level];
} else {
hashes_in[0] = proof_hash[level];
hashes_in[1] = level_hash;
}
// next level hash
poseidon_hash<scalar_t, 2+1>(hashes_in, hash_out, 1, *constants, config);
level_hash = hash_out[0];
}
return proof_hash[0] == level_hash;
}
int main(int argc, char* argv[])
{
@@ -111,51 +105,49 @@ int main(int argc, char* argv[])
scalar_t* tree = static_cast<scalar_t*>(malloc(tree_size * sizeof(scalar_t)));
std::cout << "2. Hashing blocks in parallel" << std::endl;
// ctx = device_context::get_default_device_context();
init_optimized_poseidon_constants<scalar_t, T>(ctx);
const uint32_t data_arity = 2;
const uint32_t data_arity = 4;
std::cout << "Block size (arity): " << data_arity << std::endl;
std::cout << "Initializing blocks..." << std::endl;
scalar_t d = scalar_t::zero();
scalar_t* data = static_cast<scalar_t*>(malloc(tree_width * data_arity * sizeof(scalar_t)));
for (uint32_t i = 0; i < tree_width * data_arity; i++) {
data[i] = d;
d = d + scalar_t::one();
}
std::cout << "Hashing blocks into tree leaves..." << std::endl;
PoseidonConfig config = default_poseidon_config(T);
poseidon_hash<curve_config::scalar_t, T>(data, &tree[tree_index(leaf_level, 0)], tree_width, preloaded_constants<curve_config::scalar_t, T>, config);
// data_poseidon.hash_blocks(data, tree_width, &tree[tree_index(leaf_level, 0)], Poseidon<BLS12_381::scalar_t>::HashType::MerkleTree, stream);
PoseidonConstants<scalar_t> constants;
init_optimized_poseidon_constants<scalar_t>(data_arity, ctx, &constants);
PoseidonConfig config = default_poseidon_config(data_arity+1);
poseidon_hash<curve_config::scalar_t, data_arity+1>(data, &tree[tree_index(leaf_level, 0)], tree_width, constants, config);
// std::cout << "3. Building Merkle tree" << std::endl;
std::cout << "3. Building Merkle tree" << std::endl;
// Poseidon<BLS12_381::scalar_t> tree_poseidon(tree_arity, stream);
// build_tree(tree_height, tree, tree_poseidon, stream);
PoseidonConstants<scalar_t> tree_constants;
init_optimized_poseidon_constants<scalar_t>(tree_arity, ctx, &tree_constants);
PoseidonConfig tree_config = default_poseidon_config(tree_arity+1);
build_tree(tree_height, tree, &tree_constants, tree_config);
// std::cout << "4. Generate membership proof" << std::endl;
// uint32_t position = tree_width - 1;
// std::cout << "Using the hash for block: " << position << std::endl;
// BLS12_381::scalar_t query = tree[tree_index(leaf_level, position)];
// uint32_t query_position = query_membership(query, tree, tree_height);
// // allocate arrays for the proof
// uint32_t* proof_lr = static_cast<uint32_t*>(malloc(tree_height * sizeof(uint32_t)));
// BLS12_381::scalar_t* proof_hash =
// static_cast<BLS12_381::scalar_t*>(malloc(tree_height * sizeof(BLS12_381::scalar_t)));
// generate_proof(query_position, tree, tree_height, proof_lr, proof_hash);
std::cout << "4. Generate membership proof" << std::endl;
uint32_t position = tree_width - 1;
std::cout << "Using the hash for block: " << position << std::endl;
scalar_t query = tree[tree_index(leaf_level, position)];
uint32_t query_position = query_membership(query, tree, tree_height);
// allocate arrays for the proof
uint32_t* proof_lr = static_cast<uint32_t*>(malloc(tree_height * sizeof(uint32_t)));
scalar_t* proof_hash = static_cast<scalar_t*>(malloc(tree_height * sizeof(scalar_t)));
generate_proof(query_position, tree, tree_height, proof_lr, proof_hash);
// std::cout << "5. Validate the hash membership" << std::endl;
// uint32_t validated;
// const BLS12_381::scalar_t hash = tree[tree_index(leaf_level, query_position)];
// validated = validate_proof(hash, tree_height, proof_lr, proof_hash, tree_poseidon, stream);
// std::cout << "Validated: " << validated << std::endl;
std::cout << "5. Validate the hash membership" << std::endl;
uint32_t validated;
const scalar_t hash = tree[tree_index(leaf_level, query_position)];
validated = validate_proof(hash, tree_height, proof_lr, proof_hash, &tree_constants, tree_config);
std::cout << "Validated: " << validated << std::endl;
// std::cout << "6. Tamper the hash" << std::endl;
// const BLS12_381::scalar_t tampered_hash = hash + BLS12_381::scalar_t::one();
// validated = validate_proof(tampered_hash, tree_height, proof_lr, proof_hash, tree_poseidon, stream);
// std::cout << "7. Invalidate tamper hash membership" << std::endl;
// std::cout << "Validated: " << validated << std::endl;
std::cout << "6. Tamper the hash" << std::endl;
const scalar_t tampered_hash = hash + scalar_t::one();
validated = validate_proof(tampered_hash, tree_height, proof_lr, proof_hash, &tree_constants, tree_config);
std::cout << "7. Invalidate tamper hash membership" << std::endl;
std::cout << "Validated: " << validated << std::endl;
return 0;
}