diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index bd005c84..534c2a99 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -58,7 +58,7 @@ public: void assign_zero() { *this = {}; } bool is_zero() { return *this == P256Element(); } - void add(octetStream& os) { *this += os.get(); } + void add(octetStream& os, int = -1) { *this += os.get(); } void pack(octetStream& os, int = -1) const; void unpack(octetStream& os, int = -1); diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 0a638f85..27f85b61 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -113,7 +113,7 @@ void Ciphertext::mul(const Ciphertext& c, const Rq_Element& ra) ::mul(cc1,ra,c.cc1); } -void Ciphertext::add(octetStream& os) +void Ciphertext::add(octetStream& os, int) { Ciphertext tmp(*params); tmp.unpack(os); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index 11a23e2a..9ebeed32 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -117,11 +117,11 @@ class Ciphertext int level() const { return cc0.level(); } /// Append to buffer - void pack(octetStream& o) const + void pack(octetStream& o, int = -1) const { cc0.pack(o); cc1.pack(o); o.store(pk_id); } /// Read from buffer. Assumes parameters are set correctly - void unpack(octetStream& o) + void unpack(octetStream& o, int = -1) { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } void output(ostream& s) const @@ -129,7 +129,7 @@ class Ciphertext void input(istream& s) { cc0.input(s); cc1.input(s); s.read((char*)&pk_id, sizeof(pk_id)); } - void add(octetStream& os); + void add(octetStream& os, int = -1); size_t report_size(ReportType type) const { return cc0.report_size(type) + cc1.report_size(type); } }; diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index f342e203..b0e88e97 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -45,10 +45,11 @@ class FHE_SK const Rq_Element& s() const { return sk; } /// Append to buffer - void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } + void pack(octetStream& os, int = -1) const { sk.pack(os); pr.pack(os); } /// Read from buffer. Assumes parameters are set correctly - void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } + void unpack(octetStream& os, int = -1) + { sk.unpack(os, *params); pr.unpack(os); } // Assumes Ring and prime of mess have already been set correctly // Ciphertext c must be at level 0 or an error occurs @@ -88,7 +89,8 @@ class FHE_SK bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; } - void add(octetStream& os) { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; } + void add(octetStream& os, int = -1) + { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; } void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const; diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index d6a14aab..f65eddbd 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -109,7 +109,7 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } -void Rq_Element::add(octetStream& os) +void Rq_Element::add(octetStream& os, int) { Rq_Element tmp(*this); tmp.unpack(os); @@ -298,7 +298,7 @@ void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y) partial_assign(x); } -void Rq_Element::pack(octetStream& o) const +void Rq_Element::pack(octetStream& o, int) const { check_level(); o.store(lev); @@ -306,7 +306,7 @@ void Rq_Element::pack(octetStream& o) const a[i].pack(o); } -void Rq_Element::unpack(octetStream& o) +void Rq_Element::unpack(octetStream& o, int) { unsigned int ll; o.get(ll); lev=ll; check_level(); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index 4e0cdf97..f315d22b 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -94,7 +94,7 @@ protected: friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); - void add(octetStream& os); + void add(octetStream& os, int = -1); template Rq_Element& operator+=(const vector& other); @@ -157,8 +157,8 @@ protected: * For unpack we assume the prData for a0 and a1 has been assigned * correctly already */ - void pack(octetStream& o) const; - void unpack(octetStream& o); + void pack(octetStream& o, int = -1) const; + void unpack(octetStream& o, int = -1); // without prior initialization void unpack(octetStream& o, const FHE_Params& params); diff --git a/GC/Semi.cpp b/GC/Semi.cpp index e00fed69..0a0e3f91 100644 --- a/GC/Semi.cpp +++ b/GC/Semi.cpp @@ -33,4 +33,9 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, prepare_mul(x, y, n); } +void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n) +{ + super::prepare_mul(x.mask(n), y.mask(n), n); +} + } /* namespace GC */ diff --git a/GC/Semi.h b/GC/Semi.h index 92f9139a..65411253 100644 --- a/GC/Semi.h +++ b/GC/Semi.h @@ -24,6 +24,7 @@ public: void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, bool repeat); + void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n); }; } /* namespace GC */ diff --git a/Math/Bit.h b/Math/Bit.h index 10c4e018..b5d56510 100644 --- a/Math/Bit.h +++ b/Math/Bit.h @@ -51,6 +51,11 @@ public: return other * *this; } + void add(octetStream& os, int = -1) + { + *this += os.get(); + } + void pack(octetStream& os, int = -1) const { super::pack(os, 1); diff --git a/Math/BitVec.h b/Math/BitVec.h index f4b0a1e2..a362d010 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -56,7 +56,12 @@ public: void extend_bit(BitVec_& res, int) const { res = extend_bit(); } - void add(octetStream& os) { *this += os.get(); } + void add(octetStream& os, int n_bits) + { + BitVec_ tmp; + tmp.unpack(os, n_bits); + *this += tmp; + } void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; } @@ -70,7 +75,7 @@ public: if (n == -1) pack(os); else if (n == 1) - os.store_int<1>(this->a & 1); + os.store_bit(this->a); else os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } @@ -80,7 +85,7 @@ public: if (n == -1) unpack(os); else if (n == 1) - this->a = os.get_int<1>(); + this->a = os.get_bit(); else this->a = os.get_int(DIV_CEIL(n, 8)); } diff --git a/Math/Z2k.h b/Math/Z2k.h index b5ffb196..3e10c852 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -151,7 +151,7 @@ public: bool operator==(const Z2& other) const; bool operator!=(const Z2& other) const { return not (*this == other); } - void add(octetStream& os) { *this += (os.consume(size())); } + void add(octetStream& os, int = -1) { *this += (os.consume(size())); } Z2 lazy_add(const Z2& x) const; Z2 lazy_mul(const Z2& x) const; diff --git a/Math/gf2n.h b/Math/gf2n.h index a09d9aed..bed5ba72 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -138,7 +138,7 @@ protected: { a=x.a^y.a; } void add(octet* x) { a^=*(U*)(x); } - void add(octetStream& os) + void add(octetStream& os, int = -1) { add(os.consume(size())); } void sub(const gf2n_& x,const gf2n_& y) { a=x.a^y.a; } diff --git a/Math/gfp.h b/Math/gfp.h index 9f5475e2..fe6f64c3 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -187,7 +187,7 @@ class gfp_ : public ValueInterface bool operator!=(const gfp_& y) const { return !equal(y); } // x+y - void add(octetStream& os) + void add(octetStream& os, int = -1) { add(os.consume(size())); } void add(const gfp_& x,const gfp_& y) { ZpD.Add(a.x,x.a.x,y.a.x); } diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index eb065cf4..06b75385 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -290,7 +290,7 @@ bool gfpvar_::operator !=(const gfpvar_& other) const } template -void gfpvar_::add(octetStream& other) +void gfpvar_::add(octetStream& other, int) { *this += other.get>(); } diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 55c08d4b..b6ab2ae3 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -148,7 +148,7 @@ public: bool operator==(const gfpvar_& other) const; bool operator!=(const gfpvar_& other) const; - void add(octetStream& other); + void add(octetStream& other, int = -1); void negate(); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 09c6e056..050ab31e 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -125,6 +125,7 @@ void InputBase::add_from_all(const typename T::open_type& input, int n_bits) template void Input::send_mine() { + this->os[P.my_num()].append(0); P.send_all(this->os[P.my_num()]); } diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 311de4d9..2cfe3d8b 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -35,12 +35,17 @@ class TreeSum void start(vector& values, const Player& P); void finish(vector& values, const Player& P); + void add_openings(vector& values, const Player& P, int sum_players, + int last_sum_players, int send_player); + protected: int base_player; int opening_sum; int max_broadcast; octetStream os; + vector lengths; + void ReceiveValues(vector& values, const Player& P, int sender); virtual void AddToValues(vector& values) { (void)values; } @@ -157,10 +162,6 @@ using MAC_Check_Z2k_ = MAC_Check_Z2k; -template - void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum& MC); - - /** * SPDZ opening protocol with MAC check (pairwise communication) */ @@ -239,13 +240,16 @@ size_t TreeSum::report_size(ReportType type) } template -void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum& MC) +void TreeSum::add_openings(vector& values, const Player& P, + int sum_players, int last_sum_players, int send_player) { + auto& MC = *this; MC.player_timers.resize(P.num_players()); vector& oss = MC.oss; oss.resize(P.num_players()); vector senders; senders.reserve(P.num_players()); + bool use_lengths = values.size() == lengths.size(); for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players; relative_sender < last_sum_players; relative_sender += sum_players) @@ -266,7 +270,7 @@ void add_openings(vector& values, const Player& P, int sum_players, int last_ MC.timers[SUM].start(); for (unsigned int i=0; i::start(vector& values, const Player& P) os.reset_write_head(); int sum_players = P.num_players(); int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); + bool use_lengths = values.size() == lengths.size(); while (true) { // summing phase @@ -289,7 +294,8 @@ void TreeSum::start(vector& values, const Player& P) { // send to the player up the tree for (unsigned int i=0; i::start(vector& values, const Player& P) { // if receiving, add the values timers[RECV_ADD].start(); - add_openings(values, P, sum_players, last_sum_players, base_player, *this); + add_openings(values, P, sum_players, last_sum_players, base_player); timers[RECV_ADD].stop(); } } @@ -311,7 +317,8 @@ void TreeSum::start(vector& values, const Player& P) os.reset_write_head(); size_t n = values.size(); for (unsigned int i=0; i::ReceiveValues(vector& values, const Player& P, int sender) timers[RECV_SUM].start(); P.receive_player(sender, os); timers[RECV_SUM].stop(); + bool use_lengths = values.size() == lengths.size(); for (unsigned int i = 0; i < values.size(); i++) - values[i].unpack(os); + values[i].unpack(os, use_lengths ? lengths[i] : -1); AddToValues(values); } diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index ec26bc84..9203d849 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -208,7 +208,7 @@ void MaliciousBitOnlyRepPrep::buffer_bits() assert(MC.open(f, P) * MC.open(f, P) == MC.open(h, P)); #endif } - auto t = Create_Random(P); + auto t = Create_Random(P); for (int i = 0; i < buffer_size; i++) { T& a = bits[i]; @@ -216,7 +216,7 @@ void MaliciousBitOnlyRepPrep::buffer_bits() masked.push_back(t * a - f); } MC.POpen(opened, masked, P); - typename T::clear t2 = t * t; + typename T::open_type t2 = t * t; for (int i = 0; i < buffer_size; i++) { T& a = bits[i]; diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index a2deab2b..941fabcc 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -136,6 +136,9 @@ void Rep4::prepare_joint_input(int sender, int backup, int receiver, throw not_implemented(); } } + + for (auto& x : send_os) + x.append(0); } template @@ -176,6 +179,7 @@ void Rep4::finalize_joint_input(int sender, int backup, int receiver, x.res[index] += res[1]; } + os->consume(0); receive_hashes[sender][backup].update(start, os->get_data_ptr() - start); } diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index cad357a6..1f0cc5e2 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -200,6 +200,7 @@ void Replicated::prepare_reshare(const typename T::clear& share, template void Replicated::exchange() { + os[0].append(0); if (os[0].get_length() > 0) P.pass_around(os[0], os[1], 1); this->rounds++; @@ -208,6 +209,7 @@ void Replicated::exchange() template void Replicated::start_exchange() { + os[0].append(0); P.send_relative(1, os[0]); this->rounds++; } diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index ffc34d6f..dc116967 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -47,12 +47,16 @@ void ReplicatedInput::add_other(int player, int) template void ReplicatedInput::send_mine() { + for (auto& x : os) + x.append(0); P.send_relative(os); } template void ReplicatedInput::exchange() { + for (auto& x : os) + x.append(0); bool receive = expect[P.get_player(1)]; bool send = not os[1].empty(); auto& dest = InputBase::os[P.get_player(1)]; diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 27fd3b71..747ca539 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -16,7 +16,6 @@ template class SemiMC : public TreeSum, public MAC_Check_Base { protected: - vector lengths; public: // emulate MAC_Check diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index 75aa0c6e..0a9c7a09 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -14,15 +14,15 @@ template void SemiMC::init_open(const Player& P, int n) { MAC_Check_Base::init_open(P, n); - lengths.clear(); - lengths.reserve(n); + this->lengths.clear(); + this->lengths.reserve(n); } template void SemiMC::prepare_open(const T& secret, int n_bits) { this->values.push_back(secret); - lengths.push_back(n_bits); + this->lengths.push_back(n_bits); } template @@ -53,6 +53,7 @@ void DirectSemiMC::exchange_(const PlayerBase& P) assert(this->values.size() == this->lengths.size()); for (size_t i = 0; i < this->lengths.size(); i++) this->values[i].pack(oss.mine, this->lengths[i]); + oss.mine.append(0); P.unchecked_broadcast(oss); size_t n = P.num_players(); size_t me = P.my_num(); diff --git a/Tools/Hash.h b/Tools/Hash.h index 706ddf32..1fe89903 100644 --- a/Tools/Hash.h +++ b/Tools/Hash.h @@ -39,6 +39,7 @@ public: octetStream tmp(v.size() * sizeof(T)); for (size_t i = 0; i < v.size(); i++) v[i].pack(tmp, bit_lengths[i]); + tmp.append(0); update(tmp); } void update(const string& str); diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index 2955becb..d028cd2c 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -41,6 +41,7 @@ void octetStream::assign(const octetStream& os) len=os.len; memcpy(data,os.data,len*sizeof(octet)); ptr=os.ptr; + bits = os.bits; } @@ -68,6 +69,7 @@ octetStream::octetStream(const octetStream& os) data=new octet[mxlen]; memcpy(data,os.data,len*sizeof(octet)); ptr=os.ptr; + bits = os.bits; } octetStream::octetStream(FlexBuffer& buffer) @@ -123,26 +125,20 @@ bool octetStream::equals(const octetStream& a) const void octetStream::append_random(size_t num) { - resize(len+num); - randombytes_buf(data+len, num); - len+=num; + randombytes_buf(append(num), num); } void octetStream::concat(const octetStream& os) { - resize(len+os.len); - memcpy(data+len,os.data,os.len*sizeof(octet)); - len+=os.len; + memcpy(append(os.len), os.data, os.len*sizeof(octet)); } void octetStream::store_bytes(octet* x, const size_t l) { - resize(len+4+l); - encode_length(data+len,l,4); len+=4; - memcpy(data+len,x,l*sizeof(octet)); - len+=l; + encode_length(append(4), l, 4); + memcpy(append(l), x, l*sizeof(octet)); } void octetStream::get_bytes(octet* ans, size_t& length) @@ -153,9 +149,7 @@ void octetStream::get_bytes(octet* ans, size_t& length) void octetStream::store(int l) { - resize(len+4); - encode_length(data+len,l,4); - len+=4; + encode_length(append(4), l, 4); } @@ -168,15 +162,9 @@ void octetStream::get(int& l) void octetStream::store(const bigint& x) { size_t num=numBytes(x); - resize(len+num+5); - - (data+len)[0]=0; - if (x<0) { (data+len)[0]=1; } - len++; - - encode_length(data+len,num,4); len+=4; - bytesFromBigint(data+len,x,num); - len+=num; + *append(1) = x < 0; + encode_length(append(4), num, 4); + bytesFromBigint(append(num), x, num); } diff --git a/Tools/octetStream.h b/Tools/octetStream.h index e69d04cf..77faa05e 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -48,6 +48,18 @@ class octetStream size_t len,mxlen,ptr; // len is the "write head", ptr is the "read head" octet *data; + class BitBuffer + { + public: + uint8_t n, buffer; + BitBuffer() : n(0), buffer(0) + { + } + }; + + // buffers for bit packing + array bits; + void reset(); public: @@ -86,9 +98,9 @@ class octetStream /// Allocation size_t get_max_length() const { return mxlen; } /// Data pointer - octet* get_data() const { return data; } + octet* get_data() const { assert(bits[0].n == 0); return data; } /// Read pointer - octet* get_data_ptr() const { return data + ptr; } + octet* get_data_ptr() const { assert(bits[1].n == 0); return data + ptr; } /// Whether done reading bool done() const { return ptr == len; } @@ -111,9 +123,9 @@ class octetStream void concat(const octetStream& os); /// Reset reading - void reset_read_head() { ptr=0; } + void reset_read_head() { ptr = 0; bits[1].n = 0; } /// Set length to zero but keep allocation - void reset_write_head() { len=0; ptr=0; } + void reset_write_head() { len = 0; bits[0].n = 0; reset_read_head(); } // Move len back num void rewind_write_head(size_t num) { len-=num; } @@ -166,6 +178,9 @@ class octetStream template size_t get_int(); + void store_bit(char a); + char get_bit(); + /// Append big integer void store(const bigint& x); /// Read big integer @@ -292,6 +307,13 @@ inline void octetStream::reserve(size_t l) inline octet* octetStream::append(const size_t l) { + if (bits[0].n) + { + bits[0].n = 0; + store_int<1>(bits[0].buffer); + bits[0].buffer = 0; + } + if (len+l>mxlen) resize(len+l); octet* res = data + len; @@ -312,6 +334,7 @@ inline void octetStream::append_no_resize(const octet* x, const size_t l) inline octet* octetStream::consume(size_t l) { + bits[1].n = 0; if(ptr + l > len) throw runtime_error("insufficient data"); octet* res = data + ptr; @@ -326,9 +349,7 @@ inline void octetStream::consume(octet* x,const size_t l) inline void octetStream::store_int(size_t l, int n_bytes) { - resize(len+n_bytes); - encode_length(data+len,l,n_bytes); - len+=n_bytes; + encode_length(append(n_bytes), l, n_bytes); } inline size_t octetStream::get_int(int n_bytes) @@ -340,10 +361,8 @@ template inline void octetStream::store_int(size_t l) { assert(N_BYTES <= 8); - resize(len+N_BYTES); uint64_t tmp = htole64(l); - memcpy(data + len, &tmp, N_BYTES); - len+=N_BYTES; + memcpy(append(N_BYTES), &tmp, N_BYTES); } template @@ -355,12 +374,38 @@ inline size_t octetStream::get_int() return le64toh(tmp); } +inline void octetStream::store_bit(char a) +{ + auto& n = bits[0].n; + auto& buffer = bits[0].buffer; + + if (n == 8) + append(0); + + buffer |= (a & 1) << n; + n++; +} + +inline char octetStream::get_bit() +{ + auto& n = bits[1].n; + auto& buffer = bits[1].buffer; + + if (n == 0) + { + buffer = get_int<1>(); + n = 8; + } + + return (buffer >> (8 - n--)) & 1; +} + template inline void octetStream::Send(T socket_num) const { send(socket_num,len,LENGTH_SIZE); - send(socket_num,data,len); + send(socket_num, get_data(), len); }