Files
tfhe-rs/utils/tfhe-versionable-derive/src/versionize_impl.rs
2025-04-16 14:08:48 +02:00

441 lines
17 KiB
Rust

use proc_macro2::Span;
use quote::{quote, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse_quote, Data, GenericArgument, GenericParam, Generics, Ident, Lifetime, Path,
PathArguments, Token, TraitBound, Type, TypeParam, TypePath, WhereClause,
};
use crate::transparent::{TransparentStruct, TransparentStructKind};
use crate::versionize_attribute::{
ClassicVersionizeAttribute, ConversionType, ConvertVersionizeAttribute, VersionizeAttribute,
};
use crate::{
add_lifetime_where_clause, add_trait_where_clause, add_where_lifetime_bound_to_generics,
parse_const_str, DISPATCH_TRAIT_NAME, ERROR_TRAIT_NAME, FROM_TRAIT_NAME, INTO_TRAIT_NAME,
SEND_TRAIT_NAME, STATIC_LIFETIME_NAME, SYNC_TRAIT_NAME, TRY_INTO_TRAIT_NAME,
UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
VERSIONIZE_TRAIT_NAME,
};
pub(crate) enum VersionizeImplementor {
Classic(ClassicVersionizeAttribute),
Convert(ConvertVersionizeAttribute),
Transparent(Box<TransparentStruct>),
}
impl VersionizeImplementor {
pub(crate) fn new(
attributes: VersionizeAttribute,
decla: &Data,
base_span: Span,
) -> syn::Result<Self> {
match attributes {
VersionizeAttribute::Classic(classic) => Ok(Self::Classic(classic)),
VersionizeAttribute::Convert(convert) => Ok(Self::Convert(convert)),
VersionizeAttribute::Transparent => Ok(Self::Transparent(Box::new(
TransparentStruct::new(decla, base_span)?,
))),
}
}
/// Checks if the type should have a "true" Versionize implementation or if the implementation
/// is delegated to another type
pub(crate) fn is_directly_versioned(&self) -> bool {
match self {
Self::Classic(_) => true,
Self::Convert(_) => false,
Self::Transparent(_) => false,
}
}
/// Return the associated type used in the `Versionize` trait: `MyType::Versioned<'vers>`
///
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
///
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
/// conversion target. The versioned_owned is needed because the conversion will create a new
/// value, so we can't just use a reference.
pub(crate) fn versioned_type(
&self,
lifetime: &Lifetime,
input_generics: &Generics,
) -> proc_macro2::TokenStream {
match self {
Self::Classic(attr) => {
let (_, ty_generics, _) = input_generics.split_for_impl();
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
let dispatch_enum_path = &attr.dispatch_enum;
quote! {
<#dispatch_enum_path #ty_generics as
#dispatch_trait<Self>>::Ref<#lifetime>
}
}
Self::Convert(_) => {
// If we want to apply a conversion before the call to versionize we need to use the
// "owned" alternative of the dispatch enum to be able to store the
// conversion result.
self.versioned_owned_type(input_generics)
}
Self::Transparent(transparent) => {
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
let inner_type = &transparent.inner_type;
quote! { <#inner_type as #versionize_trait>::Versioned<#lifetime>}
}
}
}
/// Return the where clause for `MyType::Versioned<'vers>`. if `MyType` has generics, this means
/// adding a 'vers lifetime bound on them.
pub(crate) fn versioned_type_where_clause(
&self,
lifetime: &Lifetime,
input_generics: &Generics,
) -> Option<WhereClause> {
let mut generics = input_generics.clone();
add_where_lifetime_bound_to_generics(&mut generics, lifetime);
let (_, _, where_clause) = generics.split_for_impl();
where_clause.cloned()
}
/// Return the associated type used in the `VersionizeOwned` trait: `MyType::VersionedOwned`
///
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
///
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
/// conversion target.
pub(crate) fn versioned_owned_type(
&self,
input_generics: &Generics,
) -> proc_macro2::TokenStream {
let (_, ty_generics, _) = input_generics.split_for_impl();
match self {
Self::Classic(attr) => {
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
let dispatch_enum_path = &attr.dispatch_enum;
quote! {
<#dispatch_enum_path #ty_generics as
#dispatch_trait<Self>>::Owned
}
}
Self::Convert(convert_attr) => {
let convert_type_path = &convert_attr.conversion_target;
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
quote! {
<#convert_type_path as #versionize_owned_trait>::VersionedOwned
}
}
Self::Transparent(transparent) => {
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let inner_type = &transparent.inner_type;
quote! { <#inner_type as #versionize_owned_trait>::VersionedOwned }
}
}
}
/// Return the where clause for `MyType::VersionedOwned`.
///
/// This is simply the where clause of the input type.
pub(crate) fn versioned_owned_type_where_clause(
&self,
input_generics: &Generics,
) -> Option<WhereClause> {
match self {
Self::Classic(_) => input_generics.split_for_impl().2.cloned(),
Self::Convert(convert_attr) => extract_generics(&convert_attr.conversion_target)
.split_for_impl()
.2
.cloned(),
Self::Transparent(_) => input_generics.split_for_impl().2.cloned(),
}
}
/// Return the where clause needed to implement the Versionize trait.
///
/// This is the same as the one for the VersionizeOwned, with an additional "Clone" bound in the
/// case where we need to perform a conversion before the versioning.
pub(crate) fn versionize_trait_where_clause(
&self,
input_generics: &Generics,
) -> syn::Result<Option<WhereClause>> {
// The base bounds for the owned traits are also used for the ref traits
let mut generics = input_generics.clone();
match self {
VersionizeImplementor::Classic(_) => {
self.versionize_owned_trait_where_clause(&generics)
}
VersionizeImplementor::Convert(_) => {
// The versionize method takes a ref. We need to own the input type in the
// conversion case to apply `From<Input> for Target`. This adds a
// `Clone` bound to have a better error message if the input type is
// not Clone.
add_trait_where_clause(&mut generics, [&parse_quote! { Self }], &["Clone"])?;
self.versionize_owned_trait_where_clause(&generics)
}
VersionizeImplementor::Transparent(transparent) => {
add_trait_where_clause(
&mut generics,
[&transparent.inner_type],
&[VERSIONIZE_TRAIT_NAME],
)?;
Ok(generics.split_for_impl().2.cloned())
}
}
}
/// Return the where clause needed to implement the VersionizeOwned trait.
///
/// If the type is directly versioned, the bound states that the argument points to a valid
/// DispatchEnum for this type. This is done by adding a bound on this argument to
/// `VersionsDisaptch<Self>`.
///
/// If there is a conversion, the target of the conversion should implement `VersionizeOwned`
/// and `From<Self>`.
pub(crate) fn versionize_owned_trait_where_clause(
&self,
input_generics: &Generics,
) -> syn::Result<Option<WhereClause>> {
let mut generics = input_generics.clone();
match self {
Self::Classic(attr) => {
let dispatch_generics = generics.clone();
let dispatch_ty_generics = dispatch_generics.split_for_impl().1;
let dispatch_enum_path = &attr.dispatch_enum;
add_trait_where_clause(
&mut generics,
[&parse_quote!(#dispatch_enum_path #dispatch_ty_generics)],
&[format!("{}<Self>", DISPATCH_TRAIT_NAME,)],
)?;
}
Self::Convert(convert_attr) => {
let convert_type_path = &convert_attr.conversion_target;
add_trait_where_clause(
&mut generics,
[&parse_quote!(#convert_type_path)],
&[
VERSIONIZE_OWNED_TRAIT_NAME,
&format!("{}<Self>", FROM_TRAIT_NAME),
],
)?;
}
Self::Transparent(transparent) => {
add_trait_where_clause(
&mut generics,
[&transparent.inner_type],
&[VERSIONIZE_OWNED_TRAIT_NAME],
)?;
}
}
Ok(generics.split_for_impl().2.cloned())
}
/// Return the where clause for the `Unversionize` trait.
///
/// If the versioning is direct, this is the same bound as the one used for `VersionizeOwned`.
///
/// If there is a conversion, the target of the conversion need to implement `Unversionize` and
/// `Into` or `TryInto<T, E>`, with `E: Error + Send + Sync + 'static`
pub(crate) fn unversionize_trait_where_clause(
&self,
input_generics: &Generics,
) -> syn::Result<Option<WhereClause>> {
match self {
Self::Classic(_) => self.versionize_owned_trait_where_clause(input_generics),
Self::Convert(convert_attr) => {
let mut generics = input_generics.clone();
let convert_type_path = &convert_attr.conversion_target;
let into_trait = match convert_attr.conversion_type {
ConversionType::Direct => format!("{}<Self>", INTO_TRAIT_NAME),
ConversionType::Try => {
// Doing a TryFrom requires that the error
// impl Error + Send + Sync + 'static
let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME);
add_trait_where_clause(
&mut generics,
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
&[ERROR_TRAIT_NAME, SYNC_TRAIT_NAME, SEND_TRAIT_NAME],
)?;
add_lifetime_where_clause(
&mut generics,
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
&[STATIC_LIFETIME_NAME],
)?;
format!("{}<Self>", TRY_INTO_TRAIT_NAME)
}
};
add_trait_where_clause(
&mut generics,
[&parse_quote!(#convert_type_path)],
&[
UNVERSIONIZE_TRAIT_NAME,
&format!("{}<Self>", FROM_TRAIT_NAME),
&into_trait,
],
)?;
Ok(generics.split_for_impl().2.cloned())
}
Self::Transparent(transparent) => {
let mut generics = input_generics.clone();
add_trait_where_clause(
&mut generics,
[&transparent.inner_type],
&[UNVERSIONIZE_TRAIT_NAME],
)?;
Ok(generics.split_for_impl().2.cloned())
}
}
}
/// Return the body of the versionize method.
pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream {
match self {
Self::Classic(_) => {
quote! {
self.into()
}
}
Self::Convert(convert_attr) => {
let versionize_owned_trait: TraitBound =
parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
quote! {
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self.to_owned()))
}
}
Self::Transparent(transparent) => match &transparent.kind {
TransparentStructKind::NewType => {
quote! {
self.0.versionize()
}
}
TransparentStructKind::SingleField(field_name) => {
quote! {
self.#field_name.versionize()
}
}
},
}
}
/// Return the body of the versionize_owned method.
pub(crate) fn versionize_owned_method_body(&self) -> proc_macro2::TokenStream {
let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
match self {
Self::Classic(_) => {
quote! {
self.into()
}
}
Self::Convert(convert_attr) => {
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
quote! {
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self))
}
}
Self::Transparent(transparent) => match &transparent.kind {
TransparentStructKind::NewType => {
quote! {
self.0.versionize_owned()
}
}
TransparentStructKind::SingleField(field_name) => {
quote! {
self.#field_name.versionize_owned()
}
}
},
}
}
/// Return the body of the unversionize method.
pub(crate) fn unversionize_method_body(&self, arg_name: &Ident) -> proc_macro2::TokenStream {
let error: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
match self {
Self::Classic(_) => {
quote! { #arg_name.try_into() }
}
Self::Convert(convert_attr) => {
let target = with_turbofish(&convert_attr.conversion_target);
match convert_attr.conversion_type {
ConversionType::Direct => {
quote! { #target::unversionize(#arg_name).map(|value| value.into()) }
}
ConversionType::Try => {
let target_name = format!("{}", target.to_token_stream());
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
.map_err(|e| #error::conversion(#target_name, e)))
}
}
}
}
Self::Transparent(transparent) => {
let inner = match &transparent.inner_type {
Type::Path(path) => Type::Path(TypePath {
qself: path.qself.clone(),
path: with_turbofish(&path.path),
}),
inner => inner.clone(),
};
match &transparent.kind {
TransparentStructKind::NewType => {
quote! {
#inner::unversionize(#arg_name).map(Self)
}
}
TransparentStructKind::SingleField(field_name) => {
quote! {
Ok(Self { #field_name: #inner::unversionize(#arg_name)? })
}
}
}
}
}
}
}
/// Return the same type but with generics that use the turbofish syntax. Converts
/// `MyStruct<T>` into `MyStruct::<T>`
fn with_turbofish(path: &Path) -> Path {
let mut with_turbo = path.clone();
for segment in with_turbo.segments.iter_mut() {
if let PathArguments::AngleBracketed(generics) = &mut segment.arguments {
generics.colon2_token = Some(Token![::](generics.span()));
}
}
with_turbo
}
/// Extract the generics inside a type
fn extract_generics(path: &Path) -> Generics {
let mut generics = Generics::default();
if let Some(last_segment) = path.segments.last() {
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
for arg in &args.args {
if let GenericArgument::Type(Type::Path(type_path)) = arg {
if let Some(ident) = type_path.path.get_ident() {
let param = TypeParam::from(ident.clone());
generics.params.push(GenericParam::Type(param));
}
}
}
}
}
generics
}