/* * ShareMatrix.h * */ #ifndef PROTOCOLS_SHAREMATRIX_H_ #define PROTOCOLS_SHAREMATRIX_H_ #include using namespace std; #include "Share.h" #include "FHE/AddableVector.h" template class MatrixMC; template class NonInitVector { template friend class NonInitVector; size_t size_; public: AddableVector v; NonInitVector(size_t size) : size_(size) { v.reserve(size); } template NonInitVector(const NonInitVector& other) : size_(other.size()), v(other.v) { } size_t size() const { return size_; } void init() { v.resize(size_); } void check() const { #ifdef DEBUG_MATRIX assert(not v.empty()); #endif } typename vector::iterator begin() { check(); return v.begin(); } typename vector::iterator end() { check(); return v.end(); } T& at(size_t index) { check(); return v.at(index); } const T& at(size_t index) const { #ifdef DEBUG_MATRIX assert(index < size()); #endif return (*this)[index]; } T& operator[](size_t index) { check(); return v[index]; } const T& operator[](size_t index) const { check(); return v[index]; } NonInitVector operator-(const NonInitVector& other) const { assert(size() == other.size()); NonInitVector res(size()); if (other.v.empty()) return *this; else if (v.empty()) { res.init(); res.v = res.v - other.v; } else res.v = v - other.v; return res; } NonInitVector& operator+=(const NonInitVector& other) { assert(size() == other.size()); if (not other.v.empty()) { if (v.empty()) *this = other; else v += other.v; } return *this; } bool operator!=(const NonInitVector& other) const { return v != other.v; } void randomize(PRNG& G) { v.clear(); for (size_t i = 0; i < size(); i++) v.push_back(G.get()); } }; template class ValueMatrix : public ValueInterface { typedef ValueMatrix This; public: int n_rows, n_cols; NonInitVector entries; static DataFieldType field_type() { return T::field_type(); } ValueMatrix(int n_rows = 0, int n_cols = 0) : n_rows(n_rows), n_cols(n_cols), entries(n_rows * n_cols) { check(); } template ValueMatrix(const ValueMatrix& other) : n_rows(other.n_rows), n_cols(other.n_cols), entries(other.entries) { check(); } void check() const { assert(entries.size() == size_t(n_rows * n_cols)); } T& operator[](const pair& indices) { #ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); #endif return entries.at(indices.first * n_cols + indices.second); } const T& operator[](const pair& indices) const { #ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); #endif return entries.at(indices.first * n_cols + indices.second); } This& operator+=(const This& other) { entries += other.entries; check(); return *this; } This operator-(const This& other) const { assert(entries.size() == other.entries.size()); This res(n_rows, n_cols); res.entries = entries - other.entries; res.check(); return res; } This operator*(const This& other) const { assert(n_cols == other.n_rows); This res(n_rows, other.n_cols); if (entries.v.empty() or other.entries.v.empty()) return res; res.entries.init(); for (int i = 0; i < n_rows; i++) for (int j = 0; j < other.n_cols; j++) for (int k = 0; k < n_cols; k++) res[{i, j}] += (*this)[{i, k}] * other[{k, j}]; res.check(); return res; } bool operator!=(const This& other) const { if (n_rows != other.n_rows or n_cols != other.n_cols) return true; return entries != other.entries; } void randomize(PRNG& G) { entries.randomize(G); } ValueMatrix transpose() const { ValueMatrix res(this->n_cols, this->n_rows); for (int j = 0; j < this->n_cols; j++) for (int i = 0; i < this->n_rows; i++) res.entries.v.push_back((*this)[{i, j}]); return res; } friend ostream& operator<<(ostream& o, const This&) { return o; } }; template class ShareMatrix : public ValueMatrix, public ShareInterface { typedef ShareMatrix This; typedef ValueMatrix super; public: typedef MatrixMC MAC_Check; typedef Beaver Protocol; typedef ::Input Input; typedef DummyLivePrep LivePrep; typedef ValueMatrix clear; typedef clear open_type; typedef typename T::mac_key_type mac_key_type; static string type_string() { return "matrix"; } static This constant(const clear& other, int my_num, mac_key_type key) { This res(other.n_rows, other.n_cols); for (size_t i = 0; i < other.entries.size(); i++) res.entries.v.push_back(T::constant(other.entries[i], my_num, key)); res.check(); return res; } ShareMatrix(int n_rows = 0, int n_cols = 0) : ValueMatrix(n_rows, n_cols) { } template ShareMatrix(const U& other) : super(other) { } ShareMatrix from_row(int start, int size) const { ShareMatrix res(min(size, this->n_rows - start), this->n_cols); for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) res[{i, j}] = (*this)[{start + i, j}]; return res; } ShareMatrix from_col(int start, int size) const { ShareMatrix res(this->n_rows, min(size, this->n_cols - start)); res.entries.clear(); for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) res.entries.v.push_back((*this)[{i, start + j}]); return res; } ShareMatrix from(int start_row, int start_col, int* sizes, bool for_real = true) const { ShareMatrix res(min(sizes[0], this->n_rows - start_row), min(sizes[1], this->n_cols - start_col)); if (not for_real) return res; for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) res.entries.v.push_back((*this)[{start_row + i, start_col + j}]); return res; } void add_from_col(int start, const ShareMatrix& other) { this->entries.init(); for (int i = 0; i < this->n_rows; i++) for (int j = 0; j < other.n_cols; j++) (*this)[{i, start + j}] += other[{i, j}]; } }; template ShareMatrix operator*(const ValueMatrix& a, const ShareMatrix& b) { assert(a.n_cols == b.n_rows); ShareMatrix res(a.n_rows, b.n_cols); if (a.entries.v.empty() or b.entries.v.empty()) return res; res.entries.init(); for (int i = 0; i < a.n_rows; i++) for (int j = 0; j < b.n_cols; j++) for (int k = 0; k < a.n_cols; k++) res[{i, j}] += a[{i, k}] * b[{k, j}]; res.check(); return res; } template class MatrixMC : public MAC_Check_Base> { typename T::MAC_Check& inner; public: MatrixMC() : inner( *(OnlineOptions::singleton.direct ? new typename T::Direct_MC : new typename T::MAC_Check)) { } ~MatrixMC() { delete &inner; } void exchange(const Player& P) { inner.init_open(P); for (auto& share : this->secrets) { share.check(); for (auto& entry : share.entries) inner.prepare_open(entry); } inner.exchange(P); for (auto& share : this->secrets) { this->values.push_back({share.n_rows, share.n_cols}); if (share.entries.v.empty()) for (size_t i = 0; i < share.entries.size(); i++) inner.finalize_open(); else { auto range = inner.finalize_several(share.entries.size()); auto& v = this->values.back().entries.v; v.insert(v.begin(), range[0], range[1]); } } } }; #endif /* PROTOCOLS_SHAREMATRIX_H_ */