mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
fix(versionable): compatibility between "convert" and generics
This commit is contained in:
committed by
Nicolas Sarlin
parent
2af4676588
commit
2b14b22820
@@ -6,8 +6,9 @@ use syn::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
add_lifetime_bound, add_trait_where_clause, add_where_lifetime_bound, extend_where_clause,
|
||||
parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, SERIALIZE_TRAIT_NAME,
|
||||
add_lifetime_param, add_trait_where_clause, add_where_lifetime_bound_to_generics,
|
||||
extend_where_clause, parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME,
|
||||
SERIALIZE_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// Generates an impl block for the From trait. This will be:
|
||||
@@ -116,7 +117,7 @@ pub(crate) trait AssociatedType: Sized {
|
||||
let mut generics = self.orig_type_generics().clone();
|
||||
if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() {
|
||||
if let Some(lifetime) = opt_lifetime {
|
||||
add_lifetime_bound(&mut generics, lifetime);
|
||||
add_lifetime_param(&mut generics, lifetime);
|
||||
}
|
||||
add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?;
|
||||
} else {
|
||||
@@ -214,8 +215,8 @@ impl<T: AssociatedType> AssociatingTrait<T> {
|
||||
let mut ref_type_generics = self.ref_type.orig_type_generics().clone();
|
||||
// If the original type has some generics, we need to add a lifetime bound on them
|
||||
if let Some(lifetime) = self.ref_type.lifetime() {
|
||||
add_lifetime_bound(&mut ref_type_generics, lifetime);
|
||||
add_where_lifetime_bound(&mut ref_type_generics, lifetime);
|
||||
add_lifetime_param(&mut ref_type_generics, lifetime);
|
||||
add_where_lifetime_bound_to_generics(&mut ref_type_generics, lifetime);
|
||||
}
|
||||
|
||||
let (impl_generics, orig_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
@@ -42,6 +42,13 @@ pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeE
|
||||
|
||||
pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
|
||||
pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
|
||||
pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
|
||||
pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
|
||||
pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
|
||||
pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
|
||||
pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
|
||||
pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
|
||||
pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
|
||||
|
||||
use associated::AssociatingTrait;
|
||||
|
||||
@@ -140,47 +147,7 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
Some(impl_version_trait(&input))
|
||||
};
|
||||
|
||||
let dispatch_enum_path = attributes.dispatch_enum();
|
||||
let dispatch_target = attributes.dispatch_target();
|
||||
let input_ident = &input.ident;
|
||||
let mut ref_generics = input.generics.clone();
|
||||
let mut trait_generics = input.generics.clone();
|
||||
let (_, ty_generics, owned_where_clause) = input.generics.split_for_impl();
|
||||
|
||||
// If the original type has some generics, we need to add bounds on them for
|
||||
// the impl
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
add_where_lifetime_bound(&mut ref_generics, &lifetime);
|
||||
|
||||
// The versionize method takes a ref. We need to own the input type in the conversion case
|
||||
// to apply `From<Input> for Target`. This adds a `Clone` bound to have a better error message
|
||||
// if the input type is not Clone.
|
||||
if attributes.needs_conversion() {
|
||||
syn_unwrap!(add_trait_where_clause(
|
||||
&mut trait_generics,
|
||||
[&parse_quote! { Self }],
|
||||
&["Clone"]
|
||||
));
|
||||
};
|
||||
|
||||
let dispatch_generics = if attributes.needs_conversion() {
|
||||
None
|
||||
} else {
|
||||
Some(&ty_generics)
|
||||
};
|
||||
|
||||
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
||||
|
||||
syn_unwrap!(add_trait_where_clause(
|
||||
&mut trait_generics,
|
||||
[&parse_quote!(#dispatch_enum_path #dispatch_generics)],
|
||||
&[format!(
|
||||
"{}<{}>",
|
||||
DISPATCH_TRAIT_NAME,
|
||||
dispatch_target.to_token_stream()
|
||||
)]
|
||||
));
|
||||
|
||||
// Parse the name of the traits that we will implement
|
||||
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
||||
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
||||
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
|
||||
@@ -188,19 +155,33 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
|
||||
let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
|
||||
|
||||
// split generics so they can be used inside the generated code
|
||||
let (_, _, ref_where_clause) = ref_generics.split_for_impl();
|
||||
let (trait_impl_generics, _, trait_where_clause) = trait_generics.split_for_impl();
|
||||
let input_ident = &input.ident;
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
|
||||
// If we want to apply a conversion before the call to versionize we need to use the "owned"
|
||||
// alternative of the dispatch enum to be able to store the conversion result.
|
||||
let versioned_type_kind = if attributes.needs_conversion() {
|
||||
quote! { Owned #owned_where_clause }
|
||||
} else {
|
||||
quote! { Ref<#lifetime> #ref_where_clause }
|
||||
};
|
||||
// split generics so they can be used inside the generated code
|
||||
let (_, ty_generics, _) = input.generics.split_for_impl();
|
||||
|
||||
// Generates the associated types required by the traits
|
||||
let versioned_type = attributes.versioned_type(&lifetime, &input.generics);
|
||||
let versioned_owned_type = attributes.versioned_owned_type(&input.generics);
|
||||
let versioned_type_where_clause =
|
||||
attributes.versioned_type_where_clause(&lifetime, &input.generics);
|
||||
let versioned_owned_type_where_clause =
|
||||
attributes.versioned_owned_type_where_clause(&input.generics);
|
||||
|
||||
// If the original type has some generics, we need to add bounds on them for
|
||||
// the traits impl
|
||||
let versionize_trait_where_clause =
|
||||
syn_unwrap!(attributes.versionize_trait_where_clause(&input.generics));
|
||||
let versionize_owned_trait_where_clause =
|
||||
syn_unwrap!(attributes.versionize_owned_trait_where_clause(&input.generics));
|
||||
let unversionize_trait_where_clause =
|
||||
syn_unwrap!(attributes.unversionize_trait_where_clause(&input.generics));
|
||||
|
||||
let trait_impl_generics = input.generics.split_for_impl().0;
|
||||
|
||||
let versionize_body = attributes.versionize_method_body();
|
||||
let versionize_owned_body = attributes.versionize_owned_method_body();
|
||||
let unversionize_arg_name = Ident::new("versioned", Span::call_site());
|
||||
let unversionize_body = attributes.unversionize_method_body(&unversionize_arg_name);
|
||||
let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
@@ -210,11 +191,9 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics
|
||||
#trait_where_clause
|
||||
#versionize_trait_where_clause
|
||||
{
|
||||
type Versioned<#lifetime> =
|
||||
<#dispatch_enum_path #dispatch_generics as
|
||||
#dispatch_trait<#dispatch_target>>::#versioned_type_kind;
|
||||
type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
#versionize_body
|
||||
@@ -223,20 +202,18 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics
|
||||
#trait_where_clause
|
||||
#versionize_owned_trait_where_clause
|
||||
{
|
||||
type VersionedOwned =
|
||||
<#dispatch_enum_path #dispatch_generics as
|
||||
#dispatch_trait<#dispatch_target>>::Owned #owned_where_clause;
|
||||
type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause;
|
||||
|
||||
fn versionize_owned(self) -> Self::VersionedOwned {
|
||||
#versionize_body
|
||||
#versionize_owned_body
|
||||
}
|
||||
}
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics
|
||||
#trait_where_clause
|
||||
#unversionize_trait_where_clause
|
||||
{
|
||||
fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
|
||||
#unversionize_body
|
||||
@@ -245,20 +222,21 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics
|
||||
#trait_where_clause
|
||||
#versionize_trait_where_clause
|
||||
{
|
||||
type VersionedSlice<#lifetime> = Vec<<Self as #versionize_trait>::Versioned<#lifetime>> #ref_where_clause;
|
||||
type VersionedSlice<#lifetime> = Vec<<Self as #versionize_trait>::Versioned<#lifetime>> #versioned_type_where_clause;
|
||||
|
||||
fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
|
||||
slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics
|
||||
#trait_where_clause
|
||||
#versionize_owned_trait_where_clause
|
||||
{
|
||||
|
||||
type VersionedVec = Vec<<Self as #versionize_owned_trait>::VersionedOwned> #owned_where_clause;
|
||||
type VersionedVec = Vec<<Self as #versionize_owned_trait>::VersionedOwned> #versioned_owned_type_where_clause;
|
||||
|
||||
fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
|
||||
vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
|
||||
@@ -267,7 +245,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
|
||||
#[automatically_derived]
|
||||
impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics
|
||||
#trait_where_clause {
|
||||
#unversionize_trait_where_clause
|
||||
{
|
||||
fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, #unversionize_error> {
|
||||
versioned
|
||||
.into_iter()
|
||||
@@ -335,7 +314,7 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
|
||||
/// Adds a where clause with a lifetime bound on all the generic types and lifetimes in `generics`
|
||||
fn add_where_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
let mut params = Vec::new();
|
||||
for param in generics.params.iter() {
|
||||
let param_ident = match param {
|
||||
@@ -359,8 +338,8 @@ fn add_where_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a lifetime bound for all the generic types in `generics`
|
||||
fn add_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
/// Adds a new lifetime param with a bound for all the generic types in `generics`
|
||||
fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
generics
|
||||
.params
|
||||
.push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
|
||||
@@ -398,6 +377,27 @@ fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a "where clause" bound for all the input types with all the input lifetimes
|
||||
fn add_lifetime_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
|
||||
generics: &mut Generics,
|
||||
types: I,
|
||||
lifetimes: &[S],
|
||||
) -> syn::Result<()> {
|
||||
let preds = &mut generics.make_where_clause().predicates;
|
||||
|
||||
if !lifetimes.is_empty() {
|
||||
let bounds: Vec<Lifetime> = lifetimes
|
||||
.iter()
|
||||
.map(|lifetime| syn::parse_str(lifetime.as_ref()))
|
||||
.collect::<syn::Result<_>>()?;
|
||||
for ty in types {
|
||||
preds.push(parse_quote! { #ty: #(#bounds)+* });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extends a where clause with predicates from another one, filtering duplicates
|
||||
fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) {
|
||||
for extend_predicate in &extension_clause.predicates {
|
||||
|
||||
@@ -2,47 +2,146 @@ use proc_macro2::Span;
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{Attribute, Expr, Ident, Lit, Meta, Path, Token, TraitBound, Type};
|
||||
use syn::{
|
||||
parse_quote, Attribute, Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Lit,
|
||||
Meta, Path, PathArguments, Token, TraitBound, Type, TypeParam, WhereClause,
|
||||
};
|
||||
|
||||
use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME, VERSIONIZE_OWNED_TRAIT_NAME};
|
||||
use crate::{
|
||||
add_lifetime_where_clause, add_trait_where_clause, add_where_lifetime_bound_to_generics,
|
||||
parse_const_str, DISPATCH_TRAIT_NAME, ERROR_TRAIT_NAME, FROM_TRAIT_NAME, INTO_TRAIT_NAME,
|
||||
SEND_TRAIT_NAME, STATIC_LIFETIME_NAME, SYNC_TRAIT_NAME, TRY_INTO_TRAIT_NAME,
|
||||
UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// Name of the attribute used to give arguments to the `Versionize` macro
|
||||
const VERSIONIZE_ATTR_NAME: &str = "versionize";
|
||||
|
||||
pub(crate) struct VersionizeAttribute {
|
||||
pub(crate) struct ClassicVersionizeAttribute {
|
||||
dispatch_enum: Path,
|
||||
from: Option<Path>,
|
||||
try_from: Option<Path>,
|
||||
into: Option<Path>,
|
||||
}
|
||||
|
||||
pub(crate) enum ConversionType {
|
||||
Direct,
|
||||
Try,
|
||||
}
|
||||
|
||||
pub(crate) struct ConvertVersionizeAttribute {
|
||||
conversion_target: Path,
|
||||
conversion_type: ConversionType,
|
||||
}
|
||||
|
||||
pub(crate) enum VersionizeAttribute {
|
||||
Classic(ClassicVersionizeAttribute),
|
||||
Convert(ConvertVersionizeAttribute),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct VersionizeAttributeBuilder {
|
||||
dispatch_enum: Option<Path>,
|
||||
convert: Option<Path>,
|
||||
try_convert: Option<Path>,
|
||||
from: Option<Path>,
|
||||
try_from: Option<Path>,
|
||||
into: Option<Path>,
|
||||
}
|
||||
|
||||
impl VersionizeAttributeBuilder {
|
||||
fn build(self) -> Option<VersionizeAttribute> {
|
||||
// These attributes are mutually exclusive
|
||||
if self.from.is_some() && self.try_from.is_some() {
|
||||
return None;
|
||||
fn build(self, base_span: &Span) -> syn::Result<VersionizeAttribute> {
|
||||
let convert_is_try = self.try_convert.is_some() || self.try_from.is_some();
|
||||
// User should not use `from` and `try_from` at the same time
|
||||
let from_target = match (self.from, self.try_from) {
|
||||
(None, None) => None,
|
||||
(Some(_), Some(try_from)) => {
|
||||
return Err(syn::Error::new(
|
||||
try_from.span(),
|
||||
"'try_from' and 'from' attributes are mutually exclusive",
|
||||
))
|
||||
}
|
||||
(None, Some(try_from)) => Some(try_from),
|
||||
(Some(from), None) => Some(from),
|
||||
};
|
||||
|
||||
// Same with `convert`/`try_convert`
|
||||
let convert_target = match (self.convert, self.try_convert) {
|
||||
(None, None) => None,
|
||||
(Some(_), Some(try_convert)) => {
|
||||
return Err(syn::Error::new(
|
||||
try_convert.span(),
|
||||
"'try_convert' and 'convert' attributes are mutually exclusive",
|
||||
))
|
||||
}
|
||||
(None, Some(try_convert)) => Some(try_convert),
|
||||
(Some(convert), None) => Some(convert),
|
||||
};
|
||||
|
||||
// from/into are here for similarity with serde, but we don't actually support having
|
||||
// different target inside. So we check this to warn the user
|
||||
let from_target =
|
||||
match (from_target, self.into) {
|
||||
(None, None) => None,
|
||||
(None, Some(into)) => return Err(syn::Error::new(
|
||||
into.span(),
|
||||
"unidirectional conversions are not handled, please add a 'from'/'try_from' \
|
||||
attribute or use the 'convert'/'try_convert' attribute instead",
|
||||
)),
|
||||
(Some(from), None) => return Err(syn::Error::new(
|
||||
from.span(),
|
||||
"unidirectional conversions are not handled, please add a 'into' attribute or \
|
||||
use the 'convert'/'try_convert' attribute instead",
|
||||
)),
|
||||
(Some(from), Some(into)) => {
|
||||
if format!("{}", from.to_token_stream())
|
||||
!= format!("{}", into.to_token_stream())
|
||||
{
|
||||
return Err(syn::Error::new(
|
||||
from.span(),
|
||||
"unidirectional conversions are not handled, 'from' and 'into' parameters \
|
||||
should have the same value",
|
||||
));
|
||||
} else {
|
||||
Some(from)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Finally, checks that the user doesn't use both from/into and convert
|
||||
let conversion_target = match (from_target, convert_target) {
|
||||
(None, None) => None,
|
||||
(Some(_), Some(convert)) => {
|
||||
return Err(syn::Error::new(
|
||||
convert.span(),
|
||||
"'convert' and 'from'/'into' attributes are mutually exclusive",
|
||||
))
|
||||
}
|
||||
(None, Some(convert)) => Some(convert),
|
||||
(Some(from), None) => Some(from),
|
||||
};
|
||||
|
||||
if let Some(conversion_target) = conversion_target {
|
||||
Ok(VersionizeAttribute::Convert(ConvertVersionizeAttribute {
|
||||
conversion_target,
|
||||
conversion_type: if convert_is_try {
|
||||
ConversionType::Try
|
||||
} else {
|
||||
ConversionType::Direct
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(VersionizeAttribute::Classic(ClassicVersionizeAttribute {
|
||||
dispatch_enum: self.dispatch_enum.ok_or(syn::Error::new(
|
||||
*base_span,
|
||||
"Missing dispatch enum argument",
|
||||
))?,
|
||||
}))
|
||||
}
|
||||
Some(VersionizeAttribute {
|
||||
dispatch_enum: self.dispatch_enum?,
|
||||
from: self.from,
|
||||
try_from: self.try_from,
|
||||
into: self.into,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VersionizeAttribute {
|
||||
/// Find and parse an attribute with the form `#[versionize(DispatchType)]`, where
|
||||
/// `DispatchType` is the name of the type holding the dispatch enum.
|
||||
/// Returns an error if no `versionize` attribute has been found, if multiple attributes are
|
||||
/// Return an error if no `versionize` attribute has been found, if multiple attributes are
|
||||
/// present on the same struct or if the attribute is malformed.
|
||||
pub(crate) fn parse_from_attributes_list(
|
||||
attributes: &[Attribute],
|
||||
@@ -82,8 +181,24 @@ impl VersionizeAttribute {
|
||||
}
|
||||
}
|
||||
Meta::NameValue(name_value) => {
|
||||
// parse versionize(convert = "TypeConvert")
|
||||
if name_value.path.is_ident("convert") {
|
||||
if attribute_builder.convert.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.convert =
|
||||
Some(parse_path_ignore_quotes(&name_value.value)?);
|
||||
}
|
||||
// parse versionize(try_convert = "TypeTryConvert")
|
||||
} else if name_value.path.is_ident("try_convert") {
|
||||
if attribute_builder.try_convert.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.try_convert =
|
||||
Some(parse_path_ignore_quotes(&name_value.value)?);
|
||||
}
|
||||
// parse versionize(from = "TypeFrom")
|
||||
if name_value.path.is_ident("from") {
|
||||
} else if name_value.path.is_ident("from") {
|
||||
if attribute_builder.from.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
@@ -122,60 +237,289 @@ impl VersionizeAttribute {
|
||||
}
|
||||
}
|
||||
|
||||
attribute_builder
|
||||
.build()
|
||||
.ok_or_else(|| Self::default_error(attribute.span()))
|
||||
}
|
||||
|
||||
pub(crate) fn dispatch_enum(&self) -> &Path {
|
||||
&self.dispatch_enum
|
||||
attribute_builder.build(&attribute.span())
|
||||
}
|
||||
|
||||
pub(crate) fn needs_conversion(&self) -> bool {
|
||||
self.try_from.is_some() || self.from.is_some()
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => false,
|
||||
VersionizeAttribute::Convert(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn dispatch_target(&self) -> Path {
|
||||
self.from
|
||||
.as_ref()
|
||||
.or(self.try_from.as_ref())
|
||||
.map(|target| target.to_owned())
|
||||
.unwrap_or_else(|| {
|
||||
syn::parse_str("Self").expect("Parsing of const value should never fail")
|
||||
})
|
||||
/// Return the associated type used in the `Versionize` trait: `MyType::Versioned<'vers>`
|
||||
///
|
||||
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
|
||||
///
|
||||
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
|
||||
/// conversion target. The versioned_owned is needed because the conversion will create a new
|
||||
/// value, so we can't just use a reference.
|
||||
pub(crate) fn versioned_type(
|
||||
&self,
|
||||
lifetime: &Lifetime,
|
||||
input_generics: &Generics,
|
||||
) -> proc_macro2::TokenStream {
|
||||
match self {
|
||||
VersionizeAttribute::Classic(attr) => {
|
||||
let (_, ty_generics, _) = input_generics.split_for_impl();
|
||||
|
||||
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
||||
let dispatch_enum_path = &attr.dispatch_enum;
|
||||
quote! {
|
||||
<#dispatch_enum_path #ty_generics as
|
||||
#dispatch_trait<Self>>::Ref<#lifetime>
|
||||
}
|
||||
}
|
||||
VersionizeAttribute::Convert(_) => {
|
||||
// If we want to apply a conversion before the call to versionize we need to use the
|
||||
// "owned" alternative of the dispatch enum to be able to store the
|
||||
// conversion result.
|
||||
self.versioned_owned_type(input_generics)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the where clause for `MyType::Versioned<'vers>`. if `MyType` has generics, this means
|
||||
/// adding a 'vers lifetime bound on them.
|
||||
pub(crate) fn versioned_type_where_clause(
|
||||
&self,
|
||||
lifetime: &Lifetime,
|
||||
input_generics: &Generics,
|
||||
) -> Option<WhereClause> {
|
||||
let mut generics = input_generics.clone();
|
||||
|
||||
add_where_lifetime_bound_to_generics(&mut generics, lifetime);
|
||||
let (_, _, where_clause) = generics.split_for_impl();
|
||||
where_clause.cloned()
|
||||
}
|
||||
|
||||
/// Return the associated type used in the `VersionizeOwned` trait: `MyType::VersionedOwned`
|
||||
///
|
||||
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
|
||||
///
|
||||
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
|
||||
/// conversion target.
|
||||
pub(crate) fn versioned_owned_type(
|
||||
&self,
|
||||
input_generics: &Generics,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let (_, ty_generics, _) = input_generics.split_for_impl();
|
||||
match self {
|
||||
VersionizeAttribute::Classic(attr) => {
|
||||
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
||||
let dispatch_enum_path = &attr.dispatch_enum;
|
||||
quote! {
|
||||
<#dispatch_enum_path #ty_generics as
|
||||
#dispatch_trait<Self>>::Owned
|
||||
}
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let convert_type_path = &convert_attr.conversion_target;
|
||||
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
||||
|
||||
quote! {
|
||||
<#convert_type_path as #versionize_owned_trait>::VersionedOwned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the where clause for `MyType::VersionedOwned`.
|
||||
///
|
||||
/// This is simply the where clause of the input type.
|
||||
pub(crate) fn versioned_owned_type_where_clause(
|
||||
&self,
|
||||
input_generics: &Generics,
|
||||
) -> Option<WhereClause> {
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => input_generics.split_for_impl().2.cloned(),
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
extract_generics(&convert_attr.conversion_target)
|
||||
.split_for_impl()
|
||||
.2
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the where clause needed to implement the Versionize trait.
|
||||
///
|
||||
/// This is the same as the one for the VersionizeOwned, with an additional "Clone" bound in the
|
||||
/// case where we need to perform a conversion before the versioning.
|
||||
pub(crate) fn versionize_trait_where_clause(
|
||||
&self,
|
||||
input_generics: &Generics,
|
||||
) -> syn::Result<Option<WhereClause>> {
|
||||
// The base bounds for the owned traits are also used for the ref traits
|
||||
let mut generics = input_generics.clone();
|
||||
if self.needs_conversion() {
|
||||
// The versionize method takes a ref. We need to own the input type in the conversion
|
||||
// case to apply `From<Input> for Target`. This adds a `Clone` bound to have
|
||||
// a better error message if the input type is not Clone.
|
||||
add_trait_where_clause(&mut generics, [&parse_quote! { Self }], &["Clone"])?;
|
||||
}
|
||||
|
||||
self.versionize_owned_trait_where_clause(&generics)
|
||||
}
|
||||
|
||||
/// Return the where clause needed to implement the VersionizeOwned trait.
|
||||
///
|
||||
/// If the type is directly versioned, the bound states that the argument points to a valid
|
||||
/// DispatchEnum for this type. This is done by adding a bound on this argument to
|
||||
/// `VersionsDisaptch<Self>`.
|
||||
///
|
||||
/// If there is a conversion, the target of the conversion should implement `VersionizeOwned`
|
||||
/// and `From<Self>`.
|
||||
pub(crate) fn versionize_owned_trait_where_clause(
|
||||
&self,
|
||||
input_generics: &Generics,
|
||||
) -> syn::Result<Option<WhereClause>> {
|
||||
let mut generics = input_generics.clone();
|
||||
match self {
|
||||
VersionizeAttribute::Classic(attr) => {
|
||||
let dispatch_generics = generics.clone();
|
||||
let dispatch_ty_generics = dispatch_generics.split_for_impl().1;
|
||||
let dispatch_enum_path = &attr.dispatch_enum;
|
||||
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
[&parse_quote!(#dispatch_enum_path #dispatch_ty_generics)],
|
||||
&[format!("{}<Self>", DISPATCH_TRAIT_NAME,)],
|
||||
)?;
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let convert_type_path = &convert_attr.conversion_target;
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
[&parse_quote!(#convert_type_path)],
|
||||
&[
|
||||
VERSIONIZE_OWNED_TRAIT_NAME,
|
||||
&format!("{}<Self>", FROM_TRAIT_NAME),
|
||||
],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(generics.split_for_impl().2.cloned())
|
||||
}
|
||||
|
||||
/// Return the where clause for the `Unversionize` trait.
|
||||
///
|
||||
/// If the versioning is direct, this is the same bound as the one used for `VersionizeOwned`.
|
||||
///
|
||||
/// If there is a conversion, the target of the conversion need to implement `Unversionize` and
|
||||
/// `Into` or `TryInto<T, E>`, with `E: Error + Send + Sync + 'static`
|
||||
pub(crate) fn unversionize_trait_where_clause(
|
||||
&self,
|
||||
input_generics: &Generics,
|
||||
) -> syn::Result<Option<WhereClause>> {
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => {
|
||||
self.versionize_owned_trait_where_clause(input_generics)
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let mut generics = input_generics.clone();
|
||||
let convert_type_path = &convert_attr.conversion_target;
|
||||
let into_trait = match convert_attr.conversion_type {
|
||||
ConversionType::Direct => format!("{}<Self>", INTO_TRAIT_NAME),
|
||||
ConversionType::Try => {
|
||||
// Doing a TryFrom requires that the error
|
||||
// impl Error + Send + Sync + 'static
|
||||
let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME);
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
|
||||
&[ERROR_TRAIT_NAME, SYNC_TRAIT_NAME, SEND_TRAIT_NAME],
|
||||
)?;
|
||||
add_lifetime_where_clause(
|
||||
&mut generics,
|
||||
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
|
||||
&[STATIC_LIFETIME_NAME],
|
||||
)?;
|
||||
|
||||
format!("{}<Self>", TRY_INTO_TRAIT_NAME)
|
||||
}
|
||||
};
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
[&parse_quote!(#convert_type_path)],
|
||||
&[
|
||||
UNVERSIONIZE_TRAIT_NAME,
|
||||
&format!("{}<Self>", FROM_TRAIT_NAME),
|
||||
&into_trait,
|
||||
],
|
||||
)?;
|
||||
|
||||
Ok(generics.split_for_impl().2.cloned())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the body of the versionize method.
|
||||
pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream {
|
||||
let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
||||
self.into
|
||||
.as_ref()
|
||||
.map(|target| {
|
||||
quote! {
|
||||
#versionize_owned_trait::versionize_owned(#target::from(self.to_owned()))
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => {
|
||||
quote! {
|
||||
self.into()
|
||||
}
|
||||
})
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
|
||||
quote! {
|
||||
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self.to_owned()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the body of the versionize_owned method.
|
||||
pub(crate) fn versionize_owned_method_body(&self) -> proc_macro2::TokenStream {
|
||||
let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
||||
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => {
|
||||
quote! {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
|
||||
quote! {
|
||||
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the body of the unversionize method.
|
||||
pub(crate) fn unversionize_method_body(&self, arg_name: &Ident) -> proc_macro2::TokenStream {
|
||||
let error: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
if let Some(target) = &self.from {
|
||||
quote! { #target::unversionize(#arg_name).map(|value| value.into()) }
|
||||
} else if let Some(target) = &self.try_from {
|
||||
let target_name = format!("{}", target.to_token_stream());
|
||||
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
|
||||
.map_err(|e| #error::conversion(#target_name, e)))
|
||||
match self {
|
||||
VersionizeAttribute::Classic(_) => {
|
||||
quote! { #arg_name.try_into() }
|
||||
}
|
||||
VersionizeAttribute::Convert(convert_attr) => {
|
||||
let target = with_turbofish(&convert_attr.conversion_target);
|
||||
match convert_attr.conversion_type {
|
||||
ConversionType::Direct => {
|
||||
quote! { #target::unversionize(#arg_name).map(|value| value.into()) }
|
||||
}
|
||||
ConversionType::Try => {
|
||||
let target_name = format!("{}", target.to_token_stream());
|
||||
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
|
||||
.map_err(|e| #error::conversion(#target_name, e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! { #arg_name.try_into() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allow the user to give type arguments as `#[versionize(MyType)]` as well as
|
||||
/// `#[versionize("MyType")]`
|
||||
fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
|
||||
match &value {
|
||||
Expr::Path(expr_path) => Ok(expr_path.path.clone()),
|
||||
@@ -192,3 +536,37 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the same type but with generics that use the turbofish syntax. Converts
|
||||
/// `MyStruct<T>` into `MyStruct::<T>`
|
||||
fn with_turbofish(path: &Path) -> Path {
|
||||
let mut with_turbo = path.clone();
|
||||
|
||||
for segment in with_turbo.segments.iter_mut() {
|
||||
if let PathArguments::AngleBracketed(generics) = &mut segment.arguments {
|
||||
generics.colon2_token = Some(Token));
|
||||
}
|
||||
}
|
||||
|
||||
with_turbo
|
||||
}
|
||||
|
||||
/// Extract the generics inside a type
|
||||
fn extract_generics(path: &Path) -> Generics {
|
||||
let mut generics = Generics::default();
|
||||
|
||||
if let Some(last_segment) = path.segments.last() {
|
||||
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
|
||||
for arg in &args.args {
|
||||
if let GenericArgument::Type(Type::Path(type_path)) = arg {
|
||||
if let Some(ident) = type_path.path.get_ident() {
|
||||
let param = TypeParam::from(ident.clone());
|
||||
generics.params.push(GenericParam::Type(param));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
|
||||
@@ -3,48 +3,57 @@
|
||||
use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Clone, Versionize)]
|
||||
#[versionize(SerializableMyStructVersions, from = SerializableMyStruct, into = SerializableMyStruct)]
|
||||
struct MyStruct {
|
||||
// To mimic serde parameters, this can also be expressed as
|
||||
// "#[versionize(from = SerializableMyStruct, into = SerializableMyStruct)]"
|
||||
#[versionize(convert = "SerializableMyStruct<T>")]
|
||||
struct MyStruct<T> {
|
||||
val: u64,
|
||||
generics: T,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(SerializableMyStructVersions)]
|
||||
struct SerializableMyStruct {
|
||||
struct SerializableMyStruct<T> {
|
||||
high: u32,
|
||||
low: u32,
|
||||
generics: T,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum SerializableMyStructVersions {
|
||||
V0(SerializableMyStruct),
|
||||
enum SerializableMyStructVersions<T> {
|
||||
V0(SerializableMyStruct<T>),
|
||||
}
|
||||
|
||||
impl From<MyStruct> for SerializableMyStruct {
|
||||
fn from(value: MyStruct) -> Self {
|
||||
println!("{}", value.val);
|
||||
impl<T> From<MyStruct<T>> for SerializableMyStruct<T> {
|
||||
fn from(value: MyStruct<T>) -> Self {
|
||||
Self {
|
||||
high: (value.val >> 32) as u32,
|
||||
low: (value.val & 0xffffffff) as u32,
|
||||
generics: value.generics,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SerializableMyStruct> for MyStruct {
|
||||
fn from(value: SerializableMyStruct) -> Self {
|
||||
impl<T> From<SerializableMyStruct<T>> for MyStruct<T> {
|
||||
fn from(value: SerializableMyStruct<T>) -> Self {
|
||||
Self {
|
||||
val: ((value.high as u64) << 32) | (value.low as u64),
|
||||
generics: value.generics,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stru = MyStruct { val: 37 };
|
||||
let stru = MyStruct {
|
||||
val: 37,
|
||||
generics: 90,
|
||||
};
|
||||
|
||||
let serialized = bincode::serialize(&stru.versionize()).unwrap();
|
||||
|
||||
let stru_decoded = MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
let stru_decoded: MyStruct<i32> =
|
||||
MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(stru.val, stru_decoded.val)
|
||||
}
|
||||
|
||||
51
utils/tfhe-versionable/tests/convert_with_bounds.rs
Normal file
51
utils/tfhe-versionable/tests/convert_with_bounds.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
//! Checks compatibility between the "convert" feature and bounds on the From/Into trait
|
||||
|
||||
use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Clone, Versionize)]
|
||||
#[versionize(try_convert = "SerializableMyStruct")]
|
||||
struct MyStruct<T> {
|
||||
generics: T,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(SerializableMyStructVersions)]
|
||||
struct SerializableMyStruct {
|
||||
concrete: u64,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum SerializableMyStructVersions {
|
||||
V0(SerializableMyStruct),
|
||||
}
|
||||
|
||||
impl<T: Into<u64>> From<MyStruct<T>> for SerializableMyStruct {
|
||||
fn from(value: MyStruct<T>) -> Self {
|
||||
Self {
|
||||
concrete: value.generics.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<u64>> TryFrom<SerializableMyStruct> for MyStruct<T> {
|
||||
fn try_from(value: SerializableMyStruct) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
generics: value.concrete.try_into()?,
|
||||
})
|
||||
}
|
||||
|
||||
type Error = T::Error;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let stru = MyStruct { generics: 90u32 };
|
||||
|
||||
let serialized = bincode::serialize(&stru.versionize()).unwrap();
|
||||
|
||||
let stru_decoded: MyStruct<u32> =
|
||||
MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(stru.generics, stru_decoded.generics)
|
||||
}
|
||||
58
utils/tfhe-versionable/tests/convert_with_generics.rs
Normal file
58
utils/tfhe-versionable/tests/convert_with_generics.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
//! Checks compatibility between the "convert" feature and generics
|
||||
|
||||
use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Clone, Versionize)]
|
||||
#[versionize(convert = "SerializableMyStruct<T>")]
|
||||
struct MyStruct<T> {
|
||||
val: u64,
|
||||
generics: T,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(SerializableMyStructVersions)]
|
||||
struct SerializableMyStruct<T> {
|
||||
high: u32,
|
||||
low: u32,
|
||||
generics: T,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum SerializableMyStructVersions<T> {
|
||||
V0(SerializableMyStruct<T>),
|
||||
}
|
||||
|
||||
impl<T> From<MyStruct<T>> for SerializableMyStruct<T> {
|
||||
fn from(value: MyStruct<T>) -> Self {
|
||||
Self {
|
||||
high: (value.val >> 32) as u32,
|
||||
low: (value.val & 0xffffffff) as u32,
|
||||
generics: value.generics,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<SerializableMyStruct<T>> for MyStruct<T> {
|
||||
fn from(value: SerializableMyStruct<T>) -> Self {
|
||||
Self {
|
||||
val: ((value.high as u64) << 32) | (value.low as u64),
|
||||
generics: value.generics,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let stru = MyStruct {
|
||||
val: 37,
|
||||
generics: 90,
|
||||
};
|
||||
|
||||
let serialized = bincode::serialize(&stru.versionize()).unwrap();
|
||||
|
||||
let stru_decoded: MyStruct<i32> =
|
||||
MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(stru.val, stru_decoded.val)
|
||||
}
|
||||
Reference in New Issue
Block a user