diff --git a/crates/rlp/rlp-derive/src/de.rs b/crates/rlp/rlp-derive/src/de.rs index 4b7123a2e2..f4edfa7df9 100644 --- a/crates/rlp/rlp-derive/src/de.rs +++ b/crates/rlp/rlp-derive/src/de.rs @@ -1,17 +1,35 @@ use proc_macro2::TokenStream; use quote::quote; +use syn::{Error, Result}; -use crate::utils::has_attribute; +use crate::utils::{attributes_include, field_ident, is_optional, parse_struct}; -pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> TokenStream { - let body = if let syn::Data::Struct(s) = &ast.data { - s - } else { - panic!("#[derive(RlpDecodable)] is only defined for structs."); - }; +pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> Result { + let body = parse_struct(ast, "RlpDecodable")?; + + let fields = body.fields.iter().enumerate(); + + let supports_trailing_opt = attributes_include(&ast.attrs, "trailing"); + + let mut encountered_opt_item = false; + let mut stmts = Vec::with_capacity(body.fields.len()); + for (i, field) in fields { + let is_opt = is_optional(field); + if is_opt { + if !supports_trailing_opt { + return Err(Error::new_spanned(field, "Optional fields are disabled. Add `#[rlp(trailing)]` attribute to the struct in order to enable")) + } + encountered_opt_item = true; + } else if encountered_opt_item && !attributes_include(&field.attrs, "default") { + return Err(Error::new_spanned( + field, + "All subsequent fields must be either optional or default.", + )) + } + + stmts.push(decodable_field(i, field, is_opt)); + } - let stmts: Vec<_> = - body.fields.iter().enumerate().map(|(i, field)| decodable_field(i, field)).collect(); let name = &ast.ident; let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); @@ -45,20 +63,16 @@ pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> TokenStream { } }; - quote! { + Ok(quote! { const _: () = { extern crate reth_rlp; #impl_block }; - } + }) } -pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> TokenStream { - let body = if let syn::Data::Struct(s) = &ast.data { - s - } else { - panic!("#[derive(RlpEncodableWrapper)] is only defined for structs."); - }; +pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> Result { + let body = parse_struct(ast, "RlpEncodableWrapper")?; assert_eq!( body.fields.iter().count(), @@ -77,25 +91,28 @@ pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> TokenStream { } }; - quote! { + Ok(quote! { const _: () = { extern crate reth_rlp; #impl_block }; - } + }) } -fn decodable_field(index: usize, field: &syn::Field) -> TokenStream { - let id = if let Some(ident) = &field.ident { - quote! { #ident } - } else { - let index = syn::Index::from(index); - quote! { #index } - }; +fn decodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream { + let ident = field_ident(index, field); - if has_attribute(field, "default") { - quote! { #id: Default::default(), } + if attributes_include(&field.attrs, "default") { + quote! { #ident: Default::default(), } + } else if is_opt { + quote! { + #ident: if started_len - b.len() < rlp_head.payload_length { + Some(reth_rlp::Decodable::decode(b)?) + } else { + None + }, + } } else { - quote! { #id: reth_rlp::Decodable::decode(b)?, } + quote! { #ident: reth_rlp::Decodable::decode(b)?, } } } diff --git a/crates/rlp/rlp-derive/src/en.rs b/crates/rlp/rlp-derive/src/en.rs index 46dce4cfdb..ff3c1f10e7 100644 --- a/crates/rlp/rlp-derive/src/en.rs +++ b/crates/rlp/rlp-derive/src/en.rs @@ -1,22 +1,37 @@ use proc_macro2::TokenStream; use quote::quote; +use syn::{Error, Result}; -use crate::utils::has_attribute; +use crate::utils::{attributes_include, field_ident, is_optional, parse_struct}; -pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream { - let body = if let syn::Data::Struct(s) = &ast.data { - s - } else { - panic!("#[derive(RlpEncodable)] is only defined for structs."); - }; +pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> Result { + let body = parse_struct(ast, "RlpEncodable")?; - let (length_stmts, stmts): (Vec<_>, Vec<_>) = body + let fields = body .fields .iter() .enumerate() - .filter(|(_, field)| !has_attribute(field, "skip")) - .map(|(i, field)| (encodable_length(i, field), encodable_field(i, field))) - .unzip(); + .filter(|(_, field)| !attributes_include(&field.attrs, "skip")); + + let supports_trailing_opt = attributes_include(&ast.attrs, "trailing"); + + let mut encountered_opt_item = false; + let mut length_stmts = Vec::with_capacity(body.fields.len()); + let mut stmts = Vec::with_capacity(body.fields.len()); + for (i, field) in fields { + let is_opt = is_optional(field); + if is_opt { + if !supports_trailing_opt { + return Err(Error::new_spanned(field, "Optional fields are disabled. Add `#[rlp(trailing)]` attribute to the struct in order to enable")) + } + encountered_opt_item = true; + } else if encountered_opt_item { + return Err(Error::new_spanned(field, "All subsequent fields must be optional.")) + } + + length_stmts.push(encodable_length(i, field, is_opt)); + stmts.push(encodable_field(i, field, is_opt)); + } let name = &ast.ident; let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); @@ -46,20 +61,16 @@ pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream { } }; - quote! { + Ok(quote! { const _: () = { extern crate reth_rlp; #impl_block }; - } + }) } -pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> TokenStream { - let body = if let syn::Data::Struct(s) = &ast.data { - s - } else { - panic!("#[derive(RlpEncodableWrapper)] is only defined for structs."); - }; +pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> Result { + let body = parse_struct(ast, "RlpEncodableWrapper")?; let ident = { let fields: Vec<_> = body.fields.iter().collect(); @@ -85,26 +96,22 @@ pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> TokenStream { } }; - quote! { + Ok(quote! { const _: () = { extern crate reth_rlp; #impl_block }; - } + }) } -pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> TokenStream { - let body = if let syn::Data::Struct(s) = &ast.data { - s - } else { - panic!("#[derive(RlpMaxEncodedLen)] is only defined for structs."); - }; +pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> Result { + let body = parse_struct(ast, "RlpMaxEncodedLen")?; let stmts: Vec<_> = body .fields .iter() .enumerate() - .filter(|(_, field)| !has_attribute(field, "skip")) + .filter(|(_, field)| !attributes_include(&field.attrs, "skip")) .map(|(index, field)| encodable_max_length(index, field)) .collect(); let name = &ast.ident; @@ -116,27 +123,22 @@ pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> TokenStream { } }; - quote! { + Ok(quote! { const _: () = { extern crate reth_rlp; #impl_block }; - } + }) } -fn field_ident(index: usize, field: &syn::Field) -> TokenStream { - if let Some(ident) = &field.ident { - quote! { #ident } - } else { - let index = syn::Index::from(index); - quote! { #index } - } -} - -fn encodable_length(index: usize, field: &syn::Field) -> TokenStream { +fn encodable_length(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream { let ident = field_ident(index, field); - quote! { rlp_head.payload_length += reth_rlp::Encodable::length(&self.#ident); } + if is_opt { + quote! { rlp_head.payload_length += &self.#ident.as_ref().map(|val| reth_rlp::Encodable::length(val)).unwrap_or_default(); } + } else { + quote! { rlp_head.payload_length += reth_rlp::Encodable::length(&self.#ident); } + } } fn encodable_max_length(index: usize, field: &syn::Field) -> TokenStream { @@ -149,10 +151,12 @@ fn encodable_max_length(index: usize, field: &syn::Field) -> TokenStream { } } -fn encodable_field(index: usize, field: &syn::Field) -> TokenStream { +fn encodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream { let ident = field_ident(index, field); - let id = quote! { self.#ident }; - - quote! { reth_rlp::Encodable::encode(&#id, out); } + if is_opt { + quote! { self.#ident.as_ref().map(|val| reth_rlp::Encodable::encode(val, out)); } + } else { + quote! { reth_rlp::Encodable::encode(&self.#ident, out); } + } } diff --git a/crates/rlp/rlp-derive/src/lib.rs b/crates/rlp/rlp-derive/src/lib.rs index c6715a3213..24737a8ae7 100644 --- a/crates/rlp/rlp-derive/src/lib.rs +++ b/crates/rlp/rlp-derive/src/lib.rs @@ -28,34 +28,28 @@ use proc_macro::TokenStream; /// Derives `Encodable` for the type which encodes the all fields as list: `` #[proc_macro_derive(RlpEncodable, attributes(rlp))] pub fn encodable(input: TokenStream) -> TokenStream { - let ast = match syn::parse(input) { - Ok(ast) => ast, - Err(err) => return err.to_compile_error().into(), - }; - let gen = impl_encodable(&ast); - gen.into() + syn::parse(input) + .and_then(|ast| impl_encodable(&ast)) + .unwrap_or_else(|err| err.to_compile_error()) + .into() } /// Derives `Encodable` for the type which encodes the fields as-is, without a header: `` #[proc_macro_derive(RlpEncodableWrapper, attributes(rlp))] pub fn encodable_wrapper(input: TokenStream) -> TokenStream { - let ast = match syn::parse(input) { - Ok(ast) => ast, - Err(err) => return err.to_compile_error().into(), - }; - let gen = impl_encodable_wrapper(&ast); - gen.into() + syn::parse(input) + .and_then(|ast| impl_encodable_wrapper(&ast)) + .unwrap_or_else(|err| err.to_compile_error()) + .into() } /// Derives `MaxEncodedLen` for types of constant size. #[proc_macro_derive(RlpMaxEncodedLen, attributes(rlp))] pub fn max_encoded_len(input: TokenStream) -> TokenStream { - let ast = match syn::parse(input) { - Ok(ast) => ast, - Err(err) => return err.to_compile_error().into(), - }; - let gen = impl_max_encoded_len(&ast); - gen.into() + syn::parse(input) + .and_then(|ast| impl_max_encoded_len(&ast)) + .unwrap_or_else(|err| err.to_compile_error()) + .into() } /// Derives `Decodable` for the type whose implementation expects an rlp-list input: ` TokenStream { /// This is the inverse of `RlpEncodable`. #[proc_macro_derive(RlpDecodable, attributes(rlp))] pub fn decodable(input: TokenStream) -> TokenStream { - let ast = match syn::parse(input) { - Ok(ast) => ast, - Err(err) => return err.to_compile_error().into(), - }; - let gen = impl_decodable(&ast); - gen.into() + syn::parse(input) + .and_then(|ast| impl_decodable(&ast)) + .unwrap_or_else(|err| err.to_compile_error()) + .into() } /// Derives `Decodable` for the type whose implementation expects only the individual fields @@ -78,7 +70,8 @@ pub fn decodable(input: TokenStream) -> TokenStream { /// This is the inverse of `RlpEncodableWrapper`. #[proc_macro_derive(RlpDecodableWrapper, attributes(rlp))] pub fn decodable_wrapper(input: TokenStream) -> TokenStream { - let ast = syn::parse(input).unwrap(); - let gen = impl_decodable_wrapper(&ast); - gen.into() + syn::parse(input) + .and_then(|ast| impl_decodable_wrapper(&ast)) + .unwrap_or_else(|err| err.to_compile_error()) + .into() } diff --git a/crates/rlp/rlp-derive/src/utils.rs b/crates/rlp/rlp-derive/src/utils.rs index cae47fe892..3e5f373059 100644 --- a/crates/rlp/rlp-derive/src/utils.rs +++ b/crates/rlp/rlp-derive/src/utils.rs @@ -1,7 +1,23 @@ -use syn::{Field, Meta, NestedMeta}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Attribute, DataStruct, Error, Field, Meta, NestedMeta, Result, Type, TypePath}; -pub(crate) fn has_attribute(field: &Field, attr_name: &str) -> bool { - field.attrs.iter().any(|attr| { +pub(crate) fn parse_struct<'a>( + ast: &'a syn::DeriveInput, + derive_attr: &str, +) -> Result<&'a DataStruct> { + if let syn::Data::Struct(s) = &ast.data { + Ok(s) + } else { + Err(Error::new_spanned( + ast, + format!("#[derive({derive_attr})] is only defined for structs."), + )) + } +} + +pub(crate) fn attributes_include(attrs: &[Attribute], attr_name: &str) -> bool { + attrs.iter().any(|attr| { if attr.path.is_ident("rlp") { if let Ok(Meta::List(meta)) = attr.parse_meta() { if let Some(NestedMeta::Meta(meta)) = meta.nested.first() { @@ -14,3 +30,23 @@ pub(crate) fn has_attribute(field: &Field, attr_name: &str) -> bool { false }) } + +pub(crate) fn is_optional(field: &Field) -> bool { + if let Type::Path(TypePath { qself, path }) = &field.ty { + qself.is_none() && + path.leading_colon.is_none() && + path.segments.len() == 1 && + path.segments.first().unwrap().ident == "Option" + } else { + false + } +} + +pub(crate) fn field_ident(index: usize, field: &syn::Field) -> TokenStream { + if let Some(ident) = &field.ident { + quote! { #ident } + } else { + let index = syn::Index::from(index); + quote! { #index } + } +} diff --git a/crates/rlp/tests/rlp.rs b/crates/rlp/tests/rlp.rs index 7479d8d1a3..7f82b59094 100644 --- a/crates/rlp/tests/rlp.rs +++ b/crates/rlp/tests/rlp.rs @@ -30,6 +30,15 @@ struct Test4NumbersGenerics<'a, D: Encodable> { d: &'a D, } +#[derive(Debug, PartialEq, RlpEncodable, RlpDecodable)] +#[rlp(trailing)] +struct TestOpt { + a: u8, + b: u64, + c: Option, + d: Option, +} + fn encoded(t: &T) -> BytesMut { let mut out = BytesMut::new(); t.encode(&mut out); @@ -101,3 +110,21 @@ fn invalid_decode_sideeffect() { assert_eq!(sl.len(), fixture.len()); } + +#[test] +fn test_opt_fields_roundtrip() { + let expected = hex!("c20102"); + let item = TestOpt { a: 1, b: 2, c: None, d: None }; + assert_eq!(&*encoded(&item), expected); + assert_eq!(TestOpt::decode(&mut &expected[..]).unwrap(), item); + + let expected = hex!("c3010203"); + let item = TestOpt { a: 1, b: 2, c: Some(3), d: None }; + assert_eq!(&*encoded(&item), expected); + assert_eq!(TestOpt::decode(&mut &expected[..]).unwrap(), item); + + let expected = hex!("c401020304"); + let item = TestOpt { a: 1, b: 2, c: Some(3), d: Some(4) }; + assert_eq!(&*encoded(&item), expected); + assert_eq!(TestOpt::decode(&mut &expected[..]).unwrap(), item); +}