diff --git a/src/node/client.rs b/src/node/client.rs index 1bf7cbab7..d5927fde2 100644 --- a/src/node/client.rs +++ b/src/node/client.rs @@ -41,16 +41,9 @@ impl Client { // Initialize or load the wallet wallet.init_db().await?; - // Check if there is a default keypair and generate one in - // case we don't have any. - if wallet.get_default_keypair().await.is_err() { - // TODO: Clean this up with Option to have less calls. - if wallet.get_keypairs().await?.is_empty() { - wallet.keygen().await?; - } - } - - wallet.set_default_keypair(&wallet.get_keypairs().await?[0].public).await?; + // Get default keypair or create one + let main_keypair = wallet.get_default_keypair_or_create_one().await?; + info!(target: "client", "Main keypair: {}", Address::from(main_keypair.public)); // Generate merkle tree if we don't have one. // TODO: See what to do about this @@ -58,9 +51,6 @@ impl Client { wallet.tree_gen().await?; } - let main_keypair = wallet.get_default_keypair().await?; - info!(target: "client", "Main keypair: {}", Address::from(main_keypair.public)); - Ok(Self { main_keypair: Mutex::new(main_keypair), wallet, @@ -252,8 +242,7 @@ impl Client { } pub async fn set_default_keypair(&self, public: &PublicKey) -> Result<()> { - self.wallet.set_default_keypair(public).await?; - let kp = self.wallet.get_default_keypair().await?; + let kp = self.wallet.set_default_keypair(public).await?; let mut mk = self.main_keypair.lock().await; *mk = kp; drop(mk); diff --git a/src/wallet/walletdb.rs b/src/wallet/walletdb.rs index 7e9517502..d97c8eb7a 100644 --- a/src/wallet/walletdb.rs +++ b/src/wallet/walletdb.rs @@ -128,7 +128,7 @@ impl WalletDb { Ok(()) } - pub async fn set_default_keypair(&self, public: &PublicKey) -> Result<()> { + pub async fn set_default_keypair(&self, public: &PublicKey) -> Result { debug!("Set default keypair"); let mut conn = self.conn.acquire().await?; @@ -143,14 +143,15 @@ impl WalletDb { .execute(&mut conn) .await?; - Ok(()) + let keypair = self.get_default_keypair().await?; + Ok(keypair) } - pub async fn get_default_keypair(&self) -> Result { + async fn get_default_keypair(&self) -> Result { debug!("Returning default keypair"); let mut conn = self.conn.acquire().await?; - let is_default = 1; + let is_default: u32 = 1; let row = sqlx::query("SELECT * FROM keys WHERE is_default = ?1;") .bind(is_default) @@ -165,11 +166,28 @@ impl WalletDb { pub async fn get_default_address(&self) -> Result
{ debug!("Returning default address"); - let keypair = self.get_default_keypair().await?; + let keypair = self.get_default_keypair_or_create_one().await?; Ok(Address::from(keypair.public)) } + pub async fn get_default_keypair_or_create_one(&self) -> Result { + debug!("Returning default keypair or create one"); + + let default_keypair = self.get_default_keypair().await; + + let keypair = if default_keypair.is_err() { + let keypairs = self.get_keypairs().await?; + let kp = if keypairs.is_empty() { self.keygen().await? } else { keypairs[0] }; + self.set_default_keypair(&kp.public).await?; + kp + } else { + default_keypair? + }; + + Ok(keypair) + } + pub async fn get_keypairs(&self) -> Result> { debug!("Returning keypairs"); let mut conn = self.conn.acquire().await?; @@ -498,7 +516,7 @@ mod tests { // set the keypair at index 1 as the default keypair wallet.set_default_keypair(&keypair2.public).await?; // get default keypair - assert_eq!(keypair2, wallet.get_default_keypair().await?); + assert_eq!(keypair2, wallet.get_default_keypair_or_create_one().await?); // get_own_coins() let own_coins = wallet.get_own_coins().await?;