Various improvements.

This commit is contained in:
Marcel Keller
2020-08-24 23:29:03 +10:00
parent cf1719b83a
commit ad583afb7e
149 changed files with 2887 additions and 859 deletions

View File

@@ -42,7 +42,6 @@ CommonFakeParty::~CommonFakeParty()
CommonParty::~CommonParty()
{
cerr << "Total time: " << timer.elapsed() << endl;
#ifdef VERBOSE
cerr << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl;
cerr << "CPU time: " << cpu_timer.elapsed() << endl;
@@ -50,6 +49,7 @@ CommonParty::~CommonParty()
cerr << "Second phase time: " << timers[1].elapsed() << endl;
cerr << "Number of gates: " << gate_counter << endl;
#endif
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
}
void CommonParty::check(int n_parties)

View File

@@ -8,6 +8,8 @@
#include "Party.h"
#include "GC/ShareSecret.hpp"
template<class T>
ProgramPartySpec<T>* ProgramPartySpec<T>::singleton = 0;

View File

@@ -9,15 +9,32 @@
#include "Register.h"
template<class T> class RealProgramParty;
template<class T> class RealGarbleWire;
template<class T>
class GarbleInputter
{
public:
RealProgramParty<T>& party;
Bundle<octetStream> oss;
PointerVector<pair<RealGarbleWire<T>*, int>> tuples;
GarbleInputter();
void exchange();
};
template<class T>
class RealGarbleWire : public PRFRegister
{
friend class RealProgramParty<T>;
friend class GarbleInputter<T>;
T mask;
public:
typedef GarbleInputter<T> Input;
static void store(NoMemory& dest,
const vector<GC::WriteAccess<GC::Secret<RealGarbleWire>>>& accesses);
static void load(vector<GC::ReadAccess<GC::Secret<RealGarbleWire>>>& accesses,
@@ -26,6 +43,11 @@ public:
static void convcbit(Integer& dest, const GC::Clear& source,
GC::Processor<GC::Secret<RealGarbleWire>>& processor);
static void inputb(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
const vector<int>& args);
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args);
RealGarbleWire(const Register& reg) : PRFRegister(reg) {}
void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
@@ -37,6 +59,10 @@ public:
void public_input(bool value);
void random();
void output();
void my_input(Input& Inputter, bool value, int n_bits);
void other_input(Input& Inputter, int from);
void finalize_input(Input&, int from, int);
};
template<class T>

View File

@@ -94,36 +94,81 @@ void RealGarbleWire<T>::XOR(const RealGarbleWire<T>& left, const RealGarbleWire<
}
template<class T>
void RealGarbleWire<T>::input(party_id_t from, char input)
void RealGarbleWire<T>::inputb(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputb(inputter, processor, args,
inputter.party.P->my_num());
}
template<class T>
void RealGarbleWire<T>::inputbvec(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputbvec(inputter, input_processor, args,
inputter.party.P->my_num());
}
template<class T>
GarbleInputter<T>::GarbleInputter() :
party(RealProgramParty<T>::s()), oss(*party.P)
{
}
template<class T>
void RealGarbleWire<T>::my_input(Input& inputter, bool, int n_bits)
{
assert(n_bits == 1);
inputter.tuples.push_back({this, inputter.party.P->my_num()});
}
template<class T>
void RealGarbleWire<T>::other_input(Input& inputter, int from)
{
inputter.tuples.push_back({this, from});
}
template<class T>
void GarbleInputter<T>::exchange()
{
PRFRegister::input(from, input);
auto& party = RealProgramParty<T>::s();
assert(party.shared_proc != 0);
auto& inputter = party.shared_proc->input;
inputter.reset(from - 1);
if (from == party.get_id())
inputter.reset_all(*party.P);
for (auto& tuple : tuples)
{
char my_mask;
my_mask = party.prng.get_bit();
party.input_masks.serialize(my_mask);
inputter.add_mine(my_mask);
inputter.send_mine();
mask = inputter.finalize_mine();
int from = tuple.second;
party_id_t from_id = from + 1;
tuple.first->PRFRegister::input(from_id, -1);
if (from_id == party.get_id())
{
char my_mask;
my_mask = party.prng.get_bit();
party.garble_input_masks.serialize(my_mask);
inputter.add_mine(my_mask);
#ifdef DEBUG_MASK
cout << "my mask: " << (int)my_mask << endl;
cout << "my mask: " << (int)my_mask << endl;
#endif
}
else
{
inputter.add_other(from);
}
}
else
{
inputter.add_other(from - 1);
octetStream os;
party.P->receive_player(from - 1, os, true);
inputter.finalize_other(from - 1, mask, os);
}
inputter.exchange();
for (auto& tuple : tuples)
tuple.first->mask = (inputter.finalize(tuple.second));
// important to make sure that mask is a bit
try
{
mask.force_to_bit();
for (auto& tuple : tuples)
tuple.first->mask.force_to_bit();
}
catch (not_implemented& e)
{
@@ -131,16 +176,36 @@ void RealGarbleWire<T>::input(party_id_t from, char input)
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul(party.shared_proc);
protocol.prepare_mul(mask, T::constant(1, party.P->my_num(), party.mac_key) - mask);
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)
- tuple.first->mask);
protocol.exchange();
if (party.MC->open(protocol.finalize_mul(), *party.P) != 0)
vector<T> to_check;
to_check.reserve(tuples.size());
for (size_t i = 0; i < tuples.size(); i++)
{
to_check.push_back(protocol.finalize_mul());
}
try
{
party.MC->CheckFor(0, to_check, *party.P);
}
catch (mac_fail&)
{
throw runtime_error("input mask not a bit");
}
}
#ifdef DEBUG_MASK
cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl;
#endif
}
template<class T>
void RealGarbleWire<T>::finalize_input(GarbleInputter<T>&, int, int)
{
}
template<class T>
void RealGarbleWire<T>::public_input(bool value)
{
@@ -169,7 +234,7 @@ void RealGarbleWire<T>::output()
assert(party.MC != 0);
assert(party.P != 0);
auto m = party.MC->open(mask, *party.P);
party.output_masks.push_back(m.get_bit(0));
party.garble_output_masks.push_back(m.get_bit(0));
party.taint();
#ifdef DEBUG_MASK
cout << "output mask: " << m << endl;

View File

@@ -19,6 +19,7 @@ class RealProgramParty : public ProgramPartySpec<T>
typedef typename T::Input Inputter;
friend class RealGarbleWire<T>;
friend class GarbleInputter<T>;
friend class GarbleJob<T>;
static RealProgramParty* singleton;
@@ -40,9 +41,15 @@ class RealProgramParty : public ProgramPartySpec<T>
GC::BreakType next;
bool one_shot;
size_t data_sent;
public:
static RealProgramParty& s();
LocalBuffer garble_input_masks, garble_output_masks;
RealProgramParty(int argc, const char** argv);
~RealProgramParty();

View File

@@ -44,10 +44,20 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
"-N", // Flag token.
"--nparties" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Evaluate only after garbling.", // Help description.
"-O", // Flag token.
"--oneshot" // Flag token.
);
opt.parse(argc, argv);
int nparties;
opt.get("-N")->getInt(nparties);
this->check(nparties);
one_shot = opt.isSet("-O");
NetworkOptions network_opts(opt, argc, argv);
OnlineOptions& online_opts = OnlineOptions::singleton;
@@ -90,7 +100,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
mac_key.randomize(prng);
if (T::needs_ot)
BaseMachine::s().ot_setups.push_back({*P, true});
prep = Preprocessing<T>::get_live_prep(0, usage);
prep = new typename T::TriplePrep(0, usage);
}
else
{
@@ -122,6 +132,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
for (int i = 0; i < SPDZ_OP_N; i++)
this->spdz_wires[i].push_back({});
this->timer.reset();
do
{
next = GC::TIME_BREAK;
@@ -129,7 +140,14 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
try
{
this->online_timer.start();
this->start_online_round();
if (one_shot)
this->start_online_round();
else
{
this->load_garbled_circuit();
next = this->second_phase(program, this->processor,
this->machine, this->dynamic_memory);
}
this->online_timer.stop();
}
catch (needs_cleaning& e)
@@ -139,6 +157,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);
MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();
if (server)
delete server;
@@ -152,7 +171,7 @@ void RealProgramParty<T>::garble()
auto& program = this->program;
auto& MC = this->MC;
while (next == GC::TIME_BREAK)
do
{
garble_jobs.clear();
garble_inputter->reset_all(*P);
@@ -178,13 +197,15 @@ void RealProgramParty<T>::garble()
vector<typename T::clear> opened;
MC->POpen(opened, wires, *P);
LocalBuffer garbled_circuit;
for (auto& x : opened)
this->garbled_circuit.serialize(x);
garbled_circuit.serialize(x);
this->garbled_circuits.push_and_clear(this->garbled_circuit);
this->input_masks_store.push_and_clear(this->input_masks);
this->output_masks_store.push_and_clear(this->output_masks);
this->garbled_circuits.push_and_clear(garbled_circuit);
this->input_masks_store.push_and_clear(garble_input_masks);
this->output_masks_store.push_and_clear(garble_output_masks);
}
while (one_shot and next == GC::TIME_BREAK);
}
template<class T>
@@ -194,6 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
delete prep;
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
}
template<class T>

View File

@@ -601,37 +601,64 @@ void EvalRegister::input_helper(char value, octetStream& os)
os.serialize(get_external());
}
void EvalRegister::input(party_id_t from, char value)
EvalInputter::EvalInputter() :
party(ProgramParty::s()), oss(*party.P)
{
auto& party = ProgramParty::s();
}
void EvalRegister::my_input(EvalInputter& inputter, bool input, int n_bits)
{
assert(n_bits == 1);
auto& party = inputter.party;
party.load_wire(*this);
octetStream os;
if (from == party.get_id())
{
if (value and (1 - value))
throw runtime_error("invalid input");
input_helper(value, os);
party.P->send_all(os, true);
}
else
{
party.P->receive_player(from - 1, os);
char ext;
os.unserialize(ext);
set_external(ext);
}
vector<octetStream> oss(party.get_n_parties());
size_t id = party.get_id() - 1;
oss[id].serialize(garbled_entry[id]);
#ifdef DEBUG_COMM
cout << "send " << garbled_entry[id] << ", "
<< oss[id].get_length() << " bytes from " << id << endl;
#endif
input_helper(input, inputter.oss.mine);
inputter.tuples.push_back({this, party.P->my_num()});
}
void EvalRegister::other_input(EvalInputter& inputter, int from)
{
auto& party = inputter.party;
party.load_wire(*this);
inputter.tuples.push_back({this, from});
}
void EvalInputter::exchange()
{
party.P->Broadcast_Receive(oss, true);
for (auto& tuple : tuples)
{
if (tuple.from != party.P->my_num())
{
char ext;
oss[tuple.from].unserialize(ext);
tuple.reg->set_external(ext);
}
}
size_t id = party.get_id() - 1;
for (auto& os : oss)
os.reset_write_head();
for (auto& tuple : tuples)
{
oss[id].serialize(tuple.reg->get_garbled_entry()[id]);
#ifdef DEBUG_COMM
cout << "send " << garbled_entry[id] << ", "
<< oss[id].get_length() << " bytes from " << id << endl;
#endif
}
party.P->Broadcast_Receive(oss, true);
}
void EvalRegister::finalize_input(EvalInputter& inputter, int, int)
{
auto& party = inputter.party;
size_t id = party.get_id() - 1;
for (size_t i = 0; i < (size_t)party.get_n_parties(); i++)
{
if (i != id)
oss[i].unserialize(garbled_entry[i]);
inputter.oss[i].unserialize(garbled_entry[i]);
}
keys[external] = garbled_entry;
#ifdef DEBUG
@@ -934,4 +961,4 @@ void KeyTuple<I>::print(int wire_id, party_id_t pid)
}
template class KeyTuple<2>;
template class KeyTuple<4>;
template class KeyTuple<4> ;

View File

@@ -20,6 +20,8 @@ using namespace std;
#include "GC/ArgTuples.h"
#include "Math/gf2n.h"
#include "Tools/FlexBuffer.h"
#include "Tools/PointerVector.h"
#include "Tools/Bundle.h"
//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
@@ -201,6 +203,8 @@ public:
inline BlackHole& endl(BlackHole& b) { return b; }
inline BlackHole& flush(BlackHole& b) { return b; }
class ProcessorBase;
class Phase
{
public:
@@ -209,6 +213,8 @@ public:
typedef BlackHole out_type;
static BlackHole out;
static const bool actual_inputs = true;
template <class T>
static void store_clear_in_dynamic(T& mem, const vector<GC::ClearWriteAccess>& accesses)
{ (void)mem; (void)accesses; }
@@ -231,6 +237,9 @@ public:
template <class T>
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
template <class T>
static void inputbvec(T&, ProcessorBase&, const vector<int>&)
{ throw not_implemented(); }
template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
{ return T::input(from, processor.get_input(n_bits), n_bits); }
@@ -244,9 +253,24 @@ public:
void output() {}
};
class NoOpInputter
{
public:
PointerVector<char> inputs;
void exchange()
{
}
};
class ProgramRegister : public Phase, public Register
{
public:
typedef NoOpInputter Input;
// only true for evaluation
static const bool actual_inputs = false;
static Register new_reg();
static Register tmp_reg() { return new_reg(); }
static Register and_reg() { return new_reg(); }
@@ -255,11 +279,18 @@ public:
static void store(NoMemory& dest,
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
template <class T>
static void inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args);
// most BMR phases don't need actual input
template<class T>
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{ (void)processor; return T::input(args.from + 1, 0, args.n_bits); }
void my_input(Input&, bool, int) {}
void other_input(Input&, int) {}
char get_output() { return 0; }
ProgramRegister(const Register& reg) : Register(reg) {}
@@ -282,6 +313,37 @@ public:
void public_input(bool value);
void random();
void output();
void finalize_input(NoOpInputter&, int from, int)
{ input(from + 1, -1); }
};
class ProgramParty;
class EvalRegister;
class EvalInputter
{
class Tuple
{
public:
EvalRegister* reg;
int from;
Tuple(EvalRegister* reg, int from) :
reg(reg), from(from)
{
}
};
public:
ProgramParty& party;
Bundle<octetStream> oss;
vector<Tuple> tuples;
EvalInputter();
void add_other(int from);
void exchange();
};
class EvalRegister : public ProgramRegister
@@ -289,9 +351,13 @@ class EvalRegister : public ProgramRegister
public:
static string name() { return "Evaluation"; }
typedef EvalInputter Input;
typedef ostream& out_type;
static ostream& out;
static const bool actual_inputs = true;
template<class T, class U>
static void store(GC::Memory<U>& dest,
const vector<GC::WriteAccess<T> >& accesses);
@@ -303,6 +369,9 @@ public:
static void andrs(T& processor, const vector<int>& args);
template <class T>
static void inputb(T& processor, const vector<int>& args);
template <class T>
static void inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args);
template <class T>
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
@@ -330,6 +399,10 @@ public:
static void check_input(long long input, int n_bits);
void input(party_id_t from, char value = -1);
void input_helper(char value, octetStream& os);
void my_input(EvalInputter& inputter, bool input, int);
void other_input(EvalInputter& inputter, int from);
void finalize_input(EvalInputter& inputter, int from, int);
};
class GarbleRegister : public ProgramRegister
@@ -349,6 +422,9 @@ public:
void public_input(bool value);
void random();
void output() {}
void finalize_input(NoOpInputter&, int from, int)
{ input(from + 1, -1); }
};
class RandomRegister : public ProgramRegister
@@ -374,6 +450,9 @@ public:
void public_input(bool value);
void random();
void output();
void finalize_input(NoOpInputter&, int from, int)
{ input(from + 1, -1); }
};

View File

@@ -9,6 +9,31 @@
#include "Register.h"
#include "Party.h"
template<class T>
void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args)
{
NoOpInputter inputter;
int my_num = -1;
try
{
my_num = ProgramParty::s().P->my_num();
}
catch (exception&)
{
}
processor.inputbvec(inputter, input_processor, args, my_num);
}
template<class T>
void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args)
{
EvalInputter inputter;
processor.inputbvec(inputter, input_processor, args,
ProgramParty::s().P->my_num());
}
template <class T>
void PRFRegister::load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source)

View File

@@ -29,6 +29,7 @@
#include "GC/ThreadMaster.hpp"
#include "GC/Program.hpp"
#include "GC/Instruction.hpp"
#include "GC/ShareSecret.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/Share.hpp"

View File

@@ -1,5 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.1.9 (Aug 24, 2020)
- Streamline inputs to binary circuits
- Improved private output
- Emulator for arithmetic circuits
- Efficient dot product with Shamir's secret sharing
- Lower memory usage for TensorFlow inference
- This version breaks bytecode compatibilty.
## 0.1.8 (June 15, 2020)
- Half-gate garbling

View File

@@ -37,6 +37,7 @@ opcodes = dict(
STMSBI = 0x243,
MOVSB = 0x244,
INPUTB = 0x246,
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
XORCBI = 0x210,
@@ -269,6 +270,24 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
arg_format = tools.cycle(['p','int','int','sbw'])
is_vec = lambda self: True
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
base.Mergeable):
__slots__ = []
code = opcodes['INPUTBVEC']
def __init__(self, *args, **kwargs):
self.arg_format = []
i = 0
while i < len(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (args[i] - 3)
i += args[i]
assert i == len(args)
super(inputbvec, self).__init__(*args, **kwargs)
def merge(self, other):
self.args += other.args
self.arg_format += other.arg_format
class print_regb(base.VectorInstruction, base.IOInstruction):
code = opcodes['PRINTREGB']
arg_format = ['cb','i']

View File

@@ -88,7 +88,8 @@ class bits(Tape.Register, _structure, _bit):
if mem_type == 'sd':
return cls.load_dynamic_mem(address)
else:
cls.load_inst[util.is_constant(address)](res, address)
for i in range(res.size):
cls.load_inst[util.is_constant(address)](res[i], address + i)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
@@ -120,17 +121,19 @@ class bits(Tape.Register, _structure, _bit):
assert(other.size == math.ceil(self.n / self.unit))
for i, (x, y) in enumerate(zip(self, other)):
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
elif isinstance(self, type(other)) or isinstance(other, type(self)):
assert(self.n == other.n)
elif (isinstance(self, type(other)) or isinstance(other, type(self))) \
and self.n == other.n:
for i in range(math.ceil(self.n / self.unit)):
self.mov(self[i], other[i])
else:
try:
other = self.bit_compose(other.bit_decompose())
bits = other.bit_decompose()
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
other = self.bit_compose(bits)
self.load_other(other)
except:
raise CompilerError('cannot convert from %s to %s' % \
(type(other), type(self)))
raise CompilerError('cannot convert %s/%s from %s to %s' % \
(str(other), repr(other), type(other), type(self)))
def long_one(self):
return 2**self.n - 1 if self.n != None else None
def __repr__(self):
@@ -160,6 +163,9 @@ class cbits(bits):
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
conv_cint_vec = inst.convcintvec
@classmethod
def bit_compose(cls, bits):
return sum(bit << i for i, bit in enumerate(bits))
@classmethod
def conv_regint_by_bit(cls, n, res, other):
assert n == res.n
assert n == other.size
@@ -459,46 +465,68 @@ class sbits(bits):
class sbitvec(_vec):
@classmethod
def get_type(cls, n):
class sbitvecn(cls):
class sbitvecn(cls, _structure):
@staticmethod
def malloc(size):
return sbits.malloc(size * n)
return sbit.malloc(size * n)
@staticmethod
def n_elements():
return n
@classmethod
def get_input_from(cls, player):
return cls.from_vec(
sbits.get_input_from(player, n).bit_decompose(n))
res = cls.from_vec(sbit() for i in range(n))
inst.inputbvec(n + 3, 0, player, *res.v)
return res
get_raw_input_from = get_input_from
def __init__(self, other=None):
if other is not None:
self.v = sbits(other, n=n).bit_decompose(n)
if util.is_constant(other):
self.v = [sbit((other >> i) & 1) for i in range(n)]
else:
self.v = sbits(other, n=n).bit_decompose(n)
@classmethod
def load_mem(cls, address):
try:
assert len(address) == n
if not isinstance(address, int) and len(address) == n:
return cls.from_vec(sbit.load_mem(x) for x in address)
except:
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
def store_in_mem(self, address):
assert self.v[0].n == 1
try:
assert len(address) == n
if not isinstance(address, int) and len(address) == n:
for x, y in zip(self.v, address):
x.store_in_mem(y)
except:
else:
for i in range(n):
self.v[i].store_in_mem(address + i)
def reveal(self):
return self.elements()[0].reveal()
revealed = [cbit() for i in range(len(self))]
for i in range(len(self)):
inst.reveal(1, revealed[i], self.v[i])
return cbits.get_type(len(self)).bit_compose(revealed)
@classmethod
def two_power(cls, nn):
return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1))
def coerce(self, other):
if util.is_constant(other):
return self.from_vec(util.bit_decompose(other, n))
else:
return super(sbitvecn, self).coerce(other)
@classmethod
def bit_compose(cls, bits):
if len(bits) < n:
bits += [0] * (n - len(bits))
assert len(bits) == n
return cls.from_vec(bits)
def __str__(self):
return 'sbitvec(%d)' % n
return sbitvecn
@classmethod
def from_vec(cls, vector):
res = cls()
res.v = list(vector)
return res
compose = from_vec
@classmethod
def combine(cls, vectors):
res = cls()
@@ -512,11 +540,7 @@ class sbitvec(_vec):
if length:
assert isinstance(elements, sint)
if Program.prog.use_split():
n = Program.prog.use_split()
columns = [[sbits.get_type(elements.size)()
for i in range(n)] for i in range(length)]
inst.split(n, elements, *sum(columns, []))
x = sbitint.wallace_tree_without_finish(columns, False)
x = elements.split_to_two_summands(length)
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
else:
assert Program.prog.options.ring
@@ -569,11 +593,14 @@ class sbitvec(_vec):
@property
def size(self):
return self.v[0].n
@property
def n_bits(self):
return len(self.v)
def store_in_mem(self, address):
for i, x in enumerate(self.elements()):
x.store_in_mem(address + i)
def bit_decompose(self):
return self.v
def bit_decompose(self, n_bits=None):
return self.v[:n_bits]
bit_compose = from_vec
def reveal(self):
assert len(self) == 1
@@ -673,7 +700,41 @@ cbits.dynamic_array = Array
def _complement_two_extend(bits, k):
return bits + [bits[-1]] * (k - len(bits))
class sbitint(_bitint, _number, sbits):
class _sbitintbase:
def extend(self, n):
bits = self.bit_decompose()
bits += [bits[-1]] * (n - len(bits))
return self.get_type(n).bit_compose(bits)
def cast(self, n):
bits = self.bit_decompose()[:n]
bits += [bits[-1]] * (n - len(bits))
return self.get_type(n).bit_compose(bits)
def round(self, k, m, kappa=None, nearest=None, signed=None):
bits = self.bit_decompose()
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None):
k = bit_length or max(self.n, other.n)
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k)[::-1]
suffixes = floatingpoint.PreOR(bits)[::-1]
z = [0] * k
for i in range(k - 1):
z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1]
z.reverse()
t2k = self.get_type(2 * k)
acc = t2k.bit_compose(z)
sign = self.bit_decompose()[-1]
signed_acc = util.if_else(sign, -acc, acc)
absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose())
part_reciprocal = absolute_val_2k * acc
return part_reciprocal, signed_acc
class sbitint(_bitint, _number, sbits, _sbitintbase):
n_bits = None
bin_type = None
types = {}
@@ -728,43 +789,11 @@ class sbitint(_bitint, _number, sbits):
res_bits = product.bit_decompose()[m:k]
t = self.combo_type(other)
return t.bit_compose(res_bits)
def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k)[::-1]
suffixes = floatingpoint.PreOR(bits)[::-1]
z = [0] * k
for i in range(k - 1):
z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1]
z.reverse()
t2k = self.get_type(2 * k)
acc = t2k.bit_compose(z)
sign = self.bit_decompose()[-1]
signed_acc = util.if_else(sign, -acc, acc)
absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose())
part_reciprocal = absolute_val_2k * acc
return part_reciprocal, signed_acc
def extend(self, n):
bits = self.bit_decompose()
bits += [bits[-1]] * (n - len(bits))
return self.get_type(n).bit_compose(bits)
def __mul__(self, other):
if isinstance(other, sbitintvec):
return other * self
else:
return super(sbitint, self).__mul__(other)
def cast(self, n):
bits = self.bit_decompose()[:n]
bits += [bits[-1]] * (n - len(bits))
return self.get_type(n).bit_compose(bits)
def int_div(self, other, bit_length=None):
k = bit_length or max(self.n, other.n)
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
def round(self, k, m, kappa=None, nearest=None, signed=None):
bits = self.bit_decompose()
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
return self.get_type(k - m).compose(res_bits)
@classmethod
def get_bit_matrix(cls, self_bits, other):
n = len(self_bits)
@@ -782,10 +811,11 @@ class sbitint(_bitint, _number, sbits):
res.append([(x & bit) for x in other.bit_decompose(n - i)])
return res
class sbitintvec(sbitvec, _number):
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
def __add__(self, other):
if util.is_zero(other):
return self
other = self.coerce(other)
assert(len(self.v) == len(other.v))
v = sbitint.bit_adder(self.v, other.v)
return self.from_vec(v)
@@ -797,7 +827,7 @@ class sbitintvec(sbitvec, _number):
if isinstance(other, sbits):
return self.from_vec(other * x for x in self.v)
matrix = []
for i, b in enumerate(other.bit_decompose()):
for i, b in enumerate(util.bit_decompose(other)):
matrix.append([x * b for x in self.v[:len(self.v)-i]])
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])

View File

@@ -12,6 +12,51 @@ import operator
import sys
from functools import reduce
class BlockAllocator:
""" Manages freed memory blocks. """
def __init__(self):
self.by_logsize = [defaultdict(set) for i in range(32)]
self.by_address = {}
def by_size(self, size):
return self.by_logsize[int(math.log(size, 2))][size]
def push(self, address, size):
end = address + size
if end in self.by_address:
next_size = self.by_address.pop(end)
self.by_size(next_size).remove(end)
size += next_size
self.by_size(size).add(address)
self.by_address[address] = size
def pop(self, size):
if len(self.by_size(size)) > 0:
block_size = size
else:
logsize = int(math.log(size, 2))
for block_size, addresses in self.by_logsize[logsize].items():
if block_size >= size and len(addresses) > 0:
break
else:
done = False
for x in self.by_logsize[logsize + 1:]:
for block_size, addresses in x.items():
if len(addresses) > 0:
done = True
break
if done:
break
else:
block_size = 0
if block_size >= size:
addr = self.by_size(block_size).pop()
del self.by_address[addr]
diff = block_size - size
if diff:
self.by_size(diff).add(addr + size)
self.by_address[addr + size] = diff
return addr
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
@@ -509,3 +554,47 @@ class Merger:
for i in range(self.G.n):
print('%d: %s' % (self.depths[i], self.instructions[i]), file=f)
f.close()
class RegintOptimizer:
def __init__(self):
self.cache = util.dict_by_id()
def run(self, instructions):
for i, inst in enumerate(instructions):
if isinstance(inst, ldint_class):
self.cache[inst.args[0]] = inst.args[1]
elif isinstance(inst, IntegerInstruction):
if inst.args[1] in self.cache and inst.args[2] in self.cache:
res = inst.op(self.cache[inst.args[1]],
self.cache[inst.args[2]])
if abs(res) < 2 ** 31:
self.cache[inst.args[0]] = res
instructions[i] = ldint(inst.args[0], res,
add_to_prog=False)
elif isinstance(inst, addint_class):
if inst.args[1] in self.cache and \
self.cache[inst.args[1]] == 0:
instructions[i] = inst.args[0].link(inst.args[2])
elif inst.args[2] in self.cache and \
self.cache[inst.args[2]] == 0:
instructions[i] = inst.args[0].link(inst.args[1])
elif isinstance(inst, IndirectMemoryInstruction):
if inst.args[1] in self.cache:
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
elif isinstance(inst, convint_class):
if inst.args[1] in self.cache:
res = self.cache[inst.args[1]]
self.cache[inst.args[0]] = res
if abs(res) < 2 ** 31:
instructions[i] = ldi(inst.args[0], res,
add_to_prog=False)
elif isinstance(inst, mulm_class):
if inst.args[2] in self.cache:
op = self.cache[inst.args[2]]
if op == 0:
instructions[i] = ldsi(inst.args[0], 0,
add_to_prog=False)
elif op == 1:
instructions[i] = None
inst.args[0].link(inst.args[1])
instructions[:] = filter(lambda x: x is not None, instructions)

View File

@@ -79,7 +79,10 @@ def LTZ(s, a, k, kappa):
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
movs(s, sint.conv(sbitvec(a, k).v[-1]))
summands = a.split_to_two_summands(k)
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
movs(s, sint.conv(msb))
return
elif program.options.ring:
from . import floatingpoint
@@ -132,9 +135,31 @@ def Trunc(d, a, k, m, kappa, signed):
mulm(d, t, c[2])
def TruncRing(d, a, k, m, signed):
a_prime = Mod2mRing(None, a, k, m, signed)
a -= a_prime
res = TruncLeakyInRing(a, k, m, signed)
if program.use_split() == 3:
from Compiler.types import sint
from .GC.types import sbitint
length = int(program.options.ring)
summands = a.split_to_n_summands(length, 3)
x = sbitint.wallace_tree_without_finish(summands, True)
if m == 1:
low = x[1][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
else:
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
low = sint.conv(mid_carry) + sint.conv(x[0][m])
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
for xx, yy in zip(x[1][m:-1],
x[0][m:-1])))
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
high = top_carry + sint.conv(x[0][-1])
shifted = sint()
shrsi(shifted, a, m)
res = shifted + sint.conv(low) - (high << (length - m))
else:
a_prime = Mod2mRing(None, a, k, m, signed)
a -= a_prime
res = TruncLeakyInRing(a, k, m, signed)
if d is not None:
movs(d, res)
return res
@@ -281,7 +306,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
"""
program.curr_tape.require_bit_length(k + kappa)
from .types import sint
if program.use_edabit() and m > 1:
if program.use_edabit() and m > 1 and not const_rounds:
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
tmp, b[:] = sint.get_edabit(m, True)
movs(r_prime, tmp)
@@ -290,7 +315,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
t[0][1] = b[-1]
PRandInt(r_dprime, k + kappa - m)
# r_dprime is always multiplied by 2^m
if use_dabit and program.use_dabit and m > 1:
if use_dabit and program.use_dabit and m > 1 and not const_rounds:
r, b[:] = zip(*(sint.get_dabit() for i in range(m)))
r = sint.bit_compose(r)
movs(r_prime, r)
@@ -383,7 +408,7 @@ def BitLTC1(u, a, b, kappa):
Mod2(u, t[4][k-1], k, kappa, False)
return p, a_bits, d, s, t, c, b, pre_input
def carry(b, a, compute_p):
def carry(b, a, compute_p=True):
""" Carry propogation:
return (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
"""

View File

@@ -9,14 +9,14 @@ import time
import sys
def run(args, options, param=-1, merge_opens=True,
def run(args, options, merge_opens=True,
reallocate=True, debug=False):
""" Compile a file and output a Program object.
If merge_opens is set to True, will attempt to merge any parallelisable open
instructions. """
prog = Program(args, options, param)
prog = Program(args, options)
instructions.program = prog
instructions_base.program = prog
types.program = prog
@@ -24,7 +24,7 @@ def run(args, options, param=-1, merge_opens=True,
prog.DEBUG = debug
VARS['program'] = prog
if options.binary:
VARS['sint'] = GC_types.sbitint.get_type(int(options.binary))
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
VARS['sfix'] = GC_types.sbitfix
comparison.set_variant(options)

View File

@@ -6,6 +6,7 @@ USER_MEM = 8192
P_VALUES = { 32: 2147565569, \
64: 9223372036855103489, \
128: 170141183460469231731687303715885907969, \
192: 3138550867693340381917894711603833208051177722232017256453, \
256: 57896044618658097711785492504343953926634992332820282019728792003956566065153, \
512: 6703903964971298549787012499102923063739682910296196688861780721860882015036773488400937149083451713845015929093243025426876941405973284973216824503566337 }
@@ -19,7 +20,7 @@ BIT_LENGTHS = { -1: 64,
512: 64 }
COST = defaultdict(lambda: defaultdict(lambda: 0),
COST = defaultdict(lambda: defaultdict(lambda: 0),
{ 'modp': defaultdict(lambda: 0,
{ 'triple': 0.00020652622883106154,
'square': 0.00020652622883106154,

View File

@@ -104,11 +104,15 @@ def PreORC(a, kappa=None, m=None, raw=False):
k = len(a)
if k == 1:
return [a[0]]
prog = program.Program.prog
kappa = kappa or prog.security
m = m or k
if isinstance(a[0], types.sgf2n):
max_k = program.Program.prog.galois_length - 1
else:
max_k = int(log(program.Program.prog.P) / log(2)) - kappa
# assume prime length is power of two
prime_length = 2 ** int(ceil(log(prog.bit_length + kappa, 2)))
max_k = prime_length - kappa - 2
assert(max_k > 0)
if k <= max_k:
p = [None] * m
@@ -132,7 +136,8 @@ def PreORC(a, kappa=None, m=None, raw=False):
# not constant-round anymore
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])
return sum(([or_op(x, y) for x in si]
for si,y in zip(s[1:],t)), s[0])[-m:]
def PreOpL(op, items):
"""
@@ -549,9 +554,10 @@ def TruncPrRing(a, k, m, signed=True):
trunc_pr(res, a, k, m)
else:
# extra bit to mask overflow
if program.Program.prog.use_edabit():
lower = sint.get_edabit(m, True)[0]
upper = sint.get_edabit(k - m, True)[0]
prog = program.Program.prog
if prog.use_edabit() or prog.use_split() == 3:
lower = sint.get_random_int(m)
upper = sint.get_random_int(k - m)
msb = sint.get_random_bit()
r = (msb << k) + (upper << m) + lower
else:

View File

@@ -91,64 +91,74 @@ class stmint(base.DirectMemoryWriteInstruction):
# must have seperate instructions because address is always modp
@base.vectorize
class ldmci(base.ReadMemoryInstruction):
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
code = base.opcodes['LDMCI']
arg_format = ['cw','ci']
direct = staticmethod(ldmc)
@base.vectorize
class ldmsi(base.ReadMemoryInstruction):
class ldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
code = base.opcodes['LDMSI']
arg_format = ['sw','ci']
direct = staticmethod(ldms)
@base.vectorize
class stmci(base.WriteMemoryInstruction):
class stmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
code = base.opcodes['STMCI']
arg_format = ['c','ci']
direct = staticmethod(stmc)
@base.vectorize
class stmsi(base.WriteMemoryInstruction):
class stmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
code = base.opcodes['STMSI']
arg_format = ['s','ci']
direct = staticmethod(stms)
@base.vectorize
class ldminti(base.ReadMemoryInstruction):
class ldminti(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """
code = base.opcodes['LDMINTI']
arg_format = ['ciw','ci']
direct = staticmethod(ldmint)
@base.vectorize
class stminti(base.WriteMemoryInstruction):
class stminti(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """
code = base.opcodes['STMINTI']
arg_format = ['ci','ci']
direct = staticmethod(stmint)
@base.vectorize
class gldmci(base.ReadMemoryInstruction):
class gldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
code = base.opcodes['LDMCI'] + 0x100
arg_format = ['cgw','ci']
direct = staticmethod(gldmc)
@base.vectorize
class gldmsi(base.ReadMemoryInstruction):
class gldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
code = base.opcodes['LDMSI'] + 0x100
arg_format = ['sgw','ci']
direct = staticmethod(gldms)
@base.vectorize
class gstmci(base.WriteMemoryInstruction):
class gstmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
code = base.opcodes['STMCI'] + 0x100
arg_format = ['cg','ci']
direct = staticmethod(gstmc)
@base.vectorize
class gstmsi(base.WriteMemoryInstruction):
class gstmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
code = base.opcodes['STMSI'] + 0x100
arg_format = ['sg','ci']
direct = staticmethod(gstms)
@base.gf2n
@base.vectorize
@@ -268,7 +278,7 @@ class use_edabit(base.Instruction):
class run_tape(base.Instruction):
r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """
code = base.opcodes['RUN_TAPE']
arg_format = ['int','int','int']
arg_format = tools.cycle(['int','int','int'])
class join_tape(base.Instruction):
r""" Join thread $c_i$. """
@@ -661,6 +671,12 @@ class shrci(base.ClearShiftInstruction):
code = base.opcodes['SHRCI']
op = '__rshift__'
@base.vectorize
class shrsi(base.ClearShiftInstruction):
r""" Secret bitwise shift right by immediate value. """
__slots__ = []
code = base.opcodes['SHRSI']
arg_format = ['sw','s','i']
###
### Data access instructions
@@ -742,6 +758,14 @@ class sedabit(base.Instruction):
def add_usage(self, req_node):
req_node.increment(('sedabit', len(self.args) - 1), self.get_size())
@base.vectorize
class randoms(base.Instruction):
""" Random share """
__slots__ = []
code = base.opcodes['RANDOMS']
arg_format = ['sw','int']
field_type = 'modp'
@base.gf2n
@base.vectorize
class square(base.DataInstruction):
@@ -781,6 +805,17 @@ class inputmask(base.Instruction):
req_node.increment((self.field_type, 'input', self.args[1]), \
self.get_size())
@base.vectorize
class inputmaskreg(base.Instruction):
__slots__ = []
code = base.opcodes['INPUTMASKREG']
arg_format = ['sw', 'cw', 'ci']
field_type = 'modp'
def add_usage(self, req_node):
# player 0 as proxy
req_node.increment((self.field_type, 'input', 0), float('inf'))
@base.gf2n
@base.vectorize
class prep(base.Instruction):
@@ -1165,21 +1200,25 @@ class ldint(base.Instruction):
class addint(base.IntegerInstruction):
__slots__ = []
code = base.opcodes['ADDINT']
op = operator.add
@base.vectorize
class subint(base.IntegerInstruction):
__slots__ = []
code = base.opcodes['SUBINT']
op = operator.sub
@base.vectorize
class mulint(base.IntegerInstruction):
__slots__ = []
code = base.opcodes['MULINT']
op = operator.mul
@base.vectorize
class divint(base.IntegerInstruction):
__slots__ = []
code = base.opcodes['DIVINT']
op = operator.floordiv
@base.vectorize
class bitdecint(base.Instruction):
@@ -1228,18 +1267,21 @@ class ltc(base.IntegerInstruction):
r""" Clear comparison $c_i = (c_j \stackrel{?}{<} c_k)$. """
__slots__ = []
code = base.opcodes['LTC']
op = operator.lt
@base.vectorize
class gtc(base.IntegerInstruction):
r""" Clear comparison $c_i = (c_j \stackrel{?}{>} c_k)$. """
__slots__ = []
code = base.opcodes['GTC']
op = operator.gt
@base.vectorize
class eqc(base.IntegerInstruction):
r""" Clear comparison $c_i = (c_j \stackrel{?}{==} c_k)$. """
__slots__ = []
code = base.opcodes['EQC']
op = operator.eq
###

View File

@@ -106,10 +106,12 @@ opcodes = dict(
GBITTRIPLE = 0x154,
GBITGF2NTRIPLE = 0x155,
INPUTMASK = 0x56,
INPUTMASKREG = 0x5C,
PREP = 0x57,
DABIT = 0x58,
EDABIT = 0x59,
SEDABIT = 0x5A,
RANDOMS = 0x5B,
# Input
INPUT = 0x60,
INPUTFIX = 0xF0,
@@ -143,6 +145,7 @@ opcodes = dict(
SHRC = 0x81,
SHLCI = 0x82,
SHRCI = 0x83,
SHRSI = 0x84,
# Branching and comparison
JMP = 0x90,
JMPNZ = 0x91,
@@ -446,10 +449,11 @@ def cisc(function):
assert int(program.options.max_parallel_open) == 0, \
'merging restriction not compatible with ' \
'mergeable CISC instructions'
merger.longest_paths_merge()
n_rounds = merger.longest_paths_merge()
filtered = filter(lambda x: x is not None, block.instructions)
self.instructions[self.merge_id()] = list(filtered), args
template, args = self.instructions[self.merge_id()]
self.instructions[self.merge_id()] = list(filtered), args, \
n_rounds
template, args, self.n_rounds = self.instructions[self.merge_id()]
subs = util.dict_by_id()
for arg, reg in zip(args, regs):
subs[arg] = reg
@@ -486,6 +490,9 @@ def cisc(function):
base += reg.size
return block.instructions
def expanded_rounds(self):
return self.n_rounds - 1
MergeCISC.__name__ = function.__name__
def wrapper(*args, **kwargs):
@@ -649,7 +656,7 @@ class String(ArgFormat):
@classmethod
def encode(cls, arg):
return arg + '\0' * (cls.length - len(arg))
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
ArgFormats = {
'c': ClearModpAF,
@@ -683,6 +690,7 @@ class Instruction(object):
"""
__slots__ = ['args', 'arg_format', 'code', 'caller']
count = 0
code_length = 10
def __init__(self, *args, **kwargs):
""" Create an instruction and append it to the program list. """
@@ -701,7 +709,7 @@ class Instruction(object):
print("Compiled %d lines at" % self.__class__.count, time.asctime())
def get_code(self, prefix=0):
return (prefix << 10) + self.code
return (prefix << self.code_length) + self.code
def get_encoding(self):
enc = int_to_bytes(self.get_code())
@@ -713,7 +721,10 @@ class Instruction(object):
return enc
def get_bytes(self):
return bytearray(self.get_encoding())
try:
return bytearray(self.get_encoding())
except TypeError:
raise CompilerError('cannot encode %s/%s' % (self, self.get_encoding()))
def check_args(self):
""" Check the args match up with that specified in arg_format """
@@ -787,6 +798,9 @@ class Instruction(object):
def expand_merged(self):
return [self]
def expanded_rounds(self):
return 0
def get_new_args(self, size, subs):
new_args = []
for arg, f in zip(self.args, self.arg_format):
@@ -864,6 +878,12 @@ class DirectMemoryInstruction(Instruction):
def __init__(self, *args, **kwargs):
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
class IndirectMemoryInstruction(Instruction):
__slots__ = []
def get_direct(self, address):
return self.direct(self.args[0], address, add_to_prog=False)
class ReadMemoryInstruction(Instruction):
__slots__ = []
@@ -875,7 +895,7 @@ class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
__slots__ = []
def __init__(self, *args, **kwargs):
if program.curr_tape.prevent_direct_memory_write:
raise CompilerError('Direct memory writing prevented')
raise CompilerError('Direct memory writing prevented in threads')
super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs)
###

View File

@@ -6,6 +6,7 @@ in particularly providing flow control and output.
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc
from Compiler.instructions import *
from Compiler.util import tuplify,untuplify,is_zero
from Compiler.allocator import RegintOptimizer
from Compiler import instructions,instructions_base,comparison,program,util
import inspect,math
import random
@@ -54,6 +55,9 @@ def set_instruction_type(function):
return instruction_typed_function
def _expand_to_print(val):
return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val)
def print_str(s, *args):
""" Print a string, with optional args for adding
variables/registers with ``%s``. """
@@ -90,7 +94,7 @@ def print_str(s, *args):
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, (list, tuple, Array)):
print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val)
print_str(*_expand_to_print(val))
else:
try:
val.output()
@@ -127,6 +131,9 @@ def print_ln_if(cond, ss, *args):
print_ln_if(get_player_id() == 0, 'Player 0 here')
"""
print_str_if(cond, ss + '\n', *args)
def print_str_if(cond, ss, *args):
if util.is_constant(cond):
if cond:
print_ln(ss, *args)
@@ -138,9 +145,14 @@ def print_ln_if(cond, ss, *args):
cond = cint.conv(cond)
for i, s in enumerate(subs):
if i != 0:
args[i - 1].output_if(cond)
if i == len(args):
s += '\n'
val = args[i - 1]
try:
val.output_if(cond)
except:
if isinstance(val, (list, tuple, Array)):
print_str_if(cond, *_expand_to_print(val))
else:
print_str_if(cond, str(val))
s += '\0' * ((-len(s)) % 4)
while s:
cond.print_if(s[:4])
@@ -162,7 +174,15 @@ def print_ln_to(player, ss, *args):
new_args = []
for arg in args:
if isinstance(arg, personal):
assert arg.player == player
if util.is_constant(arg.player) ^ util.is_constant(player):
match = False
else:
if util.is_constant(player):
match = arg.player == player
else:
match = id(arg.player) == id(player)
if not match:
raise CompilerError('player mismatch in personal printing')
new_args.append(arg._v)
else:
new_args.append(arg)
@@ -968,6 +988,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
del blocks[-n_to_merge + 1:]
del get_tape().req_node.children[-1]
merged.children = []
RegintOptimizer().run(merged.instructions)
get_tape().active_basicblock = merged
else:
req_node = get_tape().req_node.children[-1].nodes[0]
@@ -1117,7 +1138,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
else:
return loop_body(base + i)
prog = get_program()
threads = []
thread_args = []
if not util.is_zero(thread_rounds):
tape = prog.new_tape(f, (0,), 'multithread')
for i in range(n_threads - remainder):
@@ -1125,7 +1146,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
args[remainder + i][0] = i * thread_rounds
if len(mem_state):
args[remainder + i][1] = mem_state.address
threads.append(prog.run_tape(tape, remainder + i))
thread_args.append((tape, remainder + i))
if remainder:
tape1 = prog.new_tape(f, (1,), 'multithread1')
for i in range(remainder):
@@ -1133,7 +1154,8 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
args[i][0] = (n_threads - remainder + i) * thread_rounds + i
if len(mem_state):
args[i][1] = mem_state.address
threads.append(prog.run_tape(tape1, i))
thread_args.append((tape1, i))
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
if state:
@@ -1625,6 +1647,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
y = y.round(2*k, 2*f, kappa, nearest, signed=True)
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
return y
@@ -1659,7 +1682,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa)[::-1]
suffixes = PreOR(bits)[::-1]
suffixes = PreOR(bits, kappa)[::-1]
z = [0] * k
for i in range(k - 1):

View File

@@ -112,6 +112,12 @@ def lse_0_from_e_x(x, e_x):
def lse_0(x):
return lse_0_from_e_x(x, exp(x))
def approx_lse_0(x, n=3):
assert n != 5
a = x < -0.5
b = x > 0.5
return a.if_else(0, b.if_else(x, 0.5 * (x + 0.5) ** 2)) - x
def relu_prime(x):
""" ReLU derivative. """
return (0 <= x)
@@ -145,10 +151,19 @@ class Tensor(MultiArray):
kwargs['alloc'] = False
super(Tensor, self).__init__(*args, **kwargs)
def input_from(self, *args, **kwargs):
self.alloc()
super(Tensor, self).input_from(*args, **kwargs)
def __getitem__(self, *args):
self.alloc()
return super(Tensor, self).__getitem__(*args)
class Layer:
n_threads = 1
inputs = []
input_bias = True
thetas = lambda self: ()
@property
def shape(self):
@@ -193,14 +208,13 @@ class Output(Layer):
self.approx = approx
nablas = lambda self: ()
thetas = lambda self: ()
reset = lambda self: None
def divisor(self, divisor, size):
return cfix(1.0 / divisor, size=size)
def forward(self, batch):
if self.approx:
if self.approx == 5:
self.l.write(999)
return
N = len(batch)
@@ -209,10 +223,12 @@ class Output(Layer):
def _(base, size):
x = self.X.get_vector(base, size)
y = self.Y.get(batch.get_vector(base, size))
if self.approx:
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
return
e_x = exp(-x)
self.e_x.assign(e_x, base)
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
e_x = self.e_x.get_vector(0, N)
self.l.write(sum(lse) * \
self.divisor(N, 1))
@@ -323,7 +339,7 @@ class Dense(DenseBase):
self.X = MultiArray([N, d, d_in], sfix)
self.Y = MultiArray([N, d, d_out], sfix)
self.W = sfix.Matrix(d_in, d_out)
self.W = Tensor([d_in, d_out], sfix)
self.b = sfix.Array(d_out)
self.nabla_Y = MultiArray([N, d, d_out], sfix)
@@ -546,12 +562,12 @@ class MaxPool(NoVariableLayer):
for x in strides, ksize:
for i in 0, 3:
assert x[i] == 1
self.X = MultiArray(shape, sfix)
self.X = Tensor(shape, sfix)
if padding == 'SAME':
output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)]
else:
output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)]
self.Y = MultiArray(output_shape, sfix)
self.Y = Tensor(output_shape, sfix)
self.strides = strides
self.ksize = ksize
@@ -741,6 +757,7 @@ class ConvBase(BaseLayer):
use_conv2ds = False
temp_weights = None
temp_inputs = None
thetas = lambda self: (self.weights, self.bias)
@classmethod
def init_temp(cls, layers):
@@ -780,7 +797,7 @@ class ConvBase(BaseLayer):
self.weight_squant = self.new_squant()
self.bias_squant = self.new_squant()
self.weights = MultiArray(weight_shape, self.weight_squant)
self.weights = Tensor(weight_shape, self.weight_squant)
self.bias = Array(output_shape[-1], self.bias_squant)
self.unreduced = Tensor(self.output_shape, sint)
@@ -1158,7 +1175,8 @@ class Optimizer:
layer.last_used = list(filter(lambda x: x not in used, layer.inputs))
used.update(layer.inputs)
def forward(self, N=None, batch=None, keep_intermediate=True):
def forward(self, N=None, batch=None, keep_intermediate=True,
model_from=None):
""" Compute graph.
:param N: batch size (used if batch not given)
@@ -1172,12 +1190,16 @@ class Optimizer:
if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None:
layer._X.address = layer.inputs[0].Y.address
layer.Y.alloc()
if model_from is not None:
layer.input_from(model_from)
break_point()
layer.forward(batch=batch)
break_point()
if not keep_intermediate:
for l in layer.last_used:
l.Y.delete()
for theta in layer.thetas():
theta.delete()
def eval(self, data):
""" Compute evaluation after training. """
@@ -1207,6 +1229,7 @@ class Optimizer:
else:
N = self.layers[0].N
i = MemValue(0)
n_iterations = MemValue(0)
@do_while
def _():
if self.X_by_label is None:
@@ -1216,6 +1239,7 @@ class Optimizer:
n = N // len(self.X_by_label)
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
self.X_by_label) / n))
n_iterations.iadd(n_per_epoch)
print('%d runs per epoch' % n_per_epoch)
indices_by_label = []
for label, X in enumerate(self.X_by_label):
@@ -1235,7 +1259,7 @@ class Optimizer:
self.backward(batch=batch)
self.update(i)
loss = self.layers[-1].l
if self.report_loss and not self.layers[-1].approx:
if self.report_loss and self.layers[-1].approx != 5:
print_ln('loss after epoch %s: %s', i, loss.reveal())
else:
print_ln('done with epoch %s', i)
@@ -1245,7 +1269,7 @@ class Optimizer:
if self.tol > 0:
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
return res
print_ln('finished after %s epochs', i)
print_ln('finished after %s epochs and %s iterations', i, n_iterations)
class Adam(Optimizer):
def __init__(self, layers, n_epochs):

View File

@@ -3,6 +3,7 @@ from Compiler.exceptions import *
from Compiler.instructions_base import RegType
import Compiler.instructions
import Compiler.instructions_base
import Compiler.instructions_base as inst_base
from . import compilerLib
from . import allocator as al
from . import util
@@ -40,22 +41,20 @@ class Program(object):
These are created by executing a file containing appropriate instructions
and threads. """
def __init__(self, args, options, param=-1):
def __init__(self, args, options):
self.options = options
self.verbose = options.verbose
self.args = args
self.init_names(args)
self.P = P_VALUES[param]
self.param = param
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
if sum(x != 0 for x in(options.ring, options.field,
options.binary)) > 1:
raise CompilerError('can only use one out of -p, -B, -R, -F')
raise CompilerError('can only use one out of -B, -R, -F')
if options.ring:
self.bit_length = int(options.ring) - 1
else:
self.bit_length = int(options.binary) or int(options.field)
if not self.bit_length:
self.bit_length = BIT_LENGTHS[param]
self.bit_length = 64
print('Default bit length:', self.bit_length)
self.security = 40
print('Default security parameter:', self.security)
@@ -67,7 +66,7 @@ class Program(object):
self._curr_tape = None
self.DEBUG = False
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
self.free_mem_blocks = defaultdict(lambda: defaultdict(set))
self.free_mem_blocks = defaultdict(al.BlockAllocator)
self.allocated_mem_blocks = {}
self.saved = 0
self.req_num = None
@@ -100,6 +99,7 @@ class Program(object):
self._edabit = options.edabit
self._split = False
self._square = False
self._always_raw = False
Program.prog = self
def get_args(self):
@@ -165,20 +165,28 @@ class Program(object):
return tape_index
def run_tape(self, tape_index, arg):
return self.run_tapes([[tape_index, arg]])[0]
def run_tapes(self, args):
if self.curr_tape is not self.tapes[0]:
raise CompilerError('Compiler does not support ' \
'recursive spawning of threads')
if self.free_threads:
thread_number = min(self.free_threads)
self.free_threads.remove(thread_number)
else:
thread_number = self.n_threads
self.n_threads += 1
thread_numbers = []
while len(thread_numbers) < len(args):
if self.free_threads:
thread_numbers.append(min(self.free_threads))
self.free_threads.remove(thread_numbers[-1])
else:
thread_numbers.append(self.n_threads)
self.n_threads += 1
self.curr_tape.start_new_basicblock(name='pre-run_tape')
Compiler.instructions.run_tape(thread_number, arg, tape_index)
Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in
zip(thread_numbers, args)), []))
self.curr_tape.start_new_basicblock(name='post-run_tape')
self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree)
return thread_number
for arg in args:
self.curr_tape.req_node.children.append(
self.tapes[arg[0]].req_tree)
return thread_numbers
def join_tape(self, thread_number):
self.curr_tape.start_new_basicblock(name='pre-join_tape')
@@ -250,19 +258,9 @@ class Program(object):
mem_type = mem_type.reg_type
elif reg_type is not None:
self.types[mem_type] = reg_type
block_size = 0
blocks = self.free_mem_blocks[mem_type]
if len(blocks[size]) > 0:
block_size = size
else:
for block_size, addresses in blocks.items():
if block_size >= size and len(addresses) > 0:
break
else:
block_size = 0
if block_size >= size:
addr = self.free_mem_blocks[mem_type][block_size].pop()
self.free_mem_blocks[mem_type][block_size - size].add(addr + size)
addr = blocks.pop(size)
if addr is not None:
self.saved += size
else:
addr = self.allocated_mem[mem_type]
@@ -278,7 +276,7 @@ class Program(object):
is not self.curr_tape.basicblocks[0].alloc_pool:
raise CompilerError('Cannot free memory within function block')
size = self.allocated_mem_blocks.pop((addr,mem_type))
self.free_mem_blocks[mem_type][size].add(addr)
self.free_mem_blocks[mem_type].push(addr, size)
def finalize_memory(self):
from . import library
@@ -341,6 +339,20 @@ class Program(object):
else:
self._square = change
def always_raw(self, change=None):
if change is None:
return self._always_raw
else:
self._always_raw = change
def options_from_args(self):
if 'trunc_pr' in self.args:
self.use_trunc_pr = True
if 'split' in self.args or 'split3' in self.args:
self.use_split(3)
if 'raw' in self.args:
self.always_raw(True)
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
def __init__(self, name, program):
@@ -445,13 +457,14 @@ class Tape:
instructions = self.instructions
for inst in instructions:
inst.add_usage(req_node)
req_node.num['all', 'round'] = self.n_rounds
req_node.num['all', 'inv'] = self.n_to_merge
req_node.num['all', 'round'] += self.n_rounds
req_node.num['all', 'inv'] += self.n_to_merge
def expand_cisc(self):
new_instructions = []
for inst in self.instructions:
new_instructions.extend(inst.expand_merged())
self.n_rounds += inst.expanded_rounds()
self.instructions = new_instructions
def __str__(self):
@@ -503,9 +516,6 @@ class Tape:
def unpurged(function):
def wrapper(self, *args, **kwargs):
if self.purged:
if self.program.verbose:
print('%s called on purged block %s, ignoring' % \
(function.__name__, self.name))
return
return function(self, *args, **kwargs)
return wrapper
@@ -585,7 +595,8 @@ class Tape:
reg_counts = self.count_regs()
if not options.noreallocate:
if self.program.verbose:
print('Tape register usage:', dict(reg_counts))
print('Tape register usage before re-allocation:',
dict(reg_counts))
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
print('Re-allocating...')
@@ -613,6 +624,8 @@ class Tape:
alloc_loop(block.exit_block.scope)
allocator.process(block.instructions, block.alloc_pool)
allocator.finalize(options)
if self.program.verbose:
print('Tape register usage:', dict(allocator.usage))
# offline data requirements
if self.program.verbose:
@@ -847,10 +860,6 @@ class Tape:
if t == 'p':
self.req_bit_length[t] = max(bit_length + 1, \
self.req_bit_length[t])
if self.program.param != -1 and bit_length >= self.program.param:
raise CompilerError('Inadequate bit length %d for prime, ' \
'program requires %d bits' % \
(self.program.param, self.req_bit_length['p']))
else:
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
@@ -862,6 +871,7 @@ class Tape:
__slots__ = ["reg_type", "program", "absolute_i", "relative_i", \
"size", "vector", "vectorbase", "caller", \
"can_eliminate", "duplicates"]
maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1
def __init__(self, reg_type, program, size=None, i=None):
""" Creates a new register.
@@ -875,6 +885,8 @@ class Tape:
self.program = program
if size is None:
size = Compiler.instructions_base.get_global_vector_size()
if size is not None and size > self.maximum_size:
raise CompilerError('vector too large')
self.size = size
self.vectorbase = self
self.relative_i = 0

View File

@@ -147,7 +147,8 @@ def vectorize(operation):
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
and not isinstance(args[0], bits) \
and args[0].size != self.size:
raise CompilerError('Different vector sizes of operands')
raise CompilerError('Different vector sizes of operands: %d/%d'
% (self.size, args[0].size))
set_global_vector_size(self.size)
res = operation(self, *args, **kwargs)
reset_global_vector_size()
@@ -789,7 +790,7 @@ class cint(_clear, _int):
bit_length = 1 + int(math.ceil(math.log(abs(val))))
if program.options.ring:
assert(bit_length <= int(program.options.ring))
elif program.param != -1 or program.options.field:
elif program.options.field:
program.curr_tape.require_bit_length(bit_length)
if self.in_immediate_range(val):
ldi(self, val)
@@ -1731,6 +1732,15 @@ class sint(_secret, _int):
""" Secret random n-bit number according to security model.
:param bits: compile-time integer (int) """
if program.use_split() == 3:
tmp = sint()
randoms(tmp, bits)
x = tmp.split_to_two_summands(bits, True)
overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \
sint.conv(x[0][-1])
return tmp - (overflow << bits)
elif program.use_edabit():
return sint.get_edabit(bits, True)[0]
res = sint()
comparison.PRandInt(res, bits)
return res
@@ -2068,6 +2078,37 @@ class sint(_secret, _int):
def two_power(n):
return floatingpoint.two_power(n)
def split_to_n_summands(self, length, n):
from .GC.types import sbits
from .GC.instructions import split
columns = [[sbits.get_type(self.size)()
for i in range(n)] for i in range(length)]
split(n, self, *sum(columns, []))
return columns
def split_to_two_summands(self, length, get_carry=False):
n = program.use_split()
assert n
columns = self.split_to_n_summands(length, n)
return _bitint.wallace_tree_without_finish(columns, get_carry)
@vectorize
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Result potentially written to ``Player-Data/Private-Output-P<player>.``
:param player: public integer (int/regint/cint):
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
"""
if not util.is_constant(player) or self.size > 1:
secret_mask = sint()
player_mask = cint()
inputmaskreg(secret_mask, player_mask, player)
return personal(player,
(self + secret_mask).reveal() - player_mask)
else:
return super(sint, self).reveal_to(player)
class sgf2n(_secret, _gf2n):
""" Secret GF(2^n) value. """
__slots__ = []
@@ -2255,7 +2296,8 @@ class _bitint(object):
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
get_carry=False):
lower = []
for (ai,bi) in zip(a,b):
a, b = a[:], b[:]
for (ai, bi) in zip(a[:], b[:]):
if is_zero(ai) or is_zero(bi):
lower.append(ai + bi)
a.pop(0)
@@ -2404,6 +2446,7 @@ class _bitint(object):
@classmethod
def wallace_tree_without_finish(cls, columns, get_carry=True):
self = cls
columns = [col[:] for col in columns]
while max(len(c) for c in columns) > 2:
new_columns = [[] for i in range(len(columns) + 1)]
for i,col in enumerate(columns):
@@ -2439,11 +2482,12 @@ class _bitint(object):
raise CompilerError('Unclear subtraction')
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)]
d = [(reduce(util.bit_xor, (ai, bi, 1)), (1 - ai) * bi)
for (ai,bi) in zip(a,b)]
borrow = lambda y,x,*args: \
(x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1]))
borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1]
return self.compose(ai + bi + borrow \
return self.compose(reduce(util.bit_xor, (ai, bi, borrow)) \
for (ai,bi,borrow) in zip(a,b,borrows))
def __rsub__(self, other):
@@ -2715,10 +2759,15 @@ class cfix(_number, _structure):
def set_precision(cls, f, k = None):
""" Set the precision of the integer representation. Note that some
operations are undefined when the precision of :py:class:`sfix` and
:py:class:`cfix` differs.
:py:class:`cfix` differs. The initial defaults are chosen to
allow the best optimization of probabilistic truncation in
computation modulo 2^64 (2*k < 64). Generally, 2*k must be at
most the integer length for rings and at most m-s-1 for
computation modulo an m-bit prime and statistical security s
(default 40).
:param f: bit length of decimal part
:param k: whole bit length of fixed point, defaults to twice :py:obj:`f`.
:param f: bit length of decimal part (initial default 16)
:param k: whole bit length of fixed point, defaults to twice :py:obj:`f` if not given (initial default 31)
"""
cls.f = f
@@ -2789,6 +2838,10 @@ class cfix(_number, _structure):
else:
raise CompilerError('cannot initialize cfix with %s' % v)
def __iter__(self):
for x in self.v:
yield type(self)(x, self.k, self.f)
@vectorize
def load_int(self, v):
self.v = cint(v) * (2 ** self.f)
@@ -3141,7 +3194,6 @@ class _fix(_single):
""" Secret fixed point type. """
__slots__ = ['v', 'f', 'k', 'size']
@classmethod
def set_precision(cls, f, k = None):
cls.f = f
# default bitlength = 2*precision
@@ -3152,6 +3204,7 @@ class _fix(_single):
raise CompilerError('bit length cannot be less than precision')
cls.k = k
set_precision.__doc__ = cfix.set_precision.__doc__
set_precision = classmethod(set_precision)
@classmethod
def coerce(cls, other):
@@ -3377,9 +3430,10 @@ class sfix(_fix):
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Raw representation written to ``Player-Data/Private-Output-P<player>``
Raw representation possibly written to
``Player-Data/Private-Output-P<player>.``
:param player: int
:param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
"""
return personal(player, cfix(self.v.reveal_to(player)._v,
@@ -3419,17 +3473,8 @@ class unreduced_sfix(_single):
sfix.unreduced_type = unreduced_sfix
# this is for 20 bit decimal precision
# with 40 bitlength of entire number
# these constants have been chosen for multiplications to fit in 128 bit prime field
# (precision n1) 41 + (precision n2) 41 + (stat_sec) 40 = 82 + 40 = 122 <= 128
# with statistical security of 40
fixed_lower = 20
fixed_upper = 40
sfix.set_precision(fixed_lower, fixed_upper)
cfix.set_precision(fixed_lower, fixed_upper)
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
class squant(_single):
""" Quantization as in ArXiv:1712.05877v1 """
@@ -4222,8 +4267,10 @@ class Array(object):
:param value: convertible to basic type """
if conv:
value = self.value_type.conv(value)
if value.size != 1:
raise CompilerError('cannot assign vector to all elements')
mem_value = MemValue(value)
self.address = MemValue(self.address)
self.address = MemValue.if_necessary(self.address)
n_threads = 8 if use_threads and len(self) > 2**20 else 1
@library.for_range_multithread(n_threads, 1024, len(self))
def f(i):
@@ -4242,7 +4289,7 @@ class Array(object):
def get(self, indices):
return self.value_type.load_mem(
regint(self.address, size=len(indices)) + indices,
regint.inc(len(indices), self.address, 0) + indices,
size=len(indices))
def expand_to_vector(self, index, size):
@@ -4258,13 +4305,13 @@ class Array(object):
""" Fill with inputs from player if supported by type.
:param player: public (regint/cint/int) """
if raw:
if raw or program.always_raw():
input_from = self.value_type.get_raw_input_from
else:
input_from = self.value_type.get_input_from
try:
self.assign(input_from(player, size=len(self)))
except:
except TypeError:
@library.for_range_opt(len(self), budget=budget)
def _(i):
self[i] = input_from(player)
@@ -4318,6 +4365,12 @@ class Array(object):
:returns: Array of relevant clear type. """
return Array.create_from(x.reveal() for x in self)
def reveal_list(self):
""" Reveal as list. """
return list(self.get_vector().reveal())
reveal_nested = reveal_list
sint.dynamic_array = Array
sgf2n.dynamic_array = Array
@@ -4454,16 +4507,17 @@ class SubMultiArray(object):
""" Fill with inputs from player if supported by type.
:param player: public (regint/cint/int) """
budget = budget or 2 ** 21
budget = budget or Tape.Register.maximum_size
if (self.total_size() < budget) and \
self.value_type.n_elements() == 1:
if raw:
if raw or program.always_raw():
input_from = self.value_type.get_raw_input_from
else:
input_from = self.value_type.get_input_from
self.assign_vector(input_from(player, size=self.total_size()))
else:
@library.for_range_opt(self.sizes[0], budget=budget)
@library.for_range_opt(self.sizes[0],
budget=budget / self[0].total_size())
def _(i):
self[i].input_from(player, budget=budget, raw=raw)
@@ -4687,6 +4741,21 @@ class SubMultiArray(object):
library.break_point()
return res
def reveal_list(self):
""" Reveal as list. """
return list(self.get_vector().reveal())
def reveal_nested(self):
""" Reveal as nested list. """
flat = iter(self.get_vector().reveal())
res = []
def f(sizes):
if len(sizes) == 1:
return [next(flat) for i in range(sizes[0])]
else:
return [f(sizes[1:]) for i in range(sizes[0])]
return f(self.sizes)
class MultiArray(SubMultiArray):
""" Multidimensional array. """
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
@@ -4836,10 +4905,12 @@ class MemValue(_mem):
self.value_type = type(value)
self.deleted = False
if address is None:
self.address = self.value_type.malloc(1)
self.address = self.value_type.malloc(value.size)
self.size = value.size
self.write(value)
else:
self.address = address
self.size = 1
def delete(self):
self.value_type.free(self.address)
@@ -4869,6 +4940,8 @@ class MemValue(_mem):
elif isinstance(value, int):
self.register = self.value_type(value)
else:
if value.size != self.size:
raise CompilerError('size mismatch')
self.register = value
if not isinstance(self.register, self.value_type):
raise CompilerError('Mismatch in register type, cannot write \

View File

@@ -46,8 +46,10 @@ def right_shift(a, b, bits):
else:
return a.right_shift(b, bits)
def bit_decompose(a, bits):
def bit_decompose(a, bits=None):
if isinstance(a, int):
if bits is None:
bits = int_len(a)
return [int((a >> i) & 1) for i in range(bits)]
else:
return a.bit_decompose(bits)
@@ -122,6 +124,15 @@ def or_op(a, b):
OR = or_op
def bit_xor(a, b):
if is_constant(a):
if is_constant(b):
return a ^ b
else:
return b.bit_xor(a)
else:
return a.bit_xor(b)
def pow2(bits):
powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)]
return tree_reduce(operator.mul, powers)

View File

@@ -92,8 +92,8 @@ bool Proof::check_bounds(T& z, X& t, int i) const
unsigned int j,k;
// Check Bound 1 and Bound 2
AbsoluteBoundChecker<fixint<2>> plain_checker(plain_check * n_proofs);
AbsoluteBoundChecker<fixint<2>> rand_checker(rand_check * n_proofs);
AbsoluteBoundChecker<fixint<GFP_MOD_SZ>> plain_checker(plain_check * n_proofs);
AbsoluteBoundChecker<fixint<GFP_MOD_SZ>> rand_checker(rand_check * n_proofs);
for (j=0; j<phim; j++)
{
auto& te = z[j];

View File

@@ -252,8 +252,10 @@ void SummingEncCommit<FD>::create_more()
this->c.unpack(ciphertexts, this->pk);
commitments.unpack(ciphertexts, this->pk);
#ifdef VERBOSE_HE
cout << "Tree-wise sum of ciphertexts with "
<< 1e-9 * ciphertexts.get_length() << " GB" << endl;
#endif
this->timers["Exchanging ciphertexts"].start();
tree_sum.run(this->c, P);
tree_sum.run(commitments, P);

View File

@@ -13,9 +13,12 @@ template <class T>
class ArgIter
{
vector<int>::const_iterator it;
vector<int>::const_iterator end;
public:
ArgIter(const vector<int>::const_iterator it) : it(it)
ArgIter(const vector<int>::const_iterator it,
const vector<int>::const_iterator end) :
it(it), end(end)
{
}
@@ -27,8 +30,10 @@ public:
ArgIter<T> operator++()
{
auto res = it;
it += T::n;
return res;
it += T(res).n;
if (it > end)
throw runtime_error("wrong number of args");
return {res, end};
}
bool operator!=(const ArgIter<T>& other)
@@ -46,18 +51,16 @@ public:
ArgList(const vector<int>& args) :
args(args)
{
if (args.size() % T::n != 0)
throw runtime_error("wrong number of args");
}
ArgIter<T> begin()
{
return args.begin();
return {args.begin(), args.end()};
}
ArgIter<T> end()
{
return args.end();
return {args.end(), args.end()};
}
};
@@ -81,11 +84,12 @@ public:
}
};
class InputArgList : public ArgList<InputArgs>
template<class T>
class InputArgListBase : public ArgList<T>
{
public:
InputArgList(const vector<int>& args) :
ArgList<InputArgs>(args)
InputArgListBase(const vector<int>& args) :
ArgList<T>(args)
{
}
@@ -97,7 +101,55 @@ public:
return res;
}
int n_input_bits()
{
int res = 0;
for (auto x : *this)
res += x.n_bits;
return res;
}
int n_interactive_inputs_from_me(int my_num);
};
class InputArgList : public InputArgListBase<InputArgs>
{
public:
InputArgList(const vector<int>& args) :
InputArgListBase<InputArgs>(args)
{
}
};
class InputVecArgs
{
public:
int from;
int n;
int& n_bits;
int& n_shift;
int params[2];
vector<int> dest;
InputVecArgs(vector<int>::const_iterator it) : n_bits(params[0]), n_shift(params[1])
{
n = *it++;
n_bits = n - 3;
n_shift = *it++;
from = *it++;
dest.resize(n);
for (int i = 0; i < n_bits; i++)
dest[i] = *it++;
}
};
class InputVecArgList : public InputArgListBase<InputVecArgs>
{
public:
InputVecArgList(const vector<int>& args) :
InputArgListBase<InputVecArgs>(args)
{
}
};
#endif /* GC_ARGTUPLES_H_ */

View File

@@ -12,9 +12,8 @@
namespace GC
{
int FakeSecret::default_length = 128;
ostream& FakeSecret::out = cout;
SwitchableOutput FakeSecret::out;
const int FakeSecret::default_length;
void FakeSecret::load_clear(int n, const Integer& x)
{
@@ -60,12 +59,9 @@ void FakeSecret::store_clear_in_dynamic(Memory<DynamicType>& mem,
void FakeSecret::ands(Processor<FakeSecret>& processor,
const vector<int>& regs)
{
processor.check_args(regs, 4);
for (size_t i = 0; i < regs.size(); i += 4)
processor.S[regs[i + 1]] = processor.S[regs[i + 2]].a & processor.S[regs[i + 3]].a;
processor.ands(regs);
}
void FakeSecret::trans(Processor<FakeSecret>& processor, int n_outputs,
const vector<int>& args)
{
@@ -82,15 +78,13 @@ FakeSecret FakeSecret::input(GC::Processor<FakeSecret>& processor, const InputAr
return input(args.from, processor.get_input(args.params), args.n_bits);
}
FakeSecret FakeSecret::input(int from, const int128& input, int n_bits)
FakeSecret FakeSecret::input(int from, word input, int n_bits)
{
(void)from;
(void)n_bits;
FakeSecret res;
res.a = ((__uint128_t)input.get_upper() << 64) + input.get_lower();
if (res.a > ((__uint128_t)1 << n_bits))
if (n_bits < 64 and input > word(1) << n_bits)
throw out_of_range("input too large");
return res;
return input;
}
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
@@ -99,7 +93,7 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
if (repeat)
return andrs(n, x, y);
else
throw runtime_error("call static FakeSecret::ands()");
*this = BitVec(x & y).mask(n);
}
} /* namespace GC */

View File

@@ -12,8 +12,14 @@
#include "GC/ArgTuples.h"
#include "Math/gf2nlong.h"
#include "Tools/SwitchableOutput.h"
#include "Processor/DummyProtocol.h"
#include "Protocols/FakePrep.h"
#include "Protocols/FakeMC.h"
#include "Protocols/FakeProtocol.h"
#include "Protocols/FakeInput.h"
#include "Protocols/ShareInterface.h"
#include <random>
#include <fstream>
@@ -26,21 +32,41 @@ class Processor;
template <class T>
class Machine;
class FakeSecret
class FakeSecret : public ShareInterface, public BitVec
{
__uint128_t a;
public:
typedef FakeSecret DynamicType;
typedef Memory<FakeSecret> DynamicMemory;
typedef BitVec mac_key_type;
typedef BitVec clear;
typedef BitVec open_type;
typedef FakeSecret part_type;
typedef FakeSecret small_type;
typedef NoShare bit_type;
typedef FakePrep<FakeSecret> LivePrep;
typedef FakeMC<FakeSecret> MC;
typedef MC MAC_Check;
typedef MC Direct_MC;
typedef FakeProtocol<FakeSecret> Protocol;
typedef FakeInput<FakeSecret> Input;
static string type_string() { return "fake secret"; }
static string phase_name() { return "Faking"; }
static int default_length;
static const int default_length = 64;
typedef ostream& out_type;
static ostream& out;
static const bool is_real = true;
static const bool actual_inputs = true;
static SwitchableOutput out;
static DataFieldType field_type() { return DATA_GF2; }
static MC* new_mc(mac_key_type key) { return new MC(key); }
static void store_clear_in_dynamic(Memory<DynamicType>& mem,
const vector<GC::ClearWriteAccess>& accesses);
@@ -59,6 +85,11 @@ public:
static void inputb(T& processor, const vector<int>& args)
{ processor.input(args); }
template <class T>
static void inputb(T& processor, ArithmeticProcessor&, const vector<int>& args)
{ processor.input(args); }
template <class T, class U>
static void inputbvec(T&, U&, const vector<int>&) { throw not_implemented(); }
template <class T>
static void reveal_inst(T& processor, const vector<int>& args)
{ processor.reveal(args); }
@@ -69,11 +100,13 @@ public:
static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; }
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, const int128& input, int n_bits);
static FakeSecret input(int from, word input, int n_bits);
FakeSecret() : a(0) {}
FakeSecret(const Integer& x) : a(x.get()) {}
FakeSecret(__uint128_t x) : a(x) {}
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
FakeSecret() {}
template <class T>
FakeSecret(T other) : BitVec(other) {}
__uint128_t operator>>(const FakeSecret& other) const { return a >> other.a; }
__uint128_t operator<<(const FakeSecret& other) const { return a << other.a; }
@@ -90,17 +123,19 @@ public:
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
template <class T>
void xor_(int n, const FakeSecret& x, const T& y) { (void)n; a = x.a ^ y.a; }
void xor_(int n, const FakeSecret& x, const T& y)
{ *this = BitVec(x.a ^ y.a).mask(n); }
void and_(int n, const FakeSecret& x, const FakeSecret& y, bool repeat);
void andrs(int n, const FakeSecret& x, const FakeSecret& y) { (void)n; a = x.a * y.a; }
void andrs(int n, const FakeSecret& x, const FakeSecret& y)
{ *this = BitVec(x.a * (y.a & 1)).mask(n); }
void invert(int, const FakeSecret& x) { *this = ~x.a; }
void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); }
void random_bit() { a = random() % 2; }
void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; }
int size() { return -1; }
void invert(FakeSecret) { throw not_implemented(); }
};
} /* namespace GC */

View File

@@ -65,6 +65,7 @@ enum
STMSBI = 0x243,
MOVSB = 0x244,
INPUTB = 0x246,
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
// write to clear

View File

@@ -67,7 +67,7 @@ unsigned Instruction::get_mem(RegType reg_type) const
}
break;
default:
return BaseInstruction::get_mem(reg_type, MAX_SECRECY_TYPE);
return BaseInstruction::get_mem(reg_type);
}
return 0;

View File

@@ -58,6 +58,7 @@ public:
void stop_timer() { timer[0].stop(); }
void reset_timer() { timer[0].reset(); }
void run_tapes(const vector<int>& args);
void run_tape(int thread_number, int tape_number, int arg);
void join_tape(int thread_numer);
};

View File

@@ -61,8 +61,8 @@ template <class T>
template <class U>
void Memories<T>::reset(const U& program)
{
MS.resize_min(*program.direct_mem(SBIT), "memory");
MC.resize_min(*program.direct_mem(CBIT), "memory");
MS.resize_min(program.direct_mem(SBIT), "memory");
MC.resize_min(program.direct_mem(CBIT), "memory");
}
template <class T>
@@ -70,7 +70,7 @@ template <class U>
void Machine<T>::reset(const U& program)
{
Memories<T>::reset(program);
MI.resize_min(*program.direct_mem(INT), "memory");
MI.resize_min(program.direct_mem(INT), "memory");
}
template <class T>
@@ -78,12 +78,20 @@ template <class U, class V>
void Machine<T>::reset(const U& program, V& MD)
{
reset(program);
MD.resize_min(*program.direct_mem(DYN_SBIT), "dynamic memory");
MD.resize_min(program.direct_mem(DYN_SBIT), "dynamic memory");
#ifdef DEBUG_MEMORY
cerr << "reset dynamic mem to " << program.direct_mem(DYN_SBIT) << endl;
#endif
}
template<class T>
void Machine<T>::run_tapes(const vector<int>& args)
{
assert(args.size() % 3 == 0);
for (unsigned i = 0; i < args.size(); i++)
run_tape(args[i], args[i + 1], args[i + 2]);
}
template<class T>
void Machine<T>::run_tape(int thread_number, int tape_number, int arg)
{

View File

@@ -6,15 +6,23 @@
#ifndef GC_NOSHARE_H_
#define GC_NOSHARE_H_
#include "BMR/Register.h"
#include "Processor/DummyProtocol.h"
#include "Tools/SwitchableOutput.h"
class InputArgs;
class ArithmeticProcessor;
namespace GC
{
template<class T> class Processor;
class Clear;
class NoValue : public ValueInterface
{
public:
typedef NoValue Scalar;
const static int n_bits = 0;
const static int MAX_N_BITS = 0;
@@ -28,6 +36,11 @@ public:
return 0;
}
static int length()
{
return 0;
}
static void fail()
{
throw runtime_error("VM does not support binary circuits");
@@ -43,9 +56,13 @@ public:
int operator<<(int) const { fail(); return 0; }
void operator+=(int) { fail(); }
bool operator!=(NoValue) const { fail(); return 0; }
bool get_bit(int) { fail(); return 0; }
void randomize(PRNG&) { fail(); }
void invert() { fail(); }
};
inline ostream& operator<<(ostream& o, NoValue)
@@ -53,18 +70,11 @@ inline ostream& operator<<(ostream& o, NoValue)
return o;
}
template<class T>
inline bool operator!=(const T&, NoValue&)
{
NoValue::fail();
return true;
}
class NoShare : public Phase
class NoShare
{
public:
typedef DummyMC<NoShare> MC;
typedef DummyProtocol Protocol;
typedef DummyProtocol<NoShare> Protocol;
typedef NotImplementedInput<NoShare> Input;
typedef DummyLivePrep<NoShare> LivePrep;
typedef DummyMC<NoShare> MAC_Check;
@@ -83,6 +93,8 @@ public:
static const bool expensive_triples = false;
static const bool is_real = false;
static SwitchableOutput out;
static MC* new_mc(mac_key_type)
{
return new MC;
@@ -118,13 +130,16 @@ public:
NoValue::fail();
}
static void inputb(Processor<NoShare>&, const vector<int>&) { fail(); }
static void inputb(Processor<NoShare>&, ArithmeticProcessor&, const vector<int>&) { fail(); }
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(GC::Clear, int, mac_key_type) { fail(); return {}; }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
NoShare() {}
@@ -146,6 +161,8 @@ public:
void operator^=(NoShare) { fail(); }
NoShare operator+(const NoShare&) const { fail(); return {}; }
NoShare operator-(NoShare) const { fail(); return 0; }
NoShare operator*(NoValue) const { fail(); return 0; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
@@ -153,6 +170,8 @@ public:
NoShare lsb() const { fail(); return {}; }
NoShare get_bit(int) const { fail(); return {}; }
void invert(int, NoShare) { fail(); }
};
} /* namespace GC */

View File

@@ -20,17 +20,6 @@ using namespace std;
namespace GC
{
class ExecutionStats : public map<int, size_t>
{
public:
ExecutionStats& operator+=(const ExecutionStats& other)
{
for (auto it : other)
(*this)[it.first] += it.second;
return *this;
}
};
template <class T>
class Processor : public ::ProcessorBase, public GC::RuntimeBranching
{
@@ -53,8 +42,6 @@ public:
Memory<Clear> C;
Memory<Integer> I;
ExecutionStats stats;
Timer xor_timer;
Processor(Machine<T>& machine);
@@ -102,9 +89,12 @@ public:
void ands(const vector<int>& args) { and_(args, false); }
void input(const vector<int>& args);
void reveal(const vector<int>& args);
void inputb(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num);
void inputbvec(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num);
void reveal(const ::BaseInstruction& instruction);
void reveal(const vector<int>& args);
void print_reg(int reg, int n, int size);
void print_reg_plain(Clear& value);

View File

@@ -79,6 +79,8 @@ template<class U>
U GC::Processor<T>::get_long_input(const int* params,
ProcessorBase& input_proc, bool interactive)
{
if (not T::actual_inputs)
return {};
U res = input_proc.get_input<FixInput_<U>>(interactive,
&params[1]).items[0];
int n_bits = *params;
@@ -251,8 +253,12 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
check_args(args, 4);
for (size_t i = 0; i < args.size(); i += 4)
{
assert(args[i] <= T::default_length);
S[args[i+1]].and_(args[i], S[args[i+2]], S[args[i+3]], repeat);
for (int j = 0; j < DIV_CEIL(args[i], T::default_length); j++)
{
int n = min(T::default_length, args[i] - j * T::default_length);
S[args[i + 1] + j].and_(n, S[args[i + 2] + j],
S[args[i + 3] + (repeat ? 0 : j)], repeat);
}
complexity += args[i];
}
}

View File

@@ -48,8 +48,8 @@ class Program
unsigned num_reg(RegType reg_type) const
{ return max_reg[reg_type]; }
const unsigned* direct_mem(RegType reg_type) const
{ return &max_mem[reg_type]; }
unsigned direct_mem(RegType reg_type) const
{ return max_mem[reg_type]; }
template<class T, class U>
BreakType execute(Processor<T>& Proc, U& dynamic_memory, int PC = -1) const;

View File

@@ -60,6 +60,8 @@ public:
typedef NoShare bit_type;
typedef typename T::Input Input;
static string type_string() { return "evaluation secret"; }
static string phase_name() { return T::name(); }
@@ -71,6 +73,8 @@ public:
static const bool is_real = true;
static const bool actual_inputs = T::actual_inputs;
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
static Secret<T> input(Processor<Secret<T>>& processor, const InputArgs& args);
void random(int n_bits, int128 share);
@@ -102,6 +106,10 @@ public:
const vector<int>& args)
{ T::inputb(processor, input_proc, args); }
template<class U>
static void inputbvec(Processor<U>& processor, ProcessorBase& input_proc,
const vector<int>& args)
{ T::inputbvec(processor, input_proc, args); }
template<class U>
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
{ processor.reveal(args); }
@@ -143,6 +151,13 @@ public:
template <class U>
void reveal(size_t n_bits, U& x);
template <class U>
void my_input(U& inputter, BitVec value, int n_bits);
template <class U>
void other_input(U& inputter, int from, int n_bits);
template <class U>
void finalize_input(U& inputter, int from, int n_bits);
int size() const { return registers.size(); }
RegVector& get_regs() { return registers; }
const RegVector& get_regs() const { return registers; }

View File

@@ -57,6 +57,33 @@ Secret<T> Secret<T>::input(party_id_t from, const int128& input, int n_bits)
return res;
}
template<class T>
template<class U>
void GC::Secret<T>::my_input(U& inputter, BitVec value, int n_bits)
{
resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
get_reg(i).my_input(inputter, value.get_bit(i), 1);
}
template<class T>
template<class U>
void GC::Secret<T>::other_input(U& inputter, int from, int n_bits)
{
resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
get_reg(i).other_input(inputter, from);
}
template<class T>
template<class U>
void GC::Secret<T>::finalize_input(U& inputter, int from, int n_bits)
{
resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
get_reg(i).finalize_input(inputter, from, 1);
}
template<class T>
void Secret<T>::random(int n_bits, int128 share)
{

View File

@@ -41,9 +41,8 @@ public:
typedef Memory<U> DynamicMemory;
typedef SwitchableOutput out_type;
static const bool expensive_triples = false;
static const bool is_real = true;
static const bool actual_inputs = true;
static SwitchableOutput out;
@@ -63,6 +62,8 @@ public:
{ inputb(processor, processor, args); }
static void inputb(Processor<U>& processor, ProcessorBase& input_processor,
const vector<int>& args);
static void inputbvec(Processor<U>& processor, ProcessorBase& input_processor,
const vector<int>& args);
static void reveal_inst(Processor<U>& processor, const vector<int>& args);
template<class T>
@@ -75,6 +76,13 @@ public:
void invert(int n, const U& x);
void random_bit();
template<class T>
void my_input(T& inputter, BitVec value, int n_bits);
template<class T>
void other_input(T& inputter, int from, int n_bits = 1);
template<class T>
void finalize_input(T& inputter, int from, int n_bits);
};
template<class U>
@@ -169,6 +177,8 @@ public:
typedef SemiHonestRepSecret small_type;
typedef SemiHonestRepSecret whole_type;
static const bool expensive_triples = false;
static MC* new_mc(mac_key_type) { return new MC; }
SemiHonestRepSecret() {}

View File

@@ -98,6 +98,28 @@ void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
mem[access.address] = access.value;
}
template<class U>
template<class T>
void GC::ShareSecret<U>::my_input(T& inputter, BitVec value, int n_bits)
{
inputter.add_mine(value, n_bits);
}
template<class U>
template<class T>
void GC::ShareSecret<U>::other_input(T& inputter, int from, int)
{
inputter.add_other(from);
}
template<class U>
template<class T>
void GC::ShareSecret<U>::finalize_input(T& inputter, int from,
int n_bits)
{
static_cast<U&>(*this) = inputter.finalize(from, n_bits).mask(n_bits);
}
template<class U>
void ShareSecret<U>::inputb(Processor<U>& processor,
ProcessorBase& input_processor,
@@ -106,25 +128,46 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
auto& party = ShareThread<U>::s();
typename U::Input input(*party.MC, party.DataF, *party.P);
input.reset_all(*party.P);
processor.inputb(input, input_processor, args, party.P->my_num());
}
template<class U>
void ShareSecret<U>::inputbvec(Processor<U>& processor,
ProcessorBase& input_processor,
const vector<int>& args)
{
auto& party = ShareThread<U>::s();
typename U::Input input(*party.MC, party.DataF, *party.P);
input.reset_all(*party.P);
processor.inputbvec(input, input_processor, args, party.P->my_num());
}
template <class T>
void Processor<T>::inputb(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num)
{
InputArgList a(args);
bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0;
int dl = U::default_length;
complexity += a.n_input_bits();
bool interactive = a.n_interactive_inputs_from_me(my_num) > 0;
int dl = T::default_length;
for (auto x : a)
{
if (x.from == party.P->my_num())
if (x.from == my_num)
{
bigint whole_input = processor.template
get_long_input<bigint>(x.params,
bigint whole_input = get_long_input<bigint>(x.params,
input_processor, interactive);
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
input.add_mine(bigint(whole_input >> (i * dl)).get_si(),
{
auto& res = S[x.dest + i];
res.my_input(input, bigint(whole_input >> (i * dl)).get_si(),
min(dl, x.n_bits - i * dl));
}
}
else
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
input.add_other(x.from);
S[x.dest + i].other_input(input, x.from,
min(dl, x.n_bits - i * dl));
}
if (interactive)
@@ -138,9 +181,51 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
int n_bits = x.n_bits;
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
{
auto& res = processor.S[x.dest + i];
auto& res = S[x.dest + i];
int n = min(dl, n_bits - i * dl);
res = input.finalize(from, n).mask(n);
res.finalize_input(input, from, n);
}
}
}
template <class T>
void Processor<T>::inputbvec(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num)
{
InputVecArgList a(args);
complexity += a.n_input_bits();
bool interactive = a.n_interactive_inputs_from_me(my_num) > 0;
for (auto x : a)
{
if (x.from == my_num)
{
bigint whole_input = get_long_input<bigint>(x.params,
input_processor, interactive);
for (int i = 0; i < x.n_bits; i++)
{
auto& res = S[x.dest[i]];
res.my_input(input, bigint(whole_input >> (i)).get_si() & 1, 1);
}
}
else
for (int i = 0; i < x.n_bits; i++)
S[x.dest[i]].other_input(input, x.from, 1);
}
if (interactive)
cout << "Thank you" << endl;
input.exchange();
for (auto x : a)
{
int from = x.from;
int n_bits = x.n_bits;
for (int i = 0; i < n_bits; i++)
{
auto& res = S[x.dest[i]];
res.finalize_input(input, from, 1);
}
}
}

View File

@@ -57,6 +57,9 @@ public:
void pre_run();
void post_run() { ShareThread<T>::post_run(); }
size_t data_sent()
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
};
template<class T>

View File

@@ -56,7 +56,7 @@ public:
void join_tape();
void finish();
int n_interactive_inputs_from_me(InputArgList& args);
virtual size_t data_sent();
};
template<class T>

View File

@@ -91,15 +91,17 @@ void Thread<T>::finish()
}
template<class T>
int Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
size_t GC::Thread<T>::data_sent()
{
return args.n_interactive_inputs_from_me(P->my_num());
assert(P);
return P->comm_stats.total_data();
}
} /* namespace GC */
inline int InputArgList::n_interactive_inputs_from_me(int my_num)
template<class T>
inline int InputArgListBase<T>::n_interactive_inputs_from_me(int my_num)
{
int res = 0;
if (ArithmeticProcessor().use_stdin())

View File

@@ -87,10 +87,12 @@ void ThreadMaster<T>::run()
NamedCommStats stats = P->comm_stats;
ExecutionStats exe_stats;
size_t data_sent = P->comm_stats.total_data();
for (auto thread : threads)
{
stats += thread->P->comm_stats;
exe_stats += thread->processor.stats;
data_sent += thread->data_sent();
delete thread;
}
@@ -108,7 +110,8 @@ void ThreadMaster<T>::run()
if (it->second.data > 0)
cerr << it->first << " " << 1e-6 * it->second.data << " MB" << endl;
cerr << "Time = " << timer.elapsed() << endl;
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
}
} /* namespace GC */

View File

@@ -151,6 +151,18 @@ public:
for (auto& reg : this->get_regs())
reg.output(s, human);
}
template <class U>
void my_input(U& inputter, BitVec value, int n_bits)
{
inputter.add_mine(value, n_bits);
}
template <class U>
void finalize_input(U& inputter, int from, int n_bits)
{
*this = inputter.finalize(from, n_bits).mask(n_bits);
}
};
template<int S>

View File

@@ -71,6 +71,7 @@
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
X(ANDM, processor.andm(instruction)) \
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
@@ -85,6 +86,7 @@
#define GC_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, EXTRA)) \
X(INPUTBVEC, T::inputbvec(PROC, PROC, EXTRA)) \
X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \
X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \
@@ -130,7 +132,7 @@
X(PRINTINT, S0.out << I0) \
X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \
X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \
X(RUN_TAPE, MACH->run_tape(R0, IMM, REG1)) \
X(RUN_TAPE, MACH->run_tapes(EXTRA)) \
X(JOIN_TAPE, MACH->join_tape(R0)) \
X(USE, ) \

59
Machines/emulate.cpp Normal file
View File

@@ -0,0 +1,59 @@
/*
* emulate.cpp
*
*/
#include "Protocols/FakeShare.h"
#include "Processor/Machine.h"
#include "Math/Z2k.h"
#include "Math/gf2n.h"
#include "Processor/RingOptions.h"
#include "Processor/Machine.hpp"
#include "Math/Z2k.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ShuffleSacrifice.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/FakeShare.hpp"
SwitchableOutput GC::NoShare::out;
int main(int argc, const char** argv)
{
assert(argc > 1);
OnlineOptions online_opts;
Names N(0, 9999, vector<string>({"localhost"}));
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
opt.parse(argc, argv);
string progname;
if (opt.firstArgs.size() > 1)
progname = *opt.firstArgs.at(1);
else if (not opt.lastArgs.empty())
progname = *opt.lastArgs.at(0);
else if (not opt.unknownArgs.empty())
progname = *opt.unknownArgs.at(0);
else
{
string usage;
opt.getUsage(usage);
cerr << usage << endl;
exit(1);
}
switch (ring_opts.R)
{
case 64:
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 128:
Machine<FakeShare<SignedZ2<128>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
default:
cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl;
}
}

View File

@@ -5,9 +5,8 @@
#include "Protocols/MaliciousRep3Share.h"
#include "Machines/Rep.hpp"
#include "BMR/RealProgramParty.hpp"
#include "Machines/Rep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -3,9 +3,8 @@
*
*/
#include "Machines/ShamirMachine.hpp"
#include "BMR/RealProgramParty.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)

View File

@@ -15,7 +15,6 @@
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"

View File

@@ -3,9 +3,8 @@
*
*/
#include "Machines/SPDZ.cpp"
#include "BMR/RealProgramParty.hpp"
#include "Machines/SPDZ.hpp"
int main(int argc, const char** argv)
{

View File

@@ -3,9 +3,8 @@
*
*/
#include "Machines/Rep.hpp"
#include "BMR/RealProgramParty.hpp"
#include "Machines/Rep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -14,7 +14,6 @@
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"

View File

@@ -4,12 +4,10 @@
*/
#include "Math/gfp.hpp"
#include "Protocols/ReplicatedMachine.hpp"
#include "Protocols/ReplicatedFieldMachine.hpp"
#include "Machines/Rep.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
ReplicatedMachine<Rep3Share<gfp>, Rep3Share<gf2n>>(argc, argv,
"replicated-field", opt);
ReplicatedFieldMachine<Rep3Share>(argc, argv);
}

View File

@@ -0,0 +1,12 @@
/*
* semi-bmr-party.cpp
*
*/
#include "BMR/RealProgramParty.hpp"
#include "Machines/Semi.hpp"
int main(int argc, const char** argv)
{
RealProgramParty<SemiShare<gf2n_long>>(argc, argv);
}

View File

@@ -3,9 +3,8 @@
*
*/
#include "Machines/ShamirMachine.hpp"
#include "BMR/RealProgramParty.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)

View File

@@ -19,7 +19,6 @@
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/MAC_Check_Base.hpp"

View File

@@ -19,7 +19,6 @@
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/MAC_Check_Base.hpp"

View File

@@ -57,7 +57,7 @@ include $(wildcard *.d static/*.d)
%.o: %.cpp
$(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x
offline: $(OT_EXE) Check-Offline.x
@@ -193,6 +193,8 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o
mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
emulate.x: GC/FakeSecret.o
semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o
$(LIBSIMPLEOT): SimpleOT/Makefile
$(MAKE) -C SimpleOT

View File

@@ -138,6 +138,8 @@ void write_online_setup(string dirname, const bigint& p)
ofstream outf;
outf.open(ss.str().c_str());
outf << p << endl;
if (!outf.good())
throw file_error("cannot write to " + ss.str());
}
void init_gf2n(int lg2)

View File

@@ -10,6 +10,7 @@
class OnlineOptions;
class bigint;
class PRNG;
class ValueInterface
{
@@ -30,6 +31,8 @@ public:
static int power_of_two(bool, int) { throw not_implemented(); }
void normalize() {}
void randomize_part(PRNG&, int) { throw not_implemented(); }
};
#endif /* MATH_VALUEINTERFACE_H_ */

View File

@@ -117,7 +117,7 @@ public:
Z2<K> operator-(const Z2<K>& other) const;
template <int L>
Z2<K+L> operator*(const Z2<L>& other) const;
Z2<(K > L) ? K : L> operator*(const Z2<L>& other) const;
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(); }
Z2<K> operator*(int other) const { return *this * Z2<K>(other); }
@@ -171,6 +171,7 @@ public:
void XOR(const Z2& a, const Z2& b);
void randomize(PRNG& G, int n = -1);
void randomize_part(PRNG& G, int n);
void almost_randomize(PRNG& G) { randomize(G); }
void force_to_bit() { throw runtime_error("impossible"); }
@@ -342,9 +343,9 @@ inline Z2<K> Z2<K>::Mul(const Z2<L>& x, const Z2<M>& y)
template <int K>
template <int L>
inline Z2<K+L> Z2<K>::operator*(const Z2<L>& other) const
inline Z2<(K > L) ? K : L> Z2<K>::operator*(const Z2<L>& other) const
{
return Z2<K+L>::Mul(*this, other);
return Z2<(K > L) ? K : L>::Mul(*this, other);
}
template <int K>
@@ -387,6 +388,14 @@ void Z2<K>::randomize(PRNG& G, int n)
normalize();
}
template<int K>
void Z2<K>::randomize_part(PRNG& G, int n)
{
*this = {};
G.get_octets((octet*)a, DIV_CEIL(n, 8));
a[DIV_CEIL(n, 64) - 1] &= uint64_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS);
}
template<int K>
void Z2<K>::pack(octetStream& o, int n) const
{

View File

@@ -41,7 +41,7 @@ template<class T> void generate_prime_setup(string, int, int);
#endif
template<int X, int L>
class gfp_
class gfp_ : public ValueInterface
{
typedef modp_<L> modp_type;
@@ -247,7 +247,7 @@ class gfp_
gfp_ operator^(const gfp_& x) { gfp_ res; res.XOR(*this, x); return res; }
gfp_ operator|(const gfp_& x) { gfp_ res; res.OR(*this, x); return res; }
gfp_ operator<<(int i) const { gfp_ res; res.SHL(*this, i); return res; }
gfp_ operator>>(int i) { gfp_ res; res.SHR(*this, i); return res; }
gfp_ operator>>(int i) const { gfp_ res; res.SHR(*this, i); return res; }
gfp_& operator&=(const gfp_& x) { AND(*this, x); return *this; }
gfp_& operator<<=(int i) { SHL(*this, i); return *this; }

View File

@@ -193,7 +193,8 @@ void gfp_<X, L>::reqbl(int n)
{
if ((int)n > 0 && pr() < bigint(1) << (n-1))
{
cout << "Tape requires prime of bit length " << n << endl;
cerr << "Tape requires prime of bit length " << n << endl;
cerr << "Run with '-lgp " << n << "'" << endl;
throw invalid_params();
}
else if ((int)n < 0)

View File

@@ -450,17 +450,23 @@ void Player::pass_around(octetStream& o, octetStream& to_receive, int offset) co
* size getting in the way
*/
template<class T>
void MultiPlayer<T>::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
void MultiPlayer<T>::Broadcast_Receive_no_stats(vector<octetStream>& o) const
{
if (o.size() != sockets.size())
throw runtime_error("player numbers don't match");
TimeScope ts(comm_stats["Broadcasting"].add(o[player_no]));
for (int i=1; i<nplayers; i++)
{
int send_to = (my_num() + i) % num_players();
int receive_from = (my_num() + num_players() - i) % num_players();
o[my_num()].exchange(sockets[send_to], sockets[receive_from], o[receive_from]);
}
}
template<class T>
void MultiPlayer<T>::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
{
TimeScope ts(comm_stats["Broadcasting"].add(o[player_no]));
Broadcast_Receive_no_stats(o);
if (!donthash)
{ for (int i=0; i<nplayers; i++)
{ hash_update(&ctx,o[i].get_data(),o[i].get_length()); }

View File

@@ -247,6 +247,7 @@ public:
* - Assumes o[player_no] contains the thing broadcast by me
*/
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
void Broadcast_Receive_no_stats(vector<octetStream>& o) const;
// wait for available inputs
void wait_for_available(vector<int>& players, vector<int>& result) const;

View File

@@ -3,23 +3,7 @@
*
*/
#include <OT/TripleMachine.h>
#include "OT/NPartyTripleGenerator.h"
#include "OT/OTTripleSetup.h"
#include "Math/gf2n.h"
#include "Math/Setup.h"
#include "Protocols/Spdz2kShare.h"
#include "Tools/ezOptionParser.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include "Math/BitVec.h"
#include "Protocols/fake-stuff.hpp"
#include "Math/Z2k.hpp"
#include <iostream>
#include <fstream>
using namespace std;
#include "MascotParams.h"
MascotParams::MascotParams()
{

View File

@@ -51,18 +51,20 @@ public:
// Both
PRNG local_prng;
Row<T> tmp;
octetStream os;
vector<octetStream> oss;
virtual void consistency_check (vector<octetStream>& os);
void set_coeffs(__m128i* coefficients, PRNG& G, int num_elements) const;
void hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients);
void hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients);
template<class U>
void hash_row(octetStream& os, const U& row, const __m128i* coefficients);
template<class U>
void hash_row(octet* hash, const U& row, const __m128i* coefficients);
template<class U>
void hash_row(__m128i res[2], const U& row, const __m128i* coefficients);
static void add_mul(__m128i res[2], __m128i a, __m128i b);
};
template <class T>

View File

@@ -10,30 +10,28 @@ template <class T>
void OTVoleBase<T>::evaluate(vector<T>& output, const vector<T>& newReceiverInput) {
const int N1 = newReceiverInput.size() + 1;
output.resize(newReceiverInput.size());
vector<octetStream> os(2);
auto& os = oss;
os.resize(2);
os[0].reset_write_head();
os[1].reset_write_head();
if (this->ot_role & SENDER) {
T extra;
extra.randomize(local_prng);
vector<T> _corr(newReceiverInput);
_corr.push_back(extra);
corr_prime = Row<T>(_corr);
corr_prime.rows = newReceiverInput;
corr_prime.rows.push_back(extra);
for (int i = 0; i < S; ++i)
{
t0[i] = Row<T>(N1);
t0[i].randomize(this->G_sender[i][0]);
t1[i] = Row<T>(N1);
t1[i].randomize(this->G_sender[i][1]);
Row<T> u = corr_prime + t1[i] + t0[i];
u.pack(os[0]);
t0[i].randomize(this->G_sender[i][0], N1);
t1[i].randomize(this->G_sender[i][1], N1);
(corr_prime + t1[i] + t0[i]).pack(os[0]);
}
}
send_if_ot_sender(this->player, os, this->ot_role);
if (this->ot_role & RECEIVER) {
for (int i = 0; i < S; ++i)
{
t[i] = Row<T>(N1);
t[i].randomize(this->G_receiver[i]);
t[i].randomize(this->G_receiver[i], N1);
int choice_bit = this->baseReceiverInput.get_bit(i);
if (choice_bit == 1) {
@@ -84,38 +82,104 @@ void OTVoleBase<T>::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) c
}
template <class T>
void OTVoleBase<T>::hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients) {
template<class U>
void OTVoleBase<T>::hash_row(octetStream& os, const U& row,
const __m128i* coefficients)
{
octet hash[VOLE_HASH_SIZE] = {0};
this->hash_row(hash, row, coefficients);
os.append(hash, VOLE_HASH_SIZE);
}
template <class T>
void OTVoleBase<T>::hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients) {
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
os.clear();
for(auto& x : row.rows)
x.pack(os);
os.serialize(int128());
__m128i prods[2];
avx_memzero(prods, sizeof(prods));
template <class U>
void OTVoleBase<T>::hash_row(octet* hash, const U& row,
const __m128i* coefficients)
{
__m128i res[2];
avx_memzero(res, sizeof(res));
for (int i = 0; i < num_blocks; ++i) {
__m128i block;
os.unserialize(block);
mul128(block, coefficients[i], &prods[0], &prods[1]);
res[0] ^= prods[0];
res[1] ^= prods[1];
}
for (int i = 0; i < 2; i++)
res[i] = _mm_setzero_si128();
hash_row(res, row, coefficients);
crypto_generichash(hash, crypto_generichash_BYTES,
(octet*) res, crypto_generichash_BYTES, NULL, 0);
}
template <class T>
template <class U>
void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
const __m128i* coefficients)
{
auto coeff_base = coefficients;
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
__m128i buffer[T::size()];
size_t next = 0;
while (next + 16 < row.size())
{
for (int j = 0; j < 16; j++)
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
for (int j = 0; j < T::size(); j++)
add_mul(res, buffer[j], *coefficients++);
}
for (int j = 0; j < 16; j++)
if (next < row.size())
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
for (int j = 0; j < num_blocks % T::size(); j++)
add_mul(res, buffer[j], *coefficients++);
assert(coefficients == coeff_base + num_blocks);
}
template <class T>
void OTVoleBase<T>::add_mul(__m128i res[2], __m128i a, __m128i b)
{
__m128i prods[2];
mul128(a, b, &prods[0], &prods[1]);
res[0] ^= prods[0];
res[1] ^= prods[1];
}
template <>
template <class U>
inline
void OTVoleBase<Z2<128>>::hash_row(__m128i res[2],
const U& row, const __m128i* coefficients)
{
for (size_t i = 0; i < row.size(); i++)
{
__m128i block = int128(row[i].get_limb(1), row[i].get_limb(0)).a;
add_mul(res, block, coefficients[i]);
}
}
template <>
template <class U>
inline
void OTVoleBase<Z2<192>>::hash_row(__m128i res[2], const U& row,
const __m128i* coefficients)
{
__m128i block;
size_t j;
for (j = 0; j < row.size() - 1; j += 2)
{
auto x = row[j];
auto y = row[j + 1];
block = int128(x.get_limb(1), x.get_limb(0)).a;
add_mul(res, block, *coefficients++);
block = int128(y.get_limb(0), x.get_limb(2)).a;
add_mul(res, block, *coefficients++);
block = int128(y.get_limb(2), y.get_limb(1)).a;
add_mul(res, block, *coefficients++);
}
if (j < row.size())
{
auto x = row[j];
block = int128(x.get_limb(1), x.get_limb(0)).a;
add_mul(res, block, *coefficients++);
block = int128(x.get_limb(2)).a;
add_mul(res, block, *coefficients++);
}
}
template <class T>
void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
PRNG coef_prng_sender;
@@ -144,17 +208,16 @@ void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
__m128i coefficients[num_blocks];
this->set_coeffs(coefficients, coef_prng_sender, num_blocks);
Row<T> t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size());
for (int alpha = 0; alpha < S; ++alpha)
{
for (int i = 0; i < n_challenges(); i++)
{
int beta = get_challenge(coef_prng_sender, i);
t00 = t0[alpha] - t0[beta];
t01 = t0[alpha] - t1[beta];
t10 = t1[alpha] - t0[beta];
t11 = t1[alpha] - t1[beta];
auto t00 = t0[alpha] - t0[beta];
auto t01 = t0[alpha] - t1[beta];
auto t10 = t1[alpha] - t0[beta];
auto t11 = t1[alpha] - t1[beta];
this->hash_row(os[0], t00, coefficients);
this->hash_row(os[0], t01, coefficients);
@@ -202,16 +265,16 @@ void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
int choice_alpha = this->baseReceiverInput.get_bit(alpha);
int choice_beta = this->baseReceiverInput.get_bit(beta);
tmp = t[alpha] - t[beta];
auto diff = t[alpha] - t[beta];
octet* choice_hash = hashes[choice_alpha][choice_beta];
octet diff_t[VOLE_HASH_SIZE] = {0};
this->hash_row(diff_t, tmp, coefficients);
this->hash_row(diff_t, diff, coefficients);
octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta];
octet other_diff[VOLE_HASH_SIZE] = {0};
tmp = u[alpha] - u[beta];
tmp -= t[alpha];
tmp += t[beta];
auto a = u[alpha] - u[beta];
auto b = a - t[alpha];
auto tmp = b + t[beta];
this->hash_row(other_diff, tmp, coefficients);
if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) {

View File

@@ -5,7 +5,8 @@
#include "Math/gf2nlong.h"
#define VOLE_HASH_SIZE crypto_generichash_BYTES
template <class T> class DeferredMinus;
template <class T, class U> class DeferredMinus;
template <class T, class U> class DeferredPlus;
template <class T>
class Row
@@ -20,7 +21,10 @@ public:
Row(const vector<T>& _rows) : rows(_rows) {}
Row(DeferredMinus<T> d) { *this = d; }
template<class U>
Row(DeferredMinus<T, U> d) { *this = d; }
template<class U>
Row(DeferredPlus<T, U> d) { *this = d; }
bool operator==(const Row<T>& other) const;
bool operator!=(const Row<T>& other) const { return not (*this == other); }
@@ -31,23 +35,28 @@ public:
Row<T>& operator*=(const T& other);
Row<T> operator*(const T& other);
Row<T> operator+(const Row<T> & other);
DeferredMinus<T> operator-(const Row<T> & other);
DeferredPlus<T, Row<T>> operator+(const Row<T> & other);
DeferredMinus<T, Row<T>> operator-(const Row<T>& other);
Row<T>& operator=(DeferredMinus<T> d);
template<class U>
Row<T>& operator=(const DeferredMinus<T, U>& d);
template<class U>
Row<T>& operator=(const DeferredPlus<T, U>& d);
Row<T> operator<<(int i) const;
// fine, since elements in vector are allocated contiguously
const void* get_ptr() const { return rows[0].get_ptr(); }
void randomize(PRNG& G);
void randomize(PRNG& G, size_t size);
void pack(octetStream& o) const;
void unpack(octetStream& o);
size_t size() const { return rows.size(); }
const T& operator[](size_t i) const { return rows[i]; }
template <class V>
friend ostream& operator<<(ostream& o, const Row<V>& x);
};
@@ -55,17 +64,67 @@ public:
template <int K>
using Z2kRow = Row<Z2<K>>;
template <class T>
template <class T, class U>
class DeferredMinus
{
public:
const Row<T>& x;
const U& x;
const Row<T>& y;
DeferredMinus(const Row<T>& x, const Row<T>& y) : x(x), y(y)
DeferredMinus(const U& x, const Row<T>& y) : x(x), y(y)
{
assert(x.size() == y.size());
}
size_t size() const
{
return x.size();
}
T operator[](size_t i) const
{
return x[i] - y[i];
}
DeferredPlus<T, DeferredMinus> operator+(const Row<T>& other)
{
return {*this, other};
}
DeferredMinus<T, DeferredMinus> operator-(const Row<T>& other)
{
return {*this, other};
}
};
template <class T, class U>
class DeferredPlus
{
public:
const U& x;
const Row<T>& y;
DeferredPlus(const U& x, const Row<T>& y) : x(x), y(y)
{
assert(x.size() == y.size());
}
size_t size() const
{
return x.size();
}
T operator[](size_t i) const
{
return x[i] + y[i];
}
DeferredPlus<T, DeferredPlus> operator+(const Row<T>& other)
{
return {*this, other};
}
void pack(octetStream& o) const;
};
#endif /* OT_ROW_H_ */

View File

@@ -46,34 +46,46 @@ Row<T> Row<T>::operator *(const T& other)
}
template<class T>
Row<T> Row<T>::operator +(const Row<T>& other)
DeferredPlus<T, Row<T>> Row<T>::operator +(const Row<T>& other)
{
Row<T> res = other;
res += *this;
return res;
return {*this, other};
}
template<class T>
DeferredMinus<T> Row<T>::operator -(const Row<T>& other)
DeferredMinus<T, Row<T>> Row<T>::operator -(const Row<T>& other)
{
return DeferredMinus<T>(*this, other);
return {*this, other};
}
template<class T>
Row<T>& Row<T>::operator=(DeferredMinus<T> d)
template<class U>
Row<T>& Row<T>::operator=(const DeferredMinus<T, U>& d)
{
size_t size = d.x.size();
rows.resize(size);
for (size_t i = 0; i < size; i++)
rows[i] = d.x.rows[i] - d.y.rows[i];
rows[i] = d[i];
return *this;
}
template<class T>
void Row<T>::randomize(PRNG& G)
template<class U>
Row<T>& Row<T>::operator=(const DeferredPlus<T, U>& d)
{
for (size_t i = 0; i < this->size(); i++)
rows[i].randomize(G);
size_t size = d.x.size();
rows.resize(size);
for (size_t i = 0; i < size; i++)
rows[i] = d[i];
return *this;
}
template<class T>
void Row<T>::randomize(PRNG& G, size_t size)
{
rows.clear();
rows.reserve(size);
for (size_t i = 0; i < size; i++)
rows.push_back(G.get<T>());
}
template<class T>
@@ -87,6 +99,14 @@ Row<T> Row<T>::operator<<(int i) const {
return res;
}
template<class T, class U>
void DeferredPlus<T, U>::pack(octetStream& o) const
{
o.store(this->size());
for (size_t i = 0; i < this->size(); i++)
(*this)[i].pack(o);
}
template<class T>
void Row<T>::pack(octetStream& o) const
{
@@ -100,9 +120,10 @@ void Row<T>::unpack(octetStream& o)
{
size_t size;
o.get(size);
this->rows.resize(size);
for (size_t i = 0; i < this->size(); i++)
rows[i].unpack(o);
rows.clear();
rows.reserve(size);
for (size_t i = 0; i < size; i++)
rows.push_back(o.get<T>());
}
template <class V>

View File

@@ -11,6 +11,9 @@ using namespace std;
#include "Math/BitVec.h"
#include "Data_Files.h"
#include "Protocols/Replicated.h"
#include "Protocols/MAC_Check_Base.h"
#include "Processor/Input.h"
class Player;
class DataPositions;
@@ -26,17 +29,21 @@ template<class T> class ShareThread;
}
template<class T>
class DummyMC
class DummyMC : public MAC_Check_Base<T>
{
public:
void POpen(vector<typename T::open_type>&, vector<T>&, Player&)
DummyMC()
{
throw not_implemented();
}
void Check(Player& P)
template<class U>
DummyMC(U, int = 0, int = 0)
{
(void) P;
}
void exchange(const Player&)
{
throw not_implemented();
}
DummyMC<typename T::part_type>& get_part_MC()
@@ -51,7 +58,8 @@ public:
}
};
class DummyProtocol
template<class T>
class DummyProtocol : public ProtocolBase<T>
{
public:
Player& P;
@@ -66,13 +74,11 @@ public:
{
}
template<class T>
void init_mul(SubProcessor<T>* = 0)
{
throw not_implemented();
}
template<class T>
void prepare_mul(const T&, const T&, int = 0)
typename T::clear prepare_mul(const T&, const T&, int = 0)
{
throw not_implemented();
}
@@ -80,10 +86,10 @@ public:
{
throw not_implemented();
}
int finalize_mul(int = 0)
T finalize_mul(int = 0)
{
throw not_implemented();
return 0;
return {};
}
};
@@ -91,6 +97,13 @@ template<class T>
class DummyLivePrep : public Preprocessing<T>
{
public:
static void basic_setup(Player&)
{
}
static void teardown()
{
}
static void fail()
{
throw runtime_error(
@@ -106,6 +119,11 @@ public:
{
}
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
Preprocessing<T>(usage)
{
}
void set_protocol(typename T::Protocol&)
{
}
@@ -177,7 +195,7 @@ public:
throw not_implemented();
}
template<class T>
static void input(SubProcessor<T>& proc, vector<int> regs)
static void input(SubProcessor<V>& proc, vector<int> regs, int)
{
(void) proc, (void) regs;
throw not_implemented();
@@ -206,6 +224,14 @@ public:
(void) a, (void) b;
throw not_implemented();
}
static void raw_input(SubProcessor<V>&, vector<int>, int)
{
throw not_implemented();
}
static void input_mixed(SubProcessor<V>&, vector<int>, int, bool)
{
throw not_implemented();
}
};
class NotImplementedOutput

View File

@@ -106,10 +106,12 @@ enum
SQUARE = 0x52,
INV = 0x53,
INPUTMASK = 0x56,
INPUTMASKREG = 0x5C,
PREP = 0x57,
DABIT = 0x58,
EDABIT = 0x59,
SEDABIT = 0x5A,
RANDOMS = 0x5B,
// Input
INPUT = 0x60,
INPUTFIX = 0xF0,
@@ -143,6 +145,7 @@ enum
SHRC = 0x81,
SHLCI = 0x82,
SHRCI = 0x83,
SHRSI = 0x84,
// Branching and comparison
JMP = 0x90,
JMPNZ = 0x91,
@@ -283,13 +286,16 @@ enum
// Register types
enum RegType {
MODP,
GF2N,
INT,
SBIT,
CBIT,
DYN_SBIT,
SINT,
CINT,
SGF2N,
CGF2N,
NONE,
MAX_REG_TYPE,
NONE
};
enum SecrecyType {
@@ -323,6 +329,8 @@ struct TempVars {
class BaseInstruction
{
friend class Program;
protected:
int opcode; // The code
int size; // Vector size
@@ -346,10 +354,10 @@ public:
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
virtual int get_reg_type() const;
bool is_direct_memory_access(SecrecyType sec_type) const;
bool is_direct_memory_access() const;
// Returns the memory size used if applicable and known
unsigned get_mem(RegType reg_type, SecrecyType sec_type) const;
unsigned get_mem(RegType reg_type) const;
// Returns the maximal register used
unsigned get_max_reg(int reg_type) const;

View File

@@ -7,6 +7,7 @@
#include "Processor/IntInput.h"
#include "Processor/FixInput.h"
#include "Processor/FloatInput.h"
#include "Processor/instructions.h"
#include "Exceptions/Exceptions.h"
#include "Tools/time-func.h"
#include "Tools/parse.h"
@@ -102,6 +103,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case MULINT:
case DIVINT:
case CONDPRINTPLAIN:
case INPUTMASKREG:
r[0]=get_int(s);
r[1]=get_int(s);
r[2]=get_int(s);
@@ -199,6 +201,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case ORCI:
case SHLCI:
case SHRCI:
case SHRSI:
case SHLCBI:
case SHRCBI:
case NOTC:
@@ -220,7 +223,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case USE:
case USE_INP:
case USE_EDABIT:
case RUN_TAPE:
case STARTPRIVATEOUTPUT:
case GSTARTPRIVATEOUTPUT:
case STOPPRIVATEOUTPUT:
@@ -261,6 +263,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case INV2M:
case CONDPRINTSTR:
case CONDPRINTSTRB:
case RANDOMS:
r[0]=get_int(s);
n = get_int(s);
break;
@@ -310,6 +313,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case RAWINPUT:
case GRAWINPUT:
case TRUNC_PR:
case RUN_TAPE:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
@@ -444,6 +448,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case ANDRS:
case ANDS:
case INPUTB:
case INPUTBVEC:
case REVEAL:
get_vector(get_int(s), start, s);
break;
@@ -544,15 +549,57 @@ int BaseInstruction::get_reg_type() const
case USE_PREP:
case GUSE_PREP:
case USE_EDABIT:
case RUN_TAPE:
// those use r[] not for registers
return NONE;
case LDI:
case LDMC:
case STMC:
case LDMCI:
case STMCI:
case MOVC:
case ADDC:
case ADDCI:
case SUBC:
case SUBCI:
case SUBCFI:
case MULC:
case MULCI:
case DIVC:
case DIVCI:
case MODC:
case MODCI:
case LEGENDREC:
case DIGESTC:
case INV2M:
case OPEN:
case ANDC:
case XORC:
case ORC:
case ANDCI:
case XORCI:
case ORCI:
case NOTC:
case SHLC:
case SHRC:
case SHLCI:
case SHRCI:
case CONVINT:
return CINT;
default:
if (is_gf2n_instruction())
return GF2N;
{
Instruction tmp;
tmp.opcode = opcode - 0x100;
if (tmp.get_reg_type() == CINT)
return CGF2N;
else
return SGF2N;
}
else if (opcode >> 4 == 0x9)
return INT;
else
return MODP;
return SINT;
}
}
@@ -564,27 +611,35 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
int size_offset = 0;
int size = this->size;
if (opcode == DABIT)
// special treatment for instructions writing to different types
switch (opcode)
{
case DABIT:
if (reg_type == SBIT)
return r[1] + size;
else if (reg_type == MODP)
else if (reg_type == SINT)
return r[0] + size;
else
return 0;
}
else if (opcode == EDABIT or opcode == SEDABIT)
{
case EDABIT:
case SEDABIT:
if (reg_type == SBIT)
skip = 1;
else if (reg_type == MODP)
else if (reg_type == SINT)
return r[0] + size;
else
return 0;
}
else if (get_reg_type() != reg_type)
{
return 0;
break;
case INPUTMASKREG:
if (reg_type == SINT)
return r[0] + size;
else if (reg_type == CINT)
return r[1] + size;
else
return 0;
default:
if (get_reg_type() != reg_type)
return 0;
}
switch (opcode)
@@ -607,6 +662,9 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
return r[0] + start[0] * start[2];
case CONV2DS:
return r[0] + start[0] * start[1];
case OPEN:
skip = 2;
break;
case LDMSD:
case LDMSDI:
skip = 3;
@@ -627,6 +685,20 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
offset = 3;
size_offset = -2;
break;
case INPUTBVEC:
{
int res = 0;
auto it = start.begin();
while (it < start.end())
{
int n = *it - 3;
it += 3;
assert(it + n <= start.end());
for (int i = 0; i < n; i++)
res = max(res, *it++);
}
return res + 1;
}
case ANDM:
case NOTS:
size = DIV_CEIL(n, 64);
@@ -673,16 +745,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
}
inline
unsigned BaseInstruction::get_mem(RegType reg_type, SecrecyType sec_type) const
unsigned BaseInstruction::get_mem(RegType reg_type) const
{
if (get_reg_type() == reg_type and is_direct_memory_access(sec_type))
if (get_reg_type() == reg_type and is_direct_memory_access())
return n + size;
else
return 0;
}
inline
bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
bool BaseInstruction::is_direct_memory_access() const
{
switch (opcode)
{
@@ -690,12 +762,10 @@ bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
case STMS:
case GLDMS:
case GSTMS:
return sec_type == SECRET;
case LDMC:
case STMC:
case GLDMC:
case GSTMC:
return sec_type == CLEAR;
case LDMINT:
case STMINT:
case LDMSB:
@@ -731,16 +801,9 @@ __attribute__((always_inline))
#endif
inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
{
Proc.PC+=1;
auto& Procp = Proc.Procp;
auto& Proc2 = Proc.Proc2;
// binary instructions
typedef typename sint::bit_type T;
auto& processor = Proc.Procb;
auto& instruction = *this;
auto& Ci = Proc.get_Ci();
// optimize some instructions
switch (opcode)
{
@@ -772,102 +835,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
for (int i = 0; i < size; i++)
Proc.get_S2_ref(r[0] + i).mul(Proc.read_S2(r[1] + i),Proc.read_C2(r[2] + i));
return;
case LDI:
Proc.temp.assign_ansp(n);
for (int i = 0; i < size; i++)
Proc.write_Cp(r[0] + i,Proc.temp.ansp);
return;
case LDMSI:
for (int i = 0; i < size; i++)
Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(Proc.read_Ci(r[1] + i)));
return;
case LDMS:
for (int i = 0; i < size; i++)
Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(n + i));
return;
case STMSI:
for (int i = 0; i < size; i++)
Proc.machine.Mp.write_S(Proc.read_Ci(r[1] + i), Proc.read_Sp(r[0] + i), Proc.PC);
return;
case STMS:
for (int i = 0; i < size; i++)
Proc.machine.Mp.write_S(n + i, Proc.read_Sp(r[0] + i), Proc.PC);
return;
case ADDC:
for (int i = 0; i < size; i++)
Proc.get_Cp_ref(r[0] + i).add(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i));
return;
case ADDS:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i));
return;
case ADDM:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i),Proc.P.my_num(),Proc.MCp.get_alphai());
return;
case ADDCI:
Proc.temp.assign_ansp(n);
for (int i = 0; i < size; i++)
Proc.get_Cp_ref(r[0] + i).add(Proc.temp.ansp,Proc.read_Cp(r[1] + i));
return;
case SUBS:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i));
return;
case SUBSFI:
Proc.temp.assign_ansp(n);
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).sub(Proc.temp.ansp,Proc.read_Sp(r[1] + i),Proc.P.my_num(),Proc.MCp.get_alphai());
return;
case SUBML:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i), Proc.read_Cp(r[2] + i),
Proc.P.my_num(), Proc.MCp.get_alphai());
return;
case SUBMR:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Cp(r[1] + i), Proc.read_Sp(r[2] + i),
Proc.P.my_num(), Proc.MCp.get_alphai());
return;
case MULM:
for (int i = 0; i < size; i++)
Proc.get_Sp_ref(r[0] + i).mul(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i));
return;
case MULC:
for (int i = 0; i < size; i++)
Proc.get_Cp_ref(r[0] + i).mul(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i));
return;
case MULCI:
Proc.temp.assign_ansp(n);
for (int i = 0; i < size; i++)
Proc.get_Cp_ref(r[0] + i).mul(Proc.temp.ansp,Proc.read_Cp(r[1] + i));
return;
case SHRCI:
for (int i = 0; i < size; i++)
Proc.get_Cp_ref(r[0] + i).SHR(Proc.read_Cp(r[1] + i), n);
return;
case TRIPLE:
for (int i = 0; i < size; i++)
Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i),
Proc.get_Sp_ref(r[1] + i), Proc.get_Sp_ref(r[2] + i));
return;
case BIT:
for (int i = 0; i < size; i++)
Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0] + i));
return;
case LDINT:
for (int i = 0; i < size; i++)
Proc.write_Ci(r[0] + i, int(n));
return;
case INCINT:
{
for (int i = 0; i < size; i++)
{
int inc = (i / start[0]) % start[1];
Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1]) + inc * int(n));
}
}
return;
case SHUFFLE:
for (int i = 0; i < size; i++)
Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1] + i));
@@ -877,29 +844,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
swap(Proc.get_Ci_ref(r[0] + i), Proc.get_Ci_ref(r[0] + i + j));
}
return;
case ADDINT:
for (int i = 0; i < size; i++)
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) + Proc.read_Ci(r[2] + i);
return;
case SUBINT:
for (int i = 0; i < size; i++)
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) - Proc.read_Ci(r[2] + i);
return;
case MULINT:
for (int i = 0; i < size; i++)
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) * Proc.read_Ci(r[2] + i);
return;
case DIVINT:
for (int i = 0; i < size; i++)
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) / Proc.read_Ci(r[2] + i);
return;
case CONVINT:
for (int i = 0; i < size; i++)
{
Proc.temp.assign_ansp(Proc.read_Ci(r[1] + i));
Proc.get_Cp_ref(r[0] + i) = Proc.temp.ansp;
}
return;
case CONVMODP:
if (n == 0)
{
@@ -914,9 +858,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
throw Processor_Error(to_string(n) + "-bit conversion impossible; "
"integer registers only have 64 bits");
return;
#define X(NAME, CODE) case NAME: CODE; return;
COMBI_INSTRUCTIONS
#undef X
}
int r[3] = {this->r[0], this->r[1], this->r[2]};
@@ -1275,6 +1216,13 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case GINV:
Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]));
break;
case RANDOMS:
Procp.protocol.randoms_inst(Procp, *this);
return;
case INPUTMASKREG:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2]));
Proc.write_Cp(r[1], Proc.temp.rrp);
break;
case INPUTMASK:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n);
if (n == Proc.P.my_num())
@@ -1388,6 +1336,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case GSHRCI:
Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n);
break;
case SHRSI:
Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n;
break;
case GBITDEC:
for (int j = 0; j < size; j++)
{
@@ -1649,7 +1600,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case RUN_TAPE:
Proc.DataF.skip(
Proc.machine.run_tape(r[0], n, r[1], -1, &Proc.DataF.DataFp,
Proc.machine.run_tapes(start, &Proc.DataF.DataFp,
&Proc.share_thread.DataF));
break;
case JOIN_TAPE:
@@ -1788,8 +1739,41 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
{
unsigned int size = p.size();
Proc.PC=0;
auto& Procp = Proc.Procp;
// binary instructions
typedef typename sint::bit_type T;
auto& processor = Proc.Procb;
auto& Ci = Proc.get_Ci();
while (Proc.PC<size)
{ p[Proc.PC].execute(Proc); }
{
auto& instruction = p[Proc.PC];
auto& r = instruction.r;
auto& n = instruction.n;
auto& start = instruction.start;
auto& size = instruction.size;
#ifdef COUNT_INSTRUCTIONS
Proc.stats[p[Proc.PC].get_opcode()]++;
#endif
Proc.PC++;
switch(instruction.get_opcode())
{
#define X(NAME, PRE, CODE) \
case NAME: { PRE; for (int i = 0; i < size; i++) { CODE; } } break;
ARITHMETIC_INSTRUCTIONS
#undef X
#define X(NAME, CODE) case NAME: CODE; break;
COMBI_INSTRUCTIONS
#undef X
default:
instruction.execute(Proc);
}
}
}
#endif

View File

@@ -17,6 +17,7 @@
#include "GC/Machine.h"
#include "Tools/time-func.h"
#include "Tools/ExecutionStats.h"
#include <vector>
#include <map>
@@ -44,9 +45,6 @@ class Machine : public BaseMachine
// Keep record of used offline data
DataPositions pos;
int tn,numt;
bool usage_unknown;
void load_program(string threadname, string filename);
public:
@@ -71,6 +69,7 @@ class Machine : public BaseMachine
OnlineOptions opts;
atomic<size_t> data_sent;
ExecutionStats stats;
Machine(int my_number, Names& playerNames, string progname,
string memtype, int lg2, bool direct, int opening_sum,
@@ -79,9 +78,12 @@ class Machine : public BaseMachine
const Names& get_N() { return N; }
DataPositions run_tape(int thread_number, int tape_number, int arg,
int line_number, Preprocessing<sint>* prep = 0,
Preprocessing<typename sint::bit_type>* bit_prep = 0);
DataPositions run_tapes(const vector<int> &args, Preprocessing<sint> *prep,
Preprocessing<typename sint::bit_type> *bit_prep);
void fill_buffers(int thread_number, int tape_number,
Preprocessing<sint> *prep,
Preprocessing<typename sint::bit_type> *bit_prep);
DataPositions run_tape(int thread_number, int tape_number, int arg);
DataPositions join_tape(int thread_number);
void run();

View File

@@ -9,6 +9,7 @@
#include "Math/Setup.h"
#include "Tools/mkpath.h"
#include "Tools/Bundle.h"
#include <iostream>
#include <vector>
@@ -22,7 +23,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
string progname_str, string memtype, int lg2, bool direct,
int opening_sum, bool receive_threads, int max_broadcast,
bool use_encryption, bool live_prep, OnlineOptions opts)
: my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false),
: my_number(my_number), N(playerNames),
direct(direct), opening_sum(opening_sum),
receive_threads(receive_threads), max_broadcast(max_broadcast),
use_encryption(use_encryption), live_prep(live_prep), opts(opts),
@@ -40,9 +41,12 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
// make directory for outputs if necessary
mkdir_p(PREP_DIR);
auto P = new PlainPlayer(N, 0xF00);
sint::LivePrep::basic_setup(*P);
delete P;
if (opts.live_prep)
{
auto P = new PlainPlayer(N, 0xF00);
sint::LivePrep::basic_setup(*P);
delete P;
}
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), N, alphapi);
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), N, alpha2i);
@@ -131,21 +135,29 @@ void Machine<sint, sgf2n>::load_program(string threadname, string filename)
int i = progs.size() - 1;
progs[i].parse(pinp);
pinp.close();
M2.minimum_size(GF2N, progs[i], threadname);
Mp.minimum_size(MODP, progs[i], threadname);
Mi.minimum_size(INT, progs[i], threadname);
M2.minimum_size(SGF2N, CGF2N, progs[i], threadname);
Mp.minimum_size(SINT, CINT, progs[i], threadname);
Mi.minimum_size(NONE, INT, progs[i], threadname);
}
template<class sint, class sgf2n>
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
int arg, int line_number, Preprocessing<sint>* prep,
DataPositions Machine<sint, sgf2n>::run_tapes(const vector<int>& args,
Preprocessing<sint>* prep, Preprocessing<typename sint::bit_type>* bit_prep)
{
assert(args.size() % 3 == 0);
for (unsigned i = 0; i < args.size(); i += 3)
fill_buffers(args[i], args[i + 1], prep, bit_prep);
DataPositions res(N.num_players());
for (unsigned i = 0; i < args.size(); i += 3)
res.increase(run_tape(args[i], args[i + 1], args[i + 2]));
return res;
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
Preprocessing<sint>* prep,
Preprocessing<typename sint::bit_type>* bit_prep)
{
if (size_t(thread_number) >= tinfo.size())
throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size()));
if (size_t(tape_number) >= progs.size())
throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size()));
// central preprocessing
auto usage = progs[tape_number].get_offline_data_used();
if (sint::expensive and prep != 0 and OnlineOptions::singleton.bucket_size == 3)
@@ -203,6 +215,16 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
#endif
}
}
}
template<class sint, class sgf2n>
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
int arg)
{
if (size_t(thread_number) >= tinfo.size())
throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size()));
if (size_t(tape_number) >= progs.size())
throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size()));
queues[thread_number]->schedule({tape_number, arg, pos});
//printf("Send signal to run program %d in thread %d\n",tape_number,thread_number);
@@ -211,21 +233,10 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
{
if (not opts.live_prep)
{
// only one thread allowed
if (numt>1)
{ cerr << "Line " << line_number << " has " <<
numt << " threads but tape " << tape_number <<
cerr << "Internally called tape " << tape_number <<
" has unknown offline data usage" << endl;
throw invalid_program();
}
else if (line_number == -1)
{
cerr << "Internally called tape " << tape_number <<
" has unknown offline data usage" << endl;
throw invalid_program();
}
throw invalid_program();
}
usage_unknown = true;
return DataPositions(N.num_players());
}
else
@@ -252,43 +263,12 @@ void Machine<sint, sgf2n>::run()
proc_timer.start();
timer[0].start();
bool flag=true;
usage_unknown=false;
int exec=0;
while (flag)
{ inpf >> numt;
if (numt==0)
{ flag=false; }
else
{ for (int i=0; i<numt; i++)
{
// Now load up data
inpf >> tn;
// Cope with passing an integer parameter to a tape
int arg;
if (inpf.get() == ':')
inpf >> arg;
else
arg = 0;
//cerr << "Run scheduled tape " << tn << " in thread " << i << endl;
pos.increase(run_tape(i, tn, arg, exec));
}
// Make sure all terminate before we continue
auto new_pos = join_tape(0);
for (int i=1; i<numt; i++)
{ join_tape(i);
}
if (usage_unknown)
{ // synchronize files
pos = new_pos;
usage_unknown = false;
}
//printf("Finished running line %d\n",exec);
exec++;
}
}
// legacy
int _;
inpf >> _ >> _ >> _;
// run main tape
pos.increase(run_tape(0, 0, 0));
join_tape(0);
print_compiler();
@@ -317,6 +297,16 @@ void Machine<sint, sgf2n>::run()
finish_timer.stop();
#ifdef VERBOSE
cerr << "Memory usage: ";
tinfo[0].print_usage(cerr, Mp.MS, "sint");
tinfo[0].print_usage(cerr, Mp.MC, "cint");
tinfo[0].print_usage(cerr, M2.MS, "sgf2n");
tinfo[0].print_usage(cerr, M2.MS, "cgf2n");
tinfo[0].print_usage(cerr, bit_memories.MS, "sbits");
tinfo[0].print_usage(cerr, bit_memories.MC, "cbits");
tinfo[0].print_usage(cerr, Mi.MC, "regint");
cerr << endl;
for (unsigned int i = 0; i < join_timer.size(); i++)
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
@@ -326,6 +316,15 @@ void Machine<sint, sgf2n>::run()
print_timers();
cerr << "Data sent = " << data_sent / 1e6 << " MB" << endl;
PlainPlayer P(N, 0xFFF0);
Bundle<octetStream> bundle(P);
bundle.mine.store(data_sent.load());
P.Broadcast_Receive_no_stats(bundle);
size_t global = 0;
for (auto& os : bundle)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB" << endl;
#ifdef VERBOSE
if (opening_sum < N.num_players() && !direct)
cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl;
@@ -374,6 +373,39 @@ void Machine<sint, sgf2n>::run()
pos.print_cost();
#endif
if (not stats.empty())
{
cerr << "Instruction statistics:" << endl;
set<pair<size_t, int>> sorted_stats;
for (auto& x : stats)
{
sorted_stats.insert({x.second, x.first});
}
for (auto& x : sorted_stats)
{
auto opcode = x.second;
auto calls = x.first;
cerr << "\t";
int n_fill = 15;
switch (opcode)
{
#define X(NAME, PRE, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break;
ARITHMETIC_INSTRUCTIONS
#undef X
#define X(NAME, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break;
COMBI_INSTRUCTIONS
#undef X
default:
cerr << hex << setw(5) << showbase << left << opcode;
n_fill -= 5;
cerr << setw(0);
}
for (int i = 0; i < n_fill; i++)
cerr << " ";
cerr << dec << calls << endl;
}
}
#ifndef INSECURE
Data_Files<sint, sgf2n> df(*this);
df.seekg(pos);

View File

@@ -26,7 +26,7 @@ class Memory
public:
CheckVector<T> MS;
vector<typename T::clear> MC;
CheckVector<typename T::clear> MC;
void resize_s(int sz)
{ MS.resize(sz); }
@@ -73,7 +73,8 @@ class Memory
{ (void)start, (void)end; cerr << "Memory protection not activated" << endl; }
#endif
void minimum_size(RegType reg_type, const Program& program, string threadname);
void minimum_size(RegType secret_type, RegType clear_type,
const Program& program, string threadname);
friend ostream& operator<< <>(ostream& s,const Memory<T>& M);
friend istream& operator>> <>(istream& s,Memory<T>& M);

View File

@@ -4,10 +4,13 @@
#include <fstream>
template<class T>
void Memory<T>::minimum_size(RegType reg_type, const Program& program, string threadname)
void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
const Program &program, string threadname)
{
(void) threadname;
const unsigned* sizes = program.direct_mem(reg_type);
unsigned sizes[MAX_SECRECY_TYPE];
sizes[SECRET]= program.direct_mem(secret_type);
sizes[CLEAR] = program.direct_mem(clear_type);
if (sizes[SECRET] > size_s())
{
#ifdef DEBUG_MEMORY

View File

@@ -28,6 +28,9 @@ class thread_info
static void purge_preprocessing(Machine<sint, sgf2n>& machine);
template<class T>
static void print_usage(ostream& o, const vector<T>& regs, string name);
void Sub_Main_Func();
};

View File

@@ -20,6 +20,15 @@
using namespace std;
template<class sint, class sgf2n>
template<class T>
void thread_info<sint, sgf2n>::print_usage(ostream &o,
const vector<T>& regs, string name)
{
if (regs.capacity())
o << name << "=" << regs.capacity() << " ";
}
template<class sint, class sgf2n>
void thread_info<sint, sgf2n>::Sub_Main_Func()
{
@@ -105,7 +114,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
while (flag)
{ // Wait until I have a program to run
wait_timer.start();
auto job = queues->next();
ThreadJob job = queues->next();
program = job.prognum;
wait_timer.stop();
#ifdef DEBUG_THREADS
@@ -279,6 +288,16 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl;
cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl;
cerr << "Register usage: ";
print_usage(cerr, Proc.Procp.get_S(), "sint");
print_usage(cerr, Proc.Procp.get_C(), "cint");
print_usage(cerr, Proc.Proc2.get_S(), "sgf2n");
print_usage(cerr, Proc.Proc2.get_C(), "cgf2n");
print_usage(cerr, Proc.Procb.S, "sbits");
print_usage(cerr, Proc.Procb.C, "cbits");
print_usage(cerr, Proc.get_Ci(), "regint");
cerr << endl;
#endif
// wind down thread by thread
@@ -287,6 +306,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
prep_sent += Proc.Procp.bit_prep.data_sent();
for (auto& x : Proc.Procp.personal_bit_preps)
prep_sent += x->data_sent();
machine.stats += Proc.stats;
delete processor;
machine.data_sent += P.sent + prep_sent;

View File

@@ -163,6 +163,7 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
opt.parse(argc, argv);
vector<string*> allArgs(opt.firstArgs);
allArgs.insert(allArgs.end(), opt.unknownArgs.begin(), opt.unknownArgs.end());
allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end());
string usage;
vector<string> badOptions;

View File

@@ -26,7 +26,7 @@ class Program;
template <class T>
class SubProcessor
{
vector<typename T::clear> C;
CheckVector<typename T::clear> C;
CheckVector<T> S;
DataPositions bit_usage;
@@ -70,11 +70,16 @@ public:
int b);
void conv2ds(const Instruction& instruction);
vector<T>& get_S()
CheckVector<T>& get_S()
{
return S;
}
CheckVector<typename T::clear>& get_C()
{
return C;
}
T& get_S_ref(int i)
{
return S[i];
@@ -132,7 +137,7 @@ public:
template<class sint, class sgf2n>
class Processor : public ArithmeticProcessor
{
int reg_max2,reg_maxp,reg_maxi;
int reg_max2, reg_maxi;
// Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ)
octetStream socket_stream;

View File

@@ -117,11 +117,11 @@ string Processor<sint, sgf2n>::get_filename(const char* prefix, bool use_number)
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::reset(const Program& program,int arg)
{
reg_max2 = program.num_reg(GF2N);
reg_maxp = program.num_reg(MODP);
reg_maxi = program.num_reg(INT);
Proc2.resize(reg_max2);
Procp.resize(reg_maxp);
Proc2.get_S().resize(program.num_reg(SGF2N));
Proc2.get_C().resize(program.num_reg(CGF2N));
Procp.get_S().resize(program.num_reg(SINT));
Procp.get_C().resize(program.num_reg(CINT));
Ci.resize(reg_maxi);
this->arg = arg;
Procb.reset(program);

View File

@@ -11,6 +11,8 @@
#include <fstream>
using namespace std;
#include "Tools/ExecutionStats.h"
class ProcessorBase
{
// Stack
@@ -24,6 +26,8 @@ protected:
int arg;
public:
ExecutionStats stats;
void pushi(long x) { stacki.push(x); }
void popi(long& x) { x = stacki.top(); stacki.pop(); }

View File

@@ -10,8 +10,7 @@ void Program::compute_constants()
for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++)
{
max_reg[reg_type] = 0;
for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++)
max_mem[reg_type][sec_type] = 0;
max_mem[reg_type] = 0;
}
for (unsigned int i=0; i<p.size(); i++)
{
@@ -21,13 +20,10 @@ void Program::compute_constants()
{
max_reg[reg_type] = max(max_reg[reg_type],
p[i].get_max_reg(reg_type));
for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++)
max_mem[reg_type][sec_type] = max(max_mem[reg_type][sec_type],
p[i].get_mem(RegType(reg_type), SecrecyType(sec_type)));
max_mem[reg_type] = max(max_mem[reg_type],
p[i].get_mem(RegType(reg_type)));
}
}
max_mem[INT][SECRET] = 0;
}
void Program::parse(istream& s)

View File

@@ -21,7 +21,7 @@ class Program
unsigned max_reg[MAX_REG_TYPE];
// Memory size used directly
unsigned max_mem[MAX_REG_TYPE][MAX_SECRECY_TYPE];
unsigned max_mem[MAX_REG_TYPE];
// True if program contains variable-sized loop
bool unknown_usage;
@@ -45,7 +45,7 @@ class Program
int num_reg(RegType reg_type) const
{ return max_reg[reg_type]; }
const unsigned* direct_mem(RegType reg_type) const
unsigned direct_mem(RegType reg_type) const
{ return max_mem[reg_type]; }
friend ostream& operator<<(ostream& s,const Program& P);

85
Processor/instructions.h Normal file
View File

@@ -0,0 +1,85 @@
/*
* instructions.h
*
*/
#ifndef PROCESSOR_INSTRUCTIONS_H_
#define PROCESSOR_INSTRUCTIONS_H_
#define ARITHMETIC_INSTRUCTIONS \
X(LDI, auto dest = &Procp.get_C()[r[0]]; typename sint::clear tmp = int(n), \
*dest++ = tmp) \
X(LDMS, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.machine.Mp.MS[n], \
*dest++ = *source++) \
X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \
*dest++ = *source++) \
X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
*dest++ = Proc.machine.Mp.read_S(*source++)) \
X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
Proc.machine.Mp.write_S(*dest++, *source++)) \
X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \
*dest++ = *source++) \
X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_S()[r[2]], \
*dest++ = *op1++ + *op2++) \
X(ADDM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ + sint::constant(*op2++, Proc.P.my_num(), Procp.MC.get_alphai())) \
X(ADDC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ + *op2++) \
X(ADDCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
typename sint::clear op2 = int(n), \
*dest++ = *op1++ + op2) \
X(SUBS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_S()[r[2]], \
*dest++ = *op1++ - *op2++) \
X(SUBSFI, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = sint::constant(int(n), Proc.P.my_num(), Procp.MC.get_alphai()), \
*dest++ = op2 - *op1++) \
X(SUBML, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ - sint::constant(*op2++, Proc.P.my_num(), Procp.MC.get_alphai())) \
X(SUBMR, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
auto op2 = &Procp.get_S()[r[2]], \
*dest++ = sint::constant(*op1++, Proc.P.my_num(), Procp.MC.get_alphai()) - *op2++) \
X(SUBC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ - *op2++) \
X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ * *op2++) \
X(MULC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ * *op2++) \
X(MULCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
typename sint::clear op2 = int(n), \
*dest++ = *op1++ * op2) \
X(SHRCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]], \
*dest++ = *op1++ >> n) \
X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \
auto c = &Procp.get_S()[r[2]], \
Procp.DataF.get_three(DATA_TRIPLE, *a++, *b++, *c++)) \
X(BIT, auto dest = &Procp.get_S()[r[0]], \
Procp.DataF.get_one(DATA_BIT, *dest++)) \
X(LDINT, auto dest = &Proc.get_Ci()[r[0]], \
*dest++ = int(n)) \
X(ADDINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
auto op2 = &Proc.get_Ci()[r[2]], \
*dest++ = *op1++ + *op2++) \
X(SUBINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
auto op2 = &Proc.get_Ci()[r[2]], \
*dest++ = *op1++ - *op2++) \
X(MULINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
auto op2 = &Proc.get_Ci()[r[2]], \
*dest++ = *op1++ * *op2++) \
X(DIVINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
auto op2 = &Proc.get_Ci()[r[2]], \
*dest++ = *op1++ / *op2++) \
X(INCINT, auto dest = &Proc.get_Ci()[r[0]]; auto base = Proc.get_Ci()[r[1]], \
int inc = (i / start[0]) % start[1]; *dest++ = base + inc * int(n)) \
X(CONVINT, auto dest = &Procp.get_C()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
*dest++ = *source++) \
#endif /* PROCESSOR_INSTRUCTIONS_H_ */

View File

@@ -2,11 +2,7 @@ import ml
import util
import math
if 'trunc_pr' in program.args:
program.use_trunc_pr = True
if 'split' in program.args:
program.use_split(3)
program.options_from_args()
program.options.cisc = True
try:
@@ -72,7 +68,7 @@ else:
opt = ml.Optimizer()
opt.layers = layers
for layer in layers:
layer.input_from(0, raw='raw' in program.args)
layer.input_from(0)
layers[0].X.input_from(1)
start_timer(1)
opt.forward(1)

View File

@@ -3,13 +3,7 @@ from Compiler import ml
debug = False
program.use_edabit(True)
program.use_trunc_pr = True
if 'split' in program.args:
program.use_split(3)
if 'split2' in program.args:
program.use_split(2)
program.options_from_args()
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
@@ -29,9 +23,19 @@ dense = ml.Dense(12800, dim, 1)
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
# @for_range(1000)
# def _(i):
# sgd.backward()
if not ('forward' in program.args or 'backward' in program.args):
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
if 'forward' in program.args:
@for_range(1000)
def _(i):
sgd.forward(N=batch)
if 'backward' in program.args:
b = regint.Array(batch)
b.assign(regint.inc(batch))
@for_range(1000)
def _(i):
sgd.backward(batch=b)

View File

@@ -97,3 +97,6 @@ test(c[0], 0)
test(c[1], 1)
test(c[2], 1)
test(c[3], 0)
test(sbit(sbits(3)), 1)
test(sbits(sbit(1)), 1)

51
Protocols/FakeInput.h Normal file
View File

@@ -0,0 +1,51 @@
/*
* FakeProtocol.h
*
*/
#ifndef PROTOCOLS_FAKEINPUT_H_
#define PROTOCOLS_FAKEINPUT_H_
#include "Replicated.h"
#include "Processor/Input.h"
template<class T>
class FakeInput : public InputBase<T>
{
PointerVector<T> results;
public:
FakeInput(SubProcessor<T>&, typename T::MAC_Check&)
{
}
void reset(int)
{
results.clear();
}
void add_mine(const typename T::open_type& x, int = -1)
{
results.push_back(x);
}
void add_other(int)
{
}
void send_mine()
{
}
T finalize_mine()
{
return results.next();
}
void finalize_other(int, T&, octetStream&, int = -1)
{
throw not_implemented();
}
};
#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */

31
Protocols/FakeMC.h Normal file
View File

@@ -0,0 +1,31 @@
/*
* FakeMC.h
*
*/
#ifndef PROTOCOLS_FAKEMC_H_
#define PROTOCOLS_FAKEMC_H_
#include "MAC_Check_Base.h"
template<class T>
class FakeMC : public MAC_Check_Base<T>
{
public:
FakeMC(T, int = 0, int = 0)
{
}
void exchange(const Player&)
{
for (auto& x : this->secrets)
this->values.push_back(x);
}
FakeMC& get_part_MC()
{
return *this;
}
};
#endif /* PROTOCOLS_FAKEMC_H_ */

82
Protocols/FakePrep.h Normal file
View File

@@ -0,0 +1,82 @@
/*
* FakePrep.h
*
*/
#ifndef PROTOCOLS_FAKEPREP_H_
#define PROTOCOLS_FAKEPREP_H_
#include "ReplicatedPrep.h"
template<class T>
class FakePrep : public BufferPrep<T>
{
SeededPRNG G;
public:
FakePrep(SubProcessor<T>*, DataPositions& usage) :
BufferPrep<T>(usage)
{
}
FakePrep(DataPositions& usage, GC::ShareThread<T>&) :
BufferPrep<T>(usage)
{
}
FakePrep(DataPositions& usage, int = 0) :
BufferPrep<T>(usage)
{
}
void set_protocol(typename T::Protocol&)
{
}
void buffer_triples()
{
for (int i = 0; i < 1000; i++)
{
auto a = G.get<T>();
auto b = G.get<T>();
this->triples.push_back({{a, b, a * b}});
}
}
void buffer_squares()
{
for (int i = 0; i < 1000; i++)
{
auto a = G.get<T>();
this->squares.push_back({{a, a * a}});
}
}
void buffer_inverses()
{
for (int i = 0; i < 1000; i++)
{
auto a = G.get<T>();
T aa;
aa.invert(a);
this->inverses.push_back({{a, aa}});
}
}
void buffer_bits()
{
for (int i = 0; i < 1000; i++)
{
this->bits.push_back(G.get_bit());
}
}
void get_dabit_no_count(T& a, typename T::bit_type& b)
{
auto bit = G.get_bit();
a = bit;
b = bit;
}
};
#endif /* PROTOCOLS_FAKEPREP_H_ */

63
Protocols/FakeProtocol.h Normal file
View File

@@ -0,0 +1,63 @@
/*
* FakeProtocol.h
*
*/
#ifndef PROTOCOLS_FAKEPROTOCOL_H_
#define PROTOCOLS_FAKEPROTOCOL_H_
#include "Replicated.h"
template<class T>
class FakeProtocol : public ProtocolBase<T>
{
PointerVector<T> results;
SeededPRNG G;
public:
Player& P;
FakeProtocol(Player& P) : P(P)
{
}
void init_mul(SubProcessor<T>*)
{
results.clear();
}
typename T::clear prepare_mul(const T& x, const T& y, int = -1)
{
results.push_back(x * y);
return {};
}
void exchange()
{
}
T finalize_mul(int = -1)
{
return results.next();
}
void randoms(T& res, int n_bits)
{
res.randomize_part(G, n_bits);
}
int get_n_relevant_players()
{
return 1;
}
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc)
{
for (size_t i = 0; i < regs.size(); i += 4)
for (int l = 0; l < size; l++)
proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l)
>> regs[i + 3];
}
};
#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */

80
Protocols/FakeShare.h Normal file
View File

@@ -0,0 +1,80 @@
/*
* FakeShare.h
*
*/
#ifndef PROTOCOLS_FAKESHARE_H_
#define PROTOCOLS_FAKESHARE_H_
#include "GC/FakeSecret.h"
#include "ShareInterface.h"
#include "Processor/NoLivePrep.h"
#include "FakeMC.h"
#include "FakeProtocol.h"
#include "FakePrep.h"
#include "FakeInput.h"
template<class T>
class FakeShare : public T, public ShareInterface
{
typedef FakeShare This;
public:
typedef T mac_key_type;
typedef T open_type;
typedef T clear;
typedef FakePrep<This> LivePrep;
typedef FakeMC<This> MAC_Check;
typedef MAC_Check Direct_MC;
typedef FakeInput<This> Input;
typedef ::PrivateOutput<This> PrivateOutput;
typedef FakeProtocol<This> Protocol;
typedef GC::FakeSecret bit_type;
static string type_short()
{
return "emul";
}
static int threshold(int)
{
return 0;
}
static T constant(T value, int = 0, T = 0)
{
return value;
}
FakeShare()
{
}
template<class U>
FakeShare(U other) :
T(other)
{
}
void assign(T value, int = 0, T = 0)
{
*this = value;
}
void add(T a, T b, int = 0, T = {})
{
*this = a + b;
}
void sub(T a, T b, int = 0, T = {})
{
*this = a - b;
}
static void split(vector<bit_type>& dest, const vector<int>& regs, int n_bits,
const This* source, int n_inputs, Player& P);
};
#endif /* PROTOCOLS_FAKESHARE_H_ */

50
Protocols/FakeShare.hpp Normal file
View File

@@ -0,0 +1,50 @@
/*
* FakeShare.cpp
*
*/
#include "FakeShare.h"
#include "Math/Z2k.h"
#include "GC/square64.h"
template<class T>
void FakeShare<T>::split(vector<bit_type>& dest,
const vector<int>& regs, int n_bits, const This* source, int n_inputs,
Player&)
{
assert(n_bits <= 64);
int unit = GC::Clear::N_BITS;
for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++)
{
int start = k * unit;
int m = min(unit, n_inputs - start);
switch (regs.size() / n_bits)
{
case 3:
{
for (int i = 0; i < n_bits; i++)
for (int j = 1; j < 3; j++)
dest.at(regs.at(3 * i + j) + k) = {};
square64 square;
for (int j = 0; j < m; j++)
{
square.rows[j] = (source[j + start]).get_limb(0);
}
square.transpose(m, n_bits);
for (int j = 0; j < n_bits; j++)
{
auto& dest_reg = dest.at(regs.at(3 * j) + k);
dest_reg = square.rows[j];
}
break;
}
default:
not_implemented();
}
}
}

View File

@@ -27,10 +27,20 @@ public:
void buffer_inputs(int player);
};
template<class T>
class RingOnlyBitsFromSquaresPrep : public virtual BufferPrep<T>
{
public:
RingOnlyBitsFromSquaresPrep(SubProcessor<T>* proc, DataPositions& usage);
void buffer_bits();
};
// extra class to avoid recursion
template<class T>
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
public virtual MalRepRingPrep<T>
public virtual MalRepRingPrep<T>,
public virtual RingOnlyBitsFromSquaresPrep<T>
{
public:
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
@@ -40,12 +50,20 @@ public:
MaliciousRingPrep<T>::set_protocol(protocol);
}
void buffer_triples()
{
MalRepRingPrep<T>::buffer_triples();
}
void buffer_squares()
{
MalRepRingPrep<T>::buffer_squares();
}
void buffer_bits();
void buffer_bits()
{
RingOnlyBitsFromSquaresPrep<T>::buffer_bits();
}
void get_dabit_no_count(T& a, typename T::bit_type& b)
{

Some files were not shown because too many files have changed in this diff Show More