feat(versionable): Handle ?Sized bounds in the proc macro

This commit is contained in:
Nicolas Sarlin
2024-10-07 15:17:34 +02:00
committed by Nicolas Sarlin
parent 51da8fe735
commit 9cc0b9050e
3 changed files with 134 additions and 13 deletions

View File

@@ -7,8 +7,8 @@ use syn::{
use crate::{
add_lifetime_param, add_trait_where_clause, add_where_lifetime_bound_to_generics,
extend_where_clause, parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME,
SERIALIZE_TRAIT_NAME,
extend_where_clause, filter_unsized_bounds, parse_const_str, DESERIALIZE_TRAIT_NAME,
LIFETIME_NAME, SERIALIZE_TRAIT_NAME,
};
/// Generates an impl block for the From trait. This will be:
@@ -114,7 +114,7 @@ pub(crate) trait AssociatedType: Sized {
/// Returns the generics and bounds that should be added to the type
fn type_generics(&self) -> syn::Result<Generics> {
let mut generics = self.orig_type_generics().clone();
let mut generics = filter_unsized_bounds(self.orig_type_generics());
if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() {
if let Some(lifetime) = opt_lifetime {
add_lifetime_param(&mut generics, lifetime);

View File

@@ -18,9 +18,10 @@ use proc_macro2::Span;
use quote::{quote, ToTokens};
use syn::parse::Parse;
use syn::punctuated::Punctuated;
use syn::token::Plus;
use syn::{
parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
LifetimeParam, Path, TraitBound, Type, TypeParamBound, WhereClause,
LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
};
/// Adds the full path of the current crate name to avoid name clashes in generated code.
@@ -134,6 +135,8 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
pub fn derive_versionize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let input_generics = filter_unsized_bounds(&input.generics);
let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
&input.attrs
));
@@ -165,26 +168,26 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
// split generics so they can be used inside the generated code
let (_, ty_generics, _) = input.generics.split_for_impl();
let (_, ty_generics, _) = input_generics.split_for_impl();
// Generates the associated types required by the traits
let versioned_type = implementor.versioned_type(&lifetime, &input.generics);
let versioned_owned_type = implementor.versioned_owned_type(&input.generics);
let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
let versioned_type_where_clause =
implementor.versioned_type_where_clause(&lifetime, &input.generics);
implementor.versioned_type_where_clause(&lifetime, &input_generics);
let versioned_owned_type_where_clause =
implementor.versioned_owned_type_where_clause(&input.generics);
implementor.versioned_owned_type_where_clause(&input_generics);
// If the original type has some generics, we need to add bounds on them for
// the traits impl
let versionize_trait_where_clause =
syn_unwrap!(implementor.versionize_trait_where_clause(&input.generics));
syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
let versionize_owned_trait_where_clause =
syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input.generics));
syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
let unversionize_trait_where_clause =
syn_unwrap!(implementor.unversionize_trait_where_clause(&input.generics));
syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
let trait_impl_generics = input.generics.split_for_impl().0;
let trait_impl_generics = input_generics.split_for_impl().0;
let versionize_body = implementor.versionize_method_body();
let versionize_owned_body = implementor.versionize_owned_method_body();
@@ -431,3 +434,55 @@ fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result
fn parse_const_str<T: Parse>(s: &'static str) -> T {
syn::parse_str(s).expect("Parsing of const string should not fail")
}
/// Remove the '?Sized' bounds from the generics
///
/// The VersionDispatch trait requires that the versioned type is Sized so we have to remove this
/// bound. It means that for a type `MyStruct<T: ?Sized>`, we will only be able to call
/// `.versionize()` when T is Sized.
fn filter_unsized_bounds(generics: &Generics) -> Generics {
let mut generics = generics.clone();
for param in generics.type_params_mut() {
param.bounds = remove_unsized_bound(&param.bounds);
}
if let Some(clause) = &mut generics.where_clause {
for pred in &mut clause.predicates {
match pred {
syn::WherePredicate::Lifetime(_) => {}
syn::WherePredicate::Type(type_predicate) => {
type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
}
_ => {}
}
}
}
generics
}
/// Filter the ?Sized bound in a list of bounds
fn remove_unsized_bound(
bounds: &Punctuated<TypeParamBound, Plus>,
) -> Punctuated<TypeParamBound, Plus> {
bounds
.iter()
.filter(|bound| match bound {
TypeParamBound::Trait(trait_bound) => {
if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
if let Some(segment) = trait_bound.path.segments.iter().last() {
if segment.ident == "Sized" {
return false;
}
}
}
true
}
TypeParamBound::Lifetime(_) => true,
TypeParamBound::Verbatim(_) => true,
_ => true,
})
.cloned()
.collect()
}

View File

@@ -0,0 +1,66 @@
use std::convert::Infallible;
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct<T: ?Sized> {
builtin: u32,
attr: T,
}
#[derive(Version)]
struct MyStructV0 {
builtin: u32,
}
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
mod v0 {
use tfhe_versionable::{Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(MyStructVersions)]
pub(super) struct MyStruct {
pub(super) builtin: u32,
}
#[derive(VersionsDispatch)]
#[allow(unused)]
pub(super) enum MyStructVersions {
V0(MyStruct),
}
}
fn main() {
let value = 1234;
let ms = v0::MyStruct { builtin: value };
let serialized = bincode::serialize(&ms.versionize()).unwrap();
let unserialized =
MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
assert_eq!(unserialized.builtin, value);
}
#[test]
fn test() {
main()
}