Rep4, SPDZ-wise, MNIST training.

This commit is contained in:
Marcel Keller
2020-10-28 11:20:52 +11:00
parent 53f9b023dc
commit f42e614399
184 changed files with 5837 additions and 820 deletions

View File

@@ -64,7 +64,7 @@ void BitAdder::add(vector<vector<T> >& res,
int n_bits = summands.size();
for (size_t i = begin; i < end; i++)
res[i].resize(n_bits + 1);
res.at(i).resize(n_bits + 1);
size_t n_items = end - begin;

View File

@@ -8,11 +8,12 @@
#include "GC/square64.h"
#include "GC/Processor.hpp"
#include "GC/ShareSecret.hpp"
#include "Processor/Input.hpp"
namespace GC
{
SwitchableOutput FakeSecret::out;
const int FakeSecret::default_length;
void FakeSecret::load_clear(int n, const Integer& x)
@@ -87,6 +88,14 @@ FakeSecret FakeSecret::input(int from, word input, int n_bits)
return input;
}
void FakeSecret::inputbvec(Processor<FakeSecret>& processor,
ProcessorBase& input_processor, const vector<int>& args)
{
Input input;
input.reset_all(*ShareThread<FakeSecret>::s().P);
processor.inputbvec(input, input_processor, args, 0);
}
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
bool repeat)
{
@@ -96,4 +105,19 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
*this = BitVec(x & y).mask(n);
}
void FakeSecret::my_input(Input& inputter, BitVec value, int n_bits)
{
inputter.add_mine(value, n_bits);
}
void FakeSecret::other_input(Input&, int, int)
{
throw runtime_error("emulation is supposed to be lonely");
}
void FakeSecret::finalize_input(Input& inputter, int from, int n_bits)
{
*this = inputter.finalize(from, n_bits);
}
} /* namespace GC */

View File

@@ -24,6 +24,8 @@
#include <random>
#include <fstream>
class ProcessorBase;
namespace GC
{
@@ -53,7 +55,10 @@ public:
typedef FakeProtocol<FakeSecret> Protocol;
typedef FakeInput<FakeSecret> Input;
typedef SwitchableOutput out_type;
static string type_string() { return "fake secret"; }
static string type_short() { return "emulB"; }
static string phase_name() { return "Faking"; }
static const int default_length = 64;
@@ -62,7 +67,8 @@ public:
static const bool actual_inputs = true;
static SwitchableOutput out;
static const true_type invertible;
static const true_type characteristic_two;
static DataFieldType field_type() { return DATA_GF2; }
@@ -87,8 +93,8 @@ public:
template <class T>
static void inputb(T& processor, ArithmeticProcessor&, const vector<int>& args)
{ processor.input(args); }
template <class T, class U>
static void inputbvec(T&, U&, const vector<int>&) { throw not_implemented(); }
static void inputbvec(Processor<FakeSecret>& processor,
ProcessorBase& input_processor, const vector<int>& args);
template <class T>
static void reveal_inst(T& processor, const vector<int>& args)
{ processor.reveal(args); }
@@ -136,6 +142,14 @@ public:
void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; }
void invert(FakeSecret) { throw not_implemented(); }
void input(istream&, bool) { throw not_implemented(); }
bool operator<(FakeSecret) const { return false; }
void my_input(Input& inputter, BitVec value, int n_bits);
void other_input(Input& inputter, int from, int n_bits = 1);
void finalize_input(Input& inputter, int from, int n_bits);
};
} /* namespace GC */

View File

@@ -47,13 +47,6 @@ template<class T>
void Machine<T>::load_schedule(string progname)
{
BaseMachine::load_schedule(progname);
for (auto i : {1, 0, 0})
{
int n;
inpf >> n;
if (n != i)
throw runtime_error("old schedule format not supported");
}
print_compiler();
}

View File

@@ -7,6 +7,7 @@
#define GC_NOSHARE_H_
#include "Processor/DummyProtocol.h"
#include "BMR/Register.h"
#include "Tools/SwitchableOutput.h"
class InputArgs;
@@ -41,6 +42,11 @@ public:
return 0;
}
static string type_string()
{
return "no";
}
static void fail()
{
throw runtime_error("VM does not support binary circuits");
@@ -93,8 +99,6 @@ public:
static const bool expensive_triples = false;
static const bool is_real = false;
static SwitchableOutput out;
static MC* new_mc(mac_key_type)
{
return new MC;
@@ -130,7 +134,7 @@ public:
NoValue::fail();
}
static void inputb(Processor<NoShare>&, ArithmeticProcessor&, const vector<int>&) { fail(); }
static void inputb(Processor<NoShare>&, const ArithmeticProcessor&, const vector<int>&) { fail(); }
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
@@ -139,6 +143,10 @@ public:
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static void xors(Processor<NoShare>&, vector<int>) { fail(); }
static void ands(Processor<NoShare>&, vector<int>) { fail(); }
static void andrs(Processor<NoShare>&, vector<int>) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
NoShare() {}
@@ -161,8 +169,8 @@ public:
void operator^=(NoShare) { fail(); }
NoShare operator+(const NoShare&) const { fail(); return {}; }
NoShare operator-(NoShare) const { fail(); return 0; }
NoShare operator*(NoValue) const { fail(); return 0; }
NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
@@ -172,6 +180,8 @@ public:
NoShare get_bit(int) const { fail(); return {}; }
void invert(int, NoShare) { fail(); }
void input(istream&, bool) { fail(); }
};
} /* namespace GC */

View File

@@ -44,6 +44,8 @@ public:
Timer xor_timer;
typename T::out_type out;
Processor(Machine<T>& machine);
Processor(Memories<T>& memories, Machine<T>* machine = 0);
~Processor();

View File

@@ -301,15 +301,15 @@ void Processor<T>::print_reg(int reg, int n, int size)
bigint output;
for (int i = 0; i < size; i++)
output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i);
T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
print_str(n);
T::out << endl << flush;
out << endl << flush;
}
template <class T>
void Processor<T>::print_reg_plain(Clear& value)
{
T::out << hex << showbase << value << dec << flush;
out << hex << showbase << value << dec << flush;
}
template <class T>
@@ -323,7 +323,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
n_shift = sizeof(value.get()) * 8 - n_bits;
if (n_shift > 63)
n_shift = 0;
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
out << dec << (value.get() << n_shift >> n_shift) << flush;
}
else
{
@@ -334,26 +334,26 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
}
if (tmp >= bigint(1) << (n_bits - 1))
tmp -= bigint(1) << n_bits;
T::out << dec << tmp << flush;
out << dec << tmp << flush;
}
}
template <class T>
void Processor<T>::print_chr(int n)
{
T::out << (char)n << flush;
out << (char)n << flush;
}
template <class T>
void Processor<T>::print_str(int n)
{
T::out << string((char*)&n,sizeof(n)) << flush;
out << string((char*)&n,sizeof(n)) << flush;
}
template <class T>
void Processor<T>::print_float(const vector<int>& args)
{
bigint::output_float(T::out,
bigint::output_float(out,
bigint::get_float(C[args[0]], C[args[1]], C[args[2]], C[args[3]]),
C[args[4]]);
}
@@ -361,7 +361,7 @@ void Processor<T>::print_float(const vector<int>& args)
template <class T>
void Processor<T>::print_float_prec(int n)
{
T::out << setprecision(n);
out << setprecision(n);
}
} /* namespace GC */

26
GC/Rep4Secret.cpp Normal file
View File

@@ -0,0 +1,26 @@
/*
* Rep4Secret.cpp
*
*/
#ifndef GC_REP4SECRET_CPP_
#define GC_REP4SECRET_CPP_
#include "Rep4Secret.h"
#include "ShareSecret.hpp"
#include "ShareThread.hpp"
#include "Protocols/Rep4MC.hpp"
namespace GC
{
void Rep4Secret::load_clear(int n, const Integer& x)
{
this->check_length(n, x);
*this = constant(x, ShareThread<This>::s().P->my_num());
}
}
#endif /* GC_REP4SECRET_CPP_ */

53
GC/Rep4Secret.h Normal file
View File

@@ -0,0 +1,53 @@
/*
* Rep4Secret.h
*
*/
#ifndef GC_REP4SECRET_H_
#define GC_REP4SECRET_H_
#include "ShareSecret.h"
#include "Processor/NoLivePrep.h"
#include "Protocols/Rep4MC.h"
#include "Protocols/Rep4Share.h"
namespace GC
{
class Rep4Secret : public RepSecretBase<Rep4Secret, 3>
{
typedef RepSecretBase<Rep4Secret, 3> super;
typedef Rep4Secret This;
public:
typedef DummyLivePrep<This> LivePrep;
typedef Rep4<This> Protocol;
typedef Rep4MC<This> MC;
typedef MC MAC_Check;
typedef Rep4Input<This> Input;
static const bool expensive_triples = false;
static MC* new_mc(typename super::mac_key_type) { return new MC; }
static This constant(const typename super::clear& constant, int my_num,
typename super::mac_key_type = {})
{
return Rep4Share<typename super::clear>::constant(constant, my_num);
}
Rep4Secret()
{
}
template <class T>
Rep4Secret(const T& other) :
super(other)
{
}
void load_clear(int n, const Integer& x);
};
}
#endif /* GC_REP4SECRET_H_ */

View File

@@ -62,13 +62,13 @@ public:
typedef typename T::Input Input;
typedef typename T::out_type out_type;
static string type_string() { return "evaluation secret"; }
static string phase_name() { return T::name(); }
static int default_length;
static typename T::out_type out;
static const bool needs_ot = false;
static const bool is_real = true;
@@ -170,9 +170,6 @@ public:
template <class T>
int Secret<T>::default_length = 64;
template <class T>
typename T::out_type Secret<T>::out = T::out;
template <class T>
inline ostream& operator<<(ostream& o, Secret<T>& secret)
{

View File

@@ -58,10 +58,11 @@ SemiPrep::~SemiPrep()
void SemiPrep::buffer_bits()
{
auto& thread = Thread<SemiSecret>::s();
word r = thread.secure_prng.get_word();
word r = secure_prng.get_word();
for (size_t i = 0; i < sizeof(word) * 8; i++)
{
this->bits.push_back((r >> i) & 1);
}
}
size_t SemiPrep::data_sent()

View File

@@ -23,6 +23,8 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
SemiSecret::TripleGenerator* triple_generator;
MascotParams params;
SeededPRNG secure_prng;
public:
SemiPrep(DataPositions& usage, ShareThread<SemiSecret>& thread);
SemiPrep(DataPositions& usage, bool = true);

View File

@@ -80,8 +80,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
T::out.activate(my_num == 0 or online_opts.interactive);
if (not this->machine.use_encryption and not T::dishonest_majority)
insecure("unencrypted communication");

View File

@@ -44,8 +44,6 @@ public:
static const bool is_real = true;
static const bool actual_inputs = true;
static SwitchableOutput out;
static void store_clear_in_dynamic(Memory<U>& mem,
const vector<ClearWriteAccess>& accesses);
@@ -83,21 +81,26 @@ public:
void other_input(T& inputter, int from, int n_bits = 1);
template<class T>
void finalize_input(T& inputter, int from, int n_bits);
U& operator=(const U&);
};
template<class U>
class ReplicatedSecret : public FixedVec<BitVec, 2>, public ShareSecret<U>
template<class U, int L>
class RepSecretBase : public FixedVec<BitVec, L>, public ShareSecret<U>
{
typedef FixedVec<BitVec, 2> super;
typedef FixedVec<BitVec, L> super;
typedef RepSecretBase This;
public:
typedef U part_type;
typedef U small_type;
typedef U whole_type;
typedef BitVec clear;
typedef BitVec open_type;
typedef BitVec mac_type;
typedef BitVec mac_key_type;
typedef ReplicatedBase Protocol;
typedef NoShare bit_type;
static const int N_BITS = clear::N_BITS;
@@ -109,7 +112,7 @@ public:
static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; }
static const int default_length = 8 * sizeof(typename ReplicatedSecret<U>::value_type);
static const int default_length = N_BITS;
static int threshold(int)
{
@@ -124,9 +127,45 @@ public:
{
}
static void read_or_generate_mac_key(string, const Names&, mac_key_type) {}
static void read_or_generate_mac_key(string, const Player&, mac_key_type)
{
}
static ReplicatedSecret constant(const clear& value, int my_num, mac_key_type)
RepSecretBase()
{
}
template <class T>
RepSecretBase(const T& other) :
super(other)
{
}
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void xor_(int n, const This& x, const This& y)
{ *this = x ^ y; (void)n; }
This operator&(const Clear& other)
{ return super::operator&(BitVec(other)); }
This lsb()
{ return *this & 1; }
This get_bit(int i)
{ return (*this >> i) & 1; }
};
template<class U>
class ReplicatedSecret : public RepSecretBase<U, 2>
{
typedef RepSecretBase<U, 2> super;
public:
typedef ReplicatedBase Protocol;
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
typename super::mac_key_type)
{
ReplicatedSecret res;
if (my_num < 2)
@@ -140,28 +179,44 @@ public:
void load_clear(int n, const Integer& x);
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
BitVec local_mul(const ReplicatedSecret& other) const;
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
{ *this = x ^ y; (void)n; }
void reveal(size_t n_bits, Clear& x);
ReplicatedSecret operator&(const Clear& other)
{ return super::operator&(BitVec(other)); }
ReplicatedSecret lsb()
{ return *this & 1; }
ReplicatedSecret get_bit(int i)
{ return (*this >> i) & 1; }
};
class SemiHonestRepPrep;
class SmallRepSecret : public FixedVec<BitVec_<unsigned char>, 2>
{
typedef FixedVec<BitVec_<unsigned char>, 2> super;
typedef SmallRepSecret This;
public:
typedef ReplicatedMC<This> MC;
typedef BitVec_<unsigned char> open_type;
typedef open_type clear;
typedef BitVec mac_key_type;
static MC* new_mc(mac_key_type)
{
return new MC;
}
SmallRepSecret()
{
}
template<class T>
SmallRepSecret(const T& other) :
super(other)
{
}
This lsb() const
{
return *this & 1;
}
};
class SemiHonestRepSecret : public ReplicatedSecret<SemiHonestRepSecret>
{
typedef ReplicatedSecret<SemiHonestRepSecret> super;
@@ -176,7 +231,7 @@ public:
typedef ReplicatedInput<SemiHonestRepSecret> Input;
typedef SemiHonestRepSecret part_type;
typedef SemiHonestRepSecret small_type;
typedef SmallRepSecret small_type;
typedef SemiHonestRepSecret whole_type;
static const bool expensive_triples = false;

View File

@@ -25,14 +25,11 @@
namespace GC
{
template<class U>
const int ReplicatedSecret<U>::N_BITS;
template<class U, int L>
const int RepSecretBase<U, L>::N_BITS;
template<class U>
const int ReplicatedSecret<U>::default_length;
template<class U>
SwitchableOutput ShareSecret<U>::out;
template<class U, int L>
const int RepSecretBase<U, L>::default_length;
template<class U>
void ShareSecret<U>::check_length(int n, const Integer& x)
@@ -59,16 +56,16 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
*this = x;
}
template<class U>
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
template<class U, int L>
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
template<class U>
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
template<class U, int L>
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;
@@ -285,12 +282,11 @@ void ShareSecret<U>::xors(Processor<U>& processor, const vector<int>& args)
ShareThread<U>::s().xors(processor, args);
}
template<class U>
void ReplicatedSecret<U>::trans(Processor<U>& processor,
template<class U, int L>
void RepSecretBase<U, L>::trans(Processor<U>& processor,
int n_outputs, const vector<int>& args)
{
assert(length == 2);
for (int k = 0; k < 2; k++)
for (int k = 0; k < L; k++)
{
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
@@ -330,6 +326,14 @@ void ShareSecret<U>::random_bit()
*this = res;
}
template<class U>
U& GC::ShareSecret<U>::operator=(const U& other)
{
U& real_this = static_cast<U&>(*this);
real_this = other;
return real_this;
}
}
#endif

View File

@@ -54,7 +54,13 @@ void Thread<T>::run()
P = new CryptoPlayer(N, thread_num << 16);
else
P = new PlainPlayer(N, thread_num << 16);
processor.open_input_file(N.my_num(), thread_num);
processor.open_input_file(N.my_num(), thread_num,
master.opts.cmd_private_input_file);
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
processor.setup_redirection(P->my_num(), thread_num, master.opts);
if (processor.stdout_redirect_file.is_open())
processor.out.redirect_to_file(processor.stdout_redirect_file);
done.push(0);
pre_run();

View File

@@ -38,6 +38,7 @@ public:
typedef typename part_type::sacri_type sacri_type;
typedef typename part_type::mac_type mac_type;
typedef typename part_type::mac_share_type mac_share_type;
typedef BitDiagonal Rectangle;
typedef typename T::super check_type;
@@ -152,6 +153,11 @@ public:
reg.output(s, human);
}
void input(istream&, bool)
{
throw not_implemented();
}
template <class U>
void my_input(U& inputter, BitVec value, int n_bits)
{

View File

@@ -129,7 +129,7 @@
X(GLDMC, ) \
X(LDMS, ) \
X(LDMC, ) \
X(PRINTINT, S0.out << I0) \
X(PRINTINT, PROC.out << I0) \
X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \
X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \
X(RUN_TAPE, MACH->run_tapes(EXTRA)) \