From 4447f658a9b7599fdafb54b78e7c3999d74cfcef Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Thu, 4 Jul 2024 08:53:22 -0700 Subject: [PATCH] feat(trie): allow setting hashed cursor factory on `Proof` (#9304) --- .../provider/src/providers/state/latest.rs | 2 +- crates/trie/trie/src/proof.rs | 31 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/crates/storage/provider/src/providers/state/latest.rs b/crates/storage/provider/src/providers/state/latest.rs index 2ad66bbfd8..a1e8256cf1 100644 --- a/crates/storage/provider/src/providers/state/latest.rs +++ b/crates/storage/provider/src/providers/state/latest.rs @@ -93,7 +93,7 @@ impl<'b, TX: DbTx> StateRootProvider for LatestStateProviderRef<'b, TX> { impl<'b, TX: DbTx> StateProofProvider for LatestStateProviderRef<'b, TX> { fn proof(&self, address: Address, slots: &[B256]) -> ProviderResult { - Ok(Proof::new(self.tx) + Ok(Proof::from_tx(self.tx) .account_proof(address, slots) .map_err(Into::::into)?) } diff --git a/crates/trie/trie/src/proof.rs b/crates/trie/trie/src/proof.rs index 65bce7d286..2342ece986 100644 --- a/crates/trie/trie/src/proof.rs +++ b/crates/trie/trie/src/proof.rs @@ -26,10 +26,22 @@ pub struct Proof<'a, TX, H> { hashed_cursor_factory: H, } +impl<'a, TX, H> Proof<'a, TX, H> { + /// Creates a new proof generator. + pub const fn new(tx: &'a TX, hashed_cursor_factory: H) -> Self { + Self { tx, hashed_cursor_factory } + } + + /// Set the hashed cursor factory. + pub fn with_hashed_cursor_factory(self, hashed_cursor_factory: HF) -> Proof<'a, TX, HF> { + Proof { tx: self.tx, hashed_cursor_factory } + } +} + impl<'a, TX> Proof<'a, TX, &'a TX> { - /// Create a new [Proof] instance. - pub const fn new(tx: &'a TX) -> Self { - Self { tx, hashed_cursor_factory: tx } + /// Create a new [Proof] instance from database transaction. + pub const fn from_tx(tx: &'a TX) -> Self { + Self::new(tx, tx) } } @@ -282,7 +294,8 @@ mod tests { let provider = factory.provider().unwrap(); for (target, expected_proof) in data { let target = Address::from_str(target).unwrap(); - let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); + let account_proof = + Proof::from_tx(provider.tx_ref()).account_proof(target, &[]).unwrap(); similar_asserts::assert_eq!( account_proof.proof, expected_proof, @@ -302,7 +315,8 @@ mod tests { let slots = Vec::from([B256::with_last_byte(1), B256::with_last_byte(3)]); let provider = factory.provider().unwrap(); - let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &slots).unwrap(); + let account_proof = + Proof::from_tx(provider.tx_ref()).account_proof(target, &slots).unwrap(); assert_eq!(account_proof.storage_root, EMPTY_ROOT_HASH, "expected empty storage root"); assert_eq!(slots.len(), account_proof.storage_proofs.len()); @@ -334,7 +348,7 @@ mod tests { ]); let provider = factory.provider().unwrap(); - let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); + let account_proof = Proof::from_tx(provider.tx_ref()).account_proof(target, &[]).unwrap(); similar_asserts::assert_eq!(account_proof.proof, expected_account_proof); assert_eq!(account_proof.verify(root), Ok(())); } @@ -357,7 +371,7 @@ mod tests { ]); let provider = factory.provider().unwrap(); - let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); + let account_proof = Proof::from_tx(provider.tx_ref()).account_proof(target, &[]).unwrap(); similar_asserts::assert_eq!(account_proof.proof, expected_account_proof); assert_eq!(account_proof.verify(root), Ok(())); } @@ -443,7 +457,8 @@ mod tests { }; let provider = factory.provider().unwrap(); - let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &slots).unwrap(); + let account_proof = + Proof::from_tx(provider.tx_ref()).account_proof(target, &slots).unwrap(); similar_asserts::assert_eq!(account_proof, expected); assert_eq!(account_proof.verify(root), Ok(())); }