diff --git a/src/wallet/walletdb.rs b/src/wallet/walletdb.rs index 40371b7c7..b8d452b85 100644 --- a/src/wallet/walletdb.rs +++ b/src/wallet/walletdb.rs @@ -1,3 +1,4 @@ +use incrementalmerkletree::bridgetree::BridgeTree; use std::{fs::create_dir_all, path::Path, str::FromStr}; use async_std::sync::Arc; @@ -13,6 +14,7 @@ use crate::{ crypto::{ coin::Coin, keypair::{Keypair, PublicKey, SecretKey}, + merkle_node::MerkleNode, note::Note, nullifier::Nullifier, OwnCoin, OwnCoins, @@ -71,11 +73,15 @@ impl WalletDb { pub async fn init_db(&self) -> Result<()> { info!("Initializing wallet database"); + let tree = include_str!("../../sql/tree.sql"); let keys = include_str!("../../sql/keys.sql"); let coins = include_str!("../../sql/coins.sql"); let mut conn = self.conn.acquire().await?; + trace!("Initalizing merkle tree table"); + sqlx::query(tree).execute(&mut conn).await?; + trace!("Initializing keys table"); sqlx::query(keys).execute(&mut conn).await?; @@ -129,6 +135,45 @@ impl WalletDb { Ok(vec![Keypair { public, secret }]) } + pub async fn tree_gen(&self) -> Result<()> { + trace!("Attempting to generate merkle tree"); + let mut conn = self.conn.acquire().await?; + + match sqlx::query("SELECT * FROM tree").fetch_one(&mut conn).await { + Ok(_) => { + error!("Tree already exist"); + Err(Error::from(ClientFailed::TreeExists)) + } + Err(_) => { + let tree = BridgeTree::::new(100); + self.put_tree(tree).await?; + Ok(()) + } + } + } + + pub async fn get_tree(&self) -> Result> { + trace!("Getting merkle tree"); + let mut conn = self.conn.acquire().await?; + + let row = sqlx::query("SELECT tree FROM tree").fetch_one(&mut conn).await?; + let tree: BridgeTree = bincode::deserialize(row.get("tree"))?; + Ok(tree) + } + + pub async fn put_tree(&self, tree: BridgeTree) -> Result<()> { + trace!("Attempting to write merkle tree"); + let mut conn = self.conn.acquire().await?; + + let tree_bytes = bincode::serialize(&tree)?; + sqlx::query("INSERT INTO tree(tree) VALUES (?1)") + .bind(tree_bytes) + .execute(&mut conn) + .await?; + + Ok(()) + } + pub async fn get_own_coins(&self) -> Result { trace!("Finding own coins"); let is_spent = 0; @@ -151,21 +196,12 @@ impl WalletDb { let value_bytes: Vec = row.get("value"); let value = u64::from_le_bytes(value_bytes.try_into().unwrap()); let token_id = self.get_value_deserialized(row.get("token_id"))?; - let note = Note { serial, value, token_id, coin_blind, value_blind }; - // TODO: - // let witness = deserialized(row.6) let secret = self.get_value_deserialized(row.get("secret"))?; let nullifier = self.get_value_deserialized(row.get("nullifier"))?; - let oc = OwnCoin { - coin, - note, - secret, - // witness, - nullifier, - }; + let oc = OwnCoin { coin, note, secret, nullifier }; own_coins.push(oc); } @@ -181,7 +217,6 @@ impl WalletDb { let value_blind = self.get_value_serialized(&own_coin.note.value_blind)?; let value = own_coin.note.value.to_le_bytes(); let token_id = self.get_value_serialized(&own_coin.note.token_id)?; - // TODO: let witness let secret = self.get_value_serialized(&own_coin.secret)?; let is_spent = 0; let nullifier = self.get_value_serialized(&own_coin.nullifier)?; @@ -305,7 +340,11 @@ impl WalletDb { #[cfg(test)] mod tests { use super::*; - use crate::types::{DrkCoinBlind, DrkSerial, DrkValueBlind}; + use crate::{ + crypto::merkle_node::MerkleNode, + types::{DrkCoinBlind, DrkSerial, DrkValueBlind}, + }; + use incrementalmerkletree::bridgetree::BridgeTree; use pasta_curves::{arithmetic::Field, pallas}; use rand::rngs::OsRng; @@ -323,6 +362,7 @@ mod tests { let coin = Coin(pallas::Base::random(&mut OsRng)); let nullifier = Nullifier::new(*s, serial); + OwnCoin { coin, note, secret: *s, nullifier } } @@ -330,6 +370,7 @@ mod tests { async fn test_walletdb() -> Result<()> { let wallet = WalletDb::new("sqlite::memory:", WPASS.to_string()).await?; let keypair = Keypair::random(&mut OsRng); + let tree1 = BridgeTree::::new(100); // init_db() wallet.init_db().await?; @@ -350,6 +391,9 @@ mod tests { wallet.put_own_coins(c2).await?; wallet.put_own_coins(c3).await?; + // put_tree() + wallet.put_tree(tree1).await?; + // get_token_id() let id = wallet.get_token_id().await?; assert_eq!(id.len(), 4); @@ -378,6 +422,10 @@ mod tests { assert_eq!(own_coins[2], c2); assert_eq!(own_coins[3], c3); + // get_tree() + let tree2 = wallet.get_tree().await?; + assert_eq!(tree1, tree2); + Ok(()) } }