From e6fd1d5221e1c62476d46b484a0d5b4e4408d0ea Mon Sep 17 00:00:00 2001 From: parazyd Date: Thu, 24 Aug 2023 11:21:37 +0200 Subject: [PATCH] darkfi-derive-internal: Add support for Async{Encodable,Decodable} derive. --- src/serial/derive-internal/Cargo.toml | 4 + .../derive-internal/src/async_derive.rs | 412 ++++++++++++++++++ src/serial/derive-internal/src/helpers.rs | 38 -- src/serial/derive-internal/src/lib.rs | 404 +---------------- src/serial/derive-internal/src/sync_derive.rs | 407 +++++++++++++++++ 5 files changed, 845 insertions(+), 420 deletions(-) create mode 100644 src/serial/derive-internal/src/async_derive.rs delete mode 100644 src/serial/derive-internal/src/helpers.rs create mode 100644 src/serial/derive-internal/src/sync_derive.rs diff --git a/src/serial/derive-internal/Cargo.toml b/src/serial/derive-internal/Cargo.toml index 4d3b13e7d..342d400ce 100644 --- a/src/serial/derive-internal/Cargo.toml +++ b/src/serial/derive-internal/Cargo.toml @@ -12,3 +12,7 @@ edition = "2021" proc-macro2 = "1.0.66" quote = "1.0.33" syn = {version = "2.0.29", features = ["full", "fold"]} + +[features] +default = [] +async = [] diff --git a/src/serial/derive-internal/src/async_derive.rs b/src/serial/derive-internal/src/async_derive.rs new file mode 100644 index 000000000..c208460f4 --- /dev/null +++ b/src/serial/derive-internal/src/async_derive.rs @@ -0,0 +1,412 @@ +/* This file is part of DarkFi (https://dark.fi) + * + * Copyright (C) 2020-2023 Dyne.org foundation + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +//! Derive (de)serialization for enums and structs, see src/serial/derive +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use syn::{ + Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, ItemStruct, WhereClause, WherePredicate, +}; + +use super::{contains_initialize_with, contains_skip, discriminant_map, VariantParts}; + +fn named_fields( + cratename: &Ident, + enum_ident: &Ident, + variant_ident: &Ident, + discriminant_value: &TokenStream, + fields: &FieldsNamed, +) -> syn::Result { + let mut where_predicates: Vec = vec![]; + let mut variant_header = TokenStream::new(); + let mut variant_body = TokenStream::new(); + + for field in &fields.named { + if !contains_skip(&field.attrs) { + let field_ident = field.ident.clone().unwrap(); + + variant_header.extend(quote! { #field_ident, }); + + let field_type = &field.ty; + where_predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncEncodable + }) + .unwrap(), + ); + + variant_body.extend(quote! { + len += #field_ident.encode_async(s).await?; + }) + } + } + + // `..` pattern matching works even if all fields were specified + variant_header = quote! { { #variant_header .. }}; + let variant_idx_body = quote!( + #enum_ident::#variant_ident { .. } => #discriminant_value, + ); + + Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) +} + +fn unnamed_fields( + cratename: &Ident, + enum_ident: &Ident, + variant_ident: &Ident, + discriminant_value: &TokenStream, + fields: &FieldsUnnamed, +) -> syn::Result { + let mut where_predicates: Vec = vec![]; + let mut variant_header = TokenStream::new(); + let mut variant_body = TokenStream::new(); + + for (field_idx, field) in fields.unnamed.iter().enumerate() { + let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported"); + if contains_skip(&field.attrs) { + let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site()); + variant_header.extend(quote! { #field_ident, }); + } else { + let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site()); + variant_header.extend(quote! { #field_ident, }); + + let field_type = &field.ty; + where_predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncEncodable + }) + .unwrap(), + ); + + variant_body.extend(quote! { + len += #field_ident.encode_async(s).await?; + }) + } + } + + variant_header = quote! { ( #variant_header )}; + let variant_idx_body = quote!( + #enum_ident::#variant_ident(..) => #discriminant_value, + ); + + Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) +} + +pub fn async_enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result { + let enum_ident = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + let mut all_variants_idx_body = TokenStream::new(); + let mut fields_body = TokenStream::new(); + let discriminants = discriminant_map(&input.variants); + + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let discriminant_value = discriminants.get(variant_ident).unwrap(); + let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } = + match &variant.fields { + Fields::Named(fields) => { + named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)? + } + Fields::Unnamed(fields) => unnamed_fields( + &cratename, + enum_ident, + variant_ident, + discriminant_value, + fields, + )?, + Fields::Unit => { + let variant_idx_body = quote!( + #enum_ident::#variant_ident => #discriminant_value, + ); + VariantParts { + where_predicates: vec![], + variant_header: TokenStream::new(), + variant_body: TokenStream::new(), + variant_idx_body, + } + } + }; + where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate)); + all_variants_idx_body.extend(variant_idx_body); + fields_body.extend(quote!( + #enum_ident::#variant_ident #variant_header => { + #variant_body + } + )) + } + + Ok(quote! { + #[async_trait] + impl #impl_generics #cratename::AsyncEncodable for #enum_ident #ty_generics #where_clause { + async fn encode_async(&self, s: &mut S) -> ::std::io::Result { + let variant_idx: u8 = match self { + #all_variants_idx_body + }; + + let mut len = 0; + let bytes = variant_idx.to_le_bytes(); + s.write_all(&bytes).await?; + len += bytes.len(); + + match self { + #fields_body + } + + Ok(len) + } + } + }) +} + +pub fn async_enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let init_method = contains_initialize_with(&input.attrs); + let mut variant_arms = TokenStream::new(); + let discriminants = discriminant_map(&input.variants); + + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let discriminant = discriminants.get(variant_ident).unwrap(); + let mut variant_header = TokenStream::new(); + match &variant.fields { + Fields::Named(fields) => { + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + if contains_skip(&field.attrs) { + variant_header.extend(quote! { + #field_name: Default::default(), + }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncDecodable + }) + .unwrap(), + ); + + variant_header.extend(quote! { + #field_name: #cratename::AsyncDecodable::decode_async(d).await?, + }); + } + } + variant_header = quote! { { #variant_header }}; + } + Fields::Unnamed(fields) => { + for field in fields.unnamed.iter() { + if contains_skip(&field.attrs) { + variant_header.extend(quote! { Default::default(), }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncDecodable + }) + .unwrap(), + ); + + variant_header.extend(quote! { + #cratename::AsyncDecodable::decode_async(d).await?, + }); + } + } + variant_header = quote! { ( #variant_header )}; + } + Fields::Unit => {} + } + variant_arms.extend(quote! { + if variant_tag == #discriminant { #name::#variant_ident #variant_header } else + }); + } + + let init = if let Some(method_ident) = init_method { + quote! { + return_value.#method_ident(); + } + } else { + quote! {} + }; + + Ok(quote! { + #[async_trait] + impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause { + async fn decode_async(d: &mut D) -> ::std::io::Result { + let variant_tag: u8 = #cratename::AsyncDecodable::decode_async(d).await?; + + let mut return_value = + #variant_arms { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Unexpected variant tag: {:?}", variant_tag), + )) + }; + #init + Ok(return_value) + } + } + }) +} + +pub fn async_struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let mut body = TokenStream::new(); + + match &input.fields { + Fields::Named(fields) => { + for field in &fields.named { + if contains_skip(&field.attrs) { + continue + } + + let field_name = field.ident.as_ref().unwrap(); + let delta = quote! { + len += self.#field_name.encode_async(s).await?; + }; + body.extend(delta); + + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncEncodable + }) + .unwrap(), + ); + } + } + Fields::Unnamed(fields) => { + for field_idx in 0..fields.unnamed.len() { + let field_idx = Index { + index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"), + span: Span::call_site(), + }; + let delta = quote! { + len += self.#field_idx.encode_async(s).await?; + }; + body.extend(delta); + } + } + Fields::Unit => {} + } + + Ok(quote! { + #[async_trait] + impl #impl_generics #cratename::AsyncEncodable for #name #ty_generics #where_clause { + async fn encode_async(&self, s: &mut S) -> ::std::io::Result { + let mut len = 0; + #body + Ok(len) + } + } + }) +} + +pub fn async_struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let init_method = contains_initialize_with(&input.attrs); + let return_value = match &input.fields { + Fields::Named(fields) => { + let mut body = TokenStream::new(); + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + + let delta = if contains_skip(&field.attrs) { + quote! { + #field_name: Default::default(), + } + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::AsyncDecodable + }) + .unwrap(), + ); + + quote! { + #field_name: #cratename::AsyncDecodable::decode_async(d).await?, + } + }; + body.extend(delta); + } + quote! { + Self { #body } + } + } + Fields::Unnamed(fields) => { + let mut body = TokenStream::new(); + for _ in 0..fields.unnamed.len() { + let delta = quote! { + #cratename::AsyncDecodable::decode_async(d).await?, + }; + body.extend(delta); + } + quote! { + Self( #body ) + } + } + Fields::Unit => { + quote! { + Self {} + } + } + }; + + if let Some(method_ident) = init_method { + Ok(quote! { + #[async_trait] + impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause { + async fn decode_async(d: &mut D) -> ::std::io::Result { + let mut return_value = #return_value; + return_value.#method_ident(); + Ok(return_value) + } + } + }) + } else { + Ok(quote! { + #[async_trait] + impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause { + async fn decode_async(d: &mut D) -> ::std::io::Result { + Ok(#return_value) + } + } + }) + } +} diff --git a/src/serial/derive-internal/src/helpers.rs b/src/serial/derive-internal/src/helpers.rs deleted file mode 100644 index 7aa0d21dc..000000000 --- a/src/serial/derive-internal/src/helpers.rs +++ /dev/null @@ -1,38 +0,0 @@ -/* This file is part of DarkFi (https://dark.fi) - * - * Copyright (C) 2020-2023 Dyne.org foundation - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -use syn::{Attribute, Path}; - -pub fn contains_skip(attrs: &[Attribute]) -> bool { - attrs.iter().any(|attr| attr.path().is_ident("skip_serialize")) -} - -pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { - for attr in attrs.iter() { - if attr.path().is_ident("init_serialize") { - let mut res = None; - let _ = attr.parse_nested_meta(|meta| { - res = Some(meta.path); - Ok(()) - }); - return res - } - } - - None -} diff --git a/src/serial/derive-internal/src/lib.rs b/src/serial/derive-internal/src/lib.rs index 1e231efb5..d390496bd 100644 --- a/src/serial/derive-internal/src/lib.rs +++ b/src/serial/derive-internal/src/lib.rs @@ -19,15 +19,17 @@ //! Derive (de)serialization for enums and structs, see src/serial/derive use std::collections::HashMap; -use proc_macro2::{Ident, Span, TokenStream}; +use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{ - punctuated::Punctuated, token::Comma, Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, - ItemStruct, Variant, WhereClause, WherePredicate, -}; +use syn::{punctuated::Punctuated, token::Comma, Attribute, Path, Variant, WherePredicate}; -mod helpers; -use helpers::{contains_initialize_with, contains_skip}; +mod sync_derive; +pub use sync_derive::{enum_de, enum_ser, struct_de, struct_ser}; + +#[cfg(feature = "async")] +mod async_derive; +#[cfg(feature = "async")] +pub use async_derive::{async_enum_de, async_enum_ser, async_struct_de, async_struct_ser}; struct VariantParts { where_predicates: Vec, @@ -56,383 +58,21 @@ fn discriminant_map(variants: &Punctuated) -> HashMap syn::Result { - let mut where_predicates: Vec = vec![]; - let mut variant_header = TokenStream::new(); - let mut variant_body = TokenStream::new(); +pub fn contains_skip(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| attr.path().is_ident("skip_serialize")) +} - for field in &fields.named { - if !contains_skip(&field.attrs) { - let field_ident = field.ident.clone().unwrap(); - - variant_header.extend(quote! { #field_ident, }); - - let field_type = &field.ty; - where_predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Encodable - }) - .unwrap(), - ); - - variant_body.extend(quote! { - len += #field_ident.encode(&mut s)?; - }) +pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { + for attr in attrs.iter() { + if attr.path().is_ident("init_serialize") { + let mut res = None; + let _ = attr.parse_nested_meta(|meta| { + res = Some(meta.path); + Ok(()) + }); + return res } } - // `..` pattern matching works even if all fields were specified - variant_header = quote! { { #variant_header .. }}; - let variant_idx_body = quote!( - #enum_ident::#variant_ident { .. } => #discriminant_value, - ); - - Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) -} - -fn unnamed_fields( - cratename: &Ident, - enum_ident: &Ident, - variant_ident: &Ident, - discriminant_value: &TokenStream, - fields: &FieldsUnnamed, -) -> syn::Result { - let mut where_predicates: Vec = vec![]; - let mut variant_header = TokenStream::new(); - let mut variant_body = TokenStream::new(); - - for (field_idx, field) in fields.unnamed.iter().enumerate() { - let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported"); - if contains_skip(&field.attrs) { - let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site()); - variant_header.extend(quote! { #field_ident, }); - } else { - let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site()); - variant_header.extend(quote! { #field_ident, }); - - let field_type = &field.ty; - where_predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Encodable - }) - .unwrap(), - ); - - variant_body.extend(quote! { - len += #field_ident.encode(&mut s)?; - }) - } - } - - variant_header = quote! { ( #variant_header )}; - let variant_idx_body = quote!( - #enum_ident::#variant_ident(..) => #discriminant_value, - ); - - Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) -} - -pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result { - let enum_ident = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut where_clause = where_clause.map_or_else( - || WhereClause { where_token: Default::default(), predicates: Default::default() }, - Clone::clone, - ); - let mut all_variants_idx_body = TokenStream::new(); - let mut fields_body = TokenStream::new(); - let discriminants = discriminant_map(&input.variants); - - for variant in input.variants.iter() { - let variant_ident = &variant.ident; - let discriminant_value = discriminants.get(variant_ident).unwrap(); - let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } = - match &variant.fields { - Fields::Named(fields) => { - named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)? - } - Fields::Unnamed(fields) => unnamed_fields( - &cratename, - enum_ident, - variant_ident, - discriminant_value, - fields, - )?, - Fields::Unit => { - let variant_idx_body = quote!( - #enum_ident::#variant_ident => #discriminant_value, - ); - VariantParts { - where_predicates: vec![], - variant_header: TokenStream::new(), - variant_body: TokenStream::new(), - variant_idx_body, - } - } - }; - where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate)); - all_variants_idx_body.extend(variant_idx_body); - fields_body.extend(quote!( - #enum_ident::#variant_ident #variant_header => { - #variant_body - } - )) - } - - Ok(quote! { - impl #impl_generics #cratename::Encodable for #enum_ident #ty_generics #where_clause { - fn encode(&self, mut s: S) -> ::core::result::Result { - let variant_idx: u8 = match self { - #all_variants_idx_body - }; - - let mut len = 0; - let bytes = variant_idx.to_le_bytes(); - s.write_all(&bytes)?; - len += bytes.len(); - - match self { - #fields_body - } - - Ok(len) - } - } - }) -} - -pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result { - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut where_clause = where_clause.map_or_else( - || WhereClause { where_token: Default::default(), predicates: Default::default() }, - Clone::clone, - ); - - let init_method = contains_initialize_with(&input.attrs); - let mut variant_arms = TokenStream::new(); - let discriminants = discriminant_map(&input.variants); - - for variant in input.variants.iter() { - let variant_ident = &variant.ident; - let discriminant = discriminants.get(variant_ident).unwrap(); - let mut variant_header = TokenStream::new(); - match &variant.fields { - Fields::Named(fields) => { - for field in &fields.named { - let field_name = field.ident.as_ref().unwrap(); - if contains_skip(&field.attrs) { - variant_header.extend(quote! { - #field_name: Default::default(), - }); - } else { - let field_type = &field.ty; - where_clause.predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Decodable - }) - .unwrap(), - ); - - variant_header.extend(quote! { - #field_name: #cratename::Decodable::decode(&mut d)?, - }); - } - } - variant_header = quote! { { #variant_header }}; - } - Fields::Unnamed(fields) => { - for field in fields.unnamed.iter() { - if contains_skip(&field.attrs) { - variant_header.extend(quote! { Default::default(), }); - } else { - let field_type = &field.ty; - where_clause.predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Decodable - }) - .unwrap(), - ); - - variant_header.extend(quote! { - #cratename::Decodable::decode(&mut d)?, - }); - } - } - variant_header = quote! { ( #variant_header )}; - } - Fields::Unit => {} - } - variant_arms.extend(quote! { - if variant_tag == #discriminant { #name::#variant_ident #variant_header } else - }); - } - - let init = if let Some(method_ident) = init_method { - quote! { - return_value.#method_ident(); - } - } else { - quote! {} - }; - - Ok(quote! { - impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { - fn decode(mut d: D) -> ::core::result::Result { - let variant_tag: u8 = #cratename::Decodable::decode(&mut d)?; - - let mut return_value = - #variant_arms { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Unexpected variant tag: {:?}", variant_tag), - )) - }; - #init - Ok(return_value) - } - } - }) -} - -pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut where_clause = where_clause.map_or_else( - || WhereClause { where_token: Default::default(), predicates: Default::default() }, - Clone::clone, - ); - - let mut body = TokenStream::new(); - - match &input.fields { - Fields::Named(fields) => { - for field in &fields.named { - if contains_skip(&field.attrs) { - continue - } - - let field_name = field.ident.as_ref().unwrap(); - let delta = quote! { - len += self.#field_name.encode(&mut s)?; - }; - body.extend(delta); - - let field_type = &field.ty; - where_clause.predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Encodable - }) - .unwrap(), - ); - } - } - Fields::Unnamed(fields) => { - for field_idx in 0..fields.unnamed.len() { - let field_idx = Index { - index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"), - span: Span::call_site(), - }; - let delta = quote! { - len += self.#field_idx.encode(&mut s)?; - }; - body.extend(delta); - } - } - Fields::Unit => {} - } - - Ok(quote! { - impl #impl_generics #cratename::Encodable for #name #ty_generics #where_clause { - fn encode(&self, mut s: S) -> ::core::result::Result { - let mut len = 0; - #body - Ok(len) - } - } - }) -} - -pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result { - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut where_clause = where_clause.map_or_else( - || WhereClause { where_token: Default::default(), predicates: Default::default() }, - Clone::clone, - ); - - let init_method = contains_initialize_with(&input.attrs); - let return_value = match &input.fields { - Fields::Named(fields) => { - let mut body = TokenStream::new(); - for field in &fields.named { - let field_name = field.ident.as_ref().unwrap(); - - let delta = if contains_skip(&field.attrs) { - quote! { - #field_name: Default::default(), - } - } else { - let field_type = &field.ty; - where_clause.predicates.push( - syn::parse2(quote! { - #field_type: #cratename::Decodable - }) - .unwrap(), - ); - - quote! { - #field_name: #cratename::Decodable::decode(&mut d)?, - } - }; - body.extend(delta); - } - quote! { - Self { #body } - } - } - Fields::Unnamed(fields) => { - let mut body = TokenStream::new(); - for _ in 0..fields.unnamed.len() { - let delta = quote! { - #cratename::Decodable::decode(&mut d)?, - }; - body.extend(delta); - } - quote! { - Self( #body ) - } - } - Fields::Unit => { - quote! { - Self {} - } - } - }; - - if let Some(method_ident) = init_method { - Ok(quote! { - impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { - fn decode(mut d: D) -> ::core::result::Result { - let mut return_value = #return_value; - return_value.#method_ident(); - Ok(return_value) - } - } - }) - } else { - Ok(quote! { - impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { - fn decode(mut d: D) -> ::core::result::Result { - Ok(#return_value) - } - } - }) - } + None } diff --git a/src/serial/derive-internal/src/sync_derive.rs b/src/serial/derive-internal/src/sync_derive.rs new file mode 100644 index 000000000..eca8fe4a0 --- /dev/null +++ b/src/serial/derive-internal/src/sync_derive.rs @@ -0,0 +1,407 @@ +/* This file is part of DarkFi (https://dark.fi) + * + * Copyright (C) 2020-2023 Dyne.org foundation + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +//! Derive (de)serialization for enums and structs, see src/serial/derive +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use syn::{ + Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, ItemStruct, WhereClause, WherePredicate, +}; + +use super::{contains_initialize_with, contains_skip, discriminant_map, VariantParts}; + +fn named_fields( + cratename: &Ident, + enum_ident: &Ident, + variant_ident: &Ident, + discriminant_value: &TokenStream, + fields: &FieldsNamed, +) -> syn::Result { + let mut where_predicates: Vec = vec![]; + let mut variant_header = TokenStream::new(); + let mut variant_body = TokenStream::new(); + + for field in &fields.named { + if !contains_skip(&field.attrs) { + let field_ident = field.ident.clone().unwrap(); + + variant_header.extend(quote! { #field_ident, }); + + let field_type = &field.ty; + where_predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Encodable + }) + .unwrap(), + ); + + variant_body.extend(quote! { + len += #field_ident.encode(&mut s)?; + }) + } + } + + // `..` pattern matching works even if all fields were specified + variant_header = quote! { { #variant_header .. }}; + let variant_idx_body = quote!( + #enum_ident::#variant_ident { .. } => #discriminant_value, + ); + + Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) +} + +fn unnamed_fields( + cratename: &Ident, + enum_ident: &Ident, + variant_ident: &Ident, + discriminant_value: &TokenStream, + fields: &FieldsUnnamed, +) -> syn::Result { + let mut where_predicates: Vec = vec![]; + let mut variant_header = TokenStream::new(); + let mut variant_body = TokenStream::new(); + + for (field_idx, field) in fields.unnamed.iter().enumerate() { + let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported"); + if contains_skip(&field.attrs) { + let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site()); + variant_header.extend(quote! { #field_ident, }); + } else { + let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site()); + variant_header.extend(quote! { #field_ident, }); + + let field_type = &field.ty; + where_predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Encodable + }) + .unwrap(), + ); + + variant_body.extend(quote! { + len += #field_ident.encode(&mut s)?; + }) + } + } + + variant_header = quote! { ( #variant_header )}; + let variant_idx_body = quote!( + #enum_ident::#variant_ident(..) => #discriminant_value, + ); + + Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body }) +} + +pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result { + let enum_ident = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + let mut all_variants_idx_body = TokenStream::new(); + let mut fields_body = TokenStream::new(); + let discriminants = discriminant_map(&input.variants); + + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let discriminant_value = discriminants.get(variant_ident).unwrap(); + let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } = + match &variant.fields { + Fields::Named(fields) => { + named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)? + } + Fields::Unnamed(fields) => unnamed_fields( + &cratename, + enum_ident, + variant_ident, + discriminant_value, + fields, + )?, + Fields::Unit => { + let variant_idx_body = quote!( + #enum_ident::#variant_ident => #discriminant_value, + ); + VariantParts { + where_predicates: vec![], + variant_header: TokenStream::new(), + variant_body: TokenStream::new(), + variant_idx_body, + } + } + }; + where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate)); + all_variants_idx_body.extend(variant_idx_body); + fields_body.extend(quote!( + #enum_ident::#variant_ident #variant_header => { + #variant_body + } + )) + } + + Ok(quote! { + impl #impl_generics #cratename::Encodable for #enum_ident #ty_generics #where_clause { + fn encode(&self, mut s: S) -> ::core::result::Result { + let variant_idx: u8 = match self { + #all_variants_idx_body + }; + + let mut len = 0; + let bytes = variant_idx.to_le_bytes(); + s.write_all(&bytes)?; + len += bytes.len(); + + match self { + #fields_body + } + + Ok(len) + } + } + }) +} + +pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let init_method = contains_initialize_with(&input.attrs); + let mut variant_arms = TokenStream::new(); + let discriminants = discriminant_map(&input.variants); + + for variant in input.variants.iter() { + let variant_ident = &variant.ident; + let discriminant = discriminants.get(variant_ident).unwrap(); + let mut variant_header = TokenStream::new(); + match &variant.fields { + Fields::Named(fields) => { + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + if contains_skip(&field.attrs) { + variant_header.extend(quote! { + #field_name: Default::default(), + }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Decodable + }) + .unwrap(), + ); + + variant_header.extend(quote! { + #field_name: #cratename::Decodable::decode(&mut d)?, + }); + } + } + variant_header = quote! { { #variant_header }}; + } + Fields::Unnamed(fields) => { + for field in fields.unnamed.iter() { + if contains_skip(&field.attrs) { + variant_header.extend(quote! { Default::default(), }); + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Decodable + }) + .unwrap(), + ); + + variant_header.extend(quote! { + #cratename::Decodable::decode(&mut d)?, + }); + } + } + variant_header = quote! { ( #variant_header )}; + } + Fields::Unit => {} + } + variant_arms.extend(quote! { + if variant_tag == #discriminant { #name::#variant_ident #variant_header } else + }); + } + + let init = if let Some(method_ident) = init_method { + quote! { + return_value.#method_ident(); + } + } else { + quote! {} + }; + + Ok(quote! { + impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { + fn decode(mut d: D) -> ::core::result::Result { + let variant_tag: u8 = #cratename::Decodable::decode(&mut d)?; + + let mut return_value = + #variant_arms { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Unexpected variant tag: {:?}", variant_tag), + )) + }; + #init + Ok(return_value) + } + } + }) +} + +pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let mut body = TokenStream::new(); + + match &input.fields { + Fields::Named(fields) => { + for field in &fields.named { + if contains_skip(&field.attrs) { + continue + } + + let field_name = field.ident.as_ref().unwrap(); + let delta = quote! { + len += self.#field_name.encode(&mut s)?; + }; + body.extend(delta); + + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Encodable + }) + .unwrap(), + ); + } + } + Fields::Unnamed(fields) => { + for field_idx in 0..fields.unnamed.len() { + let field_idx = Index { + index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"), + span: Span::call_site(), + }; + let delta = quote! { + len += self.#field_idx.encode(&mut s)?; + }; + body.extend(delta); + } + } + Fields::Unit => {} + } + + Ok(quote! { + impl #impl_generics #cratename::Encodable for #name #ty_generics #where_clause { + fn encode(&self, mut s: S) -> ::core::result::Result { + let mut len = 0; + #body + Ok(len) + } + } + }) +} + +pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result { + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut where_clause = where_clause.map_or_else( + || WhereClause { where_token: Default::default(), predicates: Default::default() }, + Clone::clone, + ); + + let init_method = contains_initialize_with(&input.attrs); + let return_value = match &input.fields { + Fields::Named(fields) => { + let mut body = TokenStream::new(); + for field in &fields.named { + let field_name = field.ident.as_ref().unwrap(); + + let delta = if contains_skip(&field.attrs) { + quote! { + #field_name: Default::default(), + } + } else { + let field_type = &field.ty; + where_clause.predicates.push( + syn::parse2(quote! { + #field_type: #cratename::Decodable + }) + .unwrap(), + ); + + quote! { + #field_name: #cratename::Decodable::decode(&mut d)?, + } + }; + body.extend(delta); + } + quote! { + Self { #body } + } + } + Fields::Unnamed(fields) => { + let mut body = TokenStream::new(); + for _ in 0..fields.unnamed.len() { + let delta = quote! { + #cratename::Decodable::decode(&mut d)?, + }; + body.extend(delta); + } + quote! { + Self( #body ) + } + } + Fields::Unit => { + quote! { + Self {} + } + } + }; + + if let Some(method_ident) = init_method { + Ok(quote! { + impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { + fn decode(mut d: D) -> ::core::result::Result { + let mut return_value = #return_value; + return_value.#method_ident(); + Ok(return_value) + } + } + }) + } else { + Ok(quote! { + impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause { + fn decode(mut d: D) -> ::core::result::Result { + Ok(#return_value) + } + } + }) + } +}