From d5cc6559ee69907a5e4cb2e58320f339008e7601 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" Date: Thu, 12 May 2022 16:09:25 +0200 Subject: [PATCH] feat(interface): add input graph creation --- concrete-optimizer-cpp/build.rs | 2 +- .../src/concrete-optimizer.rs | 133 ++- .../src/cpp/concrete-optimizer.cpp | 783 +++++++++++++++++- .../src/cpp/concrete-optimizer.hpp | 673 ++++++++++++++- concrete-optimizer-cpp/tests/src/main.cpp | 48 +- .../src/graph/operator/tensor.rs | 2 +- 6 files changed, 1622 insertions(+), 19 deletions(-) diff --git a/concrete-optimizer-cpp/build.rs b/concrete-optimizer-cpp/build.rs index df7cae062..d8a64227f 100644 --- a/concrete-optimizer-cpp/build.rs +++ b/concrete-optimizer-cpp/build.rs @@ -1,5 +1,5 @@ fn main() { let _build = cxx_build::bridge("src/concrete-optimizer.rs"); - println!("cargo:rerun-if-changed=src/lib.rs"); + println!("cargo:rerun-if-changed=src/"); } diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index df19d62cd..4942e4c2c 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -1,3 +1,8 @@ +use concrete_optimizer::graph::operator::{ + self, FunctionTable, LevelledComplexity, OperatorIndex, Shape, +}; +use concrete_optimizer::graph::unparametrized; + fn no_solution() -> ffi::Solution { ffi::Solution { p_error: 1.0, // error probability to signal an impossible solution @@ -5,7 +10,7 @@ fn no_solution() -> ffi::Solution { } } -fn optimise_bootstrap( +fn optimize_bootstrap( precision: u64, security_level: u64, noise_factor: f64, @@ -51,19 +56,141 @@ impl From for ffi::S } } +pub struct OperationDag(unparametrized::OperationDag); + +fn empty() -> Box { + Box::new(OperationDag(unparametrized::OperationDag::new())) +} + +impl OperationDag { + fn add_input(&mut self, out_precision: u8, out_shape: &[u64]) -> ffi::OperatorIndex { + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + self.0.add_input(out_precision, out_shape).into() + } + + fn add_lut(&mut self, input: ffi::OperatorIndex, table: &[u64]) -> ffi::OperatorIndex { + let table = FunctionTable { + values: table.to_owned(), + }; + + self.0.add_lut(input.into(), table).into() + } + + #[allow(clippy::boxed_local)] + fn add_dot( + &mut self, + inputs: &[ffi::OperatorIndex], + weights: Box, + ) -> ffi::OperatorIndex { + let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); + + self.0.add_dot(&inputs, &weights.0).into() + } + + fn add_levelled_op( + &mut self, + inputs: &[ffi::OperatorIndex], + lwe_dim_cost_factor: f64, + fixed_cost: f64, + manp: f64, + out_shape: &[u64], + comment: &str, + ) -> ffi::OperatorIndex { + let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); + + let out_shape = Shape { + dimensions_size: out_shape.to_owned(), + }; + + let complexity = LevelledComplexity { + lwe_dim_cost_factor, + fixed_cost, + }; + + self.0 + .add_levelled_op(&inputs, complexity, manp, out_shape, comment) + .into() + } +} + +pub struct Weights(operator::Weights); + +fn vector(weights: &[u64]) -> Box { + Box::new(Weights(operator::Weights::vector(weights))) +} + +impl From for ffi::OperatorIndex { + fn from(oi: OperatorIndex) -> Self { + Self { index: oi.i } + } +} + +#[allow(clippy::from_over_into)] +impl Into for ffi::OperatorIndex { + fn into(self) -> OperatorIndex { + OperatorIndex { i: self.index } + } +} + #[cxx::bridge] mod ffi { + #[namespace = "concrete_optimizer"] extern "Rust" { - fn optimise_bootstrap( + + #[namespace = "concrete_optimizer::v0"] + fn optimize_bootstrap( precision: u64, security_level: u64, noise_factor: f64, maximum_acceptable_error_probability: f64, ) -> Solution; + + type OperationDag; + + #[namespace = "concrete_optimizer::dag"] + fn empty() -> Box; + + fn add_input( + self: &mut OperationDag, + out_precision: u8, + out_shape: &[u64], + ) -> OperatorIndex; + + fn add_lut(self: &mut OperationDag, input: OperatorIndex, table: &[u64]) -> OperatorIndex; + + fn add_dot( + self: &mut OperationDag, + inputs: &[OperatorIndex], + weights: Box, + ) -> OperatorIndex; + + fn add_levelled_op( + self: &mut OperationDag, + inputs: &[OperatorIndex], + lwe_dim_cost_factor: f64, + fixed_cost: f64, + manp: f64, + out_shape: &[u64], + comment: &str, + ) -> OperatorIndex; + + type Weights; + + #[namespace = "concrete_optimizer::weights"] + fn vector(weights: &[u64]) -> Box; } - #[namespace = "concrete_optimizer"] + #[derive(Clone, Copy)] + #[namespace = "concrete_optimizer::dag"] + struct OperatorIndex { + index: usize, + } + + #[namespace = "concrete_optimizer::v0"] #[derive(Debug, Clone, Copy, Default)] pub struct Solution { pub input_lwe_dimension: u64, //n_big diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index ec804d308..893a28fa8 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1,13 +1,667 @@ +#include +#include +#include #include +#include +#include +#include +#include #include +#include + +namespace rust { +inline namespace cxxbridge1 { +// #include "rust/cxx.h" + +#ifndef CXXBRIDGE1_PANIC +#define CXXBRIDGE1_PANIC +template +void panic [[noreturn]] (const char *msg); +#endif // CXXBRIDGE1_PANIC + +namespace { +template +class impl; +} // namespace + +class String; + +template +::std::size_t size_of(); +template +::std::size_t align_of(); + +#ifndef CXXBRIDGE1_RUST_STR +#define CXXBRIDGE1_RUST_STR +class Str final { +public: + Str() noexcept; + Str(const String &) noexcept; + Str(const std::string &); + Str(const char *); + Str(const char *, std::size_t); + + Str &operator=(const Str &) &noexcept = default; + + explicit operator std::string() const; + + const char *data() const noexcept; + std::size_t size() const noexcept; + std::size_t length() const noexcept; + bool empty() const noexcept; + + Str(const Str &) noexcept = default; + ~Str() noexcept = default; + + using iterator = const char *; + using const_iterator = const char *; + const_iterator begin() const noexcept; + const_iterator end() const noexcept; + const_iterator cbegin() const noexcept; + const_iterator cend() const noexcept; + + bool operator==(const Str &) const noexcept; + bool operator!=(const Str &) const noexcept; + bool operator<(const Str &) const noexcept; + bool operator<=(const Str &) const noexcept; + bool operator>(const Str &) const noexcept; + bool operator>=(const Str &) const noexcept; + + void swap(Str &) noexcept; + +private: + class uninit; + Str(uninit) noexcept; + friend impl; + + std::array repr; +}; +#endif // CXXBRIDGE1_RUST_STR + +#ifndef CXXBRIDGE1_RUST_SLICE +#define CXXBRIDGE1_RUST_SLICE +namespace detail { +template +struct copy_assignable_if {}; + +template <> +struct copy_assignable_if { + copy_assignable_if() noexcept = default; + copy_assignable_if(const copy_assignable_if &) noexcept = default; + copy_assignable_if &operator=(const copy_assignable_if &) &noexcept = delete; + copy_assignable_if &operator=(copy_assignable_if &&) &noexcept = default; +}; +} // namespace detail + +template +class Slice final + : private detail::copy_assignable_if::value> { +public: + using value_type = T; + + Slice() noexcept; + Slice(T *, std::size_t count) noexcept; + + Slice &operator=(const Slice &) &noexcept = default; + Slice &operator=(Slice &&) &noexcept = default; + + T *data() const noexcept; + std::size_t size() const noexcept; + std::size_t length() const noexcept; + bool empty() const noexcept; + + T &operator[](std::size_t n) const noexcept; + T &at(std::size_t n) const; + T &front() const noexcept; + T &back() const noexcept; + + Slice(const Slice &) noexcept = default; + ~Slice() noexcept = default; + + class iterator; + iterator begin() const noexcept; + iterator end() const noexcept; + + void swap(Slice &) noexcept; + +private: + class uninit; + Slice(uninit) noexcept; + friend impl; + friend void sliceInit(void *, const void *, std::size_t) noexcept; + friend void *slicePtr(const void *) noexcept; + friend std::size_t sliceLen(const void *) noexcept; + + std::array repr; +}; + +template +class Slice::iterator final { +public: + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = typename std::add_pointer::type; + using reference = typename std::add_lvalue_reference::type; + + reference operator*() const noexcept; + pointer operator->() const noexcept; + reference operator[](difference_type) const noexcept; + + iterator &operator++() noexcept; + iterator operator++(int) noexcept; + iterator &operator--() noexcept; + iterator operator--(int) noexcept; + + iterator &operator+=(difference_type) noexcept; + iterator &operator-=(difference_type) noexcept; + iterator operator+(difference_type) const noexcept; + iterator operator-(difference_type) const noexcept; + difference_type operator-(const iterator &) const noexcept; + + bool operator==(const iterator &) const noexcept; + bool operator!=(const iterator &) const noexcept; + bool operator<(const iterator &) const noexcept; + bool operator<=(const iterator &) const noexcept; + bool operator>(const iterator &) const noexcept; + bool operator>=(const iterator &) const noexcept; + +private: + friend class Slice; + void *pos; + std::size_t stride; +}; + +template +Slice::Slice() noexcept { + sliceInit(this, reinterpret_cast(align_of()), 0); +} + +template +Slice::Slice(T *s, std::size_t count) noexcept { + assert(s != nullptr || count == 0); + sliceInit(this, + s == nullptr && count == 0 + ? reinterpret_cast(align_of()) + : const_cast::type *>(s), + count); +} + +template +T *Slice::data() const noexcept { + return reinterpret_cast(slicePtr(this)); +} + +template +std::size_t Slice::size() const noexcept { + return sliceLen(this); +} + +template +std::size_t Slice::length() const noexcept { + return this->size(); +} + +template +bool Slice::empty() const noexcept { + return this->size() == 0; +} + +template +T &Slice::operator[](std::size_t n) const noexcept { + assert(n < this->size()); + auto ptr = static_cast(slicePtr(this)) + size_of() * n; + return *reinterpret_cast(ptr); +} + +template +T &Slice::at(std::size_t n) const { + if (n >= this->size()) { + panic("rust::Slice index out of range"); + } + return (*this)[n]; +} + +template +T &Slice::front() const noexcept { + assert(!this->empty()); + return (*this)[0]; +} + +template +T &Slice::back() const noexcept { + assert(!this->empty()); + return (*this)[this->size() - 1]; +} + +template +typename Slice::iterator::reference +Slice::iterator::operator*() const noexcept { + return *static_cast(this->pos); +} + +template +typename Slice::iterator::pointer +Slice::iterator::operator->() const noexcept { + return static_cast(this->pos); +} + +template +typename Slice::iterator::reference Slice::iterator::operator[]( + typename Slice::iterator::difference_type n) const noexcept { + auto ptr = static_cast(this->pos) + this->stride * n; + return *reinterpret_cast(ptr); +} + +template +typename Slice::iterator &Slice::iterator::operator++() noexcept { + this->pos = static_cast(this->pos) + this->stride; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator++(int) noexcept { + auto ret = iterator(*this); + this->pos = static_cast(this->pos) + this->stride; + return ret; +} + +template +typename Slice::iterator &Slice::iterator::operator--() noexcept { + this->pos = static_cast(this->pos) - this->stride; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator--(int) noexcept { + auto ret = iterator(*this); + this->pos = static_cast(this->pos) - this->stride; + return ret; +} + +template +typename Slice::iterator &Slice::iterator::operator+=( + typename Slice::iterator::difference_type n) noexcept { + this->pos = static_cast(this->pos) + this->stride * n; + return *this; +} + +template +typename Slice::iterator &Slice::iterator::operator-=( + typename Slice::iterator::difference_type n) noexcept { + this->pos = static_cast(this->pos) - this->stride * n; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator+( + typename Slice::iterator::difference_type n) const noexcept { + auto ret = iterator(*this); + ret.pos = static_cast(this->pos) + this->stride * n; + return ret; +} + +template +typename Slice::iterator Slice::iterator::operator-( + typename Slice::iterator::difference_type n) const noexcept { + auto ret = iterator(*this); + ret.pos = static_cast(this->pos) - this->stride * n; + return ret; +} + +template +typename Slice::iterator::difference_type +Slice::iterator::operator-(const iterator &other) const noexcept { + auto diff = std::distance(static_cast(other.pos), + static_cast(this->pos)); + return diff / this->stride; +} + +template +bool Slice::iterator::operator==(const iterator &other) const noexcept { + return this->pos == other.pos; +} + +template +bool Slice::iterator::operator!=(const iterator &other) const noexcept { + return this->pos != other.pos; +} + +template +bool Slice::iterator::operator<(const iterator &other) const noexcept { + return this->pos < other.pos; +} + +template +bool Slice::iterator::operator<=(const iterator &other) const noexcept { + return this->pos <= other.pos; +} + +template +bool Slice::iterator::operator>(const iterator &other) const noexcept { + return this->pos > other.pos; +} + +template +bool Slice::iterator::operator>=(const iterator &other) const noexcept { + return this->pos >= other.pos; +} + +template +typename Slice::iterator Slice::begin() const noexcept { + iterator it; + it.pos = slicePtr(this); + it.stride = size_of(); + return it; +} + +template +typename Slice::iterator Slice::end() const noexcept { + iterator it = this->begin(); + it.pos = static_cast(it.pos) + it.stride * this->size(); + return it; +} + +template +void Slice::swap(Slice &rhs) noexcept { + std::swap(*this, rhs); +} +#endif // CXXBRIDGE1_RUST_SLICE + +#ifndef CXXBRIDGE1_RUST_BOX +#define CXXBRIDGE1_RUST_BOX +template +class Box final { +public: + using element_type = T; + using const_pointer = + typename std::add_pointer::type>::type; + using pointer = typename std::add_pointer::type; + + Box() = delete; + Box(Box &&) noexcept; + ~Box() noexcept; + + explicit Box(const T &); + explicit Box(T &&); + + Box &operator=(Box &&) &noexcept; + + const T *operator->() const noexcept; + const T &operator*() const noexcept; + T *operator->() noexcept; + T &operator*() noexcept; + + template + static Box in_place(Fields &&...); + + void swap(Box &) noexcept; + + static Box from_raw(T *) noexcept; + + T *into_raw() noexcept; + + /* Deprecated */ using value_type = element_type; + +private: + class uninit; + class allocation; + Box(uninit) noexcept; + void drop() noexcept; + + friend void swap(Box &lhs, Box &rhs) noexcept { lhs.swap(rhs); } + + T *ptr; +}; + +template +class Box::uninit {}; + +template +class Box::allocation { + static T *alloc() noexcept; + static void dealloc(T *) noexcept; + +public: + allocation() noexcept : ptr(alloc()) {} + ~allocation() noexcept { + if (this->ptr) { + dealloc(this->ptr); + } + } + T *ptr; +}; + +template +Box::Box(Box &&other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; +} + +template +Box::Box(const T &val) { + allocation alloc; + ::new (alloc.ptr) T(val); + this->ptr = alloc.ptr; + alloc.ptr = nullptr; +} + +template +Box::Box(T &&val) { + allocation alloc; + ::new (alloc.ptr) T(std::move(val)); + this->ptr = alloc.ptr; + alloc.ptr = nullptr; +} + +template +Box::~Box() noexcept { + if (this->ptr) { + this->drop(); + } +} + +template +Box &Box::operator=(Box &&other) &noexcept { + if (this->ptr) { + this->drop(); + } + this->ptr = other.ptr; + other.ptr = nullptr; + return *this; +} + +template +const T *Box::operator->() const noexcept { + return this->ptr; +} + +template +const T &Box::operator*() const noexcept { + return *this->ptr; +} + +template +T *Box::operator->() noexcept { + return this->ptr; +} + +template +T &Box::operator*() noexcept { + return *this->ptr; +} + +template +template +Box Box::in_place(Fields &&...fields) { + allocation alloc; + auto ptr = alloc.ptr; + ::new (ptr) T{std::forward(fields)...}; + alloc.ptr = nullptr; + return from_raw(ptr); +} + +template +void Box::swap(Box &rhs) noexcept { + using std::swap; + swap(this->ptr, rhs.ptr); +} + +template +Box Box::from_raw(T *raw) noexcept { + Box box = uninit{}; + box.ptr = raw; + return box; +} + +template +T *Box::into_raw() noexcept { + T *raw = this->ptr; + this->ptr = nullptr; + return raw; +} + +template +Box::Box(uninit) noexcept {} +#endif // CXXBRIDGE1_RUST_BOX + +#ifndef CXXBRIDGE1_RUST_OPAQUE +#define CXXBRIDGE1_RUST_OPAQUE +class Opaque { +public: + Opaque() = delete; + Opaque(const Opaque &) = delete; + ~Opaque() = delete; +}; +#endif // CXXBRIDGE1_RUST_OPAQUE + +#ifndef CXXBRIDGE1_IS_COMPLETE +#define CXXBRIDGE1_IS_COMPLETE +namespace detail { +namespace { +template +struct is_complete : std::false_type {}; +template +struct is_complete : std::true_type {}; +} // namespace +} // namespace detail +#endif // CXXBRIDGE1_IS_COMPLETE + +#ifndef CXXBRIDGE1_LAYOUT +#define CXXBRIDGE1_LAYOUT +class layout { + template + friend std::size_t size_of(); + template + friend std::size_t align_of(); + template + static typename std::enable_if::value, + std::size_t>::type + do_size_of() { + return T::layout::size(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_size_of() { + return sizeof(T); + } + template + static + typename std::enable_if::value, std::size_t>::type + size_of() { + return do_size_of(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_align_of() { + return T::layout::align(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_align_of() { + return alignof(T); + } + template + static + typename std::enable_if::value, std::size_t>::type + align_of() { + return do_align_of(); + } +}; + +template +std::size_t size_of() { + return layout::size_of(); +} + +template +std::size_t align_of() { + return layout::align_of(); +} +#endif // CXXBRIDGE1_LAYOUT +} // namespace cxxbridge1 +} // namespace rust namespace concrete_optimizer { - struct Solution; + struct OperationDag; + struct Weights; + namespace dag { + struct OperatorIndex; + } + namespace v0 { + struct Solution; + } } namespace concrete_optimizer { -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Solution -#define CXXBRIDGE1_STRUCT_concrete_optimizer$Solution +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +struct OperationDag final : public ::rust::Opaque { + ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept; + ~OperationDag() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +#define CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +struct Weights final : public ::rust::Opaque { + ~Weights() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights + +namespace dag { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +struct OperatorIndex final { + ::std::size_t index; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +} // namespace dag + +namespace v0 { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution +#define CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution struct Solution final { ::std::uint64_t input_lwe_dimension; ::std::uint64_t internal_ks_output_lwe_dimension; @@ -23,13 +677,128 @@ struct Solution final { using IsRelocatable = ::std::true_type; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Solution +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution extern "C" { -::concrete_optimizer::Solution concrete_optimizer$cxxbridge1$optimise_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept; +::concrete_optimizer::v0::Solution concrete_optimizer$v0$cxxbridge1$optimize_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept; +} // extern "C" +} // namespace v0 + +extern "C" { +::std::size_t concrete_optimizer$cxxbridge1$OperationDag$operator$sizeof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$OperationDag$operator$alignof() noexcept; } // extern "C" -::concrete_optimizer::Solution optimise_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept { - return concrete_optimizer$cxxbridge1$optimise_bootstrap(precision, security_level, noise_factor, maximum_acceptable_error_probability); +namespace dag { +extern "C" { +::concrete_optimizer::OperationDag *concrete_optimizer$dag$cxxbridge1$empty() noexcept; +} // extern "C" +} // namespace dag + +extern "C" { +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_input(::concrete_optimizer::OperationDag &self, ::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; + +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; + +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_dot(::concrete_optimizer::OperationDag &self, ::rust::Slice inputs, ::concrete_optimizer::Weights *weights) noexcept; + +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_levelled_op(::concrete_optimizer::OperationDag &self, ::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept; +::std::size_t concrete_optimizer$cxxbridge1$Weights$operator$sizeof() noexcept; +::std::size_t concrete_optimizer$cxxbridge1$Weights$operator$alignof() noexcept; +} // extern "C" + +namespace weights { +extern "C" { +::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice weights) noexcept; +} // extern "C" +} // namespace weights + +namespace v0 { +::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept { + return concrete_optimizer$v0$cxxbridge1$optimize_bootstrap(precision, security_level, noise_factor, maximum_acceptable_error_probability); } +} // namespace v0 + +::std::size_t OperationDag::layout::size() noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$operator$sizeof(); +} + +::std::size_t OperationDag::layout::align() noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$operator$alignof(); +} + +namespace dag { +::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept { + return ::rust::Box<::concrete_optimizer::OperationDag>::from_raw(concrete_optimizer$dag$cxxbridge1$empty()); +} +} // namespace dag + +::concrete_optimizer::dag::OperatorIndex OperationDag::add_input(::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$add_input(*this, out_precision, out_shape); +} + +::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table); +} + +::concrete_optimizer::dag::OperatorIndex OperationDag::add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$add_dot(*this, inputs, weights.into_raw()); +} + +::concrete_optimizer::dag::OperatorIndex OperationDag::add_levelled_op(::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, manp, out_shape, comment); +} + +::std::size_t Weights::layout::size() noexcept { + return concrete_optimizer$cxxbridge1$Weights$operator$sizeof(); +} + +::std::size_t Weights::layout::align() noexcept { + return concrete_optimizer$cxxbridge1$Weights$operator$alignof(); +} + +namespace weights { +::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept { + return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights)); +} +} // namespace weights } // namespace concrete_optimizer + +extern "C" { +::concrete_optimizer::OperationDag *cxxbridge1$box$concrete_optimizer$OperationDag$alloc() noexcept; +void cxxbridge1$box$concrete_optimizer$OperationDag$dealloc(::concrete_optimizer::OperationDag *) noexcept; +void cxxbridge1$box$concrete_optimizer$OperationDag$drop(::rust::Box<::concrete_optimizer::OperationDag> *ptr) noexcept; + +::concrete_optimizer::Weights *cxxbridge1$box$concrete_optimizer$Weights$alloc() noexcept; +void cxxbridge1$box$concrete_optimizer$Weights$dealloc(::concrete_optimizer::Weights *) noexcept; +void cxxbridge1$box$concrete_optimizer$Weights$drop(::rust::Box<::concrete_optimizer::Weights> *ptr) noexcept; +} // extern "C" + +namespace rust { +inline namespace cxxbridge1 { +template <> +::concrete_optimizer::OperationDag *Box<::concrete_optimizer::OperationDag>::allocation::alloc() noexcept { + return cxxbridge1$box$concrete_optimizer$OperationDag$alloc(); +} +template <> +void Box<::concrete_optimizer::OperationDag>::allocation::dealloc(::concrete_optimizer::OperationDag *ptr) noexcept { + cxxbridge1$box$concrete_optimizer$OperationDag$dealloc(ptr); +} +template <> +void Box<::concrete_optimizer::OperationDag>::drop() noexcept { + cxxbridge1$box$concrete_optimizer$OperationDag$drop(this); +} +template <> +::concrete_optimizer::Weights *Box<::concrete_optimizer::Weights>::allocation::alloc() noexcept { + return cxxbridge1$box$concrete_optimizer$Weights$alloc(); +} +template <> +void Box<::concrete_optimizer::Weights>::allocation::dealloc(::concrete_optimizer::Weights *ptr) noexcept { + cxxbridge1$box$concrete_optimizer$Weights$dealloc(ptr); +} +template <> +void Box<::concrete_optimizer::Weights>::drop() noexcept { + cxxbridge1$box$concrete_optimizer$Weights$drop(this); +} +} // namespace cxxbridge1 +} // namespace rust diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 6db772aeb..8d7546e0f 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1,14 +1,668 @@ #pragma once +#include +#include +#include #include +#include +#include +#include +#include #include +#include + +namespace rust { +inline namespace cxxbridge1 { +// #include "rust/cxx.h" + +#ifndef CXXBRIDGE1_PANIC +#define CXXBRIDGE1_PANIC +template +void panic [[noreturn]] (const char *msg); +#endif // CXXBRIDGE1_PANIC + +namespace { +template +class impl; +} // namespace + +class String; + +template +::std::size_t size_of(); +template +::std::size_t align_of(); + +#ifndef CXXBRIDGE1_RUST_STR +#define CXXBRIDGE1_RUST_STR +class Str final { +public: + Str() noexcept; + Str(const String &) noexcept; + Str(const std::string &); + Str(const char *); + Str(const char *, std::size_t); + + Str &operator=(const Str &) &noexcept = default; + + explicit operator std::string() const; + + const char *data() const noexcept; + std::size_t size() const noexcept; + std::size_t length() const noexcept; + bool empty() const noexcept; + + Str(const Str &) noexcept = default; + ~Str() noexcept = default; + + using iterator = const char *; + using const_iterator = const char *; + const_iterator begin() const noexcept; + const_iterator end() const noexcept; + const_iterator cbegin() const noexcept; + const_iterator cend() const noexcept; + + bool operator==(const Str &) const noexcept; + bool operator!=(const Str &) const noexcept; + bool operator<(const Str &) const noexcept; + bool operator<=(const Str &) const noexcept; + bool operator>(const Str &) const noexcept; + bool operator>=(const Str &) const noexcept; + + void swap(Str &) noexcept; + +private: + class uninit; + Str(uninit) noexcept; + friend impl; + + std::array repr; +}; +#endif // CXXBRIDGE1_RUST_STR + +#ifndef CXXBRIDGE1_RUST_SLICE +#define CXXBRIDGE1_RUST_SLICE +namespace detail { +template +struct copy_assignable_if {}; + +template <> +struct copy_assignable_if { + copy_assignable_if() noexcept = default; + copy_assignable_if(const copy_assignable_if &) noexcept = default; + copy_assignable_if &operator=(const copy_assignable_if &) &noexcept = delete; + copy_assignable_if &operator=(copy_assignable_if &&) &noexcept = default; +}; +} // namespace detail + +template +class Slice final + : private detail::copy_assignable_if::value> { +public: + using value_type = T; + + Slice() noexcept; + Slice(T *, std::size_t count) noexcept; + + Slice &operator=(const Slice &) &noexcept = default; + Slice &operator=(Slice &&) &noexcept = default; + + T *data() const noexcept; + std::size_t size() const noexcept; + std::size_t length() const noexcept; + bool empty() const noexcept; + + T &operator[](std::size_t n) const noexcept; + T &at(std::size_t n) const; + T &front() const noexcept; + T &back() const noexcept; + + Slice(const Slice &) noexcept = default; + ~Slice() noexcept = default; + + class iterator; + iterator begin() const noexcept; + iterator end() const noexcept; + + void swap(Slice &) noexcept; + +private: + class uninit; + Slice(uninit) noexcept; + friend impl; + friend void sliceInit(void *, const void *, std::size_t) noexcept; + friend void *slicePtr(const void *) noexcept; + friend std::size_t sliceLen(const void *) noexcept; + + std::array repr; +}; + +template +class Slice::iterator final { +public: + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = typename std::add_pointer::type; + using reference = typename std::add_lvalue_reference::type; + + reference operator*() const noexcept; + pointer operator->() const noexcept; + reference operator[](difference_type) const noexcept; + + iterator &operator++() noexcept; + iterator operator++(int) noexcept; + iterator &operator--() noexcept; + iterator operator--(int) noexcept; + + iterator &operator+=(difference_type) noexcept; + iterator &operator-=(difference_type) noexcept; + iterator operator+(difference_type) const noexcept; + iterator operator-(difference_type) const noexcept; + difference_type operator-(const iterator &) const noexcept; + + bool operator==(const iterator &) const noexcept; + bool operator!=(const iterator &) const noexcept; + bool operator<(const iterator &) const noexcept; + bool operator<=(const iterator &) const noexcept; + bool operator>(const iterator &) const noexcept; + bool operator>=(const iterator &) const noexcept; + +private: + friend class Slice; + void *pos; + std::size_t stride; +}; + +template +Slice::Slice() noexcept { + sliceInit(this, reinterpret_cast(align_of()), 0); +} + +template +Slice::Slice(T *s, std::size_t count) noexcept { + assert(s != nullptr || count == 0); + sliceInit(this, + s == nullptr && count == 0 + ? reinterpret_cast(align_of()) + : const_cast::type *>(s), + count); +} + +template +T *Slice::data() const noexcept { + return reinterpret_cast(slicePtr(this)); +} + +template +std::size_t Slice::size() const noexcept { + return sliceLen(this); +} + +template +std::size_t Slice::length() const noexcept { + return this->size(); +} + +template +bool Slice::empty() const noexcept { + return this->size() == 0; +} + +template +T &Slice::operator[](std::size_t n) const noexcept { + assert(n < this->size()); + auto ptr = static_cast(slicePtr(this)) + size_of() * n; + return *reinterpret_cast(ptr); +} + +template +T &Slice::at(std::size_t n) const { + if (n >= this->size()) { + panic("rust::Slice index out of range"); + } + return (*this)[n]; +} + +template +T &Slice::front() const noexcept { + assert(!this->empty()); + return (*this)[0]; +} + +template +T &Slice::back() const noexcept { + assert(!this->empty()); + return (*this)[this->size() - 1]; +} + +template +typename Slice::iterator::reference +Slice::iterator::operator*() const noexcept { + return *static_cast(this->pos); +} + +template +typename Slice::iterator::pointer +Slice::iterator::operator->() const noexcept { + return static_cast(this->pos); +} + +template +typename Slice::iterator::reference Slice::iterator::operator[]( + typename Slice::iterator::difference_type n) const noexcept { + auto ptr = static_cast(this->pos) + this->stride * n; + return *reinterpret_cast(ptr); +} + +template +typename Slice::iterator &Slice::iterator::operator++() noexcept { + this->pos = static_cast(this->pos) + this->stride; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator++(int) noexcept { + auto ret = iterator(*this); + this->pos = static_cast(this->pos) + this->stride; + return ret; +} + +template +typename Slice::iterator &Slice::iterator::operator--() noexcept { + this->pos = static_cast(this->pos) - this->stride; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator--(int) noexcept { + auto ret = iterator(*this); + this->pos = static_cast(this->pos) - this->stride; + return ret; +} + +template +typename Slice::iterator &Slice::iterator::operator+=( + typename Slice::iterator::difference_type n) noexcept { + this->pos = static_cast(this->pos) + this->stride * n; + return *this; +} + +template +typename Slice::iterator &Slice::iterator::operator-=( + typename Slice::iterator::difference_type n) noexcept { + this->pos = static_cast(this->pos) - this->stride * n; + return *this; +} + +template +typename Slice::iterator Slice::iterator::operator+( + typename Slice::iterator::difference_type n) const noexcept { + auto ret = iterator(*this); + ret.pos = static_cast(this->pos) + this->stride * n; + return ret; +} + +template +typename Slice::iterator Slice::iterator::operator-( + typename Slice::iterator::difference_type n) const noexcept { + auto ret = iterator(*this); + ret.pos = static_cast(this->pos) - this->stride * n; + return ret; +} + +template +typename Slice::iterator::difference_type +Slice::iterator::operator-(const iterator &other) const noexcept { + auto diff = std::distance(static_cast(other.pos), + static_cast(this->pos)); + return diff / this->stride; +} + +template +bool Slice::iterator::operator==(const iterator &other) const noexcept { + return this->pos == other.pos; +} + +template +bool Slice::iterator::operator!=(const iterator &other) const noexcept { + return this->pos != other.pos; +} + +template +bool Slice::iterator::operator<(const iterator &other) const noexcept { + return this->pos < other.pos; +} + +template +bool Slice::iterator::operator<=(const iterator &other) const noexcept { + return this->pos <= other.pos; +} + +template +bool Slice::iterator::operator>(const iterator &other) const noexcept { + return this->pos > other.pos; +} + +template +bool Slice::iterator::operator>=(const iterator &other) const noexcept { + return this->pos >= other.pos; +} + +template +typename Slice::iterator Slice::begin() const noexcept { + iterator it; + it.pos = slicePtr(this); + it.stride = size_of(); + return it; +} + +template +typename Slice::iterator Slice::end() const noexcept { + iterator it = this->begin(); + it.pos = static_cast(it.pos) + it.stride * this->size(); + return it; +} + +template +void Slice::swap(Slice &rhs) noexcept { + std::swap(*this, rhs); +} +#endif // CXXBRIDGE1_RUST_SLICE + +#ifndef CXXBRIDGE1_RUST_BOX +#define CXXBRIDGE1_RUST_BOX +template +class Box final { +public: + using element_type = T; + using const_pointer = + typename std::add_pointer::type>::type; + using pointer = typename std::add_pointer::type; + + Box() = delete; + Box(Box &&) noexcept; + ~Box() noexcept; + + explicit Box(const T &); + explicit Box(T &&); + + Box &operator=(Box &&) &noexcept; + + const T *operator->() const noexcept; + const T &operator*() const noexcept; + T *operator->() noexcept; + T &operator*() noexcept; + + template + static Box in_place(Fields &&...); + + void swap(Box &) noexcept; + + static Box from_raw(T *) noexcept; + + T *into_raw() noexcept; + + /* Deprecated */ using value_type = element_type; + +private: + class uninit; + class allocation; + Box(uninit) noexcept; + void drop() noexcept; + + friend void swap(Box &lhs, Box &rhs) noexcept { lhs.swap(rhs); } + + T *ptr; +}; + +template +class Box::uninit {}; + +template +class Box::allocation { + static T *alloc() noexcept; + static void dealloc(T *) noexcept; + +public: + allocation() noexcept : ptr(alloc()) {} + ~allocation() noexcept { + if (this->ptr) { + dealloc(this->ptr); + } + } + T *ptr; +}; + +template +Box::Box(Box &&other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; +} + +template +Box::Box(const T &val) { + allocation alloc; + ::new (alloc.ptr) T(val); + this->ptr = alloc.ptr; + alloc.ptr = nullptr; +} + +template +Box::Box(T &&val) { + allocation alloc; + ::new (alloc.ptr) T(std::move(val)); + this->ptr = alloc.ptr; + alloc.ptr = nullptr; +} + +template +Box::~Box() noexcept { + if (this->ptr) { + this->drop(); + } +} + +template +Box &Box::operator=(Box &&other) &noexcept { + if (this->ptr) { + this->drop(); + } + this->ptr = other.ptr; + other.ptr = nullptr; + return *this; +} + +template +const T *Box::operator->() const noexcept { + return this->ptr; +} + +template +const T &Box::operator*() const noexcept { + return *this->ptr; +} + +template +T *Box::operator->() noexcept { + return this->ptr; +} + +template +T &Box::operator*() noexcept { + return *this->ptr; +} + +template +template +Box Box::in_place(Fields &&...fields) { + allocation alloc; + auto ptr = alloc.ptr; + ::new (ptr) T{std::forward(fields)...}; + alloc.ptr = nullptr; + return from_raw(ptr); +} + +template +void Box::swap(Box &rhs) noexcept { + using std::swap; + swap(this->ptr, rhs.ptr); +} + +template +Box Box::from_raw(T *raw) noexcept { + Box box = uninit{}; + box.ptr = raw; + return box; +} + +template +T *Box::into_raw() noexcept { + T *raw = this->ptr; + this->ptr = nullptr; + return raw; +} + +template +Box::Box(uninit) noexcept {} +#endif // CXXBRIDGE1_RUST_BOX + +#ifndef CXXBRIDGE1_RUST_OPAQUE +#define CXXBRIDGE1_RUST_OPAQUE +class Opaque { +public: + Opaque() = delete; + Opaque(const Opaque &) = delete; + ~Opaque() = delete; +}; +#endif // CXXBRIDGE1_RUST_OPAQUE + +#ifndef CXXBRIDGE1_IS_COMPLETE +#define CXXBRIDGE1_IS_COMPLETE +namespace detail { +namespace { +template +struct is_complete : std::false_type {}; +template +struct is_complete : std::true_type {}; +} // namespace +} // namespace detail +#endif // CXXBRIDGE1_IS_COMPLETE + +#ifndef CXXBRIDGE1_LAYOUT +#define CXXBRIDGE1_LAYOUT +class layout { + template + friend std::size_t size_of(); + template + friend std::size_t align_of(); + template + static typename std::enable_if::value, + std::size_t>::type + do_size_of() { + return T::layout::size(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_size_of() { + return sizeof(T); + } + template + static + typename std::enable_if::value, std::size_t>::type + size_of() { + return do_size_of(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_align_of() { + return T::layout::align(); + } + template + static typename std::enable_if::value, + std::size_t>::type + do_align_of() { + return alignof(T); + } + template + static + typename std::enable_if::value, std::size_t>::type + align_of() { + return do_align_of(); + } +}; + +template +std::size_t size_of() { + return layout::size_of(); +} + +template +std::size_t align_of() { + return layout::align_of(); +} +#endif // CXXBRIDGE1_LAYOUT +} // namespace cxxbridge1 +} // namespace rust namespace concrete_optimizer { - struct Solution; + struct OperationDag; + struct Weights; + namespace dag { + struct OperatorIndex; + } + namespace v0 { + struct Solution; + } } namespace concrete_optimizer { -#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Solution -#define CXXBRIDGE1_STRUCT_concrete_optimizer$Solution +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +#define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag +struct OperationDag final : public ::rust::Opaque { + ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept; + ~OperationDag() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag + +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +#define CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +struct Weights final : public ::rust::Opaque { + ~Weights() = delete; + +private: + friend ::rust::layout; + struct layout { + static ::std::size_t size() noexcept; + static ::std::size_t align() noexcept; + }; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights + +namespace dag { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +struct OperatorIndex final { + ::std::size_t index; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex +} // namespace dag + +namespace v0 { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution +#define CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution struct Solution final { ::std::uint64_t input_lwe_dimension; ::std::uint64_t internal_ks_output_lwe_dimension; @@ -24,7 +678,16 @@ struct Solution final { using IsRelocatable = ::std::true_type; }; -#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Solution +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$v0$Solution -::concrete_optimizer::Solution optimise_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept; +::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, ::std::uint64_t security_level, double noise_factor, double maximum_acceptable_error_probability) noexcept; +} // namespace v0 + +namespace dag { +::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept; +} // namespace dag + +namespace weights { +::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice weights) noexcept; +} // namespace weights } // namespace concrete_optimizer diff --git a/concrete-optimizer-cpp/tests/src/main.cpp b/concrete-optimizer-cpp/tests/src/main.cpp index 8cd5c3ad1..dd8230fbd 100644 --- a/concrete-optimizer-cpp/tests/src/main.cpp +++ b/concrete-optimizer-cpp/tests/src/main.cpp @@ -1,7 +1,16 @@ #include "concrete-optimizer.hpp" +#include -int main(int argc, char *argv[]) { - concrete_optimizer::Solution solution = concrete_optimizer::optimise_bootstrap(1, 128, 1, .05); +template +rust::cxxbridge1::Slice slice(std::vector &vec) { + const T *data = vec.data(); + + return rust::cxxbridge1::Slice(data, vec.size()); +} + +int test1() { + concrete_optimizer::v0::Solution solution = + concrete_optimizer::v0::optimize_bootstrap(1, 128, 1, .05); if (solution.glwe_polynomial_size != 1024) { return 1; @@ -9,3 +18,38 @@ int main(int argc, char *argv[]) { return 0; } + +int test2() { + auto dag = concrete_optimizer::dag::empty(); + + std::vector shape = {3}; + + concrete_optimizer::dag::OperatorIndex node1 = + dag->add_input(1, slice(shape)); + + std::vector inputs = {node1}; + + std::vector weight_vec = {3}; + + rust::cxxbridge1::Box weights = + concrete_optimizer::weights::vector(slice(weight_vec)); + + concrete_optimizer::dag::OperatorIndex node2 = + dag->add_dot(slice(inputs), std::move(weights)); + + return 0; +} + +int main(int argc, char *argv[]) { + int err = test1(); + + if (err) + return err; + + err = test2(); + + if (err) + return err; + + return 0; +} diff --git a/concrete-optimizer/src/graph/operator/tensor.rs b/concrete-optimizer/src/graph/operator/tensor.rs index b78b6f02e..82a6c9628 100644 --- a/concrete-optimizer/src/graph/operator/tensor.rs +++ b/concrete-optimizer/src/graph/operator/tensor.rs @@ -1,7 +1,7 @@ use delegate::delegate; #[derive(Clone, PartialEq, Eq, Debug)] pub struct Shape { - dimensions_size: Vec, + pub dimensions_size: Vec, } impl Shape {