diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index e2e0d5e03..3e323caa1 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -16,6 +16,20 @@ pub(crate) const SERDE_ATTR_NAME: &str = "serde"; /// Transparent mode can also be activated using `#[repr(transparent)]` pub(crate) const REPR_ATTR_NAME: &str = "repr"; +/// The generated associated types will only derive Serialize/Deserialize. We should not propagate +/// any attribute from other derive macro (eg: `#[default]`). This is a list of attributes that +/// should be propagated to the newly created type. +pub(crate) const PRESERVED_FIELD_ATTRIBUTE_NAMES: [&str; 4] = [ + // Not all serde attribute might be good to propagate. However, as a first approach we allow + // all of them. This might need some refining later. + "serde", + // cfg and cfg_attr should be propagated because it might not be possible to define the + // associated fields if the feature are not enabled + "cfg", "cfg_attr", + // allow is propagated to avoid adding some warnings that the user wanted to disable + "allow", +]; + /// Represent the parsed `#[versionize(...)]` attribute pub(crate) enum VersionizeAttribute { Classic(ClassicVersionizeAttribute), @@ -341,27 +355,38 @@ pub(crate) fn is_skipped(attributes: &[Attribute]) -> syn::Result { Ok(false) } -/// Replace `#[versionize(skip)]` with `#[serde(skip)]` in an attributes list +/// Replace `#[versionize(skip)]` with `#[serde(skip)]` in an attributes list, and remove attributes +/// from other derived macro pub(crate) fn replace_versionize_skip_with_serde( attributes: &[Attribute], ) -> syn::Result> { attributes .iter() .cloned() - .map(|attr| { + .filter_map(|attr| { if attr.path().is_ident(VERSIONIZE_ATTR_NAME) { let nested = - attr.parse_args_with(Punctuated::::parse_terminated)?; + match attr.parse_args_with(Punctuated::::parse_terminated) { + Ok(nested) => nested, + Err(e) => return Some(Err(e)), + }; for meta in nested.iter() { if let Meta::Path(path) = meta { if path.is_ident("skip") { - return Ok(parse_quote! { #[serde(skip)] }); + return Some(Ok(parse_quote! { #[serde(skip)] })); } } } } - Ok(attr) + + for preserved_attr in PRESERVED_FIELD_ATTRIBUTE_NAMES { + if attr.path().is_ident(preserved_attr) { + return Some(Ok(attr)); + } + } + + None }) .collect() } diff --git a/utils/tfhe-versionable/tests/enum_default.rs b/utils/tfhe-versionable/tests/enum_default.rs new file mode 100644 index 000000000..a61d805ea --- /dev/null +++ b/utils/tfhe-versionable/tests/enum_default.rs @@ -0,0 +1,31 @@ +//! Test an enum that derives Default using the `#[default]` attribute + +use std::io::Cursor; + +use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch}; +#[derive(Default, Debug, PartialEq, Eq, Versionize)] +#[versionize(MyEnumVersions)] +pub enum MyEnum { + Var0, + #[default] + Var1, +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +pub enum MyEnumVersions { + V0(MyEnum), +} + +#[test] +fn test() { + let enu = MyEnum::default(); + + let mut ser = Vec::new(); + ciborium::ser::into_writer(&enu.versionize(), &mut ser).unwrap(); + + let unvers = + MyEnum::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap(); + + assert_eq!(unvers, enu); +}