Optimized matrix multiplication in Hemi.

This commit is contained in:
Marcel Keller
2021-09-17 14:29:28 +10:00
parent 5c6f101c12
commit 799929b801
151 changed files with 5262 additions and 748 deletions

View File

@@ -6,6 +6,7 @@
#include "AtlasSecret.h"
#include "TinyMC.h"
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Secret.hpp"

View File

@@ -92,9 +92,9 @@ public:
}
}
size_t data_sent()
NamedCommStats comm_stats()
{
return part_prep.data_sent();
return part_prep.comm_stats();
}
};

View File

@@ -108,7 +108,8 @@ public:
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, word input, int n_bits);
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
static FakeSecret constant(clear value, int = 0, mac_key_type = {}, int = -1)
{ return value; }
FakeSecret() {}
template <class T>

View File

@@ -68,6 +68,7 @@ enum
CLEAR_WRITE = 0x210,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
CONVCINT = 0x213,
REVEAL = 0x214,
STMSDCI = 0x215,

View File

@@ -84,12 +84,6 @@ public:
return new HashMaliciousRepMC<U>;
}
static U constant(const BitVec& other, int my_num, const BitVec& alphai)
{
(void) my_num, (void) alphai;
return other;
}
MalRepSecretBase() {}
template<class T>
MalRepSecretBase(const T& other) : super(other) {}

View File

@@ -143,7 +143,7 @@ public:
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
NoShare() {}

View File

@@ -86,6 +86,7 @@ public:
void xors(const vector<int>& args, size_t start, size_t end);
void xorc(const ::BaseInstruction& instruction);
void nots(const ::BaseInstruction& instruction);
void notcb(const ::BaseInstruction& instruction);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);
void andrs(const vector<int>& args) { and_(args, true); }

View File

@@ -257,6 +257,19 @@ void Processor<T>::nots(const ::BaseInstruction& instruction)
}
}
template<class T>
void Processor<T>::notcb(const ::BaseInstruction& instruction)
{
int total = instruction.get_n();
int unit = Clear::N_BITS;
for (int i = 0; i < DIV_CEIL(total, unit); i++)
{
int n = min(unit, total - i * unit);
C[instruction.get_r(0) + i] =
Clear(~C[instruction.get_r(1) + i].get()).mask(n);
}
}
template<class T>
void Processor<T>::andm(const ::BaseInstruction& instruction)
{

View File

@@ -30,7 +30,7 @@ public:
static MC* new_mc(typename super::mac_key_type) { return new MC; }
static This constant(const typename super::clear& constant, int my_num,
typename super::mac_key_type = {})
typename super::mac_key_type = {}, int = -1)
{
return Rep4Share<typename super::clear>::constant(constant, my_num);
}

View File

@@ -10,6 +10,7 @@
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Replicated.hpp"
#include "OT/NPartyTripleGenerator.hpp"
namespace GC
@@ -65,12 +66,12 @@ void SemiPrep::buffer_bits()
}
}
size_t SemiPrep::data_sent()
NamedCommStats SemiPrep::comm_stats()
{
if (triple_generator)
return triple_generator->data_sent();
return triple_generator->comm_stats();
else
return 0;
return {};
}
} /* namespace GC */

View File

@@ -53,7 +53,7 @@ public:
throw not_implemented();
}
size_t data_sent();
NamedCommStats comm_stats();
};
} /* namespace GC */

View File

@@ -102,16 +102,16 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
if (not this->machine.use_encryption and not T::dishonest_majority)
insecure("unencrypted communication");
Server* server = network_opts.start_networking(this->N, my_num);
network_opts.start_networking(this->N, my_num);
if (online_opts.live_prep)
if (T::needs_ot)
{
Player* P;
if (this->machine.use_encryption)
P = new CryptoPlayer(this->N, 0xFFFF);
P = new CryptoPlayer(this->N, "shareparty");
else
P = new PlainPlayer(this->N, 0xFFFF);
P = new PlainPlayer(this->N, "shareparty");
for (int i = 0; i < this->machine.nthreads; i++)
this->machine.ot_setups.push_back({*P, true});
delete P;
@@ -133,9 +133,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
this->run();
this->machine.write_memory(this->N.my_num());
if (server)
delete server;
}
template<class T>

View File

@@ -171,8 +171,8 @@ class ReplicatedSecret : public RepSecretBase<U, 2>
public:
typedef ReplicatedBase Protocol;
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
typename super::mac_key_type)
static ReplicatedSecret constant(const typename super::clear& value,
int my_num, typename super::mac_key_type, int = -1)
{
ReplicatedSecret res;
if (my_num < 2)

View File

@@ -58,8 +58,8 @@ public:
void pre_run();
void post_run() { ShareThread<T>::post_run(); }
size_t data_sent()
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
NamedCommStats comm_stats()
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
};
template<class T>

View File

@@ -56,7 +56,7 @@ public:
void join_tape();
void finish();
virtual size_t data_sent();
virtual NamedCommStats comm_stats();
};
template<class T>

View File

@@ -51,10 +51,11 @@ void Thread<T>::run()
singleton = this;
BaseMachine::s().thread_num = thread_num;
secure_prng.ReSeed();
string id = "T" + to_string(thread_num);
if (machine.use_encryption)
P = new CryptoPlayer(N, thread_num << 16);
P = new CryptoPlayer(N, id);
else
P = new PlainPlayer(N, thread_num << 16);
P = new PlainPlayer(N, id);
processor.open_input_file(N.my_num(), thread_num,
master.opts.cmd_private_input_file);
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
@@ -98,10 +99,10 @@ void Thread<T>::finish()
}
template<class T>
size_t GC::Thread<T>::data_sent()
NamedCommStats Thread<T>::comm_stats()
{
assert(P);
return P->comm_stats.total_data();
return P->comm_stats;
}
} /* namespace GC */

View File

@@ -58,7 +58,7 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
template<class T>
void ThreadMaster<T>::run()
{
P = new PlainPlayer(N, 0xff << 24);
P = new PlainPlayer(N, "main");
machine.load_schedule(progname);
@@ -87,12 +87,10 @@ void ThreadMaster<T>::run()
NamedCommStats stats = P->comm_stats;
ExecutionStats exe_stats;
size_t data_sent = P->comm_stats.total_data();
for (auto thread : threads)
{
stats += thread->P->comm_stats;
exe_stats += thread->processor.stats;
data_sent += thread->data_sent();
delete thread;
}
@@ -102,7 +100,7 @@ void ThreadMaster<T>::run()
stats.print();
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
}
} /* namespace GC */

View File

@@ -48,7 +48,7 @@ public:
void set_protocol(typename T::Protocol& protocol);
size_t data_sent();
NamedCommStats comm_stats();
};
}

View File

@@ -92,13 +92,13 @@ void GC::TinierSharePrep<T>::buffer_bits()
}
template<class T>
size_t TinierSharePrep<T>::data_sent()
NamedCommStats TinierSharePrep<T>::comm_stats()
{
size_t res = 0;
NamedCommStats res;
if (triple_generator)
res += triple_generator->data_sent();
res += triple_generator->comm_stats();
if (real_triple_generator)
res += real_triple_generator->data_sent();
res += real_triple_generator->comm_stats();
return res;
}

View File

@@ -70,11 +70,14 @@ public:
T::reveal_inst(processor, args);
}
static This constant(BitVec other, int my_num, mac_key_type alphai)
static This constant(BitVec other, int my_num, mac_key_type alphai,
int n_bits = -1)
{
if (n_bits < 0)
n_bits = other.length();
This res;
res.resize_regs(other.length());
for (int i = 0; i < other.length(); i++)
res.resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
return res;
}

View File

@@ -43,6 +43,7 @@
X(XORCB, processor.xorc(instruction)) \
X(XORCBI, C0.xor_(PC1, IMM)) \
X(NOTS, processor.nots(INST)) \
X(NOTCB, processor.notcb(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \
X(ANDS, T::ands(PROC, EXTRA)) \
X(ADDCB, C0 = PC1 + PC2) \
@@ -140,6 +141,7 @@
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS