Use bit packing to reduce communication.

This commit is contained in:
Marcel Keller
2023-05-15 13:33:28 +10:00
parent 8e735daab5
commit 2f76d73698
26 changed files with 142 additions and 71 deletions

View File

@@ -58,7 +58,7 @@ public:
void assign_zero() { *this = {}; } void assign_zero() { *this = {}; }
bool is_zero() { return *this == P256Element(); } bool is_zero() { return *this == P256Element(); }
void add(octetStream& os) { *this += os.get<P256Element>(); } void add(octetStream& os, int = -1) { *this += os.get<P256Element>(); }
void pack(octetStream& os, int = -1) const; void pack(octetStream& os, int = -1) const;
void unpack(octetStream& os, int = -1); void unpack(octetStream& os, int = -1);

View File

@@ -113,7 +113,7 @@ void Ciphertext::mul(const Ciphertext& c, const Rq_Element& ra)
::mul(cc1,ra,c.cc1); ::mul(cc1,ra,c.cc1);
} }
void Ciphertext::add(octetStream& os) void Ciphertext::add(octetStream& os, int)
{ {
Ciphertext tmp(*params); Ciphertext tmp(*params);
tmp.unpack(os); tmp.unpack(os);

View File

@@ -117,11 +117,11 @@ class Ciphertext
int level() const { return cc0.level(); } int level() const { return cc0.level(); }
/// Append to buffer /// Append to buffer
void pack(octetStream& o) const void pack(octetStream& o, int = -1) const
{ cc0.pack(o); cc1.pack(o); o.store(pk_id); } { cc0.pack(o); cc1.pack(o); o.store(pk_id); }
/// Read from buffer. Assumes parameters are set correctly /// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o) void unpack(octetStream& o, int = -1)
{ cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); }
void output(ostream& s) const void output(ostream& s) const
@@ -129,7 +129,7 @@ class Ciphertext
void input(istream& s) void input(istream& s)
{ cc0.input(s); cc1.input(s); s.read((char*)&pk_id, sizeof(pk_id)); } { cc0.input(s); cc1.input(s); s.read((char*)&pk_id, sizeof(pk_id)); }
void add(octetStream& os); void add(octetStream& os, int = -1);
size_t report_size(ReportType type) const { return cc0.report_size(type) + cc1.report_size(type); } size_t report_size(ReportType type) const { return cc0.report_size(type) + cc1.report_size(type); }
}; };

View File

@@ -45,10 +45,11 @@ class FHE_SK
const Rq_Element& s() const { return sk; } const Rq_Element& s() const { return sk; }
/// Append to buffer /// Append to buffer
void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } void pack(octetStream& os, int = -1) const { sk.pack(os); pr.pack(os); }
/// Read from buffer. Assumes parameters are set correctly /// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } void unpack(octetStream& os, int = -1)
{ sk.unpack(os, *params); pr.unpack(os); }
// Assumes Ring and prime of mess have already been set correctly // Assumes Ring and prime of mess have already been set correctly
// Ciphertext c must be at level 0 or an error occurs // Ciphertext c must be at level 0 or an error occurs
@@ -88,7 +89,8 @@ class FHE_SK
bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; } bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; }
void add(octetStream& os) { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; } void add(octetStream& os, int = -1)
{ FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const; void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const;

View File

@@ -109,7 +109,7 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
} }
} }
void Rq_Element::add(octetStream& os) void Rq_Element::add(octetStream& os, int)
{ {
Rq_Element tmp(*this); Rq_Element tmp(*this);
tmp.unpack(os); tmp.unpack(os);
@@ -298,7 +298,7 @@ void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)
partial_assign(x); partial_assign(x);
} }
void Rq_Element::pack(octetStream& o) const void Rq_Element::pack(octetStream& o, int) const
{ {
check_level(); check_level();
o.store(lev); o.store(lev);
@@ -306,7 +306,7 @@ void Rq_Element::pack(octetStream& o) const
a[i].pack(o); a[i].pack(o);
} }
void Rq_Element::unpack(octetStream& o) void Rq_Element::unpack(octetStream& o, int)
{ {
unsigned int ll; o.get(ll); lev=ll; unsigned int ll; o.get(ll); lev=ll;
check_level(); check_level();

View File

@@ -94,7 +94,7 @@ protected:
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
void add(octetStream& os); void add(octetStream& os, int = -1);
template<class S> template<class S>
Rq_Element& operator+=(const vector<S>& other); Rq_Element& operator+=(const vector<S>& other);
@@ -157,8 +157,8 @@ protected:
* For unpack we assume the prData for a0 and a1 has been assigned * For unpack we assume the prData for a0 and a1 has been assigned
* correctly already * correctly already
*/ */
void pack(octetStream& o) const; void pack(octetStream& o, int = -1) const;
void unpack(octetStream& o); void unpack(octetStream& o, int = -1);
// without prior initialization // without prior initialization
void unpack(octetStream& o, const FHE_Params& params); void unpack(octetStream& o, const FHE_Params& params);

View File

@@ -33,4 +33,9 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
prepare_mul(x, y, n); prepare_mul(x, y, n);
} }
void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n)
{
super::prepare_mul(x.mask(n), y.mask(n), n);
}
} /* namespace GC */ } /* namespace GC */

View File

@@ -24,6 +24,7 @@ public:
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat); bool repeat);
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n);
}; };
} /* namespace GC */ } /* namespace GC */

View File

@@ -51,6 +51,11 @@ public:
return other * *this; return other * *this;
} }
void add(octetStream& os, int = -1)
{
*this += os.get<Bit>();
}
void pack(octetStream& os, int = -1) const void pack(octetStream& os, int = -1) const
{ {
super::pack(os, 1); super::pack(os, 1);

View File

@@ -56,7 +56,12 @@ public:
void extend_bit(BitVec_& res, int) const { res = extend_bit(); } void extend_bit(BitVec_& res, int) const { res = extend_bit(); }
void add(octetStream& os) { *this += os.get<BitVec_>(); } void add(octetStream& os, int n_bits)
{
BitVec_ tmp;
tmp.unpack(os, n_bits);
*this += tmp;
}
void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; } void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; }
@@ -70,7 +75,7 @@ public:
if (n == -1) if (n == -1)
pack(os); pack(os);
else if (n == 1) else if (n == 1)
os.store_int<1>(this->a & 1); os.store_bit(this->a);
else else
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
} }
@@ -80,7 +85,7 @@ public:
if (n == -1) if (n == -1)
unpack(os); unpack(os);
else if (n == 1) else if (n == 1)
this->a = os.get_int<1>(); this->a = os.get_bit();
else else
this->a = os.get_int(DIV_CEIL(n, 8)); this->a = os.get_int(DIV_CEIL(n, 8));
} }

View File

@@ -151,7 +151,7 @@ public:
bool operator==(const Z2<K>& other) const; bool operator==(const Z2<K>& other) const;
bool operator!=(const Z2<K>& other) const { return not (*this == other); } bool operator!=(const Z2<K>& other) const { return not (*this == other); }
void add(octetStream& os) { *this += (os.consume(size())); } void add(octetStream& os, int = -1) { *this += (os.consume(size())); }
Z2 lazy_add(const Z2& x) const; Z2 lazy_add(const Z2& x) const;
Z2 lazy_mul(const Z2& x) const; Z2 lazy_mul(const Z2& x) const;

View File

@@ -138,7 +138,7 @@ protected:
{ a=x.a^y.a; } { a=x.a^y.a; }
void add(octet* x) void add(octet* x)
{ a^=*(U*)(x); } { a^=*(U*)(x); }
void add(octetStream& os) void add(octetStream& os, int = -1)
{ add(os.consume(size())); } { add(os.consume(size())); }
void sub(const gf2n_& x,const gf2n_& y) void sub(const gf2n_& x,const gf2n_& y)
{ a=x.a^y.a; } { a=x.a^y.a; }

View File

@@ -187,7 +187,7 @@ class gfp_ : public ValueInterface
bool operator!=(const gfp_& y) const { return !equal(y); } bool operator!=(const gfp_& y) const { return !equal(y); }
// x+y // x+y
void add(octetStream& os) void add(octetStream& os, int = -1)
{ add(os.consume(size())); } { add(os.consume(size())); }
void add(const gfp_& x,const gfp_& y) void add(const gfp_& x,const gfp_& y)
{ ZpD.Add<L>(a.x,x.a.x,y.a.x); } { ZpD.Add<L>(a.x,x.a.x,y.a.x); }

View File

@@ -290,7 +290,7 @@ bool gfpvar_<X, L>::operator !=(const gfpvar_<X, L>& other) const
} }
template<int X, int L> template<int X, int L>
void gfpvar_<X, L>::add(octetStream& other) void gfpvar_<X, L>::add(octetStream& other, int)
{ {
*this += other.get<gfpvar_<X, L>>(); *this += other.get<gfpvar_<X, L>>();
} }

View File

@@ -148,7 +148,7 @@ public:
bool operator==(const gfpvar_& other) const; bool operator==(const gfpvar_& other) const;
bool operator!=(const gfpvar_& other) const; bool operator!=(const gfpvar_& other) const;
void add(octetStream& other); void add(octetStream& other, int = -1);
void negate(); void negate();

View File

@@ -125,6 +125,7 @@ void InputBase<T>::add_from_all(const typename T::open_type& input, int n_bits)
template<class T> template<class T>
void Input<T>::send_mine() void Input<T>::send_mine()
{ {
this->os[P.my_num()].append(0);
P.send_all(this->os[P.my_num()]); P.send_all(this->os[P.my_num()]);
} }

View File

@@ -35,12 +35,17 @@ class TreeSum
void start(vector<T>& values, const Player& P); void start(vector<T>& values, const Player& P);
void finish(vector<T>& values, const Player& P); void finish(vector<T>& values, const Player& P);
void add_openings(vector<T>& values, const Player& P, int sum_players,
int last_sum_players, int send_player);
protected: protected:
int base_player; int base_player;
int opening_sum; int opening_sum;
int max_broadcast; int max_broadcast;
octetStream os; octetStream os;
vector<int> lengths;
void ReceiveValues(vector<T>& values, const Player& P, int sender); void ReceiveValues(vector<T>& values, const Player& P, int sender);
virtual void AddToValues(vector<T>& values) { (void)values; } virtual void AddToValues(vector<T>& values) { (void)values; }
@@ -157,10 +162,6 @@ using MAC_Check_Z2k_ = MAC_Check_Z2k<typename W::open_type,
typename W::mac_key_type, typename W::open_type, W>; typename W::mac_key_type, typename W::open_type, W>;
template<class T>
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum<T>& MC);
/** /**
* SPDZ opening protocol with MAC check (pairwise communication) * SPDZ opening protocol with MAC check (pairwise communication)
*/ */
@@ -239,13 +240,16 @@ size_t TreeSum<T>::report_size(ReportType type)
} }
template<class T> template<class T>
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum<T>& MC) void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
int sum_players, int last_sum_players, int send_player)
{ {
auto& MC = *this;
MC.player_timers.resize(P.num_players()); MC.player_timers.resize(P.num_players());
vector<octetStream>& oss = MC.oss; vector<octetStream>& oss = MC.oss;
oss.resize(P.num_players()); oss.resize(P.num_players());
vector<int> senders; vector<int> senders;
senders.reserve(P.num_players()); senders.reserve(P.num_players());
bool use_lengths = values.size() == lengths.size();
for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players; for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players;
relative_sender < last_sum_players; relative_sender += sum_players) relative_sender < last_sum_players; relative_sender += sum_players)
@@ -266,7 +270,7 @@ void add_openings(vector<T>& values, const Player& P, int sum_players, int last_
MC.timers[SUM].start(); MC.timers[SUM].start();
for (unsigned int i=0; i<values.size(); i++) for (unsigned int i=0; i<values.size(); i++)
{ {
values[i].add(oss[j]); values[i].add(oss[j], use_lengths ? lengths[i] : -1);
} }
MC.timers[SUM].stop(); MC.timers[SUM].stop();
} }
@@ -278,6 +282,7 @@ void TreeSum<T>::start(vector<T>& values, const Player& P)
os.reset_write_head(); os.reset_write_head();
int sum_players = P.num_players(); int sum_players = P.num_players();
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
bool use_lengths = values.size() == lengths.size();
while (true) while (true)
{ {
// summing phase // summing phase
@@ -289,7 +294,8 @@ void TreeSum<T>::start(vector<T>& values, const Player& P)
{ {
// send to the player up the tree // send to the player up the tree
for (unsigned int i=0; i<values.size(); i++) for (unsigned int i=0; i<values.size(); i++)
{ values[i].pack(os); } values[i].pack(os, use_lengths ? lengths[i] : -1);
os.append(0);
int receiver = positive_modulo(base_player + my_relative_num % sum_players, P.num_players()); int receiver = positive_modulo(base_player + my_relative_num % sum_players, P.num_players());
timers[SEND].start(); timers[SEND].start();
P.send_to(receiver,os); P.send_to(receiver,os);
@@ -300,7 +306,7 @@ void TreeSum<T>::start(vector<T>& values, const Player& P)
{ {
// if receiving, add the values // if receiving, add the values
timers[RECV_ADD].start(); timers[RECV_ADD].start();
add_openings<T>(values, P, sum_players, last_sum_players, base_player, *this); add_openings(values, P, sum_players, last_sum_players, base_player);
timers[RECV_ADD].stop(); timers[RECV_ADD].stop();
} }
} }
@@ -311,7 +317,8 @@ void TreeSum<T>::start(vector<T>& values, const Player& P)
os.reset_write_head(); os.reset_write_head();
size_t n = values.size(); size_t n = values.size();
for (unsigned int i=0; i<n; i++) for (unsigned int i=0; i<n; i++)
{ values[i].pack(os); } values[i].pack(os, use_lengths ? lengths[i] : -1);
os.append(0);
timers[BCAST].start(); timers[BCAST].start();
for (int i = 1; i < max_broadcast && i < P.num_players(); i++) for (int i = 1; i < max_broadcast && i < P.num_players(); i++)
{ {
@@ -357,8 +364,9 @@ void TreeSum<T>::ReceiveValues(vector<T>& values, const Player& P, int sender)
timers[RECV_SUM].start(); timers[RECV_SUM].start();
P.receive_player(sender, os); P.receive_player(sender, os);
timers[RECV_SUM].stop(); timers[RECV_SUM].stop();
bool use_lengths = values.size() == lengths.size();
for (unsigned int i = 0; i < values.size(); i++) for (unsigned int i = 0; i < values.size(); i++)
values[i].unpack(os); values[i].unpack(os, use_lengths ? lengths[i] : -1);
AddToValues(values); AddToValues(values);
} }

View File

@@ -208,7 +208,7 @@ void MaliciousBitOnlyRepPrep<T>::buffer_bits()
assert(MC.open(f, P) * MC.open(f, P) == MC.open(h, P)); assert(MC.open(f, P) * MC.open(f, P) == MC.open(h, P));
#endif #endif
} }
auto t = Create_Random<typename T::clear>(P); auto t = Create_Random<typename T::open_type>(P);
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
{ {
T& a = bits[i]; T& a = bits[i];
@@ -216,7 +216,7 @@ void MaliciousBitOnlyRepPrep<T>::buffer_bits()
masked.push_back(t * a - f); masked.push_back(t * a - f);
} }
MC.POpen(opened, masked, P); MC.POpen(opened, masked, P);
typename T::clear t2 = t * t; typename T::open_type t2 = t * t;
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
{ {
T& a = bits[i]; T& a = bits[i];

View File

@@ -136,6 +136,9 @@ void Rep4<T>::prepare_joint_input(int sender, int backup, int receiver,
throw not_implemented(); throw not_implemented();
} }
} }
for (auto& x : send_os)
x.append(0);
} }
template<class T> template<class T>
@@ -176,6 +179,7 @@ void Rep4<T>::finalize_joint_input(int sender, int backup, int receiver,
x.res[index] += res[1]; x.res[index] += res[1];
} }
os->consume(0);
receive_hashes[sender][backup].update(start, receive_hashes[sender][backup].update(start,
os->get_data_ptr() - start); os->get_data_ptr() - start);
} }

View File

@@ -200,6 +200,7 @@ void Replicated<T>::prepare_reshare(const typename T::clear& share,
template<class T> template<class T>
void Replicated<T>::exchange() void Replicated<T>::exchange()
{ {
os[0].append(0);
if (os[0].get_length() > 0) if (os[0].get_length() > 0)
P.pass_around(os[0], os[1], 1); P.pass_around(os[0], os[1], 1);
this->rounds++; this->rounds++;
@@ -208,6 +209,7 @@ void Replicated<T>::exchange()
template<class T> template<class T>
void Replicated<T>::start_exchange() void Replicated<T>::start_exchange()
{ {
os[0].append(0);
P.send_relative(1, os[0]); P.send_relative(1, os[0]);
this->rounds++; this->rounds++;
} }

View File

@@ -47,12 +47,16 @@ void ReplicatedInput<T>::add_other(int player, int)
template<class T> template<class T>
void ReplicatedInput<T>::send_mine() void ReplicatedInput<T>::send_mine()
{ {
for (auto& x : os)
x.append(0);
P.send_relative(os); P.send_relative(os);
} }
template<class T> template<class T>
void ReplicatedInput<T>::exchange() void ReplicatedInput<T>::exchange()
{ {
for (auto& x : os)
x.append(0);
bool receive = expect[P.get_player(1)]; bool receive = expect[P.get_player(1)];
bool send = not os[1].empty(); bool send = not os[1].empty();
auto& dest = InputBase<T>::os[P.get_player(1)]; auto& dest = InputBase<T>::os[P.get_player(1)];

View File

@@ -16,7 +16,6 @@ template<class T>
class SemiMC : public TreeSum<typename T::open_type>, public MAC_Check_Base<T> class SemiMC : public TreeSum<typename T::open_type>, public MAC_Check_Base<T>
{ {
protected: protected:
vector<int> lengths;
public: public:
// emulate MAC_Check // emulate MAC_Check

View File

@@ -14,15 +14,15 @@ template<class T>
void SemiMC<T>::init_open(const Player& P, int n) void SemiMC<T>::init_open(const Player& P, int n)
{ {
MAC_Check_Base<T>::init_open(P, n); MAC_Check_Base<T>::init_open(P, n);
lengths.clear(); this->lengths.clear();
lengths.reserve(n); this->lengths.reserve(n);
} }
template<class T> template<class T>
void SemiMC<T>::prepare_open(const T& secret, int n_bits) void SemiMC<T>::prepare_open(const T& secret, int n_bits)
{ {
this->values.push_back(secret); this->values.push_back(secret);
lengths.push_back(n_bits); this->lengths.push_back(n_bits);
} }
template<class T> template<class T>
@@ -53,6 +53,7 @@ void DirectSemiMC<T>::exchange_(const PlayerBase& P)
assert(this->values.size() == this->lengths.size()); assert(this->values.size() == this->lengths.size());
for (size_t i = 0; i < this->lengths.size(); i++) for (size_t i = 0; i < this->lengths.size(); i++)
this->values[i].pack(oss.mine, this->lengths[i]); this->values[i].pack(oss.mine, this->lengths[i]);
oss.mine.append(0);
P.unchecked_broadcast(oss); P.unchecked_broadcast(oss);
size_t n = P.num_players(); size_t n = P.num_players();
size_t me = P.my_num(); size_t me = P.my_num();

View File

@@ -39,6 +39,7 @@ public:
octetStream tmp(v.size() * sizeof(T)); octetStream tmp(v.size() * sizeof(T));
for (size_t i = 0; i < v.size(); i++) for (size_t i = 0; i < v.size(); i++)
v[i].pack(tmp, bit_lengths[i]); v[i].pack(tmp, bit_lengths[i]);
tmp.append(0);
update(tmp); update(tmp);
} }
void update(const string& str); void update(const string& str);

View File

@@ -41,6 +41,7 @@ void octetStream::assign(const octetStream& os)
len=os.len; len=os.len;
memcpy(data,os.data,len*sizeof(octet)); memcpy(data,os.data,len*sizeof(octet));
ptr=os.ptr; ptr=os.ptr;
bits = os.bits;
} }
@@ -68,6 +69,7 @@ octetStream::octetStream(const octetStream& os)
data=new octet[mxlen]; data=new octet[mxlen];
memcpy(data,os.data,len*sizeof(octet)); memcpy(data,os.data,len*sizeof(octet));
ptr=os.ptr; ptr=os.ptr;
bits = os.bits;
} }
octetStream::octetStream(FlexBuffer& buffer) octetStream::octetStream(FlexBuffer& buffer)
@@ -123,26 +125,20 @@ bool octetStream::equals(const octetStream& a) const
void octetStream::append_random(size_t num) void octetStream::append_random(size_t num)
{ {
resize(len+num); randombytes_buf(append(num), num);
randombytes_buf(data+len, num);
len+=num;
} }
void octetStream::concat(const octetStream& os) void octetStream::concat(const octetStream& os)
{ {
resize(len+os.len); memcpy(append(os.len), os.data, os.len*sizeof(octet));
memcpy(data+len,os.data,os.len*sizeof(octet));
len+=os.len;
} }
void octetStream::store_bytes(octet* x, const size_t l) void octetStream::store_bytes(octet* x, const size_t l)
{ {
resize(len+4+l); encode_length(append(4), l, 4);
encode_length(data+len,l,4); len+=4; memcpy(append(l), x, l*sizeof(octet));
memcpy(data+len,x,l*sizeof(octet));
len+=l;
} }
void octetStream::get_bytes(octet* ans, size_t& length) void octetStream::get_bytes(octet* ans, size_t& length)
@@ -153,9 +149,7 @@ void octetStream::get_bytes(octet* ans, size_t& length)
void octetStream::store(int l) void octetStream::store(int l)
{ {
resize(len+4); encode_length(append(4), l, 4);
encode_length(data+len,l,4);
len+=4;
} }
@@ -168,15 +162,9 @@ void octetStream::get(int& l)
void octetStream::store(const bigint& x) void octetStream::store(const bigint& x)
{ {
size_t num=numBytes(x); size_t num=numBytes(x);
resize(len+num+5); *append(1) = x < 0;
encode_length(append(4), num, 4);
(data+len)[0]=0; bytesFromBigint(append(num), x, num);
if (x<0) { (data+len)[0]=1; }
len++;
encode_length(data+len,num,4); len+=4;
bytesFromBigint(data+len,x,num);
len+=num;
} }

View File

@@ -48,6 +48,18 @@ class octetStream
size_t len,mxlen,ptr; // len is the "write head", ptr is the "read head" size_t len,mxlen,ptr; // len is the "write head", ptr is the "read head"
octet *data; octet *data;
class BitBuffer
{
public:
uint8_t n, buffer;
BitBuffer() : n(0), buffer(0)
{
}
};
// buffers for bit packing
array<BitBuffer, 2> bits;
void reset(); void reset();
public: public:
@@ -86,9 +98,9 @@ class octetStream
/// Allocation /// Allocation
size_t get_max_length() const { return mxlen; } size_t get_max_length() const { return mxlen; }
/// Data pointer /// Data pointer
octet* get_data() const { return data; } octet* get_data() const { assert(bits[0].n == 0); return data; }
/// Read pointer /// Read pointer
octet* get_data_ptr() const { return data + ptr; } octet* get_data_ptr() const { assert(bits[1].n == 0); return data + ptr; }
/// Whether done reading /// Whether done reading
bool done() const { return ptr == len; } bool done() const { return ptr == len; }
@@ -111,9 +123,9 @@ class octetStream
void concat(const octetStream& os); void concat(const octetStream& os);
/// Reset reading /// Reset reading
void reset_read_head() { ptr=0; } void reset_read_head() { ptr = 0; bits[1].n = 0; }
/// Set length to zero but keep allocation /// Set length to zero but keep allocation
void reset_write_head() { len=0; ptr=0; } void reset_write_head() { len = 0; bits[0].n = 0; reset_read_head(); }
// Move len back num // Move len back num
void rewind_write_head(size_t num) { len-=num; } void rewind_write_head(size_t num) { len-=num; }
@@ -166,6 +178,9 @@ class octetStream
template<int N_BYTES> template<int N_BYTES>
size_t get_int(); size_t get_int();
void store_bit(char a);
char get_bit();
/// Append big integer /// Append big integer
void store(const bigint& x); void store(const bigint& x);
/// Read big integer /// Read big integer
@@ -292,6 +307,13 @@ inline void octetStream::reserve(size_t l)
inline octet* octetStream::append(const size_t l) inline octet* octetStream::append(const size_t l)
{ {
if (bits[0].n)
{
bits[0].n = 0;
store_int<1>(bits[0].buffer);
bits[0].buffer = 0;
}
if (len+l>mxlen) if (len+l>mxlen)
resize(len+l); resize(len+l);
octet* res = data + len; octet* res = data + len;
@@ -312,6 +334,7 @@ inline void octetStream::append_no_resize(const octet* x, const size_t l)
inline octet* octetStream::consume(size_t l) inline octet* octetStream::consume(size_t l)
{ {
bits[1].n = 0;
if(ptr + l > len) if(ptr + l > len)
throw runtime_error("insufficient data"); throw runtime_error("insufficient data");
octet* res = data + ptr; octet* res = data + ptr;
@@ -326,9 +349,7 @@ inline void octetStream::consume(octet* x,const size_t l)
inline void octetStream::store_int(size_t l, int n_bytes) inline void octetStream::store_int(size_t l, int n_bytes)
{ {
resize(len+n_bytes); encode_length(append(n_bytes), l, n_bytes);
encode_length(data+len,l,n_bytes);
len+=n_bytes;
} }
inline size_t octetStream::get_int(int n_bytes) inline size_t octetStream::get_int(int n_bytes)
@@ -340,10 +361,8 @@ template<int N_BYTES>
inline void octetStream::store_int(size_t l) inline void octetStream::store_int(size_t l)
{ {
assert(N_BYTES <= 8); assert(N_BYTES <= 8);
resize(len+N_BYTES);
uint64_t tmp = htole64(l); uint64_t tmp = htole64(l);
memcpy(data + len, &tmp, N_BYTES); memcpy(append(N_BYTES), &tmp, N_BYTES);
len+=N_BYTES;
} }
template<int N_BYTES> template<int N_BYTES>
@@ -355,12 +374,38 @@ inline size_t octetStream::get_int()
return le64toh(tmp); return le64toh(tmp);
} }
inline void octetStream::store_bit(char a)
{
auto& n = bits[0].n;
auto& buffer = bits[0].buffer;
if (n == 8)
append(0);
buffer |= (a & 1) << n;
n++;
}
inline char octetStream::get_bit()
{
auto& n = bits[1].n;
auto& buffer = bits[1].buffer;
if (n == 0)
{
buffer = get_int<1>();
n = 8;
}
return (buffer >> (8 - n--)) & 1;
}
template<class T> template<class T>
inline void octetStream::Send(T socket_num) const inline void octetStream::Send(T socket_num) const
{ {
send(socket_num,len,LENGTH_SIZE); send(socket_num,len,LENGTH_SIZE);
send(socket_num,data,len); send(socket_num, get_data(), len);
} }