mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Optimized matrix multiplication in Hemi.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "AtlasSecret.h"
|
||||
#include "TinyMC.h"
|
||||
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Secret.hpp"
|
||||
|
||||
@@ -92,9 +92,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_sent()
|
||||
NamedCommStats comm_stats()
|
||||
{
|
||||
return part_prep.data_sent();
|
||||
return part_prep.comm_stats();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -108,7 +108,8 @@ public:
|
||||
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
|
||||
static FakeSecret input(int from, word input, int n_bits);
|
||||
|
||||
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
|
||||
static FakeSecret constant(clear value, int = 0, mac_key_type = {}, int = -1)
|
||||
{ return value; }
|
||||
|
||||
FakeSecret() {}
|
||||
template <class T>
|
||||
|
||||
@@ -68,6 +68,7 @@ enum
|
||||
CLEAR_WRITE = 0x210,
|
||||
XORCBI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
NOTCB = 0x212,
|
||||
CONVCINT = 0x213,
|
||||
REVEAL = 0x214,
|
||||
STMSDCI = 0x215,
|
||||
|
||||
@@ -84,12 +84,6 @@ public:
|
||||
return new HashMaliciousRepMC<U>;
|
||||
}
|
||||
|
||||
static U constant(const BitVec& other, int my_num, const BitVec& alphai)
|
||||
{
|
||||
(void) my_num, (void) alphai;
|
||||
return other;
|
||||
}
|
||||
|
||||
MalRepSecretBase() {}
|
||||
template<class T>
|
||||
MalRepSecretBase(const T& other) : super(other) {}
|
||||
|
||||
@@ -143,7 +143,7 @@ public:
|
||||
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ public:
|
||||
void xors(const vector<int>& args, size_t start, size_t end);
|
||||
void xorc(const ::BaseInstruction& instruction);
|
||||
void nots(const ::BaseInstruction& instruction);
|
||||
void notcb(const ::BaseInstruction& instruction);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
void andrs(const vector<int>& args) { and_(args, true); }
|
||||
|
||||
@@ -257,6 +257,19 @@ void Processor<T>::nots(const ::BaseInstruction& instruction)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::notcb(const ::BaseInstruction& instruction)
|
||||
{
|
||||
int total = instruction.get_n();
|
||||
int unit = Clear::N_BITS;
|
||||
for (int i = 0; i < DIV_CEIL(total, unit); i++)
|
||||
{
|
||||
int n = min(unit, total - i * unit);
|
||||
C[instruction.get_r(0) + i] =
|
||||
Clear(~C[instruction.get_r(1) + i].get()).mask(n);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::andm(const ::BaseInstruction& instruction)
|
||||
{
|
||||
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
static MC* new_mc(typename super::mac_key_type) { return new MC; }
|
||||
|
||||
static This constant(const typename super::clear& constant, int my_num,
|
||||
typename super::mac_key_type = {})
|
||||
typename super::mac_key_type = {}, int = -1)
|
||||
{
|
||||
return Rep4Share<typename super::clear>::constant(constant, my_num);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
namespace GC
|
||||
@@ -65,12 +66,12 @@ void SemiPrep::buffer_bits()
|
||||
}
|
||||
}
|
||||
|
||||
size_t SemiPrep::data_sent()
|
||||
NamedCommStats SemiPrep::comm_stats()
|
||||
{
|
||||
if (triple_generator)
|
||||
return triple_generator->data_sent();
|
||||
return triple_generator->comm_stats();
|
||||
else
|
||||
return 0;
|
||||
return {};
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -102,16 +102,16 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
if (not this->machine.use_encryption and not T::dishonest_majority)
|
||||
insecure("unencrypted communication");
|
||||
|
||||
Server* server = network_opts.start_networking(this->N, my_num);
|
||||
network_opts.start_networking(this->N, my_num);
|
||||
|
||||
if (online_opts.live_prep)
|
||||
if (T::needs_ot)
|
||||
{
|
||||
Player* P;
|
||||
if (this->machine.use_encryption)
|
||||
P = new CryptoPlayer(this->N, 0xFFFF);
|
||||
P = new CryptoPlayer(this->N, "shareparty");
|
||||
else
|
||||
P = new PlainPlayer(this->N, 0xFFFF);
|
||||
P = new PlainPlayer(this->N, "shareparty");
|
||||
for (int i = 0; i < this->machine.nthreads; i++)
|
||||
this->machine.ot_setups.push_back({*P, true});
|
||||
delete P;
|
||||
@@ -133,9 +133,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
this->run();
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -171,8 +171,8 @@ class ReplicatedSecret : public RepSecretBase<U, 2>
|
||||
public:
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
|
||||
typename super::mac_key_type)
|
||||
static ReplicatedSecret constant(const typename super::clear& value,
|
||||
int my_num, typename super::mac_key_type, int = -1)
|
||||
{
|
||||
ReplicatedSecret res;
|
||||
if (my_num < 2)
|
||||
|
||||
@@ -58,8 +58,8 @@ public:
|
||||
void pre_run();
|
||||
void post_run() { ShareThread<T>::post_run(); }
|
||||
|
||||
size_t data_sent()
|
||||
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
|
||||
NamedCommStats comm_stats()
|
||||
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -56,7 +56,7 @@ public:
|
||||
void join_tape();
|
||||
void finish();
|
||||
|
||||
virtual size_t data_sent();
|
||||
virtual NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -51,10 +51,11 @@ void Thread<T>::run()
|
||||
singleton = this;
|
||||
BaseMachine::s().thread_num = thread_num;
|
||||
secure_prng.ReSeed();
|
||||
string id = "T" + to_string(thread_num);
|
||||
if (machine.use_encryption)
|
||||
P = new CryptoPlayer(N, thread_num << 16);
|
||||
P = new CryptoPlayer(N, id);
|
||||
else
|
||||
P = new PlainPlayer(N, thread_num << 16);
|
||||
P = new PlainPlayer(N, id);
|
||||
processor.open_input_file(N.my_num(), thread_num,
|
||||
master.opts.cmd_private_input_file);
|
||||
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
|
||||
@@ -98,10 +99,10 @@ void Thread<T>::finish()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t GC::Thread<T>::data_sent()
|
||||
NamedCommStats Thread<T>::comm_stats()
|
||||
{
|
||||
assert(P);
|
||||
return P->comm_stats.total_data();
|
||||
return P->comm_stats;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -58,7 +58,7 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
|
||||
template<class T>
|
||||
void ThreadMaster<T>::run()
|
||||
{
|
||||
P = new PlainPlayer(N, 0xff << 24);
|
||||
P = new PlainPlayer(N, "main");
|
||||
|
||||
machine.load_schedule(progname);
|
||||
|
||||
@@ -87,12 +87,10 @@ void ThreadMaster<T>::run()
|
||||
|
||||
NamedCommStats stats = P->comm_stats;
|
||||
ExecutionStats exe_stats;
|
||||
size_t data_sent = P->comm_stats.total_data();
|
||||
for (auto thread : threads)
|
||||
{
|
||||
stats += thread->P->comm_stats;
|
||||
exe_stats += thread->processor.stats;
|
||||
data_sent += thread->data_sent();
|
||||
delete thread;
|
||||
}
|
||||
|
||||
@@ -102,7 +100,7 @@ void ThreadMaster<T>::run()
|
||||
stats.print();
|
||||
|
||||
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
|
||||
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -48,7 +48,7 @@ public:
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -92,13 +92,13 @@ void GC::TinierSharePrep<T>::buffer_bits()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TinierSharePrep<T>::data_sent()
|
||||
NamedCommStats TinierSharePrep<T>::comm_stats()
|
||||
{
|
||||
size_t res = 0;
|
||||
NamedCommStats res;
|
||||
if (triple_generator)
|
||||
res += triple_generator->data_sent();
|
||||
res += triple_generator->comm_stats();
|
||||
if (real_triple_generator)
|
||||
res += real_triple_generator->data_sent();
|
||||
res += real_triple_generator->comm_stats();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -70,11 +70,14 @@ public:
|
||||
T::reveal_inst(processor, args);
|
||||
}
|
||||
|
||||
static This constant(BitVec other, int my_num, mac_key_type alphai)
|
||||
static This constant(BitVec other, int my_num, mac_key_type alphai,
|
||||
int n_bits = -1)
|
||||
{
|
||||
if (n_bits < 0)
|
||||
n_bits = other.length();
|
||||
This res;
|
||||
res.resize_regs(other.length());
|
||||
for (int i = 0; i < other.length(); i++)
|
||||
res.resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
X(XORCB, processor.xorc(instruction)) \
|
||||
X(XORCBI, C0.xor_(PC1, IMM)) \
|
||||
X(NOTS, processor.nots(INST)) \
|
||||
X(NOTCB, processor.notcb(INST)) \
|
||||
X(ANDRS, T::andrs(PROC, EXTRA)) \
|
||||
X(ANDS, T::ands(PROC, EXTRA)) \
|
||||
X(ADDCB, C0 = PC1 + PC2) \
|
||||
@@ -140,6 +141,7 @@
|
||||
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
|
||||
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
|
||||
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
|
||||
X(CRASH, if (I0.get()) throw crash_requested()) \
|
||||
|
||||
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user