diff --git a/crates/common/rlp-derive/src/de.rs b/crates/common/rlp-derive/src/de.rs index 5afd98ca9f..4b7123a2e2 100644 --- a/crates/common/rlp-derive/src/de.rs +++ b/crates/common/rlp-derive/src/de.rs @@ -1,6 +1,8 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::utils::has_attribute; + pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> TokenStream { let body = if let syn::Data::Struct(s) = &ast.data { s @@ -91,5 +93,9 @@ fn decodable_field(index: usize, field: &syn::Field) -> TokenStream { quote! { #index } }; - quote! { #id: reth_rlp::Decodable::decode(b)?, } + if has_attribute(field, "default") { + quote! { #id: Default::default(), } + } else { + quote! { #id: reth_rlp::Decodable::decode(b)?, } + } } diff --git a/crates/common/rlp-derive/src/en.rs b/crates/common/rlp-derive/src/en.rs index e35774b596..46dce4cfdb 100644 --- a/crates/common/rlp-derive/src/en.rs +++ b/crates/common/rlp-derive/src/en.rs @@ -1,6 +1,8 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::utils::has_attribute; + pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream { let body = if let syn::Data::Struct(s) = &ast.data { s @@ -8,11 +10,14 @@ pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream { panic!("#[derive(RlpEncodable)] is only defined for structs."); }; - let length_stmts: Vec<_> = - body.fields.iter().enumerate().map(|(i, field)| encodable_length(i, field)).collect(); + let (length_stmts, stmts): (Vec<_>, Vec<_>) = body + .fields + .iter() + .enumerate() + .filter(|(_, field)| !has_attribute(field, "skip")) + .map(|(i, field)| (encodable_length(i, field), encodable_field(i, field))) + .unzip(); - let stmts: Vec<_> = - body.fields.iter().enumerate().map(|(i, field)| encodable_field(i, field)).collect(); let name = &ast.ident; let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); @@ -92,13 +97,14 @@ pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> TokenStream { let body = if let syn::Data::Struct(s) = &ast.data { s } else { - panic!("#[derive(RlpEncodable)] is only defined for structs."); + panic!("#[derive(RlpMaxEncodedLen)] is only defined for structs."); }; let stmts: Vec<_> = body .fields .iter() .enumerate() + .filter(|(_, field)| !has_attribute(field, "skip")) .map(|(index, field)| encodable_max_length(index, field)) .collect(); let name = &ast.ident; diff --git a/crates/common/rlp-derive/src/lib.rs b/crates/common/rlp-derive/src/lib.rs index 48288678af..c6715a3213 100644 --- a/crates/common/rlp-derive/src/lib.rs +++ b/crates/common/rlp-derive/src/lib.rs @@ -19,6 +19,7 @@ extern crate proc_macro; mod de; mod en; +mod utils; use de::*; use en::*; diff --git a/crates/common/rlp-derive/src/utils.rs b/crates/common/rlp-derive/src/utils.rs new file mode 100644 index 0000000000..cae47fe892 --- /dev/null +++ b/crates/common/rlp-derive/src/utils.rs @@ -0,0 +1,16 @@ +use syn::{Field, Meta, NestedMeta}; + +pub(crate) fn has_attribute(field: &Field, attr_name: &str) -> bool { + field.attrs.iter().any(|attr| { + if attr.path.is_ident("rlp") { + if let Ok(Meta::List(meta)) = attr.parse_meta() { + if let Some(NestedMeta::Meta(meta)) = meta.nested.first() { + return meta.path().is_ident(attr_name) + } + return false + } else { + } + } + false + }) +} diff --git a/crates/common/rlp/tests/rlp.rs b/crates/common/rlp/tests/rlp.rs index 62a073ece3..7479d8d1a3 100644 --- a/crates/common/rlp/tests/rlp.rs +++ b/crates/common/rlp/tests/rlp.rs @@ -13,6 +13,9 @@ struct Test4Numbers { a: u8, b: u64, c: U256, + #[rlp(skip)] + #[rlp(default)] + s: U256, d: U256, } @@ -50,6 +53,9 @@ fn test_encode_item() { c: U256::from_be_bytes(hex!( "56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421" )), + s: U256::from_be_bytes(hex!( + "c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" + )), d: U256::from_be_bytes(hex!( "c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" )), @@ -62,8 +68,12 @@ fn test_encode_item() { let out = reth_rlp::encode_fixed_size(&item); assert_eq!(&*out, expected); - let decoded = Decodable::decode(&mut &*expected).unwrap(); - assert_eq!(item, decoded); + let decoded: Test4Numbers = Decodable::decode(&mut &*expected).unwrap(); + assert_eq!(decoded.a, item.a); + assert_eq!(decoded.b, item.b); + assert_eq!(decoded.c, item.c); + assert_eq!(decoded.d, item.d); + assert_eq!(decoded.s, U256::ZERO); let mut rlp_view = Rlp::new(&expected).unwrap(); assert_eq!(rlp_view.get_next().unwrap(), Some(item.a)); @@ -79,6 +89,7 @@ fn test_encode_item() { assert_eq!(encoded(&W(item)), expected); assert_eq!(W::decode(&mut &*expected).unwrap().0, decoded); + assert_eq!(Test4Numbers::LEN, 79); } #[test]