mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(vers): add crate for types versioning/backward compatibility
This commit is contained in:
committed by
Nicolas Sarlin
parent
c227bf4a49
commit
444ebbde57
@@ -7,6 +7,8 @@ members = [
|
||||
"apps/trivium",
|
||||
"concrete-csprng",
|
||||
"backends/tfhe-cuda-backend",
|
||||
"utils/tfhe-versionable",
|
||||
"utils/tfhe-versionable-derive"
|
||||
]
|
||||
|
||||
[profile.bench]
|
||||
|
||||
5
Makefile
5
Makefile
@@ -653,6 +653,11 @@ test_zk_pok: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
-p tfhe-zk-pok
|
||||
|
||||
.PHONY: test_versionable # Run tests for tfhe-versionable subcrate
|
||||
test_versionable: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
-p tfhe-versionable
|
||||
|
||||
.PHONY: doc # Build rust doc
|
||||
doc: install_rs_check_toolchain
|
||||
@# Even though we are not in docs.rs, this allows to "just" build the doc
|
||||
|
||||
@@ -2,13 +2,15 @@ use no_comment::{languages, IntoWithoutComments};
|
||||
use std::collections::HashSet;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
||||
const FILES_TO_IGNORE: [&str; 3] = [
|
||||
const FILES_TO_IGNORE: [&str; 4] = [
|
||||
// This contains fragments of code that are unrelated to TFHE-rs
|
||||
"tfhe/docs/tutorials/sha256_bool.md",
|
||||
// This contains fragments of code coming from the tutorial that cannot be run as a doctest
|
||||
"tfhe/examples/fhe_strings/README.md",
|
||||
// TODO: This contains code that could be executed as a trivium docstring
|
||||
"apps/trivium/README.md",
|
||||
// TODO: should we test this ?
|
||||
"utils/tfhe-versionable/README.md",
|
||||
];
|
||||
|
||||
pub fn check_tfhe_docs_are_tested() -> Result<(), Error> {
|
||||
|
||||
20
utils/tfhe-versionable-derive/Cargo.toml
Normal file
20
utils/tfhe-versionable-derive/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "tfhe-versionable-derive"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
keywords = ["versioning", "serialization", "encoding", "proc-macro", "derive"]
|
||||
homepage = "https://zama.ai/"
|
||||
documentation = "https://docs.rs/tfhe_versionable_derive"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "tfhe-versionable-derive: A set of proc macro for easier implementation of the tfhe-versionable traits"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
syn = { version = "2.0", features = ["full"] }
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
271
utils/tfhe-versionable-derive/src/associated.rs
Normal file
271
utils/tfhe-versionable-derive/src/associated.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{
|
||||
parse_quote, DeriveInput, ImplGenerics, Item, ItemImpl, Lifetime, Path, Type, WhereClause,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
add_lifetime_bound, add_trait_bound, add_trait_where_clause, add_where_lifetime_bound,
|
||||
parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, SERIALIZE_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// Generates an impl block for the From trait. This will be:
|
||||
/// ```
|
||||
/// impl From<Src> for Dest {
|
||||
/// fn from(value: Src) -> Self {
|
||||
/// ...[constructor]...
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub(crate) fn generate_from_trait_impl(
|
||||
src: &Type,
|
||||
dest: &Type,
|
||||
impl_generics: &ImplGenerics,
|
||||
where_clause: Option<&WhereClause>,
|
||||
constructor: &TokenStream,
|
||||
from_variable_name: &str,
|
||||
) -> syn::Result<ItemImpl> {
|
||||
let from_variable = Ident::new(from_variable_name, Span::call_site());
|
||||
Ok(parse_quote! {
|
||||
impl #impl_generics From<#src> for #dest #where_clause {
|
||||
fn from(#from_variable: #src) -> Self {
|
||||
#constructor
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates an impl block for the TryFrom trait. This will be:
|
||||
/// ```
|
||||
/// impl TryFrom<Src> for Dest {
|
||||
/// type Error = ErrorType;
|
||||
/// fn from(value: Src) -> Self {
|
||||
/// ...[constructor]...
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub(crate) fn generate_try_from_trait_impl(
|
||||
src: &Type,
|
||||
dest: &Type,
|
||||
error: &Type,
|
||||
impl_generics: &ImplGenerics,
|
||||
where_clause: Option<&WhereClause>,
|
||||
constructor: &TokenStream,
|
||||
from_variable_name: &str,
|
||||
) -> syn::Result<ItemImpl> {
|
||||
let from_variable = Ident::new(from_variable_name, Span::call_site());
|
||||
Ok(parse_quote! {
|
||||
impl #impl_generics TryFrom<#src> for #dest #where_clause {
|
||||
type Error = #error;
|
||||
fn try_from(#from_variable: #src) -> Result<Self, Self::Error> {
|
||||
#constructor
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// The ownership kind of the data for a associated type.
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum AssociatedTypeKind {
|
||||
/// This version type use references to non-Copy rust underlying built-in types.
|
||||
/// This is used for versioning before serialization. Unit types are considered as ref types
|
||||
/// for trait implementations, but they do not hold a lifetime.
|
||||
Ref(Option<Lifetime>),
|
||||
/// This version type own the non-Copy rust underlying built-in types.
|
||||
/// This is used for unversioning after serialization.
|
||||
Owned,
|
||||
}
|
||||
|
||||
/// A type that will be generated by the proc macro that are used in the versioning/unversioning
|
||||
/// process. We use associated types to avoid to rely on generated names. The two associated types
|
||||
/// used in this proc macro are the [`DispatchType`] and the [`VersionType`].
|
||||
///
|
||||
/// To be able have a more efficient versioning, these types actually come in two versions:
|
||||
/// - A `ref` type, that holds a reference to the underlying data. This is used for faster
|
||||
/// versioning using only references.
|
||||
/// - An owned type, that owns the underlying data. This is used for unversioning. The ownership of
|
||||
/// the data will be transfered during the unversioning process.
|
||||
///
|
||||
/// [`DispatchType`]: crate::dispatch_type::DispatchType
|
||||
/// [`VersionType`]: crate::dispatch_type::VersionType
|
||||
pub(crate) trait AssociatedType: Sized {
|
||||
/// This will create the alternative of the type that holds a reference to the underlying data
|
||||
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self>;
|
||||
/// This will create the alternative of the type that owns the underlying data
|
||||
fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self>;
|
||||
|
||||
/// Generates the type declaration for this type
|
||||
fn generate_type_declaration(&self) -> syn::Result<Item>;
|
||||
|
||||
/// Generates conversion methods between the origin type and the associated type. If the version
|
||||
/// type is a ref, the conversion is `From<&'vers OrigType> for Associated<'vers>` because this
|
||||
/// conversion is used for versioning. If the version type is owned, the conversion is
|
||||
/// `From<XXXAssociatedOwned> for XXX` because the owned type is used for unversioning (where
|
||||
/// Associated should be replaced by [`Version`] or [`Dispatch`].
|
||||
///
|
||||
/// [`Dispatch`]: crate::dispatch_type::DispatchType
|
||||
/// [`Version`]: crate::dispatch_type::VersionType
|
||||
fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>>;
|
||||
|
||||
/// The lifetime added for this type, if it is a "ref" type. It also returns None if the type is
|
||||
/// a unit type (no data)
|
||||
//fn lifetime(&self) -> Option<&Lifetime>;
|
||||
|
||||
/// The identifier used to name this type
|
||||
fn ident(&self) -> Ident;
|
||||
|
||||
/// The lifetime associated with this type, if it is a "ref" type. It can also be None if the
|
||||
/// ref type holds no data.
|
||||
fn lifetime(&self) -> Option<&Lifetime>;
|
||||
|
||||
/// The types that compose the original type. For example, for a structure, this is the type of
|
||||
/// its attributes
|
||||
fn inner_types(&self) -> syn::Result<Vec<&Type>>;
|
||||
|
||||
/// If the associating trait that uses this type needs a type parameter, this returns it.
|
||||
/// For the `VersionsDispatch` trait this paramter is the name of the currently used version,
|
||||
/// which is the latest variant of the dispatch enum. The `Version` trait does not need a
|
||||
/// parameter.
|
||||
fn as_trait_param(&self) -> Option<syn::Result<&Type>>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ConversionDirection {
|
||||
OrigToAssociated,
|
||||
AssociatedToOrig,
|
||||
}
|
||||
|
||||
/// A trait that is used to hold a category of associated types generated by this proc macro.
|
||||
/// These traits holds the 2 versions of the associated type, the "ref" one and the "owned" one.
|
||||
pub(crate) struct AssociatingTrait<T> {
|
||||
ref_type: T,
|
||||
owned_type: T,
|
||||
orig_type: DeriveInput,
|
||||
trait_path: Path,
|
||||
/// Bounds that should be added to the generics for the impl
|
||||
generics_bounds: Vec<String>,
|
||||
/// Bounds that should be added on the struct attributes
|
||||
attributes_bounds: Vec<String>,
|
||||
}
|
||||
|
||||
impl<T: AssociatedType> AssociatingTrait<T> {
|
||||
pub(crate) fn new(
|
||||
orig_type: &DeriveInput,
|
||||
name: &str,
|
||||
generics_bounds: &[&str],
|
||||
attributes_bounds: &[&str],
|
||||
) -> syn::Result<Self> {
|
||||
let ref_type = T::new_ref(orig_type)?;
|
||||
let owned_type = T::new_owned(orig_type)?;
|
||||
let trait_path = syn::parse_str(name)?;
|
||||
|
||||
let generics_bounds = generics_bounds
|
||||
.iter()
|
||||
.map(|bound| bound.to_string())
|
||||
.collect();
|
||||
|
||||
let attributes_bounds = attributes_bounds
|
||||
.iter()
|
||||
.map(|bound| bound.to_string())
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
ref_type,
|
||||
owned_type,
|
||||
orig_type: orig_type.clone(),
|
||||
trait_path,
|
||||
generics_bounds,
|
||||
attributes_bounds,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the impl for the associating trait
|
||||
pub(crate) fn generate_impl(&self) -> syn::Result<TokenStream> {
|
||||
let orig_ident = &self.orig_type.ident;
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
|
||||
let ref_ident = self.ref_type.ident();
|
||||
let owned_ident = self.owned_type.ident();
|
||||
|
||||
let mut generics = self.orig_type.generics.clone();
|
||||
|
||||
for bound in &self.generics_bounds {
|
||||
add_trait_bound(&mut generics, bound)?;
|
||||
}
|
||||
|
||||
let trait_param = self.ref_type.as_trait_param().transpose()?;
|
||||
|
||||
let mut ref_type_generics = generics.clone();
|
||||
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
self.ref_type.inner_types()?,
|
||||
&self.attributes_bounds,
|
||||
)?;
|
||||
|
||||
// If the original type has some generics, we need to add a lifetime bound on them
|
||||
if let Some(lifetime) = self.ref_type.lifetime() {
|
||||
add_lifetime_bound(&mut ref_type_generics, lifetime);
|
||||
add_where_lifetime_bound(&mut ref_type_generics, lifetime);
|
||||
}
|
||||
|
||||
let (impl_generics, orig_generics, where_clause) = generics.split_for_impl();
|
||||
let (_, ref_generics, ref_where_clause) = ref_type_generics.split_for_impl();
|
||||
|
||||
let trait_ident = &self.trait_path;
|
||||
|
||||
Ok(quote! {
|
||||
impl #impl_generics #trait_ident<#trait_param> for #orig_ident #orig_generics #where_clause {
|
||||
type Ref<#lifetime> = #ref_ident #ref_generics #ref_where_clause;
|
||||
type Owned = #owned_ident #orig_generics;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn generate_types_declarations(&self) -> syn::Result<TokenStream> {
|
||||
let owned_decla = self.owned_type.generate_type_declaration()?;
|
||||
|
||||
let owned_conversion = self.owned_type.generate_conversion()?;
|
||||
|
||||
let serialize_trait: Path = parse_const_str(SERIALIZE_TRAIT_NAME);
|
||||
let deserialize_trait: Path = parse_const_str(DESERIALIZE_TRAIT_NAME);
|
||||
|
||||
let ignored_lints = quote! {
|
||||
#[allow(
|
||||
// We add bounds on the generated code because it will make the compiler
|
||||
// generate better errors in case of misuse of the macros. However in some cases
|
||||
// this may generate a warning, so we silence it.
|
||||
private_bounds,
|
||||
// If these lints doesn't trigger on the orginal type, we don't want them to trigger
|
||||
// on the generated one
|
||||
clippy::upper_case_acronyms,
|
||||
clippy::large_enum_variant
|
||||
)
|
||||
]};
|
||||
|
||||
let owned_tokens = quote! {
|
||||
#[derive(#serialize_trait, #deserialize_trait)]
|
||||
#ignored_lints
|
||||
#owned_decla
|
||||
|
||||
#(#owned_conversion)*
|
||||
};
|
||||
|
||||
let ref_decla = self.ref_type.generate_type_declaration()?;
|
||||
|
||||
let ref_conversion = self.ref_type.generate_conversion()?;
|
||||
|
||||
let ref_tokens = quote! {
|
||||
#[derive(#serialize_trait)]
|
||||
#ignored_lints
|
||||
#ref_decla
|
||||
|
||||
#(#ref_conversion)*
|
||||
};
|
||||
|
||||
Ok(quote! {
|
||||
#owned_tokens
|
||||
#ref_tokens
|
||||
})
|
||||
}
|
||||
}
|
||||
356
utils/tfhe-versionable-derive/src/dispatch_type.rs
Normal file
356
utils/tfhe-versionable-derive/src/dispatch_type.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::{format_ident, quote};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Data, DeriveInput, Field, Fields, Generics, ItemEnum, ItemImpl, Lifetime, Path,
|
||||
Type, Variant,
|
||||
};
|
||||
|
||||
use crate::associated::{
|
||||
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
|
||||
};
|
||||
use crate::{
|
||||
add_lifetime_bound, add_trait_bound, add_trait_where_clause, parse_const_str, LIFETIME_NAME,
|
||||
UNVERSIONIZE_ERROR_NAME, VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// This is the enum that holds all the versions of a specific type. Each variant of the enum is
|
||||
/// a Version of a given type. The users writes the input enum using its own types. The macro
|
||||
/// will generate two types:
|
||||
/// - a `ref` type that uses the `ref` Version equivalent of each variant
|
||||
/// - an owned type, that uses the VersionOwned equivalent of each variant
|
||||
pub(crate) struct DispatchType {
|
||||
orig_type: ItemEnum,
|
||||
kind: AssociatedTypeKind,
|
||||
}
|
||||
|
||||
/// The `VersionsDispatch` macro can only be used on enum. This converts the
|
||||
/// generic `DeriveInput` into an `ItemEnum` or returns an explicit error.
|
||||
fn derive_input_to_enum(input: &DeriveInput) -> syn::Result<ItemEnum> {
|
||||
match &input.data {
|
||||
Data::Enum(enu) => Ok(ItemEnum {
|
||||
attrs: input.attrs.clone(),
|
||||
vis: input.vis.clone(),
|
||||
enum_token: enu.enum_token,
|
||||
ident: input.ident.clone(),
|
||||
generics: input.generics.clone(),
|
||||
brace_token: enu.brace_token,
|
||||
variants: enu.variants.clone(),
|
||||
}),
|
||||
_ => Err(syn::Error::new(
|
||||
input.span(),
|
||||
"VersionsDispatch can only be derived on an enum",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
impl AssociatedType for DispatchType {
|
||||
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self> {
|
||||
for lt in orig_type.generics.lifetimes() {
|
||||
// check for collision with other lifetimes in `orig_type`
|
||||
if lt.lifetime.ident == LIFETIME_NAME {
|
||||
return Err(syn::Error::new(
|
||||
lt.lifetime.span(),
|
||||
format!(
|
||||
"Lifetime name {} conflicts with the one used by macro `Version`",
|
||||
LIFETIME_NAME
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
Ok(Self {
|
||||
orig_type: derive_input_to_enum(orig_type)?,
|
||||
kind: AssociatedTypeKind::Ref(Some(lifetime)),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
orig_type: derive_input_to_enum(orig_type)?,
|
||||
kind: AssociatedTypeKind::Owned,
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_type_declaration(&self) -> syn::Result<syn::Item> {
|
||||
let variants: syn::Result<Punctuated<Variant, Comma>> = self
|
||||
.orig_type
|
||||
.variants
|
||||
.iter()
|
||||
.map(|variant| {
|
||||
let dispatch_field = self.convert_field(self.variant_field(variant)?);
|
||||
let dispatch_variant = Variant {
|
||||
fields: Fields::Unnamed(parse_quote!((#dispatch_field))),
|
||||
..variant.clone()
|
||||
};
|
||||
|
||||
Ok(dispatch_variant)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ItemEnum {
|
||||
ident: self.ident(),
|
||||
generics: self.generics()?,
|
||||
variants: variants?,
|
||||
..self.orig_type.clone()
|
||||
}
|
||||
.into())
|
||||
}
|
||||
|
||||
fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>> {
|
||||
let generics = self.generics()?;
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => {
|
||||
// Wraps the highest version into the dispatch enum
|
||||
let src_type = self.latest_version_type()?;
|
||||
let src = parse_quote! { &#lifetime #src_type };
|
||||
let dest_ident = self.ident();
|
||||
let dest = parse_quote! { #dest_ident #ty_generics };
|
||||
let constructor = self.generate_conversion_constructor_ref("value")?;
|
||||
|
||||
generate_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)
|
||||
.map(|res| vec![res])
|
||||
}
|
||||
AssociatedTypeKind::Owned => {
|
||||
// Upgrade to the highest version the convert to the main type
|
||||
let src_ident = self.ident();
|
||||
let src = parse_quote! { #src_ident #ty_generics };
|
||||
let dest_type = self.latest_version_type()?;
|
||||
let dest = parse_quote! { #dest_type };
|
||||
let error = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
let constructor = self.generate_conversion_constructor_owned("value")?;
|
||||
|
||||
let assoc_to_orig = generate_try_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&error,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)?;
|
||||
|
||||
// Wraps the highest version into the dispatch enum
|
||||
let src_type = self.latest_version_type()?;
|
||||
let src = parse_quote! { &#src_type };
|
||||
let dest_ident = self.ident();
|
||||
let dest = parse_quote! { #dest_ident #ty_generics };
|
||||
let constructor = self.generate_conversion_constructor_ref("value")?;
|
||||
|
||||
let orig_to_assoc = generate_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)?;
|
||||
|
||||
Ok(vec![orig_to_assoc, assoc_to_orig])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ident(&self) -> Ident {
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(_) => {
|
||||
format_ident!("{}Dispatch", self.orig_type.ident)
|
||||
}
|
||||
AssociatedTypeKind::Owned => {
|
||||
format_ident!("{}DispatchOwned", self.orig_type.ident)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lifetime(&self) -> Option<&Lifetime> {
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => lifetime.as_ref(),
|
||||
AssociatedTypeKind::Owned => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_trait_param(&self) -> Option<syn::Result<&Type>> {
|
||||
Some(self.latest_version_type())
|
||||
}
|
||||
|
||||
fn inner_types(&self) -> syn::Result<Vec<&Type>> {
|
||||
self.version_types()
|
||||
}
|
||||
}
|
||||
|
||||
impl DispatchType {
|
||||
/// Returns the error sent to the user for a wrong use of this macro
|
||||
fn error(&self) -> syn::Error {
|
||||
syn::Error::new(
|
||||
self.orig_type.span(),
|
||||
"VersionsDispatch should be used on a enum with single anonymous field variants",
|
||||
)
|
||||
}
|
||||
|
||||
fn generics(&self) -> syn::Result<Generics> {
|
||||
let mut generics = self.orig_type.generics.clone();
|
||||
if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind {
|
||||
add_lifetime_bound(&mut generics, lifetime);
|
||||
}
|
||||
|
||||
add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSION_TRAIT_NAME])?;
|
||||
|
||||
add_trait_bound(&mut generics, VERSIONIZE_TRAIT_NAME)?;
|
||||
|
||||
Ok(generics)
|
||||
}
|
||||
|
||||
/// Returns the number of versions in this dispatch enum
|
||||
fn versions_count(&self) -> usize {
|
||||
self.orig_type.variants.len()
|
||||
}
|
||||
|
||||
/// Returns the latest version of the original type, which is the last variant in the enum
|
||||
fn latest_version(&self) -> syn::Result<&Variant> {
|
||||
self.orig_type.variants.last().ok_or_else(|| self.error())
|
||||
}
|
||||
|
||||
fn version_types(&self) -> syn::Result<Vec<&Type>> {
|
||||
self.orig_type
|
||||
.variants
|
||||
.iter()
|
||||
.map(|variant| self.variant_field(variant))
|
||||
.map(|field_opt| field_opt.map(|field| &field.ty))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns the type of the version at index `idx`
|
||||
fn version_type_at(&self, idx: usize) -> syn::Result<&Type> {
|
||||
self.variant_at(idx)
|
||||
.and_then(|variant| self.variant_field(variant))
|
||||
.map(|field| &field.ty)
|
||||
}
|
||||
|
||||
/// Returns the variant at index `idx`
|
||||
fn variant_at(&self, idx: usize) -> syn::Result<&Variant> {
|
||||
self.orig_type
|
||||
.variants
|
||||
.iter()
|
||||
.nth(idx)
|
||||
.ok_or_else(|| self.error())
|
||||
}
|
||||
|
||||
/// Returns the type of the latest version of the original type
|
||||
fn latest_version_type(&self) -> syn::Result<&Type> {
|
||||
self.latest_version()
|
||||
.and_then(|variant| self.variant_field(variant))
|
||||
.map(|field| &field.ty)
|
||||
}
|
||||
|
||||
/// Returns the field inside a specific variant of the enum. Checks that this variant contains
|
||||
/// only one unnamed field.
|
||||
fn variant_field<'a>(&'a self, variant: &'a Variant) -> syn::Result<&'a Field> {
|
||||
match &variant.fields {
|
||||
// Check that the variant is of the form `Vn(XXXVersion)`
|
||||
Fields::Named(_) => Err(self.error()),
|
||||
Fields::Unnamed(fields) => {
|
||||
if fields.unnamed.len() != 1 {
|
||||
Err(self.error())
|
||||
} else {
|
||||
// Ok to unwrap because we checked that len is 1
|
||||
Ok(fields.unnamed.first().unwrap())
|
||||
}
|
||||
}
|
||||
Fields::Unit => Err(self.error()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the field of a variant of a dispatch enum into a field that uses
|
||||
/// the `Version` equivalent of the type
|
||||
fn convert_field(&self, field: &Field) -> Field {
|
||||
let orig_ty = field.ty.clone();
|
||||
let version_trait: Path = parse_const_str(VERSION_TRAIT_NAME);
|
||||
|
||||
let ty: Type = match &self.kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#orig_ty as #version_trait>::Ref<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#orig_ty as #version_trait>::Owned
|
||||
},
|
||||
};
|
||||
|
||||
Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the conversion from a reference to the original type into the `ref` dispatch
|
||||
/// type. This basically generates code that wrapes the input into the last variant of the enum.
|
||||
fn generate_conversion_constructor_ref(&self, arg_name: &str) -> syn::Result<TokenStream> {
|
||||
let variant_ident = &self.latest_version()?.ident;
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
|
||||
Ok(quote! {
|
||||
Self::#variant_ident(#arg_ident.into())
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates conversion from the `owned` dispatch type to the original type. This generates a
|
||||
/// `match` on the dispatch enum that calls the update method on each version enough times to
|
||||
/// get to the latest version.
|
||||
fn generate_conversion_constructor_owned(&self, arg_name: &str) -> syn::Result<TokenStream> {
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
let error_ty: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
|
||||
let match_cases =
|
||||
self.orig_type
|
||||
.variants
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(src_idx, variant)| -> syn::Result<_> {
|
||||
let last_version = self.versions_count() - 1;
|
||||
let enum_ident = self.ident();
|
||||
let target_type = self.version_type_at(src_idx)?;
|
||||
let variant_ident = &variant.ident;
|
||||
let var_name = format_ident!("v{}", src_idx);
|
||||
|
||||
let upgrades_needed = last_version - src_idx;
|
||||
|
||||
// Add chained calls to the upgrade method, with error handling
|
||||
let upgrades_chain = (0..upgrades_needed).map(|upgrade_idx| {
|
||||
// Here we can unwrap because src_idx + upgrade_idx < version_count or we wouldn't need to upgrade
|
||||
let src_variant = self.variant_at(src_idx + upgrade_idx).unwrap().ident.to_string();
|
||||
let dest_variant = self.variant_at(src_idx + upgrade_idx + 1).unwrap().ident.to_string();
|
||||
quote! {
|
||||
.and_then(|value| {
|
||||
value
|
||||
.upgrade()
|
||||
.map_err(|e|
|
||||
#error_ty::upgrade(#src_variant, #dest_variant, &e)
|
||||
)
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
Ok(quote! {
|
||||
#enum_ident::#variant_ident(#var_name) => TryInto::<#target_type>::try_into(#var_name)
|
||||
#(#upgrades_chain)*
|
||||
})
|
||||
}).collect::<syn::Result<Vec<TokenStream>>>()?;
|
||||
|
||||
Ok(quote! {
|
||||
match #arg_ident {
|
||||
#(#match_cases),*
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
372
utils/tfhe-versionable-derive/src/lib.rs
Normal file
372
utils/tfhe-versionable-derive/src/lib.rs
Normal file
@@ -0,0 +1,372 @@
|
||||
//! Set of derive macro to automatically implement the `Versionize` and `Unversionize` traits.
|
||||
//! The macro defined in this crate are:
|
||||
//! - `Versionize`: should be derived on the main type that is used in your code
|
||||
//! - `Version`: should be derived on a previous version of this type
|
||||
//! - `VersionsDispatch`: should be derived ont the enum that holds all the versions of the type
|
||||
//! - `NotVersioned`: can be used to implement `Versionize` for a type that is not really versioned
|
||||
|
||||
mod associated;
|
||||
mod dispatch_type;
|
||||
mod version_type;
|
||||
mod versionize_attribute;
|
||||
|
||||
use dispatch_type::DispatchType;
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Span;
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::parse::Parse;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::{
|
||||
parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
|
||||
LifetimeParam, Path, TraitBound, Type, TypeParamBound,
|
||||
};
|
||||
use versionize_attribute::VersionizeAttribute;
|
||||
|
||||
/// Adds the full path of the current crate name to avoid name clashes in generated code.
|
||||
macro_rules! crate_full_path {
|
||||
($trait_name:expr) => {
|
||||
concat!("::tfhe_versionable::", $trait_name)
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) const LIFETIME_NAME: &str = "'vers";
|
||||
pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
|
||||
pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
|
||||
pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
|
||||
pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
|
||||
pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
|
||||
|
||||
pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
|
||||
pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
|
||||
pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
|
||||
|
||||
use associated::AssociatingTrait;
|
||||
|
||||
use crate::version_type::VersionType;
|
||||
|
||||
/// unwrap a `syn::Result` by extracting the Ok value or returning from the outer function with
|
||||
/// a compile error
|
||||
macro_rules! syn_unwrap {
|
||||
($e:expr) => {
|
||||
match $e {
|
||||
Ok(res) => res,
|
||||
Err(err) => return err.to_compile_error().into(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Version)]
|
||||
/// Implement the `Version` trait for the target type.
|
||||
pub fn derive_version(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
impl_version_trait(&input).into()
|
||||
}
|
||||
|
||||
/// Actual implementation of the version trait. This will create the ref and owned
|
||||
/// associated types and use them to implement the trait.
|
||||
fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream {
|
||||
let version_trait = syn_unwrap!(AssociatingTrait::<VersionType>::new(
|
||||
input,
|
||||
VERSION_TRAIT_NAME,
|
||||
&[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME],
|
||||
&[VERSIONIZE_TRAIT_NAME, UNVERSIONIZE_TRAIT_NAME]
|
||||
));
|
||||
|
||||
let version_types = syn_unwrap!(version_trait.generate_types_declarations());
|
||||
|
||||
let version_impl = syn_unwrap!(version_trait.generate_impl());
|
||||
|
||||
quote! {
|
||||
const _: () = {
|
||||
#version_types
|
||||
#version_impl
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Implement the `VersionsDispatch` trait for the target type. The type where this macro is
|
||||
/// applied should be an enum where each variant is a version of the type that we want to
|
||||
/// versionize.
|
||||
#[proc_macro_derive(VersionsDispatch)]
|
||||
pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
let dispatch_trait = syn_unwrap!(AssociatingTrait::<DispatchType>::new(
|
||||
&input,
|
||||
DISPATCH_TRAIT_NAME,
|
||||
&[
|
||||
VERSIONIZE_TRAIT_NAME,
|
||||
UNVERSIONIZE_TRAIT_NAME,
|
||||
SERIALIZE_TRAIT_NAME,
|
||||
DESERIALIZE_OWNED_TRAIT_NAME
|
||||
],
|
||||
&[]
|
||||
));
|
||||
|
||||
let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations());
|
||||
|
||||
let dispatch_impl = syn_unwrap!(dispatch_trait.generate_impl());
|
||||
|
||||
quote! {
|
||||
const _: () = {
|
||||
#dispatch_types
|
||||
#dispatch_impl
|
||||
};
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
/// This derives the `Versionize` and `Unversionize` trait for the target type. This macro
|
||||
/// has a mandatory attribute parameter, which is the name of the versioned enum for this type.
|
||||
/// This enum can be anywhere in the code but should be in scope.
|
||||
#[proc_macro_derive(Versionize, attributes(versionize))]
|
||||
pub fn derive_versionize(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
|
||||
&input.attrs
|
||||
));
|
||||
|
||||
// If we apply a type conversion before the call to versionize, the type that implements
|
||||
// the `Version` trait is the target type and not Self
|
||||
let version_trait_impl: Option<proc_macro2::TokenStream> = if attributes.needs_conversion() {
|
||||
None
|
||||
} else {
|
||||
Some(impl_version_trait(&input))
|
||||
};
|
||||
|
||||
let dispatch_enum_path = attributes.dispatch_enum();
|
||||
let dispatch_target = attributes.dispatch_target();
|
||||
let input_ident = &input.ident;
|
||||
let mut ref_generics = input.generics.clone();
|
||||
let mut trait_generics = input.generics.clone();
|
||||
let (_, ty_generics, owned_where_clause) = input.generics.split_for_impl();
|
||||
|
||||
// If the original type has some generics, we need to add bounds on them for
|
||||
// the impl
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
add_where_lifetime_bound(&mut ref_generics, &lifetime);
|
||||
|
||||
// The versionize method takes a ref. We need to own the input type in the conversion case
|
||||
// to apply `From<Input> for Target`. This adds a `Clone` bound to have a better error message
|
||||
// if the input type is not Clone.
|
||||
if attributes.needs_conversion() {
|
||||
syn_unwrap!(add_trait_where_clause(
|
||||
&mut trait_generics,
|
||||
[&parse_quote! { Self }],
|
||||
&["Clone"]
|
||||
));
|
||||
};
|
||||
|
||||
let dispatch_generics = if attributes.needs_conversion() {
|
||||
None
|
||||
} else {
|
||||
Some(&ty_generics)
|
||||
};
|
||||
|
||||
let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME);
|
||||
|
||||
syn_unwrap!(add_trait_where_clause(
|
||||
&mut trait_generics,
|
||||
[&parse_quote!(#dispatch_enum_path #dispatch_generics)],
|
||||
&[format!(
|
||||
"{}<{}>",
|
||||
DISPATCH_TRAIT_NAME,
|
||||
dispatch_target.to_token_stream()
|
||||
)]
|
||||
));
|
||||
|
||||
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
||||
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
|
||||
|
||||
let (_, _, ref_where_clause) = ref_generics.split_for_impl();
|
||||
let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
|
||||
|
||||
// If we want to apply a conversion before the call to versionize we need to use the "owned"
|
||||
// alternative of the dispatch enum to be able to store the conversion result.
|
||||
let versioned_type_kind = if attributes.needs_conversion() {
|
||||
quote! { Owned #owned_where_clause }
|
||||
} else {
|
||||
quote! { Ref<#lifetime> #ref_where_clause }
|
||||
};
|
||||
|
||||
let versionize_body = attributes.versionize_method_body();
|
||||
let unversionize_arg_name = Ident::new("versioned", Span::call_site());
|
||||
let unversionize_body = attributes.unversionize_method_body(&unversionize_arg_name);
|
||||
let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
|
||||
quote! {
|
||||
#version_trait_impl
|
||||
|
||||
impl #impl_generics #versionize_trait for #input_ident #ty_generics
|
||||
#where_clause
|
||||
{
|
||||
type Versioned<#lifetime> =
|
||||
<#dispatch_enum_path #dispatch_generics as
|
||||
#dispatch_trait<#dispatch_target>>::#versioned_type_kind;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
#versionize_body
|
||||
}
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
#versionize_body
|
||||
}
|
||||
|
||||
type VersionedOwned =
|
||||
<#dispatch_enum_path #dispatch_generics as
|
||||
#dispatch_trait<#dispatch_target>>::Owned #owned_where_clause;
|
||||
|
||||
}
|
||||
|
||||
impl #impl_generics #unversionize_trait for #input_ident #ty_generics
|
||||
#where_clause
|
||||
{
|
||||
fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
|
||||
#unversionize_body
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
/// This derives the `Versionize` and `Unversionize` trait for a type that should not
|
||||
/// be versioned. The `versionize` method will simply return self
|
||||
#[proc_macro_derive(NotVersioned)]
|
||||
pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
let mut generics = input.generics.clone();
|
||||
syn_unwrap!(add_trait_where_clause(
|
||||
&mut generics,
|
||||
&[parse_quote! { Self }],
|
||||
&["Clone"]
|
||||
));
|
||||
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
let input_ident = &input.ident;
|
||||
|
||||
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
||||
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
|
||||
let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
|
||||
|
||||
quote! {
|
||||
impl #impl_generics #versionize_trait for #input_ident #ty_generics #where_clause {
|
||||
type Versioned<#lifetime> = &#lifetime Self;
|
||||
type VersionedOwned = Self;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self
|
||||
}
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl #impl_generics #unversionize_trait for #input_ident #ty_generics #where_clause {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
|
||||
Ok(versioned)
|
||||
}
|
||||
}
|
||||
|
||||
impl NotVersioned for #input_ident #ty_generics #where_clause {}
|
||||
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Adds a where clause with a lifetime bound on all the generic types and lifetimes in `generics`
|
||||
fn add_where_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
let mut params = Vec::new();
|
||||
for param in generics.params.iter() {
|
||||
let param_ident = match param {
|
||||
GenericParam::Lifetime(generic_lifetime) => {
|
||||
if generic_lifetime.lifetime.ident == lifetime.ident {
|
||||
continue;
|
||||
}
|
||||
&generic_lifetime.lifetime.ident
|
||||
}
|
||||
GenericParam::Type(generic_type) => &generic_type.ident,
|
||||
GenericParam::Const(_) => continue,
|
||||
};
|
||||
params.push(param_ident.clone());
|
||||
}
|
||||
|
||||
for param in params.iter() {
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote! { #param: #lifetime });
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a lifetime bound for all the generic types in `generics`
|
||||
fn add_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) {
|
||||
generics
|
||||
.params
|
||||
.push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
|
||||
for param in generics.type_params_mut() {
|
||||
param
|
||||
.bounds
|
||||
.push(TypeParamBound::Lifetime(lifetime.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the input str trait bound
|
||||
fn parse_trait_bound(trait_name: &str) -> syn::Result<TraitBound> {
|
||||
let trait_path: Path = syn::parse_str(trait_name)?;
|
||||
Ok(parse_quote!(#trait_path))
|
||||
}
|
||||
|
||||
/// Adds a trait bound for `trait_name` on all the generic types in `generics`
|
||||
fn add_trait_bound(generics: &mut Generics, trait_name: &str) -> syn::Result<()> {
|
||||
let trait_bound: TraitBound = parse_trait_bound(trait_name)?;
|
||||
for param in generics.type_params_mut() {
|
||||
param
|
||||
.bounds
|
||||
.push(TypeParamBound::Trait(trait_bound.clone()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a "where clause" bound for all the input types with all the input traits
|
||||
fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
|
||||
generics: &mut Generics,
|
||||
types: I,
|
||||
traits_name: &[S],
|
||||
) -> syn::Result<()> {
|
||||
let preds = &mut generics.make_where_clause().predicates;
|
||||
|
||||
if !traits_name.is_empty() {
|
||||
let bounds: Vec<TraitBound> = traits_name
|
||||
.iter()
|
||||
.map(|bound_name| parse_trait_bound(bound_name.as_ref()))
|
||||
.collect::<syn::Result<_>>()?;
|
||||
for ty in types {
|
||||
preds.push(parse_quote! { #ty: #(#bounds)+* });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates a Result [`syn::punctuated::Punctuated`] from an iterator of Results
|
||||
fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result<T>>>(
|
||||
iter: I,
|
||||
) -> syn::Result<Punctuated<T, P>> {
|
||||
let mut ret = Punctuated::new();
|
||||
for value in iter {
|
||||
ret.push(value?)
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Like [`syn::parse_str`] for inputs that are known at compile time to be valid
|
||||
fn parse_const_str<T: Parse>(s: &'static str) -> T {
|
||||
syn::parse_str(s).expect("Parsing of const string should not fail")
|
||||
}
|
||||
702
utils/tfhe-versionable-derive/src/version_type.rs
Normal file
702
utils/tfhe-versionable-derive/src/version_type.rs
Normal file
@@ -0,0 +1,702 @@
|
||||
use std::iter::zip;
|
||||
|
||||
use proc_macro2::{Literal, Span, TokenStream};
|
||||
use quote::{format_ident, quote};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime,
|
||||
Path, Type, Variant,
|
||||
};
|
||||
|
||||
use crate::associated::{
|
||||
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
|
||||
ConversionDirection,
|
||||
};
|
||||
use crate::{
|
||||
add_lifetime_bound, add_trait_where_clause, parse_const_str, parse_trait_bound,
|
||||
punctuated_from_iter_result, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME,
|
||||
VERSIONIZE_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// The types generated for a specific version of a given exposed type. These types are identical to
|
||||
/// the user written version types except that their subtypes are replaced by their "Versioned"
|
||||
/// form. This allows recursive versionning.
|
||||
pub(crate) struct VersionType {
|
||||
orig_type: DeriveInput,
|
||||
kind: AssociatedTypeKind,
|
||||
}
|
||||
|
||||
impl AssociatedType for VersionType {
|
||||
fn new_ref(orig_type: &DeriveInput) -> syn::Result<VersionType> {
|
||||
let lifetime = if is_unit(orig_type) {
|
||||
None
|
||||
} else {
|
||||
for lt in orig_type.generics.lifetimes() {
|
||||
// check for collision with other lifetimes in `orig_type`
|
||||
if lt.lifetime.ident == LIFETIME_NAME {
|
||||
return Err(syn::Error::new(
|
||||
lt.lifetime.span(),
|
||||
format!(
|
||||
"Lifetime name {} conflicts with the one used by macro `Version`",
|
||||
LIFETIME_NAME
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
Some(Lifetime::new(LIFETIME_NAME, Span::call_site()))
|
||||
};
|
||||
Ok(Self {
|
||||
orig_type: orig_type.clone(),
|
||||
kind: AssociatedTypeKind::Ref(lifetime),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
orig_type: orig_type.clone(),
|
||||
kind: AssociatedTypeKind::Owned,
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_type_declaration(&self) -> syn::Result<Item> {
|
||||
match &self.orig_type.data {
|
||||
Data::Struct(stru) => self.generate_struct(stru).map(Item::Struct),
|
||||
Data::Enum(enu) => self.generate_enum(enu).map(Item::Enum),
|
||||
Data::Union(uni) => self.generate_union(uni).map(Item::Union),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>> {
|
||||
let (_, orig_generics, _) = self.orig_type.generics.split_for_impl();
|
||||
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => {
|
||||
// Convert from `&'vers XXX` into `XXXVersion<'vers>`
|
||||
let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?;
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let src_ident = self.orig_type.ident.clone();
|
||||
let src = lifetime
|
||||
.as_ref()
|
||||
.map(|lifetime| parse_quote! { &#lifetime #src_ident #orig_generics })
|
||||
.unwrap_or_else(|| parse_quote! { &#src_ident #orig_generics });
|
||||
let dest_ident = self.ident();
|
||||
let dest = parse_quote! { #dest_ident #ty_generics };
|
||||
let constructor = self.generate_conversion_constructor(
|
||||
"value",
|
||||
&src_ident,
|
||||
ConversionDirection::OrigToAssociated,
|
||||
)?;
|
||||
|
||||
generate_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)
|
||||
.map(|res| vec![res])
|
||||
}
|
||||
AssociatedTypeKind::Owned => {
|
||||
// Convert from `XXXVersionOwned` into `XXX`
|
||||
let generics = self.conversion_generics(ConversionDirection::AssociatedToOrig)?;
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let src_ident = self.ident();
|
||||
let src = parse_quote! { #src_ident #ty_generics };
|
||||
let dest_ident = self.orig_type.ident.clone();
|
||||
let dest = parse_quote! { #dest_ident #orig_generics };
|
||||
let error = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
let constructor = self.generate_conversion_constructor(
|
||||
"value",
|
||||
&src_ident,
|
||||
ConversionDirection::AssociatedToOrig,
|
||||
)?;
|
||||
|
||||
let assoc_to_orig = generate_try_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&error,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)?;
|
||||
|
||||
// Convert from `&XXX` into `XXXVersionOwned`
|
||||
let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?;
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let src_ident = self.orig_type.ident.clone();
|
||||
let src = parse_quote! { &#src_ident #orig_generics };
|
||||
let dest_ident = self.ident();
|
||||
let dest = parse_quote! { #dest_ident #ty_generics };
|
||||
let constructor = self.generate_conversion_constructor(
|
||||
"value",
|
||||
&src_ident,
|
||||
ConversionDirection::OrigToAssociated,
|
||||
)?;
|
||||
|
||||
let orig_to_assoc = generate_from_trait_impl(
|
||||
&src,
|
||||
&dest,
|
||||
&impl_generics,
|
||||
where_clause,
|
||||
&constructor,
|
||||
"value",
|
||||
)?;
|
||||
|
||||
Ok(vec![assoc_to_orig, orig_to_assoc])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ident(&self) -> Ident {
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(_) => {
|
||||
format_ident!("{}Version", self.orig_type.ident)
|
||||
}
|
||||
AssociatedTypeKind::Owned => {
|
||||
format_ident!("{}VersionOwned", self.orig_type.ident)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lifetime(&self) -> Option<&Lifetime> {
|
||||
match &self.kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => lifetime.as_ref(),
|
||||
AssociatedTypeKind::Owned => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_trait_param(&self) -> Option<syn::Result<&Type>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn inner_types(&self) -> syn::Result<Vec<&Type>> {
|
||||
self.orig_type_fields()
|
||||
.iter()
|
||||
.map(|field| Ok(&field.ty))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl VersionType {
|
||||
/// Returns the fields of the original declaration.
|
||||
fn orig_type_fields(&self) -> Punctuated<&Field, Comma> {
|
||||
derive_type_fields(&self.orig_type)
|
||||
}
|
||||
|
||||
fn type_generics(&self) -> syn::Result<Generics> {
|
||||
let mut generics = self.orig_type.generics.clone();
|
||||
if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind {
|
||||
add_lifetime_bound(&mut generics, lifetime);
|
||||
}
|
||||
|
||||
add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?;
|
||||
|
||||
Ok(generics)
|
||||
}
|
||||
|
||||
fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics> {
|
||||
let mut generics = self.type_generics()?;
|
||||
|
||||
if let ConversionDirection::AssociatedToOrig = direction {
|
||||
if let AssociatedTypeKind::Owned = &self.kind {
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
self.inner_types()?,
|
||||
&[UNVERSIONIZE_TRAIT_NAME],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(generics)
|
||||
}
|
||||
|
||||
/// Generates the declaration for the Version equivalent of the input struct
|
||||
fn generate_struct(&self, stru: &DataStruct) -> syn::Result<ItemStruct> {
|
||||
let fields = match &stru.fields {
|
||||
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?),
|
||||
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?),
|
||||
Fields::Unit => Fields::Unit,
|
||||
};
|
||||
|
||||
let versioned_stru = ItemStruct {
|
||||
fields,
|
||||
ident: self.ident(),
|
||||
vis: self.orig_type.vis.clone(),
|
||||
attrs: Vec::new(),
|
||||
generics: self.type_generics()?,
|
||||
struct_token: stru.struct_token,
|
||||
semi_token: stru.semi_token,
|
||||
};
|
||||
|
||||
Ok(versioned_stru)
|
||||
}
|
||||
|
||||
/// Generates the declaration for the Version equivalent of the input enum
|
||||
fn generate_enum(&self, enu: &DataEnum) -> syn::Result<ItemEnum> {
|
||||
if enu.variants.is_empty() {
|
||||
return Err(syn::Error::new(
|
||||
self.orig_type.span(),
|
||||
"Version cannot be derived on empty enums",
|
||||
));
|
||||
}
|
||||
|
||||
let variants = punctuated_from_iter_result(
|
||||
enu.variants
|
||||
.iter()
|
||||
.map(|variant| self.convert_enum_variant(variant)),
|
||||
)?;
|
||||
|
||||
let versioned_enu = ItemEnum {
|
||||
ident: self.ident(),
|
||||
vis: self.orig_type.vis.clone(),
|
||||
attrs: Vec::new(),
|
||||
generics: self.type_generics()?,
|
||||
enum_token: enu.enum_token,
|
||||
brace_token: enu.brace_token,
|
||||
variants,
|
||||
};
|
||||
|
||||
Ok(versioned_enu)
|
||||
}
|
||||
|
||||
/// Generates the declaration for the Version equivalent of the input union
|
||||
fn generate_union(&self, uni: &DataUnion) -> syn::Result<ItemUnion> {
|
||||
let fields = self.convert_fields_named(&uni.fields)?;
|
||||
|
||||
let versioned_uni = ItemUnion {
|
||||
fields,
|
||||
ident: self.ident(),
|
||||
vis: self.orig_type.vis.clone(),
|
||||
attrs: Vec::new(),
|
||||
generics: self.type_generics()?,
|
||||
union_token: uni.union_token,
|
||||
};
|
||||
|
||||
Ok(versioned_uni)
|
||||
}
|
||||
|
||||
/// Converts an enum variant into its "Version" form
|
||||
fn convert_enum_variant(&self, variant: &Variant) -> syn::Result<Variant> {
|
||||
let fields = match &variant.fields {
|
||||
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?),
|
||||
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?),
|
||||
Fields::Unit => Fields::Unit,
|
||||
};
|
||||
|
||||
let versioned_variant = Variant {
|
||||
attrs: Vec::new(),
|
||||
ident: variant.ident.clone(),
|
||||
fields,
|
||||
discriminant: variant.discriminant.clone(),
|
||||
};
|
||||
|
||||
Ok(versioned_variant)
|
||||
}
|
||||
|
||||
/// Converts unnamed fields into Versioned
|
||||
fn convert_fields_unnamed(&self, fields: &FieldsUnnamed) -> syn::Result<FieldsUnnamed> {
|
||||
Ok(FieldsUnnamed {
|
||||
unnamed: punctuated_from_iter_result(self.convert_fields(fields.unnamed.iter()))?,
|
||||
..fields.clone()
|
||||
})
|
||||
}
|
||||
|
||||
/// Converts named fields into Versioned
|
||||
fn convert_fields_named(&self, fields: &FieldsNamed) -> syn::Result<FieldsNamed> {
|
||||
Ok(FieldsNamed {
|
||||
named: punctuated_from_iter_result(self.convert_fields(fields.named.iter()))?,
|
||||
..fields.clone()
|
||||
})
|
||||
}
|
||||
|
||||
/// Converts all fields in the given iterator into their "Versioned" counterparts.
|
||||
fn convert_fields<'a, I: Iterator<Item = &'a Field> + 'a>(
|
||||
&self,
|
||||
fields_iter: I,
|
||||
) -> impl IntoIterator<Item = syn::Result<Field>> + 'a {
|
||||
let kind = self.kind.clone();
|
||||
fields_iter.into_iter().map(move |field| {
|
||||
let unver_ty = field.ty.clone();
|
||||
|
||||
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
|
||||
|
||||
let ty: Type = match &kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#unver_ty as #versionize_trait>::VersionedOwned
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor part of the conversion impl block. This will create the dest type
|
||||
/// using fields of the src one. This is easy since they both have the same shape.
|
||||
/// If the conversion is from the original type to a reference version type, this is done by
|
||||
/// calling the `versionize` method on all fields.
|
||||
/// If this is a conversion between the owned version type to the original type, this is done by
|
||||
/// calling the `unversionize` method.
|
||||
fn generate_conversion_constructor(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
src_type: &Ident,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let constructor = match &self.orig_type.data {
|
||||
Data::Struct(stru) => self.generate_constructor_struct(arg_name, stru, direction),
|
||||
Data::Enum(enu) => self.generate_constructor_enum(arg_name, src_type, enu, direction),
|
||||
Data::Union(uni) => self.generate_constructor_union(arg_name, uni, direction),
|
||||
}?;
|
||||
|
||||
match direction {
|
||||
ConversionDirection::OrigToAssociated => Ok(constructor),
|
||||
ConversionDirection::AssociatedToOrig => Ok(quote! { Ok(#constructor) }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the constructor for a struct.
|
||||
fn generate_constructor_struct(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
stru: &DataStruct,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields = match &stru.fields {
|
||||
Fields::Named(fields) => {
|
||||
self.generate_constructor_fields_named(arg_name, fields.named.iter(), direction)?
|
||||
}
|
||||
Fields::Unnamed(fields) => self.generate_constructor_fields_unnamed(
|
||||
arg_name,
|
||||
fields.unnamed.iter(),
|
||||
direction,
|
||||
)?,
|
||||
Fields::Unit => TokenStream::new(),
|
||||
};
|
||||
|
||||
Ok(quote! {
|
||||
Self #fields
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for an enum.
|
||||
fn generate_constructor_enum(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
src_type: &Ident,
|
||||
enu: &DataEnum,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
if enu.variants.is_empty() {
|
||||
return Err(syn::Error::new(
|
||||
self.orig_type.span(),
|
||||
"Version cannot be derived on empty enums",
|
||||
));
|
||||
}
|
||||
|
||||
let variant_constructors: syn::Result<Vec<TokenStream>> = enu
|
||||
.variants
|
||||
.iter()
|
||||
.map(|variant| self.generate_constructor_enum_variant(src_type, variant, direction))
|
||||
.collect();
|
||||
let variant_constructors = variant_constructors?;
|
||||
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
|
||||
Ok(quote! {
|
||||
match #arg_ident {
|
||||
#(#variant_constructors),*
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for an union.
|
||||
fn generate_constructor_union(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
uni: &DataUnion,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields =
|
||||
self.generate_constructor_fields_named(arg_name, uni.fields.named.iter(), direction)?;
|
||||
|
||||
Ok(quote! {
|
||||
Self #fields
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for a specific variant of an enum
|
||||
fn generate_constructor_enum_variant(
|
||||
&self,
|
||||
src_type: &Ident,
|
||||
variant: &Variant,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let (param, fields) = match &variant.fields {
|
||||
Fields::Named(fields) => {
|
||||
let args_iter = fields
|
||||
.named
|
||||
.iter()
|
||||
// Ok to unwrap because the field is named so field.ident is Some
|
||||
.map(|field| field.ident.as_ref().unwrap());
|
||||
let args = args_iter.clone();
|
||||
|
||||
(
|
||||
quote! { {
|
||||
#(#args),*
|
||||
}},
|
||||
self.generate_constructor_enum_variants_named(
|
||||
args_iter.cloned(),
|
||||
fields.named.iter(),
|
||||
direction,
|
||||
)?,
|
||||
)
|
||||
}
|
||||
Fields::Unnamed(fields) => {
|
||||
let args_iter = generate_args_list(fields.unnamed.len());
|
||||
let args = args_iter.clone();
|
||||
(
|
||||
quote! { (#(#args),*) },
|
||||
self.generate_constructor_enum_variants_unnamed(
|
||||
args_iter,
|
||||
fields.unnamed.iter(),
|
||||
direction,
|
||||
)?,
|
||||
)
|
||||
}
|
||||
Fields::Unit => (TokenStream::new(), TokenStream::new()),
|
||||
};
|
||||
let variant_ident = &variant.ident;
|
||||
|
||||
Ok(quote! {
|
||||
#src_type::#variant_ident #param => Self::#variant_ident #fields
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for the fields of a named struct.
|
||||
fn generate_constructor_fields_named<'a, I: Iterator<Item = &'a Field> + 'a>(
|
||||
&self,
|
||||
arg_name: &'a str,
|
||||
fields: I,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = fields
|
||||
.into_iter()
|
||||
.map(move |field| self.generate_constructor_field_named(arg_name, field, direction))
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
|
||||
Ok(quote! {
|
||||
{
|
||||
#(#fields),*
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for a field of a named struct.
|
||||
fn generate_constructor_field_named(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
field: &Field,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
// Ok to unwrap because the field is named so field.ident is Some
|
||||
let field_ident = field.ident.as_ref().unwrap();
|
||||
let ty = &field.ty;
|
||||
let param = quote! { #arg_ident.#field_ident };
|
||||
|
||||
let rhs = self.generate_constructor_field_rhs(ty, param, false, direction)?;
|
||||
|
||||
Ok(quote! {
|
||||
#field_ident: #rhs
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for the fields of a named enum variant.
|
||||
fn generate_constructor_enum_variants_named<
|
||||
'a,
|
||||
I: Iterator<Item = &'a Field> + 'a,
|
||||
J: Iterator<Item = Ident>,
|
||||
>(
|
||||
&self,
|
||||
arg_names: J,
|
||||
fields: I,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
|
||||
.map(move |(arg_name, field)| {
|
||||
// Ok to unwrap because the field is named so field.ident is Some
|
||||
let field_ident = field.ident.as_ref().unwrap();
|
||||
let rhs = self.generate_constructor_field_rhs(
|
||||
&field.ty,
|
||||
quote! {#arg_name},
|
||||
true,
|
||||
direction,
|
||||
)?;
|
||||
Ok(quote! {
|
||||
#field_ident: #rhs
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
|
||||
Ok(quote! {
|
||||
{
|
||||
#(#fields),*
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for the fields of an unnamed struct.
|
||||
fn generate_constructor_fields_unnamed<'a, I: Iterator<Item = &'a Field> + 'a>(
|
||||
&self,
|
||||
arg_name: &'a str,
|
||||
fields: I,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = fields
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(idx, field)| {
|
||||
self.generate_constructor_field_unnamed(arg_name, field, idx, direction)
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
|
||||
Ok(quote! {
|
||||
(#(#fields),*)
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for a field of an unnamed struct.
|
||||
fn generate_constructor_field_unnamed(
|
||||
&self,
|
||||
arg_name: &str,
|
||||
field: &Field,
|
||||
idx: usize,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
let idx = Literal::usize_unsuffixed(idx);
|
||||
let ty = &field.ty;
|
||||
let param = quote! { #arg_ident.#idx };
|
||||
|
||||
self.generate_constructor_field_rhs(ty, param, false, direction)
|
||||
}
|
||||
|
||||
/// Generates the constructor for the fields of an unnamed enum variant.
|
||||
fn generate_constructor_enum_variants_unnamed<
|
||||
'a,
|
||||
I: Iterator<Item = &'a Field> + 'a,
|
||||
J: Iterator<Item = Ident>,
|
||||
>(
|
||||
&self,
|
||||
arg_names: J,
|
||||
fields: I,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
|
||||
.map(move |(arg_name, field)| {
|
||||
self.generate_constructor_field_rhs(&field.ty, quote! {#arg_name}, true, direction)
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
|
||||
Ok(quote! {
|
||||
(#(#fields),*)
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the rhs part of a field constructor.
|
||||
/// For example, in `Self { count: value.count.versionize() }`, this is
|
||||
/// `value.count.versionize()`.
|
||||
fn generate_constructor_field_rhs(
|
||||
&self,
|
||||
ty: &Type,
|
||||
field_param: TokenStream,
|
||||
is_ref: bool, // True if the param is already a reference
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
||||
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
|
||||
|
||||
let field_constructor = match direction {
|
||||
ConversionDirection::OrigToAssociated => {
|
||||
let param = if is_ref {
|
||||
field_param
|
||||
} else {
|
||||
quote! {&#field_param}
|
||||
};
|
||||
|
||||
match self.kind {
|
||||
AssociatedTypeKind::Ref(_) => quote! {
|
||||
#versionize_trait::versionize(#param)
|
||||
},
|
||||
AssociatedTypeKind::Owned => quote! {
|
||||
#versionize_trait::versionize_owned(#param)
|
||||
},
|
||||
}
|
||||
}
|
||||
ConversionDirection::AssociatedToOrig => match self.kind {
|
||||
AssociatedTypeKind::Ref(_) =>
|
||||
panic!("No conversion should be generated between associated ref type to original type"),
|
||||
AssociatedTypeKind::Owned => quote! {
|
||||
<#ty as #unversionize_trait>::unversionize(#field_param)?
|
||||
},
|
||||
},
|
||||
};
|
||||
Ok(field_constructor)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a list of argument names. This is used to create a pattern matching of a
|
||||
/// tuple-like enum variant.
|
||||
fn generate_args_list(count: usize) -> impl Iterator<Item = Ident> + Clone {
|
||||
(0..count).map(|val| format_ident!("value{}", val))
|
||||
}
|
||||
|
||||
/// Checks if the type is a unit type that contains no data
|
||||
fn is_unit(input: &DeriveInput) -> bool {
|
||||
match &input.data {
|
||||
Data::Struct(stru) => stru.fields.is_empty(),
|
||||
Data::Enum(enu) => enu.variants.iter().all(|variant| variant.fields.is_empty()),
|
||||
Data::Union(uni) => uni.fields.named.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the fields of the input type. This is independant of the kind of type
|
||||
/// (enum, struct, ...)
|
||||
fn derive_type_fields(input: &DeriveInput) -> Punctuated<&Field, Comma> {
|
||||
match &input.data {
|
||||
Data::Struct(stru) => match &stru.fields {
|
||||
Fields::Named(fields) => Punctuated::from_iter(fields.named.iter()),
|
||||
Fields::Unnamed(fields) => Punctuated::from_iter(fields.unnamed.iter()),
|
||||
Fields::Unit => Punctuated::new(),
|
||||
},
|
||||
Data::Enum(enu) => Punctuated::<&Field, Comma>::from_iter(
|
||||
enu.variants
|
||||
.iter()
|
||||
.filter_map(|variant| match &variant.fields {
|
||||
Fields::Named(fields) => Some(fields.named.iter()),
|
||||
Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
|
||||
Fields::Unit => None,
|
||||
})
|
||||
.flatten(),
|
||||
),
|
||||
Data::Union(uni) => Punctuated::from_iter(uni.fields.named.iter()),
|
||||
}
|
||||
}
|
||||
185
utils/tfhe-versionable-derive/src/versionize_attribute.rs
Normal file
185
utils/tfhe-versionable-derive/src/versionize_attribute.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use proc_macro2::Span;
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{Attribute, Expr, Ident, Lit, Meta, Path, Token, Type};
|
||||
|
||||
use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME};
|
||||
|
||||
/// Name of the attribute used to give arguments to our macros
|
||||
const VERSIONIZE_ATTR_NAME: &str = "versionize";
|
||||
|
||||
pub(crate) struct VersionizeAttribute {
|
||||
dispatch_enum: Path,
|
||||
from: Option<Path>,
|
||||
try_from: Option<Path>,
|
||||
into: Option<Path>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct VersionizeAttributeBuilder {
|
||||
dispatch_enum: Option<Path>,
|
||||
from: Option<Path>,
|
||||
try_from: Option<Path>,
|
||||
into: Option<Path>,
|
||||
}
|
||||
|
||||
impl VersionizeAttributeBuilder {
|
||||
fn build(self) -> Option<VersionizeAttribute> {
|
||||
// These attributes are mutually exclusive
|
||||
if self.from.is_some() && self.try_from.is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(VersionizeAttribute {
|
||||
dispatch_enum: self.dispatch_enum?,
|
||||
from: self.from,
|
||||
try_from: self.try_from,
|
||||
into: self.into,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VersionizeAttribute {
|
||||
/// Find and parse an attribute with the form `#[versionize(DispatchType)]`, where
|
||||
/// `DispatchType` is the name of the type holding the dispatch enum.
|
||||
/// Returns an error if no `versionize` attribute has been found, if multiple attributes are
|
||||
/// present on the same struct or if the attribute is malformed.
|
||||
pub(crate) fn parse_from_attributes_list(attributes: &[Attribute]) -> syn::Result<Self> {
|
||||
let version_attributes: Vec<&Attribute> = attributes
|
||||
.iter()
|
||||
.filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME))
|
||||
.collect();
|
||||
|
||||
match version_attributes.as_slice() {
|
||||
[] => Err(syn::Error::new(
|
||||
Span::call_site(),
|
||||
"Missing `versionize` attribute for `Versionize`",
|
||||
)),
|
||||
[attr] => Self::parse_from_attribute(attr),
|
||||
[_, attr2, ..] => Err(syn::Error::new(
|
||||
attr2.span(),
|
||||
"Multiple `versionize` attributes found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_error(span: Span) -> syn::Error {
|
||||
syn::Error::new(span, "Malformed `versionize` attribute")
|
||||
}
|
||||
|
||||
/// Parse a `versionize` attribute.
|
||||
/// The attribute is assumed to be a `versionize` attribute.
|
||||
pub(crate) fn parse_from_attribute(attribute: &Attribute) -> syn::Result<Self> {
|
||||
let nested = attribute.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
|
||||
|
||||
let mut attribute_builder = VersionizeAttributeBuilder::default();
|
||||
for meta in nested.iter() {
|
||||
match meta {
|
||||
Meta::Path(dispatch_enum) => {
|
||||
if attribute_builder.dispatch_enum.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.dispatch_enum = Some(dispatch_enum.clone());
|
||||
}
|
||||
}
|
||||
Meta::List(_) => {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
}
|
||||
Meta::NameValue(name_value) => {
|
||||
if name_value.path.is_ident("from") {
|
||||
if attribute_builder.from.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.from =
|
||||
Some(parse_path_ignore_quotes(&name_value.value)?);
|
||||
}
|
||||
} else if name_value.path.is_ident("try_from") {
|
||||
if attribute_builder.try_from.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.try_from =
|
||||
Some(parse_path_ignore_quotes(&name_value.value)?);
|
||||
}
|
||||
} else if name_value.path.is_ident("into") {
|
||||
if attribute_builder.into.is_some() {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
} else {
|
||||
attribute_builder.into =
|
||||
Some(parse_path_ignore_quotes(&name_value.value)?);
|
||||
}
|
||||
} else {
|
||||
return Err(Self::default_error(meta.span()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attribute_builder
|
||||
.build()
|
||||
.ok_or_else(|| Self::default_error(attribute.span()))
|
||||
}
|
||||
|
||||
pub(crate) fn dispatch_enum(&self) -> &Path {
|
||||
&self.dispatch_enum
|
||||
}
|
||||
|
||||
pub(crate) fn needs_conversion(&self) -> bool {
|
||||
self.try_from.is_some() || self.from.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn dispatch_target(&self) -> Path {
|
||||
self.from
|
||||
.as_ref()
|
||||
.or(self.try_from.as_ref())
|
||||
.map(|target| target.to_owned())
|
||||
.unwrap_or_else(|| {
|
||||
syn::parse_str("Self").expect("Parsing of const value should never fail")
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream {
|
||||
self.into
|
||||
.as_ref()
|
||||
.map(|target| {
|
||||
quote! {
|
||||
#target::from(self.to_owned()).versionize_owned()
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
quote! {
|
||||
self.into()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn unversionize_method_body(&self, arg_name: &Ident) -> proc_macro2::TokenStream {
|
||||
let error: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
if let Some(target) = &self.from {
|
||||
quote! { #target::unversionize(#arg_name).map(|value| value.into()) }
|
||||
} else if let Some(target) = &self.try_from {
|
||||
let target_name = format!("{}", target.to_token_stream());
|
||||
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
|
||||
.map_err(|e| #error::conversion(#target_name, &format!("{}", e))))
|
||||
}
|
||||
} else {
|
||||
quote! { #arg_name.try_into() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
|
||||
match &value {
|
||||
Expr::Path(expr_path) => Ok(expr_path.path.clone()),
|
||||
Expr::Lit(expr_lit) => match &expr_lit.lit {
|
||||
Lit::Str(s) => syn::parse_str(&s.value()),
|
||||
_ => Err(syn::Error::new(
|
||||
value.span(),
|
||||
"Malformed `versionize` attribute",
|
||||
)),
|
||||
},
|
||||
_ => Err(syn::Error::new(
|
||||
value.span(),
|
||||
"Malformed `versionize` attribute",
|
||||
)),
|
||||
}
|
||||
}
|
||||
28
utils/tfhe-versionable/Cargo.toml
Normal file
28
utils/tfhe-versionable/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "tfhe-versionable"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
keywords = ["versioning", "serialization", "encoding"]
|
||||
homepage = "https://zama.ai/"
|
||||
documentation = "https://docs.rs/tfhe_versionable"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "tfhe-versionable: Add versioning informations/backward compatibility on rust types used for serialization"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dev-dependencies]
|
||||
static_assertions = "1.1"
|
||||
trybuild = { version = "1", features = ["diff"] }
|
||||
|
||||
# used to test various serialization formats
|
||||
bincode = "1.3"
|
||||
serde_json = "1.0"
|
||||
ciborium = "0.2"
|
||||
rmp-serde = "1.3"
|
||||
serde_yaml = "0.9"
|
||||
toml = "0.8"
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tfhe-versionable-derive = { version = "0.1.0", path = "../tfhe-versionable-derive" }
|
||||
102
utils/tfhe-versionable/README.md
Normal file
102
utils/tfhe-versionable/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# TFHE-versionable
|
||||
This crate provides type level versioning for serialized data. It offers a way to add backward
|
||||
compatibility on any data type. The versioning scheme works recursively and is idependant of the
|
||||
chosen serialization backend.
|
||||
|
||||
To use it, simply define an enum that have a variant for each version of your target type.
|
||||
|
||||
For example, if you have defined an internal type:
|
||||
```rust
|
||||
struct MyStruct {
|
||||
val: u32
|
||||
}
|
||||
```
|
||||
|
||||
You have to define the following enum:
|
||||
```rust
|
||||
enum MyStructVersions {
|
||||
V0(MyStruct)
|
||||
}
|
||||
```
|
||||
|
||||
If at a subsequent point in time you want to add a field to this struct, the idea is to copy the previously defined version of the struct and create a new one with the added field. This mostly becomes:
|
||||
```rust
|
||||
struct MyStruct {
|
||||
val: u32,
|
||||
newval: u64
|
||||
}
|
||||
|
||||
struct MyStructV0 {
|
||||
val: u32
|
||||
}
|
||||
|
||||
enum MyStructVersions {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct)
|
||||
}
|
||||
```
|
||||
|
||||
You also have to implement the `Upgrade` trait, that tells how to go from a version to another.
|
||||
|
||||
To make this work recursively, this crate defines 3 derive macro that should be used on these types:
|
||||
- `Versionize` should be used on the current version of your type, the one that is used in your code
|
||||
- `Version` is used on every previous version of this type
|
||||
- `VersionsDispatch` is used on the enum holding all the versions
|
||||
|
||||
This will implement the `Versionize`/`Unversionize` traits with their `versionize` and `unversionize` methods that should be used before/after the calls to `serialize`/`deserialize`.
|
||||
|
||||
The enum variants should keep their order and names between versions. The only supported operation is to add a new variant.
|
||||
|
||||
# Complete example
|
||||
```rust
|
||||
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
// The structure that should be versioned, as defined in your code
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)] // Link to the enum type that will holds all the versions of this type
|
||||
struct MyStruct<T: Default> {
|
||||
attr: T,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
// To avoid polluting your code, the old versions can be defined in another module/file, along with the dispatch enum
|
||||
#[derive(Version)] // Used to mark an old version of the type
|
||||
struct MyStructV0 {
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
// The Upgrade trait tells how to go from the first version to the last. During unversioning, the
|
||||
// upgrade method will be called on the deserialized value enough times to go to the last variant.
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> MyStruct<T> {
|
||||
MyStruct {
|
||||
attr: T::default(),
|
||||
builtin: self.builtin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the dispatch enum, that holds one variant for each version of your type.
|
||||
#[derive(VersionsDispatch)]
|
||||
// This enum is not directly used but serves as a template to generate new enums that will be
|
||||
// serialized. This allows recursive versioning.
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct<T>),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let ms = MyStruct {
|
||||
attr: 37u64,
|
||||
builtin: 1234,
|
||||
};
|
||||
|
||||
let serialized = bincode::serialize(&ms.versionize()).unwrap();
|
||||
|
||||
// This can be called in future versions of your application, when more variants have been added
|
||||
let _unserialized = MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap());
|
||||
}
|
||||
```
|
||||
|
||||
See the `examples` folder for more usecases.
|
||||
50
utils/tfhe-versionable/examples/convert.rs
Normal file
50
utils/tfhe-versionable/examples/convert.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
//! Show how to call a conversion method (from/into) before versioning/unversioning
|
||||
|
||||
use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Clone, Versionize)]
|
||||
#[versionize(SerializableMyStructVersions, from = SerializableMyStruct, into = SerializableMyStruct)]
|
||||
struct MyStruct {
|
||||
val: u64,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(SerializableMyStructVersions)]
|
||||
struct SerializableMyStruct {
|
||||
high: u32,
|
||||
low: u32,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum SerializableMyStructVersions {
|
||||
V0(SerializableMyStruct),
|
||||
}
|
||||
|
||||
impl From<MyStruct> for SerializableMyStruct {
|
||||
fn from(value: MyStruct) -> Self {
|
||||
println!("{}", value.val);
|
||||
Self {
|
||||
high: (value.val >> 32) as u32,
|
||||
low: (value.val & 0xffffffff) as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SerializableMyStruct> for MyStruct {
|
||||
fn from(value: SerializableMyStruct) -> Self {
|
||||
Self {
|
||||
val: ((value.high as u64) << 32) | (value.low as u64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stru = MyStruct { val: 37 };
|
||||
|
||||
let serialized = bincode::serialize(&stru.versionize()).unwrap();
|
||||
|
||||
let stru_decoded = MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(stru.val, stru_decoded.val)
|
||||
}
|
||||
77
utils/tfhe-versionable/examples/failed_upgrade.rs
Normal file
77
utils/tfhe-versionable/examples/failed_upgrade.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
//! The upgrade method can return an error. In that case, the error is propagated to
|
||||
//! the outer `unversionize` call.
|
||||
|
||||
use tfhe_versionable::{Unversionize, Versionize};
|
||||
|
||||
mod v0 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct(pub Option<u32>);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::VersionsDispatch;
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod v1 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct(pub u32);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV0(pub Option<u32>);
|
||||
|
||||
impl Upgrade<MyStruct> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStruct, String> {
|
||||
match self.0 {
|
||||
Some(val) => Ok(MyStruct(val)),
|
||||
None => Err("Cannot convert from empty \"MyStructV0\"".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let v0 = v0::MyStruct(Some(37));
|
||||
let serialized = bincode::serialize(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0.unwrap(), v1.0);
|
||||
|
||||
let v0_empty = v0::MyStruct(None);
|
||||
let serialized_empty = bincode::serialize(&v0_empty.versionize()).unwrap();
|
||||
|
||||
assert!(v1::MyStruct::unversionize(bincode::deserialize(&serialized_empty).unwrap()).is_err());
|
||||
}
|
||||
100
utils/tfhe-versionable/examples/manual_impl.rs
Normal file
100
utils/tfhe-versionable/examples/manual_impl.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! The simple example, with manual implementation of the versionize trait
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize};
|
||||
|
||||
struct MyStruct<T: Default> {
|
||||
attr: T,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct MyStructV0 {
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct {
|
||||
attr: T::default(),
|
||||
builtin: self.builtin,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct MyStructVersion<'vers, T: 'vers + Default + Versionize> {
|
||||
attr: T::Versioned<'vers>,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct MyStructVersionOwned<T: Default + Versionize> {
|
||||
attr: T::VersionedOwned,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
impl<T: Default + Versionize + Serialize + DeserializeOwned> Versionize for MyStruct<T> {
|
||||
type Versioned<'vers> = MyStructVersionsDispatch<'vers, T>
|
||||
where
|
||||
Self: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
let ver = MyStructVersion {
|
||||
attr: self.attr.versionize(),
|
||||
builtin: self.builtin,
|
||||
};
|
||||
MyStructVersionsDispatch::V1(ver)
|
||||
}
|
||||
|
||||
type VersionedOwned = MyStructVersionsDispatchOwned<T>;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
let ver = MyStructVersionOwned {
|
||||
attr: self.attr.versionize_owned(),
|
||||
builtin: self.builtin,
|
||||
};
|
||||
MyStructVersionsDispatchOwned::V1(ver)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Default + Versionize + Unversionize + Serialize + DeserializeOwned> Unversionize
|
||||
for MyStruct<T>
|
||||
{
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
match versioned {
|
||||
MyStructVersionsDispatchOwned::V0(v0) => v0
|
||||
.upgrade()
|
||||
.map_err(|e| UnversionizeError::upgrade("V0", "V1", &e)),
|
||||
MyStructVersionsDispatchOwned::V1(v1) => Ok(Self {
|
||||
attr: T::unversionize(v1.attr)?,
|
||||
builtin: v1.builtin,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[allow(dead_code)]
|
||||
enum MyStructVersionsDispatch<'vers, T: 'vers + Default + Versionize> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStructVersion<'vers, T>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
enum MyStructVersionsDispatchOwned<T: Default + Versionize> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStructVersionOwned<T>),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let ms = MyStruct {
|
||||
attr: 37u64,
|
||||
builtin: 1234,
|
||||
};
|
||||
|
||||
let serialized = bincode::serialize(&ms.versionize()).unwrap();
|
||||
|
||||
let _unserialized = MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap());
|
||||
}
|
||||
24
utils/tfhe-versionable/examples/not_versioned.rs
Normal file
24
utils/tfhe-versionable/examples/not_versioned.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
//! This example shows how to create a type that should not be versioned, even if it is
|
||||
//! included in other versioned types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::{NotVersioned, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, NotVersioned)]
|
||||
struct MyStructNotVersioned {
|
||||
val: u32,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
struct MyStruct {
|
||||
inner: MyStructNotVersioned,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
45
utils/tfhe-versionable/examples/recursive.rs
Normal file
45
utils/tfhe-versionable/examples/recursive.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
//! An example of recursive versioning
|
||||
|
||||
use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructInnerVersions)]
|
||||
struct MyStructInner<T: Default> {
|
||||
attr: T,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
struct MyStructInnerV0 {
|
||||
attr: u32,
|
||||
}
|
||||
|
||||
impl<T: Default> Upgrade<MyStructInner<T>> for MyStructInnerV0 {
|
||||
fn upgrade(self) -> Result<MyStructInner<T>, String> {
|
||||
Ok(MyStructInner {
|
||||
attr: T::default(),
|
||||
builtin: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyStructInnerVersions<T: Default> {
|
||||
V0(MyStructInnerV0),
|
||||
V1(MyStructInner<T>),
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
struct MyStruct<T: Default> {
|
||||
inner: MyStructInner<T>,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions<T: Default> {
|
||||
V0(MyStruct<T>),
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
52
utils/tfhe-versionable/examples/simple.rs
Normal file
52
utils/tfhe-versionable/examples/simple.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
//! Shows a basic usage of this crate
|
||||
|
||||
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
// The structure that should be versioned, as defined in your code
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)] // Link to the enum type that will holds all the versions of this
|
||||
// type
|
||||
struct MyStruct<T: Default> {
|
||||
attr: T,
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
// To avoid polluting your code, the old versions can be defined in another module/file, along with
|
||||
// the dispatch enum
|
||||
#[derive(Version)] // Used to mark an old version of the type
|
||||
struct MyStructV0 {
|
||||
builtin: u32,
|
||||
}
|
||||
|
||||
// The Upgrade trait tells how to go from the first version to the last. During unversioning, the
|
||||
// upgrade method will be called on the deserialized value enough times to go to the last variant.
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct {
|
||||
attr: T::default(),
|
||||
builtin: self.builtin,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// This is the dispatch enum, that holds one variant for each version of your type.
|
||||
#[derive(VersionsDispatch)]
|
||||
// This enum is not directly used but serves as a template to generate a new enum that will be
|
||||
// serialized. This allows recursive versioning.
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct<T>),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let ms = MyStruct {
|
||||
attr: 37u64,
|
||||
builtin: 1234,
|
||||
};
|
||||
|
||||
let serialized = bincode::serialize(&ms.versionize()).unwrap();
|
||||
|
||||
// This can be called in future versions of your application, when more variants have been added
|
||||
let _unserialized = MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap());
|
||||
}
|
||||
139
utils/tfhe-versionable/examples/upgrades.rs
Normal file
139
utils/tfhe-versionable/examples/upgrades.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
//! A more realistic example of a codebase that evolves in time. Each "mod vN" should be seen as
|
||||
//! a version of an application. The "backward_compat" mods can be in different files.
|
||||
|
||||
use tfhe_versionable::{Unversionize, Versionize};
|
||||
|
||||
mod v0 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct(pub u32);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::VersionsDispatch;
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod v1 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct<T: Default>(pub u32, pub T);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV0(pub u32);
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct(self.0, T::default()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct<T>),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod v2 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::{MyEnumVersions, MyStructVersions};
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyEnumVersions)]
|
||||
pub enum MyEnum<T: Default> {
|
||||
Variant0,
|
||||
Variant1 { count: u64 },
|
||||
Variant2(T),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct<T: Default> {
|
||||
pub count: u32,
|
||||
pub attr: T,
|
||||
}
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use super::{MyEnum, MyStruct};
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV0(pub u32);
|
||||
|
||||
impl<T: Default> Upgrade<MyStructV1<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStructV1<T>, String> {
|
||||
Ok(MyStructV1(self.0, T::default()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV1<T>(pub u32, pub T);
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV1<T> {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct {
|
||||
count: self.0,
|
||||
attr: T::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStructV1<T>),
|
||||
V2(MyStruct<T>),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyEnumVersions<T: Default> {
|
||||
V0(MyEnum<T>),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let serialized = bincode::serialize(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, u64::default());
|
||||
|
||||
let v2 = v2::MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, u64::default());
|
||||
}
|
||||
31
utils/tfhe-versionable/src/derived_traits.rs
Normal file
31
utils/tfhe-versionable/src/derived_traits.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
//! These traits are not meant to be manually implemented, they are just used in the derive macro
|
||||
//! for easier access to generated types
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::UnversionizeError;
|
||||
|
||||
/// This trait is used to mark a specific version of a given type
|
||||
pub trait Version: Sized {
|
||||
type Ref<'vers>: From<&'vers Self> + Serialize
|
||||
where
|
||||
Self: 'vers;
|
||||
type Owned: for<'vers> From<&'vers Self>
|
||||
+ TryInto<Self, Error = UnversionizeError>
|
||||
+ DeserializeOwned
|
||||
+ Serialize;
|
||||
}
|
||||
|
||||
/// This trait is implemented on the dispatch enum for a given type. The dispatch enum
|
||||
/// is an enum that holds all the versions of the type. Each variant should implement the
|
||||
/// `Version` trait.
|
||||
pub trait VersionsDispatch<Unversioned>: Sized {
|
||||
type Ref<'vers>: From<&'vers Unversioned> + Serialize
|
||||
where
|
||||
Unversioned: 'vers;
|
||||
type Owned: for<'vers> From<&'vers Unversioned>
|
||||
+ TryInto<Unversioned, Error = UnversionizeError>
|
||||
+ DeserializeOwned
|
||||
+ Serialize;
|
||||
}
|
||||
279
utils/tfhe-versionable/src/lib.rs
Normal file
279
utils/tfhe-versionable/src/lib.rs
Normal file
@@ -0,0 +1,279 @@
|
||||
//! Provides a way to add versioning informations/backward compatibility on rust types used for
|
||||
//! serialization.
|
||||
//!
|
||||
//! This crates provides a set of traits [`Versionize`] and [`Unversionize`] that perform a
|
||||
//! conversion between a type and its `Versioned` counterpart. The versioned type is an enum
|
||||
//! that has a variant for each version of the type.
|
||||
//! These traits can be generated using the [`tfhe_versionable_derive::Versionize`] proc macro.
|
||||
|
||||
pub mod derived_traits;
|
||||
pub mod upgrade;
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::fmt::Display;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub use derived_traits::{Version, VersionsDispatch};
|
||||
pub use upgrade::Upgrade;
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
pub use tfhe_versionable_derive::{NotVersioned, Version, Versionize, VersionsDispatch};
|
||||
|
||||
/// This trait means that the type can be converted into a versioned equivalent
|
||||
/// type.
|
||||
pub trait Versionize {
|
||||
/// The equivalent versioned type. It should have a variant for each version.
|
||||
/// It may own the underlying data or only hold a read-only reference to it.
|
||||
type Versioned<'vers>: Serialize
|
||||
where
|
||||
Self: 'vers;
|
||||
|
||||
/// Wraps the object into a versioned enum with a variant for each version. This will
|
||||
/// use references on the underlying types if possible.
|
||||
fn versionize(&self) -> Self::Versioned<'_>;
|
||||
|
||||
type VersionedOwned: Serialize + DeserializeOwned;
|
||||
|
||||
/// Wraps the object into a versioned enum with a variant for each version. This will
|
||||
/// clone the underlying types.
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Errors that can arise in the unversionizing process.
|
||||
pub enum UnversionizeError {
|
||||
/// An error in the upgrade between `vers_from` and `vers_into`
|
||||
Upgrade {
|
||||
from_vers: String,
|
||||
into_vers: String,
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// An error has been returned in the conversion method provided by the `try_from` parameter
|
||||
/// attribute
|
||||
Conversion { from_type: String, message: String },
|
||||
}
|
||||
|
||||
impl Display for UnversionizeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Upgrade {
|
||||
from_vers,
|
||||
into_vers,
|
||||
message,
|
||||
} => write!(
|
||||
f,
|
||||
"Failed to upgrade from {from_vers} into {into_vers}: {message}"
|
||||
),
|
||||
Self::Conversion { from_type, message } => {
|
||||
write!(f, "Failed to convert from {from_type}: {message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UnversionizeError {
|
||||
pub fn upgrade(from_vers: &str, into_vers: &str, message: &str) -> Self {
|
||||
Self::Upgrade {
|
||||
from_vers: from_vers.to_string(),
|
||||
into_vers: into_vers.to_string(),
|
||||
message: message.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conversion(from_type: &str, message: &str) -> Self {
|
||||
Self::Conversion {
|
||||
from_type: from_type.to_string(),
|
||||
message: message.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for UnversionizeError {
|
||||
fn from(_value: Infallible) -> Self {
|
||||
panic!("Infallible error type should never be reached")
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait means that we can convert from a versioned enum into the target type. This trait
|
||||
/// can only be implemented on Owned/static types, whereas `Versionize` can also be implemented
|
||||
/// on reference types.
|
||||
pub trait Unversionize: Versionize + Sized {
|
||||
/// Creates an object from a versioned enum, and eventually upgrades from previous
|
||||
/// variants.
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError>;
|
||||
}
|
||||
|
||||
/// Marker trait for a type that it not really versioned, where the `versionize` method returns
|
||||
/// Self or &Self.
|
||||
pub trait NotVersioned: Versionize {}
|
||||
|
||||
/// Implements the versionable traits for a rust primitive scalar type (integer, float, bool and
|
||||
/// char) Since these types won't move between versions, we consider that they are their own
|
||||
/// versionized types
|
||||
macro_rules! impl_scalar_versionize {
|
||||
($t:ty) => {
|
||||
impl Versionize for $t {
|
||||
type Versioned<'a> = $t;
|
||||
|
||||
type VersionedOwned = $t;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
*self
|
||||
}
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl Unversionize for $t {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(versioned)
|
||||
}
|
||||
}
|
||||
|
||||
impl NotVersioned for $t {}
|
||||
};
|
||||
}
|
||||
|
||||
impl_scalar_versionize!(bool);
|
||||
|
||||
impl_scalar_versionize!(u8);
|
||||
impl_scalar_versionize!(u16);
|
||||
impl_scalar_versionize!(u32);
|
||||
impl_scalar_versionize!(u64);
|
||||
impl_scalar_versionize!(u128);
|
||||
impl_scalar_versionize!(usize);
|
||||
|
||||
impl_scalar_versionize!(i8);
|
||||
impl_scalar_versionize!(i16);
|
||||
impl_scalar_versionize!(i32);
|
||||
impl_scalar_versionize!(i64);
|
||||
impl_scalar_versionize!(i128);
|
||||
|
||||
impl_scalar_versionize!(f32);
|
||||
impl_scalar_versionize!(f64);
|
||||
|
||||
impl_scalar_versionize!(char);
|
||||
|
||||
impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> Versionize for Vec<T> {
|
||||
type Versioned<'vers> = &'vers [T] where T: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self.as_slice()
|
||||
}
|
||||
|
||||
type VersionedOwned = Self;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> Unversionize for Vec<T> {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(versioned)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> NotVersioned for Vec<T> {}
|
||||
|
||||
impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> Versionize for [T] {
|
||||
type Versioned<'vers> = &'vers [T] where T: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self
|
||||
}
|
||||
|
||||
type VersionedOwned = Vec<T>;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> NotVersioned for [T] {}
|
||||
|
||||
impl Versionize for String {
|
||||
type Versioned<'vers> = &'vers str;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
type VersionedOwned = Self;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Unversionize for String {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(versioned)
|
||||
}
|
||||
}
|
||||
|
||||
impl NotVersioned for String {}
|
||||
|
||||
impl Versionize for str {
|
||||
type Versioned<'vers> = &'vers str;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self
|
||||
}
|
||||
|
||||
type VersionedOwned = String;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl NotVersioned for str {}
|
||||
|
||||
impl<T: Versionize> Versionize for Option<T> {
|
||||
type Versioned<'vers> = Option<T::Versioned<'vers>> where T: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self.as_ref().map(|val| val.versionize())
|
||||
}
|
||||
|
||||
type VersionedOwned = Option<T::VersionedOwned>;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
self.as_ref().map(|val| val.versionize_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Unversionize> Unversionize for Option<T> {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
versioned.map(|val| T::unversionize(val)).transpose()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Versionize for PhantomData<T> {
|
||||
type Versioned<'vers> = Self
|
||||
where
|
||||
Self: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
*self
|
||||
}
|
||||
|
||||
type VersionedOwned = Self;
|
||||
|
||||
fn versionize_owned(&self) -> Self::VersionedOwned {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Unversionize for PhantomData<T> {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(versioned)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> NotVersioned for PhantomData<T> {}
|
||||
7
utils/tfhe-versionable/src/upgrade.rs
Normal file
7
utils/tfhe-versionable/src/upgrade.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! How to perform conversion from one version to the next.
|
||||
|
||||
/// This trait should be implemented for each version of the original type that is not the current
|
||||
/// one. The upgrade method is called in chains until we get to the last version of the type.
|
||||
pub trait Upgrade<T> {
|
||||
fn upgrade(self) -> Result<T, String>;
|
||||
}
|
||||
7
utils/tfhe-versionable/tests/derive_macro.rs
Normal file
7
utils/tfhe-versionable/tests/derive_macro.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#[test]
|
||||
fn tests() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.pass("tests/testcases/unit.rs");
|
||||
t.pass("tests/testcases/struct.rs");
|
||||
t.pass("tests/testcases/enum.rs");
|
||||
}
|
||||
262
utils/tfhe-versionable/tests/formats.rs
Normal file
262
utils/tfhe-versionable/tests/formats.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
//! Test that backward compatibility works with various serde compatible formats
|
||||
|
||||
use std::io::Cursor;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::{NotVersioned, Unversionize, Versionize};
|
||||
|
||||
#[derive(Serialize, Deserialize, NotVersioned, Copy, Clone, Eq, PartialEq, Debug)]
|
||||
struct MyU64(u64);
|
||||
|
||||
// Use a better default value for tests that 0
|
||||
impl Default for MyU64 {
|
||||
fn default() -> Self {
|
||||
Self(6789)
|
||||
}
|
||||
}
|
||||
|
||||
mod v0 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct(pub u32);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::VersionsDispatch;
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod v1 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct<T: Default>(pub u32, pub T);
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV0(pub u32);
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct(self.0, T::default()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct<T>),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod v2 {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use backward_compat::MyStructVersions;
|
||||
|
||||
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct<T: Default> {
|
||||
pub count: u32,
|
||||
pub attr: T,
|
||||
}
|
||||
|
||||
mod backward_compat {
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use super::MyStruct;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV0(pub u32);
|
||||
|
||||
impl<T: Default> Upgrade<MyStructV1<T>> for MyStructV0 {
|
||||
fn upgrade(self) -> Result<MyStructV1<T>, String> {
|
||||
Ok(MyStructV1(self.0, T::default()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStructV1<T>(pub u32, pub T);
|
||||
|
||||
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV1<T> {
|
||||
fn upgrade(self) -> Result<MyStruct<T>, String> {
|
||||
Ok(MyStruct {
|
||||
count: self.0,
|
||||
attr: T::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub enum MyStructVersions<T: Default> {
|
||||
V0(MyStructV0),
|
||||
V1(MyStructV1<T>),
|
||||
V2(MyStruct<T>),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bincode() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let v0_ser = bincode::serialize(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(bincode::deserialize(&v0_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let v1_ser = bincode::serialize(&v1.versionize()).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(bincode::deserialize(&v0_ser).unwrap()).unwrap();
|
||||
let v2_from_v1 =
|
||||
v2::MyStruct::<MyU64>::unversionize(bincode::deserialize(&v1_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cbor() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let mut v0_ser = Vec::new();
|
||||
ciborium::ser::into_writer(&v0.versionize(), &mut v0_ser).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(
|
||||
ciborium::de::from_reader(&mut Cursor::new(&v0_ser)).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let mut v1_ser = Vec::new();
|
||||
ciborium::ser::into_writer(&v1.versionize(), &mut v1_ser).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(
|
||||
ciborium::de::from_reader(&mut Cursor::new(&v0_ser)).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let v2_from_v1 = v2::MyStruct::<MyU64>::unversionize(
|
||||
ciborium::de::from_reader(&mut Cursor::new(&v1_ser)).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messagepack() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let v0_ser = rmp_serde::to_vec(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(rmp_serde::from_slice(&v0_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let v1_ser = rmp_serde::to_vec(&v1.versionize()).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(rmp_serde::from_slice(&v0_ser).unwrap()).unwrap();
|
||||
let v2_from_v1 =
|
||||
v2::MyStruct::<MyU64>::unversionize(rmp_serde::from_slice(&v1_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let v0_ser = serde_json::to_string(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(serde_json::from_str(&v0_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let v1_ser = serde_json::to_string(&v1.versionize()).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(serde_json::from_str(&v0_ser).unwrap()).unwrap();
|
||||
let v2_from_v1 =
|
||||
v2::MyStruct::<MyU64>::unversionize(serde_json::from_str(&v1_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yaml() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let v0_ser = serde_yaml::to_string(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(serde_yaml::from_str(&v0_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let v1_ser = serde_yaml::to_string(&v1.versionize()).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(serde_yaml::from_str(&v0_ser).unwrap()).unwrap();
|
||||
let v2_from_v1 =
|
||||
v2::MyStruct::<MyU64>::unversionize(serde_yaml::from_str(&v1_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_toml() {
|
||||
let v0 = v0::MyStruct(37);
|
||||
|
||||
let v0_ser = toml::to_string(&v0.versionize()).unwrap();
|
||||
|
||||
let v1 = v1::MyStruct::<MyU64>::unversionize(toml::from_str(&v0_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v1.0);
|
||||
assert_eq!(v1.1, MyU64::default());
|
||||
|
||||
let v1_ser = toml::to_string(&v1.versionize()).unwrap();
|
||||
|
||||
let v2 = v2::MyStruct::<MyU64>::unversionize(toml::from_str(&v0_ser).unwrap()).unwrap();
|
||||
let v2_from_v1 = v2::MyStruct::<MyU64>::unversionize(toml::from_str(&v1_ser).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(v0.0, v2.count);
|
||||
assert_eq!(v2.attr, MyU64::default());
|
||||
assert_eq!(v2, v2_from_v1);
|
||||
}
|
||||
29
utils/tfhe-versionable/tests/testcases/enum.rs
Normal file
29
utils/tfhe-versionable/tests/testcases/enum.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use static_assertions::assert_impl_all;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::{NotVersioned, Version};
|
||||
|
||||
// Simple contentless enum
|
||||
#[derive(Version)]
|
||||
pub enum MyEnum {
|
||||
Variant0,
|
||||
Variant1,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, NotVersioned)]
|
||||
pub struct MyStruct {
|
||||
val: u32,
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
pub enum MyEnum2<T> {
|
||||
Variant1(MyStruct),
|
||||
Variant2 { val1: u64, val2: u32 },
|
||||
Variant3(T),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
assert_impl_all!(MyEnum: Version);
|
||||
|
||||
assert_impl_all!(MyEnum2<u64>: Version);
|
||||
}
|
||||
53
utils/tfhe-versionable/tests/testcases/struct.rs
Normal file
53
utils/tfhe-versionable/tests/testcases/struct.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use static_assertions::assert_impl_all;
|
||||
|
||||
use tfhe_versionable::{Version, Versionize, VersionsDispatch};
|
||||
|
||||
// Empty struct
|
||||
#[derive(Version)]
|
||||
pub struct MyEmptyStruct();
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyEmptyStruct2 {}
|
||||
|
||||
// Simple anonymous struct
|
||||
#[derive(Version)]
|
||||
pub struct MyAnonStruct(u32);
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyAnonStruct2(u32, u64);
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyAnonStruct3<T>(u32, T);
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub struct MyStruct<T> {
|
||||
field0: u64,
|
||||
field1: T,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum MyStructVersions<T> {
|
||||
V0(MyStruct<T>),
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyStruct2<T, U> {
|
||||
field0: MyStruct<T>,
|
||||
field1: U,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
assert_impl_all!(MyEmptyStruct: Version);
|
||||
assert_impl_all!(MyEmptyStruct2: Version);
|
||||
|
||||
assert_impl_all!(MyAnonStruct: Version);
|
||||
|
||||
assert_impl_all!(MyAnonStruct2: Version);
|
||||
|
||||
assert_impl_all!(MyAnonStruct3<u64>: Version);
|
||||
|
||||
assert_impl_all!(MyStruct<u32>: Version);
|
||||
|
||||
assert_impl_all!(MyStruct2<usize, String>: Version);
|
||||
}
|
||||
10
utils/tfhe-versionable/tests/testcases/unit.rs
Normal file
10
utils/tfhe-versionable/tests/testcases/unit.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use static_assertions::assert_impl_all;
|
||||
|
||||
use tfhe_versionable::Version;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct MyUnit;
|
||||
|
||||
fn main() {
|
||||
assert_impl_all!(MyUnit: Version);
|
||||
}
|
||||
Reference in New Issue
Block a user