mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(versionable): Handle ?Sized bounds in the proc macro
This commit is contained in:
committed by
Nicolas Sarlin
parent
51da8fe735
commit
9cc0b9050e
@@ -7,8 +7,8 @@ use syn::{
|
||||
|
||||
use crate::{
|
||||
add_lifetime_param, add_trait_where_clause, add_where_lifetime_bound_to_generics,
|
||||
extend_where_clause, parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME,
|
||||
SERIALIZE_TRAIT_NAME,
|
||||
extend_where_clause, filter_unsized_bounds, parse_const_str, DESERIALIZE_TRAIT_NAME,
|
||||
LIFETIME_NAME, SERIALIZE_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// Generates an impl block for the From trait. This will be:
|
||||
@@ -114,7 +114,7 @@ pub(crate) trait AssociatedType: Sized {
|
||||
|
||||
/// Returns the generics and bounds that should be added to the type
|
||||
fn type_generics(&self) -> syn::Result<Generics> {
|
||||
let mut generics = self.orig_type_generics().clone();
|
||||
let mut generics = filter_unsized_bounds(self.orig_type_generics());
|
||||
if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() {
|
||||
if let Some(lifetime) = opt_lifetime {
|
||||
add_lifetime_param(&mut generics, lifetime);
|
||||
|
||||
@@ -18,9 +18,10 @@ use proc_macro2::Span;
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::parse::Parse;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Plus;
|
||||
use syn::{
|
||||
parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
|
||||
LifetimeParam, Path, TraitBound, Type, TypeParamBound, WhereClause,
|
||||
LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
|
||||
};
|
||||
|
||||
/// Adds the full path of the current crate name to avoid name clashes in generated code.
|
||||
@@ -134,6 +135,8 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
|
||||
pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
let input_generics = filter_unsized_bounds(&input.generics);
|
||||
|
||||
let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
|
||||
&input.attrs
|
||||
));
|
||||
@@ -165,26 +168,26 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
|
||||
// split generics so they can be used inside the generated code
|
||||
let (_, ty_generics, _) = input.generics.split_for_impl();
|
||||
let (_, ty_generics, _) = input_generics.split_for_impl();
|
||||
|
||||
// Generates the associated types required by the traits
|
||||
let versioned_type = implementor.versioned_type(&lifetime, &input.generics);
|
||||
let versioned_owned_type = implementor.versioned_owned_type(&input.generics);
|
||||
let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
|
||||
let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
|
||||
let versioned_type_where_clause =
|
||||
implementor.versioned_type_where_clause(&lifetime, &input.generics);
|
||||
implementor.versioned_type_where_clause(&lifetime, &input_generics);
|
||||
let versioned_owned_type_where_clause =
|
||||
implementor.versioned_owned_type_where_clause(&input.generics);
|
||||
implementor.versioned_owned_type_where_clause(&input_generics);
|
||||
|
||||
// If the original type has some generics, we need to add bounds on them for
|
||||
// the traits impl
|
||||
let versionize_trait_where_clause =
|
||||
syn_unwrap!(implementor.versionize_trait_where_clause(&input.generics));
|
||||
syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
|
||||
let versionize_owned_trait_where_clause =
|
||||
syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input.generics));
|
||||
syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
|
||||
let unversionize_trait_where_clause =
|
||||
syn_unwrap!(implementor.unversionize_trait_where_clause(&input.generics));
|
||||
syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
|
||||
|
||||
let trait_impl_generics = input.generics.split_for_impl().0;
|
||||
let trait_impl_generics = input_generics.split_for_impl().0;
|
||||
|
||||
let versionize_body = implementor.versionize_method_body();
|
||||
let versionize_owned_body = implementor.versionize_owned_method_body();
|
||||
@@ -431,3 +434,55 @@ fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result
|
||||
fn parse_const_str<T: Parse>(s: &'static str) -> T {
|
||||
syn::parse_str(s).expect("Parsing of const string should not fail")
|
||||
}
|
||||
|
||||
/// Remove the '?Sized' bounds from the generics
|
||||
///
|
||||
/// The VersionDispatch trait requires that the versioned type is Sized so we have to remove this
|
||||
/// bound. It means that for a type `MyStruct<T: ?Sized>`, we will only be able to call
|
||||
/// `.versionize()` when T is Sized.
|
||||
fn filter_unsized_bounds(generics: &Generics) -> Generics {
|
||||
let mut generics = generics.clone();
|
||||
|
||||
for param in generics.type_params_mut() {
|
||||
param.bounds = remove_unsized_bound(¶m.bounds);
|
||||
}
|
||||
|
||||
if let Some(clause) = &mut generics.where_clause {
|
||||
for pred in &mut clause.predicates {
|
||||
match pred {
|
||||
syn::WherePredicate::Lifetime(_) => {}
|
||||
syn::WherePredicate::Type(type_predicate) => {
|
||||
type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
|
||||
/// Filter the ?Sized bound in a list of bounds
|
||||
fn remove_unsized_bound(
|
||||
bounds: &Punctuated<TypeParamBound, Plus>,
|
||||
) -> Punctuated<TypeParamBound, Plus> {
|
||||
bounds
|
||||
.iter()
|
||||
.filter(|bound| match bound {
|
||||
TypeParamBound::Trait(trait_bound) => {
|
||||
if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
|
||||
if let Some(segment) = trait_bound.path.segments.iter().last() {
|
||||
if segment.ident == "Sized" {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
TypeParamBound::Lifetime(_) => true,
|
||||
TypeParamBound::Verbatim(_) => true,
|
||||
_ => true,
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
66
utils/tfhe-versionable/tests/unsized.rs
Normal file
66
utils/tfhe-versionable/tests/unsized.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
struct MyStruct<T: ?Sized> {
|
||||
builtin: u32,
|
||||
attr: T,
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
struct MyStructV0 {
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
type Error = Infallible;
|
||||
|
||||
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
|
||||
Ok(MyStruct {
|
||||
attr: T::default(),
|
||||
builtin: self.builtin,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions<T> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct<T>),
|
||||
}
|
||||
|
||||
mod v0 {
|
||||
use tfhe_versionable::{Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub(super) struct MyStruct {
|
||||
pub(super) builtin: u32,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub(super) enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let value = 1234;
|
||||
let ms = v0::MyStruct { builtin: value };
|
||||
|
||||
let serialized = bincode::serialize(&ms.versionize()).unwrap();
|
||||
|
||||
let unserialized =
|
||||
MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(unserialized.builtin, value);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
main()
|
||||
}
|
||||
Reference in New Issue
Block a user