mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(frontend-rust): add printer for tensors
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
#![allow(unused_imports,unused)]
|
||||
use std::{any::Any, marker::PhantomData, pin::Pin};
|
||||
#![allow(unused_imports, unused)]
|
||||
use std::any::Any;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::protocol::{
|
||||
KeysetInfo, LweBootstrapKeyInfo, LweKeyswitchKeyInfo, LweSecretKeyInfo,
|
||||
PackingKeyswitchKeyInfo, ProgramInfo, CircuitInfo
|
||||
CircuitInfo, KeysetInfo, LweBootstrapKeyInfo, LweKeyswitchKeyInfo, LweSecretKeyInfo,
|
||||
PackingKeyswitchKeyInfo, ProgramInfo,
|
||||
};
|
||||
use cxx::{SharedPtr, UniquePtr, CxxVector};
|
||||
use cxx::{CxxVector, SharedPtr, UniquePtr};
|
||||
|
||||
#[cxx::bridge(namespace = "concrete_rust")]
|
||||
mod ffi {
|
||||
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("ffi.h");
|
||||
type c_void;
|
||||
@@ -179,7 +181,6 @@ mod ffi {
|
||||
#[doc(hidden)]
|
||||
fn _deserialize_server_keyset(bytes: &[u8]) -> UniquePtr<ServerKeyset>;
|
||||
|
||||
|
||||
/// A client keyset holding the keys necessary to __encrypt__ input data (and decrypt outputs).
|
||||
///
|
||||
/// Warning:
|
||||
@@ -344,7 +345,10 @@ mod ffi {
|
||||
encryption_prng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientModule>;
|
||||
#[doc(hidden)]
|
||||
fn _get_client_function(self: &ClientModule, name: &str) -> Result<UniquePtr<ClientFunction>>;
|
||||
fn _get_client_function(
|
||||
self: &ClientModule,
|
||||
name: &str,
|
||||
) -> Result<UniquePtr<ClientFunction>>;
|
||||
|
||||
/// Client-side interface to an FHE function.
|
||||
///
|
||||
@@ -356,29 +360,45 @@ mod ffi {
|
||||
fn _client_function_new_encrypted(
|
||||
circuit_info_json: &str,
|
||||
client_keyset: &ClientKeyset,
|
||||
encryption_prng: UniquePtr<EncryptionCsprng>
|
||||
encryption_prng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientFunction>;
|
||||
#[doc(hidden)]
|
||||
fn _client_function_new_simulated(
|
||||
circuit_info_json: &str,
|
||||
encryption_prng: UniquePtr<EncryptionCsprng>
|
||||
encryption_prng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientFunction>;
|
||||
/// Prepare one function input to be sent to the server.
|
||||
///
|
||||
/// Note:
|
||||
/// -----
|
||||
/// This include encoding -> encryption -> conversion to serializable value.
|
||||
fn prepare_input(self: Pin<&mut ClientFunction>, arg: UniquePtr<Value>, pos: usize) -> UniquePtr<TransportValue>;
|
||||
fn prepare_input(
|
||||
self: Pin<&mut ClientFunction>,
|
||||
arg: UniquePtr<Value>,
|
||||
pos: usize,
|
||||
) -> UniquePtr<TransportValue>;
|
||||
/// Process one function output retrieved from the server.
|
||||
///
|
||||
/// Note:
|
||||
/// -----
|
||||
/// This include conversion from deserializable value -> decryption -> decoding.
|
||||
fn process_output(self: Pin<&mut ClientFunction>,result: UniquePtr<TransportValue>, pos: usize) -> UniquePtr<Value>;
|
||||
fn process_output(
|
||||
self: Pin<&mut ClientFunction>,
|
||||
result: UniquePtr<TransportValue>,
|
||||
pos: usize,
|
||||
) -> UniquePtr<Value>;
|
||||
#[doc(hidden)]
|
||||
fn simulate_prepare_input(self: Pin<&mut ClientFunction>,arg: &Value, pos: usize) -> UniquePtr<TransportValue>;
|
||||
fn simulate_prepare_input(
|
||||
self: Pin<&mut ClientFunction>,
|
||||
arg: &Value,
|
||||
pos: usize,
|
||||
) -> UniquePtr<TransportValue>;
|
||||
#[doc(hidden)]
|
||||
fn simulate_process_output(self: Pin<&mut ClientFunction>,result: &TransportValue, pos: usize) -> UniquePtr<Value>;
|
||||
fn simulate_process_output(
|
||||
self: Pin<&mut ClientFunction>,
|
||||
result: &TransportValue,
|
||||
pos: usize,
|
||||
) -> UniquePtr<Value>;
|
||||
|
||||
// ------------------------------------------------------------------------------------------- Server
|
||||
//
|
||||
@@ -389,31 +409,42 @@ mod ffi {
|
||||
/// This object allows to invoke the FHE function on the encrypted inputs coming from the client.
|
||||
type ServerFunction;
|
||||
#[doc(hidden)]
|
||||
unsafe fn _server_function_new(circuit_info_json: &str, func: *mut c_void, use_simulation: bool) -> UniquePtr<ServerFunction>;
|
||||
unsafe fn _server_function_new(
|
||||
circuit_info_json: &str,
|
||||
func: *mut c_void,
|
||||
use_simulation: bool,
|
||||
) -> UniquePtr<ServerFunction>;
|
||||
#[doc(hidden)]
|
||||
fn _call(self: Pin<&mut ServerFunction>, keys: &ServerKeyset, args: &mut [UniquePtr<TransportValue>]) -> UniquePtr<CxxVector<TransportValue>>;
|
||||
fn _call(
|
||||
self: Pin<&mut ServerFunction>,
|
||||
keys: &ServerKeyset,
|
||||
args: &mut [UniquePtr<TransportValue>],
|
||||
) -> UniquePtr<CxxVector<TransportValue>>;
|
||||
#[doc(hidden)]
|
||||
fn _simulate(self: Pin<&mut ServerFunction>, args: &mut [UniquePtr<TransportValue>]) -> UniquePtr<CxxVector<TransportValue>>;
|
||||
fn _simulate(
|
||||
self: Pin<&mut ServerFunction>,
|
||||
args: &mut [UniquePtr<TransportValue>],
|
||||
) -> UniquePtr<CxxVector<TransportValue>>;
|
||||
}
|
||||
}
|
||||
|
||||
pub use ffi::*;
|
||||
|
||||
impl ServerKeyset{
|
||||
impl ServerKeyset {
|
||||
/// Deserialize a server keyset from bytes.
|
||||
pub fn deserialize(bytes: &[u8]) -> UniquePtr<ServerKeyset> {
|
||||
_deserialize_server_keyset(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientKeyset{
|
||||
impl ClientKeyset {
|
||||
/// Deserialize a client keyset from bytes.
|
||||
pub fn deserialize(bytes: &[u8]) -> UniquePtr<ClientKeyset> {
|
||||
_deserialize_client_keyset(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TransportValue{
|
||||
impl TransportValue {
|
||||
/// Deserialize a `TransportValue` from bytes.
|
||||
pub fn deserialize(bytes: &[u8]) -> UniquePtr<TransportValue> {
|
||||
_deserialize_transport_value(bytes)
|
||||
@@ -421,26 +452,36 @@ impl TransportValue{
|
||||
}
|
||||
|
||||
impl ServerFunction {
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn new(circuit_info: &CircuitInfo, func: *mut c_void, use_simulation: bool) -> UniquePtr<ServerFunction> {
|
||||
unsafe{
|
||||
pub fn new(
|
||||
circuit_info: &CircuitInfo,
|
||||
func: *mut c_void,
|
||||
use_simulation: bool,
|
||||
) -> UniquePtr<ServerFunction> {
|
||||
unsafe {
|
||||
_server_function_new(
|
||||
&serde_json::to_string(circuit_info).unwrap(),
|
||||
func,
|
||||
use_simulation
|
||||
use_simulation,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs a call to the FHE function using the `keys` server keyset and the `args` arguments.
|
||||
pub fn call(self: Pin<&mut ServerFunction>, keys: &ServerKeyset, mut args: Vec<UniquePtr<TransportValue>>) -> Vec<UniquePtr<TransportValue>>{
|
||||
pub fn call(
|
||||
self: Pin<&mut ServerFunction>,
|
||||
keys: &ServerKeyset,
|
||||
mut args: Vec<UniquePtr<TransportValue>>,
|
||||
) -> Vec<UniquePtr<TransportValue>> {
|
||||
let output = self._call(keys, args.as_mut_slice());
|
||||
output.iter().map(|v| v.to_owned()).collect()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn simulate(self: Pin<&mut ServerFunction>, mut args: Vec<UniquePtr<TransportValue>>) -> Vec<UniquePtr<TransportValue>>{
|
||||
pub fn simulate(
|
||||
self: Pin<&mut ServerFunction>,
|
||||
mut args: Vec<UniquePtr<TransportValue>>,
|
||||
) -> Vec<UniquePtr<TransportValue>> {
|
||||
let output = self._simulate(args.as_mut_slice());
|
||||
output.iter().map(|v| v.to_owned()).collect()
|
||||
}
|
||||
@@ -539,7 +580,9 @@ pub trait GetValues {
|
||||
impl GetValues for Tensor<u8> {
|
||||
type Element = u8;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::U8(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::U8(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -547,7 +590,9 @@ impl GetValues for Tensor<u8> {
|
||||
impl GetValues for Tensor<u16> {
|
||||
type Element = u16;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::U16(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::U16(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -555,7 +600,9 @@ impl GetValues for Tensor<u16> {
|
||||
impl GetValues for Tensor<u32> {
|
||||
type Element = u32;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::U32(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::U32(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -563,7 +610,9 @@ impl GetValues for Tensor<u32> {
|
||||
impl GetValues for Tensor<u64> {
|
||||
type Element = u64;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::U64(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::U64(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -571,7 +620,9 @@ impl GetValues for Tensor<u64> {
|
||||
impl GetValues for Tensor<i8> {
|
||||
type Element = i8;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::I8(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::I8(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -579,7 +630,9 @@ impl GetValues for Tensor<i8> {
|
||||
impl GetValues for Tensor<i16> {
|
||||
type Element = i16;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::I16(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::I16(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -587,7 +640,9 @@ impl GetValues for Tensor<i16> {
|
||||
impl GetValues for Tensor<i32> {
|
||||
type Element = i32;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::I32(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::I32(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
@@ -595,26 +650,139 @@ impl GetValues for Tensor<i32> {
|
||||
impl GetValues for Tensor<i64> {
|
||||
type Element = i64;
|
||||
fn get_values(&self) -> &[Self::Element] {
|
||||
let InnerTensor::I64(inner) = &self.inner else {unreachable!()};
|
||||
let InnerTensor::I64(inner) = &self.inner else {
|
||||
unreachable!()
|
||||
};
|
||||
inner._get_values()
|
||||
}
|
||||
}
|
||||
|
||||
struct TensorPrinter<'a, T> {
|
||||
vals: &'a [T],
|
||||
dims: &'a [usize],
|
||||
}
|
||||
|
||||
impl<'a, T> Debug for TensorPrinter<'a, T>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let width = self
|
||||
.vals
|
||||
.iter()
|
||||
.map(|v| format!("{:?}", v).len())
|
||||
.max()
|
||||
.unwrap();
|
||||
fn format_recursive<J: Debug>(
|
||||
f: &mut std::fmt::Formatter<'_>,
|
||||
vals: &[J],
|
||||
dims: &[usize],
|
||||
indent: usize,
|
||||
width: usize,
|
||||
) -> std::fmt::Result {
|
||||
match dims.len() {
|
||||
1 => {
|
||||
if f.alternate() {
|
||||
write!(f, "{:indent$}", "", indent = indent)?;
|
||||
write!(f, "{:width$?}", vals)?;
|
||||
} else {
|
||||
write!(f, "{:?}", vals)?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let stride = dims[1..].iter().product::<usize>();
|
||||
if f.alternate() {
|
||||
write!(f, "{:indent$}", "", indent = indent)?;
|
||||
}
|
||||
write!(f, "[")?;
|
||||
if f.alternate() {
|
||||
writeln!(f)?;
|
||||
}
|
||||
for i in 0..dims[0] {
|
||||
format_recursive(
|
||||
f,
|
||||
&vals[i * stride..(i + 1) * stride],
|
||||
&dims[1..],
|
||||
indent + 2,
|
||||
width,
|
||||
)?;
|
||||
write!(f, ",")?;
|
||||
if f.alternate() {
|
||||
writeln!(f)?;
|
||||
}
|
||||
}
|
||||
if f.alternate() {
|
||||
write!(f, "{:indent$}", "", indent = indent)?;
|
||||
}
|
||||
write!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
if self.dims.len() == 0 {
|
||||
self.vals[0].fmt(f)
|
||||
} else {
|
||||
format_recursive(f, &self.vals, &self.dims, 0, width)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A generic tensor type.
|
||||
pub struct Tensor<T>{
|
||||
pub struct Tensor<T> {
|
||||
inner: InnerTensor,
|
||||
phantom: PhantomData<T>
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Debug for Tensor<T>
|
||||
where
|
||||
Self: GetValues,
|
||||
<Self as GetValues>::Element: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tensor")
|
||||
.field(
|
||||
"values",
|
||||
&TensorPrinter {
|
||||
vals: self.get_values(),
|
||||
dims: self.dimensions(),
|
||||
},
|
||||
)
|
||||
.field("dimensions", &self.dimensions())
|
||||
.field(
|
||||
"dtype",
|
||||
match self.inner {
|
||||
InnerTensor::U8(_) => &"u8",
|
||||
InnerTensor::U16(_) => &"u16",
|
||||
InnerTensor::U32(_) => &"u32",
|
||||
InnerTensor::U64(_) => &"u64",
|
||||
InnerTensor::I8(_) => &"i8",
|
||||
InnerTensor::I16(_) => &"i16",
|
||||
InnerTensor::I32(_) => &"i32",
|
||||
InnerTensor::I64(_) => &"i64",
|
||||
},
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Tensor<T> {
|
||||
/// Create a new tensor from values and dimensions (shape).
|
||||
pub fn new(values: Vec<T>, dimensions: Vec<usize>) -> Tensor<T> where Tensor<T>:FromElements<Element=T> {
|
||||
pub fn new(values: Vec<T>, dimensions: Vec<usize>) -> Tensor<T>
|
||||
where
|
||||
Tensor<T>: FromElements<Element = T>,
|
||||
{
|
||||
assert_eq!(
|
||||
values.len(),
|
||||
dimensions.iter().product::<usize>(),
|
||||
"Wrong number of dimensions provided"
|
||||
);
|
||||
Self::from_elements(values, dimensions)
|
||||
}
|
||||
|
||||
/// Return the dimensions of the tensor.
|
||||
pub fn dimensions(&self) -> &[usize] {
|
||||
match self.inner{
|
||||
match self.inner {
|
||||
InnerTensor::U8(ref unique_ptr) => unique_ptr._get_dimensions(),
|
||||
InnerTensor::U16(ref unique_ptr) => unique_ptr._get_dimensions(),
|
||||
InnerTensor::U32(ref unique_ptr) => unique_ptr._get_dimensions(),
|
||||
@@ -627,7 +795,10 @@ impl<T> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Return the values of the tensor.
|
||||
pub fn values(&self) -> &[T] where Tensor<T>:GetValues<Element=T> {
|
||||
pub fn values(&self) -> &[T]
|
||||
where
|
||||
Tensor<T>: GetValues<Element = T>,
|
||||
{
|
||||
self.get_values()
|
||||
}
|
||||
}
|
||||
@@ -643,102 +814,150 @@ enum InnerTensor {
|
||||
I64(UniquePtr<TensorI64>),
|
||||
}
|
||||
|
||||
impl Debug for Value {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if self._has_element_type_u8() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<u8> {
|
||||
inner: InnerTensor::U8(self._get_tensor_u8()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_u16() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<u16> {
|
||||
inner: InnerTensor::U16(self._get_tensor_u16()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_u32() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<u32> {
|
||||
inner: InnerTensor::U32(self._get_tensor_u32()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_u64() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<u64> {
|
||||
inner: InnerTensor::U64(self._get_tensor_u64()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_i8() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<i8> {
|
||||
inner: InnerTensor::I8(self._get_tensor_i8()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_i16() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<i16> {
|
||||
inner: InnerTensor::I16(self._get_tensor_i16()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_i32() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<i32> {
|
||||
inner: InnerTensor::I32(self._get_tensor_i32()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else if self._has_element_type_i64() {
|
||||
f.debug_tuple("Value")
|
||||
.field(&Tensor::<i64> {
|
||||
inner: InnerTensor::I64(self._get_tensor_i64()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
.finish()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GetTensor<T> {
|
||||
fn _get_tensor(&self) -> Option<Tensor<T>>;
|
||||
}
|
||||
|
||||
impl GetTensor<u8> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<u8>> {
|
||||
self._has_element_type_u8().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::U8(self._get_tensor_u8()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_u8().then(|| Tensor {
|
||||
inner: InnerTensor::U8(self._get_tensor_u8()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<u16> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<u16>> {
|
||||
self._has_element_type_u16().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::U16(self._get_tensor_u16()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_u16().then(|| Tensor {
|
||||
inner: InnerTensor::U16(self._get_tensor_u16()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<u32> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<u32>> {
|
||||
self._has_element_type_u32().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::U32(self._get_tensor_u32()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_u32().then(|| Tensor {
|
||||
inner: InnerTensor::U32(self._get_tensor_u32()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<u64> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<u64>> {
|
||||
self._has_element_type_u64().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::U64(self._get_tensor_u64()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_u64().then(|| Tensor {
|
||||
inner: InnerTensor::U64(self._get_tensor_u64()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<i8> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<i8>> {
|
||||
self._has_element_type_i8().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::I8(self._get_tensor_i8()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_i8().then(|| Tensor {
|
||||
inner: InnerTensor::I8(self._get_tensor_i8()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<i16> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<i16>> {
|
||||
self._has_element_type_i16().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::I16(self._get_tensor_i16()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_i16().then(|| Tensor {
|
||||
inner: InnerTensor::I16(self._get_tensor_i16()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<i32> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<i32>> {
|
||||
self._has_element_type_i32().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::I32(self._get_tensor_i32()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_i32().then(|| Tensor {
|
||||
inner: InnerTensor::I32(self._get_tensor_i32()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GetTensor<i64> for Value {
|
||||
fn _get_tensor(&self) -> Option<Tensor<i64>> {
|
||||
self._has_element_type_i64().then(|| {
|
||||
Tensor {
|
||||
inner: InnerTensor::I64(self._get_tensor_i64()),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
self._has_element_type_i64().then(|| Tensor {
|
||||
inner: InnerTensor::I64(self._get_tensor_i64()),
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Value{
|
||||
impl Value {
|
||||
/// Create a `Value` from a `Tensor`.
|
||||
pub fn from_tensor<T>(input: Tensor<T>) -> UniquePtr<Value> {
|
||||
match input.inner{
|
||||
match input.inner {
|
||||
InnerTensor::U8(unique_ptr) => _value_from_tensor_u8(unique_ptr),
|
||||
InnerTensor::U16(unique_ptr) => _value_from_tensor_u16(unique_ptr),
|
||||
InnerTensor::U32(unique_ptr) => _value_from_tensor_u32(unique_ptr),
|
||||
@@ -751,14 +970,21 @@ impl Value{
|
||||
}
|
||||
|
||||
/// Unwrap the value to a tensor of the given type (if it indeed holds a tensor with elements of this type).
|
||||
pub fn get_tensor<T>(&self) -> Option<Tensor<T>> where Self: GetTensor<T>{
|
||||
pub fn get_tensor<T>(&self) -> Option<Tensor<T>>
|
||||
where
|
||||
Self: GetTensor<T>,
|
||||
{
|
||||
self._get_tensor()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientModule {
|
||||
/// Create a new client module.
|
||||
pub fn new_encrypted(program_info: &ProgramInfo, client_keyset: &ClientKeyset, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientModule> {
|
||||
pub fn new_encrypted(
|
||||
program_info: &ProgramInfo,
|
||||
client_keyset: &ClientKeyset,
|
||||
csprng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientModule> {
|
||||
_client_module_new_encrypted(
|
||||
&serde_json::to_string(program_info).unwrap(),
|
||||
client_keyset,
|
||||
@@ -767,17 +993,21 @@ impl ClientModule {
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn new_simulated(program_info: &ProgramInfo, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientModule> {
|
||||
_client_module_new_simulated(
|
||||
&serde_json::to_string(program_info).unwrap(),
|
||||
csprng,
|
||||
)
|
||||
pub fn new_simulated(
|
||||
program_info: &ProgramInfo,
|
||||
csprng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientModule> {
|
||||
_client_module_new_simulated(&serde_json::to_string(program_info).unwrap(), csprng)
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientFunction {
|
||||
/// Create a new client function.
|
||||
pub fn new_encrypted(circuit_info: &CircuitInfo, client_keyset: &ClientKeyset, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientFunction> {
|
||||
pub fn new_encrypted(
|
||||
circuit_info: &CircuitInfo,
|
||||
client_keyset: &ClientKeyset,
|
||||
csprng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientFunction> {
|
||||
_client_function_new_encrypted(
|
||||
&serde_json::to_string(circuit_info).unwrap(),
|
||||
client_keyset,
|
||||
@@ -786,16 +1016,15 @@ impl ClientFunction {
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn new_simulated(circuit_info: &CircuitInfo, csprng: UniquePtr<EncryptionCsprng>) -> UniquePtr<ClientFunction> {
|
||||
_client_function_new_simulated(
|
||||
&serde_json::to_string(circuit_info).unwrap(),
|
||||
csprng,
|
||||
)
|
||||
pub fn new_simulated(
|
||||
circuit_info: &CircuitInfo,
|
||||
csprng: UniquePtr<EncryptionCsprng>,
|
||||
) -> UniquePtr<ClientFunction> {
|
||||
_client_function_new_simulated(&serde_json::to_string(circuit_info).unwrap(), csprng)
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKeyset {
|
||||
|
||||
/// Return references to the lwe bootstrap keys of this server keyset.
|
||||
pub fn lwe_bootstrap_keys(&self) -> Vec<&LweBootstrapKey> {
|
||||
(0..self._lwe_bootstrap_keys_len())
|
||||
@@ -900,7 +1129,7 @@ impl LweBootstrapKey {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TransportValue{
|
||||
impl std::fmt::Debug for TransportValue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "TransportValue")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user