feat(frontend-rust): add printer for tensors

This commit is contained in:
Alexandre Péré
2025-04-24 14:44:50 +02:00
parent 2090e21d4a
commit cc90e5f12b

View File

@@ -1,16 +1,18 @@
#![allow(unused_imports,unused)]
use std::{any::Any, marker::PhantomData, pin::Pin};
#![allow(unused_imports, unused)]
use std::any::Any;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::pin::Pin;
use crate::protocol::{
KeysetInfo, LweBootstrapKeyInfo, LweKeyswitchKeyInfo, LweSecretKeyInfo,
PackingKeyswitchKeyInfo, ProgramInfo, CircuitInfo
CircuitInfo, KeysetInfo, LweBootstrapKeyInfo, LweKeyswitchKeyInfo, LweSecretKeyInfo,
PackingKeyswitchKeyInfo, ProgramInfo,
};
use cxx::{SharedPtr, UniquePtr, CxxVector};
use cxx::{CxxVector, SharedPtr, UniquePtr};
#[cxx::bridge(namespace = "concrete_rust")]
mod ffi {
unsafe extern "C++" {
include!("ffi.h");
type c_void;
@@ -179,7 +181,6 @@ mod ffi {
#[doc(hidden)]
fn _deserialize_server_keyset(bytes: &[u8]) -> UniquePtr<ServerKeyset>;
/// A client keyset holding the keys necessary to __encrypt__ input data (and decrypt outputs).
///
/// Warning:
@@ -344,7 +345,10 @@ mod ffi {
encryption_prng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientModule>;
#[doc(hidden)]
fn _get_client_function(self: &ClientModule, name: &str) -> Result<UniquePtr<ClientFunction>>;
fn _get_client_function(
self: &ClientModule,
name: &str,
) -> Result<UniquePtr<ClientFunction>>;
/// Client-side interface to an FHE function.
///
@@ -356,29 +360,45 @@ mod ffi {
fn _client_function_new_encrypted(
circuit_info_json: &str,
client_keyset: &ClientKeyset,
encryption_prng: UniquePtr<EncryptionCsprng>
encryption_prng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientFunction>;
#[doc(hidden)]
fn _client_function_new_simulated(
circuit_info_json: &str,
encryption_prng: UniquePtr<EncryptionCsprng>
encryption_prng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientFunction>;
/// Prepare one function input to be sent to the server.
///
/// Note:
/// -----
/// This include encoding -> encryption -> conversion to serializable value.
fn prepare_input(self: Pin<&mut ClientFunction>, arg: UniquePtr<Value>, pos: usize) -> UniquePtr<TransportValue>;
fn prepare_input(
self: Pin<&mut ClientFunction>,
arg: UniquePtr<Value>,
pos: usize,
) -> UniquePtr<TransportValue>;
/// Process one function output retrieved from the server.
///
/// Note:
/// -----
/// This include conversion from deserializable value -> decryption -> decoding.
fn process_output(self: Pin<&mut ClientFunction>,result: UniquePtr<TransportValue>, pos: usize) -> UniquePtr<Value>;
fn process_output(
self: Pin<&mut ClientFunction>,
result: UniquePtr<TransportValue>,
pos: usize,
) -> UniquePtr<Value>;
#[doc(hidden)]
fn simulate_prepare_input(self: Pin<&mut ClientFunction>,arg: &Value, pos: usize) -> UniquePtr<TransportValue>;
fn simulate_prepare_input(
self: Pin<&mut ClientFunction>,
arg: &Value,
pos: usize,
) -> UniquePtr<TransportValue>;
#[doc(hidden)]
fn simulate_process_output(self: Pin<&mut ClientFunction>,result: &TransportValue, pos: usize) -> UniquePtr<Value>;
fn simulate_process_output(
self: Pin<&mut ClientFunction>,
result: &TransportValue,
pos: usize,
) -> UniquePtr<Value>;
// ------------------------------------------------------------------------------------------- Server
//
@@ -389,31 +409,42 @@ mod ffi {
/// This object allows to invoke the FHE function on the encrypted inputs coming from the client.
type ServerFunction;
#[doc(hidden)]
unsafe fn _server_function_new(circuit_info_json: &str, func: *mut c_void, use_simulation: bool) -> UniquePtr<ServerFunction>;
unsafe fn _server_function_new(
circuit_info_json: &str,
func: *mut c_void,
use_simulation: bool,
) -> UniquePtr<ServerFunction>;
#[doc(hidden)]
fn _call(self: Pin<&mut ServerFunction>, keys: &ServerKeyset, args: &mut [UniquePtr<TransportValue>]) -> UniquePtr<CxxVector<TransportValue>>;
fn _call(
self: Pin<&mut ServerFunction>,
keys: &ServerKeyset,
args: &mut [UniquePtr<TransportValue>],
) -> UniquePtr<CxxVector<TransportValue>>;
#[doc(hidden)]
fn _simulate(self: Pin<&mut ServerFunction>, args: &mut [UniquePtr<TransportValue>]) -> UniquePtr<CxxVector<TransportValue>>;
fn _simulate(
self: Pin<&mut ServerFunction>,
args: &mut [UniquePtr<TransportValue>],
) -> UniquePtr<CxxVector<TransportValue>>;
}
}
pub use ffi::*;
impl ServerKeyset{
impl ServerKeyset {
/// Deserialize a server keyset from bytes.
pub fn deserialize(bytes: &[u8]) -> UniquePtr<ServerKeyset> {
_deserialize_server_keyset(bytes)
}
}
impl ClientKeyset{
impl ClientKeyset {
/// Deserialize a client keyset from bytes.
pub fn deserialize(bytes: &[u8]) -> UniquePtr<ClientKeyset> {
_deserialize_client_keyset(bytes)
}
}
impl TransportValue{
impl TransportValue {
/// Deserialize a `TransportValue` from bytes.
pub fn deserialize(bytes: &[u8]) -> UniquePtr<TransportValue> {
_deserialize_transport_value(bytes)
@@ -421,26 +452,36 @@ impl TransportValue{
}
impl ServerFunction {
#[doc(hidden)]
pub fn new(circuit_info: &CircuitInfo, func: *mut c_void, use_simulation: bool) -> UniquePtr<ServerFunction> {
unsafe{
pub fn new(
circuit_info: &CircuitInfo,
func: *mut c_void,
use_simulation: bool,
) -> UniquePtr<ServerFunction> {
unsafe {
_server_function_new(
&serde_json::to_string(circuit_info).unwrap(),
func,
use_simulation
use_simulation,
)
}
}
/// Performs a call to the FHE function using the `keys` server keyset and the `args` arguments.
pub fn call(self: Pin<&mut ServerFunction>, keys: &ServerKeyset, mut args: Vec<UniquePtr<TransportValue>>) -> Vec<UniquePtr<TransportValue>>{
pub fn call(
self: Pin<&mut ServerFunction>,
keys: &ServerKeyset,
mut args: Vec<UniquePtr<TransportValue>>,
) -> Vec<UniquePtr<TransportValue>> {
let output = self._call(keys, args.as_mut_slice());
output.iter().map(|v| v.to_owned()).collect()
}
#[doc(hidden)]
pub fn simulate(self: Pin<&mut ServerFunction>, mut args: Vec<UniquePtr<TransportValue>>) -> Vec<UniquePtr<TransportValue>>{
pub fn simulate(
self: Pin<&mut ServerFunction>,
mut args: Vec<UniquePtr<TransportValue>>,
) -> Vec<UniquePtr<TransportValue>> {
let output = self._simulate(args.as_mut_slice());
output.iter().map(|v| v.to_owned()).collect()
}
@@ -539,7 +580,9 @@ pub trait GetValues {
impl GetValues for Tensor<u8> {
type Element = u8;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::U8(inner) = &self.inner else {unreachable!()};
let InnerTensor::U8(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -547,7 +590,9 @@ impl GetValues for Tensor<u8> {
impl GetValues for Tensor<u16> {
type Element = u16;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::U16(inner) = &self.inner else {unreachable!()};
let InnerTensor::U16(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -555,7 +600,9 @@ impl GetValues for Tensor<u16> {
impl GetValues for Tensor<u32> {
type Element = u32;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::U32(inner) = &self.inner else {unreachable!()};
let InnerTensor::U32(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -563,7 +610,9 @@ impl GetValues for Tensor<u32> {
impl GetValues for Tensor<u64> {
type Element = u64;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::U64(inner) = &self.inner else {unreachable!()};
let InnerTensor::U64(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -571,7 +620,9 @@ impl GetValues for Tensor<u64> {
impl GetValues for Tensor<i8> {
type Element = i8;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::I8(inner) = &self.inner else {unreachable!()};
let InnerTensor::I8(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -579,7 +630,9 @@ impl GetValues for Tensor<i8> {
impl GetValues for Tensor<i16> {
type Element = i16;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::I16(inner) = &self.inner else {unreachable!()};
let InnerTensor::I16(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -587,7 +640,9 @@ impl GetValues for Tensor<i16> {
impl GetValues for Tensor<i32> {
type Element = i32;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::I32(inner) = &self.inner else {unreachable!()};
let InnerTensor::I32(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
@@ -595,26 +650,139 @@ impl GetValues for Tensor<i32> {
impl GetValues for Tensor<i64> {
type Element = i64;
fn get_values(&self) -> &[Self::Element] {
let InnerTensor::I64(inner) = &self.inner else {unreachable!()};
let InnerTensor::I64(inner) = &self.inner else {
unreachable!()
};
inner._get_values()
}
}
struct TensorPrinter<'a, T> {
vals: &'a [T],
dims: &'a [usize],
}
impl<'a, T> Debug for TensorPrinter<'a, T>
where
T: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let width = self
.vals
.iter()
.map(|v| format!("{:?}", v).len())
.max()
.unwrap();
fn format_recursive<J: Debug>(
f: &mut std::fmt::Formatter<'_>,
vals: &[J],
dims: &[usize],
indent: usize,
width: usize,
) -> std::fmt::Result {
match dims.len() {
1 => {
if f.alternate() {
write!(f, "{:indent$}", "", indent = indent)?;
write!(f, "{:width$?}", vals)?;
} else {
write!(f, "{:?}", vals)?;
}
}
_ => {
let stride = dims[1..].iter().product::<usize>();
if f.alternate() {
write!(f, "{:indent$}", "", indent = indent)?;
}
write!(f, "[")?;
if f.alternate() {
writeln!(f)?;
}
for i in 0..dims[0] {
format_recursive(
f,
&vals[i * stride..(i + 1) * stride],
&dims[1..],
indent + 2,
width,
)?;
write!(f, ",")?;
if f.alternate() {
writeln!(f)?;
}
}
if f.alternate() {
write!(f, "{:indent$}", "", indent = indent)?;
}
write!(f, "]")?;
}
}
Ok(())
}
if self.dims.len() == 0 {
self.vals[0].fmt(f)
} else {
format_recursive(f, &self.vals, &self.dims, 0, width)
}
}
}
/// A generic tensor type.
pub struct Tensor<T>{
pub struct Tensor<T> {
inner: InnerTensor,
phantom: PhantomData<T>
phantom: PhantomData<T>,
}
impl<T> Debug for Tensor<T>
where
Self: GetValues,
<Self as GetValues>::Element: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tensor")
.field(
"values",
&TensorPrinter {
vals: self.get_values(),
dims: self.dimensions(),
},
)
.field("dimensions", &self.dimensions())
.field(
"dtype",
match self.inner {
InnerTensor::U8(_) => &"u8",
InnerTensor::U16(_) => &"u16",
InnerTensor::U32(_) => &"u32",
InnerTensor::U64(_) => &"u64",
InnerTensor::I8(_) => &"i8",
InnerTensor::I16(_) => &"i16",
InnerTensor::I32(_) => &"i32",
InnerTensor::I64(_) => &"i64",
},
)
.finish()
}
}
impl<T> Tensor<T> {
/// Create a new tensor from values and dimensions (shape).
pub fn new(values: Vec<T>, dimensions: Vec<usize>) -> Tensor<T> where Tensor<T>:FromElements<Element=T> {
pub fn new(values: Vec<T>, dimensions: Vec<usize>) -> Tensor<T>
where
Tensor<T>: FromElements<Element = T>,
{
assert_eq!(
values.len(),
dimensions.iter().product::<usize>(),
"Wrong number of dimensions provided"
);
Self::from_elements(values, dimensions)
}
/// Return the dimensions of the tensor.
pub fn dimensions(&self) -> &[usize] {
match self.inner{
match self.inner {
InnerTensor::U8(ref unique_ptr) => unique_ptr._get_dimensions(),
InnerTensor::U16(ref unique_ptr) => unique_ptr._get_dimensions(),
InnerTensor::U32(ref unique_ptr) => unique_ptr._get_dimensions(),
@@ -627,7 +795,10 @@ impl<T> Tensor<T> {
}
/// Return the values of the tensor.
pub fn values(&self) -> &[T] where Tensor<T>:GetValues<Element=T> {
pub fn values(&self) -> &[T]
where
Tensor<T>: GetValues<Element = T>,
{
self.get_values()
}
}
@@ -643,102 +814,150 @@ enum InnerTensor {
I64(UniquePtr<TensorI64>),
}
impl Debug for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self._has_element_type_u8() {
f.debug_tuple("Value")
.field(&Tensor::<u8> {
inner: InnerTensor::U8(self._get_tensor_u8()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_u16() {
f.debug_tuple("Value")
.field(&Tensor::<u16> {
inner: InnerTensor::U16(self._get_tensor_u16()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_u32() {
f.debug_tuple("Value")
.field(&Tensor::<u32> {
inner: InnerTensor::U32(self._get_tensor_u32()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_u64() {
f.debug_tuple("Value")
.field(&Tensor::<u64> {
inner: InnerTensor::U64(self._get_tensor_u64()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_i8() {
f.debug_tuple("Value")
.field(&Tensor::<i8> {
inner: InnerTensor::I8(self._get_tensor_i8()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_i16() {
f.debug_tuple("Value")
.field(&Tensor::<i16> {
inner: InnerTensor::I16(self._get_tensor_i16()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_i32() {
f.debug_tuple("Value")
.field(&Tensor::<i32> {
inner: InnerTensor::I32(self._get_tensor_i32()),
phantom: PhantomData,
})
.finish()
} else if self._has_element_type_i64() {
f.debug_tuple("Value")
.field(&Tensor::<i64> {
inner: InnerTensor::I64(self._get_tensor_i64()),
phantom: PhantomData,
})
.finish()
} else {
unreachable!()
}
}
}
pub trait GetTensor<T> {
fn _get_tensor(&self) -> Option<Tensor<T>>;
}
impl GetTensor<u8> for Value {
fn _get_tensor(&self) -> Option<Tensor<u8>> {
self._has_element_type_u8().then(|| {
Tensor {
inner: InnerTensor::U8(self._get_tensor_u8()),
phantom: PhantomData,
}
self._has_element_type_u8().then(|| Tensor {
inner: InnerTensor::U8(self._get_tensor_u8()),
phantom: PhantomData,
})
}
}
impl GetTensor<u16> for Value {
fn _get_tensor(&self) -> Option<Tensor<u16>> {
self._has_element_type_u16().then(|| {
Tensor {
inner: InnerTensor::U16(self._get_tensor_u16()),
phantom: PhantomData,
}
self._has_element_type_u16().then(|| Tensor {
inner: InnerTensor::U16(self._get_tensor_u16()),
phantom: PhantomData,
})
}
}
impl GetTensor<u32> for Value {
fn _get_tensor(&self) -> Option<Tensor<u32>> {
self._has_element_type_u32().then(|| {
Tensor {
inner: InnerTensor::U32(self._get_tensor_u32()),
phantom: PhantomData,
}
self._has_element_type_u32().then(|| Tensor {
inner: InnerTensor::U32(self._get_tensor_u32()),
phantom: PhantomData,
})
}
}
impl GetTensor<u64> for Value {
fn _get_tensor(&self) -> Option<Tensor<u64>> {
self._has_element_type_u64().then(|| {
Tensor {
inner: InnerTensor::U64(self._get_tensor_u64()),
phantom: PhantomData,
}
self._has_element_type_u64().then(|| Tensor {
inner: InnerTensor::U64(self._get_tensor_u64()),
phantom: PhantomData,
})
}
}
impl GetTensor<i8> for Value {
fn _get_tensor(&self) -> Option<Tensor<i8>> {
self._has_element_type_i8().then(|| {
Tensor {
inner: InnerTensor::I8(self._get_tensor_i8()),
phantom: PhantomData,
}
self._has_element_type_i8().then(|| Tensor {
inner: InnerTensor::I8(self._get_tensor_i8()),
phantom: PhantomData,
})
}
}
impl GetTensor<i16> for Value {
fn _get_tensor(&self) -> Option<Tensor<i16>> {
self._has_element_type_i16().then(|| {
Tensor {
inner: InnerTensor::I16(self._get_tensor_i16()),
phantom: PhantomData,
}
self._has_element_type_i16().then(|| Tensor {
inner: InnerTensor::I16(self._get_tensor_i16()),
phantom: PhantomData,
})
}
}
impl GetTensor<i32> for Value {
fn _get_tensor(&self) -> Option<Tensor<i32>> {
self._has_element_type_i32().then(|| {
Tensor {
inner: InnerTensor::I32(self._get_tensor_i32()),
phantom: PhantomData,
}
self._has_element_type_i32().then(|| Tensor {
inner: InnerTensor::I32(self._get_tensor_i32()),
phantom: PhantomData,
})
}
}
impl GetTensor<i64> for Value {
fn _get_tensor(&self) -> Option<Tensor<i64>> {
self._has_element_type_i64().then(|| {
Tensor {
inner: InnerTensor::I64(self._get_tensor_i64()),
phantom: PhantomData,
}
self._has_element_type_i64().then(|| Tensor {
inner: InnerTensor::I64(self._get_tensor_i64()),
phantom: PhantomData,
})
}
}
impl Value{
impl Value {
/// Create a `Value` from a `Tensor`.
pub fn from_tensor<T>(input: Tensor<T>) -> UniquePtr<Value> {
match input.inner{
match input.inner {
InnerTensor::U8(unique_ptr) => _value_from_tensor_u8(unique_ptr),
InnerTensor::U16(unique_ptr) => _value_from_tensor_u16(unique_ptr),
InnerTensor::U32(unique_ptr) => _value_from_tensor_u32(unique_ptr),
@@ -751,14 +970,21 @@ impl Value{
}
/// Unwrap the value to a tensor of the given type (if it indeed holds a tensor with elements of this type).
pub fn get_tensor<T>(&self) -> Option<Tensor<T>> where Self: GetTensor<T>{
pub fn get_tensor<T>(&self) -> Option<Tensor<T>>
where
Self: GetTensor<T>,
{
self._get_tensor()
}
}
impl ClientModule {
/// Create a new client module.
pub fn new_encrypted(program_info: &ProgramInfo, client_keyset: &ClientKeyset, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientModule> {
pub fn new_encrypted(
program_info: &ProgramInfo,
client_keyset: &ClientKeyset,
csprng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientModule> {
_client_module_new_encrypted(
&serde_json::to_string(program_info).unwrap(),
client_keyset,
@@ -767,17 +993,21 @@ impl ClientModule {
}
#[doc(hidden)]
pub fn new_simulated(program_info: &ProgramInfo, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientModule> {
_client_module_new_simulated(
&serde_json::to_string(program_info).unwrap(),
csprng,
)
pub fn new_simulated(
program_info: &ProgramInfo,
csprng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientModule> {
_client_module_new_simulated(&serde_json::to_string(program_info).unwrap(), csprng)
}
}
impl ClientFunction {
/// Create a new client function.
pub fn new_encrypted(circuit_info: &CircuitInfo, client_keyset: &ClientKeyset, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientFunction> {
pub fn new_encrypted(
circuit_info: &CircuitInfo,
client_keyset: &ClientKeyset,
csprng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientFunction> {
_client_function_new_encrypted(
&serde_json::to_string(circuit_info).unwrap(),
client_keyset,
@@ -786,16 +1016,15 @@ impl ClientFunction {
}
#[doc(hidden)]
pub fn new_simulated(circuit_info: &CircuitInfo, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientFunction> {
_client_function_new_simulated(
&serde_json::to_string(circuit_info).unwrap(),
csprng,
)
pub fn new_simulated(
circuit_info: &CircuitInfo,
csprng: UniquePtr<EncryptionCsprng>,
) -> UniquePtr<ClientFunction> {
_client_function_new_simulated(&serde_json::to_string(circuit_info).unwrap(), csprng)
}
}
impl ServerKeyset {
/// Return references to the lwe bootstrap keys of this server keyset.
pub fn lwe_bootstrap_keys(&self) -> Vec<&LweBootstrapKey> {
(0..self._lwe_bootstrap_keys_len())
@@ -900,7 +1129,7 @@ impl LweBootstrapKey {
}
}
impl std::fmt::Debug for TransportValue{
impl std::fmt::Debug for TransportValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TransportValue")
}