diff --git a/src/wallet/walletdb.rs b/src/wallet/walletdb.rs index 95073d9d4..8d13b64a2 100644 --- a/src/wallet/walletdb.rs +++ b/src/wallet/walletdb.rs @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -use std::path::PathBuf; +use std::{any::Any, path::PathBuf}; use async_std::sync::{Arc, Mutex}; use log::{debug, info}; @@ -55,6 +55,25 @@ impl From for QueryType { } } +#[derive(Debug)] +pub enum SqlType { + Integer(i64), + Text(String), + Blob(Vec), + Null, +} + +impl SqlType { + pub fn inner(&self) -> Option<&T> { + match self { + SqlType::Integer(v) => (v as &dyn Any).downcast_ref::(), + SqlType::Text(v) => (v as &dyn Any).downcast_ref::(), + SqlType::Blob(v) => (v as &dyn Any).downcast_ref::(), + SqlType::Null => None, + } + } +} + /// Structure representing base wallet operations. /// Additional operations can be implemented by trait extensions. pub struct WalletDb { @@ -86,6 +105,57 @@ impl WalletDb { let _ = self.conn.lock().await.execute(query, ())?; Ok(()) } + + pub async fn query_single( + &self, + table: &str, + col_names: Vec<&str>, + where_queries: Option>, + ) -> Result> { + let mut query = format!("SELECT {} FROM {}", col_names.join(", "), table); + + if let Some(wq) = where_queries.as_ref() { + let where_str: Vec = wq.iter().map(|(k, _)| format!("{} = ?", k)).collect(); + query.push_str(&format!(" WHERE {}", where_str.join(" AND "))); + } + + let params: Vec = where_queries.map_or(Vec::new(), |wq| { + wq.into_iter() + .map(|(_, v)| match v { + SqlType::Integer(i) => rusqlite::types::ToSqlOutput::from(i), + SqlType::Text(t) => rusqlite::types::ToSqlOutput::from(t), + SqlType::Blob(b) => rusqlite::types::ToSqlOutput::from(b), + SqlType::Null => rusqlite::types::ToSqlOutput::from(rusqlite::types::Null), + }) + .collect::>() + }); + + let wallet_conn = self.conn.lock().await; + let mut stmt = wallet_conn.prepare(&query)?; + let params_as_slice: Vec<&dyn rusqlite::ToSql> = + params.iter().map(|x| x as &dyn rusqlite::ToSql).collect(); + let mut rows = stmt.query(params_as_slice.as_slice())?; + + let row = match rows.next()? { + Some(row_result) => row_result, + None => return Ok(vec![]), + }; + + let mut result = vec![]; + for (idx, _) in col_names.iter().enumerate() { + let value: SqlType = match row.get_ref(idx)?.data_type() { + rusqlite::types::Type::Integer => SqlType::Integer(row.get(idx)?), + rusqlite::types::Type::Text => SqlType::Text(row.get(idx)?), + rusqlite::types::Type::Blob => SqlType::Blob(row.get(idx)?), + rusqlite::types::Type::Null => SqlType::Null, + _ => unimplemented!(), + }; + + result.push(value); + } + + Ok(result) + } } #[cfg(test)] @@ -104,4 +174,36 @@ mod tests { stmt.finalize().unwrap(); assert!(numba == 42); } + + #[async_std::test] + async fn test_query_single() { + let wallet = WalletDb::new(None, None).unwrap(); + wallet + .exec_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );") + .await + .unwrap(); + + let why = 42; + let are = "are".to_string(); + let you = 69; + let gae = vec![42u8; 32]; + + let query_str = + format!("INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);"); + + let wallet_conn = wallet.conn.lock().await; + let mut stmt = wallet_conn.prepare(&query_str).unwrap(); + stmt.execute(rusqlite::params![why, are, you, gae]).unwrap(); + stmt.finalize().unwrap(); + drop(wallet_conn); + + let ret = + wallet.query_single("mista", vec!["why", "are", "you", "gae"], None).await.unwrap(); + assert!(ret.len() == 4); + + assert!(ret[0].inner::().unwrap() == &why); + assert!(ret[1].inner::().unwrap() == &are); + assert!(ret[2].inner::().unwrap() == &you); + assert!(ret[3].inner::>().unwrap() == &gae); + } }