fix(rlp): encode None optional fields as empty if any values left (#1366)

This commit is contained in:
Roman Krasiuk
2023-02-15 12:23:35 +02:00
committed by GitHub
parent a4ad2da06e
commit 022bf6342d
4 changed files with 71 additions and 12 deletions

View File

@@ -2,7 +2,7 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, Result};
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct};
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct, EMPTY_STRING_CODE};
pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpDecodable")?;
@@ -107,7 +107,12 @@ fn decodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStrea
} else if is_opt {
quote! {
#ident: if started_len - b.len() < rlp_head.payload_length {
Some(reth_rlp::Decodable::decode(b)?)
if b.first().map(|b| *b == #EMPTY_STRING_CODE).unwrap_or_default() {
bytes::Buf::advance(b, 1);
None
} else {
Some(reth_rlp::Decodable::decode(b)?)
}
} else {
None
},

View File

@@ -1,24 +1,28 @@
use std::iter::Peekable;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, Result};
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct};
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct, EMPTY_STRING_CODE};
pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpEncodable")?;
let fields = body
let mut fields = body
.fields
.iter()
.enumerate()
.filter(|(_, field)| !attributes_include(&field.attrs, "skip"));
.filter(|(_, field)| !attributes_include(&field.attrs, "skip"))
.peekable();
let supports_trailing_opt = attributes_include(&ast.attrs, "trailing");
let mut encountered_opt_item = false;
let mut length_stmts = Vec::with_capacity(body.fields.len());
let mut stmts = Vec::with_capacity(body.fields.len());
for (i, field) in fields {
while let Some((i, field)) = fields.next() {
let is_opt = is_optional(field);
if is_opt {
if !supports_trailing_opt {
@@ -29,8 +33,8 @@ pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> Result<TokenStream> {
return Err(Error::new_spanned(field, "All subsequent fields must be optional."))
}
length_stmts.push(encodable_length(i, field, is_opt));
stmts.push(encodable_field(i, field, is_opt));
length_stmts.push(encodable_length(i, field, is_opt, fields.clone()));
stmts.push(encodable_field(i, field, is_opt, fields.clone()));
}
let name = &ast.ident;
@@ -131,11 +135,23 @@ pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> Result<TokenStream
})
}
fn encodable_length(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream {
fn encodable_length<'a>(
index: usize,
field: &syn::Field,
is_opt: bool,
mut remaining: Peekable<impl Iterator<Item = (usize, &'a syn::Field)>>,
) -> TokenStream {
let ident = field_ident(index, field);
if is_opt {
quote! { rlp_head.payload_length += &self.#ident.as_ref().map(|val| reth_rlp::Encodable::length(val)).unwrap_or_default(); }
let default = if remaining.peek().is_some() {
let condition = remaining_opt_fields_some_condition(remaining);
quote! { #condition as usize }
} else {
quote! { 0 }
};
quote! { rlp_head.payload_length += &self.#ident.as_ref().map(|val| reth_rlp::Encodable::length(val)).unwrap_or(#default); }
} else {
quote! { rlp_head.payload_length += reth_rlp::Encodable::length(&self.#ident); }
}
@@ -151,12 +167,43 @@ fn encodable_max_length(index: usize, field: &syn::Field) -> TokenStream {
}
}
fn encodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream {
fn encodable_field<'a>(
index: usize,
field: &syn::Field,
is_opt: bool,
mut remaining: Peekable<impl Iterator<Item = (usize, &'a syn::Field)>>,
) -> TokenStream {
let ident = field_ident(index, field);
if is_opt {
quote! { self.#ident.as_ref().map(|val| reth_rlp::Encodable::encode(val, out)); }
let if_some_encode = quote! {
if let Some(val) = self.#ident.as_ref() {
reth_rlp::Encodable::encode(val, out)
}
};
if remaining.peek().is_some() {
let condition = remaining_opt_fields_some_condition(remaining);
quote! {
#if_some_encode
else if #condition {
out.put_u8(#EMPTY_STRING_CODE);
}
}
} else {
quote! { #if_some_encode }
}
} else {
quote! { reth_rlp::Encodable::encode(&self.#ident, out); }
}
}
fn remaining_opt_fields_some_condition<'a>(
remaining: impl Iterator<Item = (usize, &'a syn::Field)>,
) -> TokenStream {
let conditions = remaining.map(|(index, field)| {
let ident = field_ident(index, field);
quote! { self.#ident.is_some() }
});
quote! { #(#conditions) ||* }
}

View File

@@ -2,6 +2,8 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, DataStruct, Error, Field, Meta, NestedMeta, Result, Type, TypePath};
pub(crate) const EMPTY_STRING_CODE: u8 = 0x80;
pub(crate) fn parse_struct<'a>(
ast: &'a syn::DeriveInput,
derive_attr: &str,

View File

@@ -127,4 +127,9 @@ fn test_opt_fields_roundtrip() {
let item = TestOpt { a: 1, b: 2, c: Some(3), d: Some(4) };
assert_eq!(&*encoded(&item), expected);
assert_eq!(TestOpt::decode(&mut &expected[..]).unwrap(), item);
let expected = hex!("c401028004");
let item = TestOpt { a: 1, b: 2, c: None, d: Some(4) };
assert_eq!(&*encoded(&item), expected);
assert_eq!(TestOpt::decode(&mut &expected[..]).unwrap(), item);
}