mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Merge remote-tracking branch 'origin/master' into update-prng-seed
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -116,3 +116,10 @@ Thumbs.db
|
||||
|
||||
# Sphinx build
|
||||
_build/
|
||||
|
||||
# environment
|
||||
.env
|
||||
|
||||
# temp doc files
|
||||
doc/readme.md
|
||||
doc/xml
|
||||
|
||||
13
.gitmodules
vendored
13
.gitmodules
vendored
@@ -1,12 +1,15 @@
|
||||
[submodule "SimpleOT"]
|
||||
path = SimpleOT
|
||||
path = deps/SimpleOT
|
||||
url = https://github.com/mkskeller/SimpleOT
|
||||
[submodule "mpir"]
|
||||
path = mpir
|
||||
url = git://github.com/wbhart/mpir.git
|
||||
[submodule "Programs/Circuits"]
|
||||
path = Programs/Circuits
|
||||
url = https://github.com/mkskeller/bristol-fashion
|
||||
[submodule "simde"]
|
||||
path = simde
|
||||
path = deps/simde
|
||||
url = https://github.com/simd-everywhere/simde
|
||||
[submodule "deps/libOTe"]
|
||||
path = deps/libOTe
|
||||
url = https://github.com/mkskeller/softspoken-implementation
|
||||
[submodule "deps/SimplestOT_C"]
|
||||
path = deps/SimplestOT_C
|
||||
url = https://github.com/mkskeller/SimplestOT_C
|
||||
|
||||
@@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) :
|
||||
}
|
||||
cout << "Compiler: " << prev << endl;
|
||||
P = new PlainPlayer(N, 0);
|
||||
Share<gf2n_long>::MAC_Check::setup(*P);
|
||||
if (argc > 4)
|
||||
threshold = atoi(argv[4]);
|
||||
cout << "Threshold for multi-threaded evaluation: " << threshold << endl;
|
||||
@@ -259,7 +260,8 @@ ProgramParty::~ProgramParty()
|
||||
reset();
|
||||
if (P)
|
||||
{
|
||||
cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl;
|
||||
cerr << "Data sent in online phase = " << P->total_comm().sent * 1e-6
|
||||
<< " MB (this party only)" << endl;
|
||||
delete P;
|
||||
}
|
||||
delete[] eval_threads;
|
||||
@@ -281,6 +283,7 @@ FakeProgramParty::~FakeProgramParty()
|
||||
cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes()
|
||||
<< " GB" << endl;
|
||||
#endif
|
||||
Share<gf2n_long>::MAC_Check::teardown();
|
||||
}
|
||||
|
||||
void FakeProgramParty::_compute_prfs_outputs(Key* keys)
|
||||
|
||||
@@ -48,8 +48,6 @@ public:
|
||||
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
|
||||
ProcessorBase& input_processor, const vector<int>& args);
|
||||
|
||||
RealGarbleWire(const Register& reg) : PRFRegister(reg) {}
|
||||
|
||||
void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
|
||||
const RealGarbleWire<T>& right);
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ void RealGarbleWire<T>::inputbvec(
|
||||
{
|
||||
GarbleInputter<T> inputter;
|
||||
processor.inputbvec(inputter, input_processor, args,
|
||||
inputter.party.P->my_num());
|
||||
*inputter.party.P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -175,7 +175,7 @@ void GarbleInputter<T>::exchange()
|
||||
assert(party.P != 0);
|
||||
assert(party.MC != 0);
|
||||
auto& protocol = party.shared_proc->protocol;
|
||||
protocol.init_mul(party.shared_proc);
|
||||
protocol.init_mul();
|
||||
for (auto& tuple : tuples)
|
||||
protocol.prepare_mul(tuple.first->mask,
|
||||
T::constant(1, party.P->my_num(), party.mac_key)
|
||||
|
||||
@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
|
||||
|
||||
bool one_shot;
|
||||
|
||||
size_t data_sent;
|
||||
|
||||
public:
|
||||
static RealProgramParty& s();
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ RealProgramParty<T>* RealProgramParty<T>::singleton = 0;
|
||||
|
||||
template<class T>
|
||||
RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
garble_processor(garble_machine), dummy_proc({{}, 0})
|
||||
garble_processor(garble_machine), dummy_proc({}, 0)
|
||||
{
|
||||
assert(singleton == 0);
|
||||
singleton = this;
|
||||
@@ -64,7 +64,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
online_opts = {opt, argc, argv, 1000};
|
||||
else
|
||||
online_opts = {opt, argc, argv};
|
||||
assert(not online_opts.interactive);
|
||||
|
||||
online_opts.finalize(opt, argc, argv);
|
||||
this->load(online_opts.progname);
|
||||
@@ -97,8 +96,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
if (online_opts.live_prep)
|
||||
{
|
||||
mac_key.randomize(prng);
|
||||
if (T::needs_ot)
|
||||
BaseMachine::s().ot_setups.push_back({*P, true});
|
||||
prep = new typename T::LivePrep(0, usage);
|
||||
}
|
||||
else
|
||||
@@ -107,10 +104,12 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
|
||||
}
|
||||
|
||||
T::MAC_Check::setup(*P);
|
||||
MC = new typename T::MAC_Check(mac_key);
|
||||
|
||||
garble_processor.reset(program);
|
||||
this->processor.open_input_file(N.my_num(), 0);
|
||||
this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file);
|
||||
this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out);
|
||||
|
||||
shared_proc = new SubProcessor<T>(dummy_proc, *MC, *prep, *P);
|
||||
|
||||
@@ -155,7 +154,9 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
while (next != GC::DONE_BREAK);
|
||||
|
||||
MC->Check(*P);
|
||||
data_sent = P->comm_stats.total_data() + prep->data_sent();
|
||||
|
||||
if (online_opts.verbose)
|
||||
P->total_comm().print();
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
}
|
||||
@@ -173,7 +174,8 @@ void RealProgramParty<T>::garble()
|
||||
garble_jobs.clear();
|
||||
garble_inputter->reset_all(*P);
|
||||
auto& protocol = *garble_protocol;
|
||||
protocol.init_mul(shared_proc);
|
||||
protocol.init(*prep, shared_proc->MC);
|
||||
protocol.init_mul();
|
||||
|
||||
next = this->first_phase(program, garble_processor, this->garble_machine);
|
||||
|
||||
@@ -181,7 +183,8 @@ void RealProgramParty<T>::garble()
|
||||
protocol.exchange();
|
||||
|
||||
typename T::Protocol second_protocol(*P);
|
||||
second_protocol.init_mul(shared_proc);
|
||||
second_protocol.init(*prep, shared_proc->MC);
|
||||
second_protocol.init_mul();
|
||||
for (auto& job : garble_jobs)
|
||||
job.middle_round(*this, second_protocol);
|
||||
|
||||
@@ -212,7 +215,8 @@ RealProgramParty<T>::~RealProgramParty()
|
||||
delete prep;
|
||||
delete garble_inputter;
|
||||
delete garble_protocol;
|
||||
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
garble_machine.print_comm(*this->P, this->P->total_comm());
|
||||
T::MAC_Check::teardown();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -568,6 +568,7 @@ void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
|
||||
octetStream& my_os = oss[party.get_id() - 1];
|
||||
vector<InputAccess> accesses;
|
||||
InputArgList a(args);
|
||||
processor.complexity += a.n_input_bits();
|
||||
for (auto x : a)
|
||||
{
|
||||
accesses.push_back({x , processor});
|
||||
|
||||
@@ -23,6 +23,7 @@ using namespace std;
|
||||
#include "Tools/PointerVector.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "Processor/Instruction.h"
|
||||
|
||||
//#define PAD_TO_8(n) (n+8-n%8)
|
||||
#define PAD_TO_8(n) (n)
|
||||
@@ -61,11 +62,13 @@ private:
|
||||
#endif
|
||||
};
|
||||
#else
|
||||
class BaseKeyVector : public vector<Key>
|
||||
class BaseKeyVector : public CheckVector<Key>
|
||||
{
|
||||
typedef CheckVector<Key> super;
|
||||
|
||||
public:
|
||||
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
|
||||
void resize(int size) { vector<Key>::resize(size, Key(0)); }
|
||||
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
|
||||
void resize(int size) { super::resize(size, Key(0)); }
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -151,7 +154,7 @@ public:
|
||||
* for pipelining matters.
|
||||
*/
|
||||
|
||||
Register(int n_parties);
|
||||
Register();
|
||||
|
||||
void init(int n_parties);
|
||||
void init(int rfd, int n_parties);
|
||||
@@ -234,6 +237,9 @@ public:
|
||||
template <class T>
|
||||
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
|
||||
template <class T>
|
||||
static void andrsvec(T& processor, const vector<int>& args)
|
||||
{ processor.andrsvec(args); }
|
||||
template <class T>
|
||||
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
|
||||
template <class T>
|
||||
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
|
||||
@@ -277,10 +283,6 @@ public:
|
||||
|
||||
static int threshold(int) { throw not_implemented(); }
|
||||
|
||||
static Register new_reg();
|
||||
static Register tmp_reg() { return new_reg(); }
|
||||
static Register and_reg() { return new_reg(); }
|
||||
|
||||
template<class T>
|
||||
static void store(NoMemory& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
|
||||
@@ -289,6 +291,16 @@ public:
|
||||
static void inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
|
||||
template<class U>
|
||||
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("convcbit2s not implemented"); }
|
||||
template<class U>
|
||||
static void andm(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("andm not implemented"); }
|
||||
|
||||
static void run_tapes(const vector<int>&)
|
||||
{ throw runtime_error("multi-threading not implemented"); }
|
||||
|
||||
// most BMR phases don't need actual input
|
||||
template<class T>
|
||||
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
|
||||
@@ -298,8 +310,6 @@ public:
|
||||
void other_input(Input&, int) {}
|
||||
|
||||
char get_output() { return 0; }
|
||||
|
||||
ProgramRegister(const Register& reg) : Register(reg) {}
|
||||
};
|
||||
|
||||
class PRFRegister : public ProgramRegister
|
||||
@@ -311,8 +321,6 @@ public:
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const NoMemory& source);
|
||||
|
||||
PRFRegister(const Register& reg) : ProgramRegister(reg) {}
|
||||
|
||||
void op(const PRFRegister& left, const PRFRegister& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
void input(party_id_t from, char input = -1);
|
||||
@@ -388,8 +396,6 @@ public:
|
||||
static void convcbit(Integer& dest, const GC::Clear& source,
|
||||
GC::Processor<GC::Secret<EvalRegister>>& proc);
|
||||
|
||||
EvalRegister(const Register& reg) : ProgramRegister(reg) {}
|
||||
|
||||
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
|
||||
@@ -419,8 +425,6 @@ public:
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const NoMemory& source);
|
||||
|
||||
GarbleRegister(const Register& reg) : ProgramRegister(reg) {}
|
||||
|
||||
void op(const Register& left, const Register& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
void input(party_id_t from, char value = -1);
|
||||
@@ -444,8 +448,6 @@ public:
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const NoMemory& source);
|
||||
|
||||
RandomRegister(const Register& reg) : ProgramRegister(reg) {}
|
||||
|
||||
void randomize();
|
||||
|
||||
void op(const Register& left, const Register& right, Function func);
|
||||
@@ -461,12 +463,6 @@ public:
|
||||
};
|
||||
|
||||
|
||||
inline Register::Register(int n_parties) :
|
||||
garbled_entry(n_parties), external(NO_SIGNAL),
|
||||
mask(NO_SIGNAL), keys(n_parties)
|
||||
{
|
||||
}
|
||||
|
||||
inline void KeyVector::operator=(const KeyVector& other)
|
||||
{
|
||||
resize(other.size());
|
||||
|
||||
@@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
NoOpInputter inputter;
|
||||
int my_num = -1;
|
||||
try
|
||||
{
|
||||
my_num = ProgramParty::s().P->my_num();
|
||||
}
|
||||
catch (exception&)
|
||||
{
|
||||
}
|
||||
processor.inputbvec(inputter, input_processor, args, my_num);
|
||||
processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
{
|
||||
EvalInputter inputter;
|
||||
processor.inputbvec(inputter, input_processor, args,
|
||||
ProgramParty::s().P->my_num());
|
||||
*ProgramParty::s().P);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
#include "CommonParty.h"
|
||||
#include "Party.h"
|
||||
|
||||
|
||||
inline Register ProgramRegister::new_reg()
|
||||
inline Register::Register() :
|
||||
garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL),
|
||||
mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties())
|
||||
{
|
||||
return Register(CommonParty::s().get_n_parties());
|
||||
}
|
||||
|
||||
#endif /* BMR_REGISTER_INLINE_H_ */
|
||||
|
||||
@@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty()
|
||||
_received_gc_received = 0;
|
||||
n_received = 0;
|
||||
randomfd = open("/dev/urandom", O_RDONLY);
|
||||
done_filling = false;
|
||||
}
|
||||
|
||||
BaseTrustedParty::~BaseTrustedParty()
|
||||
{
|
||||
close(randomfd);
|
||||
}
|
||||
|
||||
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :
|
||||
|
||||
@@ -20,7 +20,7 @@ public:
|
||||
vector<SendBuffer> msg_input_masks;
|
||||
|
||||
BaseTrustedParty();
|
||||
virtual ~BaseTrustedParty() {}
|
||||
virtual ~BaseTrustedParty();
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
virtual void NodeReady();
|
||||
@@ -104,7 +104,6 @@ private:
|
||||
void add_all_keys(const Register& reg, bool external);
|
||||
};
|
||||
|
||||
|
||||
inline void BaseTrustedParty::add_keys(const Register& reg)
|
||||
{
|
||||
for(int p = 0; p < get_n_parties(); p++)
|
||||
|
||||
96
CHANGELOG.md
96
CHANGELOG.md
@@ -1,5 +1,101 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.3.6 (May 9, 2023)
|
||||
|
||||
- More extensive benchmarking outputs
|
||||
- Replace MPIR by GMP
|
||||
- Secure reading of edaBits from files
|
||||
- Semi-honest client communication
|
||||
- Back-propagation for average pooling
|
||||
- Parallelized convolution
|
||||
- Probabilistic truncation as in ABY3
|
||||
- More balanced communication in Shamir secret sharing
|
||||
- Avoid unnecessary communication in Dealer protocol
|
||||
- Linear solver using Cholesky decomposition
|
||||
- Accept .py files for compilation
|
||||
- Fixed security bug: proper accounting for random elements
|
||||
|
||||
## 0.3.5 (Feb 16, 2023)
|
||||
|
||||
- Easier-to-use machine learning interface
|
||||
- Integrated compilation-execution facility
|
||||
- Import/export sequential models and parameters from/to PyTorch
|
||||
- Binary-format input files
|
||||
- Less aggressive round optimization for faster compilation by default
|
||||
- Multithreading with client interface
|
||||
- Functionality to protect order of specific memory accesses
|
||||
- Oblivious transfer works again on older (pre-2011) x86 CPUs
|
||||
- clang is used by default
|
||||
|
||||
## 0.3.4 (Nov 9, 2022)
|
||||
|
||||
- Decision tree learning
|
||||
- Optimized oblivious shuffle in Rep3
|
||||
- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC
|
||||
- Optimized element-vector AND in SemiBin
|
||||
- Optimized input protocol in Shamir-based protocols
|
||||
- Square-root ORAM (@Quitlox)
|
||||
- Improved ORAM in binary circuits
|
||||
- UTF-8 outputs
|
||||
|
||||
## 0.3.3 (Aug 25, 2022)
|
||||
|
||||
- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate
|
||||
- Fix security bug in MAC check when using multithreading
|
||||
- Fix security bug to prevent selective failure attack by checking earlier
|
||||
- Fix security bug in Mama: insufficient sacrifice.
|
||||
- Inverse permutation (@Quitlox)
|
||||
- Easier direct compilation (@eriktaubeneck)
|
||||
- Generally allow element-vector operations
|
||||
- Increase maximum register size to 2^54
|
||||
- Client example in Python
|
||||
- Uniform base OTs across platforms
|
||||
- Multithreaded base OT computation
|
||||
- Faster random bit generation in two-player Semi(2k)
|
||||
|
||||
## 0.3.2 (May 27, 2022)
|
||||
|
||||
- Secure shuffling
|
||||
- O(n log n) radix sorting
|
||||
- Documented BGV encryption interface
|
||||
- Optimized matrix multiplication in dealer protocol
|
||||
- Fixed security bug in homomorphic encryption parameter generation
|
||||
- Fixed security bug in Temi matrix multiplication
|
||||
|
||||
## 0.3.1 (Apr 19, 2022)
|
||||
|
||||
- Protocol in dealer model
|
||||
- Command-line option for security parameter
|
||||
- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482))
|
||||
- Ability to run high-level (Python) code from C++
|
||||
- More memory capacity due to 64-bit addressing
|
||||
- Homomorphic encryption for more fields of characteristic two
|
||||
- Docker container
|
||||
|
||||
## 0.3.0 (Feb 17, 2022)
|
||||
|
||||
- Semi-honest computation based on threshold semi-homomorphic encryption
|
||||
- Batch normalization backward propagation
|
||||
- AlexNet for CIFAR-10
|
||||
- Specific private output protocols
|
||||
- Semi-honest additive secret sharing without communication
|
||||
- Sending of personal values
|
||||
- Allow overwriting of persistence files
|
||||
- Protocol signature in persistence files
|
||||
|
||||
## 0.2.9 (Jan 11, 2022)
|
||||
|
||||
- Disassembler
|
||||
- Run-time parameter for probabilistic truncation error
|
||||
- Probabilistic truncation for some protocols computing modulo a prime
|
||||
- Simplified C++ interface
|
||||
- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757)
|
||||
- More general scalar-vector multiplication
|
||||
- Complete memory support for clear bits
|
||||
- Extended clear bit functionality with Yao's garbled circuits
|
||||
- Allow preprocessing information to be supplied via named pipes
|
||||
- In-place operations for containers
|
||||
|
||||
## 0.2.8 (Nov 4, 2021)
|
||||
|
||||
- Tested on Apple laptop with ARM chip
|
||||
|
||||
43
CONFIG
43
CONFIG
@@ -8,6 +8,9 @@ GDEBUG = -g
|
||||
# set this to your preferred local storage directory
|
||||
PREP_DIR = '-DPREP_DIR="Player-Data/"'
|
||||
|
||||
# directory to store SSL keys
|
||||
SSL_DIR = '-DSSL_DIR="Player-Data/"'
|
||||
|
||||
# set for SHE preprocessing (SPDZ and Overdrive)
|
||||
USE_NTL = 0
|
||||
|
||||
@@ -28,25 +31,40 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
|
||||
ARCH = -march=native
|
||||
|
||||
MACHINE := $(shell uname -m)
|
||||
ARM := $(shell uname -m | grep x86; echo $$?)
|
||||
OS := $(shell uname -s)
|
||||
ifeq ($(MACHINE), x86_64)
|
||||
# set this to 0 to avoid using AVX for OT
|
||||
ifeq ($(OS), Linux)
|
||||
CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?)
|
||||
ifeq ($(CHECK_AVX), 0)
|
||||
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
|
||||
AVX_OT = 1
|
||||
else
|
||||
AVX_OT = 0
|
||||
endif
|
||||
else
|
||||
AVX_OT = 1
|
||||
AVX_OT = 0
|
||||
endif
|
||||
else
|
||||
ARCH =
|
||||
AVX_OT = 0
|
||||
endif
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
|
||||
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
|
||||
endif
|
||||
|
||||
ifeq ($(OS), Linux)
|
||||
ifeq ($(ARM), 1)
|
||||
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
|
||||
ARCH = -march=armv8.2-a+crypto
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
USE_KOS = 0
|
||||
|
||||
# allow to set compiler in CONFIG.mine
|
||||
CXX = g++
|
||||
CXX = clang++
|
||||
|
||||
# use CONFIG.mine to overwrite DIR settings
|
||||
-include CONFIG.mine
|
||||
@@ -65,9 +83,13 @@ endif
|
||||
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
|
||||
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
|
||||
|
||||
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
|
||||
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
|
||||
LDLIBS += $(BREW_LDLIBS)
|
||||
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
|
||||
LDLIBS += -lboost_system -lssl -lcrypto
|
||||
|
||||
CFLAGS += -I./local/include
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
CFLAGS += -DUSE_NTL
|
||||
LDLIBS := -lntl $(LDLIBS)
|
||||
@@ -83,7 +105,8 @@ else
|
||||
BOOST = -lboost_thread $(MY_BOOST)
|
||||
endif
|
||||
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CFLAGS += $(BREW_CFLAGS)
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
@@ -94,3 +117,9 @@ ifeq ($(USE_NTL),1)
|
||||
CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(USE_KOS),1)
|
||||
CFLAGS += -DUSE_KOS
|
||||
else
|
||||
CFLAGS += -std=c++17
|
||||
endif
|
||||
|
||||
@@ -13,11 +13,14 @@ import Compiler.instructions as spdz
|
||||
import Compiler.tools as tools
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
|
||||
class SecretBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'sb'
|
||||
name = 'sbit'
|
||||
class ClearBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'cb'
|
||||
name = 'cbit'
|
||||
|
||||
base.ArgFormats['sb'] = SecretBitsAF
|
||||
base.ArgFormats['sbw'] = SecretBitsAF
|
||||
@@ -50,6 +53,7 @@ opcodes = dict(
|
||||
INPUTBVEC = 0x247,
|
||||
SPLIT = 0x248,
|
||||
CONVCBIT2S = 0x249,
|
||||
ANDRSVEC = 0x24a,
|
||||
XORCBI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
NOTCB = 0x212,
|
||||
@@ -64,6 +68,8 @@ opcodes = dict(
|
||||
MULCBI = 0x21c,
|
||||
SHRCBI = 0x21d,
|
||||
SHLCBI = 0x21e,
|
||||
LDMCBI = 0x258,
|
||||
STMCBI = 0x259,
|
||||
CONVCINTVEC = 0x21f,
|
||||
PRINTREGSIGNED = 0x220,
|
||||
PRINTREGB = 0x221,
|
||||
@@ -153,6 +159,52 @@ class andrs(BinaryVectorInstruction):
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
|
||||
req_node.increment(('bit', 'mixed'),
|
||||
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
|
||||
|
||||
class andrsvec(base.VarArgsInstruction, base.Mergeable,
|
||||
base.DynFormatInstruction):
|
||||
""" Constant-vector AND of secret bit registers (vectorized version).
|
||||
|
||||
:param: total number of arguments to follow (int)
|
||||
:param: number of arguments to follow for one operation /
|
||||
operation vector size plus three (int)
|
||||
:param: vector size (int)
|
||||
:param: result vector (sbit)
|
||||
:param: (repeat)...
|
||||
:param: constant operand (sbits)
|
||||
:param: vector operand
|
||||
:param: (repeat)...
|
||||
:param: (repeat from number of arguments to follow for one operation)...
|
||||
|
||||
"""
|
||||
code = opcodes['ANDRSVEC']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(andrsvec, self).__init__(*args, **kwargs)
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
size = self.args[i + 1]
|
||||
for x in self.args[i + 2:i + n]:
|
||||
assert x.n == size
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
yield 'int'
|
||||
for i, n in cls.bases(args):
|
||||
yield 'int'
|
||||
n_args = (n - 3) // 2
|
||||
assert n_args > 0
|
||||
for j in range(n_args):
|
||||
yield 'sbw'
|
||||
for j in range(n_args + 1):
|
||||
yield 'sb'
|
||||
yield 'int'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
size = self.args[i + 1]
|
||||
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
|
||||
req_node.increment(('bit', 'mixed'), size)
|
||||
|
||||
class ands(BinaryVectorInstruction):
|
||||
""" Bitwise AND of secret bit register vector.
|
||||
@@ -303,7 +355,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
|
||||
:param: memory address (int)
|
||||
"""
|
||||
code = opcodes['LDMSB']
|
||||
arg_format = ['sbw','int']
|
||||
arg_format = ['sbw','long']
|
||||
|
||||
class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
""" Copy secret bit register to secret bit memory cell with compile-time
|
||||
@@ -313,7 +365,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
:param: memory address (int)
|
||||
"""
|
||||
code = opcodes['STMSB']
|
||||
arg_format = ['sb','int']
|
||||
arg_format = ['sb','long']
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super(type(self), self).__init__(*args, **kwargs)
|
||||
# import inspect
|
||||
@@ -328,7 +380,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
|
||||
:param: memory address (int)
|
||||
"""
|
||||
code = opcodes['LDMCB']
|
||||
arg_format = ['cbw','int']
|
||||
arg_format = ['cbw','long']
|
||||
|
||||
class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
""" Copy clear bit register to clear bit memory cell with compile-time
|
||||
@@ -338,9 +390,10 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
:param: memory address (int)
|
||||
"""
|
||||
code = opcodes['STMCB']
|
||||
arg_format = ['cb','int']
|
||||
arg_format = ['cb','long']
|
||||
|
||||
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
|
||||
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction,
|
||||
base.IndirectMemoryInstruction):
|
||||
""" Copy secret bit memory cell with run-time address to secret bit
|
||||
register.
|
||||
|
||||
@@ -349,8 +402,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
|
||||
"""
|
||||
code = opcodes['LDMSBI']
|
||||
arg_format = ['sbw','ci']
|
||||
direct = staticmethod(ldmsb)
|
||||
|
||||
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
|
||||
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction,
|
||||
base.IndirectMemoryInstruction):
|
||||
""" Copy secret bit register to secret bit memory cell with run-time
|
||||
address.
|
||||
|
||||
@@ -359,6 +414,31 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
|
||||
"""
|
||||
code = opcodes['STMSBI']
|
||||
arg_format = ['sb','ci']
|
||||
direct = staticmethod(stmsb)
|
||||
|
||||
class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction,
|
||||
base.IndirectMemoryInstruction):
|
||||
""" Copy clear bit memory cell with run-time address to clear bit
|
||||
register.
|
||||
|
||||
:param: destination (cbit)
|
||||
:param: memory address (regint)
|
||||
"""
|
||||
code = opcodes['LDMCBI']
|
||||
arg_format = ['cbw','ci']
|
||||
direct = staticmethod(ldmcb)
|
||||
|
||||
class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction,
|
||||
base.IndirectMemoryInstruction):
|
||||
""" Copy clear bit register to clear bit memory cell with run-time
|
||||
address.
|
||||
|
||||
:param: source (cbit)
|
||||
:param: memory address (regint)
|
||||
"""
|
||||
code = opcodes['STMCBI']
|
||||
arg_format = ['cb','ci']
|
||||
direct = staticmethod(stmcb)
|
||||
|
||||
class ldmsdi(base.ReadMemoryInstruction):
|
||||
code = opcodes['LDMSDI']
|
||||
@@ -475,7 +555,7 @@ class movsb(NonVectorInstruction):
|
||||
code = opcodes['MOVSB']
|
||||
arg_format = ['sbw','sb']
|
||||
|
||||
class trans(base.VarArgsInstruction):
|
||||
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
|
||||
""" Secret bit register vector transpose. The first destination vector
|
||||
will contain the least significant bits of all source vectors etc.
|
||||
|
||||
@@ -489,10 +569,22 @@ class trans(base.VarArgsInstruction):
|
||||
code = opcodes['TRANS']
|
||||
is_vec = lambda self: True
|
||||
def __init__(self, *args):
|
||||
self.arg_format = ['int'] + ['sbw'] * args[0] + \
|
||||
['sb'] * (len(args) - 1 - args[0])
|
||||
super(trans, self).__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
yield 'int'
|
||||
n = next(args)
|
||||
for i in range(n):
|
||||
yield 'sbw'
|
||||
next(args)
|
||||
while True:
|
||||
try:
|
||||
yield 'sb'
|
||||
next(args)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
class bitb(NonVectorInstruction):
|
||||
""" Copy fresh secret random bit to secret bit register.
|
||||
|
||||
@@ -538,7 +630,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1])
|
||||
|
||||
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
base.Mergeable):
|
||||
base.Mergeable, base.DynFormatInstruction):
|
||||
""" Copy private input to secret bit registers bit by bit. The input is
|
||||
read as floating-point number, multiplied by a power of two, rounded to an
|
||||
integer, and then decomposed into bits.
|
||||
@@ -555,11 +647,19 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
code = opcodes['INPUTBVEC']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.arg_format = []
|
||||
for x in self.get_arg_tuples(args):
|
||||
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
|
||||
super(inputbvec, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
yield 'int'
|
||||
for i, n in cls.bases(args):
|
||||
yield 'int'
|
||||
yield 'p'
|
||||
assert n > 3
|
||||
for j in range(n - 3):
|
||||
yield 'sbw'
|
||||
yield 'int'
|
||||
|
||||
@staticmethod
|
||||
def get_arg_tuples(args):
|
||||
i = 0
|
||||
@@ -568,10 +668,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
i += args[i]
|
||||
assert i == len(args)
|
||||
|
||||
def merge(self, other):
|
||||
self.args += other.args
|
||||
self.arg_format += other.arg_format
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for x in self.get_arg_tuples(self.args):
|
||||
req_node.increment(('bit', 'input', x[2]), x[0] - 3)
|
||||
|
||||
@@ -3,10 +3,13 @@ This modules contains basic types for binary circuits. The
|
||||
fixed-length types obtained by :py:obj:`get_type(n)` are the preferred
|
||||
way of using them, and in some cases required in connection with
|
||||
container types.
|
||||
|
||||
Computation using these types will always be executed as a binary
|
||||
circuit. See :ref:`protocol-pairs` for the exact protocols.
|
||||
"""
|
||||
|
||||
from Compiler.types import MemValue, read_mem_value, regint, Array, cint
|
||||
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint
|
||||
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint, sintbit
|
||||
from Compiler.program import Tape, Program
|
||||
from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint, library
|
||||
@@ -14,10 +17,10 @@ from Compiler import instructions_base
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
import math
|
||||
import itertools
|
||||
from functools import reduce
|
||||
|
||||
class bits(Tape.Register, _structure, _bit):
|
||||
""" Base class for binary registers. """
|
||||
n = 40
|
||||
unit = 64
|
||||
PreOp = staticmethod(floatingpoint.PreOpN)
|
||||
@@ -41,7 +44,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return cls.types[length]
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cls):
|
||||
if isinstance(other, cls) and cls.n == other.n:
|
||||
return other
|
||||
elif isinstance(other, MemValue):
|
||||
return cls.conv(other.read())
|
||||
@@ -56,12 +59,12 @@ class bits(Tape.Register, _structure, _bit):
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
bits = list(bits)
|
||||
if len(bits) == 1:
|
||||
if len(bits) == 1 and isinstance(bits[0], cls):
|
||||
return bits[0]
|
||||
bits = list(bits)
|
||||
for i in range(len(bits)):
|
||||
if util.is_constant(bits[i]):
|
||||
bits[i] = sbit(bits[i])
|
||||
bits[i] = cls.bit_type(bits[i])
|
||||
res = cls.new(n=len(bits))
|
||||
if len(bits) <= cls.unit:
|
||||
cls.bitcom(res, *(sbit.conv(bit) for bit in bits))
|
||||
@@ -111,11 +114,16 @@ class bits(Tape.Register, _structure, _bit):
|
||||
if mem_type == 'sd':
|
||||
return cls.load_dynamic_mem(address)
|
||||
else:
|
||||
for i in range(res.size):
|
||||
cls.load_inst[util.is_constant(address)](res[i], address + i)
|
||||
cls.mem_op(cls.load_inst, res, address)
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, int)](self, address)
|
||||
self.mem_op(self.store_inst, self, address)
|
||||
@staticmethod
|
||||
def mem_op(inst, reg, address):
|
||||
direct = isinstance(address, int)
|
||||
if not direct:
|
||||
address = regint.conv(address)
|
||||
inst[direct](reg, address)
|
||||
@classmethod
|
||||
def new(cls, value=None, n=None):
|
||||
if util.is_constant(value):
|
||||
@@ -147,19 +155,29 @@ class bits(Tape.Register, _structure, _bit):
|
||||
self.set_length(self.n or util.int_len(other))
|
||||
self.load_int(other)
|
||||
elif isinstance(other, regint):
|
||||
assert(other.size == math.ceil(self.n / self.unit))
|
||||
for i, (x, y) in enumerate(zip(self, other)):
|
||||
assert self.unit == 64
|
||||
n_units = int(math.ceil(self.n / self.unit))
|
||||
n_convs = min(other.size, n_units)
|
||||
for i in range(n_convs):
|
||||
x = self[i]
|
||||
y = other[i]
|
||||
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
|
||||
for i in range(n_convs, n_units):
|
||||
inst.ldbits(self[i], min(self.unit, self.n - i * self.unit), 0)
|
||||
elif (isinstance(self, type(other)) or isinstance(other, type(self))) \
|
||||
and self.n == other.n:
|
||||
for i in range(math.ceil(self.n / self.unit)):
|
||||
self.mov(self[i], other[i])
|
||||
elif isinstance(other, sintbit) and isinstance(self, sbits):
|
||||
assert len(other) == 1
|
||||
r = sint.get_dabit()
|
||||
self.mov(self, r[1] ^ other.bit_xor(r[0]).reveal())
|
||||
elif isinstance(other, sint) and isinstance(self, sbits):
|
||||
self.mov(self, sbitvec(other, self.n).elements()[0])
|
||||
else:
|
||||
try:
|
||||
bits = other.bit_decompose()
|
||||
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
|
||||
bits = bits[:self.n] + [self.bit_type(0)] * (self.n - len(bits))
|
||||
other = self.bit_compose(bits)
|
||||
assert(isinstance(other, type(self)))
|
||||
assert(other.n == self.n)
|
||||
@@ -184,6 +202,8 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return 0
|
||||
elif self.is_long_one(other):
|
||||
return self
|
||||
elif isinstance(other, _vec):
|
||||
return other & other.from_vec([self])
|
||||
else:
|
||||
return self._and(other)
|
||||
@read_mem_value
|
||||
@@ -222,17 +242,30 @@ class bits(Tape.Register, _structure, _bit):
|
||||
This will output 1.
|
||||
"""
|
||||
return result_conv(x, y)(self & (x ^ y) ^ y)
|
||||
def zero_if_not(self, condition):
|
||||
if util.is_constant(condition):
|
||||
return self * condition
|
||||
else:
|
||||
return self * cbit.conv(condition)
|
||||
def expand(self, length):
|
||||
if self.n in (length, None):
|
||||
return self
|
||||
elif self.n == 1:
|
||||
return self.get_type(length).bit_compose([self] * length)
|
||||
else:
|
||||
raise CompilerError('cannot expand from %s to %s' % (self.n, length))
|
||||
|
||||
class cbits(bits):
|
||||
""" Clear bits register. Helper type with limited functionality. """
|
||||
max_length = 64
|
||||
reg_type = 'cb'
|
||||
is_clear = True
|
||||
load_inst = (None, inst.ldmcb)
|
||||
store_inst = (None, inst.stmcb)
|
||||
load_inst = (inst.ldmcbi, inst.ldmcb)
|
||||
store_inst = (inst.stmcbi, inst.stmcb)
|
||||
bitdec = inst.bitdecc
|
||||
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
|
||||
conv_cint_vec = inst.convcintvec
|
||||
mov = staticmethod(lambda x, y: inst.addcbi(x, y, 0))
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
return sum(bit << i for i, bit in enumerate(bits))
|
||||
@@ -241,14 +274,26 @@ class cbits(bits):
|
||||
assert n == res.n
|
||||
assert n == other.size
|
||||
cls.conv_cint_vec(cint(other, size=other.size), res)
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cbits) and cls.n != None and \
|
||||
cls.n // cls.unit == other.n // cls.unit:
|
||||
if isinstance(other, cls):
|
||||
return other
|
||||
else:
|
||||
res = cls()
|
||||
for i in range(math.ceil(cls.n / cls.unit)):
|
||||
cls.mov(res[i], other[i])
|
||||
return res
|
||||
else:
|
||||
return super(cbits, cls).conv(other)
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
if self.n <= 64:
|
||||
tmp = regint(value)
|
||||
elif value == self.long_one():
|
||||
tmp = cint(1, size=self.n)
|
||||
else:
|
||||
raise CompilerError('loading long integers to cbits not supported')
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=n_limbs)
|
||||
for i in range(n_limbs):
|
||||
tmp[i].load_int(value % 2 ** self.unit)
|
||||
value >>= self.unit
|
||||
self.load_other(tmp)
|
||||
def store_in_dynamic_mem(self, address):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
@@ -270,8 +315,15 @@ class cbits(bits):
|
||||
return op(self, cbits(other))
|
||||
__add__ = lambda self, other: \
|
||||
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
|
||||
__sub__ = lambda self, other: \
|
||||
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
|
||||
def __sub__(self, other):
|
||||
try:
|
||||
return self + -other
|
||||
except:
|
||||
return type(self)(regint(self) - regint(other))
|
||||
def __rsub__(self, other):
|
||||
return type(self)(other - regint(self))
|
||||
def __neg__(self):
|
||||
return type(self)(-regint(self))
|
||||
def _xor(self, other):
|
||||
if isinstance(other, (sbits, sbitvec)):
|
||||
return NotImplemented
|
||||
@@ -363,7 +415,6 @@ class sbits(bits):
|
||||
reg_type = 'sb'
|
||||
is_clear = False
|
||||
clear_type = cbits
|
||||
default_type = cbits
|
||||
load_inst = (inst.ldmsbi, inst.ldmsb)
|
||||
store_inst = (inst.stmsbi, inst.stmsb)
|
||||
bitdec = inst.bitdecs
|
||||
@@ -385,16 +436,25 @@ class sbits(bits):
|
||||
else:
|
||||
return sbits.get_type(n)(value)
|
||||
@staticmethod
|
||||
def _new(value):
|
||||
return value
|
||||
@staticmethod
|
||||
def get_random_bit():
|
||||
res = sbit()
|
||||
inst.bitb(res)
|
||||
return res
|
||||
@staticmethod
|
||||
def _check_input_player(player):
|
||||
if not util.is_constant(player):
|
||||
raise CompilerError('player must be known at compile time '
|
||||
'for binary circuit inputs')
|
||||
@classmethod
|
||||
def get_input_from(cls, player, n_bits=None):
|
||||
""" Secret input from :py:obj:`player`.
|
||||
|
||||
:param: player (int)
|
||||
"""
|
||||
cls._check_input_player(player)
|
||||
if n_bits is None:
|
||||
n_bits = cls.n
|
||||
res = cls()
|
||||
@@ -463,6 +523,8 @@ class sbits(bits):
|
||||
if isinstance(other, int):
|
||||
return self.mul_int(other)
|
||||
try:
|
||||
if (self.n, other.n) == (1, 1):
|
||||
return self & other
|
||||
if min(self.n, other.n) != 1:
|
||||
raise NotImplementedError('high order multiplication')
|
||||
n = max(self.n, other.n)
|
||||
@@ -554,7 +616,15 @@ class sbits(bits):
|
||||
rows = list(rows)
|
||||
if len(rows) == 1 and rows[0].n <= rows[0].unit:
|
||||
return rows[0].bit_decompose()
|
||||
n_columns = rows[0].n
|
||||
for row in rows:
|
||||
try:
|
||||
n_columns = row.n
|
||||
break
|
||||
except:
|
||||
pass
|
||||
for i in range(len(rows)):
|
||||
if util.is_zero(rows[i]):
|
||||
rows[i] = cls.get_type(n_columns)(0)
|
||||
for row in rows:
|
||||
assert(row.n == n_columns)
|
||||
if n_columns == 1 and len(rows) <= cls.unit:
|
||||
@@ -578,7 +648,7 @@ class sbits(bits):
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
|
||||
class sbitvec(_vec):
|
||||
class sbitvec(_vec, _bit):
|
||||
""" Vector of registers of secret bits, effectively a matrix of secret bits.
|
||||
This facilitates parallel arithmetic operations in binary circuits.
|
||||
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
|
||||
@@ -586,7 +656,7 @@ class sbitvec(_vec):
|
||||
You can access the rows by member :py:obj:`v` and the columns by calling
|
||||
:py:obj:`elements`.
|
||||
|
||||
There are three ways to create an instance:
|
||||
There are four ways to create an instance:
|
||||
|
||||
1. By transposition::
|
||||
|
||||
@@ -619,8 +689,14 @@ class sbitvec(_vec):
|
||||
This should output::
|
||||
|
||||
[1, 0, 1]
|
||||
|
||||
4. Private input::
|
||||
|
||||
x = sbitvec.get_type(32).get_input_from(player)
|
||||
|
||||
"""
|
||||
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
|
||||
is_clear = False
|
||||
@classmethod
|
||||
def get_type(cls, n):
|
||||
""" Create type for fixed-length vector of registers of secret bits.
|
||||
@@ -634,17 +710,28 @@ class sbitvec(_vec):
|
||||
return sbit.malloc(size * n, creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@staticmethod
|
||||
def mem_size():
|
||||
return n
|
||||
@classmethod
|
||||
def get_input_from(cls, player):
|
||||
def get_input_from(cls, player, size=1, f=0):
|
||||
""" Secret input from :py:obj:`player`. The input is decomposed
|
||||
into bits.
|
||||
|
||||
:param: player (int)
|
||||
"""
|
||||
res = cls.from_vec(sbit() for i in range(n))
|
||||
inst.inputbvec(n + 3, 0, player, *res.v)
|
||||
return res
|
||||
v = [0] * n
|
||||
sbits._check_input_player(player)
|
||||
instructions_base.check_vector_size(size)
|
||||
for i in range(size):
|
||||
vv = [sbit() for i in range(n)]
|
||||
inst.inputbvec(n + 3, f, player, *vv)
|
||||
for j in range(n):
|
||||
tmp = vv[j] << i
|
||||
v[j] = tmp ^ v[j]
|
||||
sbits._check_input_player(player)
|
||||
return cls.from_vec(v)
|
||||
get_raw_input_from = get_input_from
|
||||
@classmethod
|
||||
def from_vec(cls, vector):
|
||||
@@ -652,47 +739,54 @@ class sbitvec(_vec):
|
||||
res.v = _complement_two_extend(list(vector), n)[:n]
|
||||
return res
|
||||
def __init__(self, other=None, size=None):
|
||||
assert size in (None, 1)
|
||||
instructions_base.check_vector_size(size)
|
||||
if other is not None:
|
||||
if util.is_constant(other):
|
||||
self.v = [sbit((other >> i) & 1) for i in range(n)]
|
||||
t = sbits.get_type(size or 1)
|
||||
self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1))
|
||||
for i in range(n)]
|
||||
elif isinstance(other, _vec):
|
||||
self.v = self.bit_extend(other.v, n)
|
||||
self.v = [type(x)(x) for x in self.bit_extend(other.v, n)]
|
||||
elif isinstance(other, (list, tuple)):
|
||||
self.v = self.bit_extend(sbitvec(other).v, n)
|
||||
else:
|
||||
self.v = sbits.get_type(n)(other).bit_decompose()
|
||||
assert len(self.v) == n
|
||||
assert size is None or size == self.v[0].n
|
||||
@classmethod
|
||||
def load_mem(cls, address):
|
||||
def load_mem(cls, address, size=None):
|
||||
if size not in (None, 1):
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
sb = sbits.get_type(size)
|
||||
return cls.from_vec(sb.bit_compose(
|
||||
sbit.load_mem(address + i + j * n) for j in range(size))
|
||||
for i in range(n))
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
return cls.from_vec(sbit.load_mem(x) for x in address)
|
||||
else:
|
||||
return cls.from_vec(sbit.load_mem(address + i)
|
||||
for i in range(n))
|
||||
def store_in_mem(self, address):
|
||||
size = 1
|
||||
for x in self.v:
|
||||
assert util.is_constant(x) or x.n == 1
|
||||
v = [sbit.conv(x) for x in self.v]
|
||||
if not util.is_constant(x):
|
||||
size = max(size, x.n)
|
||||
v = [sbits.get_type(size).conv(x) for x in self.v]
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
assert max_n == 1
|
||||
for x, y in zip(v, address):
|
||||
x.store_in_mem(y)
|
||||
else:
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
for i in range(n):
|
||||
v[i].store_in_mem(address + i)
|
||||
for j, x in enumerate(v[i].bit_decompose()):
|
||||
x.store_in_mem(address + i + j * n)
|
||||
def reveal(self):
|
||||
if len(self) > cbits.unit:
|
||||
return self.elements()[0].reveal()
|
||||
revealed = [cbit() for i in range(len(self))]
|
||||
for i in range(len(self)):
|
||||
try:
|
||||
inst.reveal(1, revealed[i], self.v[i])
|
||||
except:
|
||||
revealed[i] = cbit.conv(self.v[i])
|
||||
return cbits.get_type(len(self)).bit_compose(revealed)
|
||||
return util.untuplify([x.reveal() for x in self.elements()])
|
||||
@classmethod
|
||||
def two_power(cls, nn):
|
||||
return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1))
|
||||
def two_power(cls, nn, size=1):
|
||||
return cls.from_vec(
|
||||
[0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1))
|
||||
def coerce(self, other):
|
||||
if util.is_constant(other):
|
||||
return self.from_vec(util.bit_decompose(other, n))
|
||||
@@ -705,8 +799,12 @@ class sbitvec(_vec):
|
||||
bits += [0] * (n - len(bits))
|
||||
assert len(bits) == n
|
||||
return cls.from_vec(bits)
|
||||
def zero_if_not(self, condition):
|
||||
return self.from_vec(x.zero_if_not(condition) for x in self.v)
|
||||
def __str__(self):
|
||||
return 'sbitvec(%d)' % n
|
||||
sbitvecn.basic_type = sbitvecn
|
||||
sbitvecn.reg_type = 'sb'
|
||||
return sbitvecn
|
||||
@classmethod
|
||||
def from_vec(cls, vector):
|
||||
@@ -723,6 +821,15 @@ class sbitvec(_vec):
|
||||
def from_matrix(cls, matrix):
|
||||
# any number of rows, limited number of columns
|
||||
return cls.combine(cls(row) for row in matrix)
|
||||
@classmethod
|
||||
def from_hex(cls, string):
|
||||
""" Create from hexadecimal string (little-endian). """
|
||||
assert len(string) % 2 == 0
|
||||
v = []
|
||||
for i in range(0, len(string), 2):
|
||||
v += [sbit(int(x))
|
||||
for x in reversed(bin(int(string[i:i + 2], 16))[2:].zfill(8))]
|
||||
return cls.from_vec(v)
|
||||
def __init__(self, elements=None, length=None, input_length=None):
|
||||
if length:
|
||||
assert isinstance(elements, sint)
|
||||
@@ -769,19 +876,20 @@ class sbitvec(_vec):
|
||||
size = other.size
|
||||
return (other.get_vector(base, min(64, size - base)) \
|
||||
for base in range(0, size, 64))
|
||||
if not isinstance(other, type(self)):
|
||||
return type(self)(other)
|
||||
return other
|
||||
def __xor__(self, other):
|
||||
other = self.coerce(other)
|
||||
return self.from_vec(x ^ y for x, y in zip(self.v, other))
|
||||
return self.from_vec(x ^ y for x, y in zip(*self.expand(other)))
|
||||
def __and__(self, other):
|
||||
return self.from_vec(x & y for x, y in zip(self.v, other.v))
|
||||
return self.from_vec(x & y for x, y in zip(*self.expand(other)))
|
||||
__rxor__ = __xor__
|
||||
__rand__ = __and__
|
||||
def __invert__(self):
|
||||
return self.from_vec(~x for x in self.v)
|
||||
def if_else(self, x, y):
|
||||
assert(len(self.v) == 1)
|
||||
try:
|
||||
return self.from_vec(util.if_else(self.v[0], a, b) \
|
||||
for a, b in zip(x, y))
|
||||
except:
|
||||
return util.if_else(self.v[0], x, y)
|
||||
return util.if_else(self.v[0], x, y)
|
||||
def __iter__(self):
|
||||
return iter(self.v)
|
||||
def __len__(self):
|
||||
@@ -794,6 +902,7 @@ class sbitvec(_vec):
|
||||
return cls.from_vec(other.v)
|
||||
else:
|
||||
return cls(other)
|
||||
hard_conv = conv
|
||||
@property
|
||||
def size(self):
|
||||
if not self.v or util.is_constant(self.v[0]):
|
||||
@@ -806,7 +915,7 @@ class sbitvec(_vec):
|
||||
def store_in_mem(self, address):
|
||||
for i, x in enumerate(self.elements()):
|
||||
x.store_in_mem(address + i)
|
||||
def bit_decompose(self, n_bits=None, security=None):
|
||||
def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None):
|
||||
return self.v[:n_bits]
|
||||
bit_compose = from_vec
|
||||
def reveal(self):
|
||||
@@ -823,6 +932,34 @@ class sbitvec(_vec):
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self.from_vec(x * other for x in self.v)
|
||||
if isinstance(other, sbitvec):
|
||||
if len(other.v) == 1:
|
||||
other = other.v[0]
|
||||
elif len(self.v) == 1:
|
||||
self, other = other, self.v[0]
|
||||
else:
|
||||
raise CompilerError('no operand of lenght 1: %d/%d',
|
||||
(len(self.v), len(other.v)))
|
||||
if not isinstance(other, sbits):
|
||||
return NotImplemented
|
||||
ops = []
|
||||
for x in self.v:
|
||||
if not util.is_zero(x):
|
||||
assert x.n == other.n
|
||||
ops.append(x)
|
||||
if ops:
|
||||
prods = [sbits.get_type(other.n)() for i in ops]
|
||||
inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops)
|
||||
res = []
|
||||
i = 0
|
||||
for x in self.v:
|
||||
if util.is_zero(x):
|
||||
res.append(0)
|
||||
else:
|
||||
res.append(prods[i])
|
||||
i += 1
|
||||
return sbitvec.from_vec(res)
|
||||
__rmul__ = __mul__
|
||||
def __add__(self, other):
|
||||
return self.from_vec(x + y for x, y in zip(self.v, other))
|
||||
def bit_and(self, other):
|
||||
@@ -831,6 +968,60 @@ class sbitvec(_vec):
|
||||
return self ^ other
|
||||
def right_shift(self, m, k, security=None, signed=True):
|
||||
return self.from_vec(self.v[m:])
|
||||
def tree_reduce(self, function):
|
||||
elements = self.elements()
|
||||
while len(elements) > 1:
|
||||
size = len(elements)
|
||||
half = size // 2
|
||||
left = elements[:half]
|
||||
right = elements[half:2*half]
|
||||
odd = elements[2*half:]
|
||||
sides = [self.from_vec(sbitvec(x).v) for x in (left, right)]
|
||||
red = function(*sides)
|
||||
elements = red.elements()
|
||||
elements += odd
|
||||
return self.from_vec(sbitvec(elements).v)
|
||||
@classmethod
|
||||
def comp_result(cls, x):
|
||||
return cls.get_type(1).from_vec([x])
|
||||
def expand(self, other, expand=True):
|
||||
m = 1
|
||||
for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []):
|
||||
try:
|
||||
m = max(m, x.n)
|
||||
except:
|
||||
pass
|
||||
res = []
|
||||
if not util.is_constant(other):
|
||||
other = self.coerce(other)
|
||||
for y in self, other:
|
||||
if isinstance(y, int):
|
||||
res.append([x * sbits.get_type(m)().long_one()
|
||||
for x in util.bit_decompose(y, len(self.v))])
|
||||
else:
|
||||
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
|
||||
return res
|
||||
def demux(self):
|
||||
if len(self) == 1:
|
||||
return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]])
|
||||
a = sbitvec.from_vec(self.v[:len(self) // 2]).demux()
|
||||
b = sbitvec.from_vec(self.v[len(self) // 2:]).demux()
|
||||
prod = [a * bb for bb in b.v]
|
||||
return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod)))
|
||||
def reverse_bytes(self):
|
||||
if len(self.v) % 8 != 0:
|
||||
raise CompilerError('bit length not divisible by eight')
|
||||
return self.from_vec(sum(reversed(
|
||||
[self.v[i:i + 8] for i in range(0, len(self.v), 8)]), []))
|
||||
def reveal_print_hex(self):
|
||||
""" Reveal and print in hexademical (one line per element). """
|
||||
for x in self.reverse_bytes().elements():
|
||||
x.reveal().print_reg()
|
||||
def update(self, other):
|
||||
other = self.conv(other)
|
||||
assert len(self.v) == len(other.v)
|
||||
for x, y in zip(self.v, other.v):
|
||||
x.update(y)
|
||||
|
||||
class bit(object):
|
||||
n = 1
|
||||
@@ -881,10 +1072,11 @@ class cbit(bit, cbits):
|
||||
sbits.bit_type = sbit
|
||||
cbits.bit_type = cbit
|
||||
sbit.clear_type = cbit
|
||||
sbits.default_type = sbits
|
||||
|
||||
class bitsBlock(oram.Block):
|
||||
value_type = sbits
|
||||
def __init__(self, value, start, lengths, entries_per_block):
|
||||
self.value_type = type(value)
|
||||
oram.Block.__init__(self, value, lengths)
|
||||
length = sum(self.lengths)
|
||||
used_bits = entries_per_block * length
|
||||
@@ -929,7 +1121,10 @@ sbits.dynamic_array = DynamicArray
|
||||
cbits.dynamic_array = Array
|
||||
|
||||
def _complement_two_extend(bits, k):
|
||||
return bits[:k] + [bits[-1]] * (k - len(bits))
|
||||
if len(bits) == 1:
|
||||
return bits + [0] * (k - len(bits))
|
||||
else:
|
||||
return bits[:k] + [bits[-1]] * (k - len(bits))
|
||||
|
||||
class _sbitintbase:
|
||||
def extend(self, n):
|
||||
@@ -988,6 +1183,9 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
mul: 15
|
||||
lt: 0
|
||||
|
||||
This class is retained for compatibility, but development now
|
||||
focuses on :py:class:`sbitintvec`.
|
||||
|
||||
"""
|
||||
n_bits = None
|
||||
bin_type = None
|
||||
@@ -1079,7 +1277,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
:param k: bit length of input """
|
||||
return _sbitintbase.pow2(self, k)
|
||||
|
||||
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
|
||||
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
"""
|
||||
Vector of signed integers for parallel binary computation::
|
||||
|
||||
@@ -1114,19 +1312,34 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
|
||||
def __add__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
other = self.coerce(other)
|
||||
assert(len(self.v) == len(other.v))
|
||||
v = sbitint.bit_adder(self.v, other.v)
|
||||
return self.from_vec(v)
|
||||
a, b = self.expand(other)
|
||||
v = sbitint.bit_adder(a, b)
|
||||
return self.get_type(len(v)).from_vec(v)
|
||||
__radd__ = __add__
|
||||
__sub__ = _bitint.__sub__
|
||||
def __rsub__(self, other):
|
||||
a, b = self.expand(other)
|
||||
return self.from_vec(b) - self.from_vec(a)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbits):
|
||||
return self.from_vec(other * x for x in self.v)
|
||||
elif len(self.v) == 1:
|
||||
return other * self.v[0]
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return NotImplemented
|
||||
my_bits, other_bits = self.expand(other, False)
|
||||
matrix = []
|
||||
for i, b in enumerate(util.bit_decompose(other)):
|
||||
matrix.append([x & b for x in self.v[:len(self.v)-i]])
|
||||
m = float('inf')
|
||||
for x in itertools.chain(my_bits, other_bits):
|
||||
try:
|
||||
m = min(m, x.n)
|
||||
except:
|
||||
pass
|
||||
for i, b in enumerate(other_bits):
|
||||
if m == 1:
|
||||
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
|
||||
else:
|
||||
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
|
||||
v = sbitint.wallace_tree_from_matrix(matrix)
|
||||
return self.from_vec(v[:len(self.v)])
|
||||
__rmul__ = __mul__
|
||||
@@ -1157,22 +1370,27 @@ class cbitfix(object):
|
||||
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
|
||||
@classmethod
|
||||
def _new(cls, value):
|
||||
if isinstance(value, list):
|
||||
return [cls._new(x) for x in value]
|
||||
res = cls()
|
||||
if cls.k < value.unit:
|
||||
bits = value.bit_decompose(cls.k)
|
||||
sign = bits[-1]
|
||||
value += (sign << (cls.k)) * -1
|
||||
res.v = value
|
||||
return res
|
||||
def output(self):
|
||||
v = self.v
|
||||
if self.k < v.unit:
|
||||
bits = self.v.bit_decompose(self.k)
|
||||
sign = bits[-1]
|
||||
v += (sign << (self.k)) * -1
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
|
||||
class sbitfix(_fix):
|
||||
""" Secret signed integer in one binary register.
|
||||
""" Secret signed fixed-point number in one binary register.
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
|
||||
This class is retained for compatibility, but development now
|
||||
focuses on :py:class:`sbitfixvec`.
|
||||
|
||||
Example::
|
||||
|
||||
print_ln('add: %s', (sbitfix(0.5) + sbitfix(0.3)).reveal())
|
||||
@@ -1211,6 +1429,7 @@ class sbitfix(_fix):
|
||||
|
||||
:param: player (int)
|
||||
"""
|
||||
sbits._check_input_player(player)
|
||||
v = cls.int_type()
|
||||
inst.inputb(player, cls.k, cls.f, v)
|
||||
return cls._new(v)
|
||||
@@ -1233,7 +1452,7 @@ class sbitfix(_fix):
|
||||
cls.set_precision(f, k)
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
class sbitfixvec(_fix):
|
||||
class sbitfixvec(_fix, _vec):
|
||||
""" Vector of fixed-point numbers for parallel binary computation.
|
||||
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
@@ -1262,23 +1481,27 @@ class sbitfixvec(_fix):
|
||||
int_type = sbitintvec.get_type(sbitfix.k)
|
||||
float_type = type(None)
|
||||
clear_type = cbitfix
|
||||
@property
|
||||
def bit_type(self):
|
||||
return type(self.v[0])
|
||||
@classmethod
|
||||
def set_precision(cls, f, k=None):
|
||||
super(sbitfixvec, cls).set_precision(f=f, k=k)
|
||||
cls.int_type = sbitintvec.get_type(cls.k)
|
||||
@classmethod
|
||||
def get_input_from(cls, player):
|
||||
def get_input_from(cls, player, size=1):
|
||||
""" Secret input from :py:obj:`player`.
|
||||
|
||||
:param: player (int)
|
||||
"""
|
||||
v = [sbit() for i in range(sbitfix.k)]
|
||||
inst.inputbvec(len(v) + 3, sbitfix.f, player, *v)
|
||||
return cls._new(cls.int_type.from_vec(v))
|
||||
return cls._new(cls.int_type.get_input_from(player, size=size,
|
||||
f=sbitfix.f))
|
||||
def __init__(self, value=None, *args, **kwargs):
|
||||
if isinstance(value, (list, tuple)):
|
||||
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
|
||||
else:
|
||||
if isinstance(value, sbitvec):
|
||||
value = self.int_type(value)
|
||||
super(sbitfixvec, self).__init__(value, *args, **kwargs)
|
||||
def elements(self):
|
||||
return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()]
|
||||
@@ -1288,9 +1511,12 @@ class sbitfixvec(_fix):
|
||||
else:
|
||||
return super(sbitfixvec, self).mul(other)
|
||||
def __xor__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
return self._new(self.v ^ other.v)
|
||||
def __and__(self, other):
|
||||
return self._new(self.v & other.v)
|
||||
__rxor__ = __xor__
|
||||
@staticmethod
|
||||
def multipliable(other, k, f, size):
|
||||
class cls(_fix):
|
||||
|
||||
@@ -2,30 +2,3 @@ from . import compilerLib, program, instructions, types, library, floatingpoint
|
||||
from .GC import types as GC_types
|
||||
import inspect
|
||||
from .config import *
|
||||
from .compilerLib import run
|
||||
|
||||
|
||||
# add all instructions to the program VARS dictionary
|
||||
compilerLib.VARS = {}
|
||||
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
|
||||
|
||||
for mod in (types, GC_types):
|
||||
instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\
|
||||
if t[1].__module__ == mod.__name__]
|
||||
|
||||
instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\
|
||||
if t[1].__module__ == library.__name__]
|
||||
|
||||
for op in instr_classes:
|
||||
compilerLib.VARS[op.__name__] = op
|
||||
|
||||
# add open and input separately due to name conflict
|
||||
compilerLib.VARS['open'] = instructions.asm_open
|
||||
compilerLib.VARS['vopen'] = instructions.vasm_open
|
||||
compilerLib.VARS['gopen'] = instructions.gasm_open
|
||||
compilerLib.VARS['vgopen'] = instructions.vgasm_open
|
||||
compilerLib.VARS['input'] = instructions.asm_input
|
||||
compilerLib.VARS['ginput'] = instructions.gasm_input
|
||||
|
||||
compilerLib.VARS['comparison'] = comparison
|
||||
compilerLib.VARS['floatingpoint'] = floatingpoint
|
||||
|
||||
@@ -15,11 +15,11 @@ from functools import reduce
|
||||
class BlockAllocator:
|
||||
""" Manages freed memory blocks. """
|
||||
def __init__(self):
|
||||
self.by_logsize = [defaultdict(set) for i in range(32)]
|
||||
self.by_logsize = [defaultdict(set) for i in range(64)]
|
||||
self.by_address = {}
|
||||
|
||||
def by_size(self, size):
|
||||
if size >= 2 ** 32:
|
||||
if size >= 2 ** 64:
|
||||
raise CompilerError('size exceeds addressing capability')
|
||||
return self.by_logsize[int(math.log(size, 2))][size]
|
||||
|
||||
@@ -101,6 +101,7 @@ class StraightlineAllocator:
|
||||
self.dealloc |= reg.vector
|
||||
else:
|
||||
self.dealloc.add(reg)
|
||||
reg.duplicates.remove(reg)
|
||||
base = reg.vectorbase
|
||||
|
||||
seen = set_by_id()
|
||||
@@ -171,7 +172,7 @@ class StraightlineAllocator:
|
||||
for reg in self.alloc:
|
||||
for x in reg.get_all():
|
||||
if x not in self.dealloc and reg not in self.dealloc \
|
||||
and len(x.duplicates) == 1:
|
||||
and len(x.duplicates) == 0:
|
||||
print('Warning: read before write at register', x)
|
||||
print('\tregister trace: %s' % format_trace(x.caller,
|
||||
'\t\t'))
|
||||
@@ -261,6 +262,7 @@ class Merger:
|
||||
instructions = self.instructions
|
||||
merge_nodes = self.open_nodes
|
||||
depths = self.depths
|
||||
self.req_num = defaultdict(lambda: 0)
|
||||
if not merge_nodes:
|
||||
return 0
|
||||
|
||||
@@ -281,6 +283,7 @@ class Merger:
|
||||
print('Merging %d %s in round %d/%d' % \
|
||||
(len(merge), t.__name__, i, len(merges)))
|
||||
self.do_merge(merge)
|
||||
self.req_num[t.__name__, 'round'] += 1
|
||||
|
||||
preorder = None
|
||||
|
||||
@@ -310,9 +313,9 @@ class Merger:
|
||||
|
||||
reg_nodes = {}
|
||||
last_def = defaultdict_by_id(lambda: -1)
|
||||
last_read = defaultdict_by_id(list)
|
||||
last_mem_write = []
|
||||
last_mem_read = []
|
||||
warned_about_mem = []
|
||||
last_mem_write_of = defaultdict(list)
|
||||
last_mem_read_of = defaultdict(list)
|
||||
last_print_str = None
|
||||
@@ -329,6 +332,8 @@ class Merger:
|
||||
round_type = {}
|
||||
|
||||
def add_edge(i, j):
|
||||
if i in (-1, j):
|
||||
return
|
||||
G.add_edge(i, j)
|
||||
for d in (self.depths, self.real_depths):
|
||||
if d[j] < d[i]:
|
||||
@@ -336,10 +341,15 @@ class Merger:
|
||||
|
||||
def read(reg, n):
|
||||
for dup in reg.duplicates:
|
||||
if last_def[dup] != -1:
|
||||
if last_def[dup] not in (-1, n):
|
||||
add_edge(last_def[dup], n)
|
||||
last_read[reg].append(n)
|
||||
|
||||
def write(reg, n):
|
||||
for dup in reg.duplicates:
|
||||
add_edge(last_def[dup], n)
|
||||
for m in last_read[dup]:
|
||||
add_edge(m, n)
|
||||
last_def[reg] = n
|
||||
|
||||
def handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
@@ -361,20 +371,22 @@ class Merger:
|
||||
addr_i = addr + i
|
||||
handle_mem_access(addr_i, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if block.warn_about_mem and not warned_about_mem and \
|
||||
(instr.get_size() > 100):
|
||||
if block.warn_about_mem and \
|
||||
not block.parent.warned_about_mem and \
|
||||
(instr.get_size() > 100) and not instr._protect:
|
||||
print('WARNING: Order of memory instructions ' \
|
||||
'not preserved due to long vector, errors possible')
|
||||
warned_about_mem.append(True)
|
||||
block.parent.warned_about_mem = True
|
||||
else:
|
||||
handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if block.warn_about_mem and not warned_about_mem and \
|
||||
not isinstance(instr, DirectMemoryInstruction):
|
||||
if block.warn_about_mem and \
|
||||
not block.parent.warned_about_mem and \
|
||||
not isinstance(instr, DirectMemoryInstruction) and \
|
||||
not instr._protect:
|
||||
print('WARNING: Order of memory instructions ' \
|
||||
'not preserved, errors possible')
|
||||
# hack
|
||||
warned_about_mem.append(True)
|
||||
block.parent.warned_about_mem = True
|
||||
|
||||
def strict_mem_access(n, last_this_kind, last_other_kind):
|
||||
if last_other_kind and last_this_kind and \
|
||||
@@ -403,6 +415,20 @@ class Merger:
|
||||
add_edge(last_input[t][1], n)
|
||||
last_input[t][0] = n
|
||||
|
||||
def keep_text_order(inst, n):
|
||||
if inst.get_players() is None:
|
||||
# switch
|
||||
for x in list(last_input.keys()):
|
||||
if isinstance(x, int):
|
||||
add_edge(last_input[x][0], n)
|
||||
del last_input[x]
|
||||
keep_merged_order(instr, n, None)
|
||||
elif last_input[None][0] is not None:
|
||||
keep_merged_order(instr, n, None)
|
||||
else:
|
||||
for player in inst.get_players():
|
||||
keep_merged_order(instr, n, player)
|
||||
|
||||
for n,instr in enumerate(block.instructions):
|
||||
outputs,inputs = instr.get_def(), instr.get_used()
|
||||
|
||||
@@ -411,13 +437,6 @@ class Merger:
|
||||
# if options.debug:
|
||||
# col = colordict[instr.__class__.__name__]
|
||||
# G.add_node(n, color=col, label=str(instr))
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
|
||||
for reg in outputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
@@ -425,9 +444,16 @@ class Merger:
|
||||
else:
|
||||
write(reg, n)
|
||||
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
|
||||
# will be merged
|
||||
if isinstance(instr, TextInputInstruction):
|
||||
keep_merged_order(instr, n, TextInputInstruction)
|
||||
keep_text_order(instr, n)
|
||||
elif isinstance(instr, RawInputInstruction):
|
||||
keep_merged_order(instr, n, RawInputInstruction)
|
||||
|
||||
@@ -456,14 +482,14 @@ class Merger:
|
||||
depths[n] = depth
|
||||
|
||||
if isinstance(instr, ReadMemoryInstruction):
|
||||
if options.preserve_mem_order:
|
||||
if options.preserve_mem_order or instr._protect:
|
||||
strict_mem_access(n, last_mem_read, last_mem_write)
|
||||
else:
|
||||
elif not options.preserve_mem_order:
|
||||
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
|
||||
elif isinstance(instr, WriteMemoryInstruction):
|
||||
if options.preserve_mem_order:
|
||||
if options.preserve_mem_order or instr._protect:
|
||||
strict_mem_access(n, last_mem_write, last_mem_read)
|
||||
else:
|
||||
elif not options.preserve_mem_order:
|
||||
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
|
||||
elif isinstance(instr, matmulsm):
|
||||
if options.preserve_mem_order:
|
||||
@@ -478,11 +504,7 @@ class Merger:
|
||||
add_edge(last_print_str, n)
|
||||
last_print_str = n
|
||||
elif isinstance(instr, PublicFileIOInstruction):
|
||||
keep_order(instr, n, instr.__class__)
|
||||
elif isinstance(instr, startprivateoutput_class):
|
||||
keep_order(instr, n, startprivateoutput_class, 2)
|
||||
elif isinstance(instr, stopprivateoutput_class):
|
||||
keep_order(instr, n, stopprivateoutput_class, 2)
|
||||
keep_order(instr, n, PublicFileIOInstruction)
|
||||
elif isinstance(instr, prep_class):
|
||||
keep_order(instr, n, instr.args[0])
|
||||
elif isinstance(instr, StackInstruction):
|
||||
@@ -520,7 +542,9 @@ class Merger:
|
||||
can_eliminate_defs = True
|
||||
for reg in inst.get_def():
|
||||
for dup in reg.duplicates:
|
||||
if not dup.can_eliminate:
|
||||
if not (dup.can_eliminate and reduce(
|
||||
operator.and_,
|
||||
(x.can_eliminate for x in dup.vector), True)):
|
||||
can_eliminate_defs = False
|
||||
break
|
||||
# remove if instruction has result that isn't used
|
||||
@@ -535,18 +559,6 @@ class Merger:
|
||||
if unused_result:
|
||||
eliminate(i)
|
||||
count += 1
|
||||
# remove unnecessary stack instructions
|
||||
# left by optimization with budget
|
||||
if isinstance(inst, popint_class) and \
|
||||
(not G.degree(i) or (G.degree(i) == 1 and
|
||||
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
|
||||
and \
|
||||
inst.args[0].can_eliminate and \
|
||||
len(G.pred[i]) == 1 and \
|
||||
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
|
||||
eliminate(list(G.pred[i])[0])
|
||||
eliminate(i)
|
||||
count += 2
|
||||
if count > 0 and self.block.parent.program.verbose:
|
||||
print('Eliminated %d dead instructions, among which %d opens: %s' \
|
||||
% (count, open_count, dict(stats)))
|
||||
@@ -570,8 +582,15 @@ class Merger:
|
||||
class RegintOptimizer:
|
||||
def __init__(self):
|
||||
self.cache = util.dict_by_id()
|
||||
self.offset_cache = util.dict_by_id()
|
||||
self.rev_offset_cache = {}
|
||||
|
||||
def run(self, instructions):
|
||||
def add_offset(self, res, new_base, new_offset):
|
||||
self.offset_cache[res] = new_base, new_offset
|
||||
if (new_base.i, new_offset) not in self.rev_offset_cache:
|
||||
self.rev_offset_cache[new_base.i, new_offset] = res
|
||||
|
||||
def run(self, instructions, program):
|
||||
for i, inst in enumerate(instructions):
|
||||
if isinstance(inst, ldint_class):
|
||||
self.cache[inst.args[0]] = inst.args[1]
|
||||
@@ -584,15 +603,35 @@ class RegintOptimizer:
|
||||
instructions[i] = ldint(inst.args[0], res,
|
||||
add_to_prog=False)
|
||||
elif isinstance(inst, addint_class):
|
||||
if inst.args[1] in self.cache and \
|
||||
self.cache[inst.args[1]] == 0:
|
||||
instructions[i] = inst.args[0].link(inst.args[2])
|
||||
elif inst.args[2] in self.cache and \
|
||||
self.cache[inst.args[2]] == 0:
|
||||
instructions[i] = inst.args[0].link(inst.args[1])
|
||||
def f(base, delta_reg):
|
||||
delta = self.cache[delta_reg]
|
||||
if base in self.offset_cache:
|
||||
reg, offset = self.offset_cache[base]
|
||||
new_base, new_offset = reg, offset + delta
|
||||
else:
|
||||
new_base, new_offset = base, delta
|
||||
self.add_offset(inst.args[0], new_base, new_offset)
|
||||
if inst.args[1] in self.cache:
|
||||
f(inst.args[2], inst.args[1])
|
||||
elif inst.args[2] in self.cache:
|
||||
f(inst.args[1], inst.args[2])
|
||||
elif isinstance(inst, subint_class) and \
|
||||
inst.args[2] in self.cache:
|
||||
delta = self.cache[inst.args[2]]
|
||||
if inst.args[1] in self.offset_cache:
|
||||
reg, offset = self.offset_cache[inst.args[1]]
|
||||
new_base, new_offset = reg, offset - delta
|
||||
else:
|
||||
new_base, new_offset = inst.args[1], -delta
|
||||
self.add_offset(inst.args[0], new_base, new_offset)
|
||||
elif isinstance(inst, IndirectMemoryInstruction):
|
||||
if inst.args[1] in self.cache:
|
||||
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
|
||||
instructions[i]._protect = inst._protect
|
||||
elif inst.args[1] in self.offset_cache:
|
||||
base, offset = self.offset_cache[inst.args[1]]
|
||||
addr = self.rev_offset_cache[base.i, offset]
|
||||
inst.args[1] = addr
|
||||
elif type(inst) == convint_class:
|
||||
if inst.args[1] in self.cache:
|
||||
res = self.cache[inst.args[1]]
|
||||
@@ -606,7 +645,13 @@ class RegintOptimizer:
|
||||
if op == 0:
|
||||
instructions[i] = ldsi(inst.args[0], 0,
|
||||
add_to_prog=False)
|
||||
elif op == 1:
|
||||
elif isinstance(inst, (crash, cond_print_str, cond_print_plain)):
|
||||
if inst.args[0] in self.cache:
|
||||
cond = self.cache[inst.args[0]]
|
||||
if not cond:
|
||||
instructions[i] = None
|
||||
inst.args[0].link(inst.args[1])
|
||||
pre = len(instructions)
|
||||
instructions[:] = list(filter(lambda x: x is not None, instructions))
|
||||
post = len(instructions)
|
||||
if pre != post and program.options.verbose:
|
||||
print('regint optimizer removed %d instructions' % (pre - post))
|
||||
|
||||
@@ -63,7 +63,7 @@ class Circuit:
|
||||
i = 0
|
||||
for l in self.n_output_wires:
|
||||
v = []
|
||||
for i in range(l):
|
||||
for j in range(l):
|
||||
v.append(flat_res[i])
|
||||
i += 1
|
||||
res.append(sbitvec.from_vec(v))
|
||||
@@ -127,18 +127,24 @@ def sha3_256(x):
|
||||
|
||||
from circuit import sha3_256
|
||||
a = sbitvec.from_vec([])
|
||||
b = sbitvec(sint(0xcc), 8, 8)
|
||||
for x in a, b:
|
||||
sha3_256(x).elements()[0].reveal().print_reg()
|
||||
b = sbitvec.from_hex('cc')
|
||||
c = sbitvec.from_hex('41fb')
|
||||
d = sbitvec.from_hex('1f877c')
|
||||
e = sbitvec.from_vec([sbit(0)] * 8)
|
||||
for x in a, b, c, d, e:
|
||||
sha3_256(x).reveal_print_hex()
|
||||
|
||||
This should output the first two test vectors of SHA3-256 in
|
||||
byte-reversed order::
|
||||
This should output the `test vectors
|
||||
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
|
||||
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
|
||||
0 byte::
|
||||
|
||||
0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7
|
||||
0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067
|
||||
Reg[0] = 0xa7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a #
|
||||
Reg[0] = 0x677035391cd3701293d385f037ba32796252bb7ce180b00b582dd9b20aaad7f0 #
|
||||
Reg[0] = 0x39f31b6e653dfcd9caed2602fd87f61b6254f581312fb6eeec4d7148fa2e72aa #
|
||||
Reg[0] = 0xbc22345e4bd3f792a341cf18ac0789f1c9c966712a501b19d1b6632ccd408ec5 #
|
||||
Reg[0] = 0x5d53469f20fef4f8eab52b88044ede69c77a6a68a60728609fc4a65ff531e7d0 #
|
||||
|
||||
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
|
||||
implemented for computation modulo a power of two.
|
||||
"""
|
||||
|
||||
global Keccak_f
|
||||
@@ -236,10 +242,10 @@ class ieee_float:
|
||||
return cls._circuits[name]
|
||||
|
||||
def __init__(self, value):
|
||||
if isinstance(value, sbitvec):
|
||||
if isinstance(value, (sbitint, sbitintvec)):
|
||||
self.value = self.circuit('i2f')(sbitvec.conv(value))
|
||||
elif isinstance(value, sbitvec):
|
||||
self.value = value
|
||||
elif isinstance(value, (sbitint, sbitintvec)):
|
||||
self.value = self.circuit('i2f')(sbitvec(value))
|
||||
elif util.is_constant_float(value):
|
||||
self.value = sbitvec(sbits.get_type(64)(
|
||||
struct.unpack('Q', struct.pack('d', value))[0]))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
from Compiler.path_oram import *
|
||||
from Compiler.oram import *
|
||||
from Compiler.path_oram import PathORAM, XOR
|
||||
from Compiler.util import bit_compose
|
||||
|
||||
def first_diff(a_bits, b_bits):
|
||||
|
||||
@@ -50,6 +50,9 @@ def set_variant(options):
|
||||
do_precomp = False
|
||||
elif variant is not None:
|
||||
raise CompilerError('Unknown comparison variant: %s' % variant)
|
||||
if const_rounds and instructions_base.program.options.binary:
|
||||
raise CompilerError(
|
||||
'Comparison variant choice incompatible with binary circuits')
|
||||
|
||||
def ld2i(c, n):
|
||||
""" Load immediate 2^n into clear GF(p) register c """
|
||||
@@ -77,30 +80,28 @@ def LTZ(s, a, k, kappa):
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
movs(s, program.non_linear.ltz(a, k, kappa))
|
||||
|
||||
def LtzRing(a, k):
|
||||
from .types import sint, _bitint
|
||||
from .GC.types import sbitvec
|
||||
if program.use_split():
|
||||
summands = a.split_to_two_summands(k)
|
||||
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
|
||||
msb = carry ^ summands[0][-1] ^ summands[1][-1]
|
||||
movs(s, sint.conv(msb))
|
||||
return
|
||||
elif program.options.ring:
|
||||
return sint.conv(msb)
|
||||
else:
|
||||
from . import floatingpoint
|
||||
require_ring_size(k, 'comparison')
|
||||
m = k - 1
|
||||
shift = int(program.options.ring) - k
|
||||
r_prime, r_bin = MaskingBitsInRing(k)
|
||||
tmp = a - r_prime
|
||||
c_prime = (tmp << shift).reveal() >> shift
|
||||
c_prime = (tmp << shift).reveal(False) >> shift
|
||||
a = r_bin[0].bit_decompose_clear(c_prime, m)
|
||||
b = r_bin[:m]
|
||||
u = CarryOutRaw(a[::-1], b[::-1])
|
||||
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
|
||||
return
|
||||
t = sint()
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
from . import types
|
||||
@@ -191,7 +192,7 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
r = sint.bit_compose(r_bits)
|
||||
if signed:
|
||||
a += (1 << (k - 1))
|
||||
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
|
||||
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
|
||||
masked = shifted >> n_shift
|
||||
u = sint()
|
||||
BitLTL(u, masked, r_bits[:n_bits], 0)
|
||||
@@ -232,7 +233,7 @@ def Mod2mRing(a_prime, a, k, m, signed):
|
||||
shift = int(program.options.ring) - m
|
||||
r_prime, r_bin = MaskingBitsInRing(m, True)
|
||||
tmp = a + r_prime
|
||||
c_prime = (tmp << shift).reveal() >> shift
|
||||
c_prime = (tmp << shift).reveal(False) >> shift
|
||||
u = sint()
|
||||
BitLTL(u, c_prime, r_bin[:m], 0)
|
||||
res = (u << m) + c_prime - r_prime
|
||||
@@ -262,7 +263,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
t[1] = a
|
||||
adds(t[2], t[0], t[1])
|
||||
adds(t[3], t[2], r_prime)
|
||||
asm_open(c, t[3])
|
||||
asm_open(True, c, t[3])
|
||||
modc(c_prime, c, c2m)
|
||||
if const_rounds:
|
||||
BitLTC1(u, c_prime, r, kappa)
|
||||
@@ -293,7 +294,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
"""
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
from .types import sint
|
||||
if program.use_edabit() and m > 1 and not const_rounds:
|
||||
if program.use_edabit() and not const_rounds:
|
||||
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
|
||||
tmp, b[:] = sint.get_edabit(m, True)
|
||||
movs(r_prime, tmp)
|
||||
@@ -511,7 +512,7 @@ def PreMulC_with_inverses_and_vectors(p, a):
|
||||
movs(w[0], r[0])
|
||||
movs(a_vec[0], a[0])
|
||||
vmuls(k, t[0], w, a_vec)
|
||||
vasm_open(k, m, t[0])
|
||||
vasm_open(k, True, m, t[0])
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_with_inverses(p, a):
|
||||
@@ -539,7 +540,7 @@ def PreMulC_with_inverses(p, a):
|
||||
w[1][0] = r[0][0]
|
||||
for i in range(k):
|
||||
muls(t[0][i], w[1][i], a[i])
|
||||
asm_open(m[i], t[0][i])
|
||||
asm_open(True, m[i], t[0][i])
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_without_inverses(p, a):
|
||||
@@ -564,7 +565,7 @@ def PreMulC_without_inverses(p, a):
|
||||
#adds(tt[0][i], t[0][i], a[i])
|
||||
#subs(tt[1][i], tt[0][i], a[i])
|
||||
#startopen(tt[1][i])
|
||||
asm_open(u[i], t[0][i])
|
||||
asm_open(True, u[i], t[0][i])
|
||||
for i in range(k-1):
|
||||
muls(v[i], r[i+1], s[i])
|
||||
w[0] = r[0]
|
||||
@@ -580,7 +581,7 @@ def PreMulC_without_inverses(p, a):
|
||||
mulm(z[i], s[i], u_inv[i])
|
||||
for i in range(k):
|
||||
muls(t[1][i], w[i], a[i])
|
||||
asm_open(m[i], t[1][i])
|
||||
asm_open(True, m[i], t[1][i])
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_end(p, a, c, m, z):
|
||||
@@ -638,6 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
|
||||
r_0 = r_prime
|
||||
mulsi(t[0], r_dprime, 2)
|
||||
if signed:
|
||||
ld2i(c2k1, k - 1)
|
||||
@@ -646,7 +648,7 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
t[1] = a
|
||||
adds(t[2], t[0], t[1])
|
||||
adds(t[3], t[2], r_prime)
|
||||
asm_open(c, t[3])
|
||||
asm_open(True, c, t[3])
|
||||
from . import floatingpoint
|
||||
c_0 = floatingpoint.bits(c, 1)[0]
|
||||
mulci(tc, c_0, 2)
|
||||
|
||||
@@ -1,94 +1,568 @@
|
||||
from Compiler.program import Program
|
||||
from .GC import types as GC_types
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import re, tempfile, os
|
||||
import tempfile
|
||||
import subprocess
|
||||
from optparse import OptionParser
|
||||
|
||||
from Compiler.exceptions import CompilerError
|
||||
|
||||
from .GC import types as GC_types
|
||||
from .program import Program, defaults
|
||||
|
||||
|
||||
def run(args, options):
|
||||
""" Compile a file and output a Program object.
|
||||
|
||||
If options.merge_opens is set to True, will attempt to merge any
|
||||
parallelisable open instructions. """
|
||||
|
||||
prog = Program(args, options)
|
||||
VARS['program'] = prog
|
||||
if options.binary:
|
||||
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
|
||||
VARS['sfix'] = GC_types.sbitfixvec
|
||||
for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \
|
||||
'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \
|
||||
'squant':
|
||||
del VARS[i]
|
||||
|
||||
print('Compiling file', prog.infile)
|
||||
f = open(prog.infile, 'rb')
|
||||
|
||||
changed = False
|
||||
if options.flow_optimization:
|
||||
output = []
|
||||
if_stack = []
|
||||
for line in open(prog.infile):
|
||||
if if_stack and not re.match(if_stack[-1][0], line):
|
||||
if_stack.pop()
|
||||
m = re.match(
|
||||
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
|
||||
line)
|
||||
if m:
|
||||
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
|
||||
m.group(3)))
|
||||
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)if(\W.*):', line)
|
||||
if m:
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
|
||||
output.append('%sdef _():\n' % (m.group(1)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)elif\s+', line)
|
||||
if m:
|
||||
raise CompilerError('elif not supported')
|
||||
if if_stack:
|
||||
m = re.match('%selse:' % if_stack[-1][0], line)
|
||||
if m:
|
||||
start = if_stack[-1][1]
|
||||
ws = if_stack[-1][0]
|
||||
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
|
||||
output[start])
|
||||
output.append('%s@else_\n' % ws)
|
||||
output.append('%sdef _():\n' % ws)
|
||||
continue
|
||||
output.append(line)
|
||||
if changed:
|
||||
infile = tempfile.NamedTemporaryFile('w+', delete=False)
|
||||
for line in output:
|
||||
infile.write(line)
|
||||
infile.seek(0)
|
||||
class Compiler:
|
||||
def __init__(self, custom_args=None, usage=None, execute=False):
|
||||
if usage:
|
||||
self.usage = usage
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
self.usage = "usage: %prog [options] filename [args]"
|
||||
self.execute = execute
|
||||
self.custom_args = custom_args
|
||||
self.build_option_parser()
|
||||
self.VARS = {}
|
||||
self.root = os.path.dirname(__file__) + '/..'
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
exec(compile(infile.read(), infile.name, 'exec'), VARS)
|
||||
def build_option_parser(self):
|
||||
parser = OptionParser(usage=self.usage)
|
||||
parser.add_option(
|
||||
"-n",
|
||||
"--nomerge",
|
||||
action="store_false",
|
||||
dest="merge_opens",
|
||||
default=defaults.merge_opens,
|
||||
help="don't attempt to merge open instructions",
|
||||
)
|
||||
parser.add_option("-o", "--output", dest="outfile", help="specify output file")
|
||||
parser.add_option(
|
||||
"-a",
|
||||
"--asm-output",
|
||||
dest="asmoutfile",
|
||||
help="asm output file for debugging",
|
||||
)
|
||||
parser.add_option(
|
||||
"-g",
|
||||
"--galoissize",
|
||||
dest="galois",
|
||||
default=defaults.galois,
|
||||
help="bit length of Galois field",
|
||||
)
|
||||
parser.add_option(
|
||||
"-d",
|
||||
"--debug",
|
||||
action="store_true",
|
||||
dest="debug",
|
||||
help="keep track of trace for debugging",
|
||||
)
|
||||
parser.add_option(
|
||||
"-c",
|
||||
"--comparison",
|
||||
dest="comparison",
|
||||
default="log",
|
||||
help="comparison variant: log|plain|inv|sinv",
|
||||
)
|
||||
parser.add_option(
|
||||
"-M",
|
||||
"--preserve-mem-order",
|
||||
action="store_true",
|
||||
dest="preserve_mem_order",
|
||||
default=defaults.preserve_mem_order,
|
||||
help="preserve order of memory instructions; possible efficiency loss",
|
||||
)
|
||||
parser.add_option(
|
||||
"-O",
|
||||
"--optimize-hard",
|
||||
action="store_true",
|
||||
dest="optimize_hard",
|
||||
help="lower number of rounds at higher compilation cost "
|
||||
"(disables -C and increases the budget to 100000)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-u",
|
||||
"--noreallocate",
|
||||
action="store_true",
|
||||
dest="noreallocate",
|
||||
default=defaults.noreallocate,
|
||||
help="don't reallocate",
|
||||
)
|
||||
parser.add_option(
|
||||
"-m",
|
||||
"--max-parallel-open",
|
||||
dest="max_parallel_open",
|
||||
default=defaults.max_parallel_open,
|
||||
help="restrict number of parallel opens",
|
||||
)
|
||||
parser.add_option(
|
||||
"-D",
|
||||
"--dead-code-elimination",
|
||||
action="store_true",
|
||||
dest="dead_code_elimination",
|
||||
default=defaults.dead_code_elimination,
|
||||
help="eliminate instructions with unused result",
|
||||
)
|
||||
parser.add_option(
|
||||
"-p",
|
||||
"--profile",
|
||||
action="store_true",
|
||||
dest="profile",
|
||||
help="profile compilation",
|
||||
)
|
||||
parser.add_option(
|
||||
"-s",
|
||||
"--stop",
|
||||
action="store_true",
|
||||
dest="stop",
|
||||
help="stop on register errors",
|
||||
)
|
||||
parser.add_option(
|
||||
"-R",
|
||||
"--ring",
|
||||
dest="ring",
|
||||
default=defaults.ring,
|
||||
help="bit length of ring (default: 0 for field)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-B",
|
||||
"--binary",
|
||||
dest="binary",
|
||||
default=defaults.binary,
|
||||
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-G",
|
||||
"--garbled-circuit",
|
||||
dest="garbled",
|
||||
action="store_true",
|
||||
help="compile for binary circuits only (default: false)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-F",
|
||||
"--field",
|
||||
dest="field",
|
||||
default=defaults.field,
|
||||
help="bit length of sint modulo prime (default: 64)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-P",
|
||||
"--prime",
|
||||
dest="prime",
|
||||
default=defaults.prime,
|
||||
help="prime modulus (default: not specified)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-I",
|
||||
"--insecure",
|
||||
action="store_true",
|
||||
dest="insecure",
|
||||
help="activate insecure functionality for benchmarking",
|
||||
)
|
||||
parser.add_option(
|
||||
"-b",
|
||||
"--budget",
|
||||
dest="budget",
|
||||
help="set budget for optimized loop unrolling (default: %d)" % \
|
||||
defaults.budget,
|
||||
)
|
||||
parser.add_option(
|
||||
"-X",
|
||||
"--mixed",
|
||||
action="store_true",
|
||||
dest="mixed",
|
||||
help="mixing arithmetic and binary computation",
|
||||
)
|
||||
parser.add_option(
|
||||
"-Y",
|
||||
"--edabit",
|
||||
action="store_true",
|
||||
dest="edabit",
|
||||
help="mixing arithmetic and binary computation using edaBits",
|
||||
)
|
||||
parser.add_option(
|
||||
"-Z",
|
||||
"--split",
|
||||
default=defaults.split,
|
||||
dest="split",
|
||||
help="mixing arithmetic and binary computation "
|
||||
"using direct conversion if supported "
|
||||
"(number of parties as argument)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--invperm",
|
||||
action="store_true",
|
||||
dest="invperm",
|
||||
help="speedup inverse permutation (only use in two-party, "
|
||||
"semi-honest environment)"
|
||||
)
|
||||
parser.add_option(
|
||||
"-C",
|
||||
"--CISC",
|
||||
action="store_true",
|
||||
dest="cisc",
|
||||
help="faster CISC compilation mode "
|
||||
"(used by default unless -O is given)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-K",
|
||||
"--keep-cisc",
|
||||
dest="keep_cisc",
|
||||
help="don't translate CISC instructions",
|
||||
)
|
||||
parser.add_option(
|
||||
"-l",
|
||||
"--flow-optimization",
|
||||
action="store_true",
|
||||
dest="flow_optimization",
|
||||
help="optimize control flow",
|
||||
)
|
||||
parser.add_option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
help="more verbose output",
|
||||
)
|
||||
if self.execute:
|
||||
parser.add_option(
|
||||
"-E",
|
||||
"--execute",
|
||||
dest="execute",
|
||||
help="protocol to execute with",
|
||||
)
|
||||
parser.add_option(
|
||||
"-H",
|
||||
"--hostfile",
|
||||
dest="hostfile",
|
||||
help="hosts to execute with",
|
||||
)
|
||||
self.parser = parser
|
||||
|
||||
if changed and not options.debug:
|
||||
os.unlink(infile.name)
|
||||
def parse_args(self):
|
||||
self.options, self.args = self.parser.parse_args(self.custom_args)
|
||||
if self.execute:
|
||||
if not self.options.execute:
|
||||
raise CompilerError("must give name of protocol with '-E'")
|
||||
protocol = self.options.execute
|
||||
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
|
||||
protocol.find("brain") >= 0 or protocol == "emulate":
|
||||
if not (self.options.ring or self.options.binary):
|
||||
self.options.ring = "64"
|
||||
if self.options.field:
|
||||
raise CompilerError(
|
||||
"field option not compatible with %s" % protocol)
|
||||
else:
|
||||
if protocol.find("bin") >= 0 or protocol.find("ccd") >= 0 or \
|
||||
protocol.find("bmr") >= 0 or \
|
||||
protocol in ("replicated", "tinier", "tiny", "yao"):
|
||||
if not self.options.binary:
|
||||
self.options.binary = "32"
|
||||
if self.options.ring or self.options.field:
|
||||
raise CompilerError(
|
||||
"ring/field options not compatible with %s" %
|
||||
protocol)
|
||||
if self.options.ring:
|
||||
raise CompilerError(
|
||||
"ring option not compatible with %s" % protocol)
|
||||
if protocol == "emulate":
|
||||
self.options.keep_cisc = ''
|
||||
|
||||
prog.finalize()
|
||||
def build_program(self, name=None):
|
||||
self.prog = Program(self.args, self.options, name=name)
|
||||
if self.execute:
|
||||
if self.options.execute in \
|
||||
("emulate", "ring", "rep-field"):
|
||||
self.prog.use_trunc_pr = True
|
||||
if self.options.execute in ("ring",):
|
||||
self.prog.use_split(3)
|
||||
if self.options.execute in ("semi2k",):
|
||||
self.prog.use_split(2)
|
||||
if self.options.execute in ("rep4-ring",):
|
||||
self.prog.use_split(4)
|
||||
|
||||
if prog.req_num:
|
||||
print('Program requires:')
|
||||
for x in prog.req_num.pretty():
|
||||
print(x)
|
||||
def build_vars(self):
|
||||
from . import comparison, floatingpoint, instructions, library, types
|
||||
|
||||
if prog.verbose:
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
print('Memory size:', dict(prog.allocated_mem))
|
||||
# add all instructions to the program VARS dictionary
|
||||
instr_classes = [
|
||||
t[1] for t in inspect.getmembers(instructions, inspect.isclass)
|
||||
]
|
||||
|
||||
return prog
|
||||
for mod in (types, GC_types):
|
||||
instr_classes += [
|
||||
t[1]
|
||||
for t in inspect.getmembers(mod, inspect.isclass)
|
||||
if t[1].__module__ == mod.__name__
|
||||
]
|
||||
|
||||
instr_classes += [
|
||||
t[1]
|
||||
for t in inspect.getmembers(library, inspect.isfunction)
|
||||
if t[1].__module__ == library.__name__
|
||||
]
|
||||
|
||||
for op in instr_classes:
|
||||
self.VARS[op.__name__] = op
|
||||
|
||||
# backward compatibility for deprecated classes
|
||||
self.VARS["sbitint"] = GC_types.sbitintvec
|
||||
self.VARS["sbitfix"] = GC_types.sbitfixvec
|
||||
|
||||
# add open and input separately due to name conflict
|
||||
self.VARS["vopen"] = instructions.vasm_open
|
||||
self.VARS["gopen"] = instructions.gasm_open
|
||||
self.VARS["vgopen"] = instructions.vgasm_open
|
||||
self.VARS["ginput"] = instructions.gasm_input
|
||||
|
||||
self.VARS["comparison"] = comparison
|
||||
self.VARS["floatingpoint"] = floatingpoint
|
||||
|
||||
self.VARS["program"] = self.prog
|
||||
if self.options.binary:
|
||||
self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary))
|
||||
self.VARS["sfix"] = GC_types.sbitfixvec
|
||||
for i in [
|
||||
"cint",
|
||||
"cfix",
|
||||
"cgf2n",
|
||||
"sintbit",
|
||||
"sgf2n",
|
||||
"sgf2nint",
|
||||
"sgf2nuint",
|
||||
"sgf2nuint32",
|
||||
"sgf2nfloat",
|
||||
"cfloat",
|
||||
"squant",
|
||||
]:
|
||||
del self.VARS[i]
|
||||
|
||||
def prep_compile(self, name=None, build=True):
|
||||
self.parse_args()
|
||||
if len(self.args) < 1 and name is None:
|
||||
self.parser.print_help()
|
||||
exit(1)
|
||||
if build:
|
||||
self.build(name=name)
|
||||
|
||||
def build(self, name=None):
|
||||
self.build_program(name=name)
|
||||
self.build_vars()
|
||||
|
||||
def compile_file(self):
|
||||
"""Compile a file and output a Program object.
|
||||
|
||||
If options.merge_opens is set to True, will attempt to merge any
|
||||
parallelisable open instructions."""
|
||||
print("Compiling file", self.prog.infile)
|
||||
|
||||
with open(self.prog.infile, "r") as f:
|
||||
changed = False
|
||||
if self.options.flow_optimization:
|
||||
output = []
|
||||
if_stack = []
|
||||
for line in f:
|
||||
if if_stack and not re.match(if_stack[-1][0], line):
|
||||
if_stack.pop()
|
||||
m = re.match(
|
||||
r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_.]+)\):",
|
||||
line,
|
||||
)
|
||||
if m:
|
||||
output.append(
|
||||
"%s@for_range_opt(%s)\n" % (m.group(1), m.group(3))
|
||||
)
|
||||
output.append("%sdef _(%s):\n" % (m.group(1), m.group(2)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match(r"(\s*)if(\W.*):", line)
|
||||
if m:
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
|
||||
output.append("%sdef _():\n" % (m.group(1)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match(r"(\s*)elif\s+", line)
|
||||
if m:
|
||||
raise CompilerError("elif not supported")
|
||||
if if_stack:
|
||||
m = re.match("%selse:" % if_stack[-1][0], line)
|
||||
if m:
|
||||
start = if_stack[-1][1]
|
||||
ws = if_stack[-1][0]
|
||||
output[start] = re.sub(
|
||||
r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start]
|
||||
)
|
||||
output.append("%s@else_\n" % ws)
|
||||
output.append("%sdef _():\n" % ws)
|
||||
continue
|
||||
output.append(line)
|
||||
if changed:
|
||||
infile = tempfile.NamedTemporaryFile("w+", delete=False)
|
||||
for line in output:
|
||||
infile.write(line)
|
||||
infile.seek(0)
|
||||
else:
|
||||
infile = open(self.prog.infile)
|
||||
else:
|
||||
infile = open(self.prog.infile)
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, "%s/Compiler" % self.root)
|
||||
# create the tapes
|
||||
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
|
||||
|
||||
if changed and not self.options.debug:
|
||||
os.unlink(infile.name)
|
||||
|
||||
return self.finalize_compile()
|
||||
|
||||
def register_function(self, name=None):
|
||||
"""
|
||||
To register a function to be compiled, use this as a decorator.
|
||||
Example:
|
||||
|
||||
@compiler.register_function('test-mpc')
|
||||
def test_mpc(compiler):
|
||||
...
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
self.compile_name = name or func.__name__
|
||||
self.compile_function = func
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
def compile_func(self):
|
||||
if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")):
|
||||
raise CompilerError(
|
||||
"No function to compile. "
|
||||
"Did you decorate a function with @register_fuction(name)?"
|
||||
)
|
||||
self.prep_compile(self.compile_name)
|
||||
print(
|
||||
"Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__)
|
||||
)
|
||||
self.compile_function()
|
||||
self.finalize_compile()
|
||||
|
||||
def finalize_compile(self):
|
||||
self.prog.finalize()
|
||||
|
||||
if self.prog.req_num:
|
||||
print("Program requires at most:")
|
||||
for x in self.prog.req_num.pretty():
|
||||
print(x)
|
||||
|
||||
if self.prog.verbose:
|
||||
print("Program requires:", repr(self.prog.req_num))
|
||||
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
|
||||
print("Memory size:", dict(self.prog.allocated_mem))
|
||||
|
||||
return self.prog
|
||||
|
||||
@staticmethod
|
||||
def executable_from_protocol(protocol):
|
||||
match = {
|
||||
"ring": "replicated-ring",
|
||||
"rep-field": "replicated-field",
|
||||
"replicated": "replicated-bin"
|
||||
}
|
||||
if protocol in match:
|
||||
protocol = match[protocol]
|
||||
if protocol.find("bmr") == -1:
|
||||
protocol = re.sub("^mal-", "malicious-", protocol)
|
||||
if protocol == "emulate":
|
||||
return protocol + ".x"
|
||||
else:
|
||||
return protocol + "-party.x"
|
||||
|
||||
def local_execution(self, args=[]):
|
||||
executable = self.executable_from_protocol(self.options.execute)
|
||||
if not os.path.exists("%s/%s" % (self.root, executable)):
|
||||
print("Creating binary for virtual machine...")
|
||||
try:
|
||||
subprocess.run(["make", executable], check=True, cwd=self.root)
|
||||
except:
|
||||
raise CompilerError(
|
||||
"Cannot produce %s. " % executable + \
|
||||
"Note that compilation requires a few GB of RAM.")
|
||||
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
|
||||
os.execl(vm, vm, self.prog.name, *args)
|
||||
|
||||
def remote_execution(self, args=[]):
|
||||
vm = self.executable_from_protocol(self.options.execute)
|
||||
hosts = list(x.strip()
|
||||
for x in filter(None, open(self.options.hostfile)))
|
||||
# test availability before compilation
|
||||
from fabric import Connection
|
||||
import subprocess
|
||||
print("Creating static binary for virtual machine...")
|
||||
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
|
||||
|
||||
# transfer files
|
||||
import glob
|
||||
hostnames = []
|
||||
destinations = []
|
||||
for host in hosts:
|
||||
split = host.split('/', maxsplit=1)
|
||||
hostnames.append(split[0])
|
||||
if len(split) > 1:
|
||||
destinations.append(split[1])
|
||||
else:
|
||||
destinations.append('.')
|
||||
connections = [Connection(hostname) for hostname in hostnames]
|
||||
print("Setting up players...")
|
||||
|
||||
def run(i):
|
||||
dest = destinations[i]
|
||||
connection = connections[i]
|
||||
connection.run(
|
||||
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
|
||||
dest)
|
||||
# executable
|
||||
connection.put("%s/static/%s" % (self.root, vm), dest)
|
||||
# program
|
||||
dest += "/"
|
||||
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
|
||||
dest + "Programs/Schedules")
|
||||
for filename in glob.glob(
|
||||
"Programs/Bytecode/%s-*.bc" % self.prog.name):
|
||||
connection.put(filename, dest + "Programs/Bytecode")
|
||||
# inputs
|
||||
for filename in glob.glob("Player-Data/Input*-P%d-*" % i):
|
||||
connection.put(filename, dest + "Player-Data")
|
||||
# key and certificates
|
||||
for suffix in ('key', 'pem'):
|
||||
connection.put("Player-Data/P%d.%s" % (i, suffix),
|
||||
dest + "Player-Data")
|
||||
for filename in glob.glob("Player-Data/*.0"):
|
||||
connection.put(filename, dest + "Player-Data")
|
||||
|
||||
import threading
|
||||
import random
|
||||
threads = []
|
||||
for i in range(len(hosts)):
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# execution
|
||||
threads = []
|
||||
# random port numbers to avoid conflict
|
||||
port = 10000 + random.randrange(40000)
|
||||
if '@' in hostnames[0]:
|
||||
party0 = hostnames[0].split('@')[1]
|
||||
else:
|
||||
party0 = hostnames[0]
|
||||
for i in range(len(connections)):
|
||||
run = lambda i: connections[i].run(
|
||||
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
|
||||
(destinations[i], vm, i, self.prog.name, party0, port,
|
||||
' '.join(args)))
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
645
Compiler/decision_tree.py
Normal file
645
Compiler/decision_tree.py
Normal file
@@ -0,0 +1,645 @@
|
||||
from Compiler.types import *
|
||||
from Compiler.sorting import *
|
||||
from Compiler.library import *
|
||||
from Compiler import util, oram
|
||||
|
||||
from itertools import accumulate
|
||||
import math
|
||||
|
||||
debug = False
|
||||
debug_split = False
|
||||
max_leaves = None
|
||||
|
||||
def get_type(x):
|
||||
if isinstance(x, (Array, SubMultiArray)):
|
||||
return x.value_type
|
||||
elif isinstance(x, (tuple, list)):
|
||||
x = x[0] + x[-1]
|
||||
if util.is_constant(x):
|
||||
return cint
|
||||
else:
|
||||
return type(x)
|
||||
else:
|
||||
return type(x)
|
||||
|
||||
def PrefixSum(x):
|
||||
return x.get_vector().prefix_sum()
|
||||
|
||||
def PrefixSumR(x):
|
||||
tmp = get_type(x).Array(len(x))
|
||||
tmp.assign_vector(x)
|
||||
break_point()
|
||||
tmp[:] = tmp.get_reverse_vector().prefix_sum()
|
||||
break_point()
|
||||
return tmp.get_reverse_vector()
|
||||
|
||||
def PrefixSum_inv(x):
|
||||
tmp = get_type(x).Array(len(x) + 1)
|
||||
tmp.assign_vector(x, base=1)
|
||||
tmp[0] = 0
|
||||
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))
|
||||
|
||||
def PrefixSumR_inv(x):
|
||||
tmp = get_type(x).Array(len(x) + 1)
|
||||
tmp.assign_vector(x)
|
||||
tmp[-1] = 0
|
||||
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))
|
||||
|
||||
class SortPerm:
|
||||
def __init__(self, x):
|
||||
B = sint.Matrix(len(x), 2)
|
||||
B.set_column(0, 1 - x.get_vector())
|
||||
B.set_column(1, x.get_vector())
|
||||
self.perm = Array.create_from(dest_comp(B))
|
||||
def apply(self, x):
|
||||
res = Array.create_from(x)
|
||||
reveal_sort(self.perm, res, False)
|
||||
return res
|
||||
def unapply(self, x):
|
||||
res = Array.create_from(x)
|
||||
reveal_sort(self.perm, res, True)
|
||||
return res
|
||||
|
||||
def Sort(keys, *to_sort, n_bits=None, time=False):
|
||||
if time:
|
||||
start_timer(1)
|
||||
for k in keys:
|
||||
assert len(k) == len(keys[0])
|
||||
n_bits = n_bits or [None] * len(keys)
|
||||
bs = Matrix.create_from(
|
||||
sum([k.get_vector().bit_decompose(nb)
|
||||
for k, nb in reversed(list(zip(keys, n_bits)))], []))
|
||||
get_vec = lambda x: x[:] if isinstance(x, Array) else x
|
||||
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
|
||||
for x in to_sort)
|
||||
res = res.transpose()
|
||||
if time:
|
||||
start_timer(11)
|
||||
radix_sort_from_matrix(bs, res)
|
||||
if time:
|
||||
stop_timer(11)
|
||||
stop_timer(1)
|
||||
res = res.transpose()
|
||||
return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
|
||||
if isinstance(get_vec(y), sfix)
|
||||
else x for (x, y) in zip(res, to_sort)]
|
||||
|
||||
def VectMax(key, *data, debug=False):
|
||||
def reducer(x, y):
|
||||
b = x[0] > y[0]
|
||||
if debug:
|
||||
print_ln('max b=%s', b.reveal())
|
||||
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
|
||||
if debug:
|
||||
key = list(key)
|
||||
data = [list(x) for x in data]
|
||||
print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
|
||||
res = util.tree_reduce(reducer, zip(key, *data))[1:]
|
||||
if debug:
|
||||
print_ln('vect max res=%s', util.reveal(res))
|
||||
return res
|
||||
|
||||
def GroupSum(g, x):
|
||||
assert len(g) == len(x)
|
||||
p = PrefixSumR(x) * g
|
||||
pi = SortPerm(g.get_vector().bit_not())
|
||||
p1 = pi.apply(p)
|
||||
s1 = PrefixSumR_inv(p1)
|
||||
d1 = PrefixSum_inv(s1)
|
||||
d = pi.unapply(d1) * g
|
||||
return PrefixSum(d)
|
||||
|
||||
def GroupPrefixSum(g, x):
|
||||
assert len(g) == len(x)
|
||||
s = get_type(x).Array(len(x) + 1)
|
||||
s[0] = 0
|
||||
s.assign_vector(PrefixSum(x), base=1)
|
||||
q = get_type(s).Array(len(x))
|
||||
q.assign_vector(s.get_vector(size=len(x)) * g)
|
||||
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
|
||||
|
||||
def GroupMax(g, keys, *x):
|
||||
if debug:
|
||||
print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
|
||||
util.reveal(keys), util.reveal(x))
|
||||
assert len(keys) == len(g)
|
||||
for xx in x:
|
||||
assert len(xx) == len(g)
|
||||
n = len(g)
|
||||
m = int(math.ceil(math.log(n, 2)))
|
||||
keys = Array.create_from(keys)
|
||||
x = [Array.create_from(xx) for xx in x]
|
||||
g_new = Array.create_from(g)
|
||||
g_old = g_new.same_shape()
|
||||
for d in range(m):
|
||||
w = 2 ** d
|
||||
g_old[:] = g_new[:]
|
||||
break_point()
|
||||
vsize = n - w
|
||||
g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
|
||||
g_old.get_vector(size=vsize, base=w)), base=w)
|
||||
b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
|
||||
for xx in [keys] + x:
|
||||
a = b.if_else(xx.get_vector(size=vsize),
|
||||
xx.get_vector(size=vsize, base=w))
|
||||
xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
|
||||
xx.get_vector(size=vsize, base=w), a), base=w)
|
||||
break_point()
|
||||
if debug:
|
||||
print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
|
||||
util.reveal(a), util.reveal(keys),
|
||||
util.reveal(x), g_new.reveal())
|
||||
t = sint.Array(len(g))
|
||||
t[-1] = 1
|
||||
t.assign_vector(g.get_vector(size=n - 1, base=1))
|
||||
if debug:
|
||||
print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
|
||||
util.reveal(t), util.reveal(keys), util.reveal(x))
|
||||
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
|
||||
|
||||
def ModifiedGini(g, y, debug=False):
|
||||
assert len(g) == len(y)
|
||||
y = [y.get_vector().bit_not(), y]
|
||||
u = [GroupPrefixSum(g, yy) for yy in y]
|
||||
s = [GroupSum(g, yy) for yy in y]
|
||||
w = [ss - uu for ss, uu in zip(s, u)]
|
||||
us = sum(u)
|
||||
ws = sum(w)
|
||||
uqs = u[0] ** 2 + u[1] ** 2
|
||||
wqs = w[0] ** 2 + w[1] ** 2
|
||||
res = sfix(uqs) / us + sfix(wqs) / ws
|
||||
if debug:
|
||||
print_ln('g=%s y=%s s=%s',
|
||||
util.reveal(g), util.reveal(y),
|
||||
util.reveal(s))
|
||||
print_ln('u0=%s', util.reveal(u[0]))
|
||||
print_ln('u0=%s', util.reveal(u[1]))
|
||||
print_ln('us=%s', util.reveal(us))
|
||||
print_ln('w0=%s', util.reveal(w[0]))
|
||||
print_ln('w1=%s', util.reveal(w[1]))
|
||||
print_ln('ws=%s', util.reveal(ws))
|
||||
print_ln('uqs=%s', util.reveal(uqs))
|
||||
print_ln('wqs=%s', util.reveal(wqs))
|
||||
if debug:
|
||||
print_ln('gini %s %s', type(res), util.reveal(res))
|
||||
return res
|
||||
|
||||
MIN_VALUE = -10000
|
||||
|
||||
def FormatLayer(h, g, *a):
|
||||
return CropLayer(h, *FormatLayer_without_crop(g, *a))
|
||||
|
||||
def FormatLayer_without_crop(g, *a, debug=False):
|
||||
for x in a:
|
||||
assert len(x) == len(g)
|
||||
v = [g.if_else(aa, 0) for aa in a]
|
||||
if debug:
|
||||
print_ln('format in %s', util.reveal(a))
|
||||
print_ln('format mux %s', util.reveal(v))
|
||||
v = Sort([g.bit_not()], *v, n_bits=[1])
|
||||
if debug:
|
||||
print_ln('format sort %s', util.reveal(v))
|
||||
return v
|
||||
|
||||
def CropLayer(k, *v):
|
||||
if max_leaves:
|
||||
n = min(2 ** k, max_leaves)
|
||||
else:
|
||||
n = 2 ** k
|
||||
return [vv[:min(n, len(vv))] for vv in v]
|
||||
|
||||
def TrainLeafNodes(h, g, y, NID):
|
||||
assert len(g) == len(y)
|
||||
assert len(g) == len(NID)
|
||||
Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
|
||||
return FormatLayer(h, g, NID, Label)
|
||||
|
||||
def GroupSame(g, y):
|
||||
assert len(g) == len(y)
|
||||
s = GroupSum(g, [sint(1)] * len(g))
|
||||
s0 = GroupSum(g, y.bit_not())
|
||||
s1 = GroupSum(g, y)
|
||||
if debug_split:
|
||||
print_ln('group same g=%s', util.reveal(g))
|
||||
print_ln('group same y=%s', util.reveal(y))
|
||||
return (s == s0).bit_or(s == s1)
|
||||
|
||||
def GroupFirstOne(g, b):
|
||||
assert len(g) == len(b)
|
||||
s = GroupPrefixSum(g, b)
|
||||
return s * b == 1
|
||||
|
||||
class TreeTrainer:
|
||||
""" Decision tree training by `Hamada et al.`_
|
||||
|
||||
:param x: sample data (by attribute, list or
|
||||
:py:obj:`~Compiler.types.Matrix`)
|
||||
:param y: binary labels (list or sint vector)
|
||||
:param h: height (int)
|
||||
:param binary: binary attributes instead of continuous
|
||||
:param attr_lengths: attribute description for mixed data
|
||||
(list of 0/1 for continuous/binary)
|
||||
:param n_threads: number of threads (default: single thread)
|
||||
|
||||
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
|
||||
|
||||
"""
|
||||
def ApplyTests(self, x, AID, Threshold):
|
||||
m = len(x)
|
||||
n = len(AID)
|
||||
assert len(AID) == len(Threshold)
|
||||
for xx in x:
|
||||
assert len(xx) == len(AID)
|
||||
e = sint.Matrix(m, n)
|
||||
AID = Array.create_from(AID)
|
||||
@for_range_multithread(self.n_threads, 1, m)
|
||||
def _(j):
|
||||
e[j][:] = AID[:] == j
|
||||
xx = sum(x[j] * e[j] for j in range(m))
|
||||
if self.debug > 1:
|
||||
print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx))
|
||||
print_ln('threshold %s', util.reveal(Threshold))
|
||||
return 2 * xx < Threshold
|
||||
|
||||
def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False):
|
||||
assert len(g) == len(x)
|
||||
assert len(g) == len(y)
|
||||
if time:
|
||||
start_timer(2)
|
||||
s = ModifiedGini(g, y, debug=debug or self.debug > 2)
|
||||
if time:
|
||||
stop_timer(2)
|
||||
if debug or self.debug > 1:
|
||||
print_ln('gini %s', s.reveal())
|
||||
xx = x
|
||||
t = get_type(x).Array(len(x))
|
||||
t[-1] = MIN_VALUE
|
||||
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
|
||||
xx.get_vector(size=len(x) - 1, base=1))
|
||||
gg = g
|
||||
p = sint.Array(len(x))
|
||||
p[-1] = 1
|
||||
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
|
||||
xx.get_vector(size=len(x) - 1) == \
|
||||
xx.get_vector(size=len(x) - 1, base=1)))
|
||||
break_point()
|
||||
if debug:
|
||||
print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p))
|
||||
s = p[:].if_else(MIN_VALUE, s)
|
||||
t = p[:].if_else(MIN_VALUE, t[:])
|
||||
if debug:
|
||||
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
|
||||
if time:
|
||||
start_timer(3)
|
||||
s, t = GroupMax(gg, s, t)
|
||||
if time:
|
||||
stop_timer(3)
|
||||
if debug:
|
||||
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
|
||||
return t, s
|
||||
|
||||
def GlobalTestSelection(self, x, y, g):
|
||||
assert len(y) == len(g)
|
||||
for xx in x:
|
||||
assert(len(xx) == len(g))
|
||||
m = len(x)
|
||||
n = len(y)
|
||||
u, t = [get_type(x).Matrix(m, n) for i in range(2)]
|
||||
v = get_type(y).Matrix(m, n)
|
||||
s = sfix.Matrix(m, n)
|
||||
@for_range_multithread(self.n_threads, 1, m)
|
||||
def _(j):
|
||||
single = not self.n_threads or self.n_threads == 1
|
||||
time = self.time and single
|
||||
if debug:
|
||||
print_ln('run %s', j)
|
||||
@if_e(self.attr_lengths[j])
|
||||
def _():
|
||||
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
|
||||
n_bits=[util.log2(n), 1], time=time)
|
||||
@else_
|
||||
def _():
|
||||
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
|
||||
n_bits=[util.log2(n), None],
|
||||
time=time)
|
||||
if self.debug_threading:
|
||||
print_ln('global sort %s %s %s', j, util.reveal(u[j]),
|
||||
util.reveal(v[j]))
|
||||
t[j][:], s[j][:] = self.AttributeWiseTestSelection(
|
||||
g, u[j], v[j], time=time, debug=self.debug_selection)
|
||||
if self.debug_threading:
|
||||
print_ln('global attribute %s %s %s', j, util.reveal(t[j]),
|
||||
util.reveal(s[j]))
|
||||
n = len(g)
|
||||
a = sint.Array(n)
|
||||
if self.debug_threading:
|
||||
print_ln('global s=%s', util.reveal(s))
|
||||
if self.debug_gini:
|
||||
print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)),
|
||||
*(ss[0].reveal() for ss in s))
|
||||
if self.time:
|
||||
start_timer(4)
|
||||
if self.debug > 1:
|
||||
print_ln('s=%s', s.reveal_nested())
|
||||
print_ln('t=%s', t.reveal_nested())
|
||||
a[:], tt = VectMax((s[j][:] for j in range(m)), range(m),
|
||||
(t[j][:] for j in range(m)), debug=self.debug > 1)
|
||||
tt = Array.create_from(tt)
|
||||
if self.time:
|
||||
stop_timer(4)
|
||||
if self.debug > 1:
|
||||
print_ln('a=%s', util.reveal(a))
|
||||
print_ln('tt=%s', util.reveal(tt))
|
||||
return a[:], tt[:]
|
||||
|
||||
def TrainInternalNodes(self, k, x, y, g, NID):
|
||||
assert len(g) == len(y)
|
||||
for xx in x:
|
||||
assert len(xx) == len(g)
|
||||
AID, Threshold = self.GlobalTestSelection(x, y, g)
|
||||
s = GroupSame(g[:], y[:])
|
||||
if self.debug > 1 or debug_split:
|
||||
print_ln('AID=%s', util.reveal(AID))
|
||||
print_ln('Threshold=%s', util.reveal(Threshold))
|
||||
print_ln('GroupSame=%s', util.reveal(s))
|
||||
AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold)
|
||||
if self.debug > 1 or debug_split:
|
||||
print_ln('AID=%s', util.reveal(AID))
|
||||
print_ln('Threshold=%s', util.reveal(Threshold))
|
||||
b = self.ApplyTests(x, AID, Threshold)
|
||||
layer = FormatLayer_without_crop(g[:], NID, AID, Threshold,
|
||||
debug=self.debug > 1)
|
||||
return *layer, b
|
||||
|
||||
@method_block
|
||||
def train_layer(self, k):
|
||||
x = self.x
|
||||
y = self.y
|
||||
g = self.g
|
||||
NID = self.NID
|
||||
if self.debug > 1:
|
||||
print_ln('g=%s', g.reveal())
|
||||
print_ln('y=%s', y.reveal())
|
||||
print_ln('x=%s', x.reveal_nested())
|
||||
self.nids[k], self.aids[k], self.thresholds[k], b = \
|
||||
self.TrainInternalNodes(k, x, y, g, NID)
|
||||
if self.debug > 1:
|
||||
print_ln('layer %s:', k)
|
||||
for name, data in zip(('NID', 'AID', 'Thr'),
|
||||
(self.nids[k], self.aids[k],
|
||||
self.thresholds[k])):
|
||||
print_ln(' %s: %s', name, data.reveal())
|
||||
NID[:] = 2 ** k * b + NID
|
||||
b_not = b.bit_not()
|
||||
if self.debug > 1:
|
||||
print_ln('b_not=%s', b_not.reveal())
|
||||
g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b)
|
||||
y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1])
|
||||
for i, xxx in enumerate(xx):
|
||||
x[i] = xxx
|
||||
|
||||
def __init__(self, x, y, h, binary=False, attr_lengths=None,
|
||||
n_threads=None):
|
||||
assert not (binary and attr_lengths)
|
||||
if binary:
|
||||
attr_lengths = [1] * len(x)
|
||||
else:
|
||||
attr_lengths = attr_lengths or ([0] * len(x))
|
||||
for l in attr_lengths:
|
||||
assert l in (0, 1)
|
||||
self.attr_lengths = Array.create_from(regint(attr_lengths))
|
||||
Array.check_indices = False
|
||||
Matrix.disable_index_checks()
|
||||
for xx in x:
|
||||
assert len(xx) == len(y)
|
||||
n = len(y)
|
||||
self.g = sint.Array(n)
|
||||
self.g.assign_all(0)
|
||||
self.g[0] = 1
|
||||
self.NID = sint.Array(n)
|
||||
self.NID.assign_all(1)
|
||||
self.y = Array.create_from(y)
|
||||
self.x = Matrix.create_from(x)
|
||||
self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)]
|
||||
self.thresholds = self.x.value_type.Matrix(h, n)
|
||||
self.n_threads = n_threads
|
||||
self.debug_selection = False
|
||||
self.debug_threading = False
|
||||
self.debug_gini = False
|
||||
self.debug = False
|
||||
self.time = False
|
||||
|
||||
def train(self):
|
||||
""" Train and return decision tree. """
|
||||
h = len(self.nids)
|
||||
@for_range(h)
|
||||
def _(k):
|
||||
self.train_layer(k)
|
||||
return self.get_tree(h)
|
||||
|
||||
def train_with_testing(self, *test_set, output=False):
|
||||
""" Train decision tree and test against test data.
|
||||
|
||||
:param y: binary labels (list or sint vector)
|
||||
:param x: sample data (by attribute, list or
|
||||
:py:obj:`~Compiler.types.Matrix`)
|
||||
:param output: output tree after every level
|
||||
:returns: tree
|
||||
|
||||
"""
|
||||
for k in range(len(self.nids)):
|
||||
self.train_layer(k)
|
||||
tree = self.get_tree(k + 1)
|
||||
if output:
|
||||
output_decision_tree(tree)
|
||||
test_decision_tree('train', tree, self.y, self.x,
|
||||
n_threads=self.n_threads)
|
||||
if test_set:
|
||||
test_decision_tree('test', tree, *test_set,
|
||||
n_threads=self.n_threads)
|
||||
return tree
|
||||
|
||||
def get_tree(self, h):
|
||||
Layer = [None] * (h + 1)
|
||||
for k in range(h):
|
||||
Layer[k] = CropLayer(k, self.nids[k], self.aids[k],
|
||||
self.thresholds[k])
|
||||
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID)
|
||||
return Layer
|
||||
|
||||
def DecisionTreeTraining(x, y, h, binary=False):
|
||||
return TreeTrainer(x, y, h, binary=binary).train()
|
||||
|
||||
def output_decision_tree(layers):
|
||||
""" Print decision tree output by :py:class:`TreeTrainer`. """
|
||||
print_ln('full model %s', util.reveal(layers))
|
||||
for i, layer in enumerate(layers[:-1]):
|
||||
print_ln('level %s:', i)
|
||||
for j, x in enumerate(('NID', 'AID', 'Thr')):
|
||||
print_ln(' %s: %s', x, util.reveal(layer[j]))
|
||||
print_ln('leaves:')
|
||||
for j, x in enumerate(('NID', 'result')):
|
||||
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
|
||||
|
||||
def pick(bits, x):
|
||||
if len(bits) == 1:
|
||||
return bits[0] * x[0]
|
||||
else:
|
||||
try:
|
||||
return x[0].dot_product(bits, x)
|
||||
except:
|
||||
return sum(aa * bb for aa, bb in zip(bits, x))
|
||||
|
||||
def run_decision_tree(layers, data):
|
||||
""" Run decision tree against sample data.
|
||||
|
||||
:param layers: tree output by :py:class:`TreeTrainer`
|
||||
:param data: sample data (:py:class:`~Compiler.types.Array`)
|
||||
:returns: binary label
|
||||
|
||||
"""
|
||||
h = len(layers) - 1
|
||||
index = 1
|
||||
for k, layer in enumerate(layers[:-1]):
|
||||
assert len(layer) == 3
|
||||
for x in layer:
|
||||
assert len(x) <= 2 ** k
|
||||
bits = layer[0].equal(index, k)
|
||||
threshold = pick(bits, layer[2])
|
||||
key_index = pick(bits, layer[1])
|
||||
if key_index.is_clear:
|
||||
key = data[key_index]
|
||||
else:
|
||||
key = pick(
|
||||
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
|
||||
child = 2 * key < threshold
|
||||
index += child * 2 ** k
|
||||
bits = layers[h][0].equal(index, h)
|
||||
return pick(bits, layers[h][1])
|
||||
|
||||
def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
|
||||
if time:
|
||||
start_timer(100)
|
||||
n = len(y)
|
||||
x = x.transpose().reveal()
|
||||
y = y.reveal()
|
||||
guess = regint.Array(n)
|
||||
truth = regint.Array(n)
|
||||
correct = regint.Array(2)
|
||||
parts = regint.Array(2)
|
||||
layers = [[Array.create_from(util.reveal(x)) for x in layer]
|
||||
for layer in layers]
|
||||
@for_range_multithread(n_threads, 1, n)
|
||||
def _(i):
|
||||
guess[i] = run_decision_tree([[part[:] for part in layer]
|
||||
for layer in layers], x[i]).reveal()
|
||||
truth[i] = y[i].reveal()
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
parts[truth[i]] += 1
|
||||
c = (guess[i].bit_xor(truth[i]).bit_not())
|
||||
correct[truth[i]] += c
|
||||
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
|
||||
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
|
||||
if time:
|
||||
stop_timer(100)
|
||||
|
||||
class TreeClassifier:
|
||||
""" Tree classification with convenient interface. Uses
|
||||
:py:class:`TreeTrainer` internally.
|
||||
|
||||
:param max_depth: the depth of the decision tree
|
||||
|
||||
"""
|
||||
def __init__(self, max_depth):
|
||||
self.max_depth = max_depth
|
||||
|
||||
@staticmethod
|
||||
def get_attr_lengths(attr_types):
|
||||
if attr_types == None:
|
||||
return None
|
||||
else:
|
||||
return [1 if x == 'b' else 0 for x in attr_types]
|
||||
|
||||
def fit(self, X, y, attr_types=None):
|
||||
""" Train tree.
|
||||
|
||||
:param X: sample data with row-wise samples (sint/sfix matrix)
|
||||
:param y: binary labels (sint list/array)
|
||||
|
||||
"""
|
||||
self.tree = TreeTrainer(
|
||||
X.transpose(), y, self.max_depth,
|
||||
attr_lengths=self.get_attr_lengths(attr_types)).train()
|
||||
|
||||
def fit_with_testing(self, X_train, y_train, X_test, y_test,
|
||||
attr_types=None, output_trees=False, debug=False):
|
||||
""" Train tree with accuracy output after every level.
|
||||
|
||||
:param X_train: training data with row-wise samples (sint/sfix matrix)
|
||||
:param y_train: training binary labels (sint list/array)
|
||||
:param X_test: testing data with row-wise samples (sint/sfix matrix)
|
||||
:param y_test: testing binary labels (sint list/array)
|
||||
:param attr_types: attributes types (list of 'b'/'c' for
|
||||
binary/continuous; default is all continuous)
|
||||
:param output_trees: output tree after every level
|
||||
:param debug: output debugging information
|
||||
|
||||
"""
|
||||
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
|
||||
attr_lengths=self.get_attr_lengths(attr_types))
|
||||
trainer.debug = debug
|
||||
trainer.debug_gini = debug
|
||||
trainer.debug_threading = debug > 1
|
||||
self.tree = trainer.train_with_testing(y_test, X_test.transpose(),
|
||||
output=output_trees)
|
||||
|
||||
def predict(self, X):
|
||||
""" Use tree for prediction.
|
||||
|
||||
:param X: sample data with row-wise samples (sint/sfix matrix)
|
||||
:returns: sint array
|
||||
|
||||
"""
|
||||
res = sint.Array(len(X))
|
||||
@for_range(len(X))
|
||||
def _(i):
|
||||
res[i] = run_decision_tree(self.tree, X[i])
|
||||
return res
|
||||
|
||||
def output(self):
|
||||
""" Output decision tree. """
|
||||
output_decision_tree(self.tree)
|
||||
|
||||
def preprocess_pandas(data):
|
||||
""" Preprocess pandas data frame to suit
|
||||
:py:class:`TreeClassifier` by expanding non-continuous attributes
|
||||
to several binary attributes as a unary encoding.
|
||||
|
||||
:returns: a tuple of the processed data and a type list for the
|
||||
:py:obj:`attr_types` argument.
|
||||
|
||||
"""
|
||||
import pandas
|
||||
import numpy
|
||||
res = []
|
||||
types = []
|
||||
for i, t in enumerate(data.dtypes):
|
||||
if pandas.api.types.is_int64_dtype(t):
|
||||
res.append(data.iloc[:,i].to_numpy())
|
||||
types.append('c')
|
||||
elif pandas.api.types.is_object_dtype(t):
|
||||
values = data.iloc[:,i].unique()
|
||||
print('converting the following to unary:', values)
|
||||
if len(values) == 2:
|
||||
res.append(data.iloc[:,i].to_numpy() == values[1])
|
||||
types.append('b')
|
||||
else:
|
||||
for value in values:
|
||||
res.append(data.iloc[:,i].to_numpy() == value)
|
||||
types.append('b')
|
||||
else:
|
||||
raise CompilerError('unknown pandas type: ' + t)
|
||||
res = numpy.array(res)
|
||||
res = numpy.swapaxes(res, 0, 1)
|
||||
return res, types
|
||||
@@ -4,8 +4,11 @@ from Compiler.program import Program
|
||||
|
||||
ORAM = OptimalORAM
|
||||
|
||||
prog = program.Program.prog
|
||||
prog.set_bit_length(min(64, prog.bit_length))
|
||||
try:
|
||||
prog = program.Program.prog
|
||||
prog.set_bit_length(min(64, prog.bit_length))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
class HeapEntry(object):
|
||||
fields = ['empty', 'prio', 'value']
|
||||
@@ -47,9 +50,11 @@ class HeapEntry(object):
|
||||
print_ln('empty %s, prio %s, value %s', *(reveal(x) for x in self))
|
||||
|
||||
class HeapORAM(object):
|
||||
def __init__(self, size, oram_type, init_rounds, int_type):
|
||||
def __init__(self, size, oram_type, init_rounds, int_type, entry_size=None):
|
||||
if entry_size is None:
|
||||
entry_size = (32,log2(size))
|
||||
self.int_type = int_type
|
||||
self.oram = oram_type(size, entry_size=(32,log2(size)), \
|
||||
self.oram = oram_type(size, entry_size=entry_size, \
|
||||
init_rounds=init_rounds, \
|
||||
value_type=int_type.basic_type)
|
||||
def __getitem__(self, index):
|
||||
@@ -74,13 +79,15 @@ class HeapORAM(object):
|
||||
return len(self.oram)
|
||||
|
||||
class HeapQ(object):
|
||||
def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint):
|
||||
def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint, entry_size=None):
|
||||
if entry_size is None:
|
||||
entry_size = (32, log2(max_size))
|
||||
basic_type = int_type.basic_type
|
||||
self.max_size = max_size
|
||||
self.levels = log2(max_size)
|
||||
self.depth = self.levels - 1
|
||||
self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type)
|
||||
self.value_index = oram_type(max_size, entry_size=log2(max_size), \
|
||||
self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type, entry_size=entry_size)
|
||||
self.value_index = oram_type(max_size, entry_size=entry_size[1], \
|
||||
init_rounds=init_rounds, \
|
||||
value_type=basic_type)
|
||||
self.size = MemValue(int_type(0))
|
||||
@@ -99,7 +106,7 @@ class HeapQ(object):
|
||||
bits.reverse()
|
||||
bits = [0] + floatingpoint.PreOR(bits, self.levels)
|
||||
bits = [bits[i+1] - bits[i] for i in range(self.levels)]
|
||||
shift = sum([bit << i for i,bit in enumerate(bits)])
|
||||
shift = self.int_type.bit_compose(bits)
|
||||
childpos = MemValue(start * shift)
|
||||
@for_range(self.levels - 1)
|
||||
def f(i):
|
||||
@@ -215,12 +222,13 @@ class HeapQ(object):
|
||||
print_ln()
|
||||
print_ln()
|
||||
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
|
||||
basic_type = int_type.basic_type
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else -1
|
||||
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
|
||||
init_rounds=vert_loops, value_type=basic_type)
|
||||
init_rounds=vert_loops, value_type=int_type)
|
||||
int_type = dist.value_type
|
||||
basic_type = int_type.basic_type
|
||||
#visited = ORAM(e_index.size)
|
||||
#previous = oram_type(e_index.size)
|
||||
Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \
|
||||
@@ -240,7 +248,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
|
||||
u = MemValue(basic_type(0))
|
||||
@for_range(n_loops or edges.size)
|
||||
def f(i):
|
||||
cint(i).print_reg('loop')
|
||||
print_ln('loop %s', i)
|
||||
time()
|
||||
u.write(if_else(last_edge, Q.pop(last_edge), u))
|
||||
#visited.access(u, True, last_edge)
|
||||
|
||||
@@ -12,4 +12,7 @@ class ArgumentError(CompilerError):
|
||||
""" Exception raised for errors in instruction argument parsing. """
|
||||
def __init__(self, arg, msg):
|
||||
self.arg = arg
|
||||
self.msg = msg
|
||||
self.msg = msg
|
||||
|
||||
class VectorMismatch(CompilerError):
|
||||
pass
|
||||
|
||||
@@ -28,13 +28,15 @@ def shift_two(n, pos):
|
||||
|
||||
def maskRing(a, k):
|
||||
shift = int(program.Program.prog.options.ring) - k
|
||||
if program.Program.prog.use_dabit:
|
||||
if program.Program.prog.use_edabit():
|
||||
r_prime, r = types.sint.get_edabit(k)
|
||||
elif program.Program.prog.use_dabit:
|
||||
rr, r = zip(*(types.sint.get_dabit() for i in range(k)))
|
||||
r_prime = types.sint.bit_compose(rr)
|
||||
else:
|
||||
r = [types.sint.get_random_bit() for i in range(k)]
|
||||
r_prime = types.sint.bit_compose(r)
|
||||
c = ((a + r_prime) << shift).reveal() >> shift
|
||||
c = ((a + r_prime) << shift).reveal(False) >> shift
|
||||
return c, r
|
||||
|
||||
def maskField(a, k, kappa):
|
||||
@@ -45,7 +47,7 @@ def maskField(a, k, kappa):
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
|
||||
# always signed due to usage in equality testing
|
||||
a += two_power(k)
|
||||
asm_open(c, a + two_power(k) * r_dprime + r_prime)
|
||||
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
|
||||
return c, r
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
@@ -231,7 +233,7 @@ def Inv(a):
|
||||
ldi(one, 1)
|
||||
inverse(t[0], t[1])
|
||||
s = t[0]*a
|
||||
asm_open(c[0], s)
|
||||
asm_open(True, c[0], s)
|
||||
# avoid division by zero for benchmarking
|
||||
divc(c[1], one, c[0])
|
||||
#divc(c[1], c[0], one)
|
||||
@@ -279,7 +281,7 @@ def BitDecRingRaw(a, k, m):
|
||||
else:
|
||||
r_bits = [types.sint.get_random_bit() for i in range(m)]
|
||||
r = types.sint.bit_compose(r_bits)
|
||||
shifted = ((a - r) << n_shift).reveal()
|
||||
shifted = ((a - r) << n_shift).reveal(False)
|
||||
masked = shifted >> n_shift
|
||||
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
|
||||
return bits
|
||||
@@ -287,7 +289,7 @@ def BitDecRingRaw(a, k, m):
|
||||
def BitDecRing(a, k, m):
|
||||
bits = BitDecRingRaw(a, k, m)
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
@@ -297,18 +299,19 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
r = [types.sint() for i in range(m)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
pow2 = two_power(k + kappa)
|
||||
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
|
||||
instructions_base.reset_global_vector_size()
|
||||
return res
|
||||
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
|
||||
return [types.sint.conv(bit) for bit in res]
|
||||
return [types.sintbit.conv(bit) for bit in res]
|
||||
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def Pow2(a, l, kappa):
|
||||
comparison.program.curr_tape.require_bit_length(l - 1)
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
return Pow2_from_bits(t)
|
||||
@@ -316,7 +319,7 @@ def Pow2(a, l, kappa):
|
||||
def Pow2_from_bits(bits):
|
||||
m = len(bits)
|
||||
t = list(bits)
|
||||
pow2k = [types.cint() for i in range(m)]
|
||||
pow2k = [None for i in range(m)]
|
||||
for i in range(m):
|
||||
pow2k[i] = two_power(2**i)
|
||||
t[i] = t[i]*pow2k[i] + 1 - t[i]
|
||||
@@ -339,10 +342,10 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
if program.Program.prog.options.ring:
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
assert n_shift > 0
|
||||
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift
|
||||
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift
|
||||
else:
|
||||
comparison.PRandInt(t, kappa)
|
||||
asm_open(c, pow2a + two_power(l) * t +
|
||||
asm_open(True, c, pow2a + two_power(l) * t +
|
||||
sum(two_power(i) * r[i] for i in range(l)))
|
||||
comparison.program.curr_tape.require_bit_length(l + kappa)
|
||||
c = list(r_bits[0].bit_decompose_clear(c, l))
|
||||
@@ -384,15 +387,15 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
r_dprime += t1 - t2
|
||||
if program.Program.prog.options.ring:
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift
|
||||
c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift
|
||||
else:
|
||||
comparison.PRandInt(rk, kappa)
|
||||
r_dprime += two_power(l) * rk
|
||||
asm_open(c, a + r_dprime + r_prime)
|
||||
asm_open(True, c, a + r_dprime + r_prime)
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
|
||||
lts(d, c_dprime, r_prime, l, kappa)
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
|
||||
if compute_modulo:
|
||||
b = c_dprime - r_prime + pow2m * d
|
||||
return b, pow2m
|
||||
@@ -414,7 +417,7 @@ def TruncInRing(to_shift, l, pow2m):
|
||||
rev *= pow2m
|
||||
r_bits = [types.sint.get_random_bit() for i in range(l)]
|
||||
r = types.sint.bit_compose(r_bits)
|
||||
shifted = (rev - (r << n_shift)).reveal()
|
||||
shifted = (rev - (r << n_shift)).reveal(False)
|
||||
masked = shifted >> n_shift
|
||||
bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l))
|
||||
return types.sint.bit_compose(reversed(bits))
|
||||
@@ -455,7 +458,7 @@ def Int2FL(a, gamma, l, kappa=None):
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
|
||||
else:
|
||||
v = 2**(l-gamma+1) * t
|
||||
p = (p + gamma - 1 - l) * (1 -z)
|
||||
p = (p + gamma - 1 - l) * z.bit_not()
|
||||
return v, p, z, s
|
||||
|
||||
def FLRound(x, mode):
|
||||
@@ -528,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
msb = r_bits[-1]
|
||||
n_shift = n_ring - (k + 1)
|
||||
tmp = a + r
|
||||
masked = (tmp << n_shift).reveal()
|
||||
masked = (tmp << n_shift).reveal(False)
|
||||
shifted = (masked << 1 >> (n_shift + m + 1))
|
||||
overflow = msb.bit_xor(masked >> (n_ring - 1))
|
||||
res = shifted - upper + \
|
||||
@@ -549,7 +552,7 @@ def TruncPrField(a, k, m, kappa=None):
|
||||
k, m, kappa, use_dabit=False)
|
||||
two_to_m = two_power(m)
|
||||
r = two_to_m * r_dprime + r_prime
|
||||
c = (b + r).reveal()
|
||||
c = (b + r).reveal(False)
|
||||
c_prime = c % two_to_m
|
||||
a_prime = c_prime - r_prime
|
||||
d = (a - a_prime) / two_to_m
|
||||
@@ -629,14 +632,16 @@ def BITLT(a, b, bit_length):
|
||||
# - From the paper
|
||||
# Multiparty Computation for Interval, Equality, and Comparison without
|
||||
# Bit-Decomposition Protocol
|
||||
def BitDecFull(a, maybe_mixed=False):
|
||||
def BitDecFull(a, n_bits=None, maybe_mixed=False):
|
||||
from .library import get_program, do_while, if_, break_point
|
||||
from .types import sint, regint, longint, cint
|
||||
p = get_program().prime
|
||||
assert p
|
||||
bit_length = p.bit_length()
|
||||
n_bits = n_bits or bit_length
|
||||
assert n_bits <= bit_length
|
||||
logp = int(round(math.log(p, 2)))
|
||||
if abs(p - 2 ** logp) / p < 2 ** -get_program().security:
|
||||
if get_program().rabbit_gap():
|
||||
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
|
||||
# no need for exact randomness generation
|
||||
# if modulo a power of two is close enough
|
||||
@@ -663,26 +668,26 @@ def BitDecFull(a, maybe_mixed=False):
|
||||
def _():
|
||||
for i in range(bit_length):
|
||||
tbits[j][i].link(sint.get_random_bit())
|
||||
c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
|
||||
c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False))
|
||||
done[j].link(c)
|
||||
return (sum(done) != a.size)
|
||||
for j in range(a.size):
|
||||
for i in range(bit_length):
|
||||
movs(bbits[i][j], tbits[j][i])
|
||||
b = sint.bit_compose(bbits)
|
||||
c = (a-b).reveal()
|
||||
c = (a-b).reveal(False)
|
||||
cmodp = c
|
||||
t = bbits[0].bit_decompose_clear(p - c, bit_length)
|
||||
c = longint(c, bit_length)
|
||||
czero = (c==0)
|
||||
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
|
||||
fbar = [bbits[0].clear_type.conv(cint(x))
|
||||
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
|
||||
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
|
||||
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
|
||||
for x in ((1<<bit_length)+c-p).bit_decompose(n_bits)]
|
||||
fbard = bbits[0].bit_decompose_clear(cmodp, n_bits)
|
||||
g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)]
|
||||
h = bbits[0].bit_adder(bbits, g)
|
||||
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
|
||||
for i in range(bit_length)]
|
||||
for i in range(n_bits)]
|
||||
if maybe_mixed:
|
||||
return abits
|
||||
else:
|
||||
|
||||
@@ -17,6 +17,7 @@ right order.
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import math
|
||||
from . import tools
|
||||
from random import randint
|
||||
from functools import reduce
|
||||
@@ -69,7 +70,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMC']
|
||||
arg_format = ['cw','int']
|
||||
arg_format = ['cw','long']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -84,7 +85,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMS']
|
||||
arg_format = ['sw','int']
|
||||
arg_format = ['sw','long']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -99,7 +100,7 @@ class stmc(base.DirectMemoryWriteInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMC']
|
||||
arg_format = ['c','int']
|
||||
arg_format = ['c','long']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -114,7 +115,7 @@ class stms(base.DirectMemoryWriteInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMS']
|
||||
arg_format = ['s','int']
|
||||
arg_format = ['s','long']
|
||||
|
||||
@base.vectorize
|
||||
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
@@ -128,7 +129,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMINT']
|
||||
arg_format = ['ciw','int']
|
||||
arg_format = ['ciw','long']
|
||||
|
||||
@base.vectorize
|
||||
class stmint(base.DirectMemoryWriteInstruction):
|
||||
@@ -142,7 +143,7 @@ class stmint(base.DirectMemoryWriteInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMINT']
|
||||
arg_format = ['ci','int']
|
||||
arg_format = ['ci','long']
|
||||
|
||||
@base.vectorize
|
||||
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
@@ -294,6 +295,7 @@ class movint(base.Instruction):
|
||||
@base.vectorize
|
||||
class pushint(base.StackInstruction):
|
||||
""" Pushes clear integer register to the thread-local stack.
|
||||
Considered obsolete.
|
||||
|
||||
:param: source (regint)
|
||||
"""
|
||||
@@ -303,6 +305,7 @@ class pushint(base.StackInstruction):
|
||||
@base.vectorize
|
||||
class popint(base.StackInstruction):
|
||||
""" Pops from the thread-local stack to clear integer register.
|
||||
Considered obsolete.
|
||||
|
||||
:param: destination (regint)
|
||||
"""
|
||||
@@ -353,7 +356,17 @@ class reqbl(base.Instruction):
|
||||
code = base.opcodes['REQBL']
|
||||
arg_format = ['int']
|
||||
|
||||
class active(base.Instruction):
|
||||
""" Indicate whether program is compatible with malicious-security
|
||||
protocols.
|
||||
|
||||
:param: 0 for no, 1 for yes
|
||||
"""
|
||||
code = base.opcodes['ACTIVE']
|
||||
arg_format = ['int']
|
||||
|
||||
class time(base.IOInstruction):
|
||||
|
||||
""" Output time since start of computation. """
|
||||
code = base.opcodes['TIME']
|
||||
arg_format = []
|
||||
@@ -384,7 +397,15 @@ class use(base.Instruction):
|
||||
:param: number (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE']
|
||||
arg_format = ['int','int','int']
|
||||
arg_format = ['int','int','long']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
from .program import field_types, data_types
|
||||
from .util import find_in_dict
|
||||
return {(find_in_dict(field_types, args[0].i),
|
||||
find_in_dict(data_types, args[1].i)):
|
||||
args[2].i}
|
||||
|
||||
class use_inp(base.Instruction):
|
||||
""" Input usage. Necessary to avoid reusage while using
|
||||
@@ -395,7 +416,14 @@ class use_inp(base.Instruction):
|
||||
:param: number (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE_INP']
|
||||
arg_format = ['int','int','int']
|
||||
arg_format = ['int','int','long']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
from .program import field_types, data_types
|
||||
from .util import find_in_dict
|
||||
return {(find_in_dict(field_types, args[0].i), 'input', args[1].i):
|
||||
args[2].i}
|
||||
|
||||
class use_edabit(base.Instruction):
|
||||
""" edaBit usage. Necessary to avoid reusage while using
|
||||
@@ -407,7 +435,11 @@ class use_edabit(base.Instruction):
|
||||
:param: number (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE_EDABIT']
|
||||
arg_format = ['int','int','int']
|
||||
arg_format = ['int','int','long']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i}
|
||||
|
||||
class use_matmul(base.Instruction):
|
||||
""" Matrix multiplication usage. Used for multithreading of
|
||||
@@ -419,7 +451,11 @@ class use_matmul(base.Instruction):
|
||||
:param: number (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE_MATMUL']
|
||||
arg_format = ['int','int','int','int']
|
||||
arg_format = ['int','int','int','long']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
|
||||
|
||||
class run_tape(base.Instruction):
|
||||
""" Start tape/bytecode file in another thread.
|
||||
@@ -442,7 +478,7 @@ class join_tape(base.Instruction):
|
||||
arg_format = ['int']
|
||||
|
||||
class crash(base.IOInstruction):
|
||||
""" Crash runtime if the register's value is > 0.
|
||||
""" Crash runtime if the value in the register is not zero.
|
||||
|
||||
:param: Crash condition (regint)"""
|
||||
code = base.opcodes['CRASH']
|
||||
@@ -464,7 +500,12 @@ class use_prep(base.Instruction):
|
||||
:param: number of items to use (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE_PREP']
|
||||
arg_format = ['str','int']
|
||||
arg_format = ['str','long']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp',
|
||||
args[0].str): args[1].i}
|
||||
|
||||
class nplayers(base.Instruction):
|
||||
""" Store number of players in clear integer register.
|
||||
@@ -585,6 +626,18 @@ class submr(base.SubBase):
|
||||
code = base.opcodes['SUBMR']
|
||||
arg_format = ['sw','c','s']
|
||||
|
||||
@base.vectorize
|
||||
class prefixsums(base.Instruction):
|
||||
""" Prefix sum.
|
||||
|
||||
:param: result (sint)
|
||||
:param: input (sint)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['PREFIXSUMS']
|
||||
arg_format = ['sw','s']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class mulc(base.MulBase):
|
||||
@@ -778,30 +831,6 @@ class gbitcom(base.Instruction):
|
||||
return True
|
||||
|
||||
|
||||
###
|
||||
### Special GF(2) arithmetic instructions
|
||||
###
|
||||
|
||||
@base.vectorize
|
||||
class gmulbitc(base.MulBase):
|
||||
r""" Clear GF(2^n) by clear GF(2) multiplication """
|
||||
__slots__ = []
|
||||
code = base.opcodes['GMULBITC']
|
||||
arg_format = ['cgw','cg','cg']
|
||||
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class gmulbitm(base.MulBase):
|
||||
r""" Secret GF(2^n) by clear GF(2) multiplication """
|
||||
__slots__ = []
|
||||
code = base.opcodes['GMULBITM']
|
||||
arg_format = ['sgw','sg','cg']
|
||||
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
###
|
||||
### Arithmetic with immediate values
|
||||
###
|
||||
@@ -1046,6 +1075,7 @@ class shrci(base.ClearShiftInstruction):
|
||||
code = base.opcodes['SHRCI']
|
||||
op = '__rshift__'
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class shrsi(base.ClearShiftInstruction):
|
||||
""" Bitwise right shift of secret register (vector) by (constant)
|
||||
@@ -1184,7 +1214,7 @@ class randoms(base.Instruction):
|
||||
field_type = 'modp'
|
||||
|
||||
@base.vectorize
|
||||
class randomfulls(base.Instruction):
|
||||
class randomfulls(base.DataInstruction):
|
||||
""" Store share(s) of a fresh secret random element in secret
|
||||
register (vectors).
|
||||
|
||||
@@ -1194,6 +1224,10 @@ class randomfulls(base.Instruction):
|
||||
code = base.opcodes['RANDOMFULLS']
|
||||
arg_format = ['sw']
|
||||
field_type = 'modp'
|
||||
data_type = 'random'
|
||||
|
||||
def get_repeat(self):
|
||||
return len(self.args)
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -1229,15 +1263,20 @@ class inverse(base.DataInstruction):
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class inputmask(base.Instruction):
|
||||
r""" Load secret $s_i$ with the next input mask for player $p$ and
|
||||
write the mask on player $p$'s private output. """
|
||||
""" Store fresh random input mask(s) in secret register (vector) and clear
|
||||
register (vector) of the relevant player.
|
||||
|
||||
:param: mask (sint)
|
||||
:param: mask (cint, player only)
|
||||
:param: player (int)
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTMASK']
|
||||
arg_format = ['sw', 'p']
|
||||
arg_format = ['sw', 'cw', 'p']
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'input', self.args[1]), \
|
||||
req_node.increment((self.field_type, 'input', self.args[2]), \
|
||||
self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
@@ -1275,7 +1314,7 @@ class prep(base.Instruction):
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, self.args[0]), 1)
|
||||
req_node.increment((self.field_type, self.args[0]), self.get_size())
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -1293,10 +1332,8 @@ class asm_input(base.TextInputInstruction):
|
||||
arg_format = tools.cycle(['sw', 'p'])
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.args[1::2]:
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
def get_players(self):
|
||||
return self.args[1::2]
|
||||
|
||||
@base.vectorize
|
||||
class inputfix(base.TextInputInstruction):
|
||||
@@ -1305,10 +1342,8 @@ class inputfix(base.TextInputInstruction):
|
||||
arg_format = tools.cycle(['sw', 'int', 'p'])
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.args[2::3]:
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
def get_players(self):
|
||||
return self.args[2::3]
|
||||
|
||||
@base.vectorize
|
||||
class inputfloat(base.TextInputInstruction):
|
||||
@@ -1322,7 +1357,7 @@ class inputfloat(base.TextInputInstruction):
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
4 * self.get_size())
|
||||
|
||||
class inputmixed_base(base.TextInputInstruction):
|
||||
class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
|
||||
__slots__ = []
|
||||
field_type = 'modp'
|
||||
# the following has to match TYPE: (N_DEST, N_PARAM)
|
||||
@@ -1341,22 +1376,30 @@ class inputmixed_base(base.TextInputInstruction):
|
||||
type_id = self.type_ids[name]
|
||||
super(inputmixed_base, self).__init__(type_id, *args)
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
yield 'int'
|
||||
@classmethod
|
||||
def dynamic_arg_format(self, args):
|
||||
yield 'int'
|
||||
for i, t in self.bases(iter(args)):
|
||||
for j in range(self.types[t][0]):
|
||||
yield 'sw'
|
||||
for j in range(self.types[t][1]):
|
||||
yield 'int'
|
||||
yield self.player_arg_type
|
||||
yield 'int'
|
||||
|
||||
def bases(self):
|
||||
@classmethod
|
||||
def bases(self, args):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
yield i
|
||||
i += sum(self.types[self.args[i]]) + 2
|
||||
while True:
|
||||
try:
|
||||
t = next(args)
|
||||
except StopIteration:
|
||||
return
|
||||
yield i, t
|
||||
n = sum(self.types[t])
|
||||
i += n + 2
|
||||
for j in range(n + 1):
|
||||
next(args)
|
||||
|
||||
@base.vectorize
|
||||
class inputmixed(inputmixed_base):
|
||||
@@ -1380,14 +1423,16 @@ class inputmixed(inputmixed_base):
|
||||
player_arg_type = 'p'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
for i, t in self.bases(iter(self.args)):
|
||||
player = self.args[i + sum(self.types[t]) + 1]
|
||||
n_dest = self.types[t][0]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
n_dest * self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
def get_players(self):
|
||||
for i, t in self.bases(iter(self.args)):
|
||||
yield self.args[i + sum(self.types[t]) + 1]
|
||||
|
||||
class inputmixedreg(inputmixed_base):
|
||||
""" Store private input in secret registers (vectors). The input is
|
||||
read as integer or floating-point number and the latter is then
|
||||
@@ -1407,11 +1452,29 @@ class inputmixedreg(inputmixed_base):
|
||||
"""
|
||||
code = base.opcodes['INPUTMIXEDREG']
|
||||
player_arg_type = 'ci'
|
||||
is_vec = lambda self: True
|
||||
|
||||
def __init__(self, *args):
|
||||
inputmixed_base.__init__(self, *args)
|
||||
for i, t in self.bases(iter(self.args)):
|
||||
n = self.types[t][0]
|
||||
for j in range(i + 1, i + 1 + n):
|
||||
assert args[j].size == self.get_size()
|
||||
|
||||
def get_size(self):
|
||||
return self.args[1].size
|
||||
|
||||
def get_code(self):
|
||||
return inputmixed_base.get_code(
|
||||
self, self.get_size() if self.get_size() > 1 else 0)
|
||||
|
||||
def add_usage(self, req_node):
|
||||
# player 0 as proxy
|
||||
req_node.increment((self.field_type, 'input', 0), float('inf'))
|
||||
|
||||
def get_players(self):
|
||||
pass
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class rawinput(base.RawInputInstruction, base.Mergeable):
|
||||
@@ -1433,7 +1496,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable):
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
|
||||
class inputpersonal(base.Instruction, base.Mergeable):
|
||||
class personal_base(base.Instruction, base.Mergeable):
|
||||
__slots__ = []
|
||||
field_type = 'modp'
|
||||
|
||||
def __init__(self, *args):
|
||||
super(personal_base, self).__init__(*args)
|
||||
for i in range(0, len(args), 4):
|
||||
assert args[i + 2].size == args[i]
|
||||
assert args[i + 3].size == args[i]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in range(0, len(self.args), 4):
|
||||
player = self.args[i + 1]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.args[i])
|
||||
|
||||
class inputpersonal(personal_base):
|
||||
""" Private input from cint.
|
||||
|
||||
:param: vector size (int)
|
||||
@@ -1445,19 +1524,47 @@ class inputpersonal(base.Instruction, base.Mergeable):
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTPERSONAL']
|
||||
arg_format = tools.cycle(['int','p','sw','c'])
|
||||
field_type = 'modp'
|
||||
|
||||
def __init__(self, *args):
|
||||
super(inputpersonal, self).__init__(*args)
|
||||
for i in range(0, len(args), 4):
|
||||
assert args[i + 2].size == args[i]
|
||||
assert args[i + 3].size == args[i]
|
||||
class privateoutput(personal_base, base.DataInstruction):
|
||||
""" Private output to cint.
|
||||
|
||||
:param: vector size (int)
|
||||
:param: player (int)
|
||||
:param: destination (cint)
|
||||
:param: source (sint)
|
||||
:param: (repeat from vector size)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['PRIVATEOUTPUT']
|
||||
arg_format = tools.cycle(['int','p','cw','s'])
|
||||
data_type = 'open'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in range(0, len(self.args), 4):
|
||||
player = self.args[i + 1]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.args[i])
|
||||
personal_base.add_usage(self, req_node)
|
||||
base.DataInstruction.add_usage(self, req_node)
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(self.args[::4])
|
||||
|
||||
class sendpersonal(base.Instruction, base.Mergeable):
|
||||
""" Private input from cint.
|
||||
|
||||
:param: vector size (int)
|
||||
:param: destination player (int)
|
||||
:param: destination (cint)
|
||||
:param: source player (int)
|
||||
:param: source (cint)
|
||||
:param: (repeat from vector size)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['SENDPERSONAL']
|
||||
arg_format = tools.cycle(['int','p','cw','p','c'])
|
||||
|
||||
def __init__(self, *args):
|
||||
super(sendpersonal, self).__init__(*args)
|
||||
for i in range(0, len(args), 5):
|
||||
assert args[i + 2].size == args[i]
|
||||
assert args[i + 4].size == args[i]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -1546,7 +1653,7 @@ class print_char(base.IOInstruction):
|
||||
arg_format = ['int']
|
||||
|
||||
def __init__(self, ch):
|
||||
super(print_char, self).__init__(ord(ch))
|
||||
super(print_char, self).__init__(ch)
|
||||
|
||||
class print_char4(base.IOInstruction):
|
||||
""" Output four bytes.
|
||||
@@ -1650,6 +1757,7 @@ class writesockets(base.IOInstruction):
|
||||
from registers into a socket for a specified client id. If the
|
||||
protocol uses MACs, the client should be different for every party.
|
||||
|
||||
:param: number of arguments to follow
|
||||
:param: client id (regint)
|
||||
:param: message type (must be 0)
|
||||
:param: vector size (int)
|
||||
@@ -1727,14 +1835,15 @@ class writesharestofile(base.IOInstruction):
|
||||
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
|
||||
(appending at the end).
|
||||
|
||||
:param: number of shares (int)
|
||||
:param: number of arguments to follow / number of shares plus one (int)
|
||||
:param: position (regint, -1 for appending)
|
||||
:param: source (sint)
|
||||
:param: (repeat from source)...
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITEFILESHARE']
|
||||
arg_format = itertools.repeat('s')
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('s'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -1788,26 +1897,19 @@ class floatoutput(base.PublicFileIOInstruction):
|
||||
code = base.opcodes['FLOATOUTPUT']
|
||||
arg_format = ['p','c','c','c','c']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class startprivateoutput(base.Instruction):
|
||||
r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STARTPRIVATEOUTPUT']
|
||||
arg_format = ['sw','s','p']
|
||||
field_type = 'modp'
|
||||
class fixinput(base.PublicFileIOInstruction):
|
||||
""" Binary fixed-point input.
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'input', self.args[2]), \
|
||||
self.get_size())
|
||||
:param: player (int)
|
||||
:param: destination (cint)
|
||||
:param: exponent (int)
|
||||
:param: input type (0: 64-bit integer, 1: float, 2: double)
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class stopprivateoutput(base.Instruction):
|
||||
r""" Previously iniated private output to $n$ via $c_i$. """
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['STOPPRIVATEOUTPUT']
|
||||
arg_format = ['cw','c','p']
|
||||
code = base.opcodes['FIXINPUT']
|
||||
arg_format = ['p','cw','int','int']
|
||||
|
||||
@base.vectorize
|
||||
class rand(base.Instruction):
|
||||
@@ -2122,17 +2224,26 @@ class gconvgf2n(base.Instruction):
|
||||
# rename 'open' to avoid conflict with built-in open function
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class asm_open(base.VarArgsInstruction):
|
||||
class asm_open(base.VarArgsInstruction, base.DataInstruction):
|
||||
""" Reveal secret registers (vectors) to clear registers (vectors).
|
||||
|
||||
:param: number of argument to follow (multiple of two)
|
||||
:param: number of argument to follow (odd number)
|
||||
:param: check after opening (0/1)
|
||||
:param: destination (cint)
|
||||
:param: source (sint)
|
||||
:param: (repeat the last two)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['OPEN']
|
||||
arg_format = tools.cycle(['cw','s'])
|
||||
arg_format = tools.chain(['int'], tools.cycle(['cw','s']))
|
||||
data_type = 'open'
|
||||
|
||||
def get_repeat(self):
|
||||
return (len(self.args) - 1) // 2
|
||||
|
||||
def merge(self, other):
|
||||
self.args[0] |= other.args[0]
|
||||
self.args += other.args[1:]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -2209,7 +2320,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction):
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
class dotprods(base.VarArgsInstruction, base.DataInstruction,
|
||||
base.DynFormatInstruction):
|
||||
""" Dot product of secret registers (vectors).
|
||||
Note that the vectorized version works element-wise.
|
||||
|
||||
@@ -2237,31 +2349,30 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
flat_args += [x, y]
|
||||
base.Instruction.__init__(self, *flat_args)
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
@classmethod
|
||||
def dynamic_arg_format(self, args):
|
||||
field = 'g' if self.is_gf2n() else ''
|
||||
for i in self.bases():
|
||||
yield 'int'
|
||||
yield 'int'
|
||||
for i, n in self.bases(args):
|
||||
yield 's' + field + 'w'
|
||||
for j in range(self.args[i] - 2):
|
||||
assert n > 2
|
||||
for j in range(n - 2):
|
||||
yield 's' + field
|
||||
yield 'int'
|
||||
|
||||
gf2n_arg_format = arg_format
|
||||
|
||||
def bases(self):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
yield i
|
||||
i += self.args[i]
|
||||
@property
|
||||
def gf2n_arg_format(self):
|
||||
return self.arg_format()
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
|
||||
return sum(self.args[i] // 2
|
||||
for i, n in self.bases(iter(self.args))) * self.get_size()
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[i + 1] for i in self.bases()]
|
||||
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
|
||||
|
||||
def get_used(self):
|
||||
for i in self.bases():
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
for reg in self.args[i + 2:i + self.args[i]]:
|
||||
yield reg
|
||||
|
||||
@@ -2317,9 +2428,10 @@ class matmulsm(matmul_base):
|
||||
super(matmulsm, self).add_usage(req_node)
|
||||
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
|
||||
|
||||
class conv2ds(base.DataInstruction):
|
||||
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
|
||||
""" Secret 2D convolution.
|
||||
|
||||
:param: number of arguments to follow (int)
|
||||
:param: result (sint vector in row-first order)
|
||||
:param: inputs (sint vector in row-first order)
|
||||
:param: weights (sint vector in row-first order)
|
||||
@@ -2335,10 +2447,12 @@ class conv2ds(base.DataInstruction):
|
||||
:param: padding height (int)
|
||||
:param: padding width (int)
|
||||
:param: batch size (int)
|
||||
:param: repeat from result...
|
||||
|
||||
"""
|
||||
code = base.opcodes['CONV2DS']
|
||||
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
|
||||
'int','int','int','int']
|
||||
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
|
||||
'int','int','int','int','int','int','int'])
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
|
||||
@@ -2349,14 +2463,16 @@ class conv2ds(base.DataInstruction):
|
||||
assert args[2].size == args[7] * args[8] * args[11]
|
||||
|
||||
def get_repeat(self):
|
||||
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
|
||||
self.args[11] * self.args[14]
|
||||
args = self.args
|
||||
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
|
||||
args[i+11] * args[i+14] for i in range(0, len(args), 15))
|
||||
|
||||
def add_usage(self, req_node):
|
||||
super(conv2ds, self).add_usage(req_node)
|
||||
args = self.args
|
||||
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
|
||||
args[14] * args[3] * args[4])), 1)
|
||||
for i in range(0, len(self.args), 15):
|
||||
args = self.args[i:i + 15]
|
||||
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
|
||||
args[14] * args[3] * args[4])), 1)
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
@@ -2372,6 +2488,124 @@ class trunc_pr(base.VarArgsInstruction):
|
||||
code = base.opcodes['TRUNC_PR']
|
||||
arg_format = tools.cycle(['sw','s','int','int'])
|
||||
|
||||
class shuffle_base(base.DataInstruction):
|
||||
n_relevant_parties = 2
|
||||
|
||||
@staticmethod
|
||||
def logn(n):
|
||||
return int(math.ceil(math.log(n, 2)))
|
||||
|
||||
@classmethod
|
||||
def n_swaps(cls, n):
|
||||
logn = cls.logn(n)
|
||||
return logn * 2 ** logn - 2 ** logn + 1
|
||||
|
||||
def add_gen_usage(self, req_node, n):
|
||||
# hack for unknown usage
|
||||
req_node.increment(('bit', 'inverse'), float('inf'))
|
||||
# minimal usage with two relevant parties
|
||||
logn = self.logn(n)
|
||||
n_switches = self.n_swaps(n)
|
||||
for i in range(self.n_relevant_parties):
|
||||
req_node.increment((self.field_type, 'input', i), n_switches)
|
||||
# multiplications for bit check
|
||||
req_node.increment((self.field_type, 'triple'),
|
||||
n_switches * self.n_relevant_parties)
|
||||
|
||||
def add_apply_usage(self, req_node, n, record_size):
|
||||
req_node.increment(('bit', 'inverse'), float('inf'))
|
||||
logn = self.logn(n)
|
||||
n_switches = self.n_swaps(n) * self.n_relevant_parties
|
||||
if n != 2 ** logn:
|
||||
record_size += 1
|
||||
req_node.increment((self.field_type, 'triple'),
|
||||
n_switches * record_size)
|
||||
|
||||
@base.gf2n
|
||||
class secshuffle(base.VectorInstruction, shuffle_base):
|
||||
""" Secure shuffling.
|
||||
|
||||
:param: destination (sint)
|
||||
:param: source (sint)
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['SECSHUFFLE']
|
||||
arg_format = ['sw','s','int']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(secshuffle_class, self).__init__(*args, **kwargs)
|
||||
assert len(args[0]) == len(args[1])
|
||||
assert len(args[0]) > args[2]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
self.add_gen_usage(req_node, len(self.args[0]))
|
||||
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
|
||||
|
||||
class gensecshuffle(shuffle_base):
|
||||
""" Generate secure shuffle to bit used several times.
|
||||
|
||||
:param: destination (regint)
|
||||
:param: size (int)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['GENSECSHUFFLE']
|
||||
arg_format = ['ciw','int']
|
||||
|
||||
def add_usage(self, req_node):
|
||||
self.add_gen_usage(req_node, self.args[1])
|
||||
|
||||
class applyshuffle(base.VectorInstruction, shuffle_base):
|
||||
""" Generate secure shuffle to bit used several times.
|
||||
|
||||
:param: destination (sint)
|
||||
:param: source (sint)
|
||||
:param: number of elements to be treated as one (int)
|
||||
:param: handle (regint)
|
||||
:param: reverse (0/1)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['APPLYSHUFFLE']
|
||||
arg_format = ['sw','s','int','ci','int']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(applyshuffle, self).__init__(*args, **kwargs)
|
||||
assert len(args[0]) == len(args[1])
|
||||
assert len(args[0]) > args[2]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
|
||||
|
||||
class delshuffle(base.Instruction):
|
||||
""" Delete secure shuffle.
|
||||
|
||||
:param: handle (regint)
|
||||
|
||||
"""
|
||||
code = base.opcodes['DELSHUFFLE']
|
||||
arg_format = ['ci']
|
||||
|
||||
class inverse_permutation(base.VectorInstruction, shuffle_base):
|
||||
""" Calculate the inverse permutation of a secret permutation.
|
||||
|
||||
:param: destination (sint)
|
||||
:param: source (sint)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['INVPERM']
|
||||
arg_format = ['sw', 's']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(inverse_permutation, self).__init__(*args, **kwargs)
|
||||
assert len(args[0]) == len(args[1])
|
||||
|
||||
def add_usage(self, req_node):
|
||||
self.add_gen_usage(req_node, len(self.args[0]))
|
||||
self.add_apply_usage(req_node, len(self.args[0]), 1)
|
||||
|
||||
|
||||
class check(base.Instruction):
|
||||
"""
|
||||
Force MAC check in current thread and all idle thread if current
|
||||
@@ -2399,7 +2633,7 @@ class sqrs(base.CISC):
|
||||
c = [program.curr_block.new_reg('c') for i in range(2)]
|
||||
square(s[0], s[1])
|
||||
subs(s[2], self.args[1], s[0])
|
||||
asm_open(c[0], s[2])
|
||||
asm_open(False, c[0], s[2])
|
||||
mulc(c[1], c[0], c[0])
|
||||
mulm(s[3], self.args[1], c[0])
|
||||
adds(s[4], s[3], s[3])
|
||||
@@ -2407,19 +2641,6 @@ class sqrs(base.CISC):
|
||||
subml(self.args[0], s[5], c[1])
|
||||
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class lts(base.CISC):
|
||||
""" Secret comparison $s_i = (s_j < s_k)$. """
|
||||
__slots__ = []
|
||||
arg_format = ['sw', 's', 's', 'int', 'int']
|
||||
|
||||
def expand(self):
|
||||
from .types import sint
|
||||
a = sint()
|
||||
subs(a, self.args[1], self.args[2])
|
||||
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
|
||||
|
||||
# placeholder for documentation
|
||||
class cisc:
|
||||
""" Meta instruction for emulation. This instruction is only generated
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import inspect
|
||||
import functools
|
||||
import copy
|
||||
import sys
|
||||
import struct
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.config import *
|
||||
from Compiler import util
|
||||
@@ -64,6 +66,7 @@ opcodes = dict(
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -78,6 +81,7 @@ opcodes = dict(
|
||||
SUBSI = 0x2A,
|
||||
SUBCFI = 0x2B,
|
||||
SUBSFI = 0x2C,
|
||||
PREFIXSUMS = 0x2D,
|
||||
# Multiplication/division
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -103,6 +107,13 @@ opcodes = dict(
|
||||
MATMULSM = 0xAB,
|
||||
CONV2DS = 0xAC,
|
||||
CHECK = 0xAF,
|
||||
PRIVATEOUTPUT = 0xAD,
|
||||
# Shuffling
|
||||
SECSHUFFLE = 0xFA,
|
||||
GENSECSHUFFLE = 0xFB,
|
||||
APPLYSHUFFLE = 0xFC,
|
||||
DELSHUFFLE = 0xFD,
|
||||
INVPERM = 0xFE,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -126,6 +137,7 @@ opcodes = dict(
|
||||
INPUTMIXEDREG = 0xF3,
|
||||
RAWINPUT = 0xF4,
|
||||
INPUTPERSONAL = 0xF5,
|
||||
SENDPERSONAL = 0xF6,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
@@ -198,8 +210,9 @@ opcodes = dict(
|
||||
CONDPRINTPLAIN = 0xE1,
|
||||
INTOUTPUT = 0xE6,
|
||||
FLOATOUTPUT = 0xE7,
|
||||
GBITDEC = 0x184,
|
||||
GBITCOM = 0x185,
|
||||
FIXINPUT = 0xE8,
|
||||
GBITDEC = 0x18A,
|
||||
GBITCOM = 0x18B,
|
||||
# Secure socket
|
||||
INITSECURESOCKET = 0x1BA,
|
||||
RESPSECURESOCKET = 0x1BB
|
||||
@@ -215,8 +228,13 @@ def int_to_bytes(x):
|
||||
global_vector_size_stack = []
|
||||
global_instruction_type_stack = ['modp']
|
||||
|
||||
def check_vector_size(size):
|
||||
if isinstance(size, program.curr_tape.Register):
|
||||
raise CompilerError('vector size must be known at compile time')
|
||||
|
||||
def set_global_vector_size(size):
|
||||
stack = global_vector_size_stack
|
||||
check_vector_size(size)
|
||||
if size == 1 and not stack:
|
||||
return
|
||||
stack.append(size)
|
||||
@@ -299,11 +317,12 @@ def vectorize(instruction, global_dict=None):
|
||||
vectorized_name = 'v' + instruction.__name__
|
||||
Vectorized_Instruction.__name__ = vectorized_name
|
||||
global_dict[vectorized_name] = Vectorized_Instruction
|
||||
|
||||
if 'sphinx.extension' in sys.modules:
|
||||
return instruction
|
||||
|
||||
global_dict[instruction.__name__ + '_class'] = instruction
|
||||
instruction.__doc__ = ''
|
||||
# exclude GF(2^n) instructions from documentation
|
||||
if instruction.code and instruction.code >> 8 == 1:
|
||||
maybe_vectorized_instruction.__doc__ = ''
|
||||
maybe_vectorized_instruction.arg_format = instruction.arg_format
|
||||
return maybe_vectorized_instruction
|
||||
|
||||
|
||||
@@ -332,7 +351,7 @@ def gf2n(instruction):
|
||||
if isinstance(arg_format, list):
|
||||
__format = []
|
||||
for __f in arg_format:
|
||||
if __f in ('int', 'p', 'ci', 'str'):
|
||||
if __f in ('int', 'long', 'p', 'ci', 'str'):
|
||||
__format.append(__f)
|
||||
else:
|
||||
__format.append(__f[0] + 'g' + __f[1:])
|
||||
@@ -355,12 +374,13 @@ def gf2n(instruction):
|
||||
arg_format = instruction_cls.gf2n_arg_format
|
||||
elif isinstance(instruction_cls.arg_format, itertools.repeat):
|
||||
__f = next(instruction_cls.arg_format)
|
||||
if __f != 'int' and __f != 'p':
|
||||
if __f not in ('int', 'long', 'p'):
|
||||
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
|
||||
else:
|
||||
arg_format = copy.deepcopy(instruction_cls.arg_format)
|
||||
reformat(arg_format)
|
||||
|
||||
@classmethod
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
@@ -389,8 +409,11 @@ def gf2n(instruction):
|
||||
else:
|
||||
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
|
||||
|
||||
if 'sphinx.extension' in sys.modules:
|
||||
return instruction
|
||||
|
||||
global_dict[instruction.__name__ + '_class'] = instruction_cls
|
||||
instruction_cls.__doc__ = ''
|
||||
maybe_gf2n_instruction.arg_format = instruction.arg_format
|
||||
return maybe_gf2n_instruction
|
||||
#return instruction
|
||||
|
||||
@@ -404,6 +427,7 @@ def cisc(function):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.security = program.security
|
||||
self.calls = [(args, kwargs)]
|
||||
self.params = []
|
||||
self.used = []
|
||||
@@ -427,7 +451,7 @@ def cisc(function):
|
||||
|
||||
def merge_id(self):
|
||||
return self.function, tuple(self.params), \
|
||||
tuple(sorted(self.kwargs.items()))
|
||||
tuple(sorted(self.kwargs.items())), self.security
|
||||
|
||||
def merge(self, other):
|
||||
self.calls += other.calls
|
||||
@@ -452,7 +476,10 @@ def cisc(function):
|
||||
except:
|
||||
args.append(arg)
|
||||
program.options.cisc = False
|
||||
old_security = program.security
|
||||
program.security = self.security
|
||||
self.function(*args, **self.kwargs)
|
||||
program.security = old_security
|
||||
program.options.cisc = True
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
@@ -499,8 +526,12 @@ def cisc(function):
|
||||
for arg in self.args:
|
||||
try:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
except:
|
||||
except TypeError:
|
||||
break
|
||||
except:
|
||||
print([call[0][0].size for call in self.calls])
|
||||
raise
|
||||
assert len(new_regs) > 1
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
|
||||
@@ -523,7 +554,7 @@ def cisc(function):
|
||||
|
||||
def get_bytes(self):
|
||||
assert len(self.kwargs) < 2
|
||||
res = int_to_bytes(opcodes['CISC'])
|
||||
res = LongArgFormat.encode(opcodes['CISC'])
|
||||
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
|
||||
name = self.function.__name__
|
||||
String.check(name)
|
||||
@@ -559,7 +590,7 @@ def cisc(function):
|
||||
same_sizes &= arg.size == args[0].size
|
||||
except:
|
||||
pass
|
||||
if program.options.cisc and same_sizes:
|
||||
if program.use_cisc() and same_sizes:
|
||||
return MergeCISC(*args, **kwargs)
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
@@ -572,9 +603,9 @@ def ret_cisc(function):
|
||||
instruction = cisc(instruction)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not program.options.cisc:
|
||||
return function(*args, **kwargs)
|
||||
from Compiler import types
|
||||
if not (program.options.cisc and isinstance(args[0], types._register)):
|
||||
return function(*args, **kwargs)
|
||||
if isinstance(args[0], types._clear):
|
||||
res_type = type(args[1])
|
||||
else:
|
||||
@@ -651,7 +682,8 @@ class RegisterArgFormat(ArgFormat):
|
||||
raise ArgumentError(arg, 'Invalid register argument')
|
||||
if arg.program != program.curr_tape:
|
||||
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
|
||||
util.format_trace(arg.caller))
|
||||
util.format_trace(arg.caller) +
|
||||
'\nMaybe use MemValue')
|
||||
if arg.reg_type != cls.reg_type:
|
||||
raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \
|
||||
(arg.reg_type, cls.reg_type))
|
||||
@@ -661,37 +693,68 @@ class RegisterArgFormat(ArgFormat):
|
||||
assert arg.i >= 0
|
||||
return int_to_bytes(arg.i)
|
||||
|
||||
def __init__(self, f):
|
||||
self.i = struct.unpack('>I', f.read(4))[0]
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
class ClearModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearModp
|
||||
name = 'cint'
|
||||
|
||||
class SecretModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretModp
|
||||
name = 'sint'
|
||||
|
||||
class ClearGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearGF2N
|
||||
name = 'cgf2n'
|
||||
|
||||
class SecretGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretGF2N
|
||||
name = 'sgf2n'
|
||||
|
||||
class ClearIntAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearInt
|
||||
name = 'regint'
|
||||
|
||||
class IntArgFormat(ArgFormat):
|
||||
n_bits = 32
|
||||
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, int) and not arg is None:
|
||||
raise ArgumentError(arg, 'Expected an integer-valued argument')
|
||||
if not arg is None:
|
||||
if not isinstance(arg, int):
|
||||
raise ArgumentError(arg, 'Expected an integer-valued argument')
|
||||
if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits:
|
||||
raise ArgumentError(
|
||||
arg, 'Immediate value outside of %d-bit range' % cls.n_bits)
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return int_to_bytes(arg)
|
||||
|
||||
def __init__(self, f):
|
||||
self.i = struct.unpack('>i', f.read(4))[0]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.i)
|
||||
|
||||
class LongArgFormat(IntArgFormat):
|
||||
n_bits = 64
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return list(struct.pack('>q', arg))
|
||||
|
||||
def __init__(self, f):
|
||||
self.i = struct.unpack('>q', f.read(8))[0]
|
||||
|
||||
class ImmediateModpAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
super(ImmediateModpAF, cls).check(arg)
|
||||
if arg >= 2**32 or arg < -2**32:
|
||||
raise ArgumentError(arg, 'Immediate value outside of 32-bit range')
|
||||
|
||||
class ImmediateGF2NAF(IntArgFormat):
|
||||
@classmethod
|
||||
@@ -702,6 +765,8 @@ class ImmediateGF2NAF(IntArgFormat):
|
||||
class PlayerNoAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not util.is_constant(arg):
|
||||
raise CompilerError('Player number must be known at compile time')
|
||||
super(PlayerNoAF, cls).check(arg)
|
||||
if arg > 256:
|
||||
raise ArgumentError(arg, 'Player number > 256')
|
||||
@@ -722,6 +787,13 @@ class String(ArgFormat):
|
||||
def encode(cls, arg):
|
||||
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
|
||||
|
||||
def __init__(self, f):
|
||||
tmp = f.read(16)
|
||||
self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')
|
||||
|
||||
def __str__(self):
|
||||
return self.str
|
||||
|
||||
ArgFormats = {
|
||||
'c': ClearModpAF,
|
||||
's': SecretModpAF,
|
||||
@@ -736,6 +808,7 @@ ArgFormats = {
|
||||
'i': ImmediateModpAF,
|
||||
'ig': ImmediateGF2NAF,
|
||||
'int': IntArgFormat,
|
||||
'long': LongArgFormat,
|
||||
'p': PlayerNoAF,
|
||||
'str': String,
|
||||
}
|
||||
@@ -776,7 +849,7 @@ class Instruction(object):
|
||||
return (prefix << self.code_length) + self.code
|
||||
|
||||
def get_encoding(self):
|
||||
enc = int_to_bytes(self.get_code())
|
||||
enc = LongArgFormat.encode(self.get_code())
|
||||
# add the number of registers if instruction flagged as has var args
|
||||
if self.has_var_args():
|
||||
enc += int_to_bytes(len(self.args))
|
||||
@@ -829,6 +902,7 @@ class Instruction(object):
|
||||
def is_vec(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_gf2n(self):
|
||||
return False
|
||||
|
||||
@@ -877,6 +951,10 @@ class Instruction(object):
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
@staticmethod
|
||||
def get_usage(args):
|
||||
return {}
|
||||
|
||||
# String version of instruction attempting to replicate encoded version
|
||||
def __str__(self):
|
||||
|
||||
@@ -890,6 +968,66 @@ class Instruction(object):
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
|
||||
|
||||
class ParsedInstruction:
|
||||
reverse_opcodes = {}
|
||||
|
||||
def __init__(self, f):
|
||||
cls = type(self)
|
||||
from Compiler import instructions
|
||||
from Compiler.GC import instructions as gc_inst
|
||||
if not cls.reverse_opcodes:
|
||||
for module in instructions, gc_inst:
|
||||
for x, y in inspect.getmodule(module).__dict__.items():
|
||||
if inspect.isclass(y) and y.__name__[0] != 'v':
|
||||
try:
|
||||
cls.reverse_opcodes[y.code] = y
|
||||
except AttributeError:
|
||||
pass
|
||||
read = lambda: struct.unpack('>I', f.read(4))[0]
|
||||
full_code = struct.unpack('>Q', f.read(8))[0]
|
||||
code = full_code % (1 << Instruction.code_length)
|
||||
self.size = full_code >> Instruction.code_length
|
||||
self.type = cls.reverse_opcodes[code]
|
||||
t = self.type
|
||||
name = t.__name__
|
||||
try:
|
||||
n_args = len(t.arg_format)
|
||||
self.var_args = False
|
||||
except:
|
||||
n_args = read()
|
||||
self.var_args = True
|
||||
try:
|
||||
arg_format = iter(t.arg_format)
|
||||
except:
|
||||
if name == 'cisc':
|
||||
arg_format = itertools.chain(['str'], itertools.repeat('int'))
|
||||
else:
|
||||
def arg_iter():
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
yield self.args[i].i
|
||||
except AttributeError:
|
||||
yield None
|
||||
i += 1
|
||||
arg_format = t.dynamic_arg_format(arg_iter())
|
||||
self.args = []
|
||||
for i in range(n_args):
|
||||
self.args.append(ArgFormats[next(arg_format)](f))
|
||||
|
||||
def __str__(self):
|
||||
name = self.type.__name__
|
||||
res = name + ' '
|
||||
if self.size > 1:
|
||||
res = 'v' + res + str(self.size) + ', '
|
||||
if self.var_args:
|
||||
res += str(len(self.args)) + ', '
|
||||
res += ', '.join(str(arg) for arg in self.args)
|
||||
return res
|
||||
|
||||
def get_usage(self):
|
||||
return self.type.get_usage(self.args)
|
||||
|
||||
class VarArgsInstruction(Instruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -901,6 +1039,26 @@ class VectorInstruction(Instruction):
|
||||
def get_code(self):
|
||||
return super(VectorInstruction, self).get_code(len(self.args[0]))
|
||||
|
||||
class DynFormatInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
return self.dynamic_arg_format(iter(self.args))
|
||||
|
||||
@classmethod
|
||||
def bases(self, args):
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
n = next(args)
|
||||
except StopIteration:
|
||||
return
|
||||
yield i, n
|
||||
i += n
|
||||
for j in range(n - 1):
|
||||
next(args)
|
||||
|
||||
###
|
||||
### Basic arithmetic
|
||||
###
|
||||
@@ -934,21 +1092,27 @@ class ClearImmediate(ImmediateBase):
|
||||
### Memory access instructions
|
||||
###
|
||||
|
||||
class DirectMemoryInstruction(Instruction):
|
||||
class MemoryInstruction(Instruction):
|
||||
__slots__ = ['_protect']
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MemoryInstruction, self).__init__(*args, **kwargs)
|
||||
self._protect = program._protect_memory
|
||||
|
||||
class DirectMemoryInstruction(MemoryInstruction):
|
||||
__slots__ = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
class IndirectMemoryInstruction(Instruction):
|
||||
class IndirectMemoryInstruction(MemoryInstruction):
|
||||
__slots__ = []
|
||||
|
||||
def get_direct(self, address):
|
||||
return self.direct(self.args[0], address, add_to_prog=False)
|
||||
|
||||
class ReadMemoryInstruction(Instruction):
|
||||
class ReadMemoryInstruction(MemoryInstruction):
|
||||
__slots__ = []
|
||||
|
||||
class WriteMemoryInstruction(Instruction):
|
||||
class WriteMemoryInstruction(MemoryInstruction):
|
||||
__slots__ = []
|
||||
|
||||
class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
|
||||
@@ -975,12 +1139,16 @@ class IOInstruction(DoNotEliminateInstruction):
|
||||
@classmethod
|
||||
def str_to_int(cls, s):
|
||||
""" Convert a 4 character string to an integer. """
|
||||
try:
|
||||
s = bytearray(s, 'utf8')
|
||||
except:
|
||||
pass
|
||||
if len(s) > 4:
|
||||
raise CompilerError('String longer than 4 characters')
|
||||
n = 0
|
||||
for c in reversed(s.ljust(4)):
|
||||
n <<= 8
|
||||
n += ord(c)
|
||||
n += c
|
||||
return n
|
||||
|
||||
class AsymmetricCommunicationInstruction(DoNotEliminateInstruction):
|
||||
@@ -999,6 +1167,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
|
||||
""" Input from text file or stdin """
|
||||
__slots__ = []
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.get_players():
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
|
||||
###
|
||||
### Data access instructions
|
||||
###
|
||||
|
||||
@@ -12,6 +12,7 @@ import inspect,math
|
||||
import random
|
||||
import collections
|
||||
import operator
|
||||
import copy
|
||||
from functools import reduce
|
||||
|
||||
def get_program():
|
||||
@@ -63,6 +64,7 @@ def print_str(s, *args):
|
||||
variables/registers with ``%s``. """
|
||||
def print_plain_str(ss):
|
||||
""" Print a plain string (no custom formatting options) """
|
||||
ss = bytearray(ss, 'utf8')
|
||||
i = 1
|
||||
while 4*i <= len(ss):
|
||||
print_char4(ss[4*(i-1):4*i])
|
||||
@@ -115,7 +117,12 @@ def print_ln(s='', *args):
|
||||
|
||||
print_ln('a is %s.', a.reveal())
|
||||
"""
|
||||
print_str(s + '\n', *args)
|
||||
print_str(str(s) + '\n', *args)
|
||||
|
||||
def print_both(s, end='\n'):
|
||||
""" Print line during compilation and execution. """
|
||||
print(s, end=end)
|
||||
print_str(s + end)
|
||||
|
||||
def print_ln_if(cond, ss, *args):
|
||||
""" Print line if :py:obj:`cond` is true. The further arguments
|
||||
@@ -137,7 +144,7 @@ def print_str_if(cond, ss, *args):
|
||||
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
|
||||
if util.is_constant(cond):
|
||||
if cond:
|
||||
print_ln(ss, *args)
|
||||
print_str(ss, *args)
|
||||
else:
|
||||
subs = ss.split('%s')
|
||||
assert len(subs) == len(args) + 1
|
||||
@@ -153,7 +160,8 @@ def print_str_if(cond, ss, *args):
|
||||
print_str_if(cond, *_expand_to_print(val))
|
||||
else:
|
||||
print_str_if(cond, str(val))
|
||||
s += '\0' * ((-len(s)) % 4)
|
||||
s = bytearray(s, 'utf8')
|
||||
s += b'\0' * ((-len(s)) % 4)
|
||||
while s:
|
||||
cond.print_if(s[:4])
|
||||
s = s[4:]
|
||||
@@ -219,7 +227,10 @@ def crash(condition=None):
|
||||
:param condition: crash if true (default: true)
|
||||
|
||||
"""
|
||||
if condition == None:
|
||||
if isinstance(condition, localint):
|
||||
# allow crash on local values
|
||||
condition = condition._v
|
||||
if condition is None:
|
||||
condition = regint(1)
|
||||
instructions.crash(regint.conv(condition))
|
||||
|
||||
@@ -239,6 +250,10 @@ def store_in_mem(value, address):
|
||||
try:
|
||||
value.store_in_mem(address)
|
||||
except AttributeError:
|
||||
if isinstance(value, (list, tuple)):
|
||||
for i, x in enumerate(value):
|
||||
store_in_mem(x, address + i)
|
||||
return
|
||||
# legacy
|
||||
if value.is_clear:
|
||||
if isinstance(address, cint):
|
||||
@@ -257,11 +272,13 @@ def reveal(secret):
|
||||
try:
|
||||
return secret.reveal()
|
||||
except AttributeError:
|
||||
if secret.is_clear:
|
||||
return secret
|
||||
if secret.is_gf2n:
|
||||
res = cgf2n()
|
||||
else:
|
||||
res = cint()
|
||||
instructions.asm_open(res, secret)
|
||||
instructions.asm_open(True, res, secret)
|
||||
return res
|
||||
|
||||
@vectorize
|
||||
@@ -278,13 +295,13 @@ def get_arg():
|
||||
ldarg(res)
|
||||
return res
|
||||
|
||||
def make_array(l):
|
||||
def make_array(l, t=None):
|
||||
if isinstance(l, program.Tape.Register):
|
||||
res = Array(1, type(l))
|
||||
res[0] = l
|
||||
res = Array(len(l), t or type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
l = list(l)
|
||||
res = Array(len(l), type(l[0]) if l else cint)
|
||||
res = Array(len(l), t or type(l[0]) if l else cint)
|
||||
res.assign(l)
|
||||
return res
|
||||
|
||||
@@ -456,6 +473,10 @@ def method_block(function):
|
||||
return wrapper
|
||||
|
||||
def cond_swap(x,y):
|
||||
from .types import SubMultiArray
|
||||
if isinstance(x, (Array, SubMultiArray)):
|
||||
b = x[0] > y[0]
|
||||
return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)]))
|
||||
b = x < y
|
||||
if isinstance(x, sfloat):
|
||||
res = ([], [])
|
||||
@@ -467,11 +488,11 @@ def cond_swap(x,y):
|
||||
res[0].append(bx + yy - by)
|
||||
res[1].append(xx - bx + by)
|
||||
return sfloat(*res[0]), sfloat(*res[1])
|
||||
bx = b * x
|
||||
by = b * y
|
||||
return bx + y - by, x - bx + by
|
||||
return b.cond_swap(y, x)
|
||||
|
||||
def sort(a):
|
||||
print("WARNING: you're using bubble sort")
|
||||
|
||||
res = a
|
||||
|
||||
for i in range(len(a)):
|
||||
@@ -497,282 +518,36 @@ def odd_even_merge_sort(a):
|
||||
if len(a) == 1:
|
||||
return
|
||||
elif len(a) % 2 == 0:
|
||||
aa = a
|
||||
a = list(a)
|
||||
lower = a[:len(a)//2]
|
||||
upper = a[len(a)//2:]
|
||||
odd_even_merge_sort(lower)
|
||||
odd_even_merge_sort(upper)
|
||||
a[:] = lower + upper
|
||||
odd_even_merge(a)
|
||||
aa[:] = a
|
||||
else:
|
||||
raise CompilerError('Length of list must be power of two')
|
||||
|
||||
def chunky_odd_even_merge_sort(a):
|
||||
tmp = a[0].Array(len(a))
|
||||
for i,j in enumerate(a):
|
||||
tmp[i] = j
|
||||
l = 1
|
||||
while l < len(a):
|
||||
l *= 2
|
||||
k = 1
|
||||
while k < l:
|
||||
k *= 2
|
||||
def round():
|
||||
for i in range(len(a)):
|
||||
a[i] = tmp[i]
|
||||
for i in range(len(a) // l):
|
||||
for j in range(l // k):
|
||||
base = i * l + j
|
||||
step = l // k
|
||||
if k == 2:
|
||||
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
||||
else:
|
||||
b = a[base:base+k*step:step]
|
||||
for m in range(base + step, base + (k - 1) * step, 2 * step):
|
||||
a[m], a[m+step] = cond_swap(a[m], a[m+step])
|
||||
for i in range(len(a)):
|
||||
tmp[i] = a[i]
|
||||
chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True)
|
||||
chunk.start()
|
||||
chunk.join()
|
||||
#round()
|
||||
for i in range(len(a)):
|
||||
a[i] = tmp[i]
|
||||
raise CompilerError(
|
||||
'This function has been removed, use loopy_odd_even_merge_sort instead')
|
||||
|
||||
def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False):
|
||||
if n is None:
|
||||
n = len(a)
|
||||
a_base = instructions.program.malloc(n, 's')
|
||||
for i,j in enumerate(a):
|
||||
store_in_mem(j, a_base + i)
|
||||
else:
|
||||
a_base = a
|
||||
tmp_base = instructions.program.malloc(n, 's')
|
||||
chunks = {}
|
||||
threads = []
|
||||
|
||||
def run_threads():
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
del threads[:]
|
||||
|
||||
def run_chunk(size, base):
|
||||
if size not in chunks:
|
||||
def swap_list(list_base):
|
||||
for i in range(size // 2):
|
||||
base = list_base + 2 * i
|
||||
x, y = cond_swap(sint.load_mem(base),
|
||||
sint.load_mem(base + 1))
|
||||
store_in_mem(x, base)
|
||||
store_in_mem(y, base + 1)
|
||||
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
|
||||
return chunks[size](base)
|
||||
|
||||
def run_round(size):
|
||||
# minimize number of chunk sizes
|
||||
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
||||
lower_size = size // n_chunks // 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
|
||||
# print len(to_swap) == lower_size * n_lower_size + \
|
||||
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
||||
# len(to_swap), n_chunks, lower_size, n_lower_size
|
||||
base = 0
|
||||
round_threads = []
|
||||
for i in range(n_lower_size):
|
||||
round_threads.append(run_chunk(lower_size, tmp_base + base))
|
||||
base += lower_size
|
||||
for i in range(n_chunks - n_lower_size):
|
||||
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
|
||||
base += lower_size + 2
|
||||
run_threads_in_rounds(round_threads)
|
||||
|
||||
postproc_chunks = []
|
||||
wrap_chunks = {}
|
||||
post_threads = []
|
||||
pre_threads = []
|
||||
|
||||
def load_and_store(x, y, to_right):
|
||||
if to_right:
|
||||
store_in_mem(sint.load_mem(x), y)
|
||||
else:
|
||||
store_in_mem(sint.load_mem(y), x)
|
||||
|
||||
def run_setup(k, a_addr, step, tmp_addr):
|
||||
if k == 2:
|
||||
def mem_op(preproc, a_addr, step, tmp_addr):
|
||||
load_and_store(a_addr, tmp_addr, preproc)
|
||||
load_and_store(a_addr + step, tmp_addr + 1, preproc)
|
||||
res = 2
|
||||
else:
|
||||
def mem_op(preproc, a_addr, step, tmp_addr):
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)):
|
||||
for i in range(k - 2):
|
||||
m = a_addr + step + i * step
|
||||
load_and_store(m, tmp_addr + i, preproc)
|
||||
res = k - 2
|
||||
if not use_chunk_wraps or k <= 4:
|
||||
mem_op(True, a_addr, step, tmp_addr)
|
||||
postproc_chunks.append((mem_op, (a_addr, step, tmp_addr)))
|
||||
else:
|
||||
if k not in wrap_chunks:
|
||||
pre_chunk = FunctionTape(mem_op, 'pre-%d' % k,
|
||||
compile_args=[True])
|
||||
post_chunk = FunctionTape(mem_op, 'post-%d' % k,
|
||||
compile_args=[False])
|
||||
wrap_chunks[k] = (pre_chunk, post_chunk)
|
||||
pre_chunk, post_chunk = wrap_chunks[k]
|
||||
pre_threads.append(pre_chunk(a_addr, step, tmp_addr))
|
||||
post_threads.append(post_chunk(a_addr, step, tmp_addr))
|
||||
return res
|
||||
|
||||
def run_threads_in_rounds(all_threads):
|
||||
for thread in all_threads:
|
||||
if len(threads) == n_threads:
|
||||
run_threads()
|
||||
threads.append(thread)
|
||||
run_threads()
|
||||
del all_threads[:]
|
||||
|
||||
def run_postproc():
|
||||
run_threads_in_rounds(post_threads)
|
||||
for chunk,args in postproc_chunks:
|
||||
chunk(False, *args)
|
||||
postproc_chunks[:] = []
|
||||
|
||||
l = 1
|
||||
while l < n:
|
||||
l *= 2
|
||||
k = 1
|
||||
while k < l:
|
||||
k *= 2
|
||||
size = 0
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
for i in range(n // l):
|
||||
for j in range(l // k):
|
||||
base = i * l + j
|
||||
step = l // k
|
||||
size += run_setup(k, a_base + base, step, tmp_base + size)
|
||||
run_threads_in_rounds(pre_threads)
|
||||
run_round(size)
|
||||
run_postproc()
|
||||
|
||||
if isinstance(a, list):
|
||||
for i in range(n):
|
||||
a[i] = sint.load_mem(a_base + i)
|
||||
instructions.program.free(a_base, 's')
|
||||
instructions.program.free(tmp_base, 's')
|
||||
raise CompilerError(
|
||||
'This function has been removed, use loopy_odd_even_merge_sort instead')
|
||||
|
||||
def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7):
|
||||
if n is None:
|
||||
n = len(a)
|
||||
a_base = instructions.program.malloc(n, 's')
|
||||
for i,j in enumerate(a):
|
||||
store_in_mem(j, a_base + i)
|
||||
else:
|
||||
a_base = a
|
||||
tmp_base = instructions.program.malloc(n, 's')
|
||||
tmp_i = instructions.program.malloc(1, 'ci')
|
||||
chunks = {}
|
||||
threads = []
|
||||
|
||||
def run_threads():
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
del threads[:]
|
||||
|
||||
def run_threads_in_rounds(all_threads):
|
||||
for thread in all_threads:
|
||||
if len(threads) == n_threads:
|
||||
run_threads()
|
||||
threads.append(thread)
|
||||
run_threads()
|
||||
del all_threads[:]
|
||||
|
||||
def run_chunk(size, base):
|
||||
if size not in chunks:
|
||||
def swap_list(list_base):
|
||||
for i in range(size // 2):
|
||||
base = list_base + 2 * i
|
||||
x, y = cond_swap(sint.load_mem(base),
|
||||
sint.load_mem(base + 1))
|
||||
store_in_mem(x, base)
|
||||
store_in_mem(y, base + 1)
|
||||
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
|
||||
return chunks[size](base)
|
||||
|
||||
def run_round(size):
|
||||
# minimize number of chunk sizes
|
||||
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
||||
lower_size = size // n_chunks // 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
|
||||
# print len(to_swap) == lower_size * n_lower_size + \
|
||||
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
||||
# len(to_swap), n_chunks, lower_size, n_lower_size
|
||||
base = 0
|
||||
round_threads = []
|
||||
for i in range(n_lower_size):
|
||||
round_threads.append(run_chunk(lower_size, tmp_base + base))
|
||||
base += lower_size
|
||||
for i in range(n_chunks - n_lower_size):
|
||||
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
|
||||
base += lower_size + 2
|
||||
run_threads_in_rounds(round_threads)
|
||||
|
||||
l = 1
|
||||
while l < n:
|
||||
l *= 2
|
||||
k = 1
|
||||
while k < l:
|
||||
k *= 2
|
||||
def load_and_store(x, y):
|
||||
if to_tmp:
|
||||
store_in_mem(sint.load_mem(x), y)
|
||||
else:
|
||||
store_in_mem(sint.load_mem(y), x)
|
||||
def outer(i):
|
||||
def inner(j):
|
||||
base = j + a_base + i * l
|
||||
step = l // k
|
||||
if k == 2:
|
||||
tmp_addr = regint.load_mem(tmp_i)
|
||||
load_and_store(base, tmp_addr)
|
||||
load_and_store(base + step, tmp_addr + 1)
|
||||
store_in_mem(tmp_addr + 2, tmp_i)
|
||||
else:
|
||||
def inner2(m):
|
||||
m += base
|
||||
tmp_addr = regint.load_mem(tmp_i)
|
||||
load_and_store(m, tmp_addr)
|
||||
store_in_mem(tmp_addr + 1, tmp_i)
|
||||
range_loop(inner2, step, (k - 1) * step, step)
|
||||
range_loop(inner, l // k)
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
to_tmp = True
|
||||
store_in_mem(tmp_base, tmp_i)
|
||||
range_loop(outer, n // l)
|
||||
if k == 2:
|
||||
run_round(n)
|
||||
else:
|
||||
run_round(n // k * (k - 2))
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
to_tmp = False
|
||||
store_in_mem(tmp_base, tmp_i)
|
||||
range_loop(outer, n // l)
|
||||
|
||||
if isinstance(a, list):
|
||||
for i in range(n):
|
||||
a[i] = sint.load_mem(a_base + i)
|
||||
instructions.program.free(a_base, 's')
|
||||
instructions.program.free(tmp_base, 's')
|
||||
instructions.program.free(tmp_i, 'ci')
|
||||
raise CompilerError(
|
||||
'This function has been removed, use loopy_odd_even_merge_sort instead')
|
||||
|
||||
|
||||
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
|
||||
n_threads=None):
|
||||
a_in = a
|
||||
if isinstance(a_in, list):
|
||||
a = Array.create_from(a)
|
||||
steps = {}
|
||||
l = sorted_length
|
||||
while l < len(a):
|
||||
@@ -816,8 +591,14 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
|
||||
swap(m2, step)
|
||||
steps[key] = step
|
||||
steps[key](l)
|
||||
if isinstance(a_in, list):
|
||||
a_in[:] = list(a)
|
||||
|
||||
def mergesort(A):
|
||||
if not get_program().options.insecure:
|
||||
raise CompilerError('mergesort reveals the order of elements, '
|
||||
'use --insecure to activate it')
|
||||
|
||||
B = Array(len(A), sint)
|
||||
|
||||
def merge(i_left, i_right, i_end):
|
||||
@@ -845,12 +626,18 @@ def mergesort(A):
|
||||
width.imul(2)
|
||||
return width < len(A)
|
||||
|
||||
def range_loop(loop_body, start, stop=None, step=None):
|
||||
def _range_prep(start, stop, step):
|
||||
if stop is None:
|
||||
stop = start
|
||||
start = 0
|
||||
if step is None:
|
||||
step = 1
|
||||
if util.is_zero(step):
|
||||
raise CompilerError('step must not be zero')
|
||||
return start, stop, step
|
||||
|
||||
def range_loop(loop_body, start, stop=None, step=None):
|
||||
start, stop, step = _range_prep(start, stop, step)
|
||||
def loop_fn(i):
|
||||
res = loop_body(i)
|
||||
return util.if_else(res == 0, stop, i + step)
|
||||
@@ -859,8 +646,6 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
condition = lambda x: x < stop
|
||||
elif step < 0:
|
||||
condition = lambda x: x > stop
|
||||
else:
|
||||
raise CompilerError('step must not be zero')
|
||||
else:
|
||||
b = step > 0
|
||||
condition = lambda x: b * (x < stop) + (1 - b) * (x > stop)
|
||||
@@ -870,36 +655,34 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
# known loop count
|
||||
if condition(start):
|
||||
get_tape().req_node.children[-1].aggregator = \
|
||||
lambda x: ((stop - start) // step) * x[0]
|
||||
lambda x: int(ceil(((stop - start) / step))) * x[0]
|
||||
|
||||
def for_range(start, stop=None, step=None):
|
||||
"""
|
||||
Decorator to execute loop bodies consecutively. Arguments work as
|
||||
in Python :py:func:`range`, but they can by any public
|
||||
in Python :py:func:`range`, but they can be any public
|
||||
integer. Information has to be passed out via container types such
|
||||
as :py:class:`~Compiler.types.Array` or declaring registers as
|
||||
:py:obj:`global`. Note that changing Python data structures such
|
||||
as :py:class:`~Compiler.types.Array` or using :py:func:`update`.
|
||||
Note that changing Python data structures such
|
||||
as lists within the loop is not possible, but the compiler cannot
|
||||
warn about this.
|
||||
|
||||
:param start/stop/step: regint/cint/int
|
||||
|
||||
Example:
|
||||
|
||||
.. code::
|
||||
The following should output 10::
|
||||
|
||||
n = 10
|
||||
a = sint.Array(n)
|
||||
x = sint(0)
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
a[i] = i
|
||||
global x
|
||||
x += 1
|
||||
x.update(x + 1)
|
||||
print_ln('%s', x.reveal())
|
||||
|
||||
Note that you cannot overwrite data structures such as
|
||||
:py:class:`~Compiler.types.Array` in a loop even when using
|
||||
:py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign`
|
||||
instead.
|
||||
:py:class:`~Compiler.types.Array` in a loop. Use
|
||||
:py:func:`~Compiler.types.Array.assign` instead.
|
||||
"""
|
||||
def decorator(loop_body):
|
||||
range_loop(loop_body, start, stop, step)
|
||||
@@ -909,11 +692,13 @@ def for_range(start, stop=None, step=None):
|
||||
def for_range_parallel(n_parallel, n_loops):
|
||||
"""
|
||||
Decorator to execute a loop :py:obj:`n_loops` up to
|
||||
:py:obj:`n_parallel` loop bodies in parallel.
|
||||
:py:obj:`n_parallel` loop bodies with optimized communication in a
|
||||
single thread.
|
||||
In most cases, it is easier to use :py:func:`for_range_opt`.
|
||||
Using any other control flow instruction inside the loop breaks
|
||||
the optimization.
|
||||
|
||||
:param n_parallel: compile-time (int)
|
||||
:param n_parallel: optimization parameter (int)
|
||||
:param n_loops: regint/cint/int or list of int
|
||||
|
||||
Example:
|
||||
@@ -937,7 +722,7 @@ def for_range_parallel(n_parallel, n_loops):
|
||||
return for_range_multithread(None, n_parallel, n_loops)
|
||||
return map_reduce_single(n_parallel, n_loops)
|
||||
|
||||
def for_range_opt(n_loops, budget=None):
|
||||
def for_range_opt(start, stop=None, step=None, budget=None):
|
||||
""" Execute loop bodies in parallel up to an optimization budget.
|
||||
This prevents excessive loop unrolling. The budget is respected
|
||||
even with nested loops. Note that the optimization is rather
|
||||
@@ -947,8 +732,10 @@ def for_range_opt(n_loops, budget=None):
|
||||
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
|
||||
optimization.
|
||||
|
||||
:param n_loops: int/regint/cint
|
||||
:param budget: number of instructions after which to start optimization (default is 100,000)
|
||||
:param start/stop/step: int/regint/cint (used as in :py:func:`range`)
|
||||
or :py:obj:`start` only as list/tuple of int (see below)
|
||||
:param budget: number of instructions after which to start optimization
|
||||
(default is 100,000)
|
||||
|
||||
Example:
|
||||
|
||||
@@ -968,6 +755,15 @@ def for_range_opt(n_loops, budget=None):
|
||||
def f(i, j):
|
||||
...
|
||||
"""
|
||||
if stop is not None:
|
||||
start, stop, step = _range_prep(start, stop, step)
|
||||
def wrapper(loop_body):
|
||||
n_loops = (step - 1 + stop - start) // step
|
||||
@for_range_opt(n_loops, budget=budget)
|
||||
def _(i):
|
||||
return loop_body(start + i * step)
|
||||
return wrapper
|
||||
n_loops = start
|
||||
if isinstance(n_loops, (list, tuple)):
|
||||
return for_range_opt_multithread(None, n_loops)
|
||||
return map_reduce_single(None, n_loops, budget=budget)
|
||||
@@ -1009,9 +805,11 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
def f(i):
|
||||
state = tuplify(initializer())
|
||||
start_block = get_block()
|
||||
j = i * n_parallel
|
||||
one = regint(1)
|
||||
for k in range(n_parallel):
|
||||
j = i * n_parallel + k
|
||||
state = reducer(tuplify(loop_body(j)), state)
|
||||
j += one
|
||||
if n_parallel > 1 and start_block != get_block():
|
||||
print('WARNING: parallelization broken '
|
||||
'by control flow instruction')
|
||||
@@ -1028,12 +826,16 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
state = tuplify(initializer())
|
||||
k = 0
|
||||
block = get_block()
|
||||
assert not isinstance(n_loops, int) or n_loops > 0
|
||||
pre = copy.copy(loop_body.__globals__)
|
||||
while (not util.is_constant(n_loops) or k < n_loops) \
|
||||
and (len(get_block()) < budget or k == 0) \
|
||||
and block is get_block():
|
||||
j = i + k
|
||||
state = reducer(tuplify(loop_body(j)), state)
|
||||
k += 1
|
||||
RegintOptimizer().run(block.instructions, get_program())
|
||||
_link(pre, loop_body.__globals__)
|
||||
r = reducer(mem_state, state)
|
||||
write_state_to_memory(r)
|
||||
global n_opt_loops
|
||||
@@ -1064,7 +866,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
del blocks[-n_to_merge + 1:]
|
||||
del get_tape().req_node.children[-1]
|
||||
merged.children = []
|
||||
RegintOptimizer().run(merged.instructions)
|
||||
RegintOptimizer().run(merged.instructions, get_program())
|
||||
get_tape().active_basicblock = merged
|
||||
else:
|
||||
req_node = get_tape().req_node.children[-1].nodes[0]
|
||||
@@ -1131,6 +933,15 @@ def for_range_opt_multithread(n_threads, n_loops):
|
||||
@for_range_opt_multithread(2, [5, 3])
|
||||
def f(i, j):
|
||||
...
|
||||
|
||||
Note that you cannot use registers across threads. Use
|
||||
:py:class:`MemValue` instead::
|
||||
|
||||
a = MemValue(sint(0))
|
||||
@for_range_opt_multithread(8, 80)
|
||||
def _(i):
|
||||
b = a + 1
|
||||
|
||||
"""
|
||||
return for_range_multithread(n_threads, None, n_loops)
|
||||
|
||||
@@ -1142,6 +953,7 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
|
||||
:param n_threads: compile-time (int)
|
||||
:param n_items: regint/cint/int (default: :py:obj:`n_threads`)
|
||||
:param max_size: maximum size to be processed at once (default: no limit)
|
||||
|
||||
The following executes ``f(0, 8)``, ``f(8, 8)``, and
|
||||
``f(16, 9)`` in three different threads:
|
||||
@@ -1158,6 +970,7 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
|
||||
reducer=None, looping=False)
|
||||
else:
|
||||
max_size = max(1, max_size)
|
||||
def wrapper(function):
|
||||
@multithread(n_threads, n_items)
|
||||
def new_function(base, size):
|
||||
@@ -1205,7 +1018,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if t != regint:
|
||||
raise CompilerError('Not implemented for other than regint')
|
||||
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
|
||||
state = tuple(initializer())
|
||||
state = initializer()
|
||||
if len(state) == 0:
|
||||
state_type = cint
|
||||
elif isinstance(state, (tuple, list)):
|
||||
state_type = type(state[0])
|
||||
else:
|
||||
state_type = type(state)
|
||||
def f(inc):
|
||||
base = args[get_arg()][0]
|
||||
if not util.is_constant(thread_rounds):
|
||||
@@ -1218,8 +1037,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if thread_mem_req:
|
||||
thread_mem = Array(thread_mem_req[regint], regint, \
|
||||
args[get_arg()].address + 2)
|
||||
mem_state = Array(len(state), type(state[0]) \
|
||||
if state else cint, args[get_arg()][1])
|
||||
mem_state = Array(len(state), state_type, args[get_arg()][1])
|
||||
@map_reduce_single(n_parallel, thread_rounds + inc, \
|
||||
initializer, reducer, mem_state)
|
||||
def f(i):
|
||||
@@ -1251,14 +1069,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
if state:
|
||||
if len(state):
|
||||
if thread_rounds:
|
||||
for i in range(n_threads - remainder):
|
||||
state = reducer(Array(len(state), type(state[0]), \
|
||||
state = reducer(Array(len(state), state_type, \
|
||||
args[remainder + i][1]), state)
|
||||
if remainder:
|
||||
for i in range(remainder):
|
||||
state = reducer(Array(len(state), type(state[0]).reg_type, \
|
||||
state = reducer(Array(len(state), state_type, \
|
||||
args[i][1]), state)
|
||||
def returner():
|
||||
return untuplify(state)
|
||||
@@ -1294,7 +1112,50 @@ def map_sum_opt(n_threads, n_loops, types):
|
||||
"""
|
||||
return map_sum(n_threads, None, n_loops, len(types), types)
|
||||
|
||||
def map_sum_simple(n_threads, n_loops, type, size):
|
||||
""" Vectorized multi-threaded sum reduction. The following computes a
|
||||
100 sums of ten squares in three threads::
|
||||
|
||||
@map_sum_simple(3, 10, sint, 100)
|
||||
def summer(i):
|
||||
return sint(regint.inc(100, i, 0)) ** 2
|
||||
|
||||
result = summer()
|
||||
|
||||
:param n_threads: number of threads (int)
|
||||
:param n_loops: number of loop runs (regint/cint/int)
|
||||
:param type: return type, must match the return statement
|
||||
in the loop
|
||||
:param size: vector size, must match the return statement
|
||||
in the loop
|
||||
|
||||
"""
|
||||
initializer = lambda: type(0, size=size)
|
||||
def summer(*args):
|
||||
assert len(args) == 2
|
||||
args = list(args)
|
||||
for i in (0, 1):
|
||||
if isinstance(args[i], tuple):
|
||||
assert len(args[i]) == 1
|
||||
args[i] = args[i][0]
|
||||
for i in (0, 1):
|
||||
assert len(args[i]) == size
|
||||
if isinstance(args[i], Array):
|
||||
args[i] = args[i][:]
|
||||
return args[0] + args[1]
|
||||
return map_reduce(n_threads, 1, n_loops, initializer, summer)
|
||||
|
||||
def tree_reduce_multithread(n_threads, function, vector):
|
||||
""" Round-efficient reduction in several threads. The following code
|
||||
computes the maximum of an array in 10 threads::
|
||||
|
||||
tree_reduce_multithread(10, lambda x, y: x.max(y), a)
|
||||
|
||||
:param n_threads: number of threads (int)
|
||||
:param function: reduction function taking exactly two arguments
|
||||
:param vector: register vector or array
|
||||
|
||||
"""
|
||||
inputs = vector.Array(len(vector))
|
||||
inputs.assign_vector(vector)
|
||||
outputs = vector.Array(len(vector) // 2)
|
||||
@@ -1311,6 +1172,18 @@ def tree_reduce_multithread(n_threads, function, vector):
|
||||
left = (left + 1) // 2
|
||||
return inputs[0]
|
||||
|
||||
def tree_reduce(function, sequence):
|
||||
""" Round-efficient reduction. The following computes the maximum
|
||||
of the list :py:obj:`l`::
|
||||
|
||||
m = tree_reduce(lambda x, y: x.max(y), l)
|
||||
|
||||
:param function: reduction function taking two arguments
|
||||
:param sequence: list, vector, or array
|
||||
|
||||
"""
|
||||
return util.tree_reduce(function, sequence)
|
||||
|
||||
def foreach_enumerate(a):
|
||||
""" Run-time loop over public data. This uses
|
||||
``Player-Data/Public-Input/<progname>``. Example:
|
||||
@@ -1338,61 +1211,57 @@ def foreach_enumerate(a):
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def while_loop(loop_body, condition, arg, g=None):
|
||||
def while_loop(loop_body, condition, arg=None, g=None):
|
||||
if not callable(condition):
|
||||
raise CompilerError('Condition must be callable')
|
||||
# store arg in stack
|
||||
pre_condition = condition(arg)
|
||||
if not isinstance(pre_condition, (bool,int)) or pre_condition:
|
||||
if arg is None:
|
||||
pre_condition = condition()
|
||||
def loop_fn():
|
||||
loop_body()
|
||||
return condition()
|
||||
else:
|
||||
pre_condition = condition(arg)
|
||||
arg = regint(arg)
|
||||
def loop_fn():
|
||||
result = loop_body(arg)
|
||||
result.link(arg)
|
||||
cont = condition(result)
|
||||
return cont
|
||||
if isinstance(result, MemValue):
|
||||
result = result.read()
|
||||
arg.update(result)
|
||||
return condition(result)
|
||||
if not isinstance(pre_condition, (bool,int)) or pre_condition:
|
||||
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
|
||||
|
||||
def while_do(condition, *args):
|
||||
""" While-do loop. The decorator requires an initialization, and
|
||||
the loop body function must return a suitable input for
|
||||
:py:obj:`condition`.
|
||||
""" While-do loop.
|
||||
|
||||
:param condition: function returning public integer (regint/cint/int)
|
||||
:param args: arguments given to :py:obj:`condition` and loop body
|
||||
|
||||
The following executes an ten-fold loop:
|
||||
|
||||
.. code::
|
||||
|
||||
@while_do(lambda x: x < 10, regint(0))
|
||||
def f(i):
|
||||
i = regint(0)
|
||||
@while_do(lambda: i < 10)
|
||||
def f():
|
||||
...
|
||||
return i + 1
|
||||
i.update(i + 1)
|
||||
...
|
||||
|
||||
"""
|
||||
def decorator(loop_body):
|
||||
while_loop(loop_body, condition, *args)
|
||||
return loop_body
|
||||
return decorator
|
||||
|
||||
def do_loop(condition, loop_fn):
|
||||
# store initial condition to stack
|
||||
pushint(condition if isinstance(condition,regint) else regint(condition))
|
||||
def wrapped_loop():
|
||||
# save condition to stack
|
||||
new_cond = regint.pop()
|
||||
# run the loop
|
||||
condition = loop_fn(new_cond)
|
||||
pushint(condition)
|
||||
return condition
|
||||
do_while(wrapped_loop)
|
||||
regint.pop()
|
||||
|
||||
def _run_and_link(function, g=None):
|
||||
if g is None:
|
||||
g = function.__globals__
|
||||
import copy
|
||||
pre = copy.copy(g)
|
||||
res = function()
|
||||
_link(pre, g)
|
||||
return res
|
||||
|
||||
def _link(pre, g):
|
||||
if g:
|
||||
from .types import _single
|
||||
for name, var in pre.items():
|
||||
@@ -1402,7 +1271,6 @@ def _run_and_link(function, g=None):
|
||||
raise CompilerError('cannot reassign constants in blocks')
|
||||
if id(new_var) != id(var):
|
||||
new_var.link(var)
|
||||
return res
|
||||
|
||||
def do_while(loop_fn, g=None):
|
||||
""" Do-while loop. The loop is stopped if the return value is zero.
|
||||
@@ -1442,11 +1310,17 @@ def if_then(condition):
|
||||
state = State()
|
||||
if callable(condition):
|
||||
condition = condition()
|
||||
try:
|
||||
if not condition.is_clear:
|
||||
raise CompilerError('cannot branch on secret values')
|
||||
except AttributeError:
|
||||
pass
|
||||
state.condition = regint.conv(condition)
|
||||
state.start_block = instructions.program.curr_block
|
||||
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
|
||||
name='if-block')
|
||||
state.has_else = False
|
||||
state.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
instructions.program.curr_tape.if_states.append(state)
|
||||
|
||||
def else_then():
|
||||
@@ -1531,6 +1405,8 @@ def if_(condition):
|
||||
def if_e(condition):
|
||||
"""
|
||||
Conditional execution with else block.
|
||||
Use :py:class:`~Compiler.types.MemValue` to assign values that
|
||||
live beyond.
|
||||
|
||||
:param condition: regint/cint/int
|
||||
|
||||
@@ -1538,12 +1414,13 @@ def if_e(condition):
|
||||
|
||||
.. code::
|
||||
|
||||
y = MemValue(0)
|
||||
@if_e(x > 0)
|
||||
def _():
|
||||
...
|
||||
y.write(1)
|
||||
@else_
|
||||
def _():
|
||||
...
|
||||
y.write(0)
|
||||
"""
|
||||
try:
|
||||
condition = bool(condition)
|
||||
@@ -1647,11 +1524,18 @@ def get_player_id():
|
||||
return res
|
||||
|
||||
def listen_for_clients(port):
|
||||
""" Listen for clients on specific port. """
|
||||
""" Listen for clients on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
"""
|
||||
instructions.listen(regint.conv(port))
|
||||
|
||||
def accept_client_connection(port):
|
||||
""" Listen for clients on specific port. """
|
||||
""" Accept client connection on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
:returns: client id
|
||||
"""
|
||||
res = regint()
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
return res
|
||||
@@ -1779,7 +1663,9 @@ def sint_cint_division(a, b, k, f, kappa):
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
def IntDiv(a, b, k, kappa=None):
|
||||
return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k,
|
||||
l = 2 * k + 1
|
||||
b = a.conv(b)
|
||||
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
|
||||
kappa, nearest=True)
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
@@ -1801,24 +1687,25 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
|
||||
base.set_global_vector_size(b.size)
|
||||
alpha = b.get_type(2 * k).two_power(2*f)
|
||||
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
|
||||
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
|
||||
x = alpha - b.extend(2 * k) * w
|
||||
base.reset_global_vector_size()
|
||||
|
||||
y = a.extend(2 *k) * w
|
||||
y = y.round(2*k, f, kappa, nearest, signed=True)
|
||||
l_y = k + 3 * f - res_f
|
||||
y = a.extend(l_y) * w
|
||||
y = y.round(l_y, f, kappa, nearest, signed=True)
|
||||
|
||||
for i in range(theta - 1):
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
x = x * x
|
||||
y = y.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
return y
|
||||
|
||||
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
|
||||
|
||||
1305
Compiler/ml.py
1305
Compiler/ml.py
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,8 @@ This has to imported explicitly.
|
||||
|
||||
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
from Compiler import floatingpoint
|
||||
from Compiler import types
|
||||
from Compiler import comparison
|
||||
@@ -290,12 +292,11 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
# how many bits to use from integer part
|
||||
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
||||
n_bits = a.f + n_int_bits
|
||||
sint = types.sint
|
||||
sint = a.int_type
|
||||
if types.program.options.ring and not as19:
|
||||
intbitint = types.intbitint
|
||||
n_shift = int(types.program.options.ring) - a.k
|
||||
if types.program.use_split():
|
||||
assert not zero_output
|
||||
from Compiler.GC.types import sbitvec
|
||||
if types.program.use_split() == 3:
|
||||
x = a.v.split_to_two_summands(a.k)
|
||||
@@ -327,6 +328,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
s = sint.conv(bits[-1])
|
||||
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
|
||||
higher_bits = bits[a.f:n_bits]
|
||||
bits_to_check = bits[n_bits:-1]
|
||||
else:
|
||||
if types.program.use_edabit():
|
||||
l = sint.get_edabit(a.f, True)
|
||||
@@ -338,7 +340,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
shifted = ((a.v - r) << n_shift).reveal()
|
||||
shifted = ((a.v - r) << n_shift).reveal(False)
|
||||
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
||||
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
|
||||
r_bits[a.f-1::-1])
|
||||
@@ -367,17 +369,17 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
bits = a.v.bit_decompose(a.k, maybe_mixed=True)
|
||||
lower = sint.bit_compose(bits[:a.f])
|
||||
higher_bits = bits[a.f:n_bits]
|
||||
s = sint.conv(bits[-1])
|
||||
s = a.bit_type.conv(bits[-1])
|
||||
bits_to_check = bits[n_bits:-1]
|
||||
if not as19:
|
||||
c = types.sfix._new(lower, k=a.k, f=a.f)
|
||||
c = a._new(lower, k=a.k, f=a.f)
|
||||
assert(len(higher_bits) == n_bits - a.f)
|
||||
pow2_bits = [sint.conv(x) for x in higher_bits]
|
||||
d = floatingpoint.Pow2_from_bits(pow2_bits)
|
||||
g = exp_from_parts(d, c)
|
||||
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
|
||||
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
|
||||
2 ** n_int_bits, signed=False,
|
||||
nearest=types.sfix.round_nearest),
|
||||
nearest=a.round_nearest),
|
||||
k=a.k, f=a.f)
|
||||
if zero_output:
|
||||
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
||||
@@ -398,6 +400,36 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
return s.if_else(1 / g, g)
|
||||
|
||||
|
||||
def mux_exp(x, y, block_size=8):
|
||||
assert util.is_constant_float(x)
|
||||
from Compiler.GC.types import sbitvec, sbits
|
||||
bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v
|
||||
sign = bits[-1]
|
||||
m = math.log(2 ** (y.k - y.f - 1), x)
|
||||
del bits[int(math.ceil(math.log(m, 2))) + y.f:]
|
||||
parts = []
|
||||
for i in range(0, len(bits), block_size):
|
||||
one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v
|
||||
exp = []
|
||||
try:
|
||||
for j in range(len(one_hot)):
|
||||
exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f))
|
||||
except OverflowError:
|
||||
pass
|
||||
exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp))
|
||||
bin_part = [0] * max(x.bit_length() for x in exp)
|
||||
for j in range(len(bin_part)):
|
||||
for k, (a, b) in enumerate(zip(one_hot, exp)):
|
||||
bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \
|
||||
else 0
|
||||
if util.is_zero(bin_part[j]):
|
||||
bin_part[j] = sbits.get_type(y.size)(0)
|
||||
if i == 0:
|
||||
bin_part[j] = sign.if_else(0, bin_part[j])
|
||||
parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part))))
|
||||
return util.tree_reduce(operator.mul, parts)
|
||||
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def log2_fx(x, use_division=True):
|
||||
@@ -420,6 +452,8 @@ def log2_fx(x, use_division=True):
|
||||
p -= x.f
|
||||
vlen = x.f
|
||||
v = x._new(v, k=x.k, f=x.f)
|
||||
elif isinstance(x, (types._register, types.cfix)):
|
||||
return log2_fx(types.sfix(x), use_division)
|
||||
else:
|
||||
d = types.sfloat(x)
|
||||
v, p, vlen = d.v, d.p, d.vlen
|
||||
@@ -501,7 +535,7 @@ def abs_fx(x):
|
||||
#
|
||||
# @return floored sint value of x
|
||||
def floor_fx(x):
|
||||
return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x))
|
||||
return load_sint(floatingpoint.Trunc(x.v, x.k, x.f, x.kappa), type(x))
|
||||
|
||||
|
||||
### sqrt methods
|
||||
@@ -627,7 +661,7 @@ def sqrt_simplified_fx(x):
|
||||
h = h * r
|
||||
H = 4 * (h * h)
|
||||
|
||||
if not x.round_nearest or (2 * f < k - 1):
|
||||
if not x.round_nearest or (2 * x.f < x.k - 1):
|
||||
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
|
||||
|
||||
H = H * x
|
||||
@@ -772,9 +806,7 @@ def sqrt_fx(x_l, k, f):
|
||||
@instructions_base.sfix_cisc
|
||||
def sqrt(x, k=None, f=None):
|
||||
"""
|
||||
Returns the square root (sfix) of any given fractional
|
||||
value as long as it can be rounded to a integral value
|
||||
with :py:obj:`f` bits of decimal precision.
|
||||
Square root.
|
||||
|
||||
:param x: fractional input (sfix).
|
||||
|
||||
@@ -882,7 +914,7 @@ def SqrtComp(z, old=False):
|
||||
k = len(z)
|
||||
if isinstance(z[0], types.sint):
|
||||
return types.sfix._new(sum(z[i] * types.cfix(
|
||||
2 ** (-(i - f + 1) / 2)).v for i in range(k)))
|
||||
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
|
||||
k_prime = k // 2
|
||||
f_prime = f // 2
|
||||
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .comparison import *
|
||||
from .floatingpoint import *
|
||||
from .types import *
|
||||
from . import comparison
|
||||
from . import comparison, program
|
||||
|
||||
class NonLinear:
|
||||
kappa = None
|
||||
@@ -30,6 +30,17 @@ class NonLinear:
|
||||
def trunc_pr(self, a, k, m, signed=True):
|
||||
if isinstance(a, types.cint):
|
||||
return shift_two(a, m)
|
||||
prog = program.Program.prog
|
||||
if prog.use_trunc_pr:
|
||||
if not prog.options.ring:
|
||||
prog.curr_tape.require_bit_length(k + prog.security)
|
||||
if signed and prog.use_trunc_pr != -1:
|
||||
a += (1 << (k - 1))
|
||||
res = sint()
|
||||
trunc_pr(res, a, k, m)
|
||||
if signed and prog.use_trunc_pr != -1:
|
||||
res -= (1 << (k - m - 1))
|
||||
return res
|
||||
return self._trunc_pr(a, k, m, signed)
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
@@ -44,6 +55,9 @@ class NonLinear:
|
||||
return a
|
||||
return self._trunc(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return -self.trunc(a, k, k - 1, kappa, True)
|
||||
|
||||
class Masking(NonLinear):
|
||||
def eqz(self, a, k):
|
||||
c, r = self._mask(a, k)
|
||||
@@ -100,42 +114,44 @@ class KnownPrime(NonLinear):
|
||||
def _mod2m(self, a, k, m, signed):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
|
||||
return sint.bit_compose(self.bit_dec(a, k, m, True))
|
||||
|
||||
def _trunc_pr(self, a, k, m, signed):
|
||||
# nearest truncation
|
||||
return self.trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def _trunc(self, a, k, m, signed=None):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
|
||||
if signed:
|
||||
res -= cint(1) << (k - 1 - m)
|
||||
return res
|
||||
return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed)
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
a += cint(1) << (m - 1)
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
k += 1
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
|
||||
res = self._trunc(a, k, m, False)
|
||||
if signed:
|
||||
res -= cint(1) << (k - m - 2)
|
||||
return res
|
||||
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
assert k < self.prime.bit_length()
|
||||
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
|
||||
if len(bits) < m:
|
||||
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
|
||||
return bits[:m]
|
||||
bits = BitDecFull(a, m, maybe_mixed=maybe_mixed)
|
||||
assert len(bits) == m
|
||||
return bits
|
||||
|
||||
def eqz(self, a, k):
|
||||
# always signed
|
||||
a += two_power(k)
|
||||
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
if k + 1 < self.prime.bit_length():
|
||||
# https://dl.acm.org/doi/10.1145/3474123.3486757
|
||||
# "negative" values wrap around when doubling, thus becoming odd
|
||||
return self.mod2m(2 * a, k + 1, 1, False)
|
||||
else:
|
||||
return super(KnownPrime, self).ltz(a, k, kappa)
|
||||
|
||||
class Ring(Masking):
|
||||
""" Non-linear functionality modulo a power of two known at compile time.
|
||||
"""
|
||||
@@ -172,3 +188,6 @@ class Ring(Masking):
|
||||
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
|
||||
else:
|
||||
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return LtzRing(a, k)
|
||||
|
||||
213
Compiler/oram.py
213
Compiler/oram.py
@@ -348,7 +348,7 @@ class Entry(object):
|
||||
def __len__(self):
|
||||
return 2 + len(self.x)
|
||||
def __repr__(self):
|
||||
return '{empty=%s}' % self.is_empty if self.is_empty \
|
||||
return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \
|
||||
else '{%s: %s}' % (self.v, self.x)
|
||||
def __add__(self, other):
|
||||
try:
|
||||
@@ -466,12 +466,14 @@ class AbstractORAM(object):
|
||||
def get_array(size, t, *args, **kwargs):
|
||||
return t.dynamic_array(size, t, *args, **kwargs)
|
||||
def read(self, index):
|
||||
return self._read(self.value_type.hard_conv(index))
|
||||
res = self._read(self.index_type.hard_conv(index))
|
||||
res = [self.value_type._new(x) for x in res]
|
||||
return res
|
||||
def write(self, index, value):
|
||||
value = util.tuplify(value)
|
||||
value = [self.value_type.conv(x) for x in value]
|
||||
new_value = [self.value_type.get_type(length).hard_conv(v) \
|
||||
for length,v in zip(self.entry_size, value \
|
||||
if isinstance(value, (tuple, list)) \
|
||||
else (value,))]
|
||||
for length,v in zip(self.entry_size, value)]
|
||||
return self._write(self.index_type.hard_conv(index), *new_value)
|
||||
def access(self, index, new_value, write, new_empty=False):
|
||||
return self._access(self.index_type.hard_conv(index),
|
||||
@@ -795,18 +797,19 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
for i,value in enumerate(values):
|
||||
index = MemValue(self.value_type.hard_conv(i))
|
||||
new_value = [MemValue(self.value_type.hard_conv(v)) \
|
||||
for v in (value if isinstance(value, (tuple, list)) \
|
||||
for v in (value if isinstance(
|
||||
value, (tuple, list, Array)) \
|
||||
else (value,))]
|
||||
self.ram[i] = Entry(index, new_value, value_type=self.value_type)
|
||||
|
||||
class TrivialORAM(RefTrivialORAM, AbstractORAM):
|
||||
""" Trivial ORAM (obviously). """
|
||||
ref_type = RefTrivialORAM
|
||||
def __init__(self, size, value_type=sint, value_length=1, index_size=None, \
|
||||
def __init__(self, size, value_type=None, value_length=1, index_size=None, \
|
||||
entry_size=None, contiguous=True, init_rounds=-1):
|
||||
self.index_size = index_size or log2(size)
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.value_type = value_type or sint
|
||||
self.index_type = self.value_type.get_type(self.index_size)
|
||||
if entry_size is None:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
@@ -859,15 +862,16 @@ class LinearORAM(TrivialORAM):
|
||||
empty_entry = self.empty_entry(False)
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type.get_type(l) for l in self.entry_size])
|
||||
self.value_length + 1, t)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
access_here = self.index_vector[i]
|
||||
return access_here * ValueTuple((entry.empty(),) + entry.x)
|
||||
not_found = f()[0]
|
||||
read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x
|
||||
not_found = self.value_type.bit_type(f()[0])
|
||||
read_value = ValueTuple(self.value_type.get_type(l)(x) for l, x in zip(self.entry_size, f()[1:])) + \
|
||||
not_found * empty_entry.x
|
||||
maybe_stop_timer(6)
|
||||
return read_value, not_found
|
||||
@method_block
|
||||
@@ -876,7 +880,9 @@ class LinearORAM(TrivialORAM):
|
||||
empty_entry = self.empty_entry(False)
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
new_value = make_array(new_value)
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
@@ -892,7 +898,9 @@ class LinearORAM(TrivialORAM):
|
||||
empty_entry = self.empty_entry(False)
|
||||
index_vector = \
|
||||
demux_array(bit_decompose(index, self.index_size))
|
||||
new_value = make_array(new_value)
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
new_empty = MemValue(new_empty)
|
||||
write = MemValue(write)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@@ -986,7 +994,8 @@ class List(EndRecursiveEviction):
|
||||
for i,value in enumerate(values):
|
||||
index = self.value_type.hard_conv(i)
|
||||
new_value = [self.value_type.hard_conv(v) \
|
||||
for v in (value if isinstance(value, (tuple, list)) \
|
||||
for v in (value if isinstance(
|
||||
value, (tuple, list, Array)) \
|
||||
else (value,))]
|
||||
self.__setitem__(index, new_value)
|
||||
def __repr__(self):
|
||||
@@ -1025,8 +1034,9 @@ def get_n_threads_for_tree(size):
|
||||
|
||||
class TreeORAM(AbstractORAM):
|
||||
""" Tree ORAM. """
|
||||
def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \
|
||||
def __init__(self, size, value_type=None, value_length=1, entry_size=None, \
|
||||
bucket_oram=TrivialORAM, init_rounds=-1):
|
||||
value_type = value_type or sint
|
||||
print('create oram of size', size)
|
||||
self.bucket_oram = bucket_oram
|
||||
# heuristic bucket size
|
||||
@@ -1062,11 +1072,12 @@ class TreeORAM(AbstractORAM):
|
||||
stop_timer(1)
|
||||
start_timer()
|
||||
self.root = RefBucket(1, self)
|
||||
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
|
||||
self.index = self.index_structure(size, self.D, self.index_type,
|
||||
init_rounds, True)
|
||||
|
||||
self.read_value = Array(self.value_length, value_type)
|
||||
self.read_value = Array(self.value_length, value_type.default_type)
|
||||
self.read_non_empty = MemValue(self.value_type.bit_type(0))
|
||||
self.state = MemValue(self.value_type(0))
|
||||
self.state = MemValue(self.value_type.default_type(0))
|
||||
@method_block
|
||||
def add_to_root(self, state, is_empty, v, *x):
|
||||
if len(x) != self.value_length:
|
||||
@@ -1106,10 +1117,10 @@ class TreeORAM(AbstractORAM):
|
||||
self.evict_bucket(RefBucket(p_bucket2, self), d)
|
||||
@method_block
|
||||
def read_and_renew_index(self, u):
|
||||
l_star = random_block(self.D, self.value_type)
|
||||
l_star = random_block(self.D, self.index_type)
|
||||
if use_insecure_randomness:
|
||||
new_path = regint.get_random(self.D)
|
||||
l_star = self.value_type(new_path)
|
||||
l_star = self.index_type(new_path)
|
||||
self.state.write(l_star)
|
||||
return self.index.update(u, l_star, evict=False).reveal()
|
||||
@method_block
|
||||
@@ -1120,7 +1131,7 @@ class TreeORAM(AbstractORAM):
|
||||
parallel = get_parallel(self.index_size, *self.internal_value_type())
|
||||
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
[self.value_type.default_type] * self.value_length)
|
||||
def process(level):
|
||||
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
|
||||
bucket = RefBucket(b_index, self)
|
||||
@@ -1142,9 +1153,9 @@ class TreeORAM(AbstractORAM):
|
||||
Program.prog.curr_tape.start_new_basicblock()
|
||||
crash()
|
||||
def internal_value_type(self):
|
||||
return self.value_type, self.value_length + 1
|
||||
return self.value_type.default_type, self.value_length + 1
|
||||
def internal_entry_size(self):
|
||||
return self.value_type, [self.D] + list(self.entry_size)
|
||||
return self.value_type.default_type, [self.D] + list(self.entry_size)
|
||||
def n_buckets(self):
|
||||
return 2**(self.D+1)
|
||||
@method_block
|
||||
@@ -1176,8 +1187,9 @@ class TreeORAM(AbstractORAM):
|
||||
#print 'pre-add', self
|
||||
maybe_start_timer(4)
|
||||
self.add_to_root(state, entry.empty(), \
|
||||
self.value_type(entry.v.read()), \
|
||||
*(self.value_type(i.read()) for i in entry.x))
|
||||
self.index_type(entry.v.read()), \
|
||||
*(self.value_type.default_type(i.read())
|
||||
for i in entry.x))
|
||||
maybe_stop_timer(4)
|
||||
#print 'pre-evict', self
|
||||
if evict:
|
||||
@@ -1221,28 +1233,35 @@ class TreeORAM(AbstractORAM):
|
||||
""" Batch initalization. Obliviously shuffles and adds N entries to
|
||||
random leaf buckets. """
|
||||
m = len(values)
|
||||
assert((m & (m-1)) == 0)
|
||||
if not (m & (m-1)) == 0:
|
||||
raise CompilerError('Batch size must a power of 2.')
|
||||
if m != self.size:
|
||||
raise CompilerError('Batch initialization must have N values.')
|
||||
if self.value_type != sint:
|
||||
raise CompilerError('Batch initialization only possible with sint.')
|
||||
|
||||
depth = log2(m)
|
||||
leaves = [0] * m
|
||||
entries = [0] * m
|
||||
indexed_values = [0] * m
|
||||
leaves = self.value_type.Array(m)
|
||||
indexed_values = \
|
||||
self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1)
|
||||
|
||||
# assign indices 0, ..., m-1
|
||||
for i,value in enumerate(values):
|
||||
@for_range(m)
|
||||
def _(i):
|
||||
value = values[i]
|
||||
index = MemValue(self.value_type.hard_conv(i))
|
||||
new_value = [MemValue(self.value_type.hard_conv(v)) \
|
||||
for v in (value if isinstance(value, (tuple, list)) \
|
||||
else (value,))]
|
||||
indexed_values[i] = [index] + new_value
|
||||
|
||||
|
||||
entries = sint.Matrix(self.bucket_size * 2 ** self.D,
|
||||
len(Entry(0, list(indexed_values[0]), False)))
|
||||
|
||||
# assign leaves
|
||||
for i,index_value in enumerate(indexed_values):
|
||||
@for_range(len(indexed_values))
|
||||
def _(i):
|
||||
index_value = list(indexed_values[i])
|
||||
leaves[i] = random_block(self.D, self.value_type)
|
||||
|
||||
index = index_value[0]
|
||||
@@ -1252,18 +1271,20 @@ class TreeORAM(AbstractORAM):
|
||||
|
||||
# save unsorted leaves for position map
|
||||
unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves]
|
||||
permutation.sort(leaves, comp=permutation.normal_comparator)
|
||||
leaves.sort()
|
||||
|
||||
bucket_sz = 0
|
||||
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
|
||||
B = [[0]*3 for i in range(m)]
|
||||
B = sint.Matrix(m, 3)
|
||||
B[0] = [0, leaves[0], 0]
|
||||
B[-1] = [None, None, sint(1)]
|
||||
s = 0
|
||||
s = MemValue(sint(0))
|
||||
|
||||
for i in range(1, m):
|
||||
@for_range_opt(m - 1)
|
||||
def _(j):
|
||||
i = j + 1
|
||||
eq = leaves[i].equal(leaves[i-1])
|
||||
s = (s + eq) * eq
|
||||
s.write((s + eq) * eq)
|
||||
B[i][0] = s
|
||||
B[i][1] = leaves[i]
|
||||
B[i-1][2] = 1 - eq
|
||||
@@ -1271,7 +1292,7 @@ class TreeORAM(AbstractORAM):
|
||||
#last_in_bucket[i-1] = 1 - eq
|
||||
|
||||
# shuffle
|
||||
permutation.shuffle(B, value_type=sint)
|
||||
B.secure_shuffle()
|
||||
#cint(0).print_reg('shuf')
|
||||
|
||||
sz = MemValue(0) #cint(0)
|
||||
@@ -1279,7 +1300,8 @@ class TreeORAM(AbstractORAM):
|
||||
empty_positions = Array(nleaves, self.value_type)
|
||||
empty_leaves = Array(nleaves, self.value_type)
|
||||
|
||||
for i in range(m):
|
||||
@for_range(m)
|
||||
def _(i):
|
||||
if_then(reveal(B[i][2]))
|
||||
#if B[i][2] == 1:
|
||||
#cint(i).print_reg('last')
|
||||
@@ -1291,12 +1313,13 @@ class TreeORAM(AbstractORAM):
|
||||
empty_positions[szval] = B[i][0] #pos[i][0]
|
||||
#empty_positions[szval].reveal().print_reg('ps0')
|
||||
empty_leaves[szval] = B[i][1] #pos[i][1]
|
||||
sz += 1
|
||||
sz.iadd(1)
|
||||
end_if()
|
||||
|
||||
pos_bits = []
|
||||
pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2)
|
||||
|
||||
for i in range(nleaves):
|
||||
@for_range_opt(nleaves)
|
||||
def _(i):
|
||||
leaf = empty_leaves[i]
|
||||
# split into 2 if bucket size can't fit into one field elem
|
||||
if self.bucket_size + Program.prog.security > 128:
|
||||
@@ -1315,46 +1338,39 @@ class TreeORAM(AbstractORAM):
|
||||
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
|
||||
else:
|
||||
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
|
||||
pos_bits += [[b, leaf] for b in bucket_bits]
|
||||
assert len(bucket_bits) == self.bucket_size
|
||||
for j, b in enumerate(bucket_bits):
|
||||
pos_bits[i * self.bucket_size + j] = [b, leaf]
|
||||
|
||||
# sort to get empty positions first
|
||||
permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator)
|
||||
pos_bits.sort(n_bits=1)
|
||||
|
||||
# now assign positions to empty entries
|
||||
empty_entries = [0] * (self.bucket_size*2**self.D - m)
|
||||
|
||||
for i in range(self.bucket_size*2**self.D - m):
|
||||
@for_range(len(entries) - m)
|
||||
def _(i):
|
||||
vtype, vlength = self.internal_value_type()
|
||||
leaf = vtype(pos_bits[i][1])
|
||||
# set leaf in empty entry for assigning after shuffle
|
||||
value = tuple([leaf] + [vtype(0) for j in range(vlength)])
|
||||
value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)])
|
||||
entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype)
|
||||
empty_entries[i] = entry
|
||||
entries[m + i] = entry
|
||||
|
||||
# now shuffle, reveal positions and place entries
|
||||
entries = entries + empty_entries
|
||||
while len(entries) & (len(entries)-1) != 0:
|
||||
entries.append(None)
|
||||
permutation.shuffle(entries, value_type=sint)
|
||||
entries = [entry for entry in entries if entry is not None]
|
||||
clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries]
|
||||
entries.secure_shuffle()
|
||||
clear_leaves = Array.create_from(
|
||||
Entry(entries.get_columns()).x[0].reveal())
|
||||
|
||||
Program.prog.curr_tape.start_new_basicblock()
|
||||
|
||||
bucket_sizes = Array(2**self.D, regint)
|
||||
for i in range(2**self.D):
|
||||
bucket_sizes[i] = 0
|
||||
k = 0
|
||||
for entry,leaf in zip(entries, clear_leaves):
|
||||
leaf = leaf.read()
|
||||
k += 1
|
||||
|
||||
# for some reason leaf_buckets is in bit-reversed order
|
||||
bits = bit_decompose(leaf, self.D)
|
||||
rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1]))
|
||||
bucket = RefBucket(rev_leaf + (1 << self.D), self)
|
||||
# hack: 1*entry ensures MemValues are converted to sints
|
||||
bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry
|
||||
@for_range_opt(len(entries))
|
||||
def _(k):
|
||||
leaf = clear_leaves[k]
|
||||
bucket = RefBucket(leaf + (1 << self.D), self)
|
||||
bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k])
|
||||
bucket_sizes[leaf] += 1
|
||||
|
||||
self.index.batch_init([leaf.read() for leaf in unsorted_leaves])
|
||||
@@ -1493,6 +1509,7 @@ class PackedIndexStructure(object):
|
||||
self.l[i] = [0] * self.elements_per_block
|
||||
time()
|
||||
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
|
||||
print_ln('packed ORAM init done')
|
||||
print('index initialized, size', size)
|
||||
def translate_index(self, index):
|
||||
""" Bit slicing *index* according parameters. Output is tuple
|
||||
@@ -1598,16 +1615,20 @@ class PackedIndexStructure(object):
|
||||
def batch_init(self, values):
|
||||
""" Initialize m values with indices 0, ..., m-1 """
|
||||
m = len(values)
|
||||
n_entries = max(1, m/self.entries_per_block)
|
||||
new_values = [0] * n_entries
|
||||
n_entries = max(1, m//self.entries_per_block)
|
||||
new_values = sint.Matrix(n_entries, self.elements_per_block)
|
||||
values = Array.create_from(values)
|
||||
|
||||
for i in range(n_entries):
|
||||
@for_range(n_entries)
|
||||
def _(i):
|
||||
block = [0] * self.elements_per_block
|
||||
for j in range(self.elements_per_block):
|
||||
base = i * self.entries_per_block + j * self.entries_per_element
|
||||
for k in range(self.entries_per_element):
|
||||
if base + k < m:
|
||||
block[j] += values[base + k] << (k * self.entry_size)
|
||||
@if_(base + k < m)
|
||||
def _():
|
||||
block[j] += \
|
||||
values[base + k] << (k * sum(self.entry_size))
|
||||
|
||||
new_values[i] = block
|
||||
|
||||
@@ -1661,13 +1682,51 @@ class OneLevelORAM(TreeORAM):
|
||||
pattern after one recursion. """
|
||||
index_structure = BaseORAMIndexStructure
|
||||
|
||||
class BinaryORAM:
|
||||
def __init__(self, size, value_type=None, **kwargs):
|
||||
import circuit_oram
|
||||
from Compiler.GC import types
|
||||
n_bits = int(get_program().options.binary)
|
||||
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
|
||||
self.index_type = self.value_type
|
||||
oram_value_type = types.sbits.get_type(64)
|
||||
if 'entry_size' not in kwargs:
|
||||
kwargs['entry_size'] = n_bits
|
||||
self.oram = circuit_oram.OptimalCircuitORAM(
|
||||
size, value_type=oram_value_type, **kwargs)
|
||||
self.size = size
|
||||
def get_index(self, index):
|
||||
return self.oram.value_type(self.index_type.conv(index).elements()[0])
|
||||
def __setitem__(self, index, value):
|
||||
value = list(self.oram.value_type(
|
||||
self.value_type.conv(v).elements()[0]) for v in tuplify(value))
|
||||
self.oram[self.get_index(index)] = value
|
||||
def __getitem__(self, index):
|
||||
value = self.oram[self.get_index(index)]
|
||||
return untuplify(tuple(self.value_type(v) for v in tuplify(value)))
|
||||
def read(self, index):
|
||||
return self.oram.read(index)
|
||||
def read_and_maybe_remove(self, index):
|
||||
return self.oram.read_and_maybe_remove(index)
|
||||
def access(self, *args):
|
||||
return self.oram.access(*args)
|
||||
def add(self, *args, **kwargs):
|
||||
return self.oram.add(*args, **kwargs)
|
||||
def delete(self, *args, **kwargs):
|
||||
return self.oram.delete(*args, **kwargs)
|
||||
|
||||
def OptimalORAM(size,*args,**kwargs):
|
||||
""" Create an ORAM instance suitable for the size based on
|
||||
experiments.
|
||||
|
||||
:param size: number of elements
|
||||
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn`
|
||||
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
|
||||
:py:class:`sfix`
|
||||
"""
|
||||
if not util.is_constant(size):
|
||||
raise CompilerError('ORAM size has be a compile-time constant')
|
||||
if get_program().options.binary:
|
||||
return BinaryORAM(size, *args, **kwargs)
|
||||
if optimal_threshold is None:
|
||||
if n_threads == 1:
|
||||
threshold = 2**11
|
||||
@@ -1716,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
|
||||
def test_oram(oram_type, N, value_type=sint, iterations=100):
|
||||
stop_grind()
|
||||
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
|
||||
test_oram_initialized(oram, iterations)
|
||||
return oram
|
||||
|
||||
def test_oram_initialized(oram, iterations=100):
|
||||
N = oram.size
|
||||
value_type = oram.value_type
|
||||
value_type = value_type.get_type(32)
|
||||
index_type = value_type.get_type(log2(N))
|
||||
start_grind()
|
||||
@@ -1783,7 +1848,7 @@ def test_batch_init(oram_type, N):
|
||||
oram = oram_type(N, value_type)
|
||||
print('initialized')
|
||||
print_reg(cint(0), 'init')
|
||||
oram.batch_init([value_type(i) for i in range(N)])
|
||||
oram.batch_init(Array.create_from(sint(regint.inc(N))))
|
||||
print_reg(cint(0), 'done')
|
||||
@for_range(N)
|
||||
def f(i):
|
||||
|
||||
1309
Compiler/path_oblivious_heap.py
Normal file
1309
Compiler/path_oblivious_heap.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,8 +8,11 @@ from functools import reduce
|
||||
|
||||
#import pdb
|
||||
|
||||
prog = program.Program.prog
|
||||
prog.set_bit_length(min(64, prog.bit_length))
|
||||
try:
|
||||
prog = program.Program.prog
|
||||
prog.set_bit_length(min(64, prog.bit_length))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
class Counter(object):
|
||||
def __init__(self, val=0, max_val=None, size=None, value_type=sgf2n):
|
||||
@@ -111,24 +114,6 @@ def bucket_size_sorter(x, y):
|
||||
return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z])
|
||||
|
||||
|
||||
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
||||
""" Simulate secure shuffling with Waksman network for 2 players.
|
||||
|
||||
|
||||
Returns the network switching config so it may be re-used later. """
|
||||
n = len(x)
|
||||
if n & (n-1) != 0:
|
||||
raise CompilerError('shuffle requires n a power of 2')
|
||||
if config is None:
|
||||
config = permutation.configure_waksman(permutation.random_perm(n))
|
||||
for i,c in enumerate(config):
|
||||
config[i] = [value_type(b) for b in c]
|
||||
permutation.waksman(x, config, reverse=reverse)
|
||||
permutation.waksman(x, config, reverse=reverse)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def LT(a, b):
|
||||
a_bits = bit_decompose(a)
|
||||
b_bits = bit_decompose(b)
|
||||
@@ -472,10 +457,15 @@ class PathORAM(TreeORAM):
|
||||
print_ln()
|
||||
|
||||
# shuffle entries and levels
|
||||
while len(merged_entries) & (len(merged_entries)-1) != 0:
|
||||
merged_entries.append(None) #self.root.bucket.empty_entry(False))
|
||||
permutation.rec_shuffle(merged_entries, value_type=self.value_type)
|
||||
merged_entries = [e for e in merged_entries if e is not None]
|
||||
flat = []
|
||||
for x in merged_entries:
|
||||
flat += list(x[0]) + [x[1]]
|
||||
flat = self.value_type(flat)
|
||||
assert len(flat) % len(merged_entries) == 0
|
||||
l = len(flat) // len(merged_entries)
|
||||
shuffled = flat.secure_shuffle(l)
|
||||
merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]]
|
||||
for i in range(len(shuffled) // l)]
|
||||
|
||||
# need to copy entries/levels to memory for re-positioning
|
||||
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)
|
||||
|
||||
@@ -10,16 +10,6 @@ if '_Array' not in dir():
|
||||
from Compiler.program import Program
|
||||
_Array = Array
|
||||
|
||||
SORT_BITS = []
|
||||
insecure_random = Random(0)
|
||||
|
||||
def predefined_comparator(x, y):
|
||||
""" Assumes SORT_BITS is populated with the required sorting network bits """
|
||||
if predefined_comparator.sort_bits_iter is None:
|
||||
predefined_comparator.sort_bits_iter = iter(SORT_BITS)
|
||||
return next(predefined_comparator.sort_bits_iter)
|
||||
predefined_comparator.sort_bits_iter = None
|
||||
|
||||
def list_comparator(x, y):
|
||||
""" Uses the first element in the list for comparison """
|
||||
return x[0] < y[0]
|
||||
@@ -37,10 +27,6 @@ def bitwise_comparator(x, y):
|
||||
|
||||
def cond_swap_bit(x,y, b):
|
||||
""" swap if b == 1 """
|
||||
if x is None:
|
||||
return y, None
|
||||
elif y is None:
|
||||
return x, None
|
||||
if isinstance(x, list):
|
||||
t = [(xi - yi) * b for xi,yi in zip(x, y)]
|
||||
return [xi - ti for xi,ti in zip(x, t)], \
|
||||
@@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator):
|
||||
else:
|
||||
raise CompilerError('Length of list must be power of two')
|
||||
|
||||
def merge(a, b, comp):
|
||||
""" General length merge (pads to power of 2) """
|
||||
while len(a) & (len(a)-1) != 0:
|
||||
a.append(None)
|
||||
while len(b) & (len(b)-1) != 0:
|
||||
b.append(None)
|
||||
if len(a) < len(b):
|
||||
a += [None] * (len(b) - len(a))
|
||||
elif len(b) < len(a):
|
||||
b += [None] * (len(b) - len(b))
|
||||
t = a + b
|
||||
odd_even_merge(t, comp)
|
||||
for i,v in enumerate(t[::]):
|
||||
if v is None:
|
||||
t.remove(None)
|
||||
return t
|
||||
|
||||
def sort(a, comp):
|
||||
""" Pads to power of 2, sorts, removes padding """
|
||||
length = len(a)
|
||||
@@ -112,47 +81,12 @@ def sort(a, comp):
|
||||
odd_even_merge_sort(a, comp)
|
||||
del a[length:]
|
||||
|
||||
def recursive_merge(a, comp):
|
||||
""" Recursively merge a list of sorted lists (initially sorted by size) """
|
||||
if len(a) == 1:
|
||||
return
|
||||
# merge smallest two lists, place result in correct position, recurse
|
||||
t = merge(a[0], a[1], comp)
|
||||
del a[0]
|
||||
del a[0]
|
||||
added = False
|
||||
for i,c in enumerate(a):
|
||||
if len(c) >= len(t):
|
||||
a.insert(i, t)
|
||||
added = True
|
||||
break
|
||||
if not added:
|
||||
a.append(t)
|
||||
recursive_merge(a, comp)
|
||||
# The following functionality for shuffling isn't used any more as it
|
||||
# has been moved to the virtual machine. The code has been kept for
|
||||
# reference.
|
||||
|
||||
def random_perm(n):
|
||||
""" Generate a random permutation of length n
|
||||
|
||||
WARNING: randomness fixed at compile-time, this is NOT secure
|
||||
"""
|
||||
if not Program.prog.options.insecure:
|
||||
raise CompilerError('no secure implementation of Waksman permution, '
|
||||
'use --insecure to activate')
|
||||
a = list(range(n))
|
||||
for i in range(n-1, 0, -1):
|
||||
j = insecure_random.randint(0, i)
|
||||
t = a[i]
|
||||
a[i] = a[j]
|
||||
a[j] = t
|
||||
return a
|
||||
|
||||
def inverse(perm):
|
||||
inv = [None] * len(perm)
|
||||
for i, p in enumerate(perm):
|
||||
inv[p] = i
|
||||
return inv
|
||||
|
||||
def configure_waksman(perm):
|
||||
def configure_waksman(perm, n_iter=[0]):
|
||||
top = n_iter == [0]
|
||||
n = len(perm)
|
||||
if n == 2:
|
||||
return [(perm[0], perm[0])]
|
||||
@@ -175,6 +109,7 @@ def configure_waksman(perm):
|
||||
via = 0
|
||||
j0 = j
|
||||
while True:
|
||||
n_iter[0] += 1
|
||||
#print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2)
|
||||
|
||||
i = inv_perm[j]
|
||||
@@ -209,8 +144,11 @@ def configure_waksman(perm):
|
||||
|
||||
assert sorted(p0) == list(range(n//2))
|
||||
assert sorted(p1) == list(range(n//2))
|
||||
p0_config = configure_waksman(p0)
|
||||
p1_config = configure_waksman(p1)
|
||||
p0_config = configure_waksman(p0, n_iter)
|
||||
p1_config = configure_waksman(p1, n_iter)
|
||||
if top:
|
||||
print(n_iter[0], 'iterations for Waksman')
|
||||
assert O[0] == 0, 'not a Waksman network'
|
||||
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
|
||||
|
||||
def waksman(a, config, depth=0, start=0, reverse=False):
|
||||
@@ -358,23 +296,10 @@ def iter_waksman(a, config, reverse=False):
|
||||
# nblocks /= 2
|
||||
# depth -= 1
|
||||
|
||||
def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
||||
n = len(x)
|
||||
if n & (n-1) != 0:
|
||||
raise CompilerError('shuffle requires n a power of 2')
|
||||
if config is None:
|
||||
config = configure_waksman(random_perm(n))
|
||||
for i,c in enumerate(config):
|
||||
config[i] = [value_type.bit_type(b) for b in c]
|
||||
waksman(x, config, reverse=reverse)
|
||||
waksman(x, config, reverse=reverse)
|
||||
|
||||
|
||||
def config_shuffle(n, value_type):
|
||||
""" Compute config for oblivious shuffling.
|
||||
|
||||
Take mod 2 for active sec. """
|
||||
perm = random_perm(n)
|
||||
def config_from_perm(perm, value_type):
|
||||
n = len(perm)
|
||||
assert(list(sorted(perm))) == list(range(n))
|
||||
if n & (n-1) != 0:
|
||||
# pad permutation to power of 2
|
||||
m = 2**int(math.ceil(math.log(n, 2)))
|
||||
@@ -394,103 +319,3 @@ def config_shuffle(n, value_type):
|
||||
for j,b in enumerate(c):
|
||||
config[i * len(perm) + j] = b
|
||||
return config
|
||||
|
||||
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
||||
""" Simulate secure shuffling with Waksman network for 2 players.
|
||||
WARNING: This is not a properly secure implementation but has roughly the right complexity.
|
||||
|
||||
Returns the network switching config so it may be re-used later. """
|
||||
n = len(x)
|
||||
m = 2**int(math.ceil(math.log(n, 2)))
|
||||
assert n == m, 'only working for powers of two'
|
||||
if config is None:
|
||||
config = config_shuffle(n, value_type)
|
||||
|
||||
if isinstance(x, list):
|
||||
if isinstance(x[0], list):
|
||||
length = len(x[0])
|
||||
assert len(x) == length
|
||||
for i in range(length):
|
||||
xi = Array(m, value_type.reg_type)
|
||||
for j in range(n):
|
||||
xi[j] = x[j][i]
|
||||
for j in range(n, m):
|
||||
xi[j] = value_type(0)
|
||||
iter_waksman(xi, config, reverse=reverse)
|
||||
iter_waksman(xi, config, reverse=reverse)
|
||||
for j, y in enumerate(xi):
|
||||
x[j][i] = y
|
||||
else:
|
||||
xa = Array(m, value_type.reg_type)
|
||||
for i in range(n):
|
||||
xa[i] = x[i]
|
||||
for i in range(n, m):
|
||||
xa[i] = value_type(0)
|
||||
iter_waksman(xa, config, reverse=reverse)
|
||||
iter_waksman(xa, config, reverse=reverse)
|
||||
x[:] = xa
|
||||
elif isinstance(x, Array):
|
||||
if len(x) != m and config is None:
|
||||
raise CompilerError('Non-power of 2 Array input not yet supported')
|
||||
iter_waksman(x, config, reverse=reverse)
|
||||
iter_waksman(x, config, reverse=reverse)
|
||||
else:
|
||||
raise CompilerError('Invalid type for shuffle:', type(x))
|
||||
|
||||
return config
|
||||
|
||||
def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None):
|
||||
""" Shuffle a list of ORAM entries.
|
||||
|
||||
Randomly permutes the first "perm_size" entries, leaving the rest (empty
|
||||
entry padding) in the same position. """
|
||||
n = len(x)
|
||||
l = len(x[0])
|
||||
if n & (n-1) != 0:
|
||||
raise CompilerError('Entries must be padded to power of two length.')
|
||||
if perm_size is None:
|
||||
perm_size = n
|
||||
|
||||
xarrays = [Array(n, value_type.reg_type) for i in range(l)]
|
||||
for i in range(n):
|
||||
for j,value in enumerate(x[i]):
|
||||
if isinstance(value, MemValue):
|
||||
xarrays[j][i] = value.read()
|
||||
else:
|
||||
xarrays[j][i] = value
|
||||
|
||||
if config is None:
|
||||
config = config_shuffle(perm_size, value_type)
|
||||
for xi in xarrays:
|
||||
shuffle(xi, config, value_type, reverse)
|
||||
for i in range(n):
|
||||
x[i] = entry_cls(xarrays[j][i] for j in range(l))
|
||||
return config
|
||||
|
||||
|
||||
def sort_zeroes(bits, x, n_ones, value_type):
|
||||
""" Return Array of values in "x" where the corresponding bit in "bits" is
|
||||
a 0.
|
||||
|
||||
The total number of zeroes in "bits" must be known.
|
||||
"bits" and "x" must be Arrays. """
|
||||
config = config_shuffle(len(x), value_type)
|
||||
shuffle(bits, config=config, value_type=value_type)
|
||||
shuffle(x, config=config, value_type=value_type)
|
||||
result = Array(n_ones, value_type.reg_type)
|
||||
|
||||
sz = MemValue(0)
|
||||
last_x = MemValue(value_type(0))
|
||||
#for i,b in enumerate(bits):
|
||||
#if_then(b.reveal() == 0)
|
||||
#result[sz.read()] = x[i]
|
||||
#sz += 1
|
||||
#end_if()
|
||||
@for_range(len(bits))
|
||||
def f(i):
|
||||
found = (bits[i].reveal() == 0)
|
||||
szval = sz.read()
|
||||
result[szval] = last_x + (x[i] - last_x) * found
|
||||
sz.write(sz + found)
|
||||
last_x.write(result[szval])
|
||||
return result
|
||||
|
||||
1014
Compiler/program.py
1014
Compiler/program.py
File diff suppressed because it is too large
Load Diff
73
Compiler/sorting.py
Normal file
73
Compiler/sorting.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import itertools
|
||||
from Compiler import types, library, instructions
|
||||
|
||||
def dest_comp(B):
|
||||
Bt = B.transpose()
|
||||
St_flat = Bt.get_vector().prefix_sum()
|
||||
Tt_flat = Bt.get_vector() * St_flat.get_vector()
|
||||
Tt = types.Matrix(*Bt.sizes, B.value_type)
|
||||
Tt.assign_vector(Tt_flat)
|
||||
return sum(Tt) - 1
|
||||
|
||||
def reveal_sort(k, D, reverse=False):
|
||||
""" Sort in place according to "perfect" key. The name hints at the fact
|
||||
that a random order of the keys is revealed.
|
||||
|
||||
:param k: vector or Array of sint containing exactly :math:`0,\dots,n-1`
|
||||
in any order
|
||||
:param D: Array or MultiArray to sort
|
||||
:param reverse: wether :py:obj:`key` is a permutation in forward or
|
||||
backward order
|
||||
|
||||
"""
|
||||
assert len(k) == len(D)
|
||||
library.break_point()
|
||||
shuffle = types.sint.get_secure_shuffle(len(k))
|
||||
k_prime = k.get_vector().secure_permute(shuffle).reveal()
|
||||
idx = types.Array.create_from(k_prime)
|
||||
if reverse:
|
||||
D.assign_vector(D.get_slice_vector(idx))
|
||||
library.break_point()
|
||||
D.secure_permute(shuffle, reverse=True)
|
||||
else:
|
||||
D.secure_permute(shuffle)
|
||||
library.break_point()
|
||||
v = D.get_vector()
|
||||
D.assign_slice_vector(idx, v)
|
||||
library.break_point()
|
||||
instructions.delshuffle(shuffle)
|
||||
|
||||
def radix_sort(k, D, n_bits=None, signed=True):
|
||||
""" Sort in place according to key.
|
||||
|
||||
:param k: keys (vector or Array of sint or sfix)
|
||||
:param D: Array or MultiArray to sort
|
||||
:param n_bits: number of bits in keys (int)
|
||||
:param signed: whether keys are signed (bool)
|
||||
|
||||
"""
|
||||
assert len(k) == len(D)
|
||||
bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits))
|
||||
if signed and len(bs) > 1:
|
||||
bs[-1][:] = bs[-1][:].bit_not()
|
||||
radix_sort_from_matrix(bs, D)
|
||||
|
||||
def radix_sort_from_matrix(bs, D):
|
||||
n = len(D)
|
||||
for b in bs:
|
||||
assert(len(b) == n)
|
||||
B = types.sint.Matrix(n, 2)
|
||||
h = types.Array.create_from(types.sint(types.regint.inc(n)))
|
||||
@library.for_range(len(bs))
|
||||
def _(i):
|
||||
b = bs[i]
|
||||
B.set_column(0, 1 - b.get_vector())
|
||||
B.set_column(1, b.get_vector())
|
||||
c = types.Array.create_from(dest_comp(B))
|
||||
reveal_sort(c, h, reverse=False)
|
||||
@library.if_e(i < len(bs) - 1)
|
||||
def _():
|
||||
reveal_sort(h, bs[i + 1], reverse=True)
|
||||
@library.else_
|
||||
def _():
|
||||
reveal_sort(h, D, reverse=True)
|
||||
804
Compiler/sqrt_oram.py
Normal file
804
Compiler/sqrt_oram.py
Normal file
@@ -0,0 +1,804 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
from Compiler import library as lib
|
||||
from Compiler import util
|
||||
from Compiler.GC.types import cbit, sbit, sbitint, sbits
|
||||
from Compiler.program import Program
|
||||
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
|
||||
regint, sint, sintbit)
|
||||
from Compiler.oram import demux_array, get_n_threads
|
||||
|
||||
# Adds messages on completion of heavy computation steps
|
||||
debug = False
|
||||
# Finer grained trace of steps that the ORAM performs
|
||||
# + runtime error checks
|
||||
# Warning: reveals information and makes the computation insecure
|
||||
trace = False
|
||||
|
||||
n_threads = 16
|
||||
n_parallel = 1024
|
||||
|
||||
# Avoids any memory allocation if set to False
|
||||
# Setting to False prevents some optimizations but allows for controlling the ORAMs outside of the main tape
|
||||
allow_memory_allocation = True
|
||||
|
||||
|
||||
def get_n_threads(n_loops):
|
||||
if n_threads is None:
|
||||
if n_loops > 2048:
|
||||
return 8
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return n_threads
|
||||
|
||||
|
||||
T = TypeVar("T", sint, sbitint)
|
||||
B = TypeVar("B", sintbit, sbit)
|
||||
|
||||
|
||||
class SqrtOram(Generic[T, B]):
|
||||
"""Oblivious RAM using the "Square-Root" algorithm.
|
||||
|
||||
:param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
|
||||
:param sint value_type: The secret type to use, defaults to sint.
|
||||
:param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
|
||||
:param int period: Leave at None, this parameter is used to recursively pass down the top-level period.
|
||||
"""
|
||||
# TODO: Preferably this is an Array of vectors, but this is currently not supported
|
||||
# One should regard these structures as Arrays where an entry may hold more
|
||||
# than one value (which is a nice property to have when using the ORAM in
|
||||
# practise).
|
||||
shuffle: MultiArray
|
||||
stash: MultiArray
|
||||
# A block has an index and data
|
||||
# `shuffle` and `stash` store the data,
|
||||
# `shufflei` and `stashi` store the index
|
||||
shufflei: Array
|
||||
stashi: Array
|
||||
|
||||
shuffle_used: Array
|
||||
position_map: PositionMap
|
||||
|
||||
# The size of the ORAM, i.e. how many elements it stores
|
||||
n: int
|
||||
# The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed
|
||||
T: int
|
||||
# Keep track of how far we are in the period, and coincidentally how large
|
||||
# the stash is (each access results in a fake or real block being put on
|
||||
# the stash)
|
||||
t: cint
|
||||
|
||||
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None:
|
||||
global debug, allow_memory_allocation
|
||||
|
||||
# Correctly initialize the shuffle (memory) depending on the type of data
|
||||
if isinstance(data, MultiArray):
|
||||
self.shuffle = data
|
||||
self.n = len(data)
|
||||
elif isinstance(data, sint):
|
||||
self.n = math.ceil(len(data) // entry_length)
|
||||
if (len(data) % entry_length != 0):
|
||||
raise Exception('Data incorrectly padded.')
|
||||
self.shuffle = MultiArray(
|
||||
(self.n, entry_length), value_type=value_type)
|
||||
self.shuffle.assign_part_vector(data.get_vector())
|
||||
else:
|
||||
raise Exception("Incorrect format.")
|
||||
|
||||
# Only sint is supported
|
||||
if value_type != sint and value_type != sbitint:
|
||||
raise Exception("The value_type must be either sint or sbitint")
|
||||
|
||||
# Set derived constants
|
||||
self.value_type = value_type
|
||||
self.bit_type: Type[B] = value_type.bit_type
|
||||
self.index_size = util.log2(self.n) + 1 # +1 because signed
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.entry_length = entry_length
|
||||
self.size = self.n
|
||||
|
||||
if debug:
|
||||
lib.print_ln(
|
||||
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
||||
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shufflei = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
self.T = int(math.ceil(
|
||||
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
||||
if debug and not period:
|
||||
lib.print_ln('Period set to %s', self.T)
|
||||
|
||||
# Here we allocate the memory for the permutation
|
||||
# Note that self.shuffle_the_shuffle mutates this field
|
||||
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
|
||||
self.permutation = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
# We allow the caller to postpone the initialization of the shuffle
|
||||
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
|
||||
# Note that if you do not initialize, the ORAM is insecure
|
||||
if initialize:
|
||||
# If the ORAM is not initialized with existing data, we can apply
|
||||
# a small optimization by forgoing shuffling the shuffle, as all
|
||||
# entries of the shuffle are equal and empty.
|
||||
if empty_data:
|
||||
random_shuffle = sint.get_secure_shuffle(self.n)
|
||||
self.shufflei.secure_permute(random_shuffle)
|
||||
self.permutation.assign(self.shufflei[:].inverse_permutation())
|
||||
if trace:
|
||||
lib.print_ln('Calculated inverse permutation')
|
||||
else:
|
||||
self.shuffle_the_shuffle()
|
||||
else:
|
||||
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
|
||||
# Initialize position map (recursive oram)
|
||||
self.position_map = PositionMap.create(self.permutation, k + 1, self.T)
|
||||
|
||||
# Initialize stash
|
||||
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
|
||||
self.stashi = Array(self.T, value_type=value_type)
|
||||
self.t = MemValue(cint(0))
|
||||
|
||||
# Initialize temp variables needed during the computation
|
||||
self.found_ = self.bit_type.Array(size=self.T)
|
||||
self.j = MemValue(cint(0, size=1))
|
||||
|
||||
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
|
||||
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
|
||||
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
|
||||
SqrtOram.shuffle_the_shuffle = lib.method_block(SqrtOram.shuffle_the_shuffle) if allow_memory_allocation else SqrtOram.shuffle_the_shuffle
|
||||
SqrtOram.refresh = lib.method_block(SqrtOram.refresh) if allow_memory_allocation else SqrtOram.refresh
|
||||
SqrtOram.reinitialize = lib.method_block(SqrtOram.reinitialize) if allow_memory_allocation else SqrtOram.reinitialize
|
||||
|
||||
@lib.method_block
|
||||
def access(self, index: T, write: B, *value: T):
|
||||
global trace,n_parallel
|
||||
if trace:
|
||||
@lib.if_e(write.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value, size=self.entry_length).get_vector(
|
||||
0, size=self.entry_length)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
# This will result in a bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check wheterh the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
|
||||
# Store the stash item into the result if found
|
||||
# If the item is not in the stash, the result will simple remain 0
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
# If we are writing, we need to add the value
|
||||
self.stash[i] += write * access_here * (value - entry)
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Found item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Did not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += write * \
|
||||
found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
@lib.if_((write * found.bit_not()).reveal())
|
||||
def _():
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
||||
), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
global trace, n_parallel
|
||||
if trace:
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
|
||||
if isinstance(value, tuple) or isinstance(value,list):
|
||||
value = self.value_type(value, size=self.entry_length)
|
||||
print(value, type(value))
|
||||
elif isinstance(value, self.value_type):
|
||||
value = self.value_type(*value, size=self.entry_length)
|
||||
print(value, type(value))
|
||||
else:
|
||||
raise Exception("Cannot handle type of value passed")
|
||||
print(self.entry_length, value, type(value),len(value))
|
||||
value = MemValue(value)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
# This will result in an bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check wheterh the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
# We update the stash value
|
||||
self.stash[i] += access_here * (value - entry)
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Found item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Did not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += found.bit_not() * \
|
||||
(value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
@lib.if_(found.bit_not().reveal())
|
||||
def _():
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
||||
), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def read(self, index: T, *value: T):
|
||||
global debug, trace, n_parallel
|
||||
if trace:
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
if debug:
|
||||
lib.print_ln('Refreshing SqrtORAM')
|
||||
lib.print_ln('t=%s according to me', self.t)
|
||||
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
# This will result in a bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check whether the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
lib.check_point()
|
||||
|
||||
# Store the stash item into the result if found
|
||||
# If the item is not in the stash, the result will simple remain 0
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
# @lib.for_range(self.t)
|
||||
# def _(i):
|
||||
# lib.print_ln("stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
|
||||
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Found item in stash (found=%s)', found.reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Did not find item in stash (found=%s)', found.reveal())
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was not found in the stash
|
||||
# the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
return result
|
||||
|
||||
__getitem__ = read
|
||||
__setitem__ = write
|
||||
|
||||
def shuffle_the_shuffle(self) -> None:
|
||||
"""Permute the memory using a newly generated permutation and return
|
||||
the permutation that would generate this particular shuffling.
|
||||
|
||||
This permutation is needed to know how to map logical addresses to
|
||||
physical addresses, and is used as such by the postition map."""
|
||||
|
||||
global trace
|
||||
# Random permutation on n elements
|
||||
random_shuffle = sint.get_secure_shuffle(self.n)
|
||||
if trace:
|
||||
lib.print_ln('Generated shuffle')
|
||||
# Apply the random permutation
|
||||
self.shuffle.secure_permute(random_shuffle)
|
||||
if trace:
|
||||
lib.print_ln('Shuffled shuffle')
|
||||
self.shufflei.secure_permute(random_shuffle)
|
||||
if trace:
|
||||
lib.print_ln('Shuffled shuffle indexes')
|
||||
|
||||
lib.check_point()
|
||||
# Calculate the permutation that would have produced the newly produced
|
||||
# shuffle order. This can be calculated by regarding the logical
|
||||
# indexes (shufflei) as a permutation and calculating its inverse,
|
||||
# i.e. find P such that P([1,2,3,...]) = shufflei.
|
||||
# this is not necessarily equal to the inverse of the above generated
|
||||
# random_shuffle, as the shuffle may already be out of order (e.g. when
|
||||
# refreshing).
|
||||
self.permutation.assign(self.shufflei[:].inverse_permutation())
|
||||
# If shufflei does not contain exactly the indices
|
||||
# [i for i in range(self.n)],
|
||||
# the underlying waksman network of 'inverse_permutation' will hang.
|
||||
if trace:
|
||||
lib.print_ln('Calculated inverse permutation')
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
|
||||
reshuffling the shuffle.
|
||||
|
||||
This must happen on the T'th (period) accesses to the ORAM."""
|
||||
|
||||
self.j.write(0)
|
||||
# Shuffle and emtpy the stash, and store elements back into shuffle
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
@lib.if_(self.shuffle_used[i])
|
||||
def _():
|
||||
self.shuffle[i] = self.stash[self.j]
|
||||
self.shufflei[i] = self.stashi[self.j]
|
||||
self.j += 1
|
||||
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
self._reset_shuffle_used()
|
||||
|
||||
# Reinitialize position map
|
||||
self.shuffle_the_shuffle()
|
||||
# Note that we skip here the step of "packing" the permutation.
|
||||
# Since the underlying memory of the position map is already aligned in
|
||||
# this packed structure, we can simply overwrite the memory while
|
||||
# maintaining the structure.
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
# Note that this method is only used during refresh, and as such is
|
||||
# only called with a permutation as data.
|
||||
|
||||
# The logical addresses of some previous permutation are irrelevant and must be reset
|
||||
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
self._reset_shuffle_used()
|
||||
|
||||
# Note that the self.shuffle is actually a MultiArray
|
||||
# This structure is preserved while overwriting the values using
|
||||
# assign_vector
|
||||
self.shuffle.assign_vector(self.value_type(
|
||||
data, size=self.n * self.entry_length))
|
||||
# Note that this updates self.permutation (see constructor for explanation)
|
||||
self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
|
||||
def _reset_shuffle_used(self):
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
self.shuffle_used.assign_all(0)
|
||||
else:
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
self.shuffle_used[i] = cint(0)
|
||||
|
||||
|
||||
class PositionMap(Generic[T, B]):
|
||||
PACK_LOG: int = 3
|
||||
PACK: int = 1 << PACK_LOG
|
||||
|
||||
n: int # n in the paper
|
||||
depth: cint # k in the paper
|
||||
value_type: Type[T]
|
||||
|
||||
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
self.n = n
|
||||
self.depth = MemValue(cint(k))
|
||||
self.value_type = value_type
|
||||
self.bit_type = value_type.bit_type
|
||||
self.index_type = self.value_type.get_type(util.log2(n) + 1) # +1 because signed
|
||||
|
||||
@abstractmethod
|
||||
def get_position(self, logical_address: _secret, fake: B) -> Any:
|
||||
"""Retrieve the block at the given (secret) logical address."""
|
||||
global trace
|
||||
if trace:
|
||||
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
|
||||
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
"""Reinitialize this PositionMap.
|
||||
|
||||
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
|
||||
we cannot simply call __init__ on self. Instead, we must take care to
|
||||
reuse and overwrite the same memory.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = sint) -> PositionMap:
|
||||
"""Creates a new PositionMap. This is the method one should call when
|
||||
needing a new position map. Depending on the size of the given data, it
|
||||
will either instantiate a RecursivePositionMap or
|
||||
a LinearPositionMap."""
|
||||
n = len(permutation)
|
||||
|
||||
global debug
|
||||
if n / PositionMap.PACK <= period:
|
||||
if debug:
|
||||
lib.print_ln(
|
||||
'Initializing LinearPositionMap at depth %s of size %s', k, n)
|
||||
res = LinearPositionMap(permutation, value_type, k=k)
|
||||
else:
|
||||
if debug:
|
||||
lib.print_ln(
|
||||
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
|
||||
res = RecursivePositionMap(permutation, period, value_type, k=k)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
|
||||
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
PositionMap.__init__(self, len(permutation), k=k)
|
||||
pack = PositionMap.PACK
|
||||
|
||||
# We pack the permutation into a smaller structure, index with a new permutation
|
||||
packed_size = int(math.ceil(self.n / pack))
|
||||
packed_structure = MultiArray(
|
||||
(packed_size, pack), value_type=value_type)
|
||||
for i in range(packed_size):
|
||||
packed_structure[i] = Array.create_from(
|
||||
permutation[i*pack:(i+1)*pack])
|
||||
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
||||
period=period, entry_length=pack, k=self.depth)
|
||||
|
||||
# Initialize random temp variables needed during the computation
|
||||
self.block_index_demux: Array = self.bit_type.Array(self.T)
|
||||
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
|
||||
|
||||
@lib.method_block
|
||||
def get_position(self, logical_address: T, fake: B) -> _clear:
|
||||
super().get_position(logical_address, fake)
|
||||
|
||||
pack = PositionMap.PACK
|
||||
pack_log = PositionMap.PACK_LOG
|
||||
|
||||
# The item at logical_address
|
||||
# will be in block with index h (block.<h>)
|
||||
# at position l in block.data (block.data<l>)
|
||||
program = Program.prog
|
||||
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
|
||||
logical_address).right_shift(pack_log, program.bit_length)))
|
||||
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
|
||||
|
||||
global trace
|
||||
if trace:
|
||||
print_at_depth(self.depth, '-> logical_address=%s: h=%s, l=%s', logical_address.reveal(), h.reveal(), l.reveal())
|
||||
# @lib.for_range(self.t)
|
||||
# def _(i):
|
||||
# print_at_depth(self.depth, "stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
|
||||
|
||||
# The resulting physical address
|
||||
p = MemValue(self.index_type(-1))
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
|
||||
# First we try and retrieve the item from the stash at position stash[h][l]
|
||||
# Since h and l are secret, we do this by scanning the entire stash
|
||||
|
||||
# First we scan the stash for the block we need
|
||||
self.block_index_demux.assign_all(0)
|
||||
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
|
||||
def _(i):
|
||||
self.block_index_demux[i] = ( self.stashi[i] == h) & self.bit_type(i < self.t)
|
||||
# We can determine if the 'index' is in the stash by checking the
|
||||
# block_index_demux array
|
||||
found = sum(self.block_index_demux)
|
||||
# Once a block is found, we use the following condition to pick the correct item from that block
|
||||
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
|
||||
|
||||
# Finally we use the conditions to conditionally write p
|
||||
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
|
||||
def p_(i):
|
||||
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
|
||||
# Therefore we include the check (i < self.t)
|
||||
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i // pack< self.t)
|
||||
p.write(p_())
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 0)
|
||||
def _(): print_at_depth(self.depth, 'Retrieve shuffle[%s]:', h.reveal())
|
||||
@lib.else_
|
||||
def __():
|
||||
print_at_depth(self.depth, 'Retrieve dummy element from shuffle:')
|
||||
|
||||
# Then we try and retrieve the item from the shuffle (the actual memory)
|
||||
# Depending on whether we found the item in the stash, we either
|
||||
# block 'h' in which 'index' resides, or a random block from the shuffle
|
||||
p_prime = self.position_map.get_position(h, found)
|
||||
self.shuffle_used[p_prime] = cbit(True)
|
||||
|
||||
# The block retrieved from the shuffle
|
||||
block_p_prime: Array = self.shuffle[p_prime]
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 0)
|
||||
def _():
|
||||
print_at_depth(self.depth, 'Retrieved position from shuffle[%s]=(%s: %s)',
|
||||
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
print_at_depth(self.depth, 'Retrieved dummy position from shuffle[%s]=(%s: %s)',
|
||||
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
|
||||
|
||||
# We add the retrieved block from the shuffle to the stash
|
||||
self.stash[self.t].assign(block_p_prime[:])
|
||||
self.stashi[self.t] = self.shufflei[p_prime]
|
||||
# Increase t
|
||||
self.t += 1
|
||||
|
||||
# if found or not fake
|
||||
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
|
||||
# Retrieve l'th item from block
|
||||
# l is secret, so we must use linear scan
|
||||
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
|
||||
pack)) & condition.expand_to_vector(pack))
|
||||
|
||||
@lib.for_range_opt(pack)
|
||||
def _(i):
|
||||
p.write((hit[i]).if_else(block_p_prime[i], p))
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
SqrtOram.reinitialize(self, *permutation)
|
||||
|
||||
|
||||
class LinearPositionMap(PositionMap):
|
||||
physical: Array
|
||||
used: Array
|
||||
|
||||
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
PositionMap.__init__(self, len(data), value_type, k=k)
|
||||
self.physical = data
|
||||
self.used = self.bit_type.Array(self.n)
|
||||
|
||||
# Initialize random temp variables needed during the computation
|
||||
self.physical_demux: Array = self.bit_type.Array(self.n)
|
||||
|
||||
@lib.method_block
|
||||
def get_position(self, logical_address: T, fake: B) -> _clear:
|
||||
"""
|
||||
This method corresponds to GetPosBase in the paper.
|
||||
"""
|
||||
super().get_position(logical_address, fake)
|
||||
|
||||
global trace
|
||||
if trace:
|
||||
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
|
||||
def _():
|
||||
lib.runtime_error(
|
||||
'logical_address must lie between 0 and self.n - 1')
|
||||
|
||||
fake = MemValue(self.bit_type(fake))
|
||||
logical_address = MemValue(logical_address)
|
||||
|
||||
p: MemValue = MemValue(self.index_type(-1))
|
||||
done: B = self.bit_type(False)
|
||||
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
self.physical_demux.assign_all(0)
|
||||
|
||||
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
|
||||
def condition_i(i):
|
||||
self.physical_demux[i] = \
|
||||
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
|
||||
| (fake & self.used[i].bit_not())
|
||||
|
||||
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
|
||||
# We only need once, so we pick the first one we find
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
self.physical_demux[i] &= done.bit_not()
|
||||
done.update(done | self.physical_demux[i])
|
||||
|
||||
# Retrieve the value from the physical memory obliviously
|
||||
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
|
||||
def calc_p(i):
|
||||
return self.physical[i] * self.physical_demux[i]
|
||||
p.write(calc_p())
|
||||
|
||||
# Update self.used
|
||||
self.used.assign(self.used[:] | self.physical_demux[:])
|
||||
|
||||
if trace:
|
||||
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
|
||||
def _():
|
||||
lib.runtime_error(
|
||||
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
self.physical.assign_vector(data)
|
||||
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
self.used.assign_all(False)
|
||||
else:
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
self.used[i] = self.bit_type(0)
|
||||
|
||||
def print_at_depth(depth: cint, message: str, *kwargs):
|
||||
lib.print_str('%s', depth)
|
||||
@lib.for_range(depth)
|
||||
def _(i):
|
||||
lib.print_char(' ')
|
||||
lib.print_char(' ')
|
||||
lib.print_ln(message, *kwargs)
|
||||
1398
Compiler/types.py
1398
Compiler/types.py
File diff suppressed because it is too large
Load Diff
@@ -116,6 +116,11 @@ def round_to_int(x):
|
||||
return x.round_to_int()
|
||||
|
||||
def tree_reduce(function, sequence):
|
||||
try:
|
||||
return sequence.tree_reduce(function)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
sequence = list(sequence)
|
||||
assert len(sequence) > 0
|
||||
n = len(sequence)
|
||||
@@ -233,6 +238,9 @@ def mem_size(x):
|
||||
except AttributeError:
|
||||
return 1
|
||||
|
||||
def find_in_dict(d, v):
|
||||
return list(d.keys())[list(d.values()).index(v)]
|
||||
|
||||
class set_by_id(object):
|
||||
def __init__(self, init=[]):
|
||||
self.content = {}
|
||||
@@ -257,6 +265,9 @@ class set_by_id(object):
|
||||
def pop(self):
|
||||
return self.content.popitem()[1]
|
||||
|
||||
def remove(self, value):
|
||||
del self.content[id(value)]
|
||||
|
||||
def __ior__(self, values):
|
||||
for value in values:
|
||||
self.add(value)
|
||||
|
||||
144
Dockerfile
Normal file
144
Dockerfile
Normal file
@@ -0,0 +1,144 @@
|
||||
###############################################################################
|
||||
# Build this stage for a build environment, e.g.: #
|
||||
# #
|
||||
# docker build --tag mpspdz:buildenv --target buildenv . #
|
||||
# #
|
||||
# The above is equivalent to: #
|
||||
# #
|
||||
# docker build --tag mpspdz:buildenv \ #
|
||||
# --target buildenv \ #
|
||||
# --build-arg arch=native \ #
|
||||
# --build-arg cxx=clang++-11 \ #
|
||||
# --build-arg use_ntl=0 \ #
|
||||
# --build-arg prep_dir="Player-Data" \ #
|
||||
# --build-arg ssl_dir="Player-Data" #
|
||||
# --build-arg cryptoplayers=0 #
|
||||
# #
|
||||
# To build for an x86-64 architecture, with g++, NTL (for HE), custom #
|
||||
# prep_dir & ssl_dir, and to use encrypted channels for 4 players: #
|
||||
# #
|
||||
# docker build --tag mpspdz:buildenv \ #
|
||||
# --target buildenv \ #
|
||||
# --build-arg arch=x86-64 \ #
|
||||
# --build-arg cxx=g++ \ #
|
||||
# --build-arg use_ntl=1 \ #
|
||||
# --build-arg prep_dir="/opt/prepdata" \ #
|
||||
# --build-arg ssl_dir="/opt/ssl" #
|
||||
# --build-arg cryptoplayers=4 . #
|
||||
# #
|
||||
# To work in a container to build different machines, and compile programs: #
|
||||
# #
|
||||
# docker run --rm -it mpspdz:buildenv bash #
|
||||
# #
|
||||
# Once in the container, build a machine and compile a program: #
|
||||
# #
|
||||
# $ make replicated-ring-party.x #
|
||||
# $ ./compile.py -R 64 tutorial #
|
||||
# #
|
||||
###############################################################################
|
||||
FROM python:3.10.3-bullseye as buildenv
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
automake \
|
||||
build-essential \
|
||||
clang-11 \
|
||||
cmake \
|
||||
git \
|
||||
libboost-dev \
|
||||
libboost-thread-dev \
|
||||
libclang-dev \
|
||||
libgmp-dev \
|
||||
libntl-dev \
|
||||
libsodium-dev \
|
||||
libssl-dev \
|
||||
libtool \
|
||||
vim \
|
||||
gdb \
|
||||
valgrind \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV MP_SPDZ_HOME /usr/src/MP-SPDZ
|
||||
WORKDIR $MP_SPDZ_HOME
|
||||
|
||||
RUN pip install --upgrade pip ipython
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG arch=native
|
||||
ARG cxx=clang++-11
|
||||
ARG use_ntl=0
|
||||
ARG prep_dir="Player-Data"
|
||||
ARG ssl_dir="Player-Data"
|
||||
|
||||
RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \
|
||||
&& echo "CXX = ${cxx}" >> CONFIG.mine \
|
||||
&& echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \
|
||||
&& echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \
|
||||
&& echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \
|
||||
>> CONFIG.mine \
|
||||
&& mkdir -p $prep_dir $ssl_dir \
|
||||
&& echo "PREP_DIR = '-DPREP_DIR=\"${prep_dir}/\"'" >> CONFIG.mine \
|
||||
&& echo "SSL_DIR = '-DSSL_DIR=\"${ssl_dir}/\"'" >> CONFIG.mine
|
||||
|
||||
# ssl keys
|
||||
ARG cryptoplayers=0
|
||||
ENV PLAYERS ${cryptoplayers}
|
||||
RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir}
|
||||
|
||||
RUN make boost libote
|
||||
|
||||
###############################################################################
|
||||
# Use this stage to a build a specific virtual machine. For example: #
|
||||
# #
|
||||
# docker build --tag mpspdz:shamir \ #
|
||||
# --target machine \ #
|
||||
# --build-arg machine=shamir-party.x \ #
|
||||
# --build-arg gfp_mod_sz=4 . #
|
||||
# #
|
||||
# The above will build shamir-party.x with 256 bit length. #
|
||||
# #
|
||||
# If no build arguments are passed (via --build-arg), mascot-party.x is built #
|
||||
# with the default 128 bit length. #
|
||||
###############################################################################
|
||||
FROM buildenv as machine
|
||||
|
||||
ARG machine="mascot-party.x"
|
||||
|
||||
ARG gfp_mod_sz=2
|
||||
|
||||
RUN echo "MOD = -DGFP_MOD_SZ=${gfp_mod_sz}" >> CONFIG.mine
|
||||
|
||||
RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/
|
||||
|
||||
|
||||
################################################################################
|
||||
# This is the default stage. Use it to compile a high-level program. #
|
||||
# By default, tutorial.mpc is compiled with --field=64 bits. #
|
||||
# #
|
||||
# docker build --tag mpspdz:mascot-tutorial \ #
|
||||
# --build-arg src=tutorial \ #
|
||||
# --build-arg compile_options="--field=64" . #
|
||||
# #
|
||||
# Note that build arguments from previous stages can also be passed. For #
|
||||
# instance, building replicated-ring-party.x, for 3 crypto players with custom #
|
||||
# PREP_DIR and SSL_DIR, and compiling tutorial.mpc with --ring=64: #
|
||||
# #
|
||||
# docker build --tag mpspdz:replicated-ring \ #
|
||||
# --build-arg machine=replicated-ring-party.x \ #
|
||||
# --build-arg prep_dir=/opt/prep \ #
|
||||
# --build-arg ssl_dir=/opt/ssl \ #
|
||||
# --build-arg cryptoplayers=3 \ #
|
||||
# --build-arg compile_options="--ring=64" . #
|
||||
# #
|
||||
# Test it: #
|
||||
# #
|
||||
# docker run --rm -it mpspdz:replicated-ring ./Scripts/ring.sh tutorial #
|
||||
################################################################################
|
||||
FROM machine as program
|
||||
|
||||
ARG src="tutorial"
|
||||
ARG compile_options="--field=64"
|
||||
RUN ./compile.py ${compile_options} ${src}
|
||||
RUN mkdir -p Player-Data \
|
||||
&& echo 1 2 3 4 > Player-Data/Input-P0-0 \
|
||||
&& echo 1 2 3 4 > Player-Data/Input-P1-0
|
||||
@@ -24,4 +24,5 @@ int main()
|
||||
generate_mac_keys<Share<P256Element::Scalar>>(key, 2, prefix, G);
|
||||
make_mult_triples<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix, G);
|
||||
make_inverse<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix, G);
|
||||
P256Element::finish();
|
||||
}
|
||||
|
||||
@@ -14,7 +14,14 @@ void P256Element::init()
|
||||
curve = EC_GROUP_new_by_curve_name(NID_secp256k1);
|
||||
assert(curve != 0);
|
||||
auto modulus = EC_GROUP_get0_order(curve);
|
||||
Scalar::init_field(BN_bn2dec(modulus), false);
|
||||
auto mod = BN_bn2dec(modulus);
|
||||
Scalar::init_field(mod, false);
|
||||
free(mod);
|
||||
}
|
||||
|
||||
void P256Element::finish()
|
||||
{
|
||||
EC_GROUP_free(curve);
|
||||
}
|
||||
|
||||
P256Element::P256Element()
|
||||
@@ -29,7 +36,7 @@ P256Element::P256Element(const Scalar& other) :
|
||||
{
|
||||
BIGNUM* exp = BN_new();
|
||||
BN_dec2bn(&exp, bigint(other).get_str().c_str());
|
||||
assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0);
|
||||
assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0);
|
||||
BN_free(exp);
|
||||
}
|
||||
|
||||
@@ -38,10 +45,15 @@ P256Element::P256Element(word other) :
|
||||
{
|
||||
BIGNUM* exp = BN_new();
|
||||
BN_dec2bn(&exp, to_string(other).c_str());
|
||||
assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0);
|
||||
assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0);
|
||||
BN_free(exp);
|
||||
}
|
||||
|
||||
P256Element::~P256Element()
|
||||
{
|
||||
EC_POINT_free(point);
|
||||
}
|
||||
|
||||
P256Element& P256Element::operator =(const P256Element& other)
|
||||
{
|
||||
assert(EC_POINT_copy(point, other.point) != 0);
|
||||
@@ -56,7 +68,11 @@ void P256Element::check()
|
||||
P256Element::Scalar P256Element::x() const
|
||||
{
|
||||
BIGNUM* x = BN_new();
|
||||
#if OPENSSL_VERSION_MAJOR >= 3
|
||||
assert(EC_POINT_get_affine_coordinates(curve, point, x, 0, 0) != 0);
|
||||
#else
|
||||
assert(EC_POINT_get_affine_coordinates_GFp(curve, point, x, 0, 0) != 0);
|
||||
#endif
|
||||
char* xx = BN_bn2dec(x);
|
||||
Scalar res((bigint(xx)));
|
||||
OPENSSL_free(xx);
|
||||
@@ -95,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const
|
||||
return not cmp;
|
||||
}
|
||||
|
||||
void P256Element::pack(octetStream& os) const
|
||||
void P256Element::pack(octetStream& os, int) const
|
||||
{
|
||||
octet* buffer;
|
||||
size_t length = EC_POINT_point2buf(curve, point,
|
||||
@@ -103,9 +119,10 @@ void P256Element::pack(octetStream& os) const
|
||||
assert(length != 0);
|
||||
os.store_int(length, 8);
|
||||
os.append(buffer, length);
|
||||
free(buffer);
|
||||
}
|
||||
|
||||
void P256Element::unpack(octetStream& os)
|
||||
void P256Element::unpack(octetStream& os, int)
|
||||
{
|
||||
size_t length = os.get_int(8);
|
||||
assert(
|
||||
|
||||
@@ -22,20 +22,23 @@ private:
|
||||
EC_POINT* point;
|
||||
|
||||
public:
|
||||
typedef void next;
|
||||
typedef P256Element next;
|
||||
typedef void Square;
|
||||
|
||||
static const true_type invertible;
|
||||
|
||||
static int size() { return 0; }
|
||||
static int length() { return 256; }
|
||||
static string type_string() { return "P256"; }
|
||||
|
||||
static void init();
|
||||
static void finish();
|
||||
|
||||
P256Element();
|
||||
P256Element(const P256Element& other);
|
||||
P256Element(const Scalar& other);
|
||||
P256Element(word other);
|
||||
~P256Element();
|
||||
|
||||
P256Element& operator=(const P256Element& other);
|
||||
|
||||
@@ -55,10 +58,10 @@ public:
|
||||
|
||||
void assign_zero() { *this = {}; }
|
||||
bool is_zero() { return *this == P256Element(); }
|
||||
void add(octetStream& os) { *this += os.get<P256Element>(); }
|
||||
void add(octetStream& os, int = -1) { *this += os.get<P256Element>(); }
|
||||
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os);
|
||||
void pack(octetStream& os, int = -1) const;
|
||||
void unpack(octetStream& os, int = -1);
|
||||
|
||||
octetStream hash(size_t n_bytes) const;
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ The following binaries have been used for the paper:
|
||||
All binaries offer the same interface. With MASCOT for example, run
|
||||
the following:
|
||||
```
|
||||
./mascot-ecsda-party.x -p 0 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
|
||||
./mascot-ecsda-party.x -p 1 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
|
||||
./mascot-ecdsa-party.x -p 0 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
|
||||
./mascot-ecdsa-party.x -p 1 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@@ -45,12 +45,13 @@ int main(int argc, const char** argv)
|
||||
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
|
||||
read_mac_key(prefix, N, keyp);
|
||||
|
||||
pShare::MAC_Check::setup(P);
|
||||
Share<P256Element>::MAC_Check::setup(P);
|
||||
|
||||
DataPositions usage;
|
||||
Sub_Data_Files<pShare> prep(N, prefix, usage);
|
||||
typename pShare::Direct_MC MCp(keyp);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
BaseMachine machine;
|
||||
machine.ot_setups.push_back({P, false});
|
||||
SubProcessor<pShare> proc(_, MCp, prep, P);
|
||||
|
||||
pShare sk, __;
|
||||
@@ -60,4 +61,8 @@ int main(int argc, const char** argv)
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
check(tuples, sk, keyp, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts);
|
||||
|
||||
pShare::MAC_Check::teardown();
|
||||
Share<P256Element>::MAC_Check::teardown();
|
||||
P256Element::finish();
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/MalRep.hpp"
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
@@ -52,10 +54,10 @@ void run(int argc, const char** argv)
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
DataPositions usage;
|
||||
@@ -69,4 +71,5 @@ void run(int argc, const char** argv)
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
// check(tuples, sk, {}, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
|
||||
P256Element::finish();
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#define NO_MIXED_CIRCUITS
|
||||
|
||||
#define NO_SECURITY_CHECK
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
@@ -92,9 +92,6 @@ void run(int argc, const char** argv)
|
||||
P256Element::init();
|
||||
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
|
||||
|
||||
BaseMachine machine;
|
||||
machine.ot_setups.push_back({P, true});
|
||||
|
||||
P256Element::Scalar keyp;
|
||||
SeededPRNG G;
|
||||
keyp.randomize(G);
|
||||
@@ -102,6 +99,9 @@ void run(int argc, const char** argv)
|
||||
typedef T<P256Element::Scalar> pShare;
|
||||
DataPositions usage;
|
||||
|
||||
pShare::MAC_Check::setup(P);
|
||||
T<P256Element>::MAC_Check::setup(P);
|
||||
|
||||
OnlineOptions::singleton.batch_size = 1;
|
||||
typename pShare::Direct_MC MCp(keyp);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
@@ -113,10 +113,10 @@ void run(int argc, const char** argv)
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
sk_prep.get_two(DATA_INVERSE, sk, __);
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
typename pShare::TriplePrep prep(0, usage);
|
||||
@@ -137,4 +137,8 @@ void run(int argc, const char** argv)
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
//check(tuples, sk, keyp, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
|
||||
|
||||
pShare::MAC_Check::teardown();
|
||||
T<P256Element>::MAC_Check::teardown();
|
||||
P256Element::finish();
|
||||
}
|
||||
|
||||
@@ -41,8 +41,8 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
timer.start();
|
||||
Player& P = proc.P;
|
||||
auto& prep = proc.DataF;
|
||||
size_t start = P.sent + prep.data_sent();
|
||||
auto stats = P.comm_stats + prep.comm_stats();
|
||||
size_t start = P.total_comm().sent;
|
||||
auto stats = P.total_comm();
|
||||
auto& extra_player = P;
|
||||
|
||||
auto& protocol = proc.protocol;
|
||||
@@ -77,7 +77,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul(&proc);
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
protocol.prepare_mul(inv_ks[i], sk);
|
||||
protocol.start_exchange();
|
||||
@@ -106,9 +106,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
timer.stop();
|
||||
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
|
||||
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
|
||||
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
|
||||
<< 1e-3 * (P.total_comm().sent - start) / buffer_size
|
||||
<< " kbytes per tuple" << endl;
|
||||
(P.comm_stats + prep.comm_stats() - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
}
|
||||
|
||||
template<template<class U> class T>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "Protocols/SemiPrep.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "GC/SemiSecret.hpp"
|
||||
#include "ot-ecdsa-party.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
@@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
(void) pk;
|
||||
Timer timer;
|
||||
timer.start();
|
||||
size_t start = P.sent;
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
EcSignature signature;
|
||||
vector<P256Element> opened_R;
|
||||
if (opts.R_after_msg)
|
||||
@@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
auto& protocol = proc->protocol;
|
||||
if (proc)
|
||||
{
|
||||
protocol.init_mul(proc);
|
||||
protocol.init_mul();
|
||||
protocol.prepare_mul(sk, tuple.a);
|
||||
protocol.start_exchange();
|
||||
}
|
||||
@@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
auto rx = tuple.R.x();
|
||||
signature.s = MC.open(
|
||||
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
|
||||
auto diff = (P.total_comm() - stats);
|
||||
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (P.sent - start) << " bytes" << endl;
|
||||
auto diff = (P.comm_stats - stats);
|
||||
<< diff.sent << " bytes" << endl;
|
||||
diff.print(true);
|
||||
return signature;
|
||||
}
|
||||
@@ -139,11 +138,11 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
P256Element pk = MCc.open(sk, P);
|
||||
MCc.Check(P);
|
||||
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
|
||||
{
|
||||
@@ -154,13 +153,12 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto& check_player = MCp.get_check_player(P);
|
||||
auto stats = check_player.comm_stats;
|
||||
auto start = check_player.sent;
|
||||
auto stats = check_player.total_comm();
|
||||
MCp.Check(P);
|
||||
MCc.Check(P);
|
||||
auto diff = (check_player.total_comm() - stats);
|
||||
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (check_player.sent - start) << " bytes" << endl;
|
||||
auto diff = (check_player.comm_stats - stats);
|
||||
<< diff.sent << " bytes" << endl;
|
||||
diff.print();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,24 +8,92 @@
|
||||
|
||||
#include "Networking/ssl_sockets.h"
|
||||
|
||||
#ifdef NO_CLIENT_TLS
|
||||
class client_ctx
|
||||
{
|
||||
public:
|
||||
client_ctx(string)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
class client_socket
|
||||
{
|
||||
public:
|
||||
int socket;
|
||||
|
||||
client_socket(boost::asio::io_service&,
|
||||
client_ctx&, int plaintext_socket, string,
|
||||
string, bool) : socket(plaintext_socket)
|
||||
{
|
||||
}
|
||||
|
||||
~client_socket()
|
||||
{
|
||||
close(socket);
|
||||
}
|
||||
};
|
||||
|
||||
inline void send(client_socket* socket, octet* data, size_t len)
|
||||
{
|
||||
send(socket->socket, data, len);
|
||||
}
|
||||
|
||||
inline void receive(client_socket* socket, octet* data, size_t len)
|
||||
{
|
||||
receive(socket->socket, data, len);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
typedef ssl_ctx client_ctx;
|
||||
typedef ssl_socket client_socket;
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Client-side interface
|
||||
*/
|
||||
class Client
|
||||
{
|
||||
vector<int> plain_sockets;
|
||||
ssl_ctx ctx;
|
||||
client_ctx ctx;
|
||||
ssl_service io_service;
|
||||
|
||||
public:
|
||||
vector<ssl_socket*> sockets;
|
||||
/**
|
||||
* Sockets for cleartext communication
|
||||
*/
|
||||
vector<client_socket*> sockets;
|
||||
|
||||
/**
|
||||
* Specification of computation domain
|
||||
*/
|
||||
octetStream specification;
|
||||
|
||||
/**
|
||||
* Start a new set of connections to computing parties.
|
||||
* @param hostnames location of computing parties
|
||||
* @param port_base port base
|
||||
* @param my_client_id client identifier
|
||||
*/
|
||||
Client(const vector<string>& hostnames, int port_base, int my_client_id);
|
||||
~Client();
|
||||
|
||||
/**
|
||||
* Securely input private values.
|
||||
* @param values vector of integer-like values
|
||||
*/
|
||||
template<class T>
|
||||
void send_private_inputs(const vector<T>& values);
|
||||
|
||||
template<class T>
|
||||
vector<T> receive_outputs(int n);
|
||||
/**
|
||||
* Securely receive output values.
|
||||
* @param n number of values
|
||||
* @returns vector of integer-like values
|
||||
*/
|
||||
template<class T, class U = T>
|
||||
vector<U> receive_outputs(int n);
|
||||
};
|
||||
|
||||
#endif /* EXTERNALIO_CLIENT_H_ */
|
||||
|
||||
@@ -20,7 +20,7 @@ Client::Client(const vector<string>& hostnames, int port_base,
|
||||
{
|
||||
set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i);
|
||||
octetStream(to_string(my_client_id)).Send(plain_sockets[i]);
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
|
||||
sockets[i] = new client_socket(io_service, ctx, plain_sockets[i],
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
@@ -46,20 +46,37 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
octetStream os;
|
||||
vector< vector<T> > triples(num_inputs, vector<T>(3));
|
||||
vector<T> triple_shares(3);
|
||||
bool active = true;
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (size_t j = 0; j < sockets.size(); j++)
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "receiving from " << j << endl << flush;
|
||||
#endif
|
||||
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[j]);
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "received " << os.get_length() << " from " << j << endl;
|
||||
cerr << "received " << os.get_length() << " from " << j << endl << flush;
|
||||
#endif
|
||||
|
||||
if (j == 0)
|
||||
{
|
||||
if (os.get_length() == 3 * values.size() * T::size())
|
||||
active = true;
|
||||
else
|
||||
active = false;
|
||||
}
|
||||
|
||||
int n_expected = active ? 3 : 1;
|
||||
if (os.get_length() != n_expected * T::size() * values.size())
|
||||
throw runtime_error("unexpected data length in sending");
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
for (int k = 0; k < n_expected; k++)
|
||||
{
|
||||
triple_shares[k].unpack(os);
|
||||
triples[j][k] += triple_shares[k];
|
||||
@@ -67,16 +84,18 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
}
|
||||
}
|
||||
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
if (active)
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
@@ -91,19 +110,33 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
|
||||
// Receive shares of the result and sum together.
|
||||
// Also receive authenticating values.
|
||||
template<class T>
|
||||
vector<T> Client::receive_outputs(int n)
|
||||
template<class T, class U>
|
||||
vector<U> Client::receive_outputs(int n)
|
||||
{
|
||||
vector<T> triples(3 * n);
|
||||
octetStream os;
|
||||
bool active = true;
|
||||
for (auto& socket : sockets)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(socket);
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "received " << os.get_length() << endl;
|
||||
cout << "received " << os.get_length() << endl << flush;
|
||||
#endif
|
||||
for (int j = 0; j < 3 * n; j++)
|
||||
|
||||
if (socket == sockets[0])
|
||||
{
|
||||
if (os.get_length() == (size_t) 3 * n * T::size())
|
||||
active = true;
|
||||
else
|
||||
active = false;
|
||||
}
|
||||
|
||||
int n_expected = n * (active ? 3 : 1);
|
||||
if (os.get_length() != (size_t) n_expected * T::size())
|
||||
throw runtime_error("unexpected data length in receiving");
|
||||
|
||||
for (int j = 0; j < n_expected; j++)
|
||||
{
|
||||
T value;
|
||||
value.unpack(os);
|
||||
@@ -111,16 +144,24 @@ vector<T> Client::receive_outputs(int n)
|
||||
}
|
||||
}
|
||||
|
||||
vector<T> output_values;
|
||||
for (int i = 0; i < 3 * n; i += 3)
|
||||
if (active)
|
||||
{
|
||||
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
|
||||
vector<U> output_values;
|
||||
for (int i = 0; i < 3 * n; i += 3)
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
}
|
||||
output_values.push_back(triples[i]);
|
||||
}
|
||||
output_values.push_back(triples[i]);
|
||||
}
|
||||
|
||||
return output_values;
|
||||
return output_values;
|
||||
}
|
||||
else
|
||||
{
|
||||
triples.resize(n);
|
||||
return triples;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
|
||||
The ExternalIO directory contains an example of managing I/O between
|
||||
external client processes and parties running MP-SPDZ engines. These
|
||||
instructions assume that MP-SPDZ has been built as per the [project
|
||||
readme](../README.md).
|
||||
|
||||
## Working Examples
|
||||
|
||||
[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a
|
||||
[bankers-bonus-client.cpp](../ExternalIO/bankers-bonus-client.cpp) and
|
||||
[bankers-bonus-client.py](../ExternalIO/bankers-bonus-client.py) act as a
|
||||
client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc)
|
||||
and demonstrates sending input and receiving output as described by
|
||||
[Damgård et al.](https://eprint.iacr.org/2015/1006) The computation
|
||||
allows up to eight clients to input a number and computes the client
|
||||
with the largest input. You can run it as follows from the main
|
||||
with the largest input. You can run the C++ code as follows from the main
|
||||
directory:
|
||||
```
|
||||
make bankers-bonus-client.x
|
||||
./compile.py bankers_bonus 1
|
||||
Scripts/setup-ssl.sh <nparties>
|
||||
Scripts/setup-clients.sh 3
|
||||
Scripts/<protocol>.sh &
|
||||
PLAYERS=<nparties> Scripts/<protocol>.sh bankers_bonus-1 &
|
||||
./bankers-bonus-client.x 0 <nparties> 100 0 &
|
||||
./bankers-bonus-client.x 1 <nparties> 200 0 &
|
||||
./bankers-bonus-client.x 2 <nparties> 50 1
|
||||
```
|
||||
`<protocol>` can be any arithmetic protocol (e.g., `mascot`) but not a
|
||||
binary protocol (e.g., `yao`).
|
||||
This should output that the winning id is 1. Note that the ids have to
|
||||
be incremental, and the client with the highest id has to input 1 as
|
||||
the last argument while the others have to input 0 there. Furthermore,
|
||||
@@ -28,58 +34,30 @@ protocol script. The setup scripts generate the necessary SSL
|
||||
certificates and keys. Therefore, if you run the computation on
|
||||
different hosts, you will have to distribute the `*.pem` files.
|
||||
|
||||
For the Python client, make sure to install
|
||||
[gmpy2](https://pypi.org/project/gmpy2), and run
|
||||
`ExternalIO/bankers-bonus-client.py` instead of
|
||||
`bankers-bonus-client.x`.
|
||||
|
||||
## I/O MPC Instructions
|
||||
|
||||
### Connection Setup
|
||||
|
||||
**listen**(*int port_num*)
|
||||
|
||||
Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background.
|
||||
|
||||
*port_num* - the port number to listen on.
|
||||
|
||||
**acceptclientconnection**(*regint client_socket_id*, *int port_num*)
|
||||
|
||||
Picks the first available client socket connection. Blocks if none available.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*port_num* - the port number identifies the socket server to accept connections on.
|
||||
1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients)
|
||||
2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection)
|
||||
3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection)
|
||||
|
||||
### Data Exchange
|
||||
|
||||
Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py).
|
||||
Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types).
|
||||
|
||||
*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*)
|
||||
1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket)
|
||||
2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client)
|
||||
3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients)
|
||||
|
||||
Read a share of an input from a client, blocking on the client send.
|
||||
## Client-Side Interface
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*number_of_inputs* - the number of inputs expected
|
||||
|
||||
*[inputs]* - returned list of shares of private input.
|
||||
|
||||
**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*)
|
||||
|
||||
Write shares of values including macs to an external client.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*[values]* - list of shares of values to send to client.
|
||||
|
||||
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
|
||||
|
||||
See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message.
|
||||
|
||||
*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*)
|
||||
|
||||
Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf)
|
||||
|
||||
*number_of_inputs* - the number of inputs expected
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
|
||||
|
||||
*[inputs]* - returned list of shares of private input.
|
||||
The example uses the `Client` class implemented in
|
||||
`ExternalIO/Client.hpp` to handle the communication, see
|
||||
[this reference](https://mp-spdz.readthedocs.io/en/latest/io.html#reference) for
|
||||
documentation.
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
template<class T>
|
||||
template<class T, class U>
|
||||
void one_run(T salary_value, Client& client)
|
||||
{
|
||||
// Run the computation
|
||||
@@ -54,18 +54,18 @@ void one_run(T salary_value, Client& client)
|
||||
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
|
||||
|
||||
// Get the result back (client_id of winning client)
|
||||
T result = client.receive_outputs<T>(1)[0];
|
||||
U result = client.receive_outputs<T>(1)[0];
|
||||
|
||||
cout << "Winning client id is : " << result << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class T, class U>
|
||||
void run(double salary_value, Client& client)
|
||||
{
|
||||
// sint
|
||||
one_run<T>(long(round(salary_value)), client);
|
||||
one_run<T, U>(long(round(salary_value)), client);
|
||||
// sfix with f = 16
|
||||
one_run<T>(long(round(salary_value * exp2(16))), client);
|
||||
one_run<T, U>(long(round(salary_value * exp2(16))), client);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
@@ -125,7 +125,7 @@ int main(int argc, char** argv)
|
||||
{
|
||||
gfp::init_field(specification.get<bigint>());
|
||||
cerr << "using prime " << gfp::pr() << endl;
|
||||
run<gfp>(salary_value, client);
|
||||
run<gfp, gfp>(salary_value, client);
|
||||
break;
|
||||
}
|
||||
case 'R':
|
||||
@@ -134,13 +134,13 @@ int main(int argc, char** argv)
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
run<Z2<64>>(salary_value, client);
|
||||
run<Z2<64>, Z2<64>>(salary_value, client);
|
||||
break;
|
||||
case 104:
|
||||
run<Z2<104>>(salary_value, client);
|
||||
run<Z2<104>, Z2<64>>(salary_value, client);
|
||||
break;
|
||||
case 128:
|
||||
run<Z2<128>>(salary_value, client);
|
||||
run<Z2<128>, Z2<64>>(salary_value, client);
|
||||
break;
|
||||
default:
|
||||
cerr << R << "-bit ring not implemented";
|
||||
|
||||
35
ExternalIO/bankers-bonus-client.py
Executable file
35
ExternalIO/bankers-bonus-client.py
Executable file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from client import *
|
||||
from domains import *
|
||||
|
||||
client_id = int(sys.argv[1])
|
||||
n_parties = int(sys.argv[2])
|
||||
bonus = float(sys.argv[3])
|
||||
finish = int(sys.argv[4])
|
||||
|
||||
client = Client(['localhost'] * n_parties, 14000, client_id)
|
||||
|
||||
type = client.specification.get_int(4)
|
||||
|
||||
if type == ord('R'):
|
||||
domain = Z2(client.specification.get_int(4))
|
||||
elif type == ord('p'):
|
||||
domain = Fp(client.specification.get_bigint())
|
||||
else:
|
||||
raise Exception('invalid type')
|
||||
|
||||
for socket in client.sockets:
|
||||
os = octetStream()
|
||||
os.store(finish)
|
||||
os.Send(socket)
|
||||
|
||||
for x in bonus, bonus * 2 ** 16:
|
||||
client.send_private_inputs([domain(x)])
|
||||
|
||||
print('Winning client id is :',
|
||||
client.receive_outputs(domain, 1)[0].v % 2 ** 64)
|
||||
126
ExternalIO/client.py
Normal file
126
ExternalIO/client.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import socket, ssl
|
||||
import struct
|
||||
import time
|
||||
|
||||
class Client:
|
||||
def __init__(self, hostnames, port_base, my_client_id):
|
||||
ctx = ssl.SSLContext()
|
||||
name = 'C%d' % my_client_id
|
||||
prefix = 'Player-Data/%s' % name
|
||||
ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key')
|
||||
ctx.load_verify_locations(capath='Player-Data')
|
||||
|
||||
self.sockets = []
|
||||
for i, hostname in enumerate(hostnames):
|
||||
for j in range(10000):
|
||||
try:
|
||||
plain_socket = socket.create_connection(
|
||||
(hostname, port_base + i))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
if j < 60:
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise
|
||||
octetStream(b'%d' % my_client_id).Send(plain_socket)
|
||||
self.sockets.append(ctx.wrap_socket(plain_socket,
|
||||
server_hostname='P%d' % i))
|
||||
|
||||
self.specification = octetStream()
|
||||
self.specification.Receive(self.sockets[0])
|
||||
|
||||
def receive_triples(self, T, n):
|
||||
triples = [[0, 0, 0] for i in range(n)]
|
||||
os = octetStream()
|
||||
for socket in self.sockets:
|
||||
os.Receive(socket)
|
||||
if socket == self.sockets[0]:
|
||||
active = os.get_length() == 3 * n * T.size()
|
||||
n_expected = 3 if active else 1
|
||||
if os.get_length() != n_expected * T.size() * n:
|
||||
import sys
|
||||
print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr)
|
||||
raise Exception('unexpected data length')
|
||||
for triple in triples:
|
||||
for i in range(n_expected):
|
||||
t = T()
|
||||
t.unpack(os)
|
||||
triple[i] += t
|
||||
res = []
|
||||
if active:
|
||||
for triple in triples:
|
||||
prod = triple[0] * triple[1]
|
||||
if prod != triple[2]:
|
||||
raise Exception(
|
||||
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
|
||||
return triples
|
||||
|
||||
def send_private_inputs(self, values):
|
||||
T = type(values[0])
|
||||
triples = self.receive_triples(T, len(values))
|
||||
os = octetStream()
|
||||
assert len(values) == len(triples)
|
||||
for value, triple in zip(values, triples):
|
||||
(value + triple[0]).pack(os)
|
||||
for socket in self.sockets:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_outputs(self, T, n):
|
||||
triples = self.receive_triples(T, n)
|
||||
return [triple[0] for triple in triples]
|
||||
|
||||
class octetStream:
|
||||
def __init__(self, value=None):
|
||||
self.buf = b''
|
||||
self.ptr = 0
|
||||
if value is not None:
|
||||
self.buf += value
|
||||
|
||||
def get_length(self):
|
||||
return len(self.buf)
|
||||
|
||||
def reset_write_head(self):
|
||||
self.buf = b''
|
||||
self.ptr = 0
|
||||
|
||||
def Send(self, socket):
|
||||
socket.sendall(struct.pack('<i', len(self.buf)))
|
||||
socket.sendall(self.buf)
|
||||
|
||||
def Receive(self, socket):
|
||||
length = struct.unpack('<I', socket.recv(4))[0]
|
||||
self.buf = b''
|
||||
while len(self.buf) < length:
|
||||
self.buf += socket.recv(length - len(self.buf))
|
||||
self.ptr = 0
|
||||
|
||||
def store(self, value):
|
||||
self.buf += struct.pack('<i', value)
|
||||
|
||||
def get_int(self, length):
|
||||
buf = self.consume(length)
|
||||
if length == 4:
|
||||
return struct.unpack('<i', buf)[0]
|
||||
elif length == 8:
|
||||
return struct.unpack('<q', buf)[0]
|
||||
raise ValueError()
|
||||
|
||||
def get_bigint(self):
|
||||
sign = self.consume(1)[0]
|
||||
assert(sign in (0, 1))
|
||||
length = self.get_int(4)
|
||||
if length:
|
||||
res = 0
|
||||
buf = self.consume(length)
|
||||
for i, b in enumerate(reversed(buf)):
|
||||
res += b << (i * 8)
|
||||
if sign:
|
||||
res *= -1
|
||||
return res
|
||||
else:
|
||||
return 0
|
||||
|
||||
def consume(self, length):
|
||||
self.ptr += length
|
||||
assert self.ptr <= len(self.buf)
|
||||
return self.buf[self.ptr - length:self.ptr]
|
||||
74
ExternalIO/domains.py
Normal file
74
ExternalIO/domains.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import struct
|
||||
|
||||
class Domain:
|
||||
def __init__(self, value=0):
|
||||
self.v = int(value % self.modulus)
|
||||
assert(self.v >= 0)
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
res = self.v + other.v
|
||||
except:
|
||||
res = self.v + other
|
||||
return type(self)(res)
|
||||
|
||||
def __mul__(self, other):
|
||||
try:
|
||||
res = self.v * other.v
|
||||
except:
|
||||
res = self.v * other
|
||||
return type(self)(res)
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.v == other.v
|
||||
|
||||
def __neq__(self, other):
|
||||
return self.v != other.v
|
||||
|
||||
@classmethod
|
||||
def size(cls):
|
||||
return cls.n_bytes
|
||||
|
||||
def unpack(self, os):
|
||||
self.v = 0
|
||||
buf = os.consume(self.n_bytes)
|
||||
for i, b in enumerate(buf):
|
||||
self.v += b << (i * 8)
|
||||
|
||||
def pack(self, os):
|
||||
v = self.v
|
||||
temp_buf = []
|
||||
for i in range(self.n_bytes):
|
||||
temp_buf.append(v & 0xff)
|
||||
v >>= 8
|
||||
#Instead of using python a loop per value we let struct pack handle all it
|
||||
os.buf += struct.pack('<{}B'.format(len(temp_buf)), *tuple(temp_buf))
|
||||
|
||||
def Z2(k):
|
||||
class Z(Domain):
|
||||
modulus = 2 ** k
|
||||
n_words = (k + 63) // 64
|
||||
n_bytes = (k + 7) // 8
|
||||
|
||||
return Z
|
||||
|
||||
def Fp(mod):
|
||||
import gmpy2
|
||||
|
||||
class Fp(Domain):
|
||||
modulus = mod
|
||||
n_words = (modulus.bit_length() + 63) // 64
|
||||
n_bytes = 8 * n_words
|
||||
R = 2 ** (64 * n_words) % modulus
|
||||
R_inv = gmpy2.invert(R, modulus)
|
||||
|
||||
def unpack(self, os):
|
||||
Domain.unpack(self, os)
|
||||
self.v = self.v * self.R_inv % self.modulus
|
||||
|
||||
def pack(self, os):
|
||||
Domain.pack(type(self)(self.v * self.R), os)
|
||||
|
||||
return Fp
|
||||
@@ -58,7 +58,8 @@ public:
|
||||
{
|
||||
if (this->size() != y.size())
|
||||
throw out_of_range("vector length mismatch");
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
size_t n = this->size();
|
||||
for (unsigned int i = 0; i < n; i++)
|
||||
(*this)[i] += y[i];
|
||||
return *this;
|
||||
}
|
||||
@@ -67,9 +68,11 @@ public:
|
||||
{
|
||||
if (this->size() != y.size())
|
||||
throw out_of_range("vector length mismatch");
|
||||
AddableVector<T> res(y.size());
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
res[i] = (*this)[i] - y[i];
|
||||
AddableVector<T> res;
|
||||
res.reserve(y.size());
|
||||
size_t n = this->size();
|
||||
for (unsigned int i = 0; i < n; i++)
|
||||
res.push_back((*this)[i] - y[i]);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "Ciphertext.h"
|
||||
#include "PPData.h"
|
||||
#include "P2Data.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
@@ -31,6 +30,12 @@ word check_pk_id(word a, word b)
|
||||
}
|
||||
|
||||
|
||||
void Ciphertext::Scale()
|
||||
{
|
||||
Scale(params->get_plaintext_modulus());
|
||||
}
|
||||
|
||||
|
||||
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1)
|
||||
{
|
||||
if (c0.params!=c1.params) { throw params_mismatch(); }
|
||||
@@ -108,16 +113,34 @@ void Ciphertext::mul(const Ciphertext& c, const Rq_Element& ra)
|
||||
::mul(cc1,ra,c.cc1);
|
||||
}
|
||||
|
||||
void Ciphertext::add(octetStream& os)
|
||||
void Ciphertext::add(octetStream& os, int)
|
||||
{
|
||||
Ciphertext tmp(*params);
|
||||
tmp.unpack(os);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
void Ciphertext::rerandomize(const FHE_PK& pk)
|
||||
{
|
||||
Rq_Element tmp(*params);
|
||||
SeededPRNG G;
|
||||
vector<FFT_Data::S> r(params->FFTD()[0].m());
|
||||
bigint p = pk.p();
|
||||
assert(p != 0);
|
||||
for (auto& x : r)
|
||||
{
|
||||
G.get(x, params->p0().numBits() - p.numBits() - 1);
|
||||
x *= p;
|
||||
}
|
||||
tmp.from(r, 0);
|
||||
Scale();
|
||||
cc0 += tmp;
|
||||
auto zero = pk.encrypt(*params);
|
||||
zero.Scale(pk.p());
|
||||
*this += zero;
|
||||
}
|
||||
|
||||
|
||||
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
|
||||
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
|
||||
template void mul(Ciphertext& ans,const Plaintext<gf2n_short,P2Data,int>& a,const Ciphertext& c);
|
||||
|
||||
|
||||
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
|
||||
const Ciphertext& c);
|
||||
|
||||
@@ -15,6 +15,12 @@ template<class T,class FD,class S> void mul(Ciphertext& ans,const Ciphertext& c,
|
||||
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1);
|
||||
void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk);
|
||||
|
||||
/**
|
||||
* BGV ciphertext.
|
||||
* The class allows adding two ciphertexts as well as adding a plaintext and
|
||||
* a ciphertext via operator overloading. The multiplication of two ciphertexts
|
||||
* requires the public key and thus needs a separate function.
|
||||
*/
|
||||
class Ciphertext
|
||||
{
|
||||
Rq_Element cc0,cc1;
|
||||
@@ -54,6 +60,7 @@ class Ciphertext
|
||||
|
||||
// Scale down an element from level 1 to level 0, if at level 0 do nothing
|
||||
void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); }
|
||||
void Scale();
|
||||
|
||||
// Throws error if ans,c0,c1 etc have different params settings
|
||||
// - Thus programmer needs to ensure this rather than this being done
|
||||
@@ -90,6 +97,12 @@ class Ciphertext
|
||||
template <class FD>
|
||||
Ciphertext& operator*=(const Plaintext_<FD>& other) { ::mul(*this, *this, other); return *this; }
|
||||
|
||||
/**
|
||||
* Ciphertext multiplication.
|
||||
* @param pk public key
|
||||
* @param x second ciphertext
|
||||
* @returns product ciphertext
|
||||
*/
|
||||
Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const
|
||||
{ Ciphertext res(*params); ::mul(res, *this, x, pk); return res; }
|
||||
|
||||
@@ -98,21 +111,25 @@ class Ciphertext
|
||||
return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this};
|
||||
}
|
||||
|
||||
/// Re-randomize for circuit privacy.
|
||||
void rerandomize(const FHE_PK& pk);
|
||||
|
||||
int level() const { return cc0.level(); }
|
||||
|
||||
// pack/unpack (like IO) also assume params are known and already set
|
||||
// correctly
|
||||
void pack(octetStream& o) const
|
||||
/// Append to buffer
|
||||
void pack(octetStream& o, int = -1) const
|
||||
{ cc0.pack(o); cc1.pack(o); o.store(pk_id); }
|
||||
void unpack(octetStream& o)
|
||||
{ cc0.unpack(o); cc1.unpack(o); o.get(pk_id); }
|
||||
|
||||
/// Read from buffer. Assumes parameters are set correctly
|
||||
void unpack(octetStream& o, int = -1)
|
||||
{ cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); }
|
||||
|
||||
void output(ostream& s) const
|
||||
{ cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); }
|
||||
void input(istream& s)
|
||||
{ cc0.input(s); cc1.input(s); s.read((char*)&pk_id, sizeof(pk_id)); }
|
||||
|
||||
void add(octetStream& os);
|
||||
void add(octetStream& os, int = -1);
|
||||
|
||||
size_t report_size(ReportType type) const { return cc0.report_size(type) + cc1.report_size(type); }
|
||||
};
|
||||
|
||||
@@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag(
|
||||
{
|
||||
auto& c = products.at(i);
|
||||
for (int j = 0; j < n_matrices; j++)
|
||||
{
|
||||
res.at(j).entries.init();
|
||||
for (size_t k = 0; k < n_rows; k++)
|
||||
res.at(j)[{k, i}] = c.element(j * n_rows + k);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -259,6 +259,3 @@ void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forw
|
||||
else
|
||||
{ throw crash_requested(); }
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
|
||||
|
||||
|
||||
FFT_Data::FFT_Data() :
|
||||
twop(-1)
|
||||
{
|
||||
}
|
||||
|
||||
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{
|
||||
R=Rg;
|
||||
@@ -78,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
for (int r=0; r<2; r++)
|
||||
{ FFT_Iter(b[r],twop,two_root[0],PrD); }
|
||||
}
|
||||
else
|
||||
throw bad_value();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class FFT_Data
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
FFT_Data() { ; }
|
||||
FFT_Data();
|
||||
FFT_Data(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ init(Rg,PrD); }
|
||||
|
||||
|
||||
117
FHE/FHE_Keys.cpp
117
FHE/FHE_Keys.cpp
@@ -2,7 +2,6 @@
|
||||
#include "FHE_Keys.h"
|
||||
#include "Ciphertext.h"
|
||||
#include "P2Data.h"
|
||||
#include "PPData.h"
|
||||
#include "FFT_Data.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
@@ -12,6 +11,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p())
|
||||
{
|
||||
}
|
||||
|
||||
FHE_SK::FHE_SK(const FHE_Params& pms) :
|
||||
FHE_SK(pms, pms.get_plaintext_modulus())
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
FHE_SK& FHE_SK::operator+=(const FHE_SK& c)
|
||||
{
|
||||
@@ -38,6 +42,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G)
|
||||
}
|
||||
|
||||
|
||||
FHE_PK::FHE_PK(const FHE_Params& pms) :
|
||||
FHE_PK(pms, pms.get_plaintext_modulus())
|
||||
{
|
||||
}
|
||||
|
||||
Rq_Element FHE_PK::sample_secret_key(PRNG& G)
|
||||
{
|
||||
Rq_Element sk = FHE_SK(*this).s();
|
||||
@@ -47,11 +56,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
|
||||
}
|
||||
|
||||
void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
{
|
||||
Rq_Element a(*this);
|
||||
a.randomize(G);
|
||||
partial_key_gen(sk, a, G, noise_boost);
|
||||
}
|
||||
|
||||
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
int noise_boost)
|
||||
{
|
||||
FHE_PK& PK = *this;
|
||||
|
||||
// Generate the main public key
|
||||
PK.a0.randomize(G);
|
||||
a0 = a;
|
||||
|
||||
// b0=a0*s+p*e0
|
||||
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
|
||||
@@ -77,9 +93,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
mul(es,es,PK.pr);
|
||||
add(PK.Sw_b,PK.Sw_b,es);
|
||||
|
||||
// Lowering level as we only decrypt at level 0
|
||||
sk.lower_level();
|
||||
|
||||
// bs=bs-p1*s^2
|
||||
Rq_Element s2;
|
||||
mul(s2,sk,sk); // Mult at level 0
|
||||
@@ -175,32 +188,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext<typename FD::T, FD, typename FD::S>&
|
||||
template<class FD>
|
||||
Ciphertext FHE_PK::encrypt(
|
||||
const Plaintext<typename FD::T, FD, typename FD::S>& mess) const
|
||||
{
|
||||
return encrypt(Rq_Element(*params, mess));
|
||||
}
|
||||
|
||||
Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const
|
||||
{
|
||||
Random_Coins rc(*params);
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
rc.generate(G);
|
||||
return encrypt(mess, rc);
|
||||
Ciphertext res(*params);
|
||||
quasi_encrypt(res, mess, rc);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (T::characteristic_two ^ (pr == 2))
|
||||
throw pr_mismatch();
|
||||
|
||||
Rq_Element ans = quasi_decrypt(c);
|
||||
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
|
||||
}
|
||||
|
||||
Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
|
||||
Rq_Element ans;
|
||||
|
||||
mul(ans,c.c1(),sk);
|
||||
sub(ans,c.c0(),ans);
|
||||
ans.change_rep(polynomial);
|
||||
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c)
|
||||
{
|
||||
return decrypt(c, params->get_plaintext_field_data<FFT_Data>());
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
Plaintext<typename FD::T, FD, typename FD::S> FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD)
|
||||
{
|
||||
@@ -295,12 +327,12 @@ void FHE_PK::unpack(octetStream& o)
|
||||
o.consume((octet*) tag, 8);
|
||||
if (memcmp(tag, "PKPKPKPK", 8))
|
||||
throw runtime_error("invalid serialization of public key");
|
||||
a0.unpack(o);
|
||||
b0.unpack(o);
|
||||
a0.unpack(o, *params);
|
||||
b0.unpack(o, *params);
|
||||
if (params->n_mults() > 0)
|
||||
{
|
||||
Sw_a.unpack(o);
|
||||
Sw_b.unpack(o);
|
||||
Sw_a.unpack(o, *params);
|
||||
Sw_b.unpack(o, *params);
|
||||
}
|
||||
pr.unpack(o);
|
||||
}
|
||||
@@ -318,7 +350,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
|
||||
const bigint& pr) const
|
||||
{
|
||||
@@ -334,15 +365,13 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
|
||||
template<class FD>
|
||||
void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
|
||||
{
|
||||
check(*params, pk, pr);
|
||||
check(*params, pk, FieldD.get_prime());
|
||||
pk.check_noise(*this);
|
||||
if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) !=
|
||||
Plaintext_<FD>(FieldD))
|
||||
throw runtime_error("incorrect key pair");
|
||||
}
|
||||
|
||||
|
||||
|
||||
void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
|
||||
{
|
||||
if (this->pr != pr)
|
||||
@@ -357,30 +386,36 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
|
||||
}
|
||||
}
|
||||
|
||||
bigint FHE_SK::get_noise(const Ciphertext& c)
|
||||
{
|
||||
sk.lower_level();
|
||||
Ciphertext cc = c;
|
||||
if (cc.level())
|
||||
cc.Scale();
|
||||
Rq_Element tmp = quasi_decrypt(cc);
|
||||
bigint res;
|
||||
bigint q = tmp.get_modulus();
|
||||
bigint half_q = q / 2;
|
||||
for (auto& x : tmp.to_vec_bigint())
|
||||
{
|
||||
// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " ";
|
||||
res = max(res, x > half_q ? x - q : x);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
#define X(FD) \
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FD>& mess, \
|
||||
const Random_Coins& rc) const; \
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FD>& mess) const; \
|
||||
template Plaintext_<FD> FHE_SK::decrypt(const Ciphertext& c, \
|
||||
const FD& FieldD); \
|
||||
template void FHE_SK::decrypt(Plaintext_<FD>& res, \
|
||||
const Ciphertext& c) const; \
|
||||
template void FHE_SK::decrypt_any(Plaintext_<FD>& res, \
|
||||
const Ciphertext& c); \
|
||||
template void FHE_SK::check(const FHE_PK& pk, const FD&);
|
||||
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
|
||||
|
||||
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
|
||||
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
|
||||
|
||||
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
const FFT_Data& FieldD);
|
||||
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
const P2Data& FieldD);
|
||||
|
||||
template void FHE_SK::decrypt_any(Plaintext_<FFT_Data>& res,
|
||||
const Ciphertext& c);
|
||||
template void FHE_SK::decrypt_any(Plaintext_<P2Data>& res,
|
||||
const Ciphertext& c);
|
||||
|
||||
template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&);
|
||||
template void FHE_SK::check(const FHE_PK& pk, const P2Data&);
|
||||
X(FFT_Data)
|
||||
X(P2Data)
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
class FHE_PK;
|
||||
class Ciphertext;
|
||||
|
||||
/**
|
||||
* BGV secret key.
|
||||
* The class allows addition.
|
||||
*/
|
||||
class FHE_SK
|
||||
{
|
||||
Rq_Element sk;
|
||||
@@ -29,6 +33,8 @@ class FHE_SK
|
||||
// secret key always on lower level
|
||||
void assign(const Rq_Element& s) { sk=s; sk.lower_level(); }
|
||||
|
||||
FHE_SK(const FHE_Params& pms);
|
||||
|
||||
FHE_SK(const FHE_Params& pms, const bigint& p)
|
||||
: sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; }
|
||||
|
||||
@@ -38,8 +44,12 @@ class FHE_SK
|
||||
|
||||
const Rq_Element& s() const { return sk; }
|
||||
|
||||
void pack(octetStream& os) const { sk.pack(os); pr.pack(os); }
|
||||
void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); }
|
||||
/// Append to buffer
|
||||
void pack(octetStream& os, int = -1) const { sk.pack(os); pr.pack(os); }
|
||||
|
||||
/// Read from buffer. Assumes parameters are set correctly
|
||||
void unpack(octetStream& os, int = -1)
|
||||
{ sk.unpack(os, *params); pr.unpack(os); }
|
||||
|
||||
// Assumes Ring and prime of mess have already been set correctly
|
||||
// Ciphertext c must be at level 0 or an error occurs
|
||||
@@ -50,9 +60,14 @@ class FHE_SK
|
||||
template <class FD>
|
||||
Plaintext<typename FD::T, FD, typename FD::S> decrypt(const Ciphertext& c, const FD& FieldD);
|
||||
|
||||
/// Decryption for cleartexts modulo prime
|
||||
Plaintext_<FFT_Data> decrypt(const Ciphertext& c);
|
||||
|
||||
template <class FD>
|
||||
void decrypt_any(Plaintext_<FD>& mess, const Ciphertext& c);
|
||||
|
||||
Rq_Element quasi_decrypt(const Ciphertext& c) const;
|
||||
|
||||
// Three stage procedure for Distributed Decryption
|
||||
// - First stage produces my shares
|
||||
// - Second stage adds in another players shares, do this once for each other player
|
||||
@@ -62,7 +77,6 @@ class FHE_SK
|
||||
void dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_number,int num_players) const;
|
||||
void dist_decrypt_2(vector<bigint>& vv,const vector<bigint>& vv1) const;
|
||||
|
||||
|
||||
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
|
||||
|
||||
/* Add secret keys
|
||||
@@ -75,17 +89,23 @@ class FHE_SK
|
||||
|
||||
bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; }
|
||||
|
||||
void add(octetStream& os) { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
|
||||
void add(octetStream& os, int = -1)
|
||||
{ FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
|
||||
|
||||
void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const;
|
||||
|
||||
template<class FD>
|
||||
void check(const FHE_PK& pk, const FD& FieldD);
|
||||
|
||||
bigint get_noise(const Ciphertext& c);
|
||||
|
||||
friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; }
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* BGV public key.
|
||||
*/
|
||||
class FHE_PK
|
||||
{
|
||||
Rq_Element a0,b0;
|
||||
@@ -104,8 +124,10 @@ class FHE_PK
|
||||
)
|
||||
{ a0=a; b0=b; Sw_a=sa; Sw_b=sb; }
|
||||
|
||||
|
||||
FHE_PK(const FHE_Params& pms, const bigint& p = 0)
|
||||
|
||||
FHE_PK(const FHE_Params& pms);
|
||||
|
||||
FHE_PK(const FHE_Params& pms, const bigint& p)
|
||||
: a0(pms.FFTD(),evaluation,evaluation),
|
||||
b0(pms.FFTD(),evaluation,evaluation),
|
||||
Sw_a(pms.FFTD(),evaluation,evaluation),
|
||||
@@ -143,19 +165,26 @@ class FHE_PK
|
||||
|
||||
template <class FD>
|
||||
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess, const Random_Coins& rc) const;
|
||||
|
||||
/// Encryption
|
||||
template <class FD>
|
||||
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess) const;
|
||||
Ciphertext encrypt(const Rq_Element& mess) const;
|
||||
|
||||
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
|
||||
|
||||
Rq_Element sample_secret_key(PRNG& G);
|
||||
void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1);
|
||||
void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
int noise_boost = 1);
|
||||
|
||||
void check_noise(const FHE_SK& sk) const;
|
||||
void check_noise(const Rq_Element& x, bool check_modulo = false) const;
|
||||
|
||||
// params setting is done out of these IO/pack/unpack functions
|
||||
/// Append to buffer
|
||||
void pack(octetStream& o) const;
|
||||
|
||||
/// Read from buffer. Assumes parameters are set correctly
|
||||
void unpack(octetStream& o);
|
||||
|
||||
bool operator!=(const FHE_PK& x) const;
|
||||
@@ -168,21 +197,39 @@ class FHE_PK
|
||||
void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
|
||||
|
||||
|
||||
/**
|
||||
* BGV key pair
|
||||
*/
|
||||
class FHE_KeyPair
|
||||
{
|
||||
public:
|
||||
/// Public key
|
||||
FHE_PK pk;
|
||||
/// Secret key
|
||||
FHE_SK sk;
|
||||
|
||||
FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) :
|
||||
FHE_KeyPair(const FHE_Params& params, const bigint& pr) :
|
||||
pk(params, pr), sk(params, pr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Initialization
|
||||
FHE_KeyPair(const FHE_Params& params) :
|
||||
pk(params), sk(params)
|
||||
{
|
||||
}
|
||||
|
||||
void generate(PRNG& G)
|
||||
{
|
||||
KeyGen(pk, sk, G);
|
||||
}
|
||||
|
||||
/// Generate fresh keys
|
||||
void generate()
|
||||
{
|
||||
SeededPRNG G;
|
||||
generate(G);
|
||||
}
|
||||
};
|
||||
|
||||
template <class S>
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
|
||||
#include "FHE_Params.h"
|
||||
#include "NTL-Subs.h"
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
#include "Protocols/HemiOptions.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
FHE_Params::FHE_Params(int n_mults, int drown_sec) :
|
||||
FFTData(n_mults + 1), Chi(0.7), sec_p(drown_sec), matrix_dim(1)
|
||||
{
|
||||
}
|
||||
|
||||
void FHE_Params::set(const Ring& R,
|
||||
const vector<bigint>& primes)
|
||||
@@ -12,16 +20,35 @@ void FHE_Params::set(const Ring& R,
|
||||
for (size_t i = 0; i < FFTData.size(); i++)
|
||||
FFTData[i].init(R,primes[i]);
|
||||
|
||||
set_sec(40);
|
||||
set_sec(sec_p);
|
||||
}
|
||||
|
||||
void FHE_Params::set_sec(int sec)
|
||||
{
|
||||
assert(sec >= 0);
|
||||
sec_p=sec;
|
||||
Bval=1; Bval=Bval<<sec_p;
|
||||
Bval=FFTData[0].get_prime()/(2*(1+Bval));
|
||||
if (Bval == 0)
|
||||
throw runtime_error("distributed decryption bound is zero");
|
||||
}
|
||||
|
||||
void FHE_Params::set_min_sec(int sec)
|
||||
{
|
||||
set_sec(max(sec, sec_p));
|
||||
}
|
||||
|
||||
void FHE_Params::set_matrix_dim(int matrix_dim)
|
||||
{
|
||||
assert(matrix_dim > 0);
|
||||
if (FFTData[0].get_prime() != 0)
|
||||
throw runtime_error("cannot change matrix dimension after parameter generation");
|
||||
this->matrix_dim = matrix_dim;
|
||||
}
|
||||
|
||||
void FHE_Params::set_matrix_dim_from_options()
|
||||
{
|
||||
set_matrix_dim(
|
||||
HemiOptions::singleton.plain_matmul ?
|
||||
1 : OnlineOptions::singleton.batch_size);
|
||||
}
|
||||
|
||||
bigint FHE_Params::Q() const
|
||||
@@ -40,6 +67,8 @@ void FHE_Params::pack(octetStream& o) const
|
||||
Chi.pack(o);
|
||||
Bval.pack(o);
|
||||
o.store(sec_p);
|
||||
o.store(matrix_dim);
|
||||
fd.pack(o);
|
||||
}
|
||||
|
||||
void FHE_Params::unpack(octetStream& o)
|
||||
@@ -52,6 +81,8 @@ void FHE_Params::unpack(octetStream& o)
|
||||
Chi.unpack(o);
|
||||
Bval.unpack(o);
|
||||
o.get(sec_p);
|
||||
o.get(matrix_dim);
|
||||
fd.unpack(o);
|
||||
}
|
||||
|
||||
bool FHE_Params::operator!=(const FHE_Params& other) const
|
||||
@@ -64,3 +95,31 @@ bool FHE_Params::operator!=(const FHE_Params& other) const
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
void FHE_Params::basic_generation_mod_prime(int plaintext_length)
|
||||
{
|
||||
if (n_mults() == 0)
|
||||
generate_semi_setup(plaintext_length, 0, *this, fd, false);
|
||||
else
|
||||
{
|
||||
Parameters parameters(1, plaintext_length, 0);
|
||||
parameters.generate_setup(*this, fd);
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
const FFT_Data& FHE_Params::get_plaintext_field_data() const
|
||||
{
|
||||
return fd;
|
||||
}
|
||||
|
||||
template<>
|
||||
const P2Data& FHE_Params::get_plaintext_field_data() const
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
bigint FHE_Params::get_plaintext_modulus() const
|
||||
{
|
||||
return fd.get_prime();
|
||||
}
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/DiscreteGauss.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Protocols/config.h"
|
||||
|
||||
/**
|
||||
* Cryptosystem parameters
|
||||
*/
|
||||
class FHE_Params
|
||||
{
|
||||
protected:
|
||||
@@ -26,19 +30,29 @@ class FHE_Params
|
||||
// Data for distributed decryption
|
||||
int sec_p;
|
||||
bigint Bval;
|
||||
int matrix_dim;
|
||||
|
||||
FFT_Data fd;
|
||||
|
||||
public:
|
||||
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
|
||||
/**
|
||||
* Initialization.
|
||||
* @param n_mults number of ciphertext multiplications (0/1)
|
||||
* @param drown_sec parameter for function privacy (default 40)
|
||||
*/
|
||||
FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY);
|
||||
|
||||
int n_mults() const { return FFTData.size() - 1; }
|
||||
|
||||
// Rely on default copy assignment/constructor (not that they should
|
||||
// ever be needed)
|
||||
|
||||
void set(const Ring& R,const vector<bigint>& primes);
|
||||
void set(const vector<bigint>& primes);
|
||||
void set_sec(int sec);
|
||||
void set_min_sec(int sec);
|
||||
|
||||
void set_matrix_dim(int matrix_dim);
|
||||
void set_matrix_dim_from_options();
|
||||
int get_matrix_dim() const { return matrix_dim; }
|
||||
|
||||
const vector<FFT_Data>& FFTD() const { return FFTData; }
|
||||
|
||||
@@ -55,10 +69,24 @@ class FHE_Params
|
||||
int phi_m() const { return FFTData[0].phi_m(); }
|
||||
const Ring& get_ring() { return FFTData[0].get_R(); }
|
||||
|
||||
/// Append to buffer
|
||||
void pack(octetStream& o) const;
|
||||
|
||||
/// Read from buffer
|
||||
void unpack(octetStream& o);
|
||||
|
||||
bool operator!=(const FHE_Params& other) const;
|
||||
|
||||
/**
|
||||
* Generate parameter for computation modulo a prime
|
||||
* @param plaintext_length bit length of prime
|
||||
*/
|
||||
void basic_generation_mod_prime(int plaintext_length);
|
||||
|
||||
template<class FD>
|
||||
const FD& get_plaintext_field_data() const;
|
||||
|
||||
bigint get_plaintext_modulus() const;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -68,6 +68,10 @@ void HNF(matrix& H,matrix& U,const matrix& A)
|
||||
{
|
||||
int m=A.size(),n=A[0].size(),r,i,j,k;
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "HNF m=" << m << ", n=" << n << endl;
|
||||
#endif
|
||||
|
||||
H=A;
|
||||
ident(U,n);
|
||||
r=min(m,n);
|
||||
@@ -79,9 +83,9 @@ void HNF(matrix& H,matrix& U,const matrix& A)
|
||||
{ if (step==2)
|
||||
{ // Step 2
|
||||
k=-1;
|
||||
mn=bigint(1)<<256;
|
||||
mn=bigint(0);
|
||||
for (j=i; j<n; j++)
|
||||
{ if (H[i][j]!=0 && abs(H[i][j])<mn)
|
||||
{ if (H[i][j]!=0 && (abs(H[i][j])<mn || mn == 0))
|
||||
{ k=j; mn=abs(H[i][j]); }
|
||||
}
|
||||
if (k!=-1)
|
||||
@@ -207,6 +211,11 @@ void SNF_Step(matrix& S,matrix& V)
|
||||
void SNF(matrix& S,const matrix& A,matrix& V)
|
||||
{
|
||||
int m=A.size(),n=A[0].size();
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "SNF m=" << m << ", n=" << n << endl;
|
||||
#endif
|
||||
|
||||
S=A;
|
||||
ident(V,n);
|
||||
|
||||
@@ -239,49 +248,6 @@ matrix inv(const matrix& A)
|
||||
}
|
||||
|
||||
|
||||
vector<modp> solve(modp_matrix& A,const Zp_Data& PrD)
|
||||
{
|
||||
unsigned int n=A.size();
|
||||
if ((n+1)!=A[0].size()) { throw invalid_params(); }
|
||||
|
||||
modp t,ti;
|
||||
for (unsigned int r=0; r<n; r++)
|
||||
{ // Find pivot
|
||||
unsigned int p=r;
|
||||
while (isZero(A[p][r],PrD)) { p++; }
|
||||
// Do pivoting
|
||||
if (p!=r)
|
||||
{ for (unsigned int c=0; c<n+1; c++)
|
||||
{ t=A[p][c]; A[p][c]=A[r][c]; A[r][c]=t; }
|
||||
}
|
||||
// Make Lcoeff=1
|
||||
Inv(ti,A[r][r],PrD);
|
||||
for (unsigned int c=0; c<n+1; c++)
|
||||
{ Mul(A[r][c],A[r][c],ti,PrD); }
|
||||
// Now kill off other entries in this column
|
||||
for (unsigned int rr=0; rr<n; rr++)
|
||||
{ if (r!=rr)
|
||||
{ for (unsigned int c=0; c<n+1; c++)
|
||||
{ Mul(t,A[rr][c],A[r][r],PrD);
|
||||
Sub(A[rr][c],A[rr][c],t,PrD);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sanity check and extract answer
|
||||
vector<modp> ans;
|
||||
ans.resize(n);
|
||||
for (unsigned int i=0; i<n; i++)
|
||||
{ for (unsigned int j=0; j<n; j++)
|
||||
{ if (i!=j && !isZero(A[i][j],PrD)) { throw bad_value(); }
|
||||
else if (!isOne(A[i][j],PrD)) { throw bad_value(); }
|
||||
}
|
||||
ans[i]=A[i][n];
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Input matrix is assumed to have more rows than columns */
|
||||
void pinv(imatrix& Ai,const imatrix& B)
|
||||
|
||||
@@ -10,7 +10,6 @@ using namespace std;
|
||||
#include "Tools/BitVector.h"
|
||||
|
||||
typedef vector< vector<bigint> > matrix;
|
||||
typedef vector< vector<modp> > modp_matrix;
|
||||
|
||||
class imatrix : public vector< BitVector >
|
||||
{
|
||||
@@ -39,13 +38,6 @@ void print(const imatrix& S);
|
||||
// requires column operations to create the inverse
|
||||
matrix inv(const matrix& A);
|
||||
|
||||
// Another special routine for modp matrices.
|
||||
// Solves
|
||||
// Ax=b
|
||||
// Assumes A is unimodular, square and only requires row operations to
|
||||
// create the inverse. In put is C=(A||b) and the routines alters A
|
||||
vector<modp> solve(modp_matrix& C,const Zp_Data& PrD);
|
||||
|
||||
// Finds a pseudo-inverse of a matrix A modulo 2
|
||||
// - Input matrix is assumed to have more rows than columns
|
||||
void pinv(imatrix& Ai,const imatrix& A);
|
||||
|
||||
258
FHE/NTL-Subs.cpp
258
FHE/NTL-Subs.cpp
@@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2)
|
||||
|
||||
template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, FFT_Data& FTD, bool round_up)
|
||||
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
|
||||
{
|
||||
int m = 1024;
|
||||
int lgp = plaintext_length;
|
||||
@@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
while (true)
|
||||
{
|
||||
tmp_params = params;
|
||||
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
|
||||
SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec,
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
|
||||
bigint p1 = 2 * p * m, p0 = p;
|
||||
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
|
||||
@@ -89,27 +89,30 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
|
||||
template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, P2Data& P2D, bool round_up)
|
||||
FHE_Params& params, P2Data& P2D, bool round_up, int n)
|
||||
{
|
||||
if (params.n_mults() > 0)
|
||||
throw runtime_error("only implemented for 0-level BGV");
|
||||
gf2n_short::init_field(plaintext_length);
|
||||
int m;
|
||||
char_2_dimension(m, plaintext_length);
|
||||
SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec,
|
||||
SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec,
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
|
||||
int lgp0 = numBits(nb.min_p0(false, 0));
|
||||
int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up);
|
||||
assert(nb.min_phi_m(lgp0, false) * 2 <= m);
|
||||
load_or_generate(P2D, params.get_ring());
|
||||
return extra_slack;
|
||||
}
|
||||
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, bool round_up)
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Need ciphertext modulus of length " << lgp0;
|
||||
if (params.n_mults() > 0)
|
||||
cout << "+" << lgp1;
|
||||
cout << " and " << phi_N(m) << " slots" << endl;
|
||||
#endif
|
||||
|
||||
int extra_slack = 0;
|
||||
if (round_up)
|
||||
@@ -124,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b
|
||||
}
|
||||
extra_slack = i - 1;
|
||||
lgp0 += extra_slack;
|
||||
#ifdef VERBOSE
|
||||
cout << "Rounding up to " << lgp0 << ", giving extra slack of "
|
||||
<< extra_slack << " bits" << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
Ring R;
|
||||
@@ -147,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b
|
||||
int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
|
||||
bool round_up, FHE_Params& params)
|
||||
{
|
||||
(void) lg2pi, (void) n;
|
||||
|
||||
#ifdef VERBOSE
|
||||
if (n >= 2 and n <= 10)
|
||||
cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
|
||||
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
|
||||
cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
|
||||
cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
|
||||
#endif
|
||||
|
||||
int extra_slack = 0;
|
||||
if (round_up)
|
||||
@@ -170,20 +179,18 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
|
||||
extra_slack = 2 * i;
|
||||
lg2p0 += i;
|
||||
lg2p1 += i;
|
||||
#ifdef VERBOSE
|
||||
cout << "Rounding up to " << lg2p0 << "+" << lg2p1
|
||||
<< ", giving extra slack of " << extra_slack << " bits" << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cout << "Total length: " << lg2p0 + lg2p1 << endl;
|
||||
#endif
|
||||
|
||||
return extra_slack;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Here onwards needs NTL
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
|
||||
@@ -220,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p,
|
||||
{
|
||||
double phi_m_bound =
|
||||
NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1);
|
||||
|
||||
#ifdef VERBOSE
|
||||
cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl;
|
||||
#endif
|
||||
|
||||
if (phi_N(m) < phi_m_bound)
|
||||
{
|
||||
int old_m = m;
|
||||
(void) old_m;
|
||||
m = 2 << int(ceil(log2(phi_m_bound)));
|
||||
|
||||
#ifdef VERBOSE
|
||||
cout << "m = " << old_m << " too small, increasing it to " << m << endl;
|
||||
#endif
|
||||
|
||||
generate_prime(p, numBits(p), m);
|
||||
}
|
||||
else
|
||||
@@ -249,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p,
|
||||
void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
|
||||
const string& i, const bigint& pr0)
|
||||
{
|
||||
(void) i;
|
||||
|
||||
if (lg2pr==0) { throw invalid_params(); }
|
||||
|
||||
bigint step=m;
|
||||
@@ -265,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
|
||||
assert(numBits(pr) == lg2pr);
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl;
|
||||
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
|
||||
#endif
|
||||
|
||||
assert(pr % m == 1);
|
||||
assert(pr % p == 1);
|
||||
assert(numBits(pr) == lg2pr);
|
||||
|
||||
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -345,10 +364,12 @@ ZZX Cyclotomic(int N)
|
||||
return F;
|
||||
}
|
||||
#else
|
||||
// simplified version powers of two
|
||||
int phi_N(int N)
|
||||
{
|
||||
if (((N - 1) & N) != 0)
|
||||
throw runtime_error("compile with NTL support");
|
||||
throw runtime_error(
|
||||
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
|
||||
else if (N == 1)
|
||||
return 1;
|
||||
else
|
||||
@@ -398,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly)
|
||||
for (int i=0; i<Rg.phim+1; i++)
|
||||
{ Rg.poly[i]=to_int(coeff(P,i)); }
|
||||
#else
|
||||
throw runtime_error("compile with NTL support");
|
||||
throw runtime_error(
|
||||
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -433,6 +455,16 @@ GF2X Subs_PowX_Mod(const GF2X& a,int pow,int m,const GF2X& c)
|
||||
|
||||
|
||||
|
||||
GF2X get_F(const Ring& Rg)
|
||||
{
|
||||
GF2X F;
|
||||
for (int i=0; i<=Rg.phi_m(); i++)
|
||||
{ if (((Rg.Phi()[i])%2)!=0)
|
||||
{ SetCoeff(F,i,1); }
|
||||
}
|
||||
//cout << "F = " << F << endl;
|
||||
return F;
|
||||
}
|
||||
|
||||
void init(P2Data& P2D,const Ring& Rg)
|
||||
{
|
||||
@@ -443,16 +475,12 @@ void init(P2Data& P2D,const Ring& Rg)
|
||||
{ SetCoeff(G,gf2n_short::get_t(i),1); }
|
||||
//cout << "G = " << G << endl;
|
||||
|
||||
for (int i=0; i<=Rg.phi_m(); i++)
|
||||
{ if (((Rg.Phi()[i])%2)!=0)
|
||||
{ SetCoeff(F,i,1); }
|
||||
}
|
||||
//cout << "F = " << F << endl;
|
||||
F = get_F(Rg);
|
||||
|
||||
// seed randomness to achieve same result for all players
|
||||
// randomness is used in SFCanZass and FindRoot
|
||||
SetSeed(ZZ(0));
|
||||
|
||||
|
||||
// Now factor F modulo 2
|
||||
vec_GF2X facts=SFCanZass(F);
|
||||
|
||||
@@ -464,17 +492,34 @@ void init(P2Data& P2D,const Ring& Rg)
|
||||
// Compute the quotient group
|
||||
QGroup QGrp;
|
||||
int Gord=-1,e=Rg.phi_m()/d; // e = # of plaintext slots, phi(m)/degree
|
||||
int seed=1;
|
||||
while (Gord!=e)
|
||||
|
||||
if ((e*gf2n_short::degree())!=Rg.phi_m())
|
||||
{ cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl;
|
||||
cout << e << " * " << gf2n_short::degree() << " != " << Rg.phi_m() << endl;
|
||||
throw invalid_params();
|
||||
}
|
||||
|
||||
int max_tries = 10;
|
||||
for (int seed = 0;; seed++)
|
||||
{ QGrp.assign(Rg.m(),seed); // QGrp encodes the the quotient group Z_m^*/<2>
|
||||
Gord=QGrp.order();
|
||||
if (Gord!=e) { cout << "Group order wrong, need to repeat the Haf-Mc algorithm" << endl; seed++; }
|
||||
Gord = QGrp.order();
|
||||
if (Gord == e)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (seed == max_tries)
|
||||
{
|
||||
cerr << "abort after " << max_tries << " tries" << endl;
|
||||
throw invalid_params();
|
||||
}
|
||||
else
|
||||
cout << "Group order wrong, need to repeat the Haf-Mc algorithm"
|
||||
<< endl;
|
||||
}
|
||||
}
|
||||
//cout << " l = " << Gord << " , d = " << d << endl;
|
||||
if ((Gord*gf2n_short::degree())!=Rg.phi_m())
|
||||
{ cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl;
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
vector<GF2X> Fi(Gord);
|
||||
vector<GF2X> Rts(Gord);
|
||||
@@ -595,6 +640,27 @@ void char_2_dimension(int& m, int& lg2)
|
||||
m=5797;
|
||||
lg2=40;
|
||||
break;
|
||||
case 64:
|
||||
m = 9615;
|
||||
break;
|
||||
case 63:
|
||||
m = 9271;
|
||||
break;
|
||||
case 28:
|
||||
m = 3277;
|
||||
break;
|
||||
case 16:
|
||||
m = 4369;
|
||||
break;
|
||||
case 15:
|
||||
m = 4681;
|
||||
break;
|
||||
case 12:
|
||||
m = 4095;
|
||||
break;
|
||||
case 11:
|
||||
m = 2047;
|
||||
break;
|
||||
default:
|
||||
throw runtime_error("field size not supported");
|
||||
break;
|
||||
@@ -630,7 +696,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D)
|
||||
finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[0], round_up, params);
|
||||
}
|
||||
|
||||
if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) > phi_N(m))
|
||||
if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) * 2 > m)
|
||||
throw runtime_error("number of slots too small");
|
||||
|
||||
cout << "m = " << m << endl;
|
||||
@@ -676,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R)
|
||||
P2D.store(R);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef USE_NTL
|
||||
/*
|
||||
* Create FHE parameters for a general plaintext modulus p
|
||||
* Basically this is for general large primes only
|
||||
*/
|
||||
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params)
|
||||
{
|
||||
cout << "Setting up parameters" << endl;
|
||||
|
||||
int lgp=numBits(p);
|
||||
int mm,idx;
|
||||
|
||||
// mm is the minimum value of m we will accept
|
||||
if (lgp<48)
|
||||
{ mm=100; // Test case
|
||||
idx=0;
|
||||
}
|
||||
else if (lgp <96)
|
||||
{ mm=8192;
|
||||
idx=1;
|
||||
}
|
||||
else if (lgp<192)
|
||||
{ mm=16384;
|
||||
idx=2;
|
||||
}
|
||||
else if (lgp<384)
|
||||
{ mm=16384;
|
||||
idx=3;
|
||||
}
|
||||
else if (lgp<768)
|
||||
{ mm=32768;
|
||||
idx=4;
|
||||
}
|
||||
else
|
||||
{ throw invalid_params(); }
|
||||
|
||||
// Now find the small factors of p-1 and their exponents
|
||||
bigint t=p-1;
|
||||
vector<long> primes(100),exp(100);
|
||||
|
||||
PrimeSeq s;
|
||||
long pr;
|
||||
pr=s.next();
|
||||
int len=0;
|
||||
while (pr<2*mm)
|
||||
{ int e=0;
|
||||
while ((t%pr)==0)
|
||||
{ e++;
|
||||
t=t/pr;
|
||||
}
|
||||
if (e!=0)
|
||||
{ primes[len]=pr;
|
||||
exp[len]=e;
|
||||
if (len!=0) { cout << " * "; }
|
||||
cout << pr << "^" << e << flush;
|
||||
len++;
|
||||
}
|
||||
pr=s.next();
|
||||
}
|
||||
cout << endl;
|
||||
|
||||
// We want to find the best m which divides pr-1, such that
|
||||
// - 2*m > phi(m) > mm
|
||||
// - m has the smallest number of factors
|
||||
vector<int> ee;
|
||||
ee.resize(len);
|
||||
for (int i=0; i<len; i++) { ee[i]=0; }
|
||||
int min_hwt=-1,m=-1,bphi_m=-1,bmx=-1;
|
||||
bool flag=true;
|
||||
while (flag)
|
||||
{ int cand_m=1,hwt=0,mx=0;
|
||||
for (int i=0; i<len; i++)
|
||||
{ //cout << ee[i] << " ";
|
||||
if (ee[i]!=0)
|
||||
{ hwt++;
|
||||
for (int j=0; j<ee[i]; j++)
|
||||
{ cand_m*=primes[i]; }
|
||||
if (ee[i]>mx) { mx=ee[i]; }
|
||||
}
|
||||
}
|
||||
// Put "if" here to stop searching for things which will never work
|
||||
if (cand_m>1 && cand_m<4*mm)
|
||||
{ //cout << " : " << cand_m << " : " << hwt << flush;
|
||||
int phim=phi_N(cand_m);
|
||||
//cout << " : " << phim << " : " << mm << endl;
|
||||
if (phim>mm && phim<3*mm)
|
||||
{ if (m==-1 || hwt<min_hwt || (hwt==min_hwt && mx<bmx) || (hwt==min_hwt && mx==bmx && phim<bphi_m))
|
||||
{ m=cand_m;
|
||||
min_hwt=hwt;
|
||||
bphi_m=phim;
|
||||
bmx=mx;
|
||||
//cout << "\t OK" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{ //cout << endl;
|
||||
}
|
||||
int i=0;
|
||||
ee[i]=ee[i]+1;
|
||||
while (ee[i]>exp[i] && flag)
|
||||
{ ee[i]=0;
|
||||
i++;
|
||||
if (i==len) { flag=false; i=0; }
|
||||
else { ee[i]=ee[i]+1; }
|
||||
}
|
||||
}
|
||||
if (m==-1)
|
||||
{ throw bad_value(); }
|
||||
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
|
||||
|
||||
Parameters parameters(n, lgp, sec);
|
||||
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params);
|
||||
int mx=0;
|
||||
for (int i=0; i<R.phi_m(); i++)
|
||||
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
|
||||
cout << "Max Coeff = " << mx << endl;
|
||||
|
||||
init(R, m, true);
|
||||
|
||||
Zp_Data Zp(p);
|
||||
PPD.init(R,Zp);
|
||||
gfp::init_field(p);
|
||||
|
||||
pr0 = parameters.pr0;
|
||||
pr1 = parameters.pr1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
#ifndef _NTL_Subs
|
||||
#define _NTL_Subs
|
||||
|
||||
/* All these routines use NTL on the inside */
|
||||
|
||||
#include "FHE/Ring.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FHE_Params.h"
|
||||
|
||||
/* Routines to set up key sizes given the number of players n
|
||||
@@ -47,26 +44,28 @@ public:
|
||||
|
||||
};
|
||||
|
||||
// Main setup routine (need NTL if online_only is false)
|
||||
// Main setup routine
|
||||
void generate_setup(int nparties, int lgp, int lg2,
|
||||
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
|
||||
|
||||
// semi-homomorphic, includes slack
|
||||
template <class FD>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, FD& FieldD, bool round_up);
|
||||
FHE_Params& params, FD& FieldD, bool round_up, int n = 1);
|
||||
|
||||
// field-independent semi-homomorphic setup
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1,
|
||||
bool round_up);
|
||||
|
||||
// Everything else needs NTL
|
||||
void init(Ring& Rg, int m, bool generate_poly);
|
||||
void init(P2Data& P2D,const Ring& Rg);
|
||||
|
||||
// For use when we want p to be a specific value
|
||||
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params);
|
||||
namespace NTL
|
||||
{
|
||||
class GF2X;
|
||||
}
|
||||
|
||||
NTL::GF2X get_F(const Ring& Rg);
|
||||
|
||||
// generate moduli according to lengths and other parameters
|
||||
void generate_moduli(bigint& pr0, bigint& pr1, const int m,
|
||||
|
||||
@@ -36,10 +36,12 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s);
|
||||
// unify parameters by taking maximum over TopGear or not
|
||||
bigint B_clean_top_gear = B_clean * 2;
|
||||
bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.));
|
||||
bigint B_clean_not_top_gear = B_clean << max(slack - sec, 0);
|
||||
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
|
||||
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
|
||||
int matrix_dim = params.get_matrix_dim();
|
||||
#ifdef NOISY
|
||||
cout << "phi(m): " << phi_m << endl;
|
||||
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
|
||||
cout << "V_s: " << V_s << endl;
|
||||
cout << "c1: " << c1 << endl;
|
||||
@@ -48,9 +50,14 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
cout << "log(slack): " << slack << endl;
|
||||
cout << "B_clean: " << B_clean << endl;
|
||||
cout << "B_scale: " << B_scale << endl;
|
||||
cout << "matrix dimension: " << matrix_dim << endl;
|
||||
cout << "drown sec: " << params.secp() << endl;
|
||||
cout << "sec: " << sec << endl;
|
||||
#endif
|
||||
|
||||
drown = 1 + n * (bigint(1) << sec);
|
||||
assert(matrix_dim > 0);
|
||||
assert(params.secp() >= 0);
|
||||
drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp());
|
||||
}
|
||||
|
||||
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
|
||||
@@ -68,8 +75,14 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma)
|
||||
{
|
||||
if (sigma <= 0)
|
||||
sigma = FHE_Params().get_R();
|
||||
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
|
||||
return 37.8 * (log_q - log2(sigma));
|
||||
// the constant was updated using Martin Albrecht's LWE estimator in Mar 2022
|
||||
// found the following pairs for 128-bit security
|
||||
// and alpha = 0.7 * sqrt(2*pi) / q
|
||||
// m = 2048, log_2(q) = 68
|
||||
// m = 4096, log_2(q) = 138
|
||||
// m = 8192, log_2(q) = 302
|
||||
// m = 16384, log_2(q) = 560
|
||||
return 15.1 * log_q;
|
||||
}
|
||||
|
||||
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params)
|
||||
@@ -92,7 +105,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
|
||||
{
|
||||
tp *= t;
|
||||
double lgtp = log(tp) / log(2.0);
|
||||
if (C[i] < 0 && lgtp < FHE_epsilon)
|
||||
if (C[i] < 0 && lgtp < -FHE_epsilon)
|
||||
{
|
||||
C[i] = pow(x, i);
|
||||
}
|
||||
@@ -114,7 +127,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
|
||||
cout << "n: " << n << endl;
|
||||
cout << "sec: " << sec << endl;
|
||||
cout << "sigma: " << this->sigma << endl;
|
||||
cout << "h: " << h << endl;
|
||||
cout << "B_clean size: " << numBits(B_clean) << endl;
|
||||
cout << "B_scale size: " << numBits(B_scale) << endl;
|
||||
cout << "B_KS size: " << numBits(B_KS) << endl;
|
||||
@@ -155,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1)
|
||||
|
||||
bigint NoiseBounds::min_p1()
|
||||
{
|
||||
return drown * B_KS + 1;
|
||||
return max(bigint(drown * B_KS), bigint((phi_m * p) << 10));
|
||||
}
|
||||
|
||||
bigint NoiseBounds::opt_p1()
|
||||
@@ -169,8 +181,10 @@ bigint NoiseBounds::opt_p1()
|
||||
// solve
|
||||
mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a);
|
||||
bigint res = ceil(s);
|
||||
#ifdef VERBOSE
|
||||
cout << "Optimal p1 vs minimal: " << numBits(res) << "/"
|
||||
<< numBits(min_p1()) << endl;
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -182,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1)
|
||||
{
|
||||
min_p0 *= 2;
|
||||
min_p1 *= 2;
|
||||
#ifdef VERBOSE
|
||||
cout << "increasing lengths: " << numBits(min_p0) << "/"
|
||||
<< numBits(min_p1) << endl;
|
||||
#endif
|
||||
}
|
||||
lg2p1 = numBits(min_p1);
|
||||
lg2p0 = numBits(min_p0);
|
||||
|
||||
@@ -42,6 +42,8 @@ public:
|
||||
bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); }
|
||||
static double min_phi_m(int log_q, double sigma);
|
||||
static double min_phi_m(int log_q, const FHE_Params& params);
|
||||
|
||||
bigint get_B_clean() { return B_clean; }
|
||||
};
|
||||
|
||||
// as per ePrint 2012:642 for slack = 0
|
||||
|
||||
@@ -55,13 +55,13 @@ void P2Data::check_dimensions() const
|
||||
// cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl;
|
||||
if (A.size() != Ai.size())
|
||||
throw runtime_error("forward and backward mapping dimensions mismatch");
|
||||
if (A.size() != A[0].size())
|
||||
if (A.size() != A.at(0).size())
|
||||
throw runtime_error("forward mapping not square");
|
||||
if (Ai.size() != Ai[0].size())
|
||||
if (Ai.size() != Ai.at(0).size())
|
||||
throw runtime_error("backward mapping not square");
|
||||
if ((int)A[0].size() != slots * gf2n_short::degree())
|
||||
if ((int)A.at(0).size() != slots * gf2n_short::degree())
|
||||
throw runtime_error(
|
||||
"mapping dimension incorrect: " + to_string(A[0].size())
|
||||
"mapping dimension incorrect: " + to_string(A.at(0).size())
|
||||
+ " != " + to_string(slots) + " * "
|
||||
+ to_string(gf2n_short::degree()));
|
||||
}
|
||||
|
||||
100
FHE/PPData.cpp
100
FHE/PPData.cpp
@@ -1,100 +0,0 @@
|
||||
#include "FHE/Subroutines.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FFT.h"
|
||||
#include "FHE/Matrix.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
|
||||
|
||||
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{
|
||||
R=Rg;
|
||||
prData=PrD;
|
||||
|
||||
root=Find_Primitive_Root_m(Rg.m(),Rg.Phi(),PrD);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void PPData::to_eval(vector<modp>& elem) const
|
||||
{
|
||||
if (elem.size()!= (unsigned) R.phi_m())
|
||||
{ throw params_mismatch(); }
|
||||
|
||||
throw not_implemented();
|
||||
|
||||
/*
|
||||
vector<modp> ans;
|
||||
ans.resize(R.phi_m());
|
||||
modp x=root;
|
||||
for (int i=0; i<R.phi_m(); i++)
|
||||
{ ans[i]=elem[R.phi_m()-1];
|
||||
for (int j=1; j<R.phi_m(); j++)
|
||||
{ Mul(ans[i],ans[i],x,prData);
|
||||
Add(ans[i],ans[i],elem[R.phi_m()-j-1],prData);
|
||||
}
|
||||
Mul(x,x,root,prData);
|
||||
}
|
||||
elem=ans;
|
||||
*/
|
||||
}
|
||||
|
||||
void PPData::from_eval(vector<modp>& elem) const
|
||||
{
|
||||
// avoid warning
|
||||
elem.empty();
|
||||
throw not_implemented();
|
||||
|
||||
/*
|
||||
modp_matrix A;
|
||||
int n=phi_m();
|
||||
A.resize(n, vector<modp>(n+1) );
|
||||
modp x=root;
|
||||
for (int i=0; i<n; i++)
|
||||
{ assignOne(A[0][i],prData);
|
||||
for (int j=1; j<n; j++)
|
||||
{ Mul(A[j][i],A[j-1][i],x,prData); }
|
||||
Mul(x,x,root,prData);
|
||||
A[i][n]=elem[i];
|
||||
}
|
||||
elem=solve(A,prData);
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
|
||||
void PPData::reset_iteration()
|
||||
{
|
||||
pow = 1;
|
||||
theta = {root, prData};
|
||||
thetaPow = theta;
|
||||
}
|
||||
|
||||
void PPData::next_iteration()
|
||||
{
|
||||
do
|
||||
{ thetaPow *= (theta);
|
||||
pow++;
|
||||
}
|
||||
while (gcd(pow,m())!=1);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
gfp PPData::get_evaluation(const vector<bigint>& mess) const
|
||||
{
|
||||
// Uses Horner's rule
|
||||
gfp ans;
|
||||
ans = mess[mess.size()-1];
|
||||
gfp coeff;
|
||||
for (int j=mess.size()-2; j>=0; j--)
|
||||
{ ans *= (thetaPow);
|
||||
coeff = mess[j];
|
||||
ans += (coeff);
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
61
FHE/PPData.h
61
FHE/PPData.h
@@ -1,61 +0,0 @@
|
||||
#ifndef _PPData
|
||||
#define _PPData
|
||||
|
||||
#include "Math/modp.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Math/gfpvar.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "FHE/Ring.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
|
||||
/* Class for holding modular arithmetic data wrt the ring
|
||||
*
|
||||
* It also holds the ring
|
||||
*/
|
||||
|
||||
class PPData
|
||||
{
|
||||
public:
|
||||
typedef gfp T;
|
||||
typedef bigint S;
|
||||
typedef typename FFT_Data::poly_type poly_type;
|
||||
|
||||
Ring R;
|
||||
Zp_Data prData;
|
||||
|
||||
modp root; // m'th Root of Unity mod pr
|
||||
|
||||
void init(const Ring& Rg,const Zp_Data& PrD);
|
||||
|
||||
PPData() { ; }
|
||||
PPData(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ init(Rg,PrD); }
|
||||
|
||||
const Zp_Data& get_prD() const { return prData; }
|
||||
const bigint& get_prime() const { return prData.pr; }
|
||||
int phi_m() const { return R.phi_m(); }
|
||||
int m() const { return R.m(); }
|
||||
int num_slots() const { return R.phi_m(); }
|
||||
|
||||
|
||||
int p(int i) const { return R.p(i); }
|
||||
int p_inv(int i) const { return R.p_inv(i); }
|
||||
const vector<int>& Phi() const { return R.Phi(); }
|
||||
|
||||
// Convert input vector from poly to evaluation representation
|
||||
// - Uses naive method and not FFT, we only use this rarely in any case
|
||||
void to_eval(vector<modp>& elem) const;
|
||||
void from_eval(vector<modp>& elem) const;
|
||||
|
||||
// Following are used to iteratively get slots, as we use PPData when
|
||||
// we do not have an efficient FFT algorithm
|
||||
gfp thetaPow,theta;
|
||||
int pow;
|
||||
void reset_iteration();
|
||||
void next_iteration();
|
||||
gfp get_evaluation(const vector<bigint>& mess) const;
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
|
||||
#include "FHE/Plaintext.h"
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
@@ -11,10 +10,43 @@
|
||||
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
Plaintext<T, FD, S>::Plaintext(const FHE_Params& params) :
|
||||
Plaintext(params.get_plaintext_field_data<FD>(), Both)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
unsigned int Plaintext<T, FD, S>::num_slots() const
|
||||
{
|
||||
return (*Field_Data).phi_m();
|
||||
}
|
||||
|
||||
template<class T, class FD, class S>
|
||||
int Plaintext<T, FD, S>::degree() const
|
||||
{
|
||||
return (*Field_Data).phi_m();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
unsigned int Plaintext<gf2n_short,P2Data,int>::num_slots() const
|
||||
{
|
||||
return (*Field_Data).num_slots();
|
||||
}
|
||||
|
||||
template<>
|
||||
int Plaintext<gf2n_short,P2Data,int>::degree() const
|
||||
{
|
||||
return (*Field_Data).degree();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp, FFT_Data, bigint>::from(const Generator<bigint>& source) const
|
||||
{
|
||||
b.resize(degree);
|
||||
b.resize(degree());
|
||||
for (auto& x : b)
|
||||
{
|
||||
source.get(bigint::tmp);
|
||||
@@ -31,7 +63,7 @@ void Plaintext<gfp,FFT_Data,bigint>::from_poly() const
|
||||
Ring_Element e(*Field_Data,polynomial);
|
||||
e.from(b);
|
||||
e.change_rep(evaluation);
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
a[i] = gfp(e.get_element(i), e.get_FFTD().get_prD());
|
||||
type=Both;
|
||||
@@ -52,45 +84,12 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::from_poly() const
|
||||
{
|
||||
if (type!=Polynomial) { return; }
|
||||
vector<modp> aa((*Field_Data).phi_m());
|
||||
for (unsigned int i=0; i<aa.size(); i++)
|
||||
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
|
||||
(*Field_Data).to_eval(aa);
|
||||
a.resize(n_slots);
|
||||
for (unsigned int i=0; i<aa.size(); i++)
|
||||
a[i] = {aa[i], Field_Data->get_prD()};
|
||||
type=Both;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::to_poly() const
|
||||
{
|
||||
if (type!=Evaluation) { return; }
|
||||
cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl;
|
||||
vector<modp> bb((*Field_Data).phi_m());
|
||||
for (unsigned int i=0; i<bb.size(); i++)
|
||||
{ bb[i]=a[i].get(); }
|
||||
(*Field_Data).from_eval(bb);
|
||||
for (unsigned int i=0; i<bb.size(); i++)
|
||||
{
|
||||
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
|
||||
b[i] = bigint::tmp;
|
||||
}
|
||||
type=Both;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gf2n_short,P2Data,int>::from_poly() const
|
||||
{
|
||||
if (type!=Polynomial) { return; }
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
(*Field_Data).backward(a,b);
|
||||
type=Both;
|
||||
}
|
||||
@@ -106,34 +105,13 @@ void Plaintext<gf2n_short,P2Data,int>::to_poly() const
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,FFT_Data,bigint>::set_sizes()
|
||||
{ n_slots = (*Field_Data).phi_m();
|
||||
degree = n_slots;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::set_sizes()
|
||||
{ n_slots = (*Field_Data).phi_m();
|
||||
degree = n_slots;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gf2n_short,P2Data,int>::set_sizes()
|
||||
{ n_slots = (*Field_Data).num_slots();
|
||||
degree = (*Field_Data).degree();
|
||||
}
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
void Plaintext<T, FD, S>::allocate(PT_Type type) const
|
||||
{
|
||||
if (type != Evaluation)
|
||||
b.resize(degree);
|
||||
b.resize(degree());
|
||||
if (type != Polynomial)
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
this->type = type;
|
||||
}
|
||||
|
||||
@@ -141,7 +119,7 @@ void Plaintext<T, FD, S>::allocate(PT_Type type) const
|
||||
template<class T, class FD, class S>
|
||||
void Plaintext<T, FD, S>::allocate_slots(const bigint& value)
|
||||
{
|
||||
b.resize(degree);
|
||||
b.resize(degree());
|
||||
for (auto& x : b)
|
||||
x.allocate_slots(value);
|
||||
}
|
||||
@@ -236,7 +214,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
|
||||
type=Polynomial;
|
||||
break;
|
||||
case Diagonal:
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
a[0].randomize(G);
|
||||
for (unsigned int i=1; i<a.size(); i++)
|
||||
{ a[i]=a[0]; }
|
||||
@@ -244,7 +222,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
|
||||
break;
|
||||
default:
|
||||
// Gen a plaintext with 0/1 in each slot
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{
|
||||
if (G.get_bit())
|
||||
@@ -272,7 +250,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
|
||||
b[0].generateUniform(G, n_bits, false);
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < n_slots; i++)
|
||||
for (size_t i = 0; i < num_slots(); i++)
|
||||
b[i].generateUniform(G, n_bits, false);
|
||||
break;
|
||||
default:
|
||||
@@ -288,7 +266,7 @@ void Plaintext<T,FD,S>::assign_zero(PT_Type t)
|
||||
allocate();
|
||||
if (type!=Polynomial)
|
||||
{
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ a[i].assign_zero(); }
|
||||
}
|
||||
@@ -306,7 +284,7 @@ void Plaintext<T,FD,S>::assign_one(PT_Type t)
|
||||
allocate();
|
||||
if (type!=Polynomial)
|
||||
{
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ a[i].assign_one(); }
|
||||
}
|
||||
@@ -359,35 +337,7 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i] = (x.a[i] + y.a[i]); }
|
||||
}
|
||||
if (z.type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<z.b.size(); i++)
|
||||
{ z.b[i]=x.b[i]+y.b[i];
|
||||
if (z.b[i]>(*z.Field_Data).get_prime())
|
||||
{ z.b[i]-=(*z.Field_Data).get_prime(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
|
||||
const Plaintext<gfp,PPData,bigint>& y)
|
||||
{
|
||||
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
|
||||
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
|
||||
|
||||
if (x.type==Both && y.type!=Both) { z.type=y.type; }
|
||||
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
|
||||
else if (x.type!=y.type) { throw rep_mismatch(); }
|
||||
else { z.type=x.type; }
|
||||
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i] = (x.a[i] + y.a[i]); }
|
||||
}
|
||||
@@ -418,7 +368,7 @@ void add(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i].add(x.a[i],y.a[i]); }
|
||||
}
|
||||
@@ -446,7 +396,7 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i]= (x.a[i] - y.a[i]); }
|
||||
}
|
||||
@@ -463,36 +413,6 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
|
||||
const Plaintext<gfp,PPData,bigint>& y)
|
||||
{
|
||||
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
|
||||
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
|
||||
|
||||
if (x.type==Both && y.type!=Both) { z.type=y.type; }
|
||||
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
|
||||
else if (x.type!=y.type) { throw rep_mismatch(); }
|
||||
else { z.type=x.type; }
|
||||
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i] = (x.a[i] - y.a[i]); }
|
||||
}
|
||||
if (z.type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<z.b.size(); i++)
|
||||
{ z.b[i]=x.b[i]-y.b[i];
|
||||
if (z.b[i]<0)
|
||||
{ z.b[i]+=(*z.Field_Data).get_prime(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -510,7 +430,7 @@ void sub(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.n_slots);
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i].sub(x.a[i],y.a[i]); }
|
||||
}
|
||||
@@ -545,7 +465,7 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
|
||||
{
|
||||
if (type!=Polynomial)
|
||||
{
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ a[i].negate(); }
|
||||
}
|
||||
@@ -560,23 +480,6 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::negate()
|
||||
{
|
||||
if (type!=Polynomial)
|
||||
{
|
||||
a.resize(n_slots);
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ a[i].negate(); }
|
||||
}
|
||||
if (type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<b.size(); i++)
|
||||
{ if (b[i]!=0)
|
||||
{ b[i]=(*Field_Data).get_prime()-b[i]; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -607,7 +510,7 @@ bool Plaintext<T,FD,S>::equals(const Plaintext& x) const
|
||||
|
||||
if (type!=Polynomial and x.type!=Polynomial)
|
||||
{
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ if (!(a[i] == x.a[i])) { return false; } }
|
||||
}
|
||||
@@ -671,9 +574,9 @@ void Plaintext<T,FD,S>::unpack(octetStream& o)
|
||||
unsigned int size;
|
||||
o.get(size);
|
||||
allocate();
|
||||
if (size != b.size())
|
||||
if (size != b.size() and size != 0)
|
||||
throw length_error("unexpected length received");
|
||||
for (unsigned int i = 0; i < b.size(); i++)
|
||||
for (unsigned int i = 0; i < size; i++)
|
||||
b[i] = o.get<S>();
|
||||
}
|
||||
|
||||
@@ -719,12 +622,6 @@ template void mul(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data
|
||||
|
||||
|
||||
|
||||
template class Plaintext<gfp,PPData,bigint>;
|
||||
|
||||
template void mul(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,const Plaintext<gfp,PPData,bigint>& y);
|
||||
|
||||
|
||||
|
||||
template class Plaintext<gf2n_short,P2Data,int>;
|
||||
|
||||
template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y);
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "FHE/Generator.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "Math/fixint.h"
|
||||
|
||||
#include <vector>
|
||||
@@ -25,6 +26,8 @@ using namespace std;
|
||||
|
||||
class FHE_PK;
|
||||
class Rq_Element;
|
||||
class FHE_Params;
|
||||
class FFT_Data;
|
||||
template<class T> class AddableVector;
|
||||
|
||||
// Forward declaration as apparently this is needed for friends in templates
|
||||
@@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits };
|
||||
|
||||
enum PT_Type { Polynomial, Evaluation, Both };
|
||||
|
||||
/**
|
||||
* BGV plaintext.
|
||||
* Use ``Plaintext_mod_prime`` instead of filling in the templates.
|
||||
* The plaintext is held in one of the two representations or both,
|
||||
* polynomial and evaluation. The latter is the one allowing element-wise
|
||||
* multiplication over a vector.
|
||||
* Plaintexts can be added, subtracted, and multiplied via operator overloading.
|
||||
*/
|
||||
template<class T,class FD,class _>
|
||||
class Plaintext
|
||||
{
|
||||
typedef typename FD::poly_type S;
|
||||
|
||||
int n_slots;
|
||||
int degree;
|
||||
|
||||
mutable vector<T> a; // The thing in evaluation/FFT form
|
||||
mutable vector<S> b; // Now in polynomial form
|
||||
@@ -58,33 +67,47 @@ class Plaintext
|
||||
|
||||
const FD *Field_Data;
|
||||
|
||||
void set_sizes();
|
||||
int degree() const;
|
||||
|
||||
public:
|
||||
|
||||
const FD& get_field() const { return *Field_Data; }
|
||||
unsigned int num_slots() const { return n_slots; }
|
||||
|
||||
/// Number of slots
|
||||
unsigned int num_slots() const;
|
||||
|
||||
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
|
||||
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
|
||||
{ Field_Data=&FieldD; allocate(type); }
|
||||
|
||||
Plaintext(const FD& FieldD, const Rq_Element& other);
|
||||
|
||||
/// Initialization
|
||||
Plaintext(const FHE_Params& params);
|
||||
|
||||
void allocate(PT_Type type) const;
|
||||
void allocate() const { allocate(type); }
|
||||
void allocate_slots(const bigint& value);
|
||||
int get_min_alloc();
|
||||
|
||||
// Access evaluation representation
|
||||
/**
|
||||
* Read slot.
|
||||
* @param i slot number
|
||||
* @returns slot content
|
||||
*/
|
||||
T element(int i) const
|
||||
{ if (type==Polynomial)
|
||||
{ from_poly(); }
|
||||
return a[i];
|
||||
}
|
||||
/**
|
||||
* Write to slot
|
||||
* @param i slot number
|
||||
* @param e new slot content
|
||||
*/
|
||||
void set_element(int i,const T& e)
|
||||
{ if (type==Polynomial)
|
||||
{ throw not_implemented(); }
|
||||
a.resize(n_slots);
|
||||
a.resize(num_slots());
|
||||
a[i]=e;
|
||||
type=Evaluation;
|
||||
}
|
||||
@@ -171,10 +194,10 @@ class Plaintext
|
||||
|
||||
bool is_diagonal() const;
|
||||
|
||||
/* Pack and unpack into an octetStream
|
||||
* For unpack we assume the FFTD has been assigned correctly already
|
||||
*/
|
||||
/// Append to buffer
|
||||
void pack(octetStream& o) const;
|
||||
|
||||
/// Read from buffer. Assumes parameters are set correctly
|
||||
void unpack(octetStream& o);
|
||||
|
||||
size_t report_size(ReportType type);
|
||||
@@ -185,4 +208,6 @@ class Plaintext
|
||||
template <class FD>
|
||||
using Plaintext_ = Plaintext<typename FD::T, FD, typename FD::S>;
|
||||
|
||||
typedef Plaintext_<FFT_Data> Plaintext_mod_prime;
|
||||
|
||||
#endif
|
||||
|
||||
@@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o)
|
||||
o.get(pi_inv);
|
||||
o.get(poly);
|
||||
}
|
||||
else
|
||||
else if (mm != 0)
|
||||
init(*this, mm);
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ void Ring_Element::prepare_push()
|
||||
|
||||
void Ring_Element::allocate()
|
||||
{
|
||||
assert(FFTD);
|
||||
element.resize(FFTD->phi_m());
|
||||
}
|
||||
|
||||
@@ -86,7 +87,6 @@ void Ring_Element::negate()
|
||||
|
||||
void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
if (a.element.empty())
|
||||
{
|
||||
@@ -99,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
return;
|
||||
}
|
||||
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
|
||||
if (&ans == &a)
|
||||
{
|
||||
ans += b;
|
||||
@@ -401,19 +403,29 @@ void Ring_Element::change_rep(RepType r)
|
||||
|
||||
bool Ring_Element::equals(const Ring_Element& a) const
|
||||
{
|
||||
if (element.empty() and a.element.empty())
|
||||
return true;
|
||||
else if (element.empty() or a.element.empty())
|
||||
throw not_implemented();
|
||||
|
||||
if (rep!=a.rep) { throw rep_mismatch(); }
|
||||
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
|
||||
|
||||
if (is_zero() or a.is_zero())
|
||||
return is_zero() and a.is_zero();
|
||||
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } }
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool Ring_Element::is_zero() const
|
||||
{
|
||||
if (element.empty())
|
||||
return true;
|
||||
for (auto& x : element)
|
||||
if (not ::isZero(x, FFTD->get_prD()))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
ConversionIterator Ring_Element::get_iterator() const
|
||||
{
|
||||
if (rep != polynomial)
|
||||
@@ -560,6 +572,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const
|
||||
{
|
||||
if (&FFTD != this->FFTD)
|
||||
throw params_mismatch();
|
||||
if (is_zero())
|
||||
throw runtime_error("element is zero");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -95,6 +95,7 @@ class Ring_Element
|
||||
void randomize(PRNG& G,bool Diag=false);
|
||||
|
||||
bool equals(const Ring_Element& a) const;
|
||||
bool is_zero() const;
|
||||
|
||||
// This is a NOP in cases where we cannot do a FFT
|
||||
void change_rep(RepType r);
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
Rq_Element::Rq_Element(const FHE_PK& pk) :
|
||||
Rq_Element(pk.get_params().FFTD())
|
||||
Rq_Element(pk.get_params().FFTD(), evaluation, evaluation)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
|
||||
}
|
||||
}
|
||||
|
||||
void Rq_Element::add(octetStream& os, int)
|
||||
{
|
||||
Rq_Element tmp(*this);
|
||||
tmp.unpack(os);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
void Rq_Element::randomize(PRNG& G,int l)
|
||||
{
|
||||
set_level(l);
|
||||
@@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p)
|
||||
|
||||
// Now add delta back onto a0
|
||||
Rq_Element bb(b0,b1);
|
||||
add(*this,*this,bb);
|
||||
::add(*this,*this,bb);
|
||||
|
||||
// Now divide by p1 mod p0
|
||||
modp p1_inv,pp;
|
||||
@@ -291,7 +298,7 @@ void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)
|
||||
partial_assign(x);
|
||||
}
|
||||
|
||||
void Rq_Element::pack(octetStream& o) const
|
||||
void Rq_Element::pack(octetStream& o, int) const
|
||||
{
|
||||
check_level();
|
||||
o.store(lev);
|
||||
@@ -299,7 +306,7 @@ void Rq_Element::pack(octetStream& o) const
|
||||
a[i].pack(o);
|
||||
}
|
||||
|
||||
void Rq_Element::unpack(octetStream& o)
|
||||
void Rq_Element::unpack(octetStream& o, int)
|
||||
{
|
||||
unsigned int ll; o.get(ll); lev=ll;
|
||||
check_level();
|
||||
@@ -340,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const
|
||||
return sz;
|
||||
}
|
||||
|
||||
void Rq_Element::unpack(octetStream& o, const FHE_Params& params)
|
||||
{
|
||||
set_data(params.FFTD());
|
||||
unpack(o);
|
||||
}
|
||||
|
||||
void Rq_Element::print_first_non_zero() const
|
||||
{
|
||||
vector<bigint> v = to_vec_bigint();
|
||||
|
||||
@@ -69,8 +69,9 @@ protected:
|
||||
a({b0}), lev(n_mults()) {}
|
||||
|
||||
template<class T, class FD, class S>
|
||||
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext) :
|
||||
Rq_Element(params)
|
||||
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext,
|
||||
RepType r0 = polynomial, RepType r1 = polynomial) :
|
||||
Rq_Element(params, r0, r1)
|
||||
{
|
||||
from(plaintext.get_iterator());
|
||||
}
|
||||
@@ -93,12 +94,14 @@ protected:
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
|
||||
|
||||
void add(octetStream& os, int = -1);
|
||||
|
||||
template<class S>
|
||||
Rq_Element& operator+=(const vector<S>& other);
|
||||
|
||||
Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; }
|
||||
Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; }
|
||||
|
||||
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; }
|
||||
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; }
|
||||
Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; }
|
||||
template <class T>
|
||||
Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; }
|
||||
@@ -154,8 +157,11 @@ protected:
|
||||
* For unpack we assume the prData for a0 and a1 has been assigned
|
||||
* correctly already
|
||||
*/
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
void pack(octetStream& o, int = -1) const;
|
||||
void unpack(octetStream& o, int = -1);
|
||||
|
||||
// without prior initialization
|
||||
void unpack(octetStream& o, const FHE_Params& params);
|
||||
|
||||
void output(ostream& s) const;
|
||||
void input(istream& s);
|
||||
@@ -176,7 +182,7 @@ Rq_Element& Rq_Element::operator+=(const vector<S>& other)
|
||||
{
|
||||
Rq_Element tmp = *this;
|
||||
tmp.from(Iterator<S>(other), lev);
|
||||
add(*this, *this, tmp);
|
||||
::add(*this, *this, tmp);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,35 +11,15 @@ void Subs(modp& ans,const vector<int>& poly,const modp& x,const Zp_Data& ZpD)
|
||||
assignZero(ans,ZpD);
|
||||
for (int i=poly.size()-1; i>=0; i--)
|
||||
{ Mul(ans,ans,x,ZpD);
|
||||
switch (poly[i])
|
||||
{ case 0:
|
||||
break;
|
||||
case 1:
|
||||
Add(ans,ans,one,ZpD);
|
||||
break;
|
||||
case -1:
|
||||
Sub(ans,ans,one,ZpD);
|
||||
break;
|
||||
case 2:
|
||||
Add(ans,ans,one,ZpD);
|
||||
Add(ans,ans,one,ZpD);
|
||||
break;
|
||||
case -2:
|
||||
Sub(ans,ans,one,ZpD);
|
||||
Sub(ans,ans,one,ZpD);
|
||||
break;
|
||||
case 3:
|
||||
Add(ans,ans,one,ZpD);
|
||||
Add(ans,ans,one,ZpD);
|
||||
Add(ans,ans,one,ZpD);
|
||||
break;
|
||||
case -3:
|
||||
Sub(ans,ans,one,ZpD);
|
||||
Sub(ans,ans,one,ZpD);
|
||||
Sub(ans,ans,one,ZpD);
|
||||
break;
|
||||
default:
|
||||
throw not_implemented();
|
||||
if (poly[i] > 0)
|
||||
{
|
||||
for (int j = 0; j < poly[i]; j++)
|
||||
Add(ans, ans, one, ZpD);
|
||||
}
|
||||
if (poly[i] < 0)
|
||||
{
|
||||
for (int j = 0; j < -poly[i]; j++)
|
||||
Sub(ans, ans, one, ZpD);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,10 +40,9 @@ template <class FD>
|
||||
void PartSetup<FD>::generate_setup(int n_parties, int plaintext_length, int sec,
|
||||
int slack, bool round_up)
|
||||
{
|
||||
sec = max(sec, 40);
|
||||
params.set_min_sec(sec);
|
||||
Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup(
|
||||
params, FieldD);
|
||||
params.set_sec(sec);
|
||||
pk = FHE_PK(params, FieldD.get_prime());
|
||||
sk = FHE_SK(params, FieldD.get_prime());
|
||||
calpha = Ciphertext(params);
|
||||
@@ -180,11 +179,8 @@ void PartSetup<P2Data>::init_field()
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
void PartSetup<FD>::check(int sec) const
|
||||
void PartSetup<FD>::check() const
|
||||
{
|
||||
sec = max(sec, 40);
|
||||
if (abs(sec - params.secp()) > 2)
|
||||
throw runtime_error("security parameters vary too much between protocol and distributed decryption");
|
||||
sk.check(params, pk, FieldD.get_prime());
|
||||
}
|
||||
|
||||
@@ -203,7 +199,7 @@ template<class FD>
|
||||
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
::secure_init(*this, P, machine, plaintext_length, sec, params);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
|
||||
@@ -57,7 +57,7 @@ public:
|
||||
|
||||
void init_field();
|
||||
|
||||
void check(int sec) const;
|
||||
void check() const;
|
||||
bool operator!=(const PartSetup<FD>& other);
|
||||
|
||||
void secure_init(Player& P, MachineBase& machine, int plaintext_length,
|
||||
|
||||
@@ -274,7 +274,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
(*P).Broadcast_Receive(ctx_Delta);
|
||||
|
||||
// Output the ctx_Delta to a file
|
||||
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
ofstream outf(filename);
|
||||
for (int j=0; j<(*P).num_players(); j++)
|
||||
{
|
||||
@@ -308,7 +308,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
octetStream occ,ctx_D;
|
||||
for (int i=0; i<2*TT; i++)
|
||||
{ if (open[i]==1)
|
||||
{ sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
{ snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
ifstream inpf(filename);
|
||||
for (int j=0; j<(*P).num_players(); j++)
|
||||
{
|
||||
@@ -386,7 +386,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
Ciphertext enc1(params),enc2(params),eDelta(params);
|
||||
octetStream oe1,oe2;
|
||||
|
||||
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
|
||||
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
|
||||
ifstream inpf(filename);
|
||||
for (int k=0; k<(*P).num_players(); k++)
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
|
||||
template <class FD>
|
||||
void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
|
||||
OT_ROLE role, int n_summands)
|
||||
OT_ROLE role, int)
|
||||
{
|
||||
o.reset_write_head();
|
||||
|
||||
@@ -67,20 +67,10 @@ void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
|
||||
G.ReSeed();
|
||||
timers["Mask randomization"].start();
|
||||
product_share.randomize(G);
|
||||
bigint B = 6 * machine.setup<FD>().params.get_R();
|
||||
B *= machine.setup<FD>().FieldD.get_prime();
|
||||
B <<= machine.drown_sec;
|
||||
// slack
|
||||
B *= NonInteractiveProof::slack(machine.sec,
|
||||
machine.setup<FD>().params.phi_m());
|
||||
B <<= machine.extra_slack;
|
||||
B *= n_summands;
|
||||
rc.generateUniform(G, 0, B, B);
|
||||
mask = c;
|
||||
mask.rerandomize(other_pk);
|
||||
timers["Mask randomization"].stop();
|
||||
timers["Encryption"].start();
|
||||
other_pk.encrypt(mask, product_share, rc);
|
||||
timers["Encryption"].stop();
|
||||
mask += c;
|
||||
mask += product_share;
|
||||
mask.pack(o);
|
||||
res -= product_share;
|
||||
}
|
||||
@@ -130,6 +120,13 @@ void Multiplier<FD>::report_size(ReportType type, MemoryUsage& res)
|
||||
res += memory_usage;
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
const vector<Ciphertext>& Multiplier<FD>::get_multiplicands(
|
||||
const vector<vector<Ciphertext> >& others_ct, const FHE_PK&)
|
||||
{
|
||||
return others_ct[P.get_full_player().get_player(-P.get_offset())];
|
||||
}
|
||||
|
||||
|
||||
template class Multiplier<FFT_Data>;
|
||||
template class Multiplier<P2Data>;
|
||||
|
||||
@@ -55,6 +55,9 @@ public:
|
||||
size_t report_size(ReportType type);
|
||||
void report_size(ReportType type, MemoryUsage& res);
|
||||
size_t report_volatile() { return volatile_capacity; }
|
||||
|
||||
const vector<Ciphertext>& get_multiplicands(
|
||||
const vector<vector<Ciphertext>>& others_ct, const FHE_PK&);
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_MULTIPLIER_H_ */
|
||||
|
||||
@@ -24,7 +24,7 @@ PairwiseGenerator<FD>::PairwiseGenerator(int thread_num,
|
||||
thread_num, machine.output, machine.get_prep_dir<FD>(P)),
|
||||
EC(P, machine.other_pks, machine.setup<FD>().FieldD, timers, machine, *this),
|
||||
MC(machine.setup<FD>().alphai),
|
||||
n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)),
|
||||
n_ciphertexts(EC.proof.U),
|
||||
C(n_ciphertexts, machine.setup<FD>().params), volatile_memory(0),
|
||||
machine(machine)
|
||||
{
|
||||
@@ -175,7 +175,7 @@ size_t PairwiseGenerator<FD>::report_size(ReportType type)
|
||||
template <class FD>
|
||||
size_t PairwiseGenerator<FD>::report_sent()
|
||||
{
|
||||
return P.sent;
|
||||
return P.total_comm().sent;
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
|
||||
@@ -17,19 +17,17 @@ PairwiseMachine::PairwiseMachine(Player& P) :
|
||||
{
|
||||
}
|
||||
|
||||
PairwiseMachine::PairwiseMachine(int argc, const char** argv) :
|
||||
MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")),
|
||||
other_pks(N.num_players(), {setup_p.params, 0}),
|
||||
pk(other_pks[N.my_num()]), sk(pk)
|
||||
RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) :
|
||||
MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise"))
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
void PairwiseMachine::init()
|
||||
void RealPairwiseMachine::init()
|
||||
{
|
||||
if (use_gf2n)
|
||||
{
|
||||
field_size = 40;
|
||||
field_size = gf2n_short::DEFAULT_LENGTH;
|
||||
gf2n_short::init_field(field_size);
|
||||
setup_keys<P2Data>();
|
||||
}
|
||||
@@ -63,11 +61,11 @@ PairwiseSetup<P2Data>& PairwiseMachine::setup()
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
void PairwiseMachine::setup_keys()
|
||||
void RealPairwiseMachine::setup_keys()
|
||||
{
|
||||
auto& N = P;
|
||||
PairwiseSetup<FD>& s = setup<FD>();
|
||||
s.init(P, drown_sec, field_size, extra_slack);
|
||||
s.init(P, sec, field_size, extra_slack);
|
||||
if (output)
|
||||
write_mac_key(get_prep_dir<FD>(P), P.my_num(), P.num_players(), s.alphai);
|
||||
for (auto& x : other_pks)
|
||||
@@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys()
|
||||
if (i != N.my_num())
|
||||
other_pks[i].unpack(os[i]);
|
||||
set_mac_key(s.alphai);
|
||||
Share<typename FD::T>::MAC_Check::setup(P);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void PairwiseMachine::set_mac_key(T alphai)
|
||||
void RealPairwiseMachine::set_mac_key(T alphai)
|
||||
{
|
||||
typedef typename T::FD FD;
|
||||
auto& N = P;
|
||||
@@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const
|
||||
bundle.compare(P);
|
||||
}
|
||||
|
||||
template void PairwiseMachine::setup_keys<FFT_Data>();
|
||||
template void PairwiseMachine::setup_keys<P2Data>();
|
||||
template void RealPairwiseMachine::setup_keys<FFT_Data>();
|
||||
template void RealPairwiseMachine::setup_keys<P2Data>();
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "FHEOffline/SimpleMachine.h"
|
||||
#include "FHEOffline/PairwiseSetup.h"
|
||||
|
||||
class PairwiseMachine : public MachineBase
|
||||
class PairwiseMachine : public virtual MachineBase
|
||||
{
|
||||
public:
|
||||
PairwiseSetup<FFT_Data> setup_p;
|
||||
@@ -23,15 +23,6 @@ public:
|
||||
vector<Ciphertext> enc_alphas;
|
||||
|
||||
PairwiseMachine(Player& P);
|
||||
PairwiseMachine(int argc, const char** argv);
|
||||
|
||||
void init();
|
||||
|
||||
template <class FD>
|
||||
void setup_keys();
|
||||
|
||||
template <class T>
|
||||
void set_mac_key(T alphai);
|
||||
|
||||
template <class FD>
|
||||
PairwiseSetup<FD>& setup();
|
||||
@@ -42,4 +33,18 @@ public:
|
||||
void check(Player& P) const;
|
||||
};
|
||||
|
||||
class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine
|
||||
{
|
||||
public:
|
||||
RealPairwiseMachine(int argc, const char** argv);
|
||||
|
||||
void init();
|
||||
|
||||
template <class FD>
|
||||
void setup_keys();
|
||||
|
||||
template <class T>
|
||||
void set_mac_key(T alphai);
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "FHEOffline/Proof.h"
|
||||
#include "FHEOffline/PairwiseMachine.h"
|
||||
#include "FHEOffline/TemiSetup.h"
|
||||
#include "Tools/Commit.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
@@ -53,7 +54,7 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
|
||||
template <class FD>
|
||||
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
::secure_init(*this, P, machine, plaintext_length, sec, params);
|
||||
alpha = FieldD;
|
||||
machine.sk = FHE_SK(params, FieldD.get_prime());
|
||||
for (auto& pk : machine.other_pks)
|
||||
@@ -62,16 +63,20 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
|
||||
|
||||
template <class T, class U>
|
||||
void secure_init(T& setup, Player& P, U& machine,
|
||||
int plaintext_length, int sec)
|
||||
int plaintext_length, int sec, FHE_Params& params)
|
||||
{
|
||||
assert(sec >= 0);
|
||||
machine.sec = sec;
|
||||
sec = max(sec, 40);
|
||||
machine.drown_sec = sec;
|
||||
params.set_min_sec(sec);
|
||||
string filename = PREP_DIR + T::name() + "-"
|
||||
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
|
||||
+ to_string(params.secp()) + "-"
|
||||
+ to_string(params.get_matrix_dim()) + "-"
|
||||
+ OnlineOptions::singleton.prime.get_str() + "-"
|
||||
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
|
||||
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
|
||||
string reason;
|
||||
|
||||
try
|
||||
{
|
||||
ifstream file(filename);
|
||||
@@ -79,13 +84,30 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
os.input(file);
|
||||
os.get(machine.extra_slack);
|
||||
setup.unpack(os);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
reason = e.what();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
setup.check(P, machine);
|
||||
}
|
||||
catch (...)
|
||||
catch (exception& e)
|
||||
{
|
||||
cout << "Finding parameters for security " << sec << " and field size ~2^"
|
||||
<< plaintext_length << endl;
|
||||
setup.params = setup.params.n_mults();
|
||||
reason = e.what();
|
||||
}
|
||||
|
||||
if (not reason.empty())
|
||||
{
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Generating parameters for security " << sec
|
||||
<< " and field size ~2^" << plaintext_length
|
||||
<< " because no suitable material "
|
||||
"from a previous run was found (" << reason << ")"
|
||||
<< endl;
|
||||
setup = {};
|
||||
setup.generate(P, machine, plaintext_length, sec);
|
||||
setup.check(P, machine);
|
||||
octetStream os;
|
||||
@@ -94,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
ofstream file(filename);
|
||||
os.output(file);
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
{
|
||||
cerr << "Ciphertext length: " << params.p0().numBits();
|
||||
for (size_t i = 1; i < params.FFTD().size(); i++)
|
||||
cerr << "+" << params.FFTD()[i].get_prime().numBits();
|
||||
cerr << endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
@@ -208,5 +238,8 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
|
||||
template class PairwiseSetup<FFT_Data>;
|
||||
template class PairwiseSetup<P2Data>;
|
||||
|
||||
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
|
||||
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);
|
||||
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
|
||||
template void secure_init(TemiSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
template void secure_init(TemiSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
|
||||
@@ -15,7 +15,7 @@ class MachineBase;
|
||||
|
||||
template <class T, class U>
|
||||
void secure_init(T& setup, Player& P, U& machine,
|
||||
int plaintext_length, int sec);
|
||||
int plaintext_length, int sec, FHE_Params& params);
|
||||
|
||||
template <class FD>
|
||||
class PairwiseSetup
|
||||
|
||||
@@ -577,7 +577,7 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
|
||||
for (int j = min; j < max; j++)
|
||||
{
|
||||
AddableVector<Ciphertext> C;
|
||||
vector<Plaintext_<FD>> m(EC.machine->sec, FieldD);
|
||||
vector<Plaintext_<FD>> m(personal_EC.proof.U, FieldD);
|
||||
if (j == P.my_num())
|
||||
{
|
||||
for (auto& x : m)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user