mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
441 lines
17 KiB
Rust
441 lines
17 KiB
Rust
use proc_macro2::Span;
|
|
use quote::{quote, ToTokens};
|
|
use syn::spanned::Spanned;
|
|
use syn::{
|
|
parse_quote, Data, GenericArgument, GenericParam, Generics, Ident, Lifetime, Path,
|
|
PathArguments, Token, TraitBound, Type, TypeParam, TypePath, WhereClause,
|
|
};
|
|
|
|
use crate::transparent::{TransparentStruct, TransparentStructKind};
|
|
use crate::versionize_attribute::{
|
|
ClassicVersionizeAttribute, ConversionType, ConvertVersionizeAttribute, VersionizeAttribute,
|
|
};
|
|
use crate::{
|
|
add_lifetime_where_clause, add_trait_where_clause, add_where_lifetime_bound_to_generics,
|
|
parse_const_str, DISPATCH_TRAIT_NAME, ERROR_TRAIT_NAME, FROM_TRAIT_NAME, INTO_TRAIT_NAME,
|
|
SEND_TRAIT_NAME, STATIC_LIFETIME_NAME, SYNC_TRAIT_NAME, TRY_INTO_TRAIT_NAME,
|
|
UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
|
|
VERSIONIZE_TRAIT_NAME,
|
|
};
|
|
|
|
pub(crate) enum VersionizeImplementor {
|
|
Classic(ClassicVersionizeAttribute),
|
|
Convert(ConvertVersionizeAttribute),
|
|
Transparent(Box<TransparentStruct>),
|
|
}
|
|
|
|
impl VersionizeImplementor {
|
|
pub(crate) fn new(
|
|
attributes: VersionizeAttribute,
|
|
decla: &Data,
|
|
base_span: Span,
|
|
) -> syn::Result<Self> {
|
|
match attributes {
|
|
VersionizeAttribute::Classic(classic) => Ok(Self::Classic(classic)),
|
|
VersionizeAttribute::Convert(convert) => Ok(Self::Convert(convert)),
|
|
VersionizeAttribute::Transparent => Ok(Self::Transparent(Box::new(
|
|
TransparentStruct::new(decla, base_span)?,
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Checks if the type should have a "true" Versionize implementation or if the implementation
|
|
/// is delegated to another type
|
|
pub(crate) fn is_directly_versioned(&self) -> bool {
|
|
match self {
|
|
Self::Classic(_) => true,
|
|
Self::Convert(_) => false,
|
|
Self::Transparent(_) => false,
|
|
}
|
|
}
|
|
|
|
/// Return the associated type used in the `Versionize` trait: `MyType::Versioned<'vers>`
|
|
///
|
|
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
|
|
///
|
|
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
|
|
/// conversion target. The versioned_owned is needed because the conversion will create a new
|
|
/// value, so we can't just use a reference.
|
|
pub(crate) fn versioned_type(
|
|
&self,
|
|
lifetime: &Lifetime,
|
|
input_generics: &Generics,
|
|
) -> proc_macro2::TokenStream {
|
|
match self {
|
|
Self::Classic(attr) => {
|
|
let (_, ty_generics, _) = input_generics.split_for_impl();
|
|
|
|
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
|
let dispatch_enum_path = &attr.dispatch_enum;
|
|
quote! {
|
|
<#dispatch_enum_path #ty_generics as
|
|
#dispatch_trait<Self>>::Ref<#lifetime>
|
|
}
|
|
}
|
|
Self::Convert(_) => {
|
|
// If we want to apply a conversion before the call to versionize we need to use the
|
|
// "owned" alternative of the dispatch enum to be able to store the
|
|
// conversion result.
|
|
self.versioned_owned_type(input_generics)
|
|
}
|
|
Self::Transparent(transparent) => {
|
|
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
|
let inner_type = &transparent.inner_type;
|
|
quote! { <#inner_type as #versionize_trait>::Versioned<#lifetime>}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the where clause for `MyType::Versioned<'vers>`. if `MyType` has generics, this means
|
|
/// adding a 'vers lifetime bound on them.
|
|
pub(crate) fn versioned_type_where_clause(
|
|
&self,
|
|
lifetime: &Lifetime,
|
|
input_generics: &Generics,
|
|
) -> Option<WhereClause> {
|
|
let mut generics = input_generics.clone();
|
|
|
|
add_where_lifetime_bound_to_generics(&mut generics, lifetime);
|
|
let (_, _, where_clause) = generics.split_for_impl();
|
|
where_clause.cloned()
|
|
}
|
|
|
|
/// Return the associated type used in the `VersionizeOwned` trait: `MyType::VersionedOwned`
|
|
///
|
|
/// If the type is directly versioned, this will be a type generated by the `VersionDispatch`.
|
|
///
|
|
/// If we have a conversion before the versioning, we re-use the versioned_owned type of the
|
|
/// conversion target.
|
|
pub(crate) fn versioned_owned_type(
|
|
&self,
|
|
input_generics: &Generics,
|
|
) -> proc_macro2::TokenStream {
|
|
let (_, ty_generics, _) = input_generics.split_for_impl();
|
|
match self {
|
|
Self::Classic(attr) => {
|
|
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
|
let dispatch_enum_path = &attr.dispatch_enum;
|
|
quote! {
|
|
<#dispatch_enum_path #ty_generics as
|
|
#dispatch_trait<Self>>::Owned
|
|
}
|
|
}
|
|
Self::Convert(convert_attr) => {
|
|
let convert_type_path = &convert_attr.conversion_target;
|
|
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
|
|
|
quote! {
|
|
<#convert_type_path as #versionize_owned_trait>::VersionedOwned
|
|
}
|
|
}
|
|
Self::Transparent(transparent) => {
|
|
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
|
let inner_type = &transparent.inner_type;
|
|
|
|
quote! { <#inner_type as #versionize_owned_trait>::VersionedOwned }
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the where clause for `MyType::VersionedOwned`.
|
|
///
|
|
/// This is simply the where clause of the input type.
|
|
pub(crate) fn versioned_owned_type_where_clause(
|
|
&self,
|
|
input_generics: &Generics,
|
|
) -> Option<WhereClause> {
|
|
match self {
|
|
Self::Classic(_) => input_generics.split_for_impl().2.cloned(),
|
|
Self::Convert(convert_attr) => extract_generics(&convert_attr.conversion_target)
|
|
.split_for_impl()
|
|
.2
|
|
.cloned(),
|
|
Self::Transparent(_) => input_generics.split_for_impl().2.cloned(),
|
|
}
|
|
}
|
|
|
|
/// Return the where clause needed to implement the Versionize trait.
|
|
///
|
|
/// This is the same as the one for the VersionizeOwned, with an additional "Clone" bound in the
|
|
/// case where we need to perform a conversion before the versioning.
|
|
pub(crate) fn versionize_trait_where_clause(
|
|
&self,
|
|
input_generics: &Generics,
|
|
) -> syn::Result<Option<WhereClause>> {
|
|
// The base bounds for the owned traits are also used for the ref traits
|
|
let mut generics = input_generics.clone();
|
|
|
|
match self {
|
|
VersionizeImplementor::Classic(_) => {
|
|
self.versionize_owned_trait_where_clause(&generics)
|
|
}
|
|
VersionizeImplementor::Convert(_) => {
|
|
// The versionize method takes a ref. We need to own the input type in the
|
|
// conversion case to apply `From<Input> for Target`. This adds a
|
|
// `Clone` bound to have a better error message if the input type is
|
|
// not Clone.
|
|
add_trait_where_clause(&mut generics, [&parse_quote! { Self }], &["Clone"])?;
|
|
self.versionize_owned_trait_where_clause(&generics)
|
|
}
|
|
VersionizeImplementor::Transparent(transparent) => {
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&transparent.inner_type],
|
|
&[VERSIONIZE_TRAIT_NAME],
|
|
)?;
|
|
Ok(generics.split_for_impl().2.cloned())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the where clause needed to implement the VersionizeOwned trait.
|
|
///
|
|
/// If the type is directly versioned, the bound states that the argument points to a valid
|
|
/// DispatchEnum for this type. This is done by adding a bound on this argument to
|
|
/// `VersionsDisaptch<Self>`.
|
|
///
|
|
/// If there is a conversion, the target of the conversion should implement `VersionizeOwned`
|
|
/// and `From<Self>`.
|
|
pub(crate) fn versionize_owned_trait_where_clause(
|
|
&self,
|
|
input_generics: &Generics,
|
|
) -> syn::Result<Option<WhereClause>> {
|
|
let mut generics = input_generics.clone();
|
|
match self {
|
|
Self::Classic(attr) => {
|
|
let dispatch_generics = generics.clone();
|
|
let dispatch_ty_generics = dispatch_generics.split_for_impl().1;
|
|
let dispatch_enum_path = &attr.dispatch_enum;
|
|
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&parse_quote!(#dispatch_enum_path #dispatch_ty_generics)],
|
|
&[format!("{}<Self>", DISPATCH_TRAIT_NAME,)],
|
|
)?;
|
|
}
|
|
Self::Convert(convert_attr) => {
|
|
let convert_type_path = &convert_attr.conversion_target;
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&parse_quote!(#convert_type_path)],
|
|
&[
|
|
VERSIONIZE_OWNED_TRAIT_NAME,
|
|
&format!("{}<Self>", FROM_TRAIT_NAME),
|
|
],
|
|
)?;
|
|
}
|
|
Self::Transparent(transparent) => {
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&transparent.inner_type],
|
|
&[VERSIONIZE_OWNED_TRAIT_NAME],
|
|
)?;
|
|
}
|
|
}
|
|
|
|
Ok(generics.split_for_impl().2.cloned())
|
|
}
|
|
|
|
/// Return the where clause for the `Unversionize` trait.
|
|
///
|
|
/// If the versioning is direct, this is the same bound as the one used for `VersionizeOwned`.
|
|
///
|
|
/// If there is a conversion, the target of the conversion need to implement `Unversionize` and
|
|
/// `Into` or `TryInto<T, E>`, with `E: Error + Send + Sync + 'static`
|
|
pub(crate) fn unversionize_trait_where_clause(
|
|
&self,
|
|
input_generics: &Generics,
|
|
) -> syn::Result<Option<WhereClause>> {
|
|
match self {
|
|
Self::Classic(_) => self.versionize_owned_trait_where_clause(input_generics),
|
|
Self::Convert(convert_attr) => {
|
|
let mut generics = input_generics.clone();
|
|
let convert_type_path = &convert_attr.conversion_target;
|
|
let into_trait = match convert_attr.conversion_type {
|
|
ConversionType::Direct => format!("{}<Self>", INTO_TRAIT_NAME),
|
|
ConversionType::Try => {
|
|
// Doing a TryFrom requires that the error
|
|
// impl Error + Send + Sync + 'static
|
|
let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME);
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
|
|
&[ERROR_TRAIT_NAME, SYNC_TRAIT_NAME, SEND_TRAIT_NAME],
|
|
)?;
|
|
add_lifetime_where_clause(
|
|
&mut generics,
|
|
[&parse_quote!(<#convert_type_path as #try_into_trait<Self>>::Error)],
|
|
&[STATIC_LIFETIME_NAME],
|
|
)?;
|
|
|
|
format!("{}<Self>", TRY_INTO_TRAIT_NAME)
|
|
}
|
|
};
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&parse_quote!(#convert_type_path)],
|
|
&[
|
|
UNVERSIONIZE_TRAIT_NAME,
|
|
&format!("{}<Self>", FROM_TRAIT_NAME),
|
|
&into_trait,
|
|
],
|
|
)?;
|
|
|
|
Ok(generics.split_for_impl().2.cloned())
|
|
}
|
|
Self::Transparent(transparent) => {
|
|
let mut generics = input_generics.clone();
|
|
|
|
add_trait_where_clause(
|
|
&mut generics,
|
|
[&transparent.inner_type],
|
|
&[UNVERSIONIZE_TRAIT_NAME],
|
|
)?;
|
|
Ok(generics.split_for_impl().2.cloned())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the body of the versionize method.
|
|
pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream {
|
|
match self {
|
|
Self::Classic(_) => {
|
|
quote! {
|
|
self.into()
|
|
}
|
|
}
|
|
Self::Convert(convert_attr) => {
|
|
let versionize_owned_trait: TraitBound =
|
|
parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
|
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
|
|
quote! {
|
|
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self.to_owned()))
|
|
}
|
|
}
|
|
Self::Transparent(transparent) => match &transparent.kind {
|
|
TransparentStructKind::NewType => {
|
|
quote! {
|
|
self.0.versionize()
|
|
}
|
|
}
|
|
TransparentStructKind::SingleField(field_name) => {
|
|
quote! {
|
|
self.#field_name.versionize()
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Return the body of the versionize_owned method.
|
|
pub(crate) fn versionize_owned_method_body(&self) -> proc_macro2::TokenStream {
|
|
let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
|
|
|
match self {
|
|
Self::Classic(_) => {
|
|
quote! {
|
|
self.into()
|
|
}
|
|
}
|
|
Self::Convert(convert_attr) => {
|
|
let convert_type_path = with_turbofish(&convert_attr.conversion_target);
|
|
quote! {
|
|
#versionize_owned_trait::versionize_owned(#convert_type_path::from(self))
|
|
}
|
|
}
|
|
Self::Transparent(transparent) => match &transparent.kind {
|
|
TransparentStructKind::NewType => {
|
|
quote! {
|
|
self.0.versionize_owned()
|
|
}
|
|
}
|
|
TransparentStructKind::SingleField(field_name) => {
|
|
quote! {
|
|
self.#field_name.versionize_owned()
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Return the body of the unversionize method.
|
|
pub(crate) fn unversionize_method_body(&self, arg_name: &Ident) -> proc_macro2::TokenStream {
|
|
let error: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
|
match self {
|
|
Self::Classic(_) => {
|
|
quote! { #arg_name.try_into() }
|
|
}
|
|
Self::Convert(convert_attr) => {
|
|
let target = with_turbofish(&convert_attr.conversion_target);
|
|
match convert_attr.conversion_type {
|
|
ConversionType::Direct => {
|
|
quote! { #target::unversionize(#arg_name).map(|value| value.into()) }
|
|
}
|
|
ConversionType::Try => {
|
|
let target_name = format!("{}", target.to_token_stream());
|
|
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
|
|
.map_err(|e| #error::conversion(#target_name, e)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Self::Transparent(transparent) => {
|
|
let inner = match &transparent.inner_type {
|
|
Type::Path(path) => Type::Path(TypePath {
|
|
qself: path.qself.clone(),
|
|
path: with_turbofish(&path.path),
|
|
}),
|
|
inner => inner.clone(),
|
|
};
|
|
|
|
match &transparent.kind {
|
|
TransparentStructKind::NewType => {
|
|
quote! {
|
|
#inner::unversionize(#arg_name).map(Self)
|
|
}
|
|
}
|
|
TransparentStructKind::SingleField(field_name) => {
|
|
quote! {
|
|
Ok(Self { #field_name: #inner::unversionize(#arg_name)? })
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the same type but with generics that use the turbofish syntax. Converts
|
|
/// `MyStruct<T>` into `MyStruct::<T>`
|
|
fn with_turbofish(path: &Path) -> Path {
|
|
let mut with_turbo = path.clone();
|
|
|
|
for segment in with_turbo.segments.iter_mut() {
|
|
if let PathArguments::AngleBracketed(generics) = &mut segment.arguments {
|
|
generics.colon2_token = Some(Token));
|
|
}
|
|
}
|
|
|
|
with_turbo
|
|
}
|
|
|
|
/// Extract the generics inside a type
|
|
fn extract_generics(path: &Path) -> Generics {
|
|
let mut generics = Generics::default();
|
|
|
|
if let Some(last_segment) = path.segments.last() {
|
|
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
|
|
for arg in &args.args {
|
|
if let GenericArgument::Type(Type::Path(type_path)) = arg {
|
|
if let Some(ident) = type_path.path.get_ident() {
|
|
let param = TypeParam::from(ident.clone());
|
|
generics.params.push(GenericParam::Type(param));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
generics
|
|
}
|