Merge remote-tracking branch 'origin/master' into update-prng-seed

This commit is contained in:
jonas
2023-05-30 10:34:06 +02:00
625 changed files with 24382 additions and 6556 deletions

7
.gitignore vendored
View File

@@ -116,3 +116,10 @@ Thumbs.db
# Sphinx build
_build/
# environment
.env
# temp doc files
doc/readme.md
doc/xml

13
.gitmodules vendored
View File

@@ -1,12 +1,15 @@
[submodule "SimpleOT"]
path = SimpleOT
path = deps/SimpleOT
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = mpir
url = git://github.com/wbhart/mpir.git
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion
[submodule "simde"]
path = simde
path = deps/simde
url = https://github.com/simd-everywhere/simde
[submodule "deps/libOTe"]
path = deps/libOTe
url = https://github.com/mkskeller/softspoken-implementation
[submodule "deps/SimplestOT_C"]
path = deps/SimplestOT_C
url = https://github.com/mkskeller/SimplestOT_C

View File

@@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) :
}
cout << "Compiler: " << prev << endl;
P = new PlainPlayer(N, 0);
Share<gf2n_long>::MAC_Check::setup(*P);
if (argc > 4)
threshold = atoi(argv[4]);
cout << "Threshold for multi-threaded evaluation: " << threshold << endl;
@@ -259,7 +260,8 @@ ProgramParty::~ProgramParty()
reset();
if (P)
{
cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl;
cerr << "Data sent in online phase = " << P->total_comm().sent * 1e-6
<< " MB (this party only)" << endl;
delete P;
}
delete[] eval_threads;
@@ -281,6 +283,7 @@ FakeProgramParty::~FakeProgramParty()
cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes()
<< " GB" << endl;
#endif
Share<gf2n_long>::MAC_Check::teardown();
}
void FakeProgramParty::_compute_prfs_outputs(Key* keys)

View File

@@ -48,8 +48,6 @@ public:
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args);
RealGarbleWire(const Register& reg) : PRFRegister(reg) {}
void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
const RealGarbleWire<T>& right);

View File

@@ -110,7 +110,7 @@ void RealGarbleWire<T>::inputbvec(
{
GarbleInputter<T> inputter;
processor.inputbvec(inputter, input_processor, args,
inputter.party.P->my_num());
*inputter.party.P);
}
template<class T>
@@ -175,7 +175,7 @@ void GarbleInputter<T>::exchange()
assert(party.P != 0);
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul(party.shared_proc);
protocol.init_mul();
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)

View File

@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
bool one_shot;
size_t data_sent;
public:
static RealProgramParty& s();

View File

@@ -28,7 +28,7 @@ RealProgramParty<T>* RealProgramParty<T>::singleton = 0;
template<class T>
RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
garble_processor(garble_machine), dummy_proc({{}, 0})
garble_processor(garble_machine), dummy_proc({}, 0)
{
assert(singleton == 0);
singleton = this;
@@ -64,7 +64,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
online_opts = {opt, argc, argv, 1000};
else
online_opts = {opt, argc, argv};
assert(not online_opts.interactive);
online_opts.finalize(opt, argc, argv);
this->load(online_opts.progname);
@@ -97,8 +96,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
if (online_opts.live_prep)
{
mac_key.randomize(prng);
if (T::needs_ot)
BaseMachine::s().ot_setups.push_back({*P, true});
prep = new typename T::LivePrep(0, usage);
}
else
@@ -107,10 +104,12 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
}
T::MAC_Check::setup(*P);
MC = new typename T::MAC_Check(mac_key);
garble_processor.reset(program);
this->processor.open_input_file(N.my_num(), 0);
this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file);
this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out);
shared_proc = new SubProcessor<T>(dummy_proc, *MC, *prep, *P);
@@ -155,7 +154,9 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);
MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();
if (online_opts.verbose)
P->total_comm().print();
this->machine.write_memory(this->N.my_num());
}
@@ -173,7 +174,8 @@ void RealProgramParty<T>::garble()
garble_jobs.clear();
garble_inputter->reset_all(*P);
auto& protocol = *garble_protocol;
protocol.init_mul(shared_proc);
protocol.init(*prep, shared_proc->MC);
protocol.init_mul();
next = this->first_phase(program, garble_processor, this->garble_machine);
@@ -181,7 +183,8 @@ void RealProgramParty<T>::garble()
protocol.exchange();
typename T::Protocol second_protocol(*P);
second_protocol.init_mul(shared_proc);
second_protocol.init(*prep, shared_proc->MC);
second_protocol.init_mul();
for (auto& job : garble_jobs)
job.middle_round(*this, second_protocol);
@@ -212,7 +215,8 @@ RealProgramParty<T>::~RealProgramParty()
delete prep;
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
garble_machine.print_comm(*this->P, this->P->total_comm());
T::MAC_Check::teardown();
}
template<class T>

View File

@@ -568,6 +568,7 @@ void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
octetStream& my_os = oss[party.get_id() - 1];
vector<InputAccess> accesses;
InputArgList a(args);
processor.complexity += a.n_input_bits();
for (auto x : a)
{
accesses.push_back({x , processor});

View File

@@ -23,6 +23,7 @@ using namespace std;
#include "Tools/PointerVector.h"
#include "Tools/Bundle.h"
#include "Tools/SwitchableOutput.h"
#include "Processor/Instruction.h"
//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
@@ -61,11 +62,13 @@ private:
#endif
};
#else
class BaseKeyVector : public vector<Key>
class BaseKeyVector : public CheckVector<Key>
{
typedef CheckVector<Key> super;
public:
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
void resize(int size) { vector<Key>::resize(size, Key(0)); }
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
void resize(int size) { super::resize(size, Key(0)); }
};
#endif
@@ -151,7 +154,7 @@ public:
* for pipelining matters.
*/
Register(int n_parties);
Register();
void init(int n_parties);
void init(int rfd, int n_parties);
@@ -234,6 +237,9 @@ public:
template <class T>
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
template <class T>
static void andrsvec(T& processor, const vector<int>& args)
{ processor.andrsvec(args); }
template <class T>
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
template <class T>
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
@@ -277,10 +283,6 @@ public:
static int threshold(int) { throw not_implemented(); }
static Register new_reg();
static Register tmp_reg() { return new_reg(); }
static Register and_reg() { return new_reg(); }
template<class T>
static void store(NoMemory& dest,
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
@@ -289,6 +291,16 @@ public:
static void inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args);
template<class U>
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("convcbit2s not implemented"); }
template<class U>
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }
static void run_tapes(const vector<int>&)
{ throw runtime_error("multi-threading not implemented"); }
// most BMR phases don't need actual input
template<class T>
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
@@ -298,8 +310,6 @@ public:
void other_input(Input&, int) {}
char get_output() { return 0; }
ProgramRegister(const Register& reg) : Register(reg) {}
};
class PRFRegister : public ProgramRegister
@@ -311,8 +321,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
PRFRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const PRFRegister& left, const PRFRegister& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char input = -1);
@@ -388,8 +396,6 @@ public:
static void convcbit(Integer& dest, const GC::Clear& source,
GC::Processor<GC::Secret<EvalRegister>>& proc);
EvalRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
void XOR(const Register& left, const Register& right);
@@ -419,8 +425,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
GarbleRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const Register& left, const Register& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char value = -1);
@@ -444,8 +448,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
RandomRegister(const Register& reg) : ProgramRegister(reg) {}
void randomize();
void op(const Register& left, const Register& right, Function func);
@@ -461,12 +463,6 @@ public:
};
inline Register::Register(int n_parties) :
garbled_entry(n_parties), external(NO_SIGNAL),
mask(NO_SIGNAL), keys(n_parties)
{
}
inline void KeyVector::operator=(const KeyVector& other)
{
resize(other.size());

View File

@@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args)
{
NoOpInputter inputter;
int my_num = -1;
try
{
my_num = ProgramParty::s().P->my_num();
}
catch (exception&)
{
}
processor.inputbvec(inputter, input_processor, args, my_num);
processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P);
}
template<class T>
@@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor,
{
EvalInputter inputter;
processor.inputbvec(inputter, input_processor, args,
ProgramParty::s().P->my_num());
*ProgramParty::s().P);
}
template <class T>

View File

@@ -9,10 +9,10 @@
#include "CommonParty.h"
#include "Party.h"
inline Register ProgramRegister::new_reg()
inline Register::Register() :
garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL),
mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties())
{
return Register(CommonParty::s().get_n_parties());
}
#endif /* BMR_REGISTER_INLINE_H_ */

View File

@@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty()
_received_gc_received = 0;
n_received = 0;
randomfd = open("/dev/urandom", O_RDONLY);
done_filling = false;
}
BaseTrustedParty::~BaseTrustedParty()
{
close(randomfd);
}
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :

View File

@@ -20,7 +20,7 @@ public:
vector<SendBuffer> msg_input_masks;
BaseTrustedParty();
virtual ~BaseTrustedParty() {}
virtual ~BaseTrustedParty();
/* From NodeUpdatable class */
virtual void NodeReady();
@@ -104,7 +104,6 @@ private:
void add_all_keys(const Register& reg, bool external);
};
inline void BaseTrustedParty::add_keys(const Register& reg)
{
for(int p = 0; p < get_n_parties(); p++)

View File

@@ -1,5 +1,101 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.6 (May 9, 2023)
- More extensive benchmarking outputs
- Replace MPIR by GMP
- Secure reading of edaBits from files
- Semi-honest client communication
- Back-propagation for average pooling
- Parallelized convolution
- Probabilistic truncation as in ABY3
- More balanced communication in Shamir secret sharing
- Avoid unnecessary communication in Dealer protocol
- Linear solver using Cholesky decomposition
- Accept .py files for compilation
- Fixed security bug: proper accounting for random elements
## 0.3.5 (Feb 16, 2023)
- Easier-to-use machine learning interface
- Integrated compilation-execution facility
- Import/export sequential models and parameters from/to PyTorch
- Binary-format input files
- Less aggressive round optimization for faster compilation by default
- Multithreading with client interface
- Functionality to protect order of specific memory accesses
- Oblivious transfer works again on older (pre-2011) x86 CPUs
- clang is used by default
## 0.3.4 (Nov 9, 2022)
- Decision tree learning
- Optimized oblivious shuffle in Rep3
- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC
- Optimized element-vector AND in SemiBin
- Optimized input protocol in Shamir-based protocols
- Square-root ORAM (@Quitlox)
- Improved ORAM in binary circuits
- UTF-8 outputs
## 0.3.3 (Aug 25, 2022)
- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate
- Fix security bug in MAC check when using multithreading
- Fix security bug to prevent selective failure attack by checking earlier
- Fix security bug in Mama: insufficient sacrifice.
- Inverse permutation (@Quitlox)
- Easier direct compilation (@eriktaubeneck)
- Generally allow element-vector operations
- Increase maximum register size to 2^54
- Client example in Python
- Uniform base OTs across platforms
- Multithreaded base OT computation
- Faster random bit generation in two-player Semi(2k)
## 0.3.2 (May 27, 2022)
- Secure shuffling
- O(n log n) radix sorting
- Documented BGV encryption interface
- Optimized matrix multiplication in dealer protocol
- Fixed security bug in homomorphic encryption parameter generation
- Fixed security bug in Temi matrix multiplication
## 0.3.1 (Apr 19, 2022)
- Protocol in dealer model
- Command-line option for security parameter
- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482))
- Ability to run high-level (Python) code from C++
- More memory capacity due to 64-bit addressing
- Homomorphic encryption for more fields of characteristic two
- Docker container
## 0.3.0 (Feb 17, 2022)
- Semi-honest computation based on threshold semi-homomorphic encryption
- Batch normalization backward propagation
- AlexNet for CIFAR-10
- Specific private output protocols
- Semi-honest additive secret sharing without communication
- Sending of personal values
- Allow overwriting of persistence files
- Protocol signature in persistence files
## 0.2.9 (Jan 11, 2022)
- Disassembler
- Run-time parameter for probabilistic truncation error
- Probabilistic truncation for some protocols computing modulo a prime
- Simplified C++ interface
- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757)
- More general scalar-vector multiplication
- Complete memory support for clear bits
- Extended clear bit functionality with Yao's garbled circuits
- Allow preprocessing information to be supplied via named pipes
- In-place operations for containers
## 0.2.8 (Nov 4, 2021)
- Tested on Apple laptop with ARM chip

43
CONFIG
View File

@@ -8,6 +8,9 @@ GDEBUG = -g
# set this to your preferred local storage directory
PREP_DIR = '-DPREP_DIR="Player-Data/"'
# directory to store SSL keys
SSL_DIR = '-DSSL_DIR="Player-Data/"'
# set for SHE preprocessing (SPDZ and Overdrive)
USE_NTL = 0
@@ -28,25 +31,40 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
ARCH = -march=native
MACHINE := $(shell uname -m)
ARM := $(shell uname -m | grep x86; echo $$?)
OS := $(shell uname -s)
ifeq ($(MACHINE), x86_64)
# set this to 0 to avoid using AVX for OT
ifeq ($(OS), Linux)
CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?)
ifeq ($(CHECK_AVX), 0)
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
AVX_OT = 1
else
AVX_OT = 0
endif
else
AVX_OT = 1
AVX_OT = 0
endif
else
ARCH =
AVX_OT = 0
endif
ifeq ($(OS), Darwin)
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
endif
ifeq ($(OS), Linux)
ifeq ($(ARM), 1)
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
ARCH = -march=armv8.2-a+crypto
endif
endif
endif
USE_KOS = 0
# allow to set compiler in CONFIG.mine
CXX = g++
CXX = clang++
# use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine
@@ -65,9 +83,13 @@ endif
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
LDLIBS += $(BREW_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto
CFLAGS += -I./local/include
ifeq ($(USE_NTL),1)
CFLAGS += -DUSE_NTL
LDLIBS := -lntl $(LDLIBS)
@@ -83,7 +105,8 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS)
LD = $(CXX)
@@ -94,3 +117,9 @@ ifeq ($(USE_NTL),1)
CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy
endif
endif
ifeq ($(USE_KOS),1)
CFLAGS += -DUSE_KOS
else
CFLAGS += -std=c++17
endif

View File

@@ -13,11 +13,14 @@ import Compiler.instructions as spdz
import Compiler.tools as tools
import collections
import itertools
import math
class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
name = 'sbit'
class ClearBitsAF(base.RegisterArgFormat):
reg_type = 'cb'
name = 'cbit'
base.ArgFormats['sb'] = SecretBitsAF
base.ArgFormats['sbw'] = SecretBitsAF
@@ -50,6 +53,7 @@ opcodes = dict(
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
ANDRSVEC = 0x24a,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
@@ -64,6 +68,8 @@ opcodes = dict(
MULCBI = 0x21c,
SHRCBI = 0x21d,
SHLCBI = 0x21e,
LDMCBI = 0x258,
STMCBI = 0x259,
CONVCINTVEC = 0x21f,
PRINTREGSIGNED = 0x220,
PRINTREGB = 0x221,
@@ -153,6 +159,52 @@ class andrs(BinaryVectorInstruction):
def add_usage(self, req_node):
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
req_node.increment(('bit', 'mixed'),
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
class andrsvec(base.VarArgsInstruction, base.Mergeable,
base.DynFormatInstruction):
""" Constant-vector AND of secret bit registers (vectorized version).
:param: total number of arguments to follow (int)
:param: number of arguments to follow for one operation /
operation vector size plus three (int)
:param: vector size (int)
:param: result vector (sbit)
:param: (repeat)...
:param: constant operand (sbits)
:param: vector operand
:param: (repeat)...
:param: (repeat from number of arguments to follow for one operation)...
"""
code = opcodes['ANDRSVEC']
def __init__(self, *args, **kwargs):
super(andrsvec, self).__init__(*args, **kwargs)
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
for x in self.args[i + 2:i + n]:
assert x.n == size
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
n_args = (n - 3) // 2
assert n_args > 0
for j in range(n_args):
yield 'sbw'
for j in range(n_args + 1):
yield 'sb'
yield 'int'
def add_usage(self, req_node):
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)
class ands(BinaryVectorInstruction):
""" Bitwise AND of secret bit register vector.
@@ -303,7 +355,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
:param: memory address (int)
"""
code = opcodes['LDMSB']
arg_format = ['sbw','int']
arg_format = ['sbw','long']
class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
""" Copy secret bit register to secret bit memory cell with compile-time
@@ -313,7 +365,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
:param: memory address (int)
"""
code = opcodes['STMSB']
arg_format = ['sb','int']
arg_format = ['sb','long']
# def __init__(self, *args, **kwargs):
# super(type(self), self).__init__(*args, **kwargs)
# import inspect
@@ -328,7 +380,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
:param: memory address (int)
"""
code = opcodes['LDMCB']
arg_format = ['cbw','int']
arg_format = ['cbw','long']
class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
""" Copy clear bit register to clear bit memory cell with compile-time
@@ -338,9 +390,10 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
:param: memory address (int)
"""
code = opcodes['STMCB']
arg_format = ['cb','int']
arg_format = ['cb','long']
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy secret bit memory cell with run-time address to secret bit
register.
@@ -349,8 +402,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['LDMSBI']
arg_format = ['sbw','ci']
direct = staticmethod(ldmsb)
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy secret bit register to secret bit memory cell with run-time
address.
@@ -359,6 +414,31 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['STMSBI']
arg_format = ['sb','ci']
direct = staticmethod(stmsb)
class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy clear bit memory cell with run-time address to clear bit
register.
:param: destination (cbit)
:param: memory address (regint)
"""
code = opcodes['LDMCBI']
arg_format = ['cbw','ci']
direct = staticmethod(ldmcb)
class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy clear bit register to clear bit memory cell with run-time
address.
:param: source (cbit)
:param: memory address (regint)
"""
code = opcodes['STMCBI']
arg_format = ['cb','ci']
direct = staticmethod(stmcb)
class ldmsdi(base.ReadMemoryInstruction):
code = opcodes['LDMSDI']
@@ -475,7 +555,7 @@ class movsb(NonVectorInstruction):
code = opcodes['MOVSB']
arg_format = ['sbw','sb']
class trans(base.VarArgsInstruction):
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
""" Secret bit register vector transpose. The first destination vector
will contain the least significant bits of all source vectors etc.
@@ -489,10 +569,22 @@ class trans(base.VarArgsInstruction):
code = opcodes['TRANS']
is_vec = lambda self: True
def __init__(self, *args):
self.arg_format = ['int'] + ['sbw'] * args[0] + \
['sb'] * (len(args) - 1 - args[0])
super(trans, self).__init__(*args)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
n = next(args)
for i in range(n):
yield 'sbw'
next(args)
while True:
try:
yield 'sb'
next(args)
except StopIteration:
break
class bitb(NonVectorInstruction):
""" Copy fresh secret random bit to secret bit register.
@@ -538,7 +630,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1])
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
base.Mergeable):
base.Mergeable, base.DynFormatInstruction):
""" Copy private input to secret bit registers bit by bit. The input is
read as floating-point number, multiplied by a power of two, rounded to an
integer, and then decomposed into bits.
@@ -555,11 +647,19 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
code = opcodes['INPUTBVEC']
def __init__(self, *args, **kwargs):
self.arg_format = []
for x in self.get_arg_tuples(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
super(inputbvec, self).__init__(*args, **kwargs)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
yield 'p'
assert n > 3
for j in range(n - 3):
yield 'sbw'
yield 'int'
@staticmethod
def get_arg_tuples(args):
i = 0
@@ -568,10 +668,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
i += args[i]
assert i == len(args)
def merge(self, other):
self.args += other.args
self.arg_format += other.arg_format
def add_usage(self, req_node):
for x in self.get_arg_tuples(self.args):
req_node.increment(('bit', 'input', x[2]), x[0] - 3)

View File

@@ -3,10 +3,13 @@ This modules contains basic types for binary circuits. The
fixed-length types obtained by :py:obj:`get_type(n)` are the preferred
way of using them, and in some cases required in connection with
container types.
Computation using these types will always be executed as a binary
circuit. See :ref:`protocol-pairs` for the exact protocols.
"""
from Compiler.types import MemValue, read_mem_value, regint, Array, cint
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint, sintbit
from Compiler.program import Tape, Program
from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint, library
@@ -14,10 +17,10 @@ from Compiler import instructions_base
import Compiler.GC.instructions as inst
import operator
import math
import itertools
from functools import reduce
class bits(Tape.Register, _structure, _bit):
""" Base class for binary registers. """
n = 40
unit = 64
PreOp = staticmethod(floatingpoint.PreOpN)
@@ -41,7 +44,7 @@ class bits(Tape.Register, _structure, _bit):
return cls.types[length]
@classmethod
def conv(cls, other):
if isinstance(other, cls):
if isinstance(other, cls) and cls.n == other.n:
return other
elif isinstance(other, MemValue):
return cls.conv(other.read())
@@ -56,12 +59,12 @@ class bits(Tape.Register, _structure, _bit):
@classmethod
def bit_compose(cls, bits):
bits = list(bits)
if len(bits) == 1:
if len(bits) == 1 and isinstance(bits[0], cls):
return bits[0]
bits = list(bits)
for i in range(len(bits)):
if util.is_constant(bits[i]):
bits[i] = sbit(bits[i])
bits[i] = cls.bit_type(bits[i])
res = cls.new(n=len(bits))
if len(bits) <= cls.unit:
cls.bitcom(res, *(sbit.conv(bit) for bit in bits))
@@ -111,11 +114,16 @@ class bits(Tape.Register, _structure, _bit):
if mem_type == 'sd':
return cls.load_dynamic_mem(address)
else:
for i in range(res.size):
cls.load_inst[util.is_constant(address)](res[i], address + i)
cls.mem_op(cls.load_inst, res, address)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
self.mem_op(self.store_inst, self, address)
@staticmethod
def mem_op(inst, reg, address):
direct = isinstance(address, int)
if not direct:
address = regint.conv(address)
inst[direct](reg, address)
@classmethod
def new(cls, value=None, n=None):
if util.is_constant(value):
@@ -147,19 +155,29 @@ class bits(Tape.Register, _structure, _bit):
self.set_length(self.n or util.int_len(other))
self.load_int(other)
elif isinstance(other, regint):
assert(other.size == math.ceil(self.n / self.unit))
for i, (x, y) in enumerate(zip(self, other)):
assert self.unit == 64
n_units = int(math.ceil(self.n / self.unit))
n_convs = min(other.size, n_units)
for i in range(n_convs):
x = self[i]
y = other[i]
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
for i in range(n_convs, n_units):
inst.ldbits(self[i], min(self.unit, self.n - i * self.unit), 0)
elif (isinstance(self, type(other)) or isinstance(other, type(self))) \
and self.n == other.n:
for i in range(math.ceil(self.n / self.unit)):
self.mov(self[i], other[i])
elif isinstance(other, sintbit) and isinstance(self, sbits):
assert len(other) == 1
r = sint.get_dabit()
self.mov(self, r[1] ^ other.bit_xor(r[0]).reveal())
elif isinstance(other, sint) and isinstance(self, sbits):
self.mov(self, sbitvec(other, self.n).elements()[0])
else:
try:
bits = other.bit_decompose()
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
bits = bits[:self.n] + [self.bit_type(0)] * (self.n - len(bits))
other = self.bit_compose(bits)
assert(isinstance(other, type(self)))
assert(other.n == self.n)
@@ -184,6 +202,8 @@ class bits(Tape.Register, _structure, _bit):
return 0
elif self.is_long_one(other):
return self
elif isinstance(other, _vec):
return other & other.from_vec([self])
else:
return self._and(other)
@read_mem_value
@@ -222,17 +242,30 @@ class bits(Tape.Register, _structure, _bit):
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
def zero_if_not(self, condition):
if util.is_constant(condition):
return self * condition
else:
return self * cbit.conv(condition)
def expand(self, length):
if self.n in (length, None):
return self
elif self.n == 1:
return self.get_type(length).bit_compose([self] * length)
else:
raise CompilerError('cannot expand from %s to %s' % (self.n, length))
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
max_length = 64
reg_type = 'cb'
is_clear = True
load_inst = (None, inst.ldmcb)
store_inst = (None, inst.stmcb)
load_inst = (inst.ldmcbi, inst.ldmcb)
store_inst = (inst.stmcbi, inst.stmcb)
bitdec = inst.bitdecc
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
conv_cint_vec = inst.convcintvec
mov = staticmethod(lambda x, y: inst.addcbi(x, y, 0))
@classmethod
def bit_compose(cls, bits):
return sum(bit << i for i, bit in enumerate(bits))
@@ -241,14 +274,26 @@ class cbits(bits):
assert n == res.n
assert n == other.size
cls.conv_cint_vec(cint(other, size=other.size), res)
@classmethod
def conv(cls, other):
if isinstance(other, cbits) and cls.n != None and \
cls.n // cls.unit == other.n // cls.unit:
if isinstance(other, cls):
return other
else:
res = cls()
for i in range(math.ceil(cls.n / cls.unit)):
cls.mov(res[i], other[i])
return res
else:
return super(cbits, cls).conv(other)
types = {}
def load_int(self, value):
if self.n <= 64:
tmp = regint(value)
elif value == self.long_one():
tmp = cint(1, size=self.n)
else:
raise CompilerError('loading long integers to cbits not supported')
n_limbs = math.ceil(self.n / self.unit)
tmp = regint(size=n_limbs)
for i in range(n_limbs):
tmp[i].load_int(value % 2 ** self.unit)
value >>= self.unit
self.load_other(tmp)
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
@@ -270,8 +315,15 @@ class cbits(bits):
return op(self, cbits(other))
__add__ = lambda self, other: \
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
__sub__ = lambda self, other: \
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
def __sub__(self, other):
try:
return self + -other
except:
return type(self)(regint(self) - regint(other))
def __rsub__(self, other):
return type(self)(other - regint(self))
def __neg__(self):
return type(self)(-regint(self))
def _xor(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
@@ -363,7 +415,6 @@ class sbits(bits):
reg_type = 'sb'
is_clear = False
clear_type = cbits
default_type = cbits
load_inst = (inst.ldmsbi, inst.ldmsb)
store_inst = (inst.stmsbi, inst.stmsb)
bitdec = inst.bitdecs
@@ -385,16 +436,25 @@ class sbits(bits):
else:
return sbits.get_type(n)(value)
@staticmethod
def _new(value):
return value
@staticmethod
def get_random_bit():
res = sbit()
inst.bitb(res)
return res
@staticmethod
def _check_input_player(player):
if not util.is_constant(player):
raise CompilerError('player must be known at compile time '
'for binary circuit inputs')
@classmethod
def get_input_from(cls, player, n_bits=None):
""" Secret input from :py:obj:`player`.
:param: player (int)
"""
cls._check_input_player(player)
if n_bits is None:
n_bits = cls.n
res = cls()
@@ -463,6 +523,8 @@ class sbits(bits):
if isinstance(other, int):
return self.mul_int(other)
try:
if (self.n, other.n) == (1, 1):
return self & other
if min(self.n, other.n) != 1:
raise NotImplementedError('high order multiplication')
n = max(self.n, other.n)
@@ -554,7 +616,15 @@ class sbits(bits):
rows = list(rows)
if len(rows) == 1 and rows[0].n <= rows[0].unit:
return rows[0].bit_decompose()
n_columns = rows[0].n
for row in rows:
try:
n_columns = row.n
break
except:
pass
for i in range(len(rows)):
if util.is_zero(rows[i]):
rows[i] = cls.get_type(n_columns)(0)
for row in rows:
assert(row.n == n_columns)
if n_columns == 1 and len(rows) <= cls.unit:
@@ -578,7 +648,7 @@ class sbits(bits):
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
class sbitvec(_vec):
class sbitvec(_vec, _bit):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
This facilitates parallel arithmetic operations in binary circuits.
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -586,7 +656,7 @@ class sbitvec(_vec):
You can access the rows by member :py:obj:`v` and the columns by calling
:py:obj:`elements`.
There are three ways to create an instance:
There are four ways to create an instance:
1. By transposition::
@@ -619,8 +689,14 @@ class sbitvec(_vec):
This should output::
[1, 0, 1]
4. Private input::
x = sbitvec.get_type(32).get_input_from(player)
"""
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
is_clear = False
@classmethod
def get_type(cls, n):
""" Create type for fixed-length vector of registers of secret bits.
@@ -634,17 +710,28 @@ class sbitvec(_vec):
return sbit.malloc(size * n, creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@staticmethod
def mem_size():
return n
@classmethod
def get_input_from(cls, player):
def get_input_from(cls, player, size=1, f=0):
""" Secret input from :py:obj:`player`. The input is decomposed
into bits.
:param: player (int)
"""
res = cls.from_vec(sbit() for i in range(n))
inst.inputbvec(n + 3, 0, player, *res.v)
return res
v = [0] * n
sbits._check_input_player(player)
instructions_base.check_vector_size(size)
for i in range(size):
vv = [sbit() for i in range(n)]
inst.inputbvec(n + 3, f, player, *vv)
for j in range(n):
tmp = vv[j] << i
v[j] = tmp ^ v[j]
sbits._check_input_player(player)
return cls.from_vec(v)
get_raw_input_from = get_input_from
@classmethod
def from_vec(cls, vector):
@@ -652,47 +739,54 @@ class sbitvec(_vec):
res.v = _complement_two_extend(list(vector), n)[:n]
return res
def __init__(self, other=None, size=None):
assert size in (None, 1)
instructions_base.check_vector_size(size)
if other is not None:
if util.is_constant(other):
self.v = [sbit((other >> i) & 1) for i in range(n)]
t = sbits.get_type(size or 1)
self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1))
for i in range(n)]
elif isinstance(other, _vec):
self.v = self.bit_extend(other.v, n)
self.v = [type(x)(x) for x in self.bit_extend(other.v, n)]
elif isinstance(other, (list, tuple)):
self.v = self.bit_extend(sbitvec(other).v, n)
else:
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
assert size is None or size == self.v[0].n
@classmethod
def load_mem(cls, address):
def load_mem(cls, address, size=None):
if size not in (None, 1):
assert isinstance(address, int) or len(address) == 1
sb = sbits.get_type(size)
return cls.from_vec(sb.bit_compose(
sbit.load_mem(address + i + j * n) for j in range(size))
for i in range(n))
if not isinstance(address, int) and len(address) == n:
return cls.from_vec(sbit.load_mem(x) for x in address)
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
def store_in_mem(self, address):
size = 1
for x in self.v:
assert util.is_constant(x) or x.n == 1
v = [sbit.conv(x) for x in self.v]
if not util.is_constant(x):
size = max(size, x.n)
v = [sbits.get_type(size).conv(x) for x in self.v]
if not isinstance(address, int) and len(address) == n:
assert max_n == 1
for x, y in zip(v, address):
x.store_in_mem(y)
else:
assert isinstance(address, int) or len(address) == 1
for i in range(n):
v[i].store_in_mem(address + i)
for j, x in enumerate(v[i].bit_decompose()):
x.store_in_mem(address + i + j * n)
def reveal(self):
if len(self) > cbits.unit:
return self.elements()[0].reveal()
revealed = [cbit() for i in range(len(self))]
for i in range(len(self)):
try:
inst.reveal(1, revealed[i], self.v[i])
except:
revealed[i] = cbit.conv(self.v[i])
return cbits.get_type(len(self)).bit_compose(revealed)
return util.untuplify([x.reveal() for x in self.elements()])
@classmethod
def two_power(cls, nn):
return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1))
def two_power(cls, nn, size=1):
return cls.from_vec(
[0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1))
def coerce(self, other):
if util.is_constant(other):
return self.from_vec(util.bit_decompose(other, n))
@@ -705,8 +799,12 @@ class sbitvec(_vec):
bits += [0] * (n - len(bits))
assert len(bits) == n
return cls.from_vec(bits)
def zero_if_not(self, condition):
return self.from_vec(x.zero_if_not(condition) for x in self.v)
def __str__(self):
return 'sbitvec(%d)' % n
sbitvecn.basic_type = sbitvecn
sbitvecn.reg_type = 'sb'
return sbitvecn
@classmethod
def from_vec(cls, vector):
@@ -723,6 +821,15 @@ class sbitvec(_vec):
def from_matrix(cls, matrix):
# any number of rows, limited number of columns
return cls.combine(cls(row) for row in matrix)
@classmethod
def from_hex(cls, string):
""" Create from hexadecimal string (little-endian). """
assert len(string) % 2 == 0
v = []
for i in range(0, len(string), 2):
v += [sbit(int(x))
for x in reversed(bin(int(string[i:i + 2], 16))[2:].zfill(8))]
return cls.from_vec(v)
def __init__(self, elements=None, length=None, input_length=None):
if length:
assert isinstance(elements, sint)
@@ -769,19 +876,20 @@ class sbitvec(_vec):
size = other.size
return (other.get_vector(base, min(64, size - base)) \
for base in range(0, size, 64))
if not isinstance(other, type(self)):
return type(self)(other)
return other
def __xor__(self, other):
other = self.coerce(other)
return self.from_vec(x ^ y for x, y in zip(self.v, other))
return self.from_vec(x ^ y for x, y in zip(*self.expand(other)))
def __and__(self, other):
return self.from_vec(x & y for x, y in zip(self.v, other.v))
return self.from_vec(x & y for x, y in zip(*self.expand(other)))
__rxor__ = __xor__
__rand__ = __and__
def __invert__(self):
return self.from_vec(~x for x in self.v)
def if_else(self, x, y):
assert(len(self.v) == 1)
try:
return self.from_vec(util.if_else(self.v[0], a, b) \
for a, b in zip(x, y))
except:
return util.if_else(self.v[0], x, y)
return util.if_else(self.v[0], x, y)
def __iter__(self):
return iter(self.v)
def __len__(self):
@@ -794,6 +902,7 @@ class sbitvec(_vec):
return cls.from_vec(other.v)
else:
return cls(other)
hard_conv = conv
@property
def size(self):
if not self.v or util.is_constant(self.v[0]):
@@ -806,7 +915,7 @@ class sbitvec(_vec):
def store_in_mem(self, address):
for i, x in enumerate(self.elements()):
x.store_in_mem(address + i)
def bit_decompose(self, n_bits=None, security=None):
def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None):
return self.v[:n_bits]
bit_compose = from_vec
def reveal(self):
@@ -823,6 +932,34 @@ class sbitvec(_vec):
def __mul__(self, other):
if isinstance(other, int):
return self.from_vec(x * other for x in self.v)
if isinstance(other, sbitvec):
if len(other.v) == 1:
other = other.v[0]
elif len(self.v) == 1:
self, other = other, self.v[0]
else:
raise CompilerError('no operand of lenght 1: %d/%d',
(len(self.v), len(other.v)))
if not isinstance(other, sbits):
return NotImplemented
ops = []
for x in self.v:
if not util.is_zero(x):
assert x.n == other.n
ops.append(x)
if ops:
prods = [sbits.get_type(other.n)() for i in ops]
inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops)
res = []
i = 0
for x in self.v:
if util.is_zero(x):
res.append(0)
else:
res.append(prods[i])
i += 1
return sbitvec.from_vec(res)
__rmul__ = __mul__
def __add__(self, other):
return self.from_vec(x + y for x, y in zip(self.v, other))
def bit_and(self, other):
@@ -831,6 +968,60 @@ class sbitvec(_vec):
return self ^ other
def right_shift(self, m, k, security=None, signed=True):
return self.from_vec(self.v[m:])
def tree_reduce(self, function):
elements = self.elements()
while len(elements) > 1:
size = len(elements)
half = size // 2
left = elements[:half]
right = elements[half:2*half]
odd = elements[2*half:]
sides = [self.from_vec(sbitvec(x).v) for x in (left, right)]
red = function(*sides)
elements = red.elements()
elements += odd
return self.from_vec(sbitvec(elements).v)
@classmethod
def comp_result(cls, x):
return cls.get_type(1).from_vec([x])
def expand(self, other, expand=True):
m = 1
for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []):
try:
m = max(m, x.n)
except:
pass
res = []
if not util.is_constant(other):
other = self.coerce(other)
for y in self, other:
if isinstance(y, int):
res.append([x * sbits.get_type(m)().long_one()
for x in util.bit_decompose(y, len(self.v))])
else:
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
return res
def demux(self):
if len(self) == 1:
return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]])
a = sbitvec.from_vec(self.v[:len(self) // 2]).demux()
b = sbitvec.from_vec(self.v[len(self) // 2:]).demux()
prod = [a * bb for bb in b.v]
return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod)))
def reverse_bytes(self):
if len(self.v) % 8 != 0:
raise CompilerError('bit length not divisible by eight')
return self.from_vec(sum(reversed(
[self.v[i:i + 8] for i in range(0, len(self.v), 8)]), []))
def reveal_print_hex(self):
""" Reveal and print in hexademical (one line per element). """
for x in self.reverse_bytes().elements():
x.reveal().print_reg()
def update(self, other):
other = self.conv(other)
assert len(self.v) == len(other.v)
for x, y in zip(self.v, other.v):
x.update(y)
class bit(object):
n = 1
@@ -881,10 +1072,11 @@ class cbit(bit, cbits):
sbits.bit_type = sbit
cbits.bit_type = cbit
sbit.clear_type = cbit
sbits.default_type = sbits
class bitsBlock(oram.Block):
value_type = sbits
def __init__(self, value, start, lengths, entries_per_block):
self.value_type = type(value)
oram.Block.__init__(self, value, lengths)
length = sum(self.lengths)
used_bits = entries_per_block * length
@@ -929,7 +1121,10 @@ sbits.dynamic_array = DynamicArray
cbits.dynamic_array = Array
def _complement_two_extend(bits, k):
return bits[:k] + [bits[-1]] * (k - len(bits))
if len(bits) == 1:
return bits + [0] * (k - len(bits))
else:
return bits[:k] + [bits[-1]] * (k - len(bits))
class _sbitintbase:
def extend(self, n):
@@ -988,6 +1183,9 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
mul: 15
lt: 0
This class is retained for compatibility, but development now
focuses on :py:class:`sbitintvec`.
"""
n_bits = None
bin_type = None
@@ -1079,7 +1277,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
:param k: bit length of input """
return _sbitintbase.pow2(self, k)
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
Vector of signed integers for parallel binary computation::
@@ -1114,19 +1312,34 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
def __add__(self, other):
if util.is_zero(other):
return self
other = self.coerce(other)
assert(len(self.v) == len(other.v))
v = sbitint.bit_adder(self.v, other.v)
return self.from_vec(v)
a, b = self.expand(other)
v = sbitint.bit_adder(a, b)
return self.get_type(len(v)).from_vec(v)
__radd__ = __add__
__sub__ = _bitint.__sub__
def __rsub__(self, other):
a, b = self.expand(other)
return self.from_vec(b) - self.from_vec(a)
def __mul__(self, other):
if isinstance(other, sbits):
return self.from_vec(other * x for x in self.v)
elif len(self.v) == 1:
return other * self.v[0]
elif isinstance(other, sbitfixvec):
return NotImplemented
my_bits, other_bits = self.expand(other, False)
matrix = []
for i, b in enumerate(util.bit_decompose(other)):
matrix.append([x & b for x in self.v[:len(self.v)-i]])
m = float('inf')
for x in itertools.chain(my_bits, other_bits):
try:
m = min(m, x.n)
except:
pass
for i, b in enumerate(other_bits):
if m == 1:
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
else:
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])
__rmul__ = __mul__
@@ -1157,22 +1370,27 @@ class cbitfix(object):
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
@classmethod
def _new(cls, value):
if isinstance(value, list):
return [cls._new(x) for x in value]
res = cls()
if cls.k < value.unit:
bits = value.bit_decompose(cls.k)
sign = bits[-1]
value += (sign << (cls.k)) * -1
res.v = value
return res
def output(self):
v = self.v
if self.k < v.unit:
bits = self.v.bit_decompose(self.k)
sign = bits[-1]
v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
class sbitfix(_fix):
""" Secret signed integer in one binary register.
""" Secret signed fixed-point number in one binary register.
Use :py:obj:`set_precision()` to change the precision.
This class is retained for compatibility, but development now
focuses on :py:class:`sbitfixvec`.
Example::
print_ln('add: %s', (sbitfix(0.5) + sbitfix(0.3)).reveal())
@@ -1211,6 +1429,7 @@ class sbitfix(_fix):
:param: player (int)
"""
sbits._check_input_player(player)
v = cls.int_type()
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
@@ -1233,7 +1452,7 @@ class sbitfix(_fix):
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
class sbitfixvec(_fix):
class sbitfixvec(_fix, _vec):
""" Vector of fixed-point numbers for parallel binary computation.
Use :py:obj:`set_precision()` to change the precision.
@@ -1262,23 +1481,27 @@ class sbitfixvec(_fix):
int_type = sbitintvec.get_type(sbitfix.k)
float_type = type(None)
clear_type = cbitfix
@property
def bit_type(self):
return type(self.v[0])
@classmethod
def set_precision(cls, f, k=None):
super(sbitfixvec, cls).set_precision(f=f, k=k)
cls.int_type = sbitintvec.get_type(cls.k)
@classmethod
def get_input_from(cls, player):
def get_input_from(cls, player, size=1):
""" Secret input from :py:obj:`player`.
:param: player (int)
"""
v = [sbit() for i in range(sbitfix.k)]
inst.inputbvec(len(v) + 3, sbitfix.f, player, *v)
return cls._new(cls.int_type.from_vec(v))
return cls._new(cls.int_type.get_input_from(player, size=size,
f=sbitfix.f))
def __init__(self, value=None, *args, **kwargs):
if isinstance(value, (list, tuple)):
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
else:
if isinstance(value, sbitvec):
value = self.int_type(value)
super(sbitfixvec, self).__init__(value, *args, **kwargs)
def elements(self):
return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()]
@@ -1288,9 +1511,12 @@ class sbitfixvec(_fix):
else:
return super(sbitfixvec, self).mul(other)
def __xor__(self, other):
if util.is_zero(other):
return self
return self._new(self.v ^ other.v)
def __and__(self, other):
return self._new(self.v & other.v)
__rxor__ = __xor__
@staticmethod
def multipliable(other, k, f, size):
class cls(_fix):

View File

@@ -2,30 +2,3 @@ from . import compilerLib, program, instructions, types, library, floatingpoint
from .GC import types as GC_types
import inspect
from .config import *
from .compilerLib import run
# add all instructions to the program VARS dictionary
compilerLib.VARS = {}
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
for mod in (types, GC_types):
instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\
if t[1].__module__ == mod.__name__]
instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\
if t[1].__module__ == library.__name__]
for op in instr_classes:
compilerLib.VARS[op.__name__] = op
# add open and input separately due to name conflict
compilerLib.VARS['open'] = instructions.asm_open
compilerLib.VARS['vopen'] = instructions.vasm_open
compilerLib.VARS['gopen'] = instructions.gasm_open
compilerLib.VARS['vgopen'] = instructions.vgasm_open
compilerLib.VARS['input'] = instructions.asm_input
compilerLib.VARS['ginput'] = instructions.gasm_input
compilerLib.VARS['comparison'] = comparison
compilerLib.VARS['floatingpoint'] = floatingpoint

View File

@@ -15,11 +15,11 @@ from functools import reduce
class BlockAllocator:
""" Manages freed memory blocks. """
def __init__(self):
self.by_logsize = [defaultdict(set) for i in range(32)]
self.by_logsize = [defaultdict(set) for i in range(64)]
self.by_address = {}
def by_size(self, size):
if size >= 2 ** 32:
if size >= 2 ** 64:
raise CompilerError('size exceeds addressing capability')
return self.by_logsize[int(math.log(size, 2))][size]
@@ -101,6 +101,7 @@ class StraightlineAllocator:
self.dealloc |= reg.vector
else:
self.dealloc.add(reg)
reg.duplicates.remove(reg)
base = reg.vectorbase
seen = set_by_id()
@@ -171,7 +172,7 @@ class StraightlineAllocator:
for reg in self.alloc:
for x in reg.get_all():
if x not in self.dealloc and reg not in self.dealloc \
and len(x.duplicates) == 1:
and len(x.duplicates) == 0:
print('Warning: read before write at register', x)
print('\tregister trace: %s' % format_trace(x.caller,
'\t\t'))
@@ -261,6 +262,7 @@ class Merger:
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
self.req_num = defaultdict(lambda: 0)
if not merge_nodes:
return 0
@@ -281,6 +283,7 @@ class Merger:
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
self.req_num[t.__name__, 'round'] += 1
preorder = None
@@ -310,9 +313,9 @@ class Merger:
reg_nodes = {}
last_def = defaultdict_by_id(lambda: -1)
last_read = defaultdict_by_id(list)
last_mem_write = []
last_mem_read = []
warned_about_mem = []
last_mem_write_of = defaultdict(list)
last_mem_read_of = defaultdict(list)
last_print_str = None
@@ -329,6 +332,8 @@ class Merger:
round_type = {}
def add_edge(i, j):
if i in (-1, j):
return
G.add_edge(i, j)
for d in (self.depths, self.real_depths):
if d[j] < d[i]:
@@ -336,10 +341,15 @@ class Merger:
def read(reg, n):
for dup in reg.duplicates:
if last_def[dup] != -1:
if last_def[dup] not in (-1, n):
add_edge(last_def[dup], n)
last_read[reg].append(n)
def write(reg, n):
for dup in reg.duplicates:
add_edge(last_def[dup], n)
for m in last_read[dup]:
add_edge(m, n)
last_def[reg] = n
def handle_mem_access(addr, reg_type, last_access_this_kind,
@@ -361,20 +371,22 @@ class Merger:
addr_i = addr + i
handle_mem_access(addr_i, reg_type, last_access_this_kind,
last_access_other_kind)
if block.warn_about_mem and not warned_about_mem and \
(instr.get_size() > 100):
if block.warn_about_mem and \
not block.parent.warned_about_mem and \
(instr.get_size() > 100) and not instr._protect:
print('WARNING: Order of memory instructions ' \
'not preserved due to long vector, errors possible')
warned_about_mem.append(True)
block.parent.warned_about_mem = True
else:
handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind)
if block.warn_about_mem and not warned_about_mem and \
not isinstance(instr, DirectMemoryInstruction):
if block.warn_about_mem and \
not block.parent.warned_about_mem and \
not isinstance(instr, DirectMemoryInstruction) and \
not instr._protect:
print('WARNING: Order of memory instructions ' \
'not preserved, errors possible')
# hack
warned_about_mem.append(True)
block.parent.warned_about_mem = True
def strict_mem_access(n, last_this_kind, last_other_kind):
if last_other_kind and last_this_kind and \
@@ -403,6 +415,20 @@ class Merger:
add_edge(last_input[t][1], n)
last_input[t][0] = n
def keep_text_order(inst, n):
if inst.get_players() is None:
# switch
for x in list(last_input.keys()):
if isinstance(x, int):
add_edge(last_input[x][0], n)
del last_input[x]
keep_merged_order(instr, n, None)
elif last_input[None][0] is not None:
keep_merged_order(instr, n, None)
else:
for player in inst.get_players():
keep_merged_order(instr, n, player)
for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used()
@@ -411,13 +437,6 @@ class Merger:
# if options.debug:
# col = colordict[instr.__class__.__name__]
# G.add_node(n, color=col, label=str(instr))
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
for reg in outputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
@@ -425,9 +444,16 @@ class Merger:
else:
write(reg, n)
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
# will be merged
if isinstance(instr, TextInputInstruction):
keep_merged_order(instr, n, TextInputInstruction)
keep_text_order(instr, n)
elif isinstance(instr, RawInputInstruction):
keep_merged_order(instr, n, RawInputInstruction)
@@ -456,14 +482,14 @@ class Merger:
depths[n] = depth
if isinstance(instr, ReadMemoryInstruction):
if options.preserve_mem_order:
if options.preserve_mem_order or instr._protect:
strict_mem_access(n, last_mem_read, last_mem_write)
else:
elif not options.preserve_mem_order:
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
elif isinstance(instr, WriteMemoryInstruction):
if options.preserve_mem_order:
if options.preserve_mem_order or instr._protect:
strict_mem_access(n, last_mem_write, last_mem_read)
else:
elif not options.preserve_mem_order:
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
elif isinstance(instr, matmulsm):
if options.preserve_mem_order:
@@ -478,11 +504,7 @@ class Merger:
add_edge(last_print_str, n)
last_print_str = n
elif isinstance(instr, PublicFileIOInstruction):
keep_order(instr, n, instr.__class__)
elif isinstance(instr, startprivateoutput_class):
keep_order(instr, n, startprivateoutput_class, 2)
elif isinstance(instr, stopprivateoutput_class):
keep_order(instr, n, stopprivateoutput_class, 2)
keep_order(instr, n, PublicFileIOInstruction)
elif isinstance(instr, prep_class):
keep_order(instr, n, instr.args[0])
elif isinstance(instr, StackInstruction):
@@ -520,7 +542,9 @@ class Merger:
can_eliminate_defs = True
for reg in inst.get_def():
for dup in reg.duplicates:
if not dup.can_eliminate:
if not (dup.can_eliminate and reduce(
operator.and_,
(x.can_eliminate for x in dup.vector), True)):
can_eliminate_defs = False
break
# remove if instruction has result that isn't used
@@ -535,18 +559,6 @@ class Merger:
if unused_result:
eliminate(i)
count += 1
# remove unnecessary stack instructions
# left by optimization with budget
if isinstance(inst, popint_class) and \
(not G.degree(i) or (G.degree(i) == 1 and
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
and \
inst.args[0].can_eliminate and \
len(G.pred[i]) == 1 and \
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
eliminate(list(G.pred[i])[0])
eliminate(i)
count += 2
if count > 0 and self.block.parent.program.verbose:
print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats)))
@@ -570,8 +582,15 @@ class Merger:
class RegintOptimizer:
def __init__(self):
self.cache = util.dict_by_id()
self.offset_cache = util.dict_by_id()
self.rev_offset_cache = {}
def run(self, instructions):
def add_offset(self, res, new_base, new_offset):
self.offset_cache[res] = new_base, new_offset
if (new_base.i, new_offset) not in self.rev_offset_cache:
self.rev_offset_cache[new_base.i, new_offset] = res
def run(self, instructions, program):
for i, inst in enumerate(instructions):
if isinstance(inst, ldint_class):
self.cache[inst.args[0]] = inst.args[1]
@@ -584,15 +603,35 @@ class RegintOptimizer:
instructions[i] = ldint(inst.args[0], res,
add_to_prog=False)
elif isinstance(inst, addint_class):
if inst.args[1] in self.cache and \
self.cache[inst.args[1]] == 0:
instructions[i] = inst.args[0].link(inst.args[2])
elif inst.args[2] in self.cache and \
self.cache[inst.args[2]] == 0:
instructions[i] = inst.args[0].link(inst.args[1])
def f(base, delta_reg):
delta = self.cache[delta_reg]
if base in self.offset_cache:
reg, offset = self.offset_cache[base]
new_base, new_offset = reg, offset + delta
else:
new_base, new_offset = base, delta
self.add_offset(inst.args[0], new_base, new_offset)
if inst.args[1] in self.cache:
f(inst.args[2], inst.args[1])
elif inst.args[2] in self.cache:
f(inst.args[1], inst.args[2])
elif isinstance(inst, subint_class) and \
inst.args[2] in self.cache:
delta = self.cache[inst.args[2]]
if inst.args[1] in self.offset_cache:
reg, offset = self.offset_cache[inst.args[1]]
new_base, new_offset = reg, offset - delta
else:
new_base, new_offset = inst.args[1], -delta
self.add_offset(inst.args[0], new_base, new_offset)
elif isinstance(inst, IndirectMemoryInstruction):
if inst.args[1] in self.cache:
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
instructions[i]._protect = inst._protect
elif inst.args[1] in self.offset_cache:
base, offset = self.offset_cache[inst.args[1]]
addr = self.rev_offset_cache[base.i, offset]
inst.args[1] = addr
elif type(inst) == convint_class:
if inst.args[1] in self.cache:
res = self.cache[inst.args[1]]
@@ -606,7 +645,13 @@ class RegintOptimizer:
if op == 0:
instructions[i] = ldsi(inst.args[0], 0,
add_to_prog=False)
elif op == 1:
elif isinstance(inst, (crash, cond_print_str, cond_print_plain)):
if inst.args[0] in self.cache:
cond = self.cache[inst.args[0]]
if not cond:
instructions[i] = None
inst.args[0].link(inst.args[1])
pre = len(instructions)
instructions[:] = list(filter(lambda x: x is not None, instructions))
post = len(instructions)
if pre != post and program.options.verbose:
print('regint optimizer removed %d instructions' % (pre - post))

View File

@@ -63,7 +63,7 @@ class Circuit:
i = 0
for l in self.n_output_wires:
v = []
for i in range(l):
for j in range(l):
v.append(flat_res[i])
i += 1
res.append(sbitvec.from_vec(v))
@@ -127,18 +127,24 @@ def sha3_256(x):
from circuit import sha3_256
a = sbitvec.from_vec([])
b = sbitvec(sint(0xcc), 8, 8)
for x in a, b:
sha3_256(x).elements()[0].reveal().print_reg()
b = sbitvec.from_hex('cc')
c = sbitvec.from_hex('41fb')
d = sbitvec.from_hex('1f877c')
e = sbitvec.from_vec([sbit(0)] * 8)
for x in a, b, c, d, e:
sha3_256(x).reveal_print_hex()
This should output the first two test vectors of SHA3-256 in
byte-reversed order::
This should output the `test vectors
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
0 byte::
0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7
0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067
Reg[0] = 0xa7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a #
Reg[0] = 0x677035391cd3701293d385f037ba32796252bb7ce180b00b582dd9b20aaad7f0 #
Reg[0] = 0x39f31b6e653dfcd9caed2602fd87f61b6254f581312fb6eeec4d7148fa2e72aa #
Reg[0] = 0xbc22345e4bd3f792a341cf18ac0789f1c9c966712a501b19d1b6632ccd408ec5 #
Reg[0] = 0x5d53469f20fef4f8eab52b88044ede69c77a6a68a60728609fc4a65ff531e7d0 #
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
implemented for computation modulo a power of two.
"""
global Keccak_f
@@ -236,10 +242,10 @@ class ieee_float:
return cls._circuits[name]
def __init__(self, value):
if isinstance(value, sbitvec):
if isinstance(value, (sbitint, sbitintvec)):
self.value = self.circuit('i2f')(sbitvec.conv(value))
elif isinstance(value, sbitvec):
self.value = value
elif isinstance(value, (sbitint, sbitintvec)):
self.value = self.circuit('i2f')(sbitvec(value))
elif util.is_constant_float(value):
self.value = sbitvec(sbits.get_type(64)(
struct.unpack('Q', struct.pack('d', value))[0]))

View File

@@ -1,5 +1,6 @@
from Compiler.path_oram import *
from Compiler.oram import *
from Compiler.path_oram import PathORAM, XOR
from Compiler.util import bit_compose
def first_diff(a_bits, b_bits):

View File

@@ -50,6 +50,9 @@ def set_variant(options):
do_precomp = False
elif variant is not None:
raise CompilerError('Unknown comparison variant: %s' % variant)
if const_rounds and instructions_base.program.options.binary:
raise CompilerError(
'Comparison variant choice incompatible with binary circuits')
def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """
@@ -77,30 +80,28 @@ def LTZ(s, a, k, kappa):
k: bit length of a
"""
movs(s, program.non_linear.ltz(a, k, kappa))
def LtzRing(a, k):
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
summands = a.split_to_two_summands(k)
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
movs(s, sint.conv(msb))
return
elif program.options.ring:
return sint.conv(msb)
else:
from . import floatingpoint
require_ring_size(k, 'comparison')
m = k - 1
shift = int(program.options.ring) - k
r_prime, r_bin = MaskingBitsInRing(k)
tmp = a - r_prime
c_prime = (tmp << shift).reveal() >> shift
c_prime = (tmp << shift).reveal(False) >> shift
a = r_bin[0].bit_decompose_clear(c_prime, m)
b = r_bin[:m]
u = CarryOutRaw(a[::-1], b[::-1])
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
return
t = sint()
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
def LessThanZero(a, k, kappa):
from . import types
@@ -191,7 +192,7 @@ def TruncLeakyInRing(a, k, m, signed):
r = sint.bit_compose(r_bits)
if signed:
a += (1 << (k - 1))
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits[:n_bits], 0)
@@ -232,7 +233,7 @@ def Mod2mRing(a_prime, a, k, m, signed):
shift = int(program.options.ring) - m
r_prime, r_bin = MaskingBitsInRing(m, True)
tmp = a + r_prime
c_prime = (tmp << shift).reveal() >> shift
c_prime = (tmp << shift).reveal(False) >> shift
u = sint()
BitLTL(u, c_prime, r_bin[:m], 0)
res = (u << m) + c_prime - r_prime
@@ -262,7 +263,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
asm_open(c, t[3])
asm_open(True, c, t[3])
modc(c_prime, c, c2m)
if const_rounds:
BitLTC1(u, c_prime, r, kappa)
@@ -293,7 +294,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
"""
program.curr_tape.require_bit_length(k + kappa)
from .types import sint
if program.use_edabit() and m > 1 and not const_rounds:
if program.use_edabit() and not const_rounds:
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
tmp, b[:] = sint.get_edabit(m, True)
movs(r_prime, tmp)
@@ -511,7 +512,7 @@ def PreMulC_with_inverses_and_vectors(p, a):
movs(w[0], r[0])
movs(a_vec[0], a[0])
vmuls(k, t[0], w, a_vec)
vasm_open(k, m, t[0])
vasm_open(k, True, m, t[0])
PreMulC_end(p, a, c, m, z)
def PreMulC_with_inverses(p, a):
@@ -539,7 +540,7 @@ def PreMulC_with_inverses(p, a):
w[1][0] = r[0][0]
for i in range(k):
muls(t[0][i], w[1][i], a[i])
asm_open(m[i], t[0][i])
asm_open(True, m[i], t[0][i])
PreMulC_end(p, a, c, m, z)
def PreMulC_without_inverses(p, a):
@@ -564,7 +565,7 @@ def PreMulC_without_inverses(p, a):
#adds(tt[0][i], t[0][i], a[i])
#subs(tt[1][i], tt[0][i], a[i])
#startopen(tt[1][i])
asm_open(u[i], t[0][i])
asm_open(True, u[i], t[0][i])
for i in range(k-1):
muls(v[i], r[i+1], s[i])
w[0] = r[0]
@@ -580,7 +581,7 @@ def PreMulC_without_inverses(p, a):
mulm(z[i], s[i], u_inv[i])
for i in range(k):
muls(t[1][i], w[i], a[i])
asm_open(m[i], t[1][i])
asm_open(True, m[i], t[1][i])
PreMulC_end(p, a, c, m, z)
def PreMulC_end(p, a, c, m, z):
@@ -638,6 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
t = [program.curr_block.new_reg('s') for i in range(6)]
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
r_0 = r_prime
mulsi(t[0], r_dprime, 2)
if signed:
ld2i(c2k1, k - 1)
@@ -646,7 +648,7 @@ def Mod2(a_0, a, k, kappa, signed):
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
asm_open(c, t[3])
asm_open(True, c, t[3])
from . import floatingpoint
c_0 = floatingpoint.bits(c, 1)[0]
mulci(tc, c_0, 2)

View File

@@ -1,94 +1,568 @@
from Compiler.program import Program
from .GC import types as GC_types
import inspect
import os
import re
import sys
import re, tempfile, os
import tempfile
import subprocess
from optparse import OptionParser
from Compiler.exceptions import CompilerError
from .GC import types as GC_types
from .program import Program, defaults
def run(args, options):
""" Compile a file and output a Program object.
If options.merge_opens is set to True, will attempt to merge any
parallelisable open instructions. """
prog = Program(args, options)
VARS['program'] = prog
if options.binary:
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
VARS['sfix'] = GC_types.sbitfixvec
for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \
'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \
'squant':
del VARS[i]
print('Compiling file', prog.infile)
f = open(prog.infile, 'rb')
changed = False
if options.flow_optimization:
output = []
if_stack = []
for line in open(prog.infile):
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
line)
if m:
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
m.group(3)))
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
changed = True
continue
m = re.match('(\s*)if(\W.*):', line)
if m:
if_stack.append((m.group(1), len(output)))
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
output.append('%sdef _():\n' % (m.group(1)))
changed = True
continue
m = re.match('(\s*)elif\s+', line)
if m:
raise CompilerError('elif not supported')
if if_stack:
m = re.match('%selse:' % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
output[start])
output.append('%s@else_\n' % ws)
output.append('%sdef _():\n' % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile('w+', delete=False)
for line in output:
infile.write(line)
infile.seek(0)
class Compiler:
def __init__(self, custom_args=None, usage=None, execute=False):
if usage:
self.usage = usage
else:
infile = open(prog.infile)
else:
infile = open(prog.infile)
self.usage = "usage: %prog [options] filename [args]"
self.execute = execute
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'
# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
exec(compile(infile.read(), infile.name, 'exec'), VARS)
def build_option_parser(self):
parser = OptionParser(usage=self.usage)
parser.add_option(
"-n",
"--nomerge",
action="store_false",
dest="merge_opens",
default=defaults.merge_opens,
help="don't attempt to merge open instructions",
)
parser.add_option("-o", "--output", dest="outfile", help="specify output file")
parser.add_option(
"-a",
"--asm-output",
dest="asmoutfile",
help="asm output file for debugging",
)
parser.add_option(
"-g",
"--galoissize",
dest="galois",
default=defaults.galois,
help="bit length of Galois field",
)
parser.add_option(
"-d",
"--debug",
action="store_true",
dest="debug",
help="keep track of trace for debugging",
)
parser.add_option(
"-c",
"--comparison",
dest="comparison",
default="log",
help="comparison variant: log|plain|inv|sinv",
)
parser.add_option(
"-M",
"--preserve-mem-order",
action="store_true",
dest="preserve_mem_order",
default=defaults.preserve_mem_order,
help="preserve order of memory instructions; possible efficiency loss",
)
parser.add_option(
"-O",
"--optimize-hard",
action="store_true",
dest="optimize_hard",
help="lower number of rounds at higher compilation cost "
"(disables -C and increases the budget to 100000)",
)
parser.add_option(
"-u",
"--noreallocate",
action="store_true",
dest="noreallocate",
default=defaults.noreallocate,
help="don't reallocate",
)
parser.add_option(
"-m",
"--max-parallel-open",
dest="max_parallel_open",
default=defaults.max_parallel_open,
help="restrict number of parallel opens",
)
parser.add_option(
"-D",
"--dead-code-elimination",
action="store_true",
dest="dead_code_elimination",
default=defaults.dead_code_elimination,
help="eliminate instructions with unused result",
)
parser.add_option(
"-p",
"--profile",
action="store_true",
dest="profile",
help="profile compilation",
)
parser.add_option(
"-s",
"--stop",
action="store_true",
dest="stop",
help="stop on register errors",
)
parser.add_option(
"-R",
"--ring",
dest="ring",
default=defaults.ring,
help="bit length of ring (default: 0 for field)",
)
parser.add_option(
"-B",
"--binary",
dest="binary",
default=defaults.binary,
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
)
parser.add_option(
"-G",
"--garbled-circuit",
dest="garbled",
action="store_true",
help="compile for binary circuits only (default: false)",
)
parser.add_option(
"-F",
"--field",
dest="field",
default=defaults.field,
help="bit length of sint modulo prime (default: 64)",
)
parser.add_option(
"-P",
"--prime",
dest="prime",
default=defaults.prime,
help="prime modulus (default: not specified)",
)
parser.add_option(
"-I",
"--insecure",
action="store_true",
dest="insecure",
help="activate insecure functionality for benchmarking",
)
parser.add_option(
"-b",
"--budget",
dest="budget",
help="set budget for optimized loop unrolling (default: %d)" % \
defaults.budget,
)
parser.add_option(
"-X",
"--mixed",
action="store_true",
dest="mixed",
help="mixing arithmetic and binary computation",
)
parser.add_option(
"-Y",
"--edabit",
action="store_true",
dest="edabit",
help="mixing arithmetic and binary computation using edaBits",
)
parser.add_option(
"-Z",
"--split",
default=defaults.split,
dest="split",
help="mixing arithmetic and binary computation "
"using direct conversion if supported "
"(number of parties as argument)",
)
parser.add_option(
"--invperm",
action="store_true",
dest="invperm",
help="speedup inverse permutation (only use in two-party, "
"semi-honest environment)"
)
parser.add_option(
"-C",
"--CISC",
action="store_true",
dest="cisc",
help="faster CISC compilation mode "
"(used by default unless -O is given)",
)
parser.add_option(
"-K",
"--keep-cisc",
dest="keep_cisc",
help="don't translate CISC instructions",
)
parser.add_option(
"-l",
"--flow-optimization",
action="store_true",
dest="flow_optimization",
help="optimize control flow",
)
parser.add_option(
"-v",
"--verbose",
action="store_true",
dest="verbose",
help="more verbose output",
)
if self.execute:
parser.add_option(
"-E",
"--execute",
dest="execute",
help="protocol to execute with",
)
parser.add_option(
"-H",
"--hostfile",
dest="hostfile",
help="hosts to execute with",
)
self.parser = parser
if changed and not options.debug:
os.unlink(infile.name)
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.execute:
if not self.options.execute:
raise CompilerError("must give name of protocol with '-E'")
protocol = self.options.execute
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
protocol.find("brain") >= 0 or protocol == "emulate":
if not (self.options.ring or self.options.binary):
self.options.ring = "64"
if self.options.field:
raise CompilerError(
"field option not compatible with %s" % protocol)
else:
if protocol.find("bin") >= 0 or protocol.find("ccd") >= 0 or \
protocol.find("bmr") >= 0 or \
protocol in ("replicated", "tinier", "tiny", "yao"):
if not self.options.binary:
self.options.binary = "32"
if self.options.ring or self.options.field:
raise CompilerError(
"ring/field options not compatible with %s" %
protocol)
if self.options.ring:
raise CompilerError(
"ring option not compatible with %s" % protocol)
if protocol == "emulate":
self.options.keep_cisc = ''
prog.finalize()
def build_program(self, name=None):
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute in \
("emulate", "ring", "rep-field"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring",):
self.prog.use_split(3)
if self.options.execute in ("semi2k",):
self.prog.use_split(2)
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
if prog.req_num:
print('Program requires:')
for x in prog.req_num.pretty():
print(x)
def build_vars(self):
from . import comparison, floatingpoint, instructions, library, types
if prog.verbose:
print('Program requires:', repr(prog.req_num))
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
print('Memory size:', dict(prog.allocated_mem))
# add all instructions to the program VARS dictionary
instr_classes = [
t[1] for t in inspect.getmembers(instructions, inspect.isclass)
]
return prog
for mod in (types, GC_types):
instr_classes += [
t[1]
for t in inspect.getmembers(mod, inspect.isclass)
if t[1].__module__ == mod.__name__
]
instr_classes += [
t[1]
for t in inspect.getmembers(library, inspect.isfunction)
if t[1].__module__ == library.__name__
]
for op in instr_classes:
self.VARS[op.__name__] = op
# backward compatibility for deprecated classes
self.VARS["sbitint"] = GC_types.sbitintvec
self.VARS["sbitfix"] = GC_types.sbitfixvec
# add open and input separately due to name conflict
self.VARS["vopen"] = instructions.vasm_open
self.VARS["gopen"] = instructions.gasm_open
self.VARS["vgopen"] = instructions.vgasm_open
self.VARS["ginput"] = instructions.gasm_input
self.VARS["comparison"] = comparison
self.VARS["floatingpoint"] = floatingpoint
self.VARS["program"] = self.prog
if self.options.binary:
self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary))
self.VARS["sfix"] = GC_types.sbitfixvec
for i in [
"cint",
"cfix",
"cgf2n",
"sintbit",
"sgf2n",
"sgf2nint",
"sgf2nuint",
"sgf2nuint32",
"sgf2nfloat",
"cfloat",
"squant",
]:
del self.VARS[i]
def prep_compile(self, name=None, build=True):
self.parse_args()
if len(self.args) < 1 and name is None:
self.parser.print_help()
exit(1)
if build:
self.build(name=name)
def build(self, name=None):
self.build_program(name=name)
self.build_vars()
def compile_file(self):
"""Compile a file and output a Program object.
If options.merge_opens is set to True, will attempt to merge any
parallelisable open instructions."""
print("Compiling file", self.prog.infile)
with open(self.prog.infile, "r") as f:
changed = False
if self.options.flow_optimization:
output = []
if_stack = []
for line in f:
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_.]+)\):",
line,
)
if m:
output.append(
"%s@for_range_opt(%s)\n" % (m.group(1), m.group(3))
)
output.append("%sdef _(%s):\n" % (m.group(1), m.group(2)))
changed = True
continue
m = re.match(r"(\s*)if(\W.*):", line)
if m:
if_stack.append((m.group(1), len(output)))
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
output.append("%sdef _():\n" % (m.group(1)))
changed = True
continue
m = re.match(r"(\s*)elif\s+", line)
if m:
raise CompilerError("elif not supported")
if if_stack:
m = re.match("%selse:" % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(
r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start]
)
output.append("%s@else_\n" % ws)
output.append("%sdef _():\n" % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile("w+", delete=False)
for line in output:
infile.write(line)
infile.seek(0)
else:
infile = open(self.prog.infile)
else:
infile = open(self.prog.infile)
# make compiler modules directly accessible
sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
if changed and not self.options.debug:
os.unlink(infile.name)
return self.finalize_compile()
def register_function(self, name=None):
"""
To register a function to be compiled, use this as a decorator.
Example:
@compiler.register_function('test-mpc')
def test_mpc(compiler):
...
"""
def inner(func):
self.compile_name = name or func.__name__
self.compile_function = func
return func
return inner
def compile_func(self):
if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")):
raise CompilerError(
"No function to compile. "
"Did you decorate a function with @register_fuction(name)?"
)
self.prep_compile(self.compile_name)
print(
"Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__)
)
self.compile_function()
self.finalize_compile()
def finalize_compile(self):
self.prog.finalize()
if self.prog.req_num:
print("Program requires at most:")
for x in self.prog.req_num.pretty():
print(x)
if self.prog.verbose:
print("Program requires:", repr(self.prog.req_num))
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
print("Memory size:", dict(self.prog.allocated_mem))
return self.prog
@staticmethod
def executable_from_protocol(protocol):
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
if protocol in match:
protocol = match[protocol]
if protocol.find("bmr") == -1:
protocol = re.sub("^mal-", "malicious-", protocol)
if protocol == "emulate":
return protocol + ".x"
else:
return protocol + "-party.x"
def local_execution(self, args=[]):
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
try:
subprocess.run(["make", executable], check=True, cwd=self.root)
except:
raise CompilerError(
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]):
vm = self.executable_from_protocol(self.options.execute)
hosts = list(x.strip()
for x in filter(None, open(self.options.hostfile)))
# test availability before compilation
from fabric import Connection
import subprocess
print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
# transfer files
import glob
hostnames = []
destinations = []
for host in hosts:
split = host.split('/', maxsplit=1)
hostnames.append(split[0])
if len(split) > 1:
destinations.append(split[1])
else:
destinations.append('.')
connections = [Connection(hostname) for hostname in hostnames]
print("Setting up players...")
def run(i):
dest = destinations[i]
connection = connections[i]
connection.run(
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
# executable
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
dest + "Programs/Schedules")
for filename in glob.glob(
"Programs/Bytecode/%s-*.bc" % self.prog.name):
connection.put(filename, dest + "Programs/Bytecode")
# inputs
for filename in glob.glob("Player-Data/Input*-P%d-*" % i):
connection.put(filename, dest + "Player-Data")
# key and certificates
for suffix in ('key', 'pem'):
connection.put("Player-Data/P%d.%s" % (i, suffix),
dest + "Player-Data")
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
import threading
import random
threads = []
for i in range(len(hosts)):
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# execution
threads = []
# random port numbers to avoid conflict
port = 10000 + random.randrange(40000)
if '@' in hostnames[0]:
party0 = hostnames[0].split('@')[1]
else:
party0 = hostnames[0]
for i in range(len(connections)):
run = lambda i: connections[i].run(
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
(destinations[i], vm, i, self.prog.name, party0, port,
' '.join(args)))
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()

645
Compiler/decision_tree.py Normal file
View File

@@ -0,0 +1,645 @@
from Compiler.types import *
from Compiler.sorting import *
from Compiler.library import *
from Compiler import util, oram
from itertools import accumulate
import math
debug = False
debug_split = False
max_leaves = None
def get_type(x):
if isinstance(x, (Array, SubMultiArray)):
return x.value_type
elif isinstance(x, (tuple, list)):
x = x[0] + x[-1]
if util.is_constant(x):
return cint
else:
return type(x)
else:
return type(x)
def PrefixSum(x):
return x.get_vector().prefix_sum()
def PrefixSumR(x):
tmp = get_type(x).Array(len(x))
tmp.assign_vector(x)
break_point()
tmp[:] = tmp.get_reverse_vector().prefix_sum()
break_point()
return tmp.get_reverse_vector()
def PrefixSum_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x, base=1)
tmp[0] = 0
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))
def PrefixSumR_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x)
tmp[-1] = 0
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))
class SortPerm:
def __init__(self, x):
B = sint.Matrix(len(x), 2)
B.set_column(0, 1 - x.get_vector())
B.set_column(1, x.get_vector())
self.perm = Array.create_from(dest_comp(B))
def apply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, False)
return res
def unapply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, True)
return res
def Sort(keys, *to_sort, n_bits=None, time=False):
if time:
start_timer(1)
for k in keys:
assert len(k) == len(keys[0])
n_bits = n_bits or [None] * len(keys)
bs = Matrix.create_from(
sum([k.get_vector().bit_decompose(nb)
for k, nb in reversed(list(zip(keys, n_bits)))], []))
get_vec = lambda x: x[:] if isinstance(x, Array) else x
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
for x in to_sort)
res = res.transpose()
if time:
start_timer(11)
radix_sort_from_matrix(bs, res)
if time:
stop_timer(11)
stop_timer(1)
res = res.transpose()
return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
if isinstance(get_vec(y), sfix)
else x for (x, y) in zip(res, to_sort)]
def VectMax(key, *data, debug=False):
def reducer(x, y):
b = x[0] > y[0]
if debug:
print_ln('max b=%s', b.reveal())
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
if debug:
key = list(key)
data = [list(x) for x in data]
print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
res = util.tree_reduce(reducer, zip(key, *data))[1:]
if debug:
print_ln('vect max res=%s', util.reveal(res))
return res
def GroupSum(g, x):
assert len(g) == len(x)
p = PrefixSumR(x) * g
pi = SortPerm(g.get_vector().bit_not())
p1 = pi.apply(p)
s1 = PrefixSumR_inv(p1)
d1 = PrefixSum_inv(s1)
d = pi.unapply(d1) * g
return PrefixSum(d)
def GroupPrefixSum(g, x):
assert len(g) == len(x)
s = get_type(x).Array(len(x) + 1)
s[0] = 0
s.assign_vector(PrefixSum(x), base=1)
q = get_type(s).Array(len(x))
q.assign_vector(s.get_vector(size=len(x)) * g)
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
def GroupMax(g, keys, *x):
if debug:
print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
util.reveal(keys), util.reveal(x))
assert len(keys) == len(g)
for xx in x:
assert len(xx) == len(g)
n = len(g)
m = int(math.ceil(math.log(n, 2)))
keys = Array.create_from(keys)
x = [Array.create_from(xx) for xx in x]
g_new = Array.create_from(g)
g_old = g_new.same_shape()
for d in range(m):
w = 2 ** d
g_old[:] = g_new[:]
break_point()
vsize = n - w
g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
g_old.get_vector(size=vsize, base=w)), base=w)
b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
for xx in [keys] + x:
a = b.if_else(xx.get_vector(size=vsize),
xx.get_vector(size=vsize, base=w))
xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
xx.get_vector(size=vsize, base=w), a), base=w)
break_point()
if debug:
print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
util.reveal(a), util.reveal(keys),
util.reveal(x), g_new.reveal())
t = sint.Array(len(g))
t[-1] = 1
t.assign_vector(g.get_vector(size=n - 1, base=1))
if debug:
print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
util.reveal(t), util.reveal(keys), util.reveal(x))
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
def ModifiedGini(g, y, debug=False):
assert len(g) == len(y)
y = [y.get_vector().bit_not(), y]
u = [GroupPrefixSum(g, yy) for yy in y]
s = [GroupSum(g, yy) for yy in y]
w = [ss - uu for ss, uu in zip(s, u)]
us = sum(u)
ws = sum(w)
uqs = u[0] ** 2 + u[1] ** 2
wqs = w[0] ** 2 + w[1] ** 2
res = sfix(uqs) / us + sfix(wqs) / ws
if debug:
print_ln('g=%s y=%s s=%s',
util.reveal(g), util.reveal(y),
util.reveal(s))
print_ln('u0=%s', util.reveal(u[0]))
print_ln('u0=%s', util.reveal(u[1]))
print_ln('us=%s', util.reveal(us))
print_ln('w0=%s', util.reveal(w[0]))
print_ln('w1=%s', util.reveal(w[1]))
print_ln('ws=%s', util.reveal(ws))
print_ln('uqs=%s', util.reveal(uqs))
print_ln('wqs=%s', util.reveal(wqs))
if debug:
print_ln('gini %s %s', type(res), util.reveal(res))
return res
MIN_VALUE = -10000
def FormatLayer(h, g, *a):
return CropLayer(h, *FormatLayer_without_crop(g, *a))
def FormatLayer_without_crop(g, *a, debug=False):
for x in a:
assert len(x) == len(g)
v = [g.if_else(aa, 0) for aa in a]
if debug:
print_ln('format in %s', util.reveal(a))
print_ln('format mux %s', util.reveal(v))
v = Sort([g.bit_not()], *v, n_bits=[1])
if debug:
print_ln('format sort %s', util.reveal(v))
return v
def CropLayer(k, *v):
if max_leaves:
n = min(2 ** k, max_leaves)
else:
n = 2 ** k
return [vv[:min(n, len(vv))] for vv in v]
def TrainLeafNodes(h, g, y, NID):
assert len(g) == len(y)
assert len(g) == len(NID)
Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
return FormatLayer(h, g, NID, Label)
def GroupSame(g, y):
assert len(g) == len(y)
s = GroupSum(g, [sint(1)] * len(g))
s0 = GroupSum(g, y.bit_not())
s1 = GroupSum(g, y)
if debug_split:
print_ln('group same g=%s', util.reveal(g))
print_ln('group same y=%s', util.reveal(y))
return (s == s0).bit_or(s == s1)
def GroupFirstOne(g, b):
assert len(g) == len(b)
s = GroupPrefixSum(g, b)
return s * b == 1
class TreeTrainer:
""" Decision tree training by `Hamada et al.`_
:param x: sample data (by attribute, list or
:py:obj:`~Compiler.types.Matrix`)
:param y: binary labels (list or sint vector)
:param h: height (int)
:param binary: binary attributes instead of continuous
:param attr_lengths: attribute description for mixed data
(list of 0/1 for continuous/binary)
:param n_threads: number of threads (default: single thread)
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
"""
def ApplyTests(self, x, AID, Threshold):
m = len(x)
n = len(AID)
assert len(AID) == len(Threshold)
for xx in x:
assert len(xx) == len(AID)
e = sint.Matrix(m, n)
AID = Array.create_from(AID)
@for_range_multithread(self.n_threads, 1, m)
def _(j):
e[j][:] = AID[:] == j
xx = sum(x[j] * e[j] for j in range(m))
if self.debug > 1:
print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx))
print_ln('threshold %s', util.reveal(Threshold))
return 2 * xx < Threshold
def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False):
assert len(g) == len(x)
assert len(g) == len(y)
if time:
start_timer(2)
s = ModifiedGini(g, y, debug=debug or self.debug > 2)
if time:
stop_timer(2)
if debug or self.debug > 1:
print_ln('gini %s', s.reveal())
xx = x
t = get_type(x).Array(len(x))
t[-1] = MIN_VALUE
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
xx.get_vector(size=len(x) - 1, base=1))
gg = g
p = sint.Array(len(x))
p[-1] = 1
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
xx.get_vector(size=len(x) - 1) == \
xx.get_vector(size=len(x) - 1, base=1)))
break_point()
if debug:
print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p))
s = p[:].if_else(MIN_VALUE, s)
t = p[:].if_else(MIN_VALUE, t[:])
if debug:
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
if time:
start_timer(3)
s, t = GroupMax(gg, s, t)
if time:
stop_timer(3)
if debug:
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
return t, s
def GlobalTestSelection(self, x, y, g):
assert len(y) == len(g)
for xx in x:
assert(len(xx) == len(g))
m = len(x)
n = len(y)
u, t = [get_type(x).Matrix(m, n) for i in range(2)]
v = get_type(y).Matrix(m, n)
s = sfix.Matrix(m, n)
@for_range_multithread(self.n_threads, 1, m)
def _(j):
single = not self.n_threads or self.n_threads == 1
time = self.time and single
if debug:
print_ln('run %s', j)
@if_e(self.attr_lengths[j])
def _():
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
n_bits=[util.log2(n), 1], time=time)
@else_
def _():
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
n_bits=[util.log2(n), None],
time=time)
if self.debug_threading:
print_ln('global sort %s %s %s', j, util.reveal(u[j]),
util.reveal(v[j]))
t[j][:], s[j][:] = self.AttributeWiseTestSelection(
g, u[j], v[j], time=time, debug=self.debug_selection)
if self.debug_threading:
print_ln('global attribute %s %s %s', j, util.reveal(t[j]),
util.reveal(s[j]))
n = len(g)
a = sint.Array(n)
if self.debug_threading:
print_ln('global s=%s', util.reveal(s))
if self.debug_gini:
print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)),
*(ss[0].reveal() for ss in s))
if self.time:
start_timer(4)
if self.debug > 1:
print_ln('s=%s', s.reveal_nested())
print_ln('t=%s', t.reveal_nested())
a[:], tt = VectMax((s[j][:] for j in range(m)), range(m),
(t[j][:] for j in range(m)), debug=self.debug > 1)
tt = Array.create_from(tt)
if self.time:
stop_timer(4)
if self.debug > 1:
print_ln('a=%s', util.reveal(a))
print_ln('tt=%s', util.reveal(tt))
return a[:], tt[:]
def TrainInternalNodes(self, k, x, y, g, NID):
assert len(g) == len(y)
for xx in x:
assert len(xx) == len(g)
AID, Threshold = self.GlobalTestSelection(x, y, g)
s = GroupSame(g[:], y[:])
if self.debug > 1 or debug_split:
print_ln('AID=%s', util.reveal(AID))
print_ln('Threshold=%s', util.reveal(Threshold))
print_ln('GroupSame=%s', util.reveal(s))
AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold)
if self.debug > 1 or debug_split:
print_ln('AID=%s', util.reveal(AID))
print_ln('Threshold=%s', util.reveal(Threshold))
b = self.ApplyTests(x, AID, Threshold)
layer = FormatLayer_without_crop(g[:], NID, AID, Threshold,
debug=self.debug > 1)
return *layer, b
@method_block
def train_layer(self, k):
x = self.x
y = self.y
g = self.g
NID = self.NID
if self.debug > 1:
print_ln('g=%s', g.reveal())
print_ln('y=%s', y.reveal())
print_ln('x=%s', x.reveal_nested())
self.nids[k], self.aids[k], self.thresholds[k], b = \
self.TrainInternalNodes(k, x, y, g, NID)
if self.debug > 1:
print_ln('layer %s:', k)
for name, data in zip(('NID', 'AID', 'Thr'),
(self.nids[k], self.aids[k],
self.thresholds[k])):
print_ln(' %s: %s', name, data.reveal())
NID[:] = 2 ** k * b + NID
b_not = b.bit_not()
if self.debug > 1:
print_ln('b_not=%s', b_not.reveal())
g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b)
y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1])
for i, xxx in enumerate(xx):
x[i] = xxx
def __init__(self, x, y, h, binary=False, attr_lengths=None,
n_threads=None):
assert not (binary and attr_lengths)
if binary:
attr_lengths = [1] * len(x)
else:
attr_lengths = attr_lengths or ([0] * len(x))
for l in attr_lengths:
assert l in (0, 1)
self.attr_lengths = Array.create_from(regint(attr_lengths))
Array.check_indices = False
Matrix.disable_index_checks()
for xx in x:
assert len(xx) == len(y)
n = len(y)
self.g = sint.Array(n)
self.g.assign_all(0)
self.g[0] = 1
self.NID = sint.Array(n)
self.NID.assign_all(1)
self.y = Array.create_from(y)
self.x = Matrix.create_from(x)
self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)]
self.thresholds = self.x.value_type.Matrix(h, n)
self.n_threads = n_threads
self.debug_selection = False
self.debug_threading = False
self.debug_gini = False
self.debug = False
self.time = False
def train(self):
""" Train and return decision tree. """
h = len(self.nids)
@for_range(h)
def _(k):
self.train_layer(k)
return self.get_tree(h)
def train_with_testing(self, *test_set, output=False):
""" Train decision tree and test against test data.
:param y: binary labels (list or sint vector)
:param x: sample data (by attribute, list or
:py:obj:`~Compiler.types.Matrix`)
:param output: output tree after every level
:returns: tree
"""
for k in range(len(self.nids)):
self.train_layer(k)
tree = self.get_tree(k + 1)
if output:
output_decision_tree(tree)
test_decision_tree('train', tree, self.y, self.x,
n_threads=self.n_threads)
if test_set:
test_decision_tree('test', tree, *test_set,
n_threads=self.n_threads)
return tree
def get_tree(self, h):
Layer = [None] * (h + 1)
for k in range(h):
Layer[k] = CropLayer(k, self.nids[k], self.aids[k],
self.thresholds[k])
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID)
return Layer
def DecisionTreeTraining(x, y, h, binary=False):
return TreeTrainer(x, y, h, binary=binary).train()
def output_decision_tree(layers):
""" Print decision tree output by :py:class:`TreeTrainer`. """
print_ln('full model %s', util.reveal(layers))
for i, layer in enumerate(layers[:-1]):
print_ln('level %s:', i)
for j, x in enumerate(('NID', 'AID', 'Thr')):
print_ln(' %s: %s', x, util.reveal(layer[j]))
print_ln('leaves:')
for j, x in enumerate(('NID', 'result')):
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
def pick(bits, x):
if len(bits) == 1:
return bits[0] * x[0]
else:
try:
return x[0].dot_product(bits, x)
except:
return sum(aa * bb for aa, bb in zip(bits, x))
def run_decision_tree(layers, data):
""" Run decision tree against sample data.
:param layers: tree output by :py:class:`TreeTrainer`
:param data: sample data (:py:class:`~Compiler.types.Array`)
:returns: binary label
"""
h = len(layers) - 1
index = 1
for k, layer in enumerate(layers[:-1]):
assert len(layer) == 3
for x in layer:
assert len(x) <= 2 ** k
bits = layer[0].equal(index, k)
threshold = pick(bits, layer[2])
key_index = pick(bits, layer[1])
if key_index.is_clear:
key = data[key_index]
else:
key = pick(
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
child = 2 * key < threshold
index += child * 2 ** k
bits = layers[h][0].equal(index, h)
return pick(bits, layers[h][1])
def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
if time:
start_timer(100)
n = len(y)
x = x.transpose().reveal()
y = y.reveal()
guess = regint.Array(n)
truth = regint.Array(n)
correct = regint.Array(2)
parts = regint.Array(2)
layers = [[Array.create_from(util.reveal(x)) for x in layer]
for layer in layers]
@for_range_multithread(n_threads, 1, n)
def _(i):
guess[i] = run_decision_tree([[part[:] for part in layer]
for layer in layers], x[i]).reveal()
truth[i] = y[i].reveal()
@for_range(n)
def _(i):
parts[truth[i]] += 1
c = (guess[i].bit_xor(truth[i]).bit_not())
correct[truth[i]] += c
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
if time:
stop_timer(100)
class TreeClassifier:
""" Tree classification with convenient interface. Uses
:py:class:`TreeTrainer` internally.
:param max_depth: the depth of the decision tree
"""
def __init__(self, max_depth):
self.max_depth = max_depth
@staticmethod
def get_attr_lengths(attr_types):
if attr_types == None:
return None
else:
return [1 if x == 'b' else 0 for x in attr_types]
def fit(self, X, y, attr_types=None):
""" Train tree.
:param X: sample data with row-wise samples (sint/sfix matrix)
:param y: binary labels (sint list/array)
"""
self.tree = TreeTrainer(
X.transpose(), y, self.max_depth,
attr_lengths=self.get_attr_lengths(attr_types)).train()
def fit_with_testing(self, X_train, y_train, X_test, y_test,
attr_types=None, output_trees=False, debug=False):
""" Train tree with accuracy output after every level.
:param X_train: training data with row-wise samples (sint/sfix matrix)
:param y_train: training binary labels (sint list/array)
:param X_test: testing data with row-wise samples (sint/sfix matrix)
:param y_test: testing binary labels (sint list/array)
:param attr_types: attributes types (list of 'b'/'c' for
binary/continuous; default is all continuous)
:param output_trees: output tree after every level
:param debug: output debugging information
"""
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
attr_lengths=self.get_attr_lengths(attr_types))
trainer.debug = debug
trainer.debug_gini = debug
trainer.debug_threading = debug > 1
self.tree = trainer.train_with_testing(y_test, X_test.transpose(),
output=output_trees)
def predict(self, X):
""" Use tree for prediction.
:param X: sample data with row-wise samples (sint/sfix matrix)
:returns: sint array
"""
res = sint.Array(len(X))
@for_range(len(X))
def _(i):
res[i] = run_decision_tree(self.tree, X[i])
return res
def output(self):
""" Output decision tree. """
output_decision_tree(self.tree)
def preprocess_pandas(data):
""" Preprocess pandas data frame to suit
:py:class:`TreeClassifier` by expanding non-continuous attributes
to several binary attributes as a unary encoding.
:returns: a tuple of the processed data and a type list for the
:py:obj:`attr_types` argument.
"""
import pandas
import numpy
res = []
types = []
for i, t in enumerate(data.dtypes):
if pandas.api.types.is_int64_dtype(t):
res.append(data.iloc[:,i].to_numpy())
types.append('c')
elif pandas.api.types.is_object_dtype(t):
values = data.iloc[:,i].unique()
print('converting the following to unary:', values)
if len(values) == 2:
res.append(data.iloc[:,i].to_numpy() == values[1])
types.append('b')
else:
for value in values:
res.append(data.iloc[:,i].to_numpy() == value)
types.append('b')
else:
raise CompilerError('unknown pandas type: ' + t)
res = numpy.array(res)
res = numpy.swapaxes(res, 0, 1)
return res, types

View File

@@ -4,8 +4,11 @@ from Compiler.program import Program
ORAM = OptimalORAM
prog = program.Program.prog
prog.set_bit_length(min(64, prog.bit_length))
try:
prog = program.Program.prog
prog.set_bit_length(min(64, prog.bit_length))
except AttributeError:
pass
class HeapEntry(object):
fields = ['empty', 'prio', 'value']
@@ -47,9 +50,11 @@ class HeapEntry(object):
print_ln('empty %s, prio %s, value %s', *(reveal(x) for x in self))
class HeapORAM(object):
def __init__(self, size, oram_type, init_rounds, int_type):
def __init__(self, size, oram_type, init_rounds, int_type, entry_size=None):
if entry_size is None:
entry_size = (32,log2(size))
self.int_type = int_type
self.oram = oram_type(size, entry_size=(32,log2(size)), \
self.oram = oram_type(size, entry_size=entry_size, \
init_rounds=init_rounds, \
value_type=int_type.basic_type)
def __getitem__(self, index):
@@ -74,13 +79,15 @@ class HeapORAM(object):
return len(self.oram)
class HeapQ(object):
def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint):
def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint, entry_size=None):
if entry_size is None:
entry_size = (32, log2(max_size))
basic_type = int_type.basic_type
self.max_size = max_size
self.levels = log2(max_size)
self.depth = self.levels - 1
self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type)
self.value_index = oram_type(max_size, entry_size=log2(max_size), \
self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type, entry_size=entry_size)
self.value_index = oram_type(max_size, entry_size=entry_size[1], \
init_rounds=init_rounds, \
value_type=basic_type)
self.size = MemValue(int_type(0))
@@ -99,7 +106,7 @@ class HeapQ(object):
bits.reverse()
bits = [0] + floatingpoint.PreOR(bits, self.levels)
bits = [bits[i+1] - bits[i] for i in range(self.levels)]
shift = sum([bit << i for i,bit in enumerate(bits)])
shift = self.int_type.bit_compose(bits)
childpos = MemValue(start * shift)
@for_range(self.levels - 1)
def f(i):
@@ -215,12 +222,13 @@ class HeapQ(object):
print_ln()
print_ln()
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
basic_type = int_type.basic_type
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
init_rounds=vert_loops, value_type=basic_type)
init_rounds=vert_loops, value_type=int_type)
int_type = dist.value_type
basic_type = int_type.basic_type
#visited = ORAM(e_index.size)
#previous = oram_type(e_index.size)
Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \
@@ -240,7 +248,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
u = MemValue(basic_type(0))
@for_range(n_loops or edges.size)
def f(i):
cint(i).print_reg('loop')
print_ln('loop %s', i)
time()
u.write(if_else(last_edge, Q.pop(last_edge), u))
#visited.access(u, True, last_edge)

View File

@@ -12,4 +12,7 @@ class ArgumentError(CompilerError):
""" Exception raised for errors in instruction argument parsing. """
def __init__(self, arg, msg):
self.arg = arg
self.msg = msg
self.msg = msg
class VectorMismatch(CompilerError):
pass

View File

@@ -28,13 +28,15 @@ def shift_two(n, pos):
def maskRing(a, k):
shift = int(program.Program.prog.options.ring) - k
if program.Program.prog.use_dabit:
if program.Program.prog.use_edabit():
r_prime, r = types.sint.get_edabit(k)
elif program.Program.prog.use_dabit:
rr, r = zip(*(types.sint.get_dabit() for i in range(k)))
r_prime = types.sint.bit_compose(rr)
else:
r = [types.sint.get_random_bit() for i in range(k)]
r_prime = types.sint.bit_compose(r)
c = ((a + r_prime) << shift).reveal() >> shift
c = ((a + r_prime) << shift).reveal(False) >> shift
return c, r
def maskField(a, k, kappa):
@@ -45,7 +47,7 @@ def maskField(a, k, kappa):
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
# always signed due to usage in equality testing
a += two_power(k)
asm_open(c, a + two_power(k) * r_dprime + r_prime)
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
return c, r
@instructions_base.ret_cisc
@@ -231,7 +233,7 @@ def Inv(a):
ldi(one, 1)
inverse(t[0], t[1])
s = t[0]*a
asm_open(c[0], s)
asm_open(True, c[0], s)
# avoid division by zero for benchmarking
divc(c[1], one, c[0])
#divc(c[1], c[0], one)
@@ -279,7 +281,7 @@ def BitDecRingRaw(a, k, m):
else:
r_bits = [types.sint.get_random_bit() for i in range(m)]
r = types.sint.bit_compose(r_bits)
shifted = ((a - r) << n_shift).reveal()
shifted = ((a - r) << n_shift).reveal(False)
masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return bits
@@ -287,7 +289,7 @@ def BitDecRingRaw(a, k, m):
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
@@ -297,18 +299,19 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
r = [types.sint() for i in range(m)]
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
pow2 = two_power(k + kappa)
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
instructions_base.reset_global_vector_size()
return res
def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
return [types.sint.conv(bit) for bit in res]
return [types.sintbit.conv(bit) for bit in res]
@instructions_base.ret_cisc
def Pow2(a, l, kappa):
comparison.program.curr_tape.require_bit_length(l - 1)
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
return Pow2_from_bits(t)
@@ -316,7 +319,7 @@ def Pow2(a, l, kappa):
def Pow2_from_bits(bits):
m = len(bits)
t = list(bits)
pow2k = [types.cint() for i in range(m)]
pow2k = [None for i in range(m)]
for i in range(m):
pow2k[i] = two_power(2**i)
t[i] = t[i]*pow2k[i] + 1 - t[i]
@@ -339,10 +342,10 @@ def B2U_from_Pow2(pow2a, l, kappa):
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
assert n_shift > 0
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift
else:
comparison.PRandInt(t, kappa)
asm_open(c, pow2a + two_power(l) * t +
asm_open(True, c, pow2a + two_power(l) * t +
sum(two_power(i) * r[i] for i in range(l)))
comparison.program.curr_tape.require_bit_length(l + kappa)
c = list(r_bits[0].bit_decompose_clear(c, l))
@@ -384,15 +387,15 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
r_dprime += t1 - t2
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift
c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift
else:
comparison.PRandInt(rk, kappa)
r_dprime += two_power(l) * rk
asm_open(c, a + r_dprime + r_prime)
asm_open(True, c, a + r_dprime + r_prime)
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
lts(d, c_dprime, r_prime, l, kappa)
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
@@ -414,7 +417,7 @@ def TruncInRing(to_shift, l, pow2m):
rev *= pow2m
r_bits = [types.sint.get_random_bit() for i in range(l)]
r = types.sint.bit_compose(r_bits)
shifted = (rev - (r << n_shift)).reveal()
shifted = (rev - (r << n_shift)).reveal(False)
masked = shifted >> n_shift
bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l))
return types.sint.bit_compose(reversed(bits))
@@ -455,7 +458,7 @@ def Int2FL(a, gamma, l, kappa=None):
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * (1 -z)
p = (p + gamma - 1 - l) * z.bit_not()
return v, p, z, s
def FLRound(x, mode):
@@ -528,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True):
msb = r_bits[-1]
n_shift = n_ring - (k + 1)
tmp = a + r
masked = (tmp << n_shift).reveal()
masked = (tmp << n_shift).reveal(False)
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = msb.bit_xor(masked >> (n_ring - 1))
res = shifted - upper + \
@@ -549,7 +552,7 @@ def TruncPrField(a, k, m, kappa=None):
k, m, kappa, use_dabit=False)
two_to_m = two_power(m)
r = two_to_m * r_dprime + r_prime
c = (b + r).reveal()
c = (b + r).reveal(False)
c_prime = c % two_to_m
a_prime = c_prime - r_prime
d = (a - a_prime) / two_to_m
@@ -629,14 +632,16 @@ def BITLT(a, b, bit_length):
# - From the paper
# Multiparty Computation for Interval, Equality, and Comparison without
# Bit-Decomposition Protocol
def BitDecFull(a, maybe_mixed=False):
def BitDecFull(a, n_bits=None, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint, cint
p = get_program().prime
assert p
bit_length = p.bit_length()
n_bits = n_bits or bit_length
assert n_bits <= bit_length
logp = int(round(math.log(p, 2)))
if abs(p - 2 ** logp) / p < 2 ** -get_program().security:
if get_program().rabbit_gap():
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
# no need for exact randomness generation
# if modulo a power of two is close enough
@@ -663,26 +668,26 @@ def BitDecFull(a, maybe_mixed=False):
def _():
for i in range(bit_length):
tbits[j][i].link(sint.get_random_bit())
c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False))
done[j].link(c)
return (sum(done) != a.size)
for j in range(a.size):
for i in range(bit_length):
movs(bbits[i][j], tbits[j][i])
b = sint.bit_compose(bbits)
c = (a-b).reveal()
c = (a-b).reveal(False)
cmodp = c
t = bbits[0].bit_decompose_clear(p - c, bit_length)
c = longint(c, bit_length)
czero = (c==0)
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
fbar = [bbits[0].clear_type.conv(cint(x))
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
for x in ((1<<bit_length)+c-p).bit_decompose(n_bits)]
fbard = bbits[0].bit_decompose_clear(cmodp, n_bits)
g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)]
h = bbits[0].bit_adder(bbits, g)
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
for i in range(bit_length)]
for i in range(n_bits)]
if maybe_mixed:
return abits
else:

View File

@@ -17,6 +17,7 @@ right order.
import itertools
import operator
import math
from . import tools
from random import randint
from functools import reduce
@@ -69,7 +70,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMC']
arg_format = ['cw','int']
arg_format = ['cw','long']
@base.gf2n
@base.vectorize
@@ -84,7 +85,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMS']
arg_format = ['sw','int']
arg_format = ['sw','long']
@base.gf2n
@base.vectorize
@@ -99,7 +100,7 @@ class stmc(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMC']
arg_format = ['c','int']
arg_format = ['c','long']
@base.gf2n
@base.vectorize
@@ -114,7 +115,7 @@ class stms(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMS']
arg_format = ['s','int']
arg_format = ['s','long']
@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
@@ -128,7 +129,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMINT']
arg_format = ['ciw','int']
arg_format = ['ciw','long']
@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
@@ -142,7 +143,7 @@ class stmint(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMINT']
arg_format = ['ci','int']
arg_format = ['ci','long']
@base.vectorize
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
@@ -294,6 +295,7 @@ class movint(base.Instruction):
@base.vectorize
class pushint(base.StackInstruction):
""" Pushes clear integer register to the thread-local stack.
Considered obsolete.
:param: source (regint)
"""
@@ -303,6 +305,7 @@ class pushint(base.StackInstruction):
@base.vectorize
class popint(base.StackInstruction):
""" Pops from the thread-local stack to clear integer register.
Considered obsolete.
:param: destination (regint)
"""
@@ -353,7 +356,17 @@ class reqbl(base.Instruction):
code = base.opcodes['REQBL']
arg_format = ['int']
class active(base.Instruction):
""" Indicate whether program is compatible with malicious-security
protocols.
:param: 0 for no, 1 for yes
"""
code = base.opcodes['ACTIVE']
arg_format = ['int']
class time(base.IOInstruction):
""" Output time since start of computation. """
code = base.opcodes['TIME']
arg_format = []
@@ -384,7 +397,15 @@ class use(base.Instruction):
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE']
arg_format = ['int','int','int']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i),
find_in_dict(data_types, args[1].i)):
args[2].i}
class use_inp(base.Instruction):
""" Input usage. Necessary to avoid reusage while using
@@ -395,7 +416,14 @@ class use_inp(base.Instruction):
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_INP']
arg_format = ['int','int','int']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i), 'input', args[1].i):
args[2].i}
class use_edabit(base.Instruction):
""" edaBit usage. Necessary to avoid reusage while using
@@ -407,7 +435,11 @@ class use_edabit(base.Instruction):
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_EDABIT']
arg_format = ['int','int','int']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i}
class use_matmul(base.Instruction):
""" Matrix multiplication usage. Used for multithreading of
@@ -419,7 +451,11 @@ class use_matmul(base.Instruction):
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_MATMUL']
arg_format = ['int','int','int','int']
arg_format = ['int','int','int','long']
@classmethod
def get_usage(cls, args):
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
class run_tape(base.Instruction):
""" Start tape/bytecode file in another thread.
@@ -442,7 +478,7 @@ class join_tape(base.Instruction):
arg_format = ['int']
class crash(base.IOInstruction):
""" Crash runtime if the register's value is > 0.
""" Crash runtime if the value in the register is not zero.
:param: Crash condition (regint)"""
code = base.opcodes['CRASH']
@@ -464,7 +500,12 @@ class use_prep(base.Instruction):
:param: number of items to use (int, -1 for unknown)
"""
code = base.opcodes['USE_PREP']
arg_format = ['str','int']
arg_format = ['str','long']
@classmethod
def get_usage(cls, args):
return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp',
args[0].str): args[1].i}
class nplayers(base.Instruction):
""" Store number of players in clear integer register.
@@ -585,6 +626,18 @@ class submr(base.SubBase):
code = base.opcodes['SUBMR']
arg_format = ['sw','c','s']
@base.vectorize
class prefixsums(base.Instruction):
""" Prefix sum.
:param: result (sint)
:param: input (sint)
"""
__slots__ = []
code = base.opcodes['PREFIXSUMS']
arg_format = ['sw','s']
@base.gf2n
@base.vectorize
class mulc(base.MulBase):
@@ -778,30 +831,6 @@ class gbitcom(base.Instruction):
return True
###
### Special GF(2) arithmetic instructions
###
@base.vectorize
class gmulbitc(base.MulBase):
r""" Clear GF(2^n) by clear GF(2) multiplication """
__slots__ = []
code = base.opcodes['GMULBITC']
arg_format = ['cgw','cg','cg']
def is_gf2n(self):
return True
@base.vectorize
class gmulbitm(base.MulBase):
r""" Secret GF(2^n) by clear GF(2) multiplication """
__slots__ = []
code = base.opcodes['GMULBITM']
arg_format = ['sgw','sg','cg']
def is_gf2n(self):
return True
###
### Arithmetic with immediate values
###
@@ -1046,6 +1075,7 @@ class shrci(base.ClearShiftInstruction):
code = base.opcodes['SHRCI']
op = '__rshift__'
@base.gf2n
@base.vectorize
class shrsi(base.ClearShiftInstruction):
""" Bitwise right shift of secret register (vector) by (constant)
@@ -1184,7 +1214,7 @@ class randoms(base.Instruction):
field_type = 'modp'
@base.vectorize
class randomfulls(base.Instruction):
class randomfulls(base.DataInstruction):
""" Store share(s) of a fresh secret random element in secret
register (vectors).
@@ -1194,6 +1224,10 @@ class randomfulls(base.Instruction):
code = base.opcodes['RANDOMFULLS']
arg_format = ['sw']
field_type = 'modp'
data_type = 'random'
def get_repeat(self):
return len(self.args)
@base.gf2n
@base.vectorize
@@ -1229,15 +1263,20 @@ class inverse(base.DataInstruction):
@base.gf2n
@base.vectorize
class inputmask(base.Instruction):
r""" Load secret $s_i$ with the next input mask for player $p$ and
write the mask on player $p$'s private output. """
""" Store fresh random input mask(s) in secret register (vector) and clear
register (vector) of the relevant player.
:param: mask (sint)
:param: mask (cint, player only)
:param: player (int)
"""
__slots__ = []
code = base.opcodes['INPUTMASK']
arg_format = ['sw', 'p']
arg_format = ['sw', 'cw', 'p']
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[1]), \
req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size())
@base.vectorize
@@ -1275,7 +1314,7 @@ class prep(base.Instruction):
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, self.args[0]), 1)
req_node.increment((self.field_type, self.args[0]), self.get_size())
def has_var_args(self):
return True
@@ -1293,10 +1332,8 @@ class asm_input(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[1::2]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
def get_players(self):
return self.args[1::2]
@base.vectorize
class inputfix(base.TextInputInstruction):
@@ -1305,10 +1342,8 @@ class inputfix(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[2::3]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
def get_players(self):
return self.args[2::3]
@base.vectorize
class inputfloat(base.TextInputInstruction):
@@ -1322,7 +1357,7 @@ class inputfloat(base.TextInputInstruction):
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())
class inputmixed_base(base.TextInputInstruction):
class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
__slots__ = []
field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM)
@@ -1341,22 +1376,30 @@ class inputmixed_base(base.TextInputInstruction):
type_id = self.type_ids[name]
super(inputmixed_base, self).__init__(type_id, *args)
@property
def arg_format(self):
for i in self.bases():
t = self.args[i]
yield 'int'
@classmethod
def dynamic_arg_format(self, args):
yield 'int'
for i, t in self.bases(iter(args)):
for j in range(self.types[t][0]):
yield 'sw'
for j in range(self.types[t][1]):
yield 'int'
yield self.player_arg_type
yield 'int'
def bases(self):
@classmethod
def bases(self, args):
i = 0
while i < len(self.args):
yield i
i += sum(self.types[self.args[i]]) + 2
while True:
try:
t = next(args)
except StopIteration:
return
yield i, t
n = sum(self.types[t])
i += n + 2
for j in range(n + 1):
next(args)
@base.vectorize
class inputmixed(inputmixed_base):
@@ -1380,14 +1423,16 @@ class inputmixed(inputmixed_base):
player_arg_type = 'p'
def add_usage(self, req_node):
for i in self.bases():
t = self.args[i]
for i, t in self.bases(iter(self.args)):
player = self.args[i + sum(self.types[t]) + 1]
n_dest = self.types[t][0]
req_node.increment((self.field_type, 'input', player), \
n_dest * self.get_size())
@base.vectorize
def get_players(self):
for i, t in self.bases(iter(self.args)):
yield self.args[i + sum(self.types[t]) + 1]
class inputmixedreg(inputmixed_base):
""" Store private input in secret registers (vectors). The input is
read as integer or floating-point number and the latter is then
@@ -1407,11 +1452,29 @@ class inputmixedreg(inputmixed_base):
"""
code = base.opcodes['INPUTMIXEDREG']
player_arg_type = 'ci'
is_vec = lambda self: True
def __init__(self, *args):
inputmixed_base.__init__(self, *args)
for i, t in self.bases(iter(self.args)):
n = self.types[t][0]
for j in range(i + 1, i + 1 + n):
assert args[j].size == self.get_size()
def get_size(self):
return self.args[1].size
def get_code(self):
return inputmixed_base.get_code(
self, self.get_size() if self.get_size() > 1 else 0)
def add_usage(self, req_node):
# player 0 as proxy
req_node.increment((self.field_type, 'input', 0), float('inf'))
def get_players(self):
pass
@base.gf2n
@base.vectorize
class rawinput(base.RawInputInstruction, base.Mergeable):
@@ -1433,7 +1496,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable):
req_node.increment((self.field_type, 'input', player), \
self.get_size())
class inputpersonal(base.Instruction, base.Mergeable):
class personal_base(base.Instruction, base.Mergeable):
__slots__ = []
field_type = 'modp'
def __init__(self, *args):
super(personal_base, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
class inputpersonal(personal_base):
""" Private input from cint.
:param: vector size (int)
@@ -1445,19 +1524,47 @@ class inputpersonal(base.Instruction, base.Mergeable):
__slots__ = []
code = base.opcodes['INPUTPERSONAL']
arg_format = tools.cycle(['int','p','sw','c'])
field_type = 'modp'
def __init__(self, *args):
super(inputpersonal, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
class privateoutput(personal_base, base.DataInstruction):
""" Private output to cint.
:param: vector size (int)
:param: player (int)
:param: destination (cint)
:param: source (sint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['PRIVATEOUTPUT']
arg_format = tools.cycle(['int','p','cw','s'])
data_type = 'open'
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
personal_base.add_usage(self, req_node)
base.DataInstruction.add_usage(self, req_node)
def get_repeat(self):
return sum(self.args[::4])
class sendpersonal(base.Instruction, base.Mergeable):
""" Private input from cint.
:param: vector size (int)
:param: destination player (int)
:param: destination (cint)
:param: source player (int)
:param: source (cint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['SENDPERSONAL']
arg_format = tools.cycle(['int','p','cw','p','c'])
def __init__(self, *args):
super(sendpersonal, self).__init__(*args)
for i in range(0, len(args), 5):
assert args[i + 2].size == args[i]
assert args[i + 4].size == args[i]
@base.gf2n
@base.vectorize
@@ -1546,7 +1653,7 @@ class print_char(base.IOInstruction):
arg_format = ['int']
def __init__(self, ch):
super(print_char, self).__init__(ord(ch))
super(print_char, self).__init__(ch)
class print_char4(base.IOInstruction):
""" Output four bytes.
@@ -1650,6 +1757,7 @@ class writesockets(base.IOInstruction):
from registers into a socket for a specified client id. If the
protocol uses MACs, the client should be different for every party.
:param: number of arguments to follow
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
@@ -1727,14 +1835,15 @@ class writesharestofile(base.IOInstruction):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
:param: number of shares (int)
:param: number of arguments to follow / number of shares plus one (int)
:param: position (regint, -1 for appending)
:param: source (sint)
:param: (repeat from source)...
"""
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = itertools.repeat('s')
arg_format = tools.chain(['ci'], itertools.repeat('s'))
def has_var_args(self):
return True
@@ -1788,26 +1897,19 @@ class floatoutput(base.PublicFileIOInstruction):
code = base.opcodes['FLOATOUTPUT']
arg_format = ['p','c','c','c','c']
@base.gf2n
@base.vectorize
class startprivateoutput(base.Instruction):
r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
__slots__ = []
code = base.opcodes['STARTPRIVATEOUTPUT']
arg_format = ['sw','s','p']
field_type = 'modp'
class fixinput(base.PublicFileIOInstruction):
""" Binary fixed-point input.
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size())
:param: player (int)
:param: destination (cint)
:param: exponent (int)
:param: input type (0: 64-bit integer, 1: float, 2: double)
@base.gf2n
@base.vectorize
class stopprivateoutput(base.Instruction):
r""" Previously iniated private output to $n$ via $c_i$. """
"""
__slots__ = []
code = base.opcodes['STOPPRIVATEOUTPUT']
arg_format = ['cw','c','p']
code = base.opcodes['FIXINPUT']
arg_format = ['p','cw','int','int']
@base.vectorize
class rand(base.Instruction):
@@ -2122,17 +2224,26 @@ class gconvgf2n(base.Instruction):
# rename 'open' to avoid conflict with built-in open function
@base.gf2n
@base.vectorize
class asm_open(base.VarArgsInstruction):
class asm_open(base.VarArgsInstruction, base.DataInstruction):
""" Reveal secret registers (vectors) to clear registers (vectors).
:param: number of argument to follow (multiple of two)
:param: number of argument to follow (odd number)
:param: check after opening (0/1)
:param: destination (cint)
:param: source (sint)
:param: (repeat the last two)...
"""
__slots__ = []
code = base.opcodes['OPEN']
arg_format = tools.cycle(['cw','s'])
arg_format = tools.chain(['int'], tools.cycle(['cw','s']))
data_type = 'open'
def get_repeat(self):
return (len(self.args) - 1) // 2
def merge(self, other):
self.args[0] |= other.args[0]
self.args += other.args[1:]
@base.gf2n
@base.vectorize
@@ -2209,7 +2320,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction):
@base.gf2n
@base.vectorize
class dotprods(base.VarArgsInstruction, base.DataInstruction):
class dotprods(base.VarArgsInstruction, base.DataInstruction,
base.DynFormatInstruction):
""" Dot product of secret registers (vectors).
Note that the vectorized version works element-wise.
@@ -2237,31 +2349,30 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
flat_args += [x, y]
base.Instruction.__init__(self, *flat_args)
@property
def arg_format(self):
@classmethod
def dynamic_arg_format(self, args):
field = 'g' if self.is_gf2n() else ''
for i in self.bases():
yield 'int'
yield 'int'
for i, n in self.bases(args):
yield 's' + field + 'w'
for j in range(self.args[i] - 2):
assert n > 2
for j in range(n - 2):
yield 's' + field
yield 'int'
gf2n_arg_format = arg_format
def bases(self):
i = 0
while i < len(self.args):
yield i
i += self.args[i]
@property
def gf2n_arg_format(self):
return self.arg_format()
def get_repeat(self):
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
return sum(self.args[i] // 2
for i, n in self.bases(iter(self.args))) * self.get_size()
def get_def(self):
return [self.args[i + 1] for i in self.bases()]
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
def get_used(self):
for i in self.bases():
for i, n in self.bases(iter(self.args)):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
@@ -2317,9 +2428,10 @@ class matmulsm(matmul_base):
super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
class conv2ds(base.DataInstruction):
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution.
:param: number of arguments to follow (int)
:param: result (sint vector in row-first order)
:param: inputs (sint vector in row-first order)
:param: weights (sint vector in row-first order)
@@ -2335,10 +2447,12 @@ class conv2ds(base.DataInstruction):
:param: padding height (int)
:param: padding width (int)
:param: batch size (int)
:param: repeat from result...
"""
code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
'int','int','int','int']
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
'int','int','int','int','int','int','int'])
data_type = 'triple'
is_vec = lambda self: True
@@ -2349,14 +2463,16 @@ class conv2ds(base.DataInstruction):
assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11] * self.args[14]
args = self.args
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
args[i+11] * args[i+14] for i in range(0, len(args), 15))
def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node)
args = self.args
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
for i in range(0, len(self.args), 15):
args = self.args[i:i + 15]
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
@base.vectorize
class trunc_pr(base.VarArgsInstruction):
@@ -2372,6 +2488,124 @@ class trunc_pr(base.VarArgsInstruction):
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
class shuffle_base(base.DataInstruction):
n_relevant_parties = 2
@staticmethod
def logn(n):
return int(math.ceil(math.log(n, 2)))
@classmethod
def n_swaps(cls, n):
logn = cls.logn(n)
return logn * 2 ** logn - 2 ** logn + 1
def add_gen_usage(self, req_node, n):
# hack for unknown usage
req_node.increment(('bit', 'inverse'), float('inf'))
# minimal usage with two relevant parties
logn = self.logn(n)
n_switches = self.n_swaps(n)
for i in range(self.n_relevant_parties):
req_node.increment((self.field_type, 'input', i), n_switches)
# multiplications for bit check
req_node.increment((self.field_type, 'triple'),
n_switches * self.n_relevant_parties)
def add_apply_usage(self, req_node, n, record_size):
req_node.increment(('bit', 'inverse'), float('inf'))
logn = self.logn(n)
n_switches = self.n_swaps(n) * self.n_relevant_parties
if n != 2 ** logn:
record_size += 1
req_node.increment((self.field_type, 'triple'),
n_switches * record_size)
@base.gf2n
class secshuffle(base.VectorInstruction, shuffle_base):
""" Secure shuffling.
:param: destination (sint)
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['SECSHUFFLE']
arg_format = ['sw','s','int']
def __init__(self, *args, **kwargs):
super(secshuffle_class, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
self.add_gen_usage(req_node, len(self.args[0]))
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
class gensecshuffle(shuffle_base):
""" Generate secure shuffle to bit used several times.
:param: destination (regint)
:param: size (int)
"""
__slots__ = []
code = base.opcodes['GENSECSHUFFLE']
arg_format = ['ciw','int']
def add_usage(self, req_node):
self.add_gen_usage(req_node, self.args[1])
class applyshuffle(base.VectorInstruction, shuffle_base):
""" Generate secure shuffle to bit used several times.
:param: destination (sint)
:param: source (sint)
:param: number of elements to be treated as one (int)
:param: handle (regint)
:param: reverse (0/1)
"""
__slots__ = []
code = base.opcodes['APPLYSHUFFLE']
arg_format = ['sw','s','int','ci','int']
def __init__(self, *args, **kwargs):
super(applyshuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
class delshuffle(base.Instruction):
""" Delete secure shuffle.
:param: handle (regint)
"""
code = base.opcodes['DELSHUFFLE']
arg_format = ['ci']
class inverse_permutation(base.VectorInstruction, shuffle_base):
""" Calculate the inverse permutation of a secret permutation.
:param: destination (sint)
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['INVPERM']
arg_format = ['sw', 's']
def __init__(self, *args, **kwargs):
super(inverse_permutation, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
def add_usage(self, req_node):
self.add_gen_usage(req_node, len(self.args[0]))
self.add_apply_usage(req_node, len(self.args[0]), 1)
class check(base.Instruction):
"""
Force MAC check in current thread and all idle thread if current
@@ -2399,7 +2633,7 @@ class sqrs(base.CISC):
c = [program.curr_block.new_reg('c') for i in range(2)]
square(s[0], s[1])
subs(s[2], self.args[1], s[0])
asm_open(c[0], s[2])
asm_open(False, c[0], s[2])
mulc(c[1], c[0], c[0])
mulm(s[3], self.args[1], c[0])
adds(s[4], s[3], s[3])
@@ -2407,19 +2641,6 @@ class sqrs(base.CISC):
subml(self.args[0], s[5], c[1])
@base.gf2n
@base.vectorize
class lts(base.CISC):
""" Secret comparison $s_i = (s_j < s_k)$. """
__slots__ = []
arg_format = ['sw', 's', 's', 'int', 'int']
def expand(self):
from .types import sint
a = sint()
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
# placeholder for documentation
class cisc:
""" Meta instruction for emulation. This instruction is only generated

View File

@@ -4,6 +4,8 @@ import time
import inspect
import functools
import copy
import sys
import struct
from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
@@ -64,6 +66,7 @@ opcodes = dict(
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -78,6 +81,7 @@ opcodes = dict(
SUBSI = 0x2A,
SUBCFI = 0x2B,
SUBSFI = 0x2C,
PREFIXSUMS = 0x2D,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
@@ -103,6 +107,13 @@ opcodes = dict(
MATMULSM = 0xAB,
CONV2DS = 0xAC,
CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
# Shuffling
SECSHUFFLE = 0xFA,
GENSECSHUFFLE = 0xFB,
APPLYSHUFFLE = 0xFC,
DELSHUFFLE = 0xFD,
INVPERM = 0xFE,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -126,6 +137,7 @@ opcodes = dict(
INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4,
INPUTPERSONAL = 0xF5,
SENDPERSONAL = 0xF6,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -198,8 +210,9 @@ opcodes = dict(
CONDPRINTPLAIN = 0xE1,
INTOUTPUT = 0xE6,
FLOATOUTPUT = 0xE7,
GBITDEC = 0x184,
GBITCOM = 0x185,
FIXINPUT = 0xE8,
GBITDEC = 0x18A,
GBITCOM = 0x18B,
# Secure socket
INITSECURESOCKET = 0x1BA,
RESPSECURESOCKET = 0x1BB
@@ -215,8 +228,13 @@ def int_to_bytes(x):
global_vector_size_stack = []
global_instruction_type_stack = ['modp']
def check_vector_size(size):
if isinstance(size, program.curr_tape.Register):
raise CompilerError('vector size must be known at compile time')
def set_global_vector_size(size):
stack = global_vector_size_stack
check_vector_size(size)
if size == 1 and not stack:
return
stack.append(size)
@@ -299,11 +317,12 @@ def vectorize(instruction, global_dict=None):
vectorized_name = 'v' + instruction.__name__
Vectorized_Instruction.__name__ = vectorized_name
global_dict[vectorized_name] = Vectorized_Instruction
if 'sphinx.extension' in sys.modules:
return instruction
global_dict[instruction.__name__ + '_class'] = instruction
instruction.__doc__ = ''
# exclude GF(2^n) instructions from documentation
if instruction.code and instruction.code >> 8 == 1:
maybe_vectorized_instruction.__doc__ = ''
maybe_vectorized_instruction.arg_format = instruction.arg_format
return maybe_vectorized_instruction
@@ -332,7 +351,7 @@ def gf2n(instruction):
if isinstance(arg_format, list):
__format = []
for __f in arg_format:
if __f in ('int', 'p', 'ci', 'str'):
if __f in ('int', 'long', 'p', 'ci', 'str'):
__format.append(__f)
else:
__format.append(__f[0] + 'g' + __f[1:])
@@ -355,12 +374,13 @@ def gf2n(instruction):
arg_format = instruction_cls.gf2n_arg_format
elif isinstance(instruction_cls.arg_format, itertools.repeat):
__f = next(instruction_cls.arg_format)
if __f != 'int' and __f != 'p':
if __f not in ('int', 'long', 'p'):
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
else:
arg_format = copy.deepcopy(instruction_cls.arg_format)
reformat(arg_format)
@classmethod
def is_gf2n(self):
return True
@@ -389,8 +409,11 @@ def gf2n(instruction):
else:
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
if 'sphinx.extension' in sys.modules:
return instruction
global_dict[instruction.__name__ + '_class'] = instruction_cls
instruction_cls.__doc__ = ''
maybe_gf2n_instruction.arg_format = instruction.arg_format
return maybe_gf2n_instruction
#return instruction
@@ -404,6 +427,7 @@ def cisc(function):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.security = program.security
self.calls = [(args, kwargs)]
self.params = []
self.used = []
@@ -427,7 +451,7 @@ def cisc(function):
def merge_id(self):
return self.function, tuple(self.params), \
tuple(sorted(self.kwargs.items()))
tuple(sorted(self.kwargs.items())), self.security
def merge(self, other):
self.calls += other.calls
@@ -452,7 +476,10 @@ def cisc(function):
except:
args.append(arg)
program.options.cisc = False
old_security = program.security
program.security = self.security
self.function(*args, **self.kwargs)
program.security = old_security
program.options.cisc = True
reset_global_vector_size()
program.curr_tape = old_tape
@@ -499,8 +526,12 @@ def cisc(function):
for arg in self.args:
try:
new_regs.append(type(arg)(size=size))
except:
except TypeError:
break
except:
print([call[0][0].size for call in self.calls])
raise
assert len(new_regs) > 1
base = 0
for call in self.calls:
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
@@ -523,7 +554,7 @@ def cisc(function):
def get_bytes(self):
assert len(self.kwargs) < 2
res = int_to_bytes(opcodes['CISC'])
res = LongArgFormat.encode(opcodes['CISC'])
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
name = self.function.__name__
String.check(name)
@@ -559,7 +590,7 @@ def cisc(function):
same_sizes &= arg.size == args[0].size
except:
pass
if program.options.cisc and same_sizes:
if program.use_cisc() and same_sizes:
return MergeCISC(*args, **kwargs)
else:
return function(*args, **kwargs)
@@ -572,9 +603,9 @@ def ret_cisc(function):
instruction = cisc(instruction)
def wrapper(*args, **kwargs):
if not program.options.cisc:
return function(*args, **kwargs)
from Compiler import types
if not (program.options.cisc and isinstance(args[0], types._register)):
return function(*args, **kwargs)
if isinstance(args[0], types._clear):
res_type = type(args[1])
else:
@@ -651,7 +682,8 @@ class RegisterArgFormat(ArgFormat):
raise ArgumentError(arg, 'Invalid register argument')
if arg.program != program.curr_tape:
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
util.format_trace(arg.caller))
util.format_trace(arg.caller) +
'\nMaybe use MemValue')
if arg.reg_type != cls.reg_type:
raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \
(arg.reg_type, cls.reg_type))
@@ -661,37 +693,68 @@ class RegisterArgFormat(ArgFormat):
assert arg.i >= 0
return int_to_bytes(arg.i)
def __init__(self, f):
self.i = struct.unpack('>I', f.read(4))[0]
def __str__(self):
return self.reg_type + str(self.i)
class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp
name = 'cint'
class SecretModpAF(RegisterArgFormat):
reg_type = RegType.SecretModp
name = 'sint'
class ClearGF2NAF(RegisterArgFormat):
reg_type = RegType.ClearGF2N
name = 'cgf2n'
class SecretGF2NAF(RegisterArgFormat):
reg_type = RegType.SecretGF2N
name = 'sgf2n'
class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
name = 'regint'
class IntArgFormat(ArgFormat):
n_bits = 32
@classmethod
def check(cls, arg):
if not isinstance(arg, int) and not arg is None:
raise ArgumentError(arg, 'Expected an integer-valued argument')
if not arg is None:
if not isinstance(arg, int):
raise ArgumentError(arg, 'Expected an integer-valued argument')
if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits:
raise ArgumentError(
arg, 'Immediate value outside of %d-bit range' % cls.n_bits)
@classmethod
def encode(cls, arg):
return int_to_bytes(arg)
def __init__(self, f):
self.i = struct.unpack('>i', f.read(4))[0]
def __str__(self):
return str(self.i)
class LongArgFormat(IntArgFormat):
n_bits = 64
@classmethod
def encode(cls, arg):
return list(struct.pack('>q', arg))
def __init__(self, f):
self.i = struct.unpack('>q', f.read(8))[0]
class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
super(ImmediateModpAF, cls).check(arg)
if arg >= 2**32 or arg < -2**32:
raise ArgumentError(arg, 'Immediate value outside of 32-bit range')
class ImmediateGF2NAF(IntArgFormat):
@classmethod
@@ -702,6 +765,8 @@ class ImmediateGF2NAF(IntArgFormat):
class PlayerNoAF(IntArgFormat):
@classmethod
def check(cls, arg):
if not util.is_constant(arg):
raise CompilerError('Player number must be known at compile time')
super(PlayerNoAF, cls).check(arg)
if arg > 256:
raise ArgumentError(arg, 'Player number > 256')
@@ -722,6 +787,13 @@ class String(ArgFormat):
def encode(cls, arg):
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
def __init__(self, f):
tmp = f.read(16)
self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')
def __str__(self):
return self.str
ArgFormats = {
'c': ClearModpAF,
's': SecretModpAF,
@@ -736,6 +808,7 @@ ArgFormats = {
'i': ImmediateModpAF,
'ig': ImmediateGF2NAF,
'int': IntArgFormat,
'long': LongArgFormat,
'p': PlayerNoAF,
'str': String,
}
@@ -776,7 +849,7 @@ class Instruction(object):
return (prefix << self.code_length) + self.code
def get_encoding(self):
enc = int_to_bytes(self.get_code())
enc = LongArgFormat.encode(self.get_code())
# add the number of registers if instruction flagged as has var args
if self.has_var_args():
enc += int_to_bytes(len(self.args))
@@ -829,6 +902,7 @@ class Instruction(object):
def is_vec(self):
return False
@classmethod
def is_gf2n(self):
return False
@@ -877,6 +951,10 @@ class Instruction(object):
new_args.append(arg)
return new_args
@staticmethod
def get_usage(args):
return {}
# String version of instruction attempting to replicate encoded version
def __str__(self):
@@ -890,6 +968,66 @@ class Instruction(object):
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
class ParsedInstruction:
reverse_opcodes = {}
def __init__(self, f):
cls = type(self)
from Compiler import instructions
from Compiler.GC import instructions as gc_inst
if not cls.reverse_opcodes:
for module in instructions, gc_inst:
for x, y in inspect.getmodule(module).__dict__.items():
if inspect.isclass(y) and y.__name__[0] != 'v':
try:
cls.reverse_opcodes[y.code] = y
except AttributeError:
pass
read = lambda: struct.unpack('>I', f.read(4))[0]
full_code = struct.unpack('>Q', f.read(8))[0]
code = full_code % (1 << Instruction.code_length)
self.size = full_code >> Instruction.code_length
self.type = cls.reverse_opcodes[code]
t = self.type
name = t.__name__
try:
n_args = len(t.arg_format)
self.var_args = False
except:
n_args = read()
self.var_args = True
try:
arg_format = iter(t.arg_format)
except:
if name == 'cisc':
arg_format = itertools.chain(['str'], itertools.repeat('int'))
else:
def arg_iter():
i = 0
while True:
try:
yield self.args[i].i
except AttributeError:
yield None
i += 1
arg_format = t.dynamic_arg_format(arg_iter())
self.args = []
for i in range(n_args):
self.args.append(ArgFormats[next(arg_format)](f))
def __str__(self):
name = self.type.__name__
res = name + ' '
if self.size > 1:
res = 'v' + res + str(self.size) + ', '
if self.var_args:
res += str(len(self.args)) + ', '
res += ', '.join(str(arg) for arg in self.args)
return res
def get_usage(self):
return self.type.get_usage(self.args)
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
@@ -901,6 +1039,26 @@ class VectorInstruction(Instruction):
def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
class DynFormatInstruction(Instruction):
__slots__ = []
@property
def arg_format(self):
return self.dynamic_arg_format(iter(self.args))
@classmethod
def bases(self, args):
i = 0
while True:
try:
n = next(args)
except StopIteration:
return
yield i, n
i += n
for j in range(n - 1):
next(args)
###
### Basic arithmetic
###
@@ -934,21 +1092,27 @@ class ClearImmediate(ImmediateBase):
### Memory access instructions
###
class DirectMemoryInstruction(Instruction):
class MemoryInstruction(Instruction):
__slots__ = ['_protect']
def __init__(self, *args, **kwargs):
super(MemoryInstruction, self).__init__(*args, **kwargs)
self._protect = program._protect_memory
class DirectMemoryInstruction(MemoryInstruction):
__slots__ = []
def __init__(self, *args, **kwargs):
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
class IndirectMemoryInstruction(Instruction):
class IndirectMemoryInstruction(MemoryInstruction):
__slots__ = []
def get_direct(self, address):
return self.direct(self.args[0], address, add_to_prog=False)
class ReadMemoryInstruction(Instruction):
class ReadMemoryInstruction(MemoryInstruction):
__slots__ = []
class WriteMemoryInstruction(Instruction):
class WriteMemoryInstruction(MemoryInstruction):
__slots__ = []
class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
@@ -975,12 +1139,16 @@ class IOInstruction(DoNotEliminateInstruction):
@classmethod
def str_to_int(cls, s):
""" Convert a 4 character string to an integer. """
try:
s = bytearray(s, 'utf8')
except:
pass
if len(s) > 4:
raise CompilerError('String longer than 4 characters')
n = 0
for c in reversed(s.ljust(4)):
n <<= 8
n += ord(c)
n += c
return n
class AsymmetricCommunicationInstruction(DoNotEliminateInstruction):
@@ -999,6 +1167,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """
__slots__ = []
def add_usage(self, req_node):
for player in self.get_players():
req_node.increment((self.field_type, 'input', player), \
self.get_size())
###
### Data access instructions
###

View File

@@ -12,6 +12,7 @@ import inspect,math
import random
import collections
import operator
import copy
from functools import reduce
def get_program():
@@ -63,6 +64,7 @@ def print_str(s, *args):
variables/registers with ``%s``. """
def print_plain_str(ss):
""" Print a plain string (no custom formatting options) """
ss = bytearray(ss, 'utf8')
i = 1
while 4*i <= len(ss):
print_char4(ss[4*(i-1):4*i])
@@ -115,7 +117,12 @@ def print_ln(s='', *args):
print_ln('a is %s.', a.reveal())
"""
print_str(s + '\n', *args)
print_str(str(s) + '\n', *args)
def print_both(s, end='\n'):
""" Print line during compilation and execution. """
print(s, end=end)
print_str(s + end)
def print_ln_if(cond, ss, *args):
""" Print line if :py:obj:`cond` is true. The further arguments
@@ -137,7 +144,7 @@ def print_str_if(cond, ss, *args):
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
if util.is_constant(cond):
if cond:
print_ln(ss, *args)
print_str(ss, *args)
else:
subs = ss.split('%s')
assert len(subs) == len(args) + 1
@@ -153,7 +160,8 @@ def print_str_if(cond, ss, *args):
print_str_if(cond, *_expand_to_print(val))
else:
print_str_if(cond, str(val))
s += '\0' * ((-len(s)) % 4)
s = bytearray(s, 'utf8')
s += b'\0' * ((-len(s)) % 4)
while s:
cond.print_if(s[:4])
s = s[4:]
@@ -219,7 +227,10 @@ def crash(condition=None):
:param condition: crash if true (default: true)
"""
if condition == None:
if isinstance(condition, localint):
# allow crash on local values
condition = condition._v
if condition is None:
condition = regint(1)
instructions.crash(regint.conv(condition))
@@ -239,6 +250,10 @@ def store_in_mem(value, address):
try:
value.store_in_mem(address)
except AttributeError:
if isinstance(value, (list, tuple)):
for i, x in enumerate(value):
store_in_mem(x, address + i)
return
# legacy
if value.is_clear:
if isinstance(address, cint):
@@ -257,11 +272,13 @@ def reveal(secret):
try:
return secret.reveal()
except AttributeError:
if secret.is_clear:
return secret
if secret.is_gf2n:
res = cgf2n()
else:
res = cint()
instructions.asm_open(res, secret)
instructions.asm_open(True, res, secret)
return res
@vectorize
@@ -278,13 +295,13 @@ def get_arg():
ldarg(res)
return res
def make_array(l):
def make_array(l, t=None):
if isinstance(l, program.Tape.Register):
res = Array(1, type(l))
res[0] = l
res = Array(len(l), t or type(l))
res[:] = l
else:
l = list(l)
res = Array(len(l), type(l[0]) if l else cint)
res = Array(len(l), t or type(l[0]) if l else cint)
res.assign(l)
return res
@@ -456,6 +473,10 @@ def method_block(function):
return wrapper
def cond_swap(x,y):
from .types import SubMultiArray
if isinstance(x, (Array, SubMultiArray)):
b = x[0] > y[0]
return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)]))
b = x < y
if isinstance(x, sfloat):
res = ([], [])
@@ -467,11 +488,11 @@ def cond_swap(x,y):
res[0].append(bx + yy - by)
res[1].append(xx - bx + by)
return sfloat(*res[0]), sfloat(*res[1])
bx = b * x
by = b * y
return bx + y - by, x - bx + by
return b.cond_swap(y, x)
def sort(a):
print("WARNING: you're using bubble sort")
res = a
for i in range(len(a)):
@@ -497,282 +518,36 @@ def odd_even_merge_sort(a):
if len(a) == 1:
return
elif len(a) % 2 == 0:
aa = a
a = list(a)
lower = a[:len(a)//2]
upper = a[len(a)//2:]
odd_even_merge_sort(lower)
odd_even_merge_sort(upper)
a[:] = lower + upper
odd_even_merge(a)
aa[:] = a
else:
raise CompilerError('Length of list must be power of two')
def chunky_odd_even_merge_sort(a):
tmp = a[0].Array(len(a))
for i,j in enumerate(a):
tmp[i] = j
l = 1
while l < len(a):
l *= 2
k = 1
while k < l:
k *= 2
def round():
for i in range(len(a)):
a[i] = tmp[i]
for i in range(len(a) // l):
for j in range(l // k):
base = i * l + j
step = l // k
if k == 2:
a[base], a[base+step] = cond_swap(a[base], a[base+step])
else:
b = a[base:base+k*step:step]
for m in range(base + step, base + (k - 1) * step, 2 * step):
a[m], a[m+step] = cond_swap(a[m], a[m+step])
for i in range(len(a)):
tmp[i] = a[i]
chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True)
chunk.start()
chunk.join()
#round()
for i in range(len(a)):
a[i] = tmp[i]
raise CompilerError(
'This function has been removed, use loopy_odd_even_merge_sort instead')
def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False):
if n is None:
n = len(a)
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
chunks = {}
threads = []
def run_threads():
for thread in threads:
thread.start()
for thread in threads:
thread.join()
del threads[:]
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(sint.load_mem(base),
sint.load_mem(base + 1))
store_in_mem(x, base)
store_in_mem(y, base + 1)
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
return chunks[size](base)
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
base = 0
round_threads = []
for i in range(n_lower_size):
round_threads.append(run_chunk(lower_size, tmp_base + base))
base += lower_size
for i in range(n_chunks - n_lower_size):
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
base += lower_size + 2
run_threads_in_rounds(round_threads)
postproc_chunks = []
wrap_chunks = {}
post_threads = []
pre_threads = []
def load_and_store(x, y, to_right):
if to_right:
store_in_mem(sint.load_mem(x), y)
else:
store_in_mem(sint.load_mem(y), x)
def run_setup(k, a_addr, step, tmp_addr):
if k == 2:
def mem_op(preproc, a_addr, step, tmp_addr):
load_and_store(a_addr, tmp_addr, preproc)
load_and_store(a_addr + step, tmp_addr + 1, preproc)
res = 2
else:
def mem_op(preproc, a_addr, step, tmp_addr):
instructions.program.curr_tape.merge_opens = False
# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)):
for i in range(k - 2):
m = a_addr + step + i * step
load_and_store(m, tmp_addr + i, preproc)
res = k - 2
if not use_chunk_wraps or k <= 4:
mem_op(True, a_addr, step, tmp_addr)
postproc_chunks.append((mem_op, (a_addr, step, tmp_addr)))
else:
if k not in wrap_chunks:
pre_chunk = FunctionTape(mem_op, 'pre-%d' % k,
compile_args=[True])
post_chunk = FunctionTape(mem_op, 'post-%d' % k,
compile_args=[False])
wrap_chunks[k] = (pre_chunk, post_chunk)
pre_chunk, post_chunk = wrap_chunks[k]
pre_threads.append(pre_chunk(a_addr, step, tmp_addr))
post_threads.append(post_chunk(a_addr, step, tmp_addr))
return res
def run_threads_in_rounds(all_threads):
for thread in all_threads:
if len(threads) == n_threads:
run_threads()
threads.append(thread)
run_threads()
del all_threads[:]
def run_postproc():
run_threads_in_rounds(post_threads)
for chunk,args in postproc_chunks:
chunk(False, *args)
postproc_chunks[:] = []
l = 1
while l < n:
l *= 2
k = 1
while k < l:
k *= 2
size = 0
instructions.program.curr_tape.merge_opens = False
for i in range(n // l):
for j in range(l // k):
base = i * l + j
step = l // k
size += run_setup(k, a_base + base, step, tmp_base + size)
run_threads_in_rounds(pre_threads)
run_round(size)
run_postproc()
if isinstance(a, list):
for i in range(n):
a[i] = sint.load_mem(a_base + i)
instructions.program.free(a_base, 's')
instructions.program.free(tmp_base, 's')
raise CompilerError(
'This function has been removed, use loopy_odd_even_merge_sort instead')
def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7):
if n is None:
n = len(a)
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
tmp_i = instructions.program.malloc(1, 'ci')
chunks = {}
threads = []
def run_threads():
for thread in threads:
thread.start()
for thread in threads:
thread.join()
del threads[:]
def run_threads_in_rounds(all_threads):
for thread in all_threads:
if len(threads) == n_threads:
run_threads()
threads.append(thread)
run_threads()
del all_threads[:]
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(sint.load_mem(base),
sint.load_mem(base + 1))
store_in_mem(x, base)
store_in_mem(y, base + 1)
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
return chunks[size](base)
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
base = 0
round_threads = []
for i in range(n_lower_size):
round_threads.append(run_chunk(lower_size, tmp_base + base))
base += lower_size
for i in range(n_chunks - n_lower_size):
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
base += lower_size + 2
run_threads_in_rounds(round_threads)
l = 1
while l < n:
l *= 2
k = 1
while k < l:
k *= 2
def load_and_store(x, y):
if to_tmp:
store_in_mem(sint.load_mem(x), y)
else:
store_in_mem(sint.load_mem(y), x)
def outer(i):
def inner(j):
base = j + a_base + i * l
step = l // k
if k == 2:
tmp_addr = regint.load_mem(tmp_i)
load_and_store(base, tmp_addr)
load_and_store(base + step, tmp_addr + 1)
store_in_mem(tmp_addr + 2, tmp_i)
else:
def inner2(m):
m += base
tmp_addr = regint.load_mem(tmp_i)
load_and_store(m, tmp_addr)
store_in_mem(tmp_addr + 1, tmp_i)
range_loop(inner2, step, (k - 1) * step, step)
range_loop(inner, l // k)
instructions.program.curr_tape.merge_opens = False
to_tmp = True
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n // l)
if k == 2:
run_round(n)
else:
run_round(n // k * (k - 2))
instructions.program.curr_tape.merge_opens = False
to_tmp = False
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n // l)
if isinstance(a, list):
for i in range(n):
a[i] = sint.load_mem(a_base + i)
instructions.program.free(a_base, 's')
instructions.program.free(tmp_base, 's')
instructions.program.free(tmp_i, 'ci')
raise CompilerError(
'This function has been removed, use loopy_odd_even_merge_sort instead')
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
n_threads=None):
a_in = a
if isinstance(a_in, list):
a = Array.create_from(a)
steps = {}
l = sorted_length
while l < len(a):
@@ -816,8 +591,14 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
swap(m2, step)
steps[key] = step
steps[key](l)
if isinstance(a_in, list):
a_in[:] = list(a)
def mergesort(A):
if not get_program().options.insecure:
raise CompilerError('mergesort reveals the order of elements, '
'use --insecure to activate it')
B = Array(len(A), sint)
def merge(i_left, i_right, i_end):
@@ -845,12 +626,18 @@ def mergesort(A):
width.imul(2)
return width < len(A)
def range_loop(loop_body, start, stop=None, step=None):
def _range_prep(start, stop, step):
if stop is None:
stop = start
start = 0
if step is None:
step = 1
if util.is_zero(step):
raise CompilerError('step must not be zero')
return start, stop, step
def range_loop(loop_body, start, stop=None, step=None):
start, stop, step = _range_prep(start, stop, step)
def loop_fn(i):
res = loop_body(i)
return util.if_else(res == 0, stop, i + step)
@@ -859,8 +646,6 @@ def range_loop(loop_body, start, stop=None, step=None):
condition = lambda x: x < stop
elif step < 0:
condition = lambda x: x > stop
else:
raise CompilerError('step must not be zero')
else:
b = step > 0
condition = lambda x: b * (x < stop) + (1 - b) * (x > stop)
@@ -870,36 +655,34 @@ def range_loop(loop_body, start, stop=None, step=None):
# known loop count
if condition(start):
get_tape().req_node.children[-1].aggregator = \
lambda x: ((stop - start) // step) * x[0]
lambda x: int(ceil(((stop - start) / step))) * x[0]
def for_range(start, stop=None, step=None):
"""
Decorator to execute loop bodies consecutively. Arguments work as
in Python :py:func:`range`, but they can by any public
in Python :py:func:`range`, but they can be any public
integer. Information has to be passed out via container types such
as :py:class:`~Compiler.types.Array` or declaring registers as
:py:obj:`global`. Note that changing Python data structures such
as :py:class:`~Compiler.types.Array` or using :py:func:`update`.
Note that changing Python data structures such
as lists within the loop is not possible, but the compiler cannot
warn about this.
:param start/stop/step: regint/cint/int
Example:
.. code::
The following should output 10::
n = 10
a = sint.Array(n)
x = sint(0)
@for_range(n)
def _(i):
a[i] = i
global x
x += 1
x.update(x + 1)
print_ln('%s', x.reveal())
Note that you cannot overwrite data structures such as
:py:class:`~Compiler.types.Array` in a loop even when using
:py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign`
instead.
:py:class:`~Compiler.types.Array` in a loop. Use
:py:func:`~Compiler.types.Array.assign` instead.
"""
def decorator(loop_body):
range_loop(loop_body, start, stop, step)
@@ -909,11 +692,13 @@ def for_range(start, stop=None, step=None):
def for_range_parallel(n_parallel, n_loops):
"""
Decorator to execute a loop :py:obj:`n_loops` up to
:py:obj:`n_parallel` loop bodies in parallel.
:py:obj:`n_parallel` loop bodies with optimized communication in a
single thread.
In most cases, it is easier to use :py:func:`for_range_opt`.
Using any other control flow instruction inside the loop breaks
the optimization.
:param n_parallel: compile-time (int)
:param n_parallel: optimization parameter (int)
:param n_loops: regint/cint/int or list of int
Example:
@@ -937,7 +722,7 @@ def for_range_parallel(n_parallel, n_loops):
return for_range_multithread(None, n_parallel, n_loops)
return map_reduce_single(n_parallel, n_loops)
def for_range_opt(n_loops, budget=None):
def for_range_opt(start, stop=None, step=None, budget=None):
""" Execute loop bodies in parallel up to an optimization budget.
This prevents excessive loop unrolling. The budget is respected
even with nested loops. Note that the optimization is rather
@@ -947,8 +732,10 @@ def for_range_opt(n_loops, budget=None):
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
optimization.
:param n_loops: int/regint/cint
:param budget: number of instructions after which to start optimization (default is 100,000)
:param start/stop/step: int/regint/cint (used as in :py:func:`range`)
or :py:obj:`start` only as list/tuple of int (see below)
:param budget: number of instructions after which to start optimization
(default is 100,000)
Example:
@@ -968,6 +755,15 @@ def for_range_opt(n_loops, budget=None):
def f(i, j):
...
"""
if stop is not None:
start, stop, step = _range_prep(start, stop, step)
def wrapper(loop_body):
n_loops = (step - 1 + stop - start) // step
@for_range_opt(n_loops, budget=budget)
def _(i):
return loop_body(start + i * step)
return wrapper
n_loops = start
if isinstance(n_loops, (list, tuple)):
return for_range_opt_multithread(None, n_loops)
return map_reduce_single(None, n_loops, budget=budget)
@@ -1009,9 +805,11 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
def f(i):
state = tuplify(initializer())
start_block = get_block()
j = i * n_parallel
one = regint(1)
for k in range(n_parallel):
j = i * n_parallel + k
state = reducer(tuplify(loop_body(j)), state)
j += one
if n_parallel > 1 and start_block != get_block():
print('WARNING: parallelization broken '
'by control flow instruction')
@@ -1028,12 +826,16 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
state = tuplify(initializer())
k = 0
block = get_block()
assert not isinstance(n_loops, int) or n_loops > 0
pre = copy.copy(loop_body.__globals__)
while (not util.is_constant(n_loops) or k < n_loops) \
and (len(get_block()) < budget or k == 0) \
and block is get_block():
j = i + k
state = reducer(tuplify(loop_body(j)), state)
k += 1
RegintOptimizer().run(block.instructions, get_program())
_link(pre, loop_body.__globals__)
r = reducer(mem_state, state)
write_state_to_memory(r)
global n_opt_loops
@@ -1064,7 +866,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
del blocks[-n_to_merge + 1:]
del get_tape().req_node.children[-1]
merged.children = []
RegintOptimizer().run(merged.instructions)
RegintOptimizer().run(merged.instructions, get_program())
get_tape().active_basicblock = merged
else:
req_node = get_tape().req_node.children[-1].nodes[0]
@@ -1131,6 +933,15 @@ def for_range_opt_multithread(n_threads, n_loops):
@for_range_opt_multithread(2, [5, 3])
def f(i, j):
...
Note that you cannot use registers across threads. Use
:py:class:`MemValue` instead::
a = MemValue(sint(0))
@for_range_opt_multithread(8, 80)
def _(i):
b = a + 1
"""
return for_range_multithread(n_threads, None, n_loops)
@@ -1142,6 +953,7 @@ def multithread(n_threads, n_items=None, max_size=None):
:param n_threads: compile-time (int)
:param n_items: regint/cint/int (default: :py:obj:`n_threads`)
:param max_size: maximum size to be processed at once (default: no limit)
The following executes ``f(0, 8)``, ``f(8, 8)``, and
``f(16, 9)`` in three different threads:
@@ -1158,6 +970,7 @@ def multithread(n_threads, n_items=None, max_size=None):
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
else:
max_size = max(1, max_size)
def wrapper(function):
@multithread(n_threads, n_items)
def new_function(base, size):
@@ -1205,7 +1018,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if t != regint:
raise CompilerError('Not implemented for other than regint')
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
state = tuple(initializer())
state = initializer()
if len(state) == 0:
state_type = cint
elif isinstance(state, (tuple, list)):
state_type = type(state[0])
else:
state_type = type(state)
def f(inc):
base = args[get_arg()][0]
if not util.is_constant(thread_rounds):
@@ -1218,8 +1037,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint, \
args[get_arg()].address + 2)
mem_state = Array(len(state), type(state[0]) \
if state else cint, args[get_arg()][1])
mem_state = Array(len(state), state_type, args[get_arg()][1])
@map_reduce_single(n_parallel, thread_rounds + inc, \
initializer, reducer, mem_state)
def f(i):
@@ -1251,14 +1069,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
if state:
if len(state):
if thread_rounds:
for i in range(n_threads - remainder):
state = reducer(Array(len(state), type(state[0]), \
state = reducer(Array(len(state), state_type, \
args[remainder + i][1]), state)
if remainder:
for i in range(remainder):
state = reducer(Array(len(state), type(state[0]).reg_type, \
state = reducer(Array(len(state), state_type, \
args[i][1]), state)
def returner():
return untuplify(state)
@@ -1294,7 +1112,50 @@ def map_sum_opt(n_threads, n_loops, types):
"""
return map_sum(n_threads, None, n_loops, len(types), types)
def map_sum_simple(n_threads, n_loops, type, size):
""" Vectorized multi-threaded sum reduction. The following computes a
100 sums of ten squares in three threads::
@map_sum_simple(3, 10, sint, 100)
def summer(i):
return sint(regint.inc(100, i, 0)) ** 2
result = summer()
:param n_threads: number of threads (int)
:param n_loops: number of loop runs (regint/cint/int)
:param type: return type, must match the return statement
in the loop
:param size: vector size, must match the return statement
in the loop
"""
initializer = lambda: type(0, size=size)
def summer(*args):
assert len(args) == 2
args = list(args)
for i in (0, 1):
if isinstance(args[i], tuple):
assert len(args[i]) == 1
args[i] = args[i][0]
for i in (0, 1):
assert len(args[i]) == size
if isinstance(args[i], Array):
args[i] = args[i][:]
return args[0] + args[1]
return map_reduce(n_threads, 1, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector):
""" Round-efficient reduction in several threads. The following code
computes the maximum of an array in 10 threads::
tree_reduce_multithread(10, lambda x, y: x.max(y), a)
:param n_threads: number of threads (int)
:param function: reduction function taking exactly two arguments
:param vector: register vector or array
"""
inputs = vector.Array(len(vector))
inputs.assign_vector(vector)
outputs = vector.Array(len(vector) // 2)
@@ -1311,6 +1172,18 @@ def tree_reduce_multithread(n_threads, function, vector):
left = (left + 1) // 2
return inputs[0]
def tree_reduce(function, sequence):
""" Round-efficient reduction. The following computes the maximum
of the list :py:obj:`l`::
m = tree_reduce(lambda x, y: x.max(y), l)
:param function: reduction function taking two arguments
:param sequence: list, vector, or array
"""
return util.tree_reduce(function, sequence)
def foreach_enumerate(a):
""" Run-time loop over public data. This uses
``Player-Data/Public-Input/<progname>``. Example:
@@ -1338,61 +1211,57 @@ def foreach_enumerate(a):
return f
return decorator
def while_loop(loop_body, condition, arg, g=None):
def while_loop(loop_body, condition, arg=None, g=None):
if not callable(condition):
raise CompilerError('Condition must be callable')
# store arg in stack
pre_condition = condition(arg)
if not isinstance(pre_condition, (bool,int)) or pre_condition:
if arg is None:
pre_condition = condition()
def loop_fn():
loop_body()
return condition()
else:
pre_condition = condition(arg)
arg = regint(arg)
def loop_fn():
result = loop_body(arg)
result.link(arg)
cont = condition(result)
return cont
if isinstance(result, MemValue):
result = result.read()
arg.update(result)
return condition(result)
if not isinstance(pre_condition, (bool,int)) or pre_condition:
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
def while_do(condition, *args):
""" While-do loop. The decorator requires an initialization, and
the loop body function must return a suitable input for
:py:obj:`condition`.
""" While-do loop.
:param condition: function returning public integer (regint/cint/int)
:param args: arguments given to :py:obj:`condition` and loop body
The following executes an ten-fold loop:
.. code::
@while_do(lambda x: x < 10, regint(0))
def f(i):
i = regint(0)
@while_do(lambda: i < 10)
def f():
...
return i + 1
i.update(i + 1)
...
"""
def decorator(loop_body):
while_loop(loop_body, condition, *args)
return loop_body
return decorator
def do_loop(condition, loop_fn):
# store initial condition to stack
pushint(condition if isinstance(condition,regint) else regint(condition))
def wrapped_loop():
# save condition to stack
new_cond = regint.pop()
# run the loop
condition = loop_fn(new_cond)
pushint(condition)
return condition
do_while(wrapped_loop)
regint.pop()
def _run_and_link(function, g=None):
if g is None:
g = function.__globals__
import copy
pre = copy.copy(g)
res = function()
_link(pre, g)
return res
def _link(pre, g):
if g:
from .types import _single
for name, var in pre.items():
@@ -1402,7 +1271,6 @@ def _run_and_link(function, g=None):
raise CompilerError('cannot reassign constants in blocks')
if id(new_var) != id(var):
new_var.link(var)
return res
def do_while(loop_fn, g=None):
""" Do-while loop. The loop is stopped if the return value is zero.
@@ -1442,11 +1310,17 @@ def if_then(condition):
state = State()
if callable(condition):
condition = condition()
try:
if not condition.is_clear:
raise CompilerError('cannot branch on secret values')
except AttributeError:
pass
state.condition = regint.conv(condition)
state.start_block = instructions.program.curr_block
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
name='if-block')
state.has_else = False
state.caller = [frame[1:] for frame in inspect.stack()[1:]]
instructions.program.curr_tape.if_states.append(state)
def else_then():
@@ -1531,6 +1405,8 @@ def if_(condition):
def if_e(condition):
"""
Conditional execution with else block.
Use :py:class:`~Compiler.types.MemValue` to assign values that
live beyond.
:param condition: regint/cint/int
@@ -1538,12 +1414,13 @@ def if_e(condition):
.. code::
y = MemValue(0)
@if_e(x > 0)
def _():
...
y.write(1)
@else_
def _():
...
y.write(0)
"""
try:
condition = bool(condition)
@@ -1647,11 +1524,18 @@ def get_player_id():
return res
def listen_for_clients(port):
""" Listen for clients on specific port. """
""" Listen for clients on specific port base.
:param port: port base (int/regint/cint)
"""
instructions.listen(regint.conv(port))
def accept_client_connection(port):
""" Listen for clients on specific port. """
""" Accept client connection on specific port base.
:param port: port base (int/regint/cint)
:returns: client id
"""
res = regint()
instructions.acceptclientconnection(res, regint.conv(port))
return res
@@ -1779,7 +1663,9 @@ def sint_cint_division(a, b, k, f, kappa):
return (sign_a * sign_b) * A
def IntDiv(a, b, k, kappa=None):
return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k,
l = 2 * k + 1
b = a.conv(b)
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
kappa, nearest=True)
@instructions_base.ret_cisc
@@ -1801,24 +1687,25 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
theta = int(ceil(log(k/3.5) / log(2)))
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f)
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
y = a.extend(2 *k) * w
y = y.round(2*k, f, kappa, nearest, signed=True)
l_y = k + 3 * f - res_f
y = a.extend(l_y) * w
y = y.round(l_y, f, kappa, nearest, signed=True)
for i in range(theta - 1):
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.extend(l_y) * (alpha + x).extend(l_y)
x = x * x
y = y.round(2*k, 2*f, kappa, nearest, signed=True)
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
y = y.extend(l_y) * (alpha + x).extend(l_y)
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,8 @@ This has to imported explicitly.
import math
import operator
from functools import reduce
from Compiler import floatingpoint
from Compiler import types
from Compiler import comparison
@@ -290,12 +292,11 @@ def exp2_fx(a, zero_output=False, as19=False):
# how many bits to use from integer part
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
n_bits = a.f + n_int_bits
sint = types.sint
sint = a.int_type
if types.program.options.ring and not as19:
intbitint = types.intbitint
n_shift = int(types.program.options.ring) - a.k
if types.program.use_split():
assert not zero_output
from Compiler.GC.types import sbitvec
if types.program.use_split() == 3:
x = a.v.split_to_two_summands(a.k)
@@ -327,6 +328,7 @@ def exp2_fx(a, zero_output=False, as19=False):
s = sint.conv(bits[-1])
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
higher_bits = bits[a.f:n_bits]
bits_to_check = bits[n_bits:-1]
else:
if types.program.use_edabit():
l = sint.get_edabit(a.f, True)
@@ -338,7 +340,7 @@ def exp2_fx(a, zero_output=False, as19=False):
r_bits = [sint.get_random_bit() for i in range(a.k)]
r = sint.bit_compose(r_bits)
lower_r = sint.bit_compose(r_bits[:a.f])
shifted = ((a.v - r) << n_shift).reveal()
shifted = ((a.v - r) << n_shift).reveal(False)
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])
@@ -367,17 +369,17 @@ def exp2_fx(a, zero_output=False, as19=False):
bits = a.v.bit_decompose(a.k, maybe_mixed=True)
lower = sint.bit_compose(bits[:a.f])
higher_bits = bits[a.f:n_bits]
s = sint.conv(bits[-1])
s = a.bit_type.conv(bits[-1])
bits_to_check = bits[n_bits:-1]
if not as19:
c = types.sfix._new(lower, k=a.k, f=a.f)
c = a._new(lower, k=a.k, f=a.f)
assert(len(higher_bits) == n_bits - a.f)
pow2_bits = [sint.conv(x) for x in higher_bits]
d = floatingpoint.Pow2_from_bits(pow2_bits)
g = exp_from_parts(d, c)
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
2 ** n_int_bits, signed=False,
nearest=types.sfix.round_nearest),
nearest=a.round_nearest),
k=a.k, f=a.f)
if zero_output:
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
@@ -398,6 +400,36 @@ def exp2_fx(a, zero_output=False, as19=False):
return s.if_else(1 / g, g)
def mux_exp(x, y, block_size=8):
assert util.is_constant_float(x)
from Compiler.GC.types import sbitvec, sbits
bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v
sign = bits[-1]
m = math.log(2 ** (y.k - y.f - 1), x)
del bits[int(math.ceil(math.log(m, 2))) + y.f:]
parts = []
for i in range(0, len(bits), block_size):
one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v
exp = []
try:
for j in range(len(one_hot)):
exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f))
except OverflowError:
pass
exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp))
bin_part = [0] * max(x.bit_length() for x in exp)
for j in range(len(bin_part)):
for k, (a, b) in enumerate(zip(one_hot, exp)):
bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \
else 0
if util.is_zero(bin_part[j]):
bin_part[j] = sbits.get_type(y.size)(0)
if i == 0:
bin_part[j] = sign.if_else(0, bin_part[j])
parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part))))
return util.tree_reduce(operator.mul, parts)
@types.vectorize
@instructions_base.sfix_cisc
def log2_fx(x, use_division=True):
@@ -420,6 +452,8 @@ def log2_fx(x, use_division=True):
p -= x.f
vlen = x.f
v = x._new(v, k=x.k, f=x.f)
elif isinstance(x, (types._register, types.cfix)):
return log2_fx(types.sfix(x), use_division)
else:
d = types.sfloat(x)
v, p, vlen = d.v, d.p, d.vlen
@@ -501,7 +535,7 @@ def abs_fx(x):
#
# @return floored sint value of x
def floor_fx(x):
return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x))
return load_sint(floatingpoint.Trunc(x.v, x.k, x.f, x.kappa), type(x))
### sqrt methods
@@ -627,7 +661,7 @@ def sqrt_simplified_fx(x):
h = h * r
H = 4 * (h * h)
if not x.round_nearest or (2 * f < k - 1):
if not x.round_nearest or (2 * x.f < x.k - 1):
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
H = H * x
@@ -772,9 +806,7 @@ def sqrt_fx(x_l, k, f):
@instructions_base.sfix_cisc
def sqrt(x, k=None, f=None):
"""
Returns the square root (sfix) of any given fractional
value as long as it can be rounded to a integral value
with :py:obj:`f` bits of decimal precision.
Square root.
:param x: fractional input (sfix).
@@ -882,7 +914,7 @@ def SqrtComp(z, old=False):
k = len(z)
if isinstance(z[0], types.sint):
return types.sfix._new(sum(z[i] * types.cfix(
2 ** (-(i - f + 1) / 2)).v for i in range(k)))
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
k_prime = k // 2
f_prime = f // 2
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))

View File

@@ -1,7 +1,7 @@
from .comparison import *
from .floatingpoint import *
from .types import *
from . import comparison
from . import comparison, program
class NonLinear:
kappa = None
@@ -30,6 +30,17 @@ class NonLinear:
def trunc_pr(self, a, k, m, signed=True):
if isinstance(a, types.cint):
return shift_two(a, m)
prog = program.Program.prog
if prog.use_trunc_pr:
if not prog.options.ring:
prog.curr_tape.require_bit_length(k + prog.security)
if signed and prog.use_trunc_pr != -1:
a += (1 << (k - 1))
res = sint()
trunc_pr(res, a, k, m)
if signed and prog.use_trunc_pr != -1:
res -= (1 << (k - m - 1))
return res
return self._trunc_pr(a, k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
@@ -44,6 +55,9 @@ class NonLinear:
return a
return self._trunc(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return -self.trunc(a, k, k - 1, kappa, True)
class Masking(NonLinear):
def eqz(self, a, k):
c, r = self._mask(a, k)
@@ -100,42 +114,44 @@ class KnownPrime(NonLinear):
def _mod2m(self, a, k, m, signed):
if signed:
a += cint(1) << (k - 1)
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
return sint.bit_compose(self.bit_dec(a, k, m, True))
def _trunc_pr(self, a, k, m, signed):
# nearest truncation
return self.trunc_round_nearest(a, k, m, signed)
def _trunc(self, a, k, m, signed=None):
if signed:
a += cint(1) << (k - 1)
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
if signed:
res -= cint(1) << (k - 1 - m)
return res
return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
a += cint(1) << (m - 1)
if signed:
a += cint(1) << (k - 1)
k += 1
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
res = self._trunc(a, k, m, False)
if signed:
res -= cint(1) << (k - m - 2)
return res
def bit_dec(self, a, k, m, maybe_mixed=False):
assert k < self.prime.bit_length()
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
if len(bits) < m:
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
return bits[:m]
bits = BitDecFull(a, m, maybe_mixed=maybe_mixed)
assert len(bits) == m
return bits
def eqz(self, a, k):
# always signed
a += two_power(k)
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
def ltz(self, a, k, kappa=None):
if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
return self.mod2m(2 * a, k + 1, 1, False)
else:
return super(KnownPrime, self).ltz(a, k, kappa)
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
"""
@@ -172,3 +188,6 @@ class Ring(Masking):
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return LtzRing(a, k)

View File

@@ -348,7 +348,7 @@ class Entry(object):
def __len__(self):
return 2 + len(self.x)
def __repr__(self):
return '{empty=%s}' % self.is_empty if self.is_empty \
return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \
else '{%s: %s}' % (self.v, self.x)
def __add__(self, other):
try:
@@ -466,12 +466,14 @@ class AbstractORAM(object):
def get_array(size, t, *args, **kwargs):
return t.dynamic_array(size, t, *args, **kwargs)
def read(self, index):
return self._read(self.value_type.hard_conv(index))
res = self._read(self.index_type.hard_conv(index))
res = [self.value_type._new(x) for x in res]
return res
def write(self, index, value):
value = util.tuplify(value)
value = [self.value_type.conv(x) for x in value]
new_value = [self.value_type.get_type(length).hard_conv(v) \
for length,v in zip(self.entry_size, value \
if isinstance(value, (tuple, list)) \
else (value,))]
for length,v in zip(self.entry_size, value)]
return self._write(self.index_type.hard_conv(index), *new_value)
def access(self, index, new_value, write, new_empty=False):
return self._access(self.index_type.hard_conv(index),
@@ -795,18 +797,19 @@ class RefTrivialORAM(EndRecursiveEviction):
for i,value in enumerate(values):
index = MemValue(self.value_type.hard_conv(i))
new_value = [MemValue(self.value_type.hard_conv(v)) \
for v in (value if isinstance(value, (tuple, list)) \
for v in (value if isinstance(
value, (tuple, list, Array)) \
else (value,))]
self.ram[i] = Entry(index, new_value, value_type=self.value_type)
class TrivialORAM(RefTrivialORAM, AbstractORAM):
""" Trivial ORAM (obviously). """
ref_type = RefTrivialORAM
def __init__(self, size, value_type=sint, value_length=1, index_size=None, \
def __init__(self, size, value_type=None, value_length=1, index_size=None, \
entry_size=None, contiguous=True, init_rounds=-1):
self.index_size = index_size or log2(size)
self.value_type = value_type
self.index_type = value_type.get_type(self.index_size)
self.value_type = value_type or sint
self.index_type = self.value_type.get_type(self.index_size)
if entry_size is None:
self.value_length = value_length
self.entry_size = [None] * value_length
@@ -859,15 +862,16 @@ class LinearORAM(TrivialORAM):
empty_entry = self.empty_entry(False)
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type.get_type(l) for l in self.entry_size])
self.value_length + 1, t)
def f(i):
entry = self.ram[i]
access_here = self.index_vector[i]
return access_here * ValueTuple((entry.empty(),) + entry.x)
not_found = f()[0]
read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x
not_found = self.value_type.bit_type(f()[0])
read_value = ValueTuple(self.value_type.get_type(l)(x) for l, x in zip(self.entry_size, f()[1:])) + \
not_found * empty_entry.x
maybe_stop_timer(6)
return read_value, not_found
@method_block
@@ -876,7 +880,9 @@ class LinearORAM(TrivialORAM):
empty_entry = self.empty_entry(False)
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
new_value = make_array(new_value)
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
def f(i):
entry = self.ram[i]
@@ -892,7 +898,9 @@ class LinearORAM(TrivialORAM):
empty_entry = self.empty_entry(False)
index_vector = \
demux_array(bit_decompose(index, self.index_size))
new_value = make_array(new_value)
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
new_empty = MemValue(new_empty)
write = MemValue(write)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@@ -986,7 +994,8 @@ class List(EndRecursiveEviction):
for i,value in enumerate(values):
index = self.value_type.hard_conv(i)
new_value = [self.value_type.hard_conv(v) \
for v in (value if isinstance(value, (tuple, list)) \
for v in (value if isinstance(
value, (tuple, list, Array)) \
else (value,))]
self.__setitem__(index, new_value)
def __repr__(self):
@@ -1025,8 +1034,9 @@ def get_n_threads_for_tree(size):
class TreeORAM(AbstractORAM):
""" Tree ORAM. """
def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \
def __init__(self, size, value_type=None, value_length=1, entry_size=None, \
bucket_oram=TrivialORAM, init_rounds=-1):
value_type = value_type or sint
print('create oram of size', size)
self.bucket_oram = bucket_oram
# heuristic bucket size
@@ -1062,11 +1072,12 @@ class TreeORAM(AbstractORAM):
stop_timer(1)
start_timer()
self.root = RefBucket(1, self)
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
self.index = self.index_structure(size, self.D, self.index_type,
init_rounds, True)
self.read_value = Array(self.value_length, value_type)
self.read_value = Array(self.value_length, value_type.default_type)
self.read_non_empty = MemValue(self.value_type.bit_type(0))
self.state = MemValue(self.value_type(0))
self.state = MemValue(self.value_type.default_type(0))
@method_block
def add_to_root(self, state, is_empty, v, *x):
if len(x) != self.value_length:
@@ -1106,10 +1117,10 @@ class TreeORAM(AbstractORAM):
self.evict_bucket(RefBucket(p_bucket2, self), d)
@method_block
def read_and_renew_index(self, u):
l_star = random_block(self.D, self.value_type)
l_star = random_block(self.D, self.index_type)
if use_insecure_randomness:
new_path = regint.get_random(self.D)
l_star = self.value_type(new_path)
l_star = self.index_type(new_path)
self.state.write(l_star)
return self.index.update(u, l_star, evict=False).reveal()
@method_block
@@ -1120,7 +1131,7 @@ class TreeORAM(AbstractORAM):
parallel = get_parallel(self.index_size, *self.internal_value_type())
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
[self.value_type.default_type] * self.value_length)
def process(level):
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
bucket = RefBucket(b_index, self)
@@ -1142,9 +1153,9 @@ class TreeORAM(AbstractORAM):
Program.prog.curr_tape.start_new_basicblock()
crash()
def internal_value_type(self):
return self.value_type, self.value_length + 1
return self.value_type.default_type, self.value_length + 1
def internal_entry_size(self):
return self.value_type, [self.D] + list(self.entry_size)
return self.value_type.default_type, [self.D] + list(self.entry_size)
def n_buckets(self):
return 2**(self.D+1)
@method_block
@@ -1176,8 +1187,9 @@ class TreeORAM(AbstractORAM):
#print 'pre-add', self
maybe_start_timer(4)
self.add_to_root(state, entry.empty(), \
self.value_type(entry.v.read()), \
*(self.value_type(i.read()) for i in entry.x))
self.index_type(entry.v.read()), \
*(self.value_type.default_type(i.read())
for i in entry.x))
maybe_stop_timer(4)
#print 'pre-evict', self
if evict:
@@ -1221,28 +1233,35 @@ class TreeORAM(AbstractORAM):
""" Batch initalization. Obliviously shuffles and adds N entries to
random leaf buckets. """
m = len(values)
assert((m & (m-1)) == 0)
if not (m & (m-1)) == 0:
raise CompilerError('Batch size must a power of 2.')
if m != self.size:
raise CompilerError('Batch initialization must have N values.')
if self.value_type != sint:
raise CompilerError('Batch initialization only possible with sint.')
depth = log2(m)
leaves = [0] * m
entries = [0] * m
indexed_values = [0] * m
leaves = self.value_type.Array(m)
indexed_values = \
self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1)
# assign indices 0, ..., m-1
for i,value in enumerate(values):
@for_range(m)
def _(i):
value = values[i]
index = MemValue(self.value_type.hard_conv(i))
new_value = [MemValue(self.value_type.hard_conv(v)) \
for v in (value if isinstance(value, (tuple, list)) \
else (value,))]
indexed_values[i] = [index] + new_value
entries = sint.Matrix(self.bucket_size * 2 ** self.D,
len(Entry(0, list(indexed_values[0]), False)))
# assign leaves
for i,index_value in enumerate(indexed_values):
@for_range(len(indexed_values))
def _(i):
index_value = list(indexed_values[i])
leaves[i] = random_block(self.D, self.value_type)
index = index_value[0]
@@ -1252,18 +1271,20 @@ class TreeORAM(AbstractORAM):
# save unsorted leaves for position map
unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves]
permutation.sort(leaves, comp=permutation.normal_comparator)
leaves.sort()
bucket_sz = 0
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
B = [[0]*3 for i in range(m)]
B = sint.Matrix(m, 3)
B[0] = [0, leaves[0], 0]
B[-1] = [None, None, sint(1)]
s = 0
s = MemValue(sint(0))
for i in range(1, m):
@for_range_opt(m - 1)
def _(j):
i = j + 1
eq = leaves[i].equal(leaves[i-1])
s = (s + eq) * eq
s.write((s + eq) * eq)
B[i][0] = s
B[i][1] = leaves[i]
B[i-1][2] = 1 - eq
@@ -1271,7 +1292,7 @@ class TreeORAM(AbstractORAM):
#last_in_bucket[i-1] = 1 - eq
# shuffle
permutation.shuffle(B, value_type=sint)
B.secure_shuffle()
#cint(0).print_reg('shuf')
sz = MemValue(0) #cint(0)
@@ -1279,7 +1300,8 @@ class TreeORAM(AbstractORAM):
empty_positions = Array(nleaves, self.value_type)
empty_leaves = Array(nleaves, self.value_type)
for i in range(m):
@for_range(m)
def _(i):
if_then(reveal(B[i][2]))
#if B[i][2] == 1:
#cint(i).print_reg('last')
@@ -1291,12 +1313,13 @@ class TreeORAM(AbstractORAM):
empty_positions[szval] = B[i][0] #pos[i][0]
#empty_positions[szval].reveal().print_reg('ps0')
empty_leaves[szval] = B[i][1] #pos[i][1]
sz += 1
sz.iadd(1)
end_if()
pos_bits = []
pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2)
for i in range(nleaves):
@for_range_opt(nleaves)
def _(i):
leaf = empty_leaves[i]
# split into 2 if bucket size can't fit into one field elem
if self.bucket_size + Program.prog.security > 128:
@@ -1315,46 +1338,39 @@ class TreeORAM(AbstractORAM):
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
else:
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
pos_bits += [[b, leaf] for b in bucket_bits]
assert len(bucket_bits) == self.bucket_size
for j, b in enumerate(bucket_bits):
pos_bits[i * self.bucket_size + j] = [b, leaf]
# sort to get empty positions first
permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator)
pos_bits.sort(n_bits=1)
# now assign positions to empty entries
empty_entries = [0] * (self.bucket_size*2**self.D - m)
for i in range(self.bucket_size*2**self.D - m):
@for_range(len(entries) - m)
def _(i):
vtype, vlength = self.internal_value_type()
leaf = vtype(pos_bits[i][1])
# set leaf in empty entry for assigning after shuffle
value = tuple([leaf] + [vtype(0) for j in range(vlength)])
value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)])
entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype)
empty_entries[i] = entry
entries[m + i] = entry
# now shuffle, reveal positions and place entries
entries = entries + empty_entries
while len(entries) & (len(entries)-1) != 0:
entries.append(None)
permutation.shuffle(entries, value_type=sint)
entries = [entry for entry in entries if entry is not None]
clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries]
entries.secure_shuffle()
clear_leaves = Array.create_from(
Entry(entries.get_columns()).x[0].reveal())
Program.prog.curr_tape.start_new_basicblock()
bucket_sizes = Array(2**self.D, regint)
for i in range(2**self.D):
bucket_sizes[i] = 0
k = 0
for entry,leaf in zip(entries, clear_leaves):
leaf = leaf.read()
k += 1
# for some reason leaf_buckets is in bit-reversed order
bits = bit_decompose(leaf, self.D)
rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1]))
bucket = RefBucket(rev_leaf + (1 << self.D), self)
# hack: 1*entry ensures MemValues are converted to sints
bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry
@for_range_opt(len(entries))
def _(k):
leaf = clear_leaves[k]
bucket = RefBucket(leaf + (1 << self.D), self)
bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k])
bucket_sizes[leaf] += 1
self.index.batch_init([leaf.read() for leaf in unsorted_leaves])
@@ -1493,6 +1509,7 @@ class PackedIndexStructure(object):
self.l[i] = [0] * self.elements_per_block
time()
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
print_ln('packed ORAM init done')
print('index initialized, size', size)
def translate_index(self, index):
""" Bit slicing *index* according parameters. Output is tuple
@@ -1598,16 +1615,20 @@ class PackedIndexStructure(object):
def batch_init(self, values):
""" Initialize m values with indices 0, ..., m-1 """
m = len(values)
n_entries = max(1, m/self.entries_per_block)
new_values = [0] * n_entries
n_entries = max(1, m//self.entries_per_block)
new_values = sint.Matrix(n_entries, self.elements_per_block)
values = Array.create_from(values)
for i in range(n_entries):
@for_range(n_entries)
def _(i):
block = [0] * self.elements_per_block
for j in range(self.elements_per_block):
base = i * self.entries_per_block + j * self.entries_per_element
for k in range(self.entries_per_element):
if base + k < m:
block[j] += values[base + k] << (k * self.entry_size)
@if_(base + k < m)
def _():
block[j] += \
values[base + k] << (k * sum(self.entry_size))
new_values[i] = block
@@ -1661,13 +1682,51 @@ class OneLevelORAM(TreeORAM):
pattern after one recursion. """
index_structure = BaseORAMIndexStructure
class BinaryORAM:
def __init__(self, size, value_type=None, **kwargs):
import circuit_oram
from Compiler.GC import types
n_bits = int(get_program().options.binary)
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
self.index_type = self.value_type
oram_value_type = types.sbits.get_type(64)
if 'entry_size' not in kwargs:
kwargs['entry_size'] = n_bits
self.oram = circuit_oram.OptimalCircuitORAM(
size, value_type=oram_value_type, **kwargs)
self.size = size
def get_index(self, index):
return self.oram.value_type(self.index_type.conv(index).elements()[0])
def __setitem__(self, index, value):
value = list(self.oram.value_type(
self.value_type.conv(v).elements()[0]) for v in tuplify(value))
self.oram[self.get_index(index)] = value
def __getitem__(self, index):
value = self.oram[self.get_index(index)]
return untuplify(tuple(self.value_type(v) for v in tuplify(value)))
def read(self, index):
return self.oram.read(index)
def read_and_maybe_remove(self, index):
return self.oram.read_and_maybe_remove(index)
def access(self, *args):
return self.oram.access(*args)
def add(self, *args, **kwargs):
return self.oram.add(*args, **kwargs)
def delete(self, *args, **kwargs):
return self.oram.delete(*args, **kwargs)
def OptimalORAM(size,*args,**kwargs):
""" Create an ORAM instance suitable for the size based on
experiments.
:param size: number of elements
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn`
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
:py:class:`sfix`
"""
if not util.is_constant(size):
raise CompilerError('ORAM size has be a compile-time constant')
if get_program().options.binary:
return BinaryORAM(size, *args, **kwargs)
if optimal_threshold is None:
if n_threads == 1:
threshold = 2**11
@@ -1716,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
def test_oram(oram_type, N, value_type=sint, iterations=100):
stop_grind()
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
test_oram_initialized(oram, iterations)
return oram
def test_oram_initialized(oram, iterations=100):
N = oram.size
value_type = oram.value_type
value_type = value_type.get_type(32)
index_type = value_type.get_type(log2(N))
start_grind()
@@ -1783,7 +1848,7 @@ def test_batch_init(oram_type, N):
oram = oram_type(N, value_type)
print('initialized')
print_reg(cint(0), 'init')
oram.batch_init([value_type(i) for i in range(N)])
oram.batch_init(Array.create_from(sint(regint.inc(N))))
print_reg(cint(0), 'done')
@for_range(N)
def f(i):

File diff suppressed because it is too large Load Diff

View File

@@ -8,8 +8,11 @@ from functools import reduce
#import pdb
prog = program.Program.prog
prog.set_bit_length(min(64, prog.bit_length))
try:
prog = program.Program.prog
prog.set_bit_length(min(64, prog.bit_length))
except AttributeError:
pass
class Counter(object):
def __init__(self, val=0, max_val=None, size=None, value_type=sgf2n):
@@ -111,24 +114,6 @@ def bucket_size_sorter(x, y):
return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z])
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
""" Simulate secure shuffling with Waksman network for 2 players.
Returns the network switching config so it may be re-used later. """
n = len(x)
if n & (n-1) != 0:
raise CompilerError('shuffle requires n a power of 2')
if config is None:
config = permutation.configure_waksman(permutation.random_perm(n))
for i,c in enumerate(config):
config[i] = [value_type(b) for b in c]
permutation.waksman(x, config, reverse=reverse)
permutation.waksman(x, config, reverse=reverse)
return config
def LT(a, b):
a_bits = bit_decompose(a)
b_bits = bit_decompose(b)
@@ -472,10 +457,15 @@ class PathORAM(TreeORAM):
print_ln()
# shuffle entries and levels
while len(merged_entries) & (len(merged_entries)-1) != 0:
merged_entries.append(None) #self.root.bucket.empty_entry(False))
permutation.rec_shuffle(merged_entries, value_type=self.value_type)
merged_entries = [e for e in merged_entries if e is not None]
flat = []
for x in merged_entries:
flat += list(x[0]) + [x[1]]
flat = self.value_type(flat)
assert len(flat) % len(merged_entries) == 0
l = len(flat) // len(merged_entries)
shuffled = flat.secure_shuffle(l)
merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]]
for i in range(len(shuffled) // l)]
# need to copy entries/levels to memory for re-positioning
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)

View File

@@ -10,16 +10,6 @@ if '_Array' not in dir():
from Compiler.program import Program
_Array = Array
SORT_BITS = []
insecure_random = Random(0)
def predefined_comparator(x, y):
""" Assumes SORT_BITS is populated with the required sorting network bits """
if predefined_comparator.sort_bits_iter is None:
predefined_comparator.sort_bits_iter = iter(SORT_BITS)
return next(predefined_comparator.sort_bits_iter)
predefined_comparator.sort_bits_iter = None
def list_comparator(x, y):
""" Uses the first element in the list for comparison """
return x[0] < y[0]
@@ -37,10 +27,6 @@ def bitwise_comparator(x, y):
def cond_swap_bit(x,y, b):
""" swap if b == 1 """
if x is None:
return y, None
elif y is None:
return x, None
if isinstance(x, list):
t = [(xi - yi) * b for xi,yi in zip(x, y)]
return [xi - ti for xi,ti in zip(x, t)], \
@@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator):
else:
raise CompilerError('Length of list must be power of two')
def merge(a, b, comp):
""" General length merge (pads to power of 2) """
while len(a) & (len(a)-1) != 0:
a.append(None)
while len(b) & (len(b)-1) != 0:
b.append(None)
if len(a) < len(b):
a += [None] * (len(b) - len(a))
elif len(b) < len(a):
b += [None] * (len(b) - len(b))
t = a + b
odd_even_merge(t, comp)
for i,v in enumerate(t[::]):
if v is None:
t.remove(None)
return t
def sort(a, comp):
""" Pads to power of 2, sorts, removes padding """
length = len(a)
@@ -112,47 +81,12 @@ def sort(a, comp):
odd_even_merge_sort(a, comp)
del a[length:]
def recursive_merge(a, comp):
""" Recursively merge a list of sorted lists (initially sorted by size) """
if len(a) == 1:
return
# merge smallest two lists, place result in correct position, recurse
t = merge(a[0], a[1], comp)
del a[0]
del a[0]
added = False
for i,c in enumerate(a):
if len(c) >= len(t):
a.insert(i, t)
added = True
break
if not added:
a.append(t)
recursive_merge(a, comp)
# The following functionality for shuffling isn't used any more as it
# has been moved to the virtual machine. The code has been kept for
# reference.
def random_perm(n):
""" Generate a random permutation of length n
WARNING: randomness fixed at compile-time, this is NOT secure
"""
if not Program.prog.options.insecure:
raise CompilerError('no secure implementation of Waksman permution, '
'use --insecure to activate')
a = list(range(n))
for i in range(n-1, 0, -1):
j = insecure_random.randint(0, i)
t = a[i]
a[i] = a[j]
a[j] = t
return a
def inverse(perm):
inv = [None] * len(perm)
for i, p in enumerate(perm):
inv[p] = i
return inv
def configure_waksman(perm):
def configure_waksman(perm, n_iter=[0]):
top = n_iter == [0]
n = len(perm)
if n == 2:
return [(perm[0], perm[0])]
@@ -175,6 +109,7 @@ def configure_waksman(perm):
via = 0
j0 = j
while True:
n_iter[0] += 1
#print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2)
i = inv_perm[j]
@@ -209,8 +144,11 @@ def configure_waksman(perm):
assert sorted(p0) == list(range(n//2))
assert sorted(p1) == list(range(n//2))
p0_config = configure_waksman(p0)
p1_config = configure_waksman(p1)
p0_config = configure_waksman(p0, n_iter)
p1_config = configure_waksman(p1, n_iter)
if top:
print(n_iter[0], 'iterations for Waksman')
assert O[0] == 0, 'not a Waksman network'
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
def waksman(a, config, depth=0, start=0, reverse=False):
@@ -358,23 +296,10 @@ def iter_waksman(a, config, reverse=False):
# nblocks /= 2
# depth -= 1
def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
n = len(x)
if n & (n-1) != 0:
raise CompilerError('shuffle requires n a power of 2')
if config is None:
config = configure_waksman(random_perm(n))
for i,c in enumerate(config):
config[i] = [value_type.bit_type(b) for b in c]
waksman(x, config, reverse=reverse)
waksman(x, config, reverse=reverse)
def config_shuffle(n, value_type):
""" Compute config for oblivious shuffling.
Take mod 2 for active sec. """
perm = random_perm(n)
def config_from_perm(perm, value_type):
n = len(perm)
assert(list(sorted(perm))) == list(range(n))
if n & (n-1) != 0:
# pad permutation to power of 2
m = 2**int(math.ceil(math.log(n, 2)))
@@ -394,103 +319,3 @@ def config_shuffle(n, value_type):
for j,b in enumerate(c):
config[i * len(perm) + j] = b
return config
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
""" Simulate secure shuffling with Waksman network for 2 players.
WARNING: This is not a properly secure implementation but has roughly the right complexity.
Returns the network switching config so it may be re-used later. """
n = len(x)
m = 2**int(math.ceil(math.log(n, 2)))
assert n == m, 'only working for powers of two'
if config is None:
config = config_shuffle(n, value_type)
if isinstance(x, list):
if isinstance(x[0], list):
length = len(x[0])
assert len(x) == length
for i in range(length):
xi = Array(m, value_type.reg_type)
for j in range(n):
xi[j] = x[j][i]
for j in range(n, m):
xi[j] = value_type(0)
iter_waksman(xi, config, reverse=reverse)
iter_waksman(xi, config, reverse=reverse)
for j, y in enumerate(xi):
x[j][i] = y
else:
xa = Array(m, value_type.reg_type)
for i in range(n):
xa[i] = x[i]
for i in range(n, m):
xa[i] = value_type(0)
iter_waksman(xa, config, reverse=reverse)
iter_waksman(xa, config, reverse=reverse)
x[:] = xa
elif isinstance(x, Array):
if len(x) != m and config is None:
raise CompilerError('Non-power of 2 Array input not yet supported')
iter_waksman(x, config, reverse=reverse)
iter_waksman(x, config, reverse=reverse)
else:
raise CompilerError('Invalid type for shuffle:', type(x))
return config
def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None):
""" Shuffle a list of ORAM entries.
Randomly permutes the first "perm_size" entries, leaving the rest (empty
entry padding) in the same position. """
n = len(x)
l = len(x[0])
if n & (n-1) != 0:
raise CompilerError('Entries must be padded to power of two length.')
if perm_size is None:
perm_size = n
xarrays = [Array(n, value_type.reg_type) for i in range(l)]
for i in range(n):
for j,value in enumerate(x[i]):
if isinstance(value, MemValue):
xarrays[j][i] = value.read()
else:
xarrays[j][i] = value
if config is None:
config = config_shuffle(perm_size, value_type)
for xi in xarrays:
shuffle(xi, config, value_type, reverse)
for i in range(n):
x[i] = entry_cls(xarrays[j][i] for j in range(l))
return config
def sort_zeroes(bits, x, n_ones, value_type):
""" Return Array of values in "x" where the corresponding bit in "bits" is
a 0.
The total number of zeroes in "bits" must be known.
"bits" and "x" must be Arrays. """
config = config_shuffle(len(x), value_type)
shuffle(bits, config=config, value_type=value_type)
shuffle(x, config=config, value_type=value_type)
result = Array(n_ones, value_type.reg_type)
sz = MemValue(0)
last_x = MemValue(value_type(0))
#for i,b in enumerate(bits):
#if_then(b.reveal() == 0)
#result[sz.read()] = x[i]
#sz += 1
#end_if()
@for_range(len(bits))
def f(i):
found = (bits[i].reveal() == 0)
szval = sz.read()
result[szval] = last_x + (x[i] - last_x) * found
sz.write(sz + found)
last_x.write(result[szval])
return result

File diff suppressed because it is too large Load Diff

73
Compiler/sorting.py Normal file
View File

@@ -0,0 +1,73 @@
import itertools
from Compiler import types, library, instructions
def dest_comp(B):
Bt = B.transpose()
St_flat = Bt.get_vector().prefix_sum()
Tt_flat = Bt.get_vector() * St_flat.get_vector()
Tt = types.Matrix(*Bt.sizes, B.value_type)
Tt.assign_vector(Tt_flat)
return sum(Tt) - 1
def reveal_sort(k, D, reverse=False):
""" Sort in place according to "perfect" key. The name hints at the fact
that a random order of the keys is revealed.
:param k: vector or Array of sint containing exactly :math:`0,\dots,n-1`
in any order
:param D: Array or MultiArray to sort
:param reverse: wether :py:obj:`key` is a permutation in forward or
backward order
"""
assert len(k) == len(D)
library.break_point()
shuffle = types.sint.get_secure_shuffle(len(k))
k_prime = k.get_vector().secure_permute(shuffle).reveal()
idx = types.Array.create_from(k_prime)
if reverse:
D.assign_vector(D.get_slice_vector(idx))
library.break_point()
D.secure_permute(shuffle, reverse=True)
else:
D.secure_permute(shuffle)
library.break_point()
v = D.get_vector()
D.assign_slice_vector(idx, v)
library.break_point()
instructions.delshuffle(shuffle)
def radix_sort(k, D, n_bits=None, signed=True):
""" Sort in place according to key.
:param k: keys (vector or Array of sint or sfix)
:param D: Array or MultiArray to sort
:param n_bits: number of bits in keys (int)
:param signed: whether keys are signed (bool)
"""
assert len(k) == len(D)
bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits))
if signed and len(bs) > 1:
bs[-1][:] = bs[-1][:].bit_not()
radix_sort_from_matrix(bs, D)
def radix_sort_from_matrix(bs, D):
n = len(D)
for b in bs:
assert(len(b) == n)
B = types.sint.Matrix(n, 2)
h = types.Array.create_from(types.sint(types.regint.inc(n)))
@library.for_range(len(bs))
def _(i):
b = bs[i]
B.set_column(0, 1 - b.get_vector())
B.set_column(1, b.get_vector())
c = types.Array.create_from(dest_comp(B))
reveal_sort(c, h, reverse=False)
@library.if_e(i < len(bs) - 1)
def _():
reveal_sort(h, bs[i + 1], reverse=True)
@library.else_
def _():
reveal_sort(h, D, reverse=True)

804
Compiler/sqrt_oram.py Normal file
View File

@@ -0,0 +1,804 @@
from __future__ import annotations
import math
from abc import abstractmethod
from typing import Any, Generic, Type, TypeVar
from Compiler import library as lib
from Compiler import util
from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.program import Program
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
regint, sint, sintbit)
from Compiler.oram import demux_array, get_n_threads
# Adds messages on completion of heavy computation steps
debug = False
# Finer grained trace of steps that the ORAM performs
# + runtime error checks
# Warning: reveals information and makes the computation insecure
trace = False
n_threads = 16
n_parallel = 1024
# Avoids any memory allocation if set to False
# Setting to False prevents some optimizations but allows for controlling the ORAMs outside of the main tape
allow_memory_allocation = True
def get_n_threads(n_loops):
if n_threads is None:
if n_loops > 2048:
return 8
else:
return None
else:
return n_threads
T = TypeVar("T", sint, sbitint)
B = TypeVar("B", sintbit, sbit)
class SqrtOram(Generic[T, B]):
"""Oblivious RAM using the "Square-Root" algorithm.
:param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
:param sint value_type: The secret type to use, defaults to sint.
:param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
:param int period: Leave at None, this parameter is used to recursively pass down the top-level period.
"""
# TODO: Preferably this is an Array of vectors, but this is currently not supported
# One should regard these structures as Arrays where an entry may hold more
# than one value (which is a nice property to have when using the ORAM in
# practise).
shuffle: MultiArray
stash: MultiArray
# A block has an index and data
# `shuffle` and `stash` store the data,
# `shufflei` and `stashi` store the index
shufflei: Array
stashi: Array
shuffle_used: Array
position_map: PositionMap
# The size of the ORAM, i.e. how many elements it stores
n: int
# The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed
T: int
# Keep track of how far we are in the period, and coincidentally how large
# the stash is (each access results in a fake or real block being put on
# the stash)
t: cint
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None:
global debug, allow_memory_allocation
# Correctly initialize the shuffle (memory) depending on the type of data
if isinstance(data, MultiArray):
self.shuffle = data
self.n = len(data)
elif isinstance(data, sint):
self.n = math.ceil(len(data) // entry_length)
if (len(data) % entry_length != 0):
raise Exception('Data incorrectly padded.')
self.shuffle = MultiArray(
(self.n, entry_length), value_type=value_type)
self.shuffle.assign_part_vector(data.get_vector())
else:
raise Exception("Incorrect format.")
# Only sint is supported
if value_type != sint and value_type != sbitint:
raise Exception("The value_type must be either sint or sbitint")
# Set derived constants
self.value_type = value_type
self.bit_type: Type[B] = value_type.bit_type
self.index_size = util.log2(self.n) + 1 # +1 because signed
self.index_type = value_type.get_type(self.index_size)
self.entry_length = entry_length
self.size = self.n
if debug:
lib.print_ln(
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from(
[self.index_type(i) for i in range(self.n)])
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
self.T = int(math.ceil(
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
if debug and not period:
lib.print_ln('Period set to %s', self.T)
# Here we allocate the memory for the permutation
# Note that self.shuffle_the_shuffle mutates this field
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
self.permutation = Array.create_from(
[self.index_type(i) for i in range(self.n)])
# We allow the caller to postpone the initialization of the shuffle
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
# Note that if you do not initialize, the ORAM is insecure
if initialize:
# If the ORAM is not initialized with existing data, we can apply
# a small optimization by forgoing shuffling the shuffle, as all
# entries of the shuffle are equal and empty.
if empty_data:
random_shuffle = sint.get_secure_shuffle(self.n)
self.shufflei.secure_permute(random_shuffle)
self.permutation.assign(self.shufflei[:].inverse_permutation())
if trace:
lib.print_ln('Calculated inverse permutation')
else:
self.shuffle_the_shuffle()
else:
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
# Initialize position map (recursive oram)
self.position_map = PositionMap.create(self.permutation, k + 1, self.T)
# Initialize stash
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
self.stashi = Array(self.T, value_type=value_type)
self.t = MemValue(cint(0))
# Initialize temp variables needed during the computation
self.found_ = self.bit_type.Array(size=self.T)
self.j = MemValue(cint(0, size=1))
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
SqrtOram.shuffle_the_shuffle = lib.method_block(SqrtOram.shuffle_the_shuffle) if allow_memory_allocation else SqrtOram.shuffle_the_shuffle
SqrtOram.refresh = lib.method_block(SqrtOram.refresh) if allow_memory_allocation else SqrtOram.refresh
SqrtOram.reinitialize = lib.method_block(SqrtOram.reinitialize) if allow_memory_allocation else SqrtOram.reinitialize
@lib.method_block
def access(self, index: T, write: B, *value: T):
global trace,n_parallel
if trace:
@lib.if_e(write.reveal() == 1)
def _():
lib.print_ln('Writing to secret index %s', index.reveal())
@lib.else_
def __():
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value, size=self.entry_length).get_vector(
0, size=self.entry_length)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# If we are writing, we need to add the value
self.stash[i] += write * access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash')
@lib.else_
def __():
lib.print_ln('Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += write * \
found.bit_not() * (value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
@lib.if_((write * found.bit_not()).reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
@lib.method_block
def write(self, index: T, *value: T):
global trace, n_parallel
if trace:
lib.print_ln('Writing to secret index %s', index.reveal())
if isinstance(value, tuple) or isinstance(value,list):
value = self.value_type(value, size=self.entry_length)
print(value, type(value))
elif isinstance(value, self.value_type):
value = self.value_type(*value, size=self.entry_length)
print(value, type(value))
else:
raise Exception("Cannot handle type of value passed")
print(self.entry_length, value, type(value),len(value))
value = MemValue(value)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in an bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# We update the stash value
self.stash[i] += access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash')
@lib.else_
def __():
lib.print_ln('Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += found.bit_not() * \
(value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
@lib.if_(found.bit_not().reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
@lib.method_block
def read(self, index: T, *value: T):
global debug, trace, n_parallel
if trace:
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
if debug:
lib.print_ln('Refreshing SqrtORAM')
lib.print_ln('t=%s according to me', self.t)
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check whether the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
lib.check_point()
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
# @lib.for_range(self.t)
# def _(i):
# lib.print_ln("stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash (found=%s)', found.reveal())
@lib.else_
def __():
lib.print_ln('Did not find item in stash (found=%s)', found.reveal())
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
__getitem__ = read
__setitem__ = write
def shuffle_the_shuffle(self) -> None:
"""Permute the memory using a newly generated permutation and return
the permutation that would generate this particular shuffling.
This permutation is needed to know how to map logical addresses to
physical addresses, and is used as such by the postition map."""
global trace
# Random permutation on n elements
random_shuffle = sint.get_secure_shuffle(self.n)
if trace:
lib.print_ln('Generated shuffle')
# Apply the random permutation
self.shuffle.secure_permute(random_shuffle)
if trace:
lib.print_ln('Shuffled shuffle')
self.shufflei.secure_permute(random_shuffle)
if trace:
lib.print_ln('Shuffled shuffle indexes')
lib.check_point()
# Calculate the permutation that would have produced the newly produced
# shuffle order. This can be calculated by regarding the logical
# indexes (shufflei) as a permutation and calculating its inverse,
# i.e. find P such that P([1,2,3,...]) = shufflei.
# this is not necessarily equal to the inverse of the above generated
# random_shuffle, as the shuffle may already be out of order (e.g. when
# refreshing).
self.permutation.assign(self.shufflei[:].inverse_permutation())
# If shufflei does not contain exactly the indices
# [i for i in range(self.n)],
# the underlying waksman network of 'inverse_permutation' will hang.
if trace:
lib.print_ln('Calculated inverse permutation')
def refresh(self):
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
reshuffling the shuffle.
This must happen on the T'th (period) accesses to the ORAM."""
self.j.write(0)
# Shuffle and emtpy the stash, and store elements back into shuffle
@lib.for_range_opt(self.n)
def _(i):
@lib.if_(self.shuffle_used[i])
def _():
self.shuffle[i] = self.stash[self.j]
self.shufflei[i] = self.stashi[self.j]
self.j += 1
# Reset the clock
self.t.write(0)
# Reset shuffle_used
self._reset_shuffle_used()
# Reinitialize position map
self.shuffle_the_shuffle()
# Note that we skip here the step of "packing" the permutation.
# Since the underlying memory of the position map is already aligned in
# this packed structure, we can simply overwrite the memory while
# maintaining the structure.
self.position_map.reinitialize(*self.permutation)
def reinitialize(self, *data: T):
# Note that this method is only used during refresh, and as such is
# only called with a permutation as data.
# The logical addresses of some previous permutation are irrelevant and must be reset
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
# Reset the clock
self.t.write(0)
# Reset shuffle_used
self._reset_shuffle_used()
# Note that the self.shuffle is actually a MultiArray
# This structure is preserved while overwriting the values using
# assign_vector
self.shuffle.assign_vector(self.value_type(
data, size=self.n * self.entry_length))
# Note that this updates self.permutation (see constructor for explanation)
self.shuffle_the_shuffle()
self.position_map.reinitialize(*self.permutation)
def _reset_shuffle_used(self):
global allow_memory_allocation
if allow_memory_allocation:
self.shuffle_used.assign_all(0)
else:
@lib.for_range_opt(self.n)
def _(i):
self.shuffle_used[i] = cint(0)
class PositionMap(Generic[T, B]):
PACK_LOG: int = 3
PACK: int = 1 << PACK_LOG
n: int # n in the paper
depth: cint # k in the paper
value_type: Type[T]
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
self.n = n
self.depth = MemValue(cint(k))
self.value_type = value_type
self.bit_type = value_type.bit_type
self.index_type = self.value_type.get_type(util.log2(n) + 1) # +1 because signed
@abstractmethod
def get_position(self, logical_address: _secret, fake: B) -> Any:
"""Retrieve the block at the given (secret) logical address."""
global trace
if trace:
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, *permutation: T):
"""Reinitialize this PositionMap.
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
we cannot simply call __init__ on self. Instead, we must take care to
reuse and overwrite the same memory.
"""
...
@classmethod
def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = sint) -> PositionMap:
"""Creates a new PositionMap. This is the method one should call when
needing a new position map. Depending on the size of the given data, it
will either instantiate a RecursivePositionMap or
a LinearPositionMap."""
n = len(permutation)
global debug
if n / PositionMap.PACK <= period:
if debug:
lib.print_ln(
'Initializing LinearPositionMap at depth %s of size %s', k, n)
res = LinearPositionMap(permutation, value_type, k=k)
else:
if debug:
lib.print_ln(
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
res = RecursivePositionMap(permutation, period, value_type, k=k)
return res
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None:
PositionMap.__init__(self, len(permutation), k=k)
pack = PositionMap.PACK
# We pack the permutation into a smaller structure, index with a new permutation
packed_size = int(math.ceil(self.n / pack))
packed_structure = MultiArray(
(packed_size, pack), value_type=value_type)
for i in range(packed_size):
packed_structure[i] = Array.create_from(
permutation[i*pack:(i+1)*pack])
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth)
# Initialize random temp variables needed during the computation
self.block_index_demux: Array = self.bit_type.Array(self.T)
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
super().get_position(logical_address, fake)
pack = PositionMap.PACK
pack_log = PositionMap.PACK_LOG
# The item at logical_address
# will be in block with index h (block.<h>)
# at position l in block.data (block.data<l>)
program = Program.prog
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
logical_address).right_shift(pack_log, program.bit_length)))
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
global trace
if trace:
print_at_depth(self.depth, '-> logical_address=%s: h=%s, l=%s', logical_address.reveal(), h.reveal(), l.reveal())
# @lib.for_range(self.t)
# def _(i):
# print_at_depth(self.depth, "stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
# The resulting physical address
p = MemValue(self.index_type(-1))
found: B = MemValue(self.bit_type(False))
# First we try and retrieve the item from the stash at position stash[h][l]
# Since h and l are secret, we do this by scanning the entire stash
# First we scan the stash for the block we need
self.block_index_demux.assign_all(0)
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
def _(i):
self.block_index_demux[i] = ( self.stashi[i] == h) & self.bit_type(i < self.t)
# We can determine if the 'index' is in the stash by checking the
# block_index_demux array
found = sum(self.block_index_demux)
# Once a block is found, we use the following condition to pick the correct item from that block
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
# Finally we use the conditions to conditionally write p
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
def p_(i):
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
# Therefore we include the check (i < self.t)
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i // pack< self.t)
p.write(p_())
if trace:
@lib.if_e(found.reveal() == 0)
def _(): print_at_depth(self.depth, 'Retrieve shuffle[%s]:', h.reveal())
@lib.else_
def __():
print_at_depth(self.depth, 'Retrieve dummy element from shuffle:')
# Then we try and retrieve the item from the shuffle (the actual memory)
# Depending on whether we found the item in the stash, we either
# block 'h' in which 'index' resides, or a random block from the shuffle
p_prime = self.position_map.get_position(h, found)
self.shuffle_used[p_prime] = cbit(True)
# The block retrieved from the shuffle
block_p_prime: Array = self.shuffle[p_prime]
if trace:
@lib.if_e(found.reveal() == 0)
def _():
print_at_depth(self.depth, 'Retrieved position from shuffle[%s]=(%s: %s)',
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
@lib.else_
def __():
print_at_depth(self.depth, 'Retrieved dummy position from shuffle[%s]=(%s: %s)',
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
# We add the retrieved block from the shuffle to the stash
self.stash[self.t].assign(block_p_prime[:])
self.stashi[self.t] = self.shufflei[p_prime]
# Increase t
self.t += 1
# if found or not fake
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
# Retrieve l'th item from block
# l is secret, so we must use linear scan
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
pack)) & condition.expand_to_vector(pack))
@lib.for_range_opt(pack)
def _(i):
p.write((hit[i]).if_else(block_p_prime[i], p))
return p.reveal()
def reinitialize(self, *permutation: T):
SqrtOram.reinitialize(self, *permutation)
class LinearPositionMap(PositionMap):
physical: Array
used: Array
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
PositionMap.__init__(self, len(data), value_type, k=k)
self.physical = data
self.used = self.bit_type.Array(self.n)
# Initialize random temp variables needed during the computation
self.physical_demux: Array = self.bit_type.Array(self.n)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
"""
This method corresponds to GetPosBase in the paper.
"""
super().get_position(logical_address, fake)
global trace
if trace:
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
def _():
lib.runtime_error(
'logical_address must lie between 0 and self.n - 1')
fake = MemValue(self.bit_type(fake))
logical_address = MemValue(logical_address)
p: MemValue = MemValue(self.index_type(-1))
done: B = self.bit_type(False)
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
self.physical_demux.assign_all(0)
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
def condition_i(i):
self.physical_demux[i] = \
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
| (fake & self.used[i].bit_not())
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
# We only need once, so we pick the first one we find
@lib.for_range_opt(self.n)
def _(i):
self.physical_demux[i] &= done.bit_not()
done.update(done | self.physical_demux[i])
# Retrieve the value from the physical memory obliviously
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * self.physical_demux[i]
p.write(calc_p())
# Update self.used
self.used.assign(self.used[:] | self.physical_demux[:])
if trace:
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
def _():
lib.runtime_error(
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
return p.reveal()
def reinitialize(self, *data: T):
self.physical.assign_vector(data)
global allow_memory_allocation
if allow_memory_allocation:
self.used.assign_all(False)
else:
@lib.for_range_opt(self.n)
def _(i):
self.used[i] = self.bit_type(0)
def print_at_depth(depth: cint, message: str, *kwargs):
lib.print_str('%s', depth)
@lib.for_range(depth)
def _(i):
lib.print_char(' ')
lib.print_char(' ')
lib.print_ln(message, *kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -116,6 +116,11 @@ def round_to_int(x):
return x.round_to_int()
def tree_reduce(function, sequence):
try:
return sequence.tree_reduce(function)
except AttributeError:
pass
sequence = list(sequence)
assert len(sequence) > 0
n = len(sequence)
@@ -233,6 +238,9 @@ def mem_size(x):
except AttributeError:
return 1
def find_in_dict(d, v):
return list(d.keys())[list(d.values()).index(v)]
class set_by_id(object):
def __init__(self, init=[]):
self.content = {}
@@ -257,6 +265,9 @@ class set_by_id(object):
def pop(self):
return self.content.popitem()[1]
def remove(self, value):
del self.content[id(value)]
def __ior__(self, values):
for value in values:
self.add(value)

144
Dockerfile Normal file
View File

@@ -0,0 +1,144 @@
###############################################################################
# Build this stage for a build environment, e.g.: #
# #
# docker build --tag mpspdz:buildenv --target buildenv . #
# #
# The above is equivalent to: #
# #
# docker build --tag mpspdz:buildenv \ #
# --target buildenv \ #
# --build-arg arch=native \ #
# --build-arg cxx=clang++-11 \ #
# --build-arg use_ntl=0 \ #
# --build-arg prep_dir="Player-Data" \ #
# --build-arg ssl_dir="Player-Data" #
# --build-arg cryptoplayers=0 #
# #
# To build for an x86-64 architecture, with g++, NTL (for HE), custom #
# prep_dir & ssl_dir, and to use encrypted channels for 4 players: #
# #
# docker build --tag mpspdz:buildenv \ #
# --target buildenv \ #
# --build-arg arch=x86-64 \ #
# --build-arg cxx=g++ \ #
# --build-arg use_ntl=1 \ #
# --build-arg prep_dir="/opt/prepdata" \ #
# --build-arg ssl_dir="/opt/ssl" #
# --build-arg cryptoplayers=4 . #
# #
# To work in a container to build different machines, and compile programs: #
# #
# docker run --rm -it mpspdz:buildenv bash #
# #
# Once in the container, build a machine and compile a program: #
# #
# $ make replicated-ring-party.x #
# $ ./compile.py -R 64 tutorial #
# #
###############################################################################
FROM python:3.10.3-bullseye as buildenv
RUN apt-get update && apt-get install -y --no-install-recommends \
automake \
build-essential \
clang-11 \
cmake \
git \
libboost-dev \
libboost-thread-dev \
libclang-dev \
libgmp-dev \
libntl-dev \
libsodium-dev \
libssl-dev \
libtool \
vim \
gdb \
valgrind \
&& rm -rf /var/lib/apt/lists/*
ENV MP_SPDZ_HOME /usr/src/MP-SPDZ
WORKDIR $MP_SPDZ_HOME
RUN pip install --upgrade pip ipython
COPY . .
ARG arch=native
ARG cxx=clang++-11
ARG use_ntl=0
ARG prep_dir="Player-Data"
ARG ssl_dir="Player-Data"
RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \
&& echo "CXX = ${cxx}" >> CONFIG.mine \
&& echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \
&& echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \
&& echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \
>> CONFIG.mine \
&& mkdir -p $prep_dir $ssl_dir \
&& echo "PREP_DIR = '-DPREP_DIR=\"${prep_dir}/\"'" >> CONFIG.mine \
&& echo "SSL_DIR = '-DSSL_DIR=\"${ssl_dir}/\"'" >> CONFIG.mine
# ssl keys
ARG cryptoplayers=0
ENV PLAYERS ${cryptoplayers}
RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir}
RUN make boost libote
###############################################################################
# Use this stage to a build a specific virtual machine. For example: #
# #
# docker build --tag mpspdz:shamir \ #
# --target machine \ #
# --build-arg machine=shamir-party.x \ #
# --build-arg gfp_mod_sz=4 . #
# #
# The above will build shamir-party.x with 256 bit length. #
# #
# If no build arguments are passed (via --build-arg), mascot-party.x is built #
# with the default 128 bit length. #
###############################################################################
FROM buildenv as machine
ARG machine="mascot-party.x"
ARG gfp_mod_sz=2
RUN echo "MOD = -DGFP_MOD_SZ=${gfp_mod_sz}" >> CONFIG.mine
RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/
################################################################################
# This is the default stage. Use it to compile a high-level program. #
# By default, tutorial.mpc is compiled with --field=64 bits. #
# #
# docker build --tag mpspdz:mascot-tutorial \ #
# --build-arg src=tutorial \ #
# --build-arg compile_options="--field=64" . #
# #
# Note that build arguments from previous stages can also be passed. For #
# instance, building replicated-ring-party.x, for 3 crypto players with custom #
# PREP_DIR and SSL_DIR, and compiling tutorial.mpc with --ring=64: #
# #
# docker build --tag mpspdz:replicated-ring \ #
# --build-arg machine=replicated-ring-party.x \ #
# --build-arg prep_dir=/opt/prep \ #
# --build-arg ssl_dir=/opt/ssl \ #
# --build-arg cryptoplayers=3 \ #
# --build-arg compile_options="--ring=64" . #
# #
# Test it: #
# #
# docker run --rm -it mpspdz:replicated-ring ./Scripts/ring.sh tutorial #
################################################################################
FROM machine as program
ARG src="tutorial"
ARG compile_options="--field=64"
RUN ./compile.py ${compile_options} ${src}
RUN mkdir -p Player-Data \
&& echo 1 2 3 4 > Player-Data/Input-P0-0 \
&& echo 1 2 3 4 > Player-Data/Input-P1-0

View File

@@ -24,4 +24,5 @@ int main()
generate_mac_keys<Share<P256Element::Scalar>>(key, 2, prefix, G);
make_mult_triples<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix, G);
make_inverse<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix, G);
P256Element::finish();
}

View File

@@ -14,7 +14,14 @@ void P256Element::init()
curve = EC_GROUP_new_by_curve_name(NID_secp256k1);
assert(curve != 0);
auto modulus = EC_GROUP_get0_order(curve);
Scalar::init_field(BN_bn2dec(modulus), false);
auto mod = BN_bn2dec(modulus);
Scalar::init_field(mod, false);
free(mod);
}
void P256Element::finish()
{
EC_GROUP_free(curve);
}
P256Element::P256Element()
@@ -29,7 +36,7 @@ P256Element::P256Element(const Scalar& other) :
{
BIGNUM* exp = BN_new();
BN_dec2bn(&exp, bigint(other).get_str().c_str());
assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0);
assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0);
BN_free(exp);
}
@@ -38,10 +45,15 @@ P256Element::P256Element(word other) :
{
BIGNUM* exp = BN_new();
BN_dec2bn(&exp, to_string(other).c_str());
assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0);
assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0);
BN_free(exp);
}
P256Element::~P256Element()
{
EC_POINT_free(point);
}
P256Element& P256Element::operator =(const P256Element& other)
{
assert(EC_POINT_copy(point, other.point) != 0);
@@ -56,7 +68,11 @@ void P256Element::check()
P256Element::Scalar P256Element::x() const
{
BIGNUM* x = BN_new();
#if OPENSSL_VERSION_MAJOR >= 3
assert(EC_POINT_get_affine_coordinates(curve, point, x, 0, 0) != 0);
#else
assert(EC_POINT_get_affine_coordinates_GFp(curve, point, x, 0, 0) != 0);
#endif
char* xx = BN_bn2dec(x);
Scalar res((bigint(xx)));
OPENSSL_free(xx);
@@ -95,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const
return not cmp;
}
void P256Element::pack(octetStream& os) const
void P256Element::pack(octetStream& os, int) const
{
octet* buffer;
size_t length = EC_POINT_point2buf(curve, point,
@@ -103,9 +119,10 @@ void P256Element::pack(octetStream& os) const
assert(length != 0);
os.store_int(length, 8);
os.append(buffer, length);
free(buffer);
}
void P256Element::unpack(octetStream& os)
void P256Element::unpack(octetStream& os, int)
{
size_t length = os.get_int(8);
assert(

View File

@@ -22,20 +22,23 @@ private:
EC_POINT* point;
public:
typedef void next;
typedef P256Element next;
typedef void Square;
static const true_type invertible;
static int size() { return 0; }
static int length() { return 256; }
static string type_string() { return "P256"; }
static void init();
static void finish();
P256Element();
P256Element(const P256Element& other);
P256Element(const Scalar& other);
P256Element(word other);
~P256Element();
P256Element& operator=(const P256Element& other);
@@ -55,10 +58,10 @@ public:
void assign_zero() { *this = {}; }
bool is_zero() { return *this == P256Element(); }
void add(octetStream& os) { *this += os.get<P256Element>(); }
void add(octetStream& os, int = -1) { *this += os.get<P256Element>(); }
void pack(octetStream& os) const;
void unpack(octetStream& os);
void pack(octetStream& os, int = -1) const;
void unpack(octetStream& os, int = -1);
octetStream hash(size_t n_bytes) const;

View File

@@ -24,8 +24,8 @@ The following binaries have been used for the paper:
All binaries offer the same interface. With MASCOT for example, run
the following:
```
./mascot-ecsda-party.x -p 0 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
./mascot-ecsda-party.x -p 1 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
./mascot-ecdsa-party.x -p 0 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
./mascot-ecdsa-party.x -p 1 [-N <number of parties>] [-h <host of party 0>] [-D] [<number of prep tuples>]
...
```

View File

@@ -45,12 +45,13 @@ int main(int argc, const char** argv)
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
read_mac_key(prefix, N, keyp);
pShare::MAC_Check::setup(P);
Share<P256Element>::MAC_Check::setup(P);
DataPositions usage;
Sub_Data_Files<pShare> prep(N, prefix, usage);
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
BaseMachine machine;
machine.ot_setups.push_back({P, false});
SubProcessor<pShare> proc(_, MCp, prep, P);
pShare sk, __;
@@ -60,4 +61,8 @@ int main(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P, opts);
pShare::MAC_Check::teardown();
Share<P256Element>::MAC_Check::teardown();
P256Element::finish();
}

View File

@@ -30,6 +30,8 @@
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/MalRep.hpp"
#include "Machines/Rep.hpp"
#include <assert.h>
@@ -52,10 +54,10 @@ void run(int argc, const char** argv)
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
DataPositions usage;
@@ -69,4 +71,5 @@ void run(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
// check(tuples, sk, {}, P);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
P256Element::finish();
}

View File

@@ -5,6 +5,8 @@
#define NO_MIXED_CIRCUITS
#define NO_SECURITY_CHECK
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/VectorInput.h"

View File

@@ -92,9 +92,6 @@ void run(int argc, const char** argv)
P256Element::init();
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
BaseMachine machine;
machine.ot_setups.push_back({P, true});
P256Element::Scalar keyp;
SeededPRNG G;
keyp.randomize(G);
@@ -102,6 +99,9 @@ void run(int argc, const char** argv)
typedef T<P256Element::Scalar> pShare;
DataPositions usage;
pShare::MAC_Check::setup(P);
T<P256Element>::MAC_Check::setup(P);
OnlineOptions::singleton.batch_size = 1;
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
@@ -113,10 +113,10 @@ void run(int argc, const char** argv)
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
sk_prep.get_two(DATA_INVERSE, sk, __);
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
typename pShare::TriplePrep prep(0, usage);
@@ -137,4 +137,8 @@ void run(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
//check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
pShare::MAC_Check::teardown();
T<P256Element>::MAC_Check::teardown();
P256Element::finish();
}

View File

@@ -41,8 +41,8 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
timer.start();
Player& P = proc.P;
auto& prep = proc.DataF;
size_t start = P.sent + prep.data_sent();
auto stats = P.comm_stats + prep.comm_stats();
size_t start = P.total_comm().sent;
auto stats = P.total_comm();
auto& extra_player = P;
auto& protocol = proc.protocol;
@@ -77,7 +77,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul(&proc);
protocol.init_mul();
for (int i = 0; i < buffer_size; i++)
protocol.prepare_mul(inv_ks[i], sk);
protocol.start_exchange();
@@ -106,9 +106,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
timer.stop();
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
<< 1e-3 * (P.total_comm().sent - start) / buffer_size
<< " kbytes per tuple" << endl;
(P.comm_stats + prep.comm_stats() - stats).print(true);
(P.total_comm() - stats).print(true);
}
template<template<class U> class T>

View File

@@ -10,6 +10,7 @@
#include "Protocols/SemiPrep.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "GC/SemiSecret.hpp"
#include "ot-ecdsa-party.hpp"
#include <assert.h>

View File

@@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length,
(void) pk;
Timer timer;
timer.start();
size_t start = P.sent;
auto stats = P.comm_stats;
auto stats = P.total_comm();
EcSignature signature;
vector<P256Element> opened_R;
if (opts.R_after_msg)
@@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length,
auto& protocol = proc->protocol;
if (proc)
{
protocol.init_mul(proc);
protocol.init_mul();
protocol.prepare_mul(sk, tuple.a);
protocol.start_exchange();
}
@@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length,
auto rx = tuple.R.x();
signature.s = MC.open(
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
auto diff = (P.total_comm() - stats);
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
<< (P.sent - start) << " bytes" << endl;
auto diff = (P.comm_stats - stats);
<< diff.sent << " bytes" << endl;
diff.print(true);
return signature;
}
@@ -139,11 +138,11 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
P256Element pk = MCc.open(sk, P);
MCc.Check(P);
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
{
@@ -154,13 +153,12 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
Timer timer;
timer.start();
auto& check_player = MCp.get_check_player(P);
auto stats = check_player.comm_stats;
auto start = check_player.sent;
auto stats = check_player.total_comm();
MCp.Check(P);
MCc.Check(P);
auto diff = (check_player.total_comm() - stats);
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
<< (check_player.sent - start) << " bytes" << endl;
auto diff = (check_player.comm_stats - stats);
<< diff.sent << " bytes" << endl;
diff.print();
}
}

View File

@@ -8,24 +8,92 @@
#include "Networking/ssl_sockets.h"
#ifdef NO_CLIENT_TLS
class client_ctx
{
public:
client_ctx(string)
{
}
};
class client_socket
{
public:
int socket;
client_socket(boost::asio::io_service&,
client_ctx&, int plaintext_socket, string,
string, bool) : socket(plaintext_socket)
{
}
~client_socket()
{
close(socket);
}
};
inline void send(client_socket* socket, octet* data, size_t len)
{
send(socket->socket, data, len);
}
inline void receive(client_socket* socket, octet* data, size_t len)
{
receive(socket->socket, data, len);
}
#else
typedef ssl_ctx client_ctx;
typedef ssl_socket client_socket;
#endif
/**
* Client-side interface
*/
class Client
{
vector<int> plain_sockets;
ssl_ctx ctx;
client_ctx ctx;
ssl_service io_service;
public:
vector<ssl_socket*> sockets;
/**
* Sockets for cleartext communication
*/
vector<client_socket*> sockets;
/**
* Specification of computation domain
*/
octetStream specification;
/**
* Start a new set of connections to computing parties.
* @param hostnames location of computing parties
* @param port_base port base
* @param my_client_id client identifier
*/
Client(const vector<string>& hostnames, int port_base, int my_client_id);
~Client();
/**
* Securely input private values.
* @param values vector of integer-like values
*/
template<class T>
void send_private_inputs(const vector<T>& values);
template<class T>
vector<T> receive_outputs(int n);
/**
* Securely receive output values.
* @param n number of values
* @returns vector of integer-like values
*/
template<class T, class U = T>
vector<U> receive_outputs(int n);
};
#endif /* EXTERNALIO_CLIENT_H_ */

View File

@@ -20,7 +20,7 @@ Client::Client(const vector<string>& hostnames, int port_base,
{
set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i);
octetStream(to_string(my_client_id)).Send(plain_sockets[i]);
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
sockets[i] = new client_socket(io_service, ctx, plain_sockets[i],
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
@@ -46,20 +46,37 @@ void Client::send_private_inputs(const vector<T>& values)
octetStream os;
vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3);
bool active = true;
// Receive num_inputs triples from SPDZ
for (size_t j = 0; j < sockets.size(); j++)
{
#ifdef VERBOSE_COMM
cerr << "receiving from " << j << endl << flush;
#endif
os.reset_write_head();
os.Receive(sockets[j]);
#ifdef VERBOSE_COMM
cerr << "received " << os.get_length() << " from " << j << endl;
cerr << "received " << os.get_length() << " from " << j << endl << flush;
#endif
if (j == 0)
{
if (os.get_length() == 3 * values.size() * T::size())
active = true;
else
active = false;
}
int n_expected = active ? 3 : 1;
if (os.get_length() != n_expected * T::size() * values.size())
throw runtime_error("unexpected data length in sending");
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
for (int k = 0; k < n_expected; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
@@ -67,16 +84,18 @@ void Client::send_private_inputs(const vector<T>& values)
}
}
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
if (active)
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
}
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
@@ -91,19 +110,33 @@ void Client::send_private_inputs(const vector<T>& values)
// Receive shares of the result and sum together.
// Also receive authenticating values.
template<class T>
vector<T> Client::receive_outputs(int n)
template<class T, class U>
vector<U> Client::receive_outputs(int n)
{
vector<T> triples(3 * n);
octetStream os;
bool active = true;
for (auto& socket : sockets)
{
os.reset_write_head();
os.Receive(socket);
#ifdef VERBOSE_COMM
cout << "received " << os.get_length() << endl;
cout << "received " << os.get_length() << endl << flush;
#endif
for (int j = 0; j < 3 * n; j++)
if (socket == sockets[0])
{
if (os.get_length() == (size_t) 3 * n * T::size())
active = true;
else
active = false;
}
int n_expected = n * (active ? 3 : 1);
if (os.get_length() != (size_t) n_expected * T::size())
throw runtime_error("unexpected data length in receiving");
for (int j = 0; j < n_expected; j++)
{
T value;
value.unpack(os);
@@ -111,16 +144,24 @@ vector<T> Client::receive_outputs(int n)
}
}
vector<T> output_values;
for (int i = 0; i < 3 * n; i += 3)
if (active)
{
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
vector<U> output_values;
for (int i = 0; i < 3 * n; i += 3)
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
}
output_values.push_back(triples[i]);
}
output_values.push_back(triples[i]);
}
return output_values;
return output_values;
}
else
{
triples.resize(n);
return triples;
}
}

View File

@@ -1,24 +1,30 @@
The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
The ExternalIO directory contains an example of managing I/O between
external client processes and parties running MP-SPDZ engines. These
instructions assume that MP-SPDZ has been built as per the [project
readme](../README.md).
## Working Examples
[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a
[bankers-bonus-client.cpp](../ExternalIO/bankers-bonus-client.cpp) and
[bankers-bonus-client.py](../ExternalIO/bankers-bonus-client.py) act as a
client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc)
and demonstrates sending input and receiving output as described by
[Damgård et al.](https://eprint.iacr.org/2015/1006) The computation
allows up to eight clients to input a number and computes the client
with the largest input. You can run it as follows from the main
with the largest input. You can run the C++ code as follows from the main
directory:
```
make bankers-bonus-client.x
./compile.py bankers_bonus 1
Scripts/setup-ssl.sh <nparties>
Scripts/setup-clients.sh 3
Scripts/<protocol>.sh &
PLAYERS=<nparties> Scripts/<protocol>.sh bankers_bonus-1 &
./bankers-bonus-client.x 0 <nparties> 100 0 &
./bankers-bonus-client.x 1 <nparties> 200 0 &
./bankers-bonus-client.x 2 <nparties> 50 1
```
`<protocol>` can be any arithmetic protocol (e.g., `mascot`) but not a
binary protocol (e.g., `yao`).
This should output that the winning id is 1. Note that the ids have to
be incremental, and the client with the highest id has to input 1 as
the last argument while the others have to input 0 there. Furthermore,
@@ -28,58 +34,30 @@ protocol script. The setup scripts generate the necessary SSL
certificates and keys. Therefore, if you run the computation on
different hosts, you will have to distribute the `*.pem` files.
For the Python client, make sure to install
[gmpy2](https://pypi.org/project/gmpy2), and run
`ExternalIO/bankers-bonus-client.py` instead of
`bankers-bonus-client.x`.
## I/O MPC Instructions
### Connection Setup
**listen**(*int port_num*)
Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background.
*port_num* - the port number to listen on.
**acceptclientconnection**(*regint client_socket_id*, *int port_num*)
Picks the first available client socket connection. Blocks if none available.
*client_socket_id* - an identifier used to refer to the client socket.
*port_num* - the port number identifies the socket server to accept connections on.
1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients)
2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection)
3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection)
### Data Exchange
Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py).
Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types).
*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*)
1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket)
2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client)
3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients)
Read a share of an input from a client, blocking on the client send.
## Client-Side Interface
*client_socket_id* - an identifier used to refer to the client socket.
*number_of_inputs* - the number of inputs expected
*[inputs]* - returned list of shares of private input.
**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*)
Write shares of values including macs to an external client.
*client_socket_id* - an identifier used to refer to the client socket.
*[values]* - list of shares of values to send to client.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message.
*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*)
Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf)
*number_of_inputs* - the number of inputs expected
*client_socket_id* - an identifier used to refer to the client socket.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
*[inputs]* - returned list of shares of private input.
The example uses the `Client` class implemented in
`ExternalIO/Client.hpp` to handle the communication, see
[this reference](https://mp-spdz.readthedocs.io/en/latest/io.html#reference) for
documentation.

View File

@@ -46,7 +46,7 @@
#include <sstream>
#include <fstream>
template<class T>
template<class T, class U>
void one_run(T salary_value, Client& client)
{
// Run the computation
@@ -54,18 +54,18 @@ void one_run(T salary_value, Client& client)
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back (client_id of winning client)
T result = client.receive_outputs<T>(1)[0];
U result = client.receive_outputs<T>(1)[0];
cout << "Winning client id is : " << result << endl;
}
template<class T>
template<class T, class U>
void run(double salary_value, Client& client)
{
// sint
one_run<T>(long(round(salary_value)), client);
one_run<T, U>(long(round(salary_value)), client);
// sfix with f = 16
one_run<T>(long(round(salary_value * exp2(16))), client);
one_run<T, U>(long(round(salary_value * exp2(16))), client);
}
int main(int argc, char** argv)
@@ -125,7 +125,7 @@ int main(int argc, char** argv)
{
gfp::init_field(specification.get<bigint>());
cerr << "using prime " << gfp::pr() << endl;
run<gfp>(salary_value, client);
run<gfp, gfp>(salary_value, client);
break;
}
case 'R':
@@ -134,13 +134,13 @@ int main(int argc, char** argv)
switch (R)
{
case 64:
run<Z2<64>>(salary_value, client);
run<Z2<64>, Z2<64>>(salary_value, client);
break;
case 104:
run<Z2<104>>(salary_value, client);
run<Z2<104>, Z2<64>>(salary_value, client);
break;
case 128:
run<Z2<128>>(salary_value, client);
run<Z2<128>, Z2<64>>(salary_value, client);
break;
default:
cerr << R << "-bit ring not implemented";

View File

@@ -0,0 +1,35 @@
#!/usr/bin/python3
import sys
sys.path.append('.')
from client import *
from domains import *
client_id = int(sys.argv[1])
n_parties = int(sys.argv[2])
bonus = float(sys.argv[3])
finish = int(sys.argv[4])
client = Client(['localhost'] * n_parties, 14000, client_id)
type = client.specification.get_int(4)
if type == ord('R'):
domain = Z2(client.specification.get_int(4))
elif type == ord('p'):
domain = Fp(client.specification.get_bigint())
else:
raise Exception('invalid type')
for socket in client.sockets:
os = octetStream()
os.store(finish)
os.Send(socket)
for x in bonus, bonus * 2 ** 16:
client.send_private_inputs([domain(x)])
print('Winning client id is :',
client.receive_outputs(domain, 1)[0].v % 2 ** 64)

126
ExternalIO/client.py Normal file
View File

@@ -0,0 +1,126 @@
import socket, ssl
import struct
import time
class Client:
def __init__(self, hostnames, port_base, my_client_id):
ctx = ssl.SSLContext()
name = 'C%d' % my_client_id
prefix = 'Player-Data/%s' % name
ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key')
ctx.load_verify_locations(capath='Player-Data')
self.sockets = []
for i, hostname in enumerate(hostnames):
for j in range(10000):
try:
plain_socket = socket.create_connection(
(hostname, port_base + i))
break
except ConnectionRefusedError:
if j < 60:
time.sleep(1)
else:
raise
octetStream(b'%d' % my_client_id).Send(plain_socket)
self.sockets.append(ctx.wrap_socket(plain_socket,
server_hostname='P%d' % i))
self.specification = octetStream()
self.specification.Receive(self.sockets[0])
def receive_triples(self, T, n):
triples = [[0, 0, 0] for i in range(n)]
os = octetStream()
for socket in self.sockets:
os.Receive(socket)
if socket == self.sockets[0]:
active = os.get_length() == 3 * n * T.size()
n_expected = 3 if active else 1
if os.get_length() != n_expected * T.size() * n:
import sys
print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr)
raise Exception('unexpected data length')
for triple in triples:
for i in range(n_expected):
t = T()
t.unpack(os)
triple[i] += t
res = []
if active:
for triple in triples:
prod = triple[0] * triple[1]
if prod != triple[2]:
raise Exception(
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
return triples
def send_private_inputs(self, values):
T = type(values[0])
triples = self.receive_triples(T, len(values))
os = octetStream()
assert len(values) == len(triples)
for value, triple in zip(values, triples):
(value + triple[0]).pack(os)
for socket in self.sockets:
os.Send(socket)
def receive_outputs(self, T, n):
triples = self.receive_triples(T, n)
return [triple[0] for triple in triples]
class octetStream:
def __init__(self, value=None):
self.buf = b''
self.ptr = 0
if value is not None:
self.buf += value
def get_length(self):
return len(self.buf)
def reset_write_head(self):
self.buf = b''
self.ptr = 0
def Send(self, socket):
socket.sendall(struct.pack('<i', len(self.buf)))
socket.sendall(self.buf)
def Receive(self, socket):
length = struct.unpack('<I', socket.recv(4))[0]
self.buf = b''
while len(self.buf) < length:
self.buf += socket.recv(length - len(self.buf))
self.ptr = 0
def store(self, value):
self.buf += struct.pack('<i', value)
def get_int(self, length):
buf = self.consume(length)
if length == 4:
return struct.unpack('<i', buf)[0]
elif length == 8:
return struct.unpack('<q', buf)[0]
raise ValueError()
def get_bigint(self):
sign = self.consume(1)[0]
assert(sign in (0, 1))
length = self.get_int(4)
if length:
res = 0
buf = self.consume(length)
for i, b in enumerate(reversed(buf)):
res += b << (i * 8)
if sign:
res *= -1
return res
else:
return 0
def consume(self, length):
self.ptr += length
assert self.ptr <= len(self.buf)
return self.buf[self.ptr - length:self.ptr]

74
ExternalIO/domains.py Normal file
View File

@@ -0,0 +1,74 @@
import struct
class Domain:
def __init__(self, value=0):
self.v = int(value % self.modulus)
assert(self.v >= 0)
def __add__(self, other):
try:
res = self.v + other.v
except:
res = self.v + other
return type(self)(res)
def __mul__(self, other):
try:
res = self.v * other.v
except:
res = self.v * other
return type(self)(res)
__radd__ = __add__
def __eq__(self, other):
return self.v == other.v
def __neq__(self, other):
return self.v != other.v
@classmethod
def size(cls):
return cls.n_bytes
def unpack(self, os):
self.v = 0
buf = os.consume(self.n_bytes)
for i, b in enumerate(buf):
self.v += b << (i * 8)
def pack(self, os):
v = self.v
temp_buf = []
for i in range(self.n_bytes):
temp_buf.append(v & 0xff)
v >>= 8
#Instead of using python a loop per value we let struct pack handle all it
os.buf += struct.pack('<{}B'.format(len(temp_buf)), *tuple(temp_buf))
def Z2(k):
class Z(Domain):
modulus = 2 ** k
n_words = (k + 63) // 64
n_bytes = (k + 7) // 8
return Z
def Fp(mod):
import gmpy2
class Fp(Domain):
modulus = mod
n_words = (modulus.bit_length() + 63) // 64
n_bytes = 8 * n_words
R = 2 ** (64 * n_words) % modulus
R_inv = gmpy2.invert(R, modulus)
def unpack(self, os):
Domain.unpack(self, os)
self.v = self.v * self.R_inv % self.modulus
def pack(self, os):
Domain.pack(type(self)(self.v * self.R), os)
return Fp

View File

@@ -58,7 +58,8 @@ public:
{
if (this->size() != y.size())
throw out_of_range("vector length mismatch");
for (unsigned int i = 0; i < this->size(); i++)
size_t n = this->size();
for (unsigned int i = 0; i < n; i++)
(*this)[i] += y[i];
return *this;
}
@@ -67,9 +68,11 @@ public:
{
if (this->size() != y.size())
throw out_of_range("vector length mismatch");
AddableVector<T> res(y.size());
for (unsigned int i = 0; i < this->size(); i++)
res[i] = (*this)[i] - y[i];
AddableVector<T> res;
res.reserve(y.size());
size_t n = this->size();
for (unsigned int i = 0; i < n; i++)
res.push_back((*this)[i] - y[i]);
return res;
}

View File

@@ -1,5 +1,4 @@
#include "Ciphertext.h"
#include "PPData.h"
#include "P2Data.h"
#include "Tools/Exceptions.h"
@@ -31,6 +30,12 @@ word check_pk_id(word a, word b)
}
void Ciphertext::Scale()
{
Scale(params->get_plaintext_modulus());
}
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1)
{
if (c0.params!=c1.params) { throw params_mismatch(); }
@@ -108,16 +113,34 @@ void Ciphertext::mul(const Ciphertext& c, const Rq_Element& ra)
::mul(cc1,ra,c.cc1);
}
void Ciphertext::add(octetStream& os)
void Ciphertext::add(octetStream& os, int)
{
Ciphertext tmp(*params);
tmp.unpack(os);
*this += tmp;
}
void Ciphertext::rerandomize(const FHE_PK& pk)
{
Rq_Element tmp(*params);
SeededPRNG G;
vector<FFT_Data::S> r(params->FFTD()[0].m());
bigint p = pk.p();
assert(p != 0);
for (auto& x : r)
{
G.get(x, params->p0().numBits() - p.numBits() - 1);
x *= p;
}
tmp.from(r, 0);
Scale();
cc0 += tmp;
auto zero = pk.encrypt(*params);
zero.Scale(pk.p());
*this += zero;
}
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gf2n_short,P2Data,int>& a,const Ciphertext& c);
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
const Ciphertext& c);

View File

@@ -15,6 +15,12 @@ template<class T,class FD,class S> void mul(Ciphertext& ans,const Ciphertext& c,
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1);
void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk);
/**
* BGV ciphertext.
* The class allows adding two ciphertexts as well as adding a plaintext and
* a ciphertext via operator overloading. The multiplication of two ciphertexts
* requires the public key and thus needs a separate function.
*/
class Ciphertext
{
Rq_Element cc0,cc1;
@@ -54,6 +60,7 @@ class Ciphertext
// Scale down an element from level 1 to level 0, if at level 0 do nothing
void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); }
void Scale();
// Throws error if ans,c0,c1 etc have different params settings
// - Thus programmer needs to ensure this rather than this being done
@@ -90,6 +97,12 @@ class Ciphertext
template <class FD>
Ciphertext& operator*=(const Plaintext_<FD>& other) { ::mul(*this, *this, other); return *this; }
/**
* Ciphertext multiplication.
* @param pk public key
* @param x second ciphertext
* @returns product ciphertext
*/
Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const
{ Ciphertext res(*params); ::mul(res, *this, x, pk); return res; }
@@ -98,21 +111,25 @@ class Ciphertext
return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this};
}
/// Re-randomize for circuit privacy.
void rerandomize(const FHE_PK& pk);
int level() const { return cc0.level(); }
// pack/unpack (like IO) also assume params are known and already set
// correctly
void pack(octetStream& o) const
/// Append to buffer
void pack(octetStream& o, int = -1) const
{ cc0.pack(o); cc1.pack(o); o.store(pk_id); }
void unpack(octetStream& o)
{ cc0.unpack(o); cc1.unpack(o); o.get(pk_id); }
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o, int = -1)
{ cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); }
void output(ostream& s) const
{ cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); }
void input(istream& s)
{ cc0.input(s); cc1.input(s); s.read((char*)&pk_id, sizeof(pk_id)); }
void add(octetStream& os);
void add(octetStream& os, int = -1);
size_t report_size(ReportType type) const { return cc0.report_size(type) + cc1.report_size(type); }
};

View File

@@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag(
{
auto& c = products.at(i);
for (int j = 0; j < n_matrices; j++)
{
res.at(j).entries.init();
for (size_t k = 0; k < n_rows; k++)
res.at(j)[{k, i}] = c.element(j * n_rows + k);
}
}
return res;
}

View File

@@ -259,6 +259,3 @@ void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forw
else
{ throw crash_requested(); }
}

View File

@@ -7,6 +7,11 @@
FFT_Data::FFT_Data() :
twop(-1)
{
}
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;
@@ -78,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
for (int r=0; r<2; r++)
{ FFT_Iter(b[r],twop,two_root[0],PrD); }
}
else
throw bad_value();
}
}

View File

@@ -50,7 +50,7 @@ class FFT_Data
void pack(octetStream& o) const;
void unpack(octetStream& o);
FFT_Data() { ; }
FFT_Data();
FFT_Data(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }

View File

@@ -2,7 +2,6 @@
#include "FHE_Keys.h"
#include "Ciphertext.h"
#include "P2Data.h"
#include "PPData.h"
#include "FFT_Data.h"
#include "Math/modp.hpp"
@@ -12,6 +11,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p())
{
}
FHE_SK::FHE_SK(const FHE_Params& pms) :
FHE_SK(pms, pms.get_plaintext_modulus())
{
}
FHE_SK& FHE_SK::operator+=(const FHE_SK& c)
{
@@ -38,6 +42,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G)
}
FHE_PK::FHE_PK(const FHE_Params& pms) :
FHE_PK(pms, pms.get_plaintext_modulus())
{
}
Rq_Element FHE_PK::sample_secret_key(PRNG& G)
{
Rq_Element sk = FHE_SK(*this).s();
@@ -47,11 +56,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
}
void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
{
Rq_Element a(*this);
a.randomize(G);
partial_key_gen(sk, a, G, noise_boost);
}
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost)
{
FHE_PK& PK = *this;
// Generate the main public key
PK.a0.randomize(G);
a0 = a;
// b0=a0*s+p*e0
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
@@ -77,9 +93,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
mul(es,es,PK.pr);
add(PK.Sw_b,PK.Sw_b,es);
// Lowering level as we only decrypt at level 0
sk.lower_level();
// bs=bs-p1*s^2
Rq_Element s2;
mul(s2,sk,sk); // Mult at level 0
@@ -175,32 +188,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext<typename FD::T, FD, typename FD::S>&
template<class FD>
Ciphertext FHE_PK::encrypt(
const Plaintext<typename FD::T, FD, typename FD::S>& mess) const
{
return encrypt(Rq_Element(*params, mess));
}
Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const
{
Random_Coins rc(*params);
PRNG G;
G.ReSeed();
rc.generate(G);
return encrypt(mess, rc);
Ciphertext res(*params);
quasi_encrypt(res, mess, rc);
return res;
}
template<class T, class FD, class S>
void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element ans = quasi_decrypt(c);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
}
Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
ans.change_rep(polynomial);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
return ans;
}
Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c)
{
return decrypt(c, params->get_plaintext_field_data<FFT_Data>());
}
template<class FD>
Plaintext<typename FD::T, FD, typename FD::S> FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD)
{
@@ -295,12 +327,12 @@ void FHE_PK::unpack(octetStream& o)
o.consume((octet*) tag, 8);
if (memcmp(tag, "PKPKPKPK", 8))
throw runtime_error("invalid serialization of public key");
a0.unpack(o);
b0.unpack(o);
a0.unpack(o, *params);
b0.unpack(o, *params);
if (params->n_mults() > 0)
{
Sw_a.unpack(o);
Sw_b.unpack(o);
Sw_a.unpack(o, *params);
Sw_b.unpack(o, *params);
}
pr.unpack(o);
}
@@ -318,7 +350,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const
return false;
}
void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
const bigint& pr) const
{
@@ -334,15 +365,13 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
template<class FD>
void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
{
check(*params, pk, pr);
check(*params, pk, FieldD.get_prime());
pk.check_noise(*this);
if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) !=
Plaintext_<FD>(FieldD))
throw runtime_error("incorrect key pair");
}
void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
{
if (this->pr != pr)
@@ -357,30 +386,36 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
}
}
bigint FHE_SK::get_noise(const Ciphertext& c)
{
sk.lower_level();
Ciphertext cc = c;
if (cc.level())
cc.Scale();
Rq_Element tmp = quasi_decrypt(cc);
bigint res;
bigint q = tmp.get_modulus();
bigint half_q = q / 2;
for (auto& x : tmp.to_vec_bigint())
{
// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " ";
res = max(res, x > half_q ? x - q : x);
}
return res;
}
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
const Random_Coins& rc) const;
#define X(FD) \
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FD>& mess, \
const Random_Coins& rc) const; \
template Ciphertext FHE_PK::encrypt(const Plaintext_<FD>& mess) const; \
template Plaintext_<FD> FHE_SK::decrypt(const Ciphertext& c, \
const FD& FieldD); \
template void FHE_SK::decrypt(Plaintext_<FD>& res, \
const Ciphertext& c) const; \
template void FHE_SK::decrypt_any(Plaintext_<FD>& res, \
const Ciphertext& c); \
template void FHE_SK::check(const FHE_PK& pk, const FD&);
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
const FFT_Data& FieldD);
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
const P2Data& FieldD);
template void FHE_SK::decrypt_any(Plaintext_<FFT_Data>& res,
const Ciphertext& c);
template void FHE_SK::decrypt_any(Plaintext_<P2Data>& res,
const Ciphertext& c);
template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&);
template void FHE_SK::check(const FHE_PK& pk, const P2Data&);
X(FFT_Data)
X(P2Data)

View File

@@ -12,6 +12,10 @@
class FHE_PK;
class Ciphertext;
/**
* BGV secret key.
* The class allows addition.
*/
class FHE_SK
{
Rq_Element sk;
@@ -29,6 +33,8 @@ class FHE_SK
// secret key always on lower level
void assign(const Rq_Element& s) { sk=s; sk.lower_level(); }
FHE_SK(const FHE_Params& pms);
FHE_SK(const FHE_Params& pms, const bigint& p)
: sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; }
@@ -38,8 +44,12 @@ class FHE_SK
const Rq_Element& s() const { return sk; }
void pack(octetStream& os) const { sk.pack(os); pr.pack(os); }
void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); }
/// Append to buffer
void pack(octetStream& os, int = -1) const { sk.pack(os); pr.pack(os); }
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& os, int = -1)
{ sk.unpack(os, *params); pr.unpack(os); }
// Assumes Ring and prime of mess have already been set correctly
// Ciphertext c must be at level 0 or an error occurs
@@ -50,9 +60,14 @@ class FHE_SK
template <class FD>
Plaintext<typename FD::T, FD, typename FD::S> decrypt(const Ciphertext& c, const FD& FieldD);
/// Decryption for cleartexts modulo prime
Plaintext_<FFT_Data> decrypt(const Ciphertext& c);
template <class FD>
void decrypt_any(Plaintext_<FD>& mess, const Ciphertext& c);
Rq_Element quasi_decrypt(const Ciphertext& c) const;
// Three stage procedure for Distributed Decryption
// - First stage produces my shares
// - Second stage adds in another players shares, do this once for each other player
@@ -62,7 +77,6 @@ class FHE_SK
void dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_number,int num_players) const;
void dist_decrypt_2(vector<bigint>& vv,const vector<bigint>& vv1) const;
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
/* Add secret keys
@@ -75,17 +89,23 @@ class FHE_SK
bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; }
void add(octetStream& os) { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
void add(octetStream& os, int = -1)
{ FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const;
template<class FD>
void check(const FHE_PK& pk, const FD& FieldD);
bigint get_noise(const Ciphertext& c);
friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; }
};
/**
* BGV public key.
*/
class FHE_PK
{
Rq_Element a0,b0;
@@ -104,8 +124,10 @@ class FHE_PK
)
{ a0=a; b0=b; Sw_a=sa; Sw_b=sb; }
FHE_PK(const FHE_Params& pms, const bigint& p = 0)
FHE_PK(const FHE_Params& pms);
FHE_PK(const FHE_Params& pms, const bigint& p)
: a0(pms.FFTD(),evaluation,evaluation),
b0(pms.FFTD(),evaluation,evaluation),
Sw_a(pms.FFTD(),evaluation,evaluation),
@@ -143,19 +165,26 @@ class FHE_PK
template <class FD>
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess, const Random_Coins& rc) const;
/// Encryption
template <class FD>
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess) const;
Ciphertext encrypt(const Rq_Element& mess) const;
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
Rq_Element sample_secret_key(PRNG& G);
void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1);
void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost = 1);
void check_noise(const FHE_SK& sk) const;
void check_noise(const Rq_Element& x, bool check_modulo = false) const;
// params setting is done out of these IO/pack/unpack functions
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o);
bool operator!=(const FHE_PK& x) const;
@@ -168,21 +197,39 @@ class FHE_PK
void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
/**
* BGV key pair
*/
class FHE_KeyPair
{
public:
/// Public key
FHE_PK pk;
/// Secret key
FHE_SK sk;
FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) :
FHE_KeyPair(const FHE_Params& params, const bigint& pr) :
pk(params, pr), sk(params, pr)
{
}
/// Initialization
FHE_KeyPair(const FHE_Params& params) :
pk(params), sk(params)
{
}
void generate(PRNG& G)
{
KeyGen(pk, sk, G);
}
/// Generate fresh keys
void generate()
{
SeededPRNG G;
generate(G);
}
};
template <class S>

View File

@@ -1,7 +1,15 @@
#include "FHE_Params.h"
#include "NTL-Subs.h"
#include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h"
#include "Protocols/HemiOptions.h"
#include "Processor/OnlineOptions.h"
FHE_Params::FHE_Params(int n_mults, int drown_sec) :
FFTData(n_mults + 1), Chi(0.7), sec_p(drown_sec), matrix_dim(1)
{
}
void FHE_Params::set(const Ring& R,
const vector<bigint>& primes)
@@ -12,16 +20,35 @@ void FHE_Params::set(const Ring& R,
for (size_t i = 0; i < FFTData.size(); i++)
FFTData[i].init(R,primes[i]);
set_sec(40);
set_sec(sec_p);
}
void FHE_Params::set_sec(int sec)
{
assert(sec >= 0);
sec_p=sec;
Bval=1; Bval=Bval<<sec_p;
Bval=FFTData[0].get_prime()/(2*(1+Bval));
if (Bval == 0)
throw runtime_error("distributed decryption bound is zero");
}
void FHE_Params::set_min_sec(int sec)
{
set_sec(max(sec, sec_p));
}
void FHE_Params::set_matrix_dim(int matrix_dim)
{
assert(matrix_dim > 0);
if (FFTData[0].get_prime() != 0)
throw runtime_error("cannot change matrix dimension after parameter generation");
this->matrix_dim = matrix_dim;
}
void FHE_Params::set_matrix_dim_from_options()
{
set_matrix_dim(
HemiOptions::singleton.plain_matmul ?
1 : OnlineOptions::singleton.batch_size);
}
bigint FHE_Params::Q() const
@@ -40,6 +67,8 @@ void FHE_Params::pack(octetStream& o) const
Chi.pack(o);
Bval.pack(o);
o.store(sec_p);
o.store(matrix_dim);
fd.pack(o);
}
void FHE_Params::unpack(octetStream& o)
@@ -52,6 +81,8 @@ void FHE_Params::unpack(octetStream& o)
Chi.unpack(o);
Bval.unpack(o);
o.get(sec_p);
o.get(matrix_dim);
fd.unpack(o);
}
bool FHE_Params::operator!=(const FHE_Params& other) const
@@ -64,3 +95,31 @@ bool FHE_Params::operator!=(const FHE_Params& other) const
else
return false;
}
void FHE_Params::basic_generation_mod_prime(int plaintext_length)
{
if (n_mults() == 0)
generate_semi_setup(plaintext_length, 0, *this, fd, false);
else
{
Parameters parameters(1, plaintext_length, 0);
parameters.generate_setup(*this, fd);
}
}
template<>
const FFT_Data& FHE_Params::get_plaintext_field_data() const
{
return fd;
}
template<>
const P2Data& FHE_Params::get_plaintext_field_data() const
{
throw not_implemented();
}
bigint FHE_Params::get_plaintext_modulus() const
{
return fd.get_prime();
}

View File

@@ -13,7 +13,11 @@
#include "FHE/FFT_Data.h"
#include "FHE/DiscreteGauss.h"
#include "Tools/random.h"
#include "Protocols/config.h"
/**
* Cryptosystem parameters
*/
class FHE_Params
{
protected:
@@ -26,19 +30,29 @@ class FHE_Params
// Data for distributed decryption
int sec_p;
bigint Bval;
int matrix_dim;
FFT_Data fd;
public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
/**
* Initialization.
* @param n_mults number of ciphertext multiplications (0/1)
* @param drown_sec parameter for function privacy (default 40)
*/
FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY);
int n_mults() const { return FFTData.size() - 1; }
// Rely on default copy assignment/constructor (not that they should
// ever be needed)
void set(const Ring& R,const vector<bigint>& primes);
void set(const vector<bigint>& primes);
void set_sec(int sec);
void set_min_sec(int sec);
void set_matrix_dim(int matrix_dim);
void set_matrix_dim_from_options();
int get_matrix_dim() const { return matrix_dim; }
const vector<FFT_Data>& FFTD() const { return FFTData; }
@@ -55,10 +69,24 @@ class FHE_Params
int phi_m() const { return FFTData[0].phi_m(); }
const Ring& get_ring() { return FFTData[0].get_R(); }
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer
void unpack(octetStream& o);
bool operator!=(const FHE_Params& other) const;
/**
* Generate parameter for computation modulo a prime
* @param plaintext_length bit length of prime
*/
void basic_generation_mod_prime(int plaintext_length);
template<class FD>
const FD& get_plaintext_field_data() const;
bigint get_plaintext_modulus() const;
};
#endif

View File

@@ -68,6 +68,10 @@ void HNF(matrix& H,matrix& U,const matrix& A)
{
int m=A.size(),n=A[0].size(),r,i,j,k;
#ifdef VERBOSE
cerr << "HNF m=" << m << ", n=" << n << endl;
#endif
H=A;
ident(U,n);
r=min(m,n);
@@ -79,9 +83,9 @@ void HNF(matrix& H,matrix& U,const matrix& A)
{ if (step==2)
{ // Step 2
k=-1;
mn=bigint(1)<<256;
mn=bigint(0);
for (j=i; j<n; j++)
{ if (H[i][j]!=0 && abs(H[i][j])<mn)
{ if (H[i][j]!=0 && (abs(H[i][j])<mn || mn == 0))
{ k=j; mn=abs(H[i][j]); }
}
if (k!=-1)
@@ -207,6 +211,11 @@ void SNF_Step(matrix& S,matrix& V)
void SNF(matrix& S,const matrix& A,matrix& V)
{
int m=A.size(),n=A[0].size();
#ifdef VERBOSE
cerr << "SNF m=" << m << ", n=" << n << endl;
#endif
S=A;
ident(V,n);
@@ -239,49 +248,6 @@ matrix inv(const matrix& A)
}
vector<modp> solve(modp_matrix& A,const Zp_Data& PrD)
{
unsigned int n=A.size();
if ((n+1)!=A[0].size()) { throw invalid_params(); }
modp t,ti;
for (unsigned int r=0; r<n; r++)
{ // Find pivot
unsigned int p=r;
while (isZero(A[p][r],PrD)) { p++; }
// Do pivoting
if (p!=r)
{ for (unsigned int c=0; c<n+1; c++)
{ t=A[p][c]; A[p][c]=A[r][c]; A[r][c]=t; }
}
// Make Lcoeff=1
Inv(ti,A[r][r],PrD);
for (unsigned int c=0; c<n+1; c++)
{ Mul(A[r][c],A[r][c],ti,PrD); }
// Now kill off other entries in this column
for (unsigned int rr=0; rr<n; rr++)
{ if (r!=rr)
{ for (unsigned int c=0; c<n+1; c++)
{ Mul(t,A[rr][c],A[r][r],PrD);
Sub(A[rr][c],A[rr][c],t,PrD);
}
}
}
}
// Sanity check and extract answer
vector<modp> ans;
ans.resize(n);
for (unsigned int i=0; i<n; i++)
{ for (unsigned int j=0; j<n; j++)
{ if (i!=j && !isZero(A[i][j],PrD)) { throw bad_value(); }
else if (!isOne(A[i][j],PrD)) { throw bad_value(); }
}
ans[i]=A[i][n];
}
return ans;
}
/* Input matrix is assumed to have more rows than columns */
void pinv(imatrix& Ai,const imatrix& B)

View File

@@ -10,7 +10,6 @@ using namespace std;
#include "Tools/BitVector.h"
typedef vector< vector<bigint> > matrix;
typedef vector< vector<modp> > modp_matrix;
class imatrix : public vector< BitVector >
{
@@ -39,13 +38,6 @@ void print(const imatrix& S);
// requires column operations to create the inverse
matrix inv(const matrix& A);
// Another special routine for modp matrices.
// Solves
// Ax=b
// Assumes A is unimodular, square and only requires row operations to
// create the inverse. In put is C=(A||b) and the routines alters A
vector<modp> solve(modp_matrix& C,const Zp_Data& PrD);
// Finds a pseudo-inverse of a matrix A modulo 2
// - Input matrix is assumed to have more rows than columns
void pinv(imatrix& Ai,const imatrix& A);

View File

@@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2)
template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FFT_Data& FTD, bool round_up)
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
{
int m = 1024;
int lgp = plaintext_length;
@@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec,
while (true)
{
tmp_params = params;
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
bigint p1 = 2 * p * m, p0 = p;
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
@@ -89,27 +89,30 @@ int generate_semi_setup(int plaintext_length, int sec,
template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, P2Data& P2D, bool round_up)
FHE_Params& params, P2Data& P2D, bool round_up, int n)
{
if (params.n_mults() > 0)
throw runtime_error("only implemented for 0-level BGV");
gf2n_short::init_field(plaintext_length);
int m;
char_2_dimension(m, plaintext_length);
SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec,
SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
int lgp0 = numBits(nb.min_p0(false, 0));
int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up);
assert(nb.min_phi_m(lgp0, false) * 2 <= m);
load_or_generate(P2D, params.get_ring());
return extra_slack;
}
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, bool round_up)
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up)
{
#ifdef VERBOSE
cout << "Need ciphertext modulus of length " << lgp0;
if (params.n_mults() > 0)
cout << "+" << lgp1;
cout << " and " << phi_N(m) << " slots" << endl;
#endif
int extra_slack = 0;
if (round_up)
@@ -124,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b
}
extra_slack = i - 1;
lgp0 += extra_slack;
#ifdef VERBOSE
cout << "Rounding up to " << lgp0 << ", giving extra slack of "
<< extra_slack << " bits" << endl;
#endif
}
Ring R;
@@ -147,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b
int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
bool round_up, FHE_Params& params)
{
(void) lg2pi, (void) n;
#ifdef VERBOSE
if (n >= 2 and n <= 10)
cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
#endif
int extra_slack = 0;
if (round_up)
@@ -170,20 +179,18 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
extra_slack = 2 * i;
lg2p0 += i;
lg2p1 += i;
#ifdef VERBOSE
cout << "Rounding up to " << lg2p0 << "+" << lg2p1
<< ", giving extra slack of " << extra_slack << " bits" << endl;
#endif
}
#ifdef VERBOSE
cout << "Total length: " << lg2p0 + lg2p1 << endl;
#endif
return extra_slack;
}
/******************************************************************************
* Here onwards needs NTL
******************************************************************************/
@@ -220,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p,
{
double phi_m_bound =
NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1);
#ifdef VERBOSE
cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl;
#endif
if (phi_N(m) < phi_m_bound)
{
int old_m = m;
(void) old_m;
m = 2 << int(ceil(log2(phi_m_bound)));
#ifdef VERBOSE
cout << "m = " << old_m << " too small, increasing it to " << m << endl;
#endif
generate_prime(p, numBits(p), m);
}
else
@@ -249,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p,
void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
const string& i, const bigint& pr0)
{
(void) i;
if (lg2pr==0) { throw invalid_params(); }
bigint step=m;
@@ -265,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
assert(numBits(pr) == lg2pr);
}
#ifdef VERBOSE
cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl;
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
#endif
assert(pr % m == 1);
assert(pr % p == 1);
assert(numBits(pr) == lg2pr);
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
}
/*
@@ -345,10 +364,12 @@ ZZX Cyclotomic(int N)
return F;
}
#else
// simplified version powers of two
int phi_N(int N)
{
if (((N - 1) & N) != 0)
throw runtime_error("compile with NTL support");
throw runtime_error(
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
else if (N == 1)
return 1;
else
@@ -398,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly)
for (int i=0; i<Rg.phim+1; i++)
{ Rg.poly[i]=to_int(coeff(P,i)); }
#else
throw runtime_error("compile with NTL support");
throw runtime_error(
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
#endif
}
}
@@ -433,6 +455,16 @@ GF2X Subs_PowX_Mod(const GF2X& a,int pow,int m,const GF2X& c)
GF2X get_F(const Ring& Rg)
{
GF2X F;
for (int i=0; i<=Rg.phi_m(); i++)
{ if (((Rg.Phi()[i])%2)!=0)
{ SetCoeff(F,i,1); }
}
//cout << "F = " << F << endl;
return F;
}
void init(P2Data& P2D,const Ring& Rg)
{
@@ -443,16 +475,12 @@ void init(P2Data& P2D,const Ring& Rg)
{ SetCoeff(G,gf2n_short::get_t(i),1); }
//cout << "G = " << G << endl;
for (int i=0; i<=Rg.phi_m(); i++)
{ if (((Rg.Phi()[i])%2)!=0)
{ SetCoeff(F,i,1); }
}
//cout << "F = " << F << endl;
F = get_F(Rg);
// seed randomness to achieve same result for all players
// randomness is used in SFCanZass and FindRoot
SetSeed(ZZ(0));
// Now factor F modulo 2
vec_GF2X facts=SFCanZass(F);
@@ -464,17 +492,34 @@ void init(P2Data& P2D,const Ring& Rg)
// Compute the quotient group
QGroup QGrp;
int Gord=-1,e=Rg.phi_m()/d; // e = # of plaintext slots, phi(m)/degree
int seed=1;
while (Gord!=e)
if ((e*gf2n_short::degree())!=Rg.phi_m())
{ cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl;
cout << e << " * " << gf2n_short::degree() << " != " << Rg.phi_m() << endl;
throw invalid_params();
}
int max_tries = 10;
for (int seed = 0;; seed++)
{ QGrp.assign(Rg.m(),seed); // QGrp encodes the the quotient group Z_m^*/<2>
Gord=QGrp.order();
if (Gord!=e) { cout << "Group order wrong, need to repeat the Haf-Mc algorithm" << endl; seed++; }
Gord = QGrp.order();
if (Gord == e)
{
break;
}
else
{
if (seed == max_tries)
{
cerr << "abort after " << max_tries << " tries" << endl;
throw invalid_params();
}
else
cout << "Group order wrong, need to repeat the Haf-Mc algorithm"
<< endl;
}
}
//cout << " l = " << Gord << " , d = " << d << endl;
if ((Gord*gf2n_short::degree())!=Rg.phi_m())
{ cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl;
throw not_implemented();
}
vector<GF2X> Fi(Gord);
vector<GF2X> Rts(Gord);
@@ -595,6 +640,27 @@ void char_2_dimension(int& m, int& lg2)
m=5797;
lg2=40;
break;
case 64:
m = 9615;
break;
case 63:
m = 9271;
break;
case 28:
m = 3277;
break;
case 16:
m = 4369;
break;
case 15:
m = 4681;
break;
case 12:
m = 4095;
break;
case 11:
m = 2047;
break;
default:
throw runtime_error("field size not supported");
break;
@@ -630,7 +696,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D)
finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[0], round_up, params);
}
if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) > phi_N(m))
if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) * 2 > m)
throw runtime_error("number of slots too small");
cout << "m = " << m << endl;
@@ -676,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R)
P2D.store(R);
}
}
#ifdef USE_NTL
/*
* Create FHE parameters for a general plaintext modulus p
* Basically this is for general large primes only
*/
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params)
{
cout << "Setting up parameters" << endl;
int lgp=numBits(p);
int mm,idx;
// mm is the minimum value of m we will accept
if (lgp<48)
{ mm=100; // Test case
idx=0;
}
else if (lgp <96)
{ mm=8192;
idx=1;
}
else if (lgp<192)
{ mm=16384;
idx=2;
}
else if (lgp<384)
{ mm=16384;
idx=3;
}
else if (lgp<768)
{ mm=32768;
idx=4;
}
else
{ throw invalid_params(); }
// Now find the small factors of p-1 and their exponents
bigint t=p-1;
vector<long> primes(100),exp(100);
PrimeSeq s;
long pr;
pr=s.next();
int len=0;
while (pr<2*mm)
{ int e=0;
while ((t%pr)==0)
{ e++;
t=t/pr;
}
if (e!=0)
{ primes[len]=pr;
exp[len]=e;
if (len!=0) { cout << " * "; }
cout << pr << "^" << e << flush;
len++;
}
pr=s.next();
}
cout << endl;
// We want to find the best m which divides pr-1, such that
// - 2*m > phi(m) > mm
// - m has the smallest number of factors
vector<int> ee;
ee.resize(len);
for (int i=0; i<len; i++) { ee[i]=0; }
int min_hwt=-1,m=-1,bphi_m=-1,bmx=-1;
bool flag=true;
while (flag)
{ int cand_m=1,hwt=0,mx=0;
for (int i=0; i<len; i++)
{ //cout << ee[i] << " ";
if (ee[i]!=0)
{ hwt++;
for (int j=0; j<ee[i]; j++)
{ cand_m*=primes[i]; }
if (ee[i]>mx) { mx=ee[i]; }
}
}
// Put "if" here to stop searching for things which will never work
if (cand_m>1 && cand_m<4*mm)
{ //cout << " : " << cand_m << " : " << hwt << flush;
int phim=phi_N(cand_m);
//cout << " : " << phim << " : " << mm << endl;
if (phim>mm && phim<3*mm)
{ if (m==-1 || hwt<min_hwt || (hwt==min_hwt && mx<bmx) || (hwt==min_hwt && mx==bmx && phim<bphi_m))
{ m=cand_m;
min_hwt=hwt;
bphi_m=phim;
bmx=mx;
//cout << "\t OK" << endl;
}
}
}
else
{ //cout << endl;
}
int i=0;
ee[i]=ee[i]+1;
while (ee[i]>exp[i] && flag)
{ ee[i]=0;
i++;
if (i==len) { flag=false; i=0; }
else { ee[i]=ee[i]+1; }
}
}
if (m==-1)
{ throw bad_value(); }
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
Parameters parameters(n, lgp, sec);
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params);
int mx=0;
for (int i=0; i<R.phi_m(); i++)
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
cout << "Max Coeff = " << mx << endl;
init(R, m, true);
Zp_Data Zp(p);
PPD.init(R,Zp);
gfp::init_field(p);
pr0 = parameters.pr0;
pr1 = parameters.pr1;
}
#endif

View File

@@ -1,12 +1,9 @@
#ifndef _NTL_Subs
#define _NTL_Subs
/* All these routines use NTL on the inside */
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
#include "FHE/P2Data.h"
#include "FHE/PPData.h"
#include "FHE/FHE_Params.h"
/* Routines to set up key sizes given the number of players n
@@ -47,26 +44,28 @@ public:
};
// Main setup routine (need NTL if online_only is false)
// Main setup routine
void generate_setup(int nparties, int lgp, int lg2,
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
// semi-homomorphic, includes slack
template <class FD>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FD& FieldD, bool round_up);
FHE_Params& params, FD& FieldD, bool round_up, int n = 1);
// field-independent semi-homomorphic setup
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1,
bool round_up);
// Everything else needs NTL
void init(Ring& Rg, int m, bool generate_poly);
void init(P2Data& P2D,const Ring& Rg);
// For use when we want p to be a specific value
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params);
namespace NTL
{
class GF2X;
}
NTL::GF2X get_F(const Ring& Rg);
// generate moduli according to lengths and other parameters
void generate_moduli(bigint& pr0, bigint& pr1, const int m,

View File

@@ -36,10 +36,12 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s);
// unify parameters by taking maximum over TopGear or not
bigint B_clean_top_gear = B_clean * 2;
bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.));
bigint B_clean_not_top_gear = B_clean << max(slack - sec, 0);
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
int matrix_dim = params.get_matrix_dim();
#ifdef NOISY
cout << "phi(m): " << phi_m << endl;
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cout << "V_s: " << V_s << endl;
cout << "c1: " << c1 << endl;
@@ -48,9 +50,14 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
cout << "log(slack): " << slack << endl;
cout << "B_clean: " << B_clean << endl;
cout << "B_scale: " << B_scale << endl;
cout << "matrix dimension: " << matrix_dim << endl;
cout << "drown sec: " << params.secp() << endl;
cout << "sec: " << sec << endl;
#endif
drown = 1 + n * (bigint(1) << sec);
assert(matrix_dim > 0);
assert(params.secp() >= 0);
drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp());
}
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
@@ -68,8 +75,14 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma)
{
if (sigma <= 0)
sigma = FHE_Params().get_R();
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
return 37.8 * (log_q - log2(sigma));
// the constant was updated using Martin Albrecht's LWE estimator in Mar 2022
// found the following pairs for 128-bit security
// and alpha = 0.7 * sqrt(2*pi) / q
// m = 2048, log_2(q) = 68
// m = 4096, log_2(q) = 138
// m = 8192, log_2(q) = 302
// m = 16384, log_2(q) = 560
return 15.1 * log_q;
}
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params)
@@ -92,7 +105,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
{
tp *= t;
double lgtp = log(tp) / log(2.0);
if (C[i] < 0 && lgtp < FHE_epsilon)
if (C[i] < 0 && lgtp < -FHE_epsilon)
{
C[i] = pow(x, i);
}
@@ -114,7 +127,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
cout << "n: " << n << endl;
cout << "sec: " << sec << endl;
cout << "sigma: " << this->sigma << endl;
cout << "h: " << h << endl;
cout << "B_clean size: " << numBits(B_clean) << endl;
cout << "B_scale size: " << numBits(B_scale) << endl;
cout << "B_KS size: " << numBits(B_KS) << endl;
@@ -155,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1)
bigint NoiseBounds::min_p1()
{
return drown * B_KS + 1;
return max(bigint(drown * B_KS), bigint((phi_m * p) << 10));
}
bigint NoiseBounds::opt_p1()
@@ -169,8 +181,10 @@ bigint NoiseBounds::opt_p1()
// solve
mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a);
bigint res = ceil(s);
#ifdef VERBOSE
cout << "Optimal p1 vs minimal: " << numBits(res) << "/"
<< numBits(min_p1()) << endl;
#endif
return res;
}
@@ -182,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1)
{
min_p0 *= 2;
min_p1 *= 2;
#ifdef VERBOSE
cout << "increasing lengths: " << numBits(min_p0) << "/"
<< numBits(min_p1) << endl;
#endif
}
lg2p1 = numBits(min_p1);
lg2p0 = numBits(min_p0);

View File

@@ -42,6 +42,8 @@ public:
bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); }
static double min_phi_m(int log_q, double sigma);
static double min_phi_m(int log_q, const FHE_Params& params);
bigint get_B_clean() { return B_clean; }
};
// as per ePrint 2012:642 for slack = 0

View File

@@ -55,13 +55,13 @@ void P2Data::check_dimensions() const
// cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl;
if (A.size() != Ai.size())
throw runtime_error("forward and backward mapping dimensions mismatch");
if (A.size() != A[0].size())
if (A.size() != A.at(0).size())
throw runtime_error("forward mapping not square");
if (Ai.size() != Ai[0].size())
if (Ai.size() != Ai.at(0).size())
throw runtime_error("backward mapping not square");
if ((int)A[0].size() != slots * gf2n_short::degree())
if ((int)A.at(0).size() != slots * gf2n_short::degree())
throw runtime_error(
"mapping dimension incorrect: " + to_string(A[0].size())
"mapping dimension incorrect: " + to_string(A.at(0).size())
+ " != " + to_string(slots) + " * "
+ to_string(gf2n_short::degree()));
}

View File

@@ -1,100 +0,0 @@
#include "FHE/Subroutines.h"
#include "FHE/PPData.h"
#include "FHE/FFT.h"
#include "FHE/Matrix.h"
#include "Math/modp.hpp"
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;
prData=PrD;
root=Find_Primitive_Root_m(Rg.m(),Rg.Phi(),PrD);
}
void PPData::to_eval(vector<modp>& elem) const
{
if (elem.size()!= (unsigned) R.phi_m())
{ throw params_mismatch(); }
throw not_implemented();
/*
vector<modp> ans;
ans.resize(R.phi_m());
modp x=root;
for (int i=0; i<R.phi_m(); i++)
{ ans[i]=elem[R.phi_m()-1];
for (int j=1; j<R.phi_m(); j++)
{ Mul(ans[i],ans[i],x,prData);
Add(ans[i],ans[i],elem[R.phi_m()-j-1],prData);
}
Mul(x,x,root,prData);
}
elem=ans;
*/
}
void PPData::from_eval(vector<modp>& elem) const
{
// avoid warning
elem.empty();
throw not_implemented();
/*
modp_matrix A;
int n=phi_m();
A.resize(n, vector<modp>(n+1) );
modp x=root;
for (int i=0; i<n; i++)
{ assignOne(A[0][i],prData);
for (int j=1; j<n; j++)
{ Mul(A[j][i],A[j-1][i],x,prData); }
Mul(x,x,root,prData);
A[i][n]=elem[i];
}
elem=solve(A,prData);
*/
}
void PPData::reset_iteration()
{
pow = 1;
theta = {root, prData};
thetaPow = theta;
}
void PPData::next_iteration()
{
do
{ thetaPow *= (theta);
pow++;
}
while (gcd(pow,m())!=1);
}
gfp PPData::get_evaluation(const vector<bigint>& mess) const
{
// Uses Horner's rule
gfp ans;
ans = mess[mess.size()-1];
gfp coeff;
for (int j=mess.size()-2; j>=0; j--)
{ ans *= (thetaPow);
coeff = mess[j];
ans += (coeff);
}
return ans;
}

View File

@@ -1,61 +0,0 @@
#ifndef _PPData
#define _PPData
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/gfpvar.h"
#include "Math/fixint.h"
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
/* Class for holding modular arithmetic data wrt the ring
*
* It also holds the ring
*/
class PPData
{
public:
typedef gfp T;
typedef bigint S;
typedef typename FFT_Data::poly_type poly_type;
Ring R;
Zp_Data prData;
modp root; // m'th Root of Unity mod pr
void init(const Ring& Rg,const Zp_Data& PrD);
PPData() { ; }
PPData(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }
int phi_m() const { return R.phi_m(); }
int m() const { return R.m(); }
int num_slots() const { return R.phi_m(); }
int p(int i) const { return R.p(i); }
int p_inv(int i) const { return R.p_inv(i); }
const vector<int>& Phi() const { return R.Phi(); }
// Convert input vector from poly to evaluation representation
// - Uses naive method and not FFT, we only use this rarely in any case
void to_eval(vector<modp>& elem) const;
void from_eval(vector<modp>& elem) const;
// Following are used to iteratively get slots, as we use PPData when
// we do not have an efficient FFT algorithm
gfp thetaPow,theta;
int pow;
void reset_iteration();
void next_iteration();
gfp get_evaluation(const vector<bigint>& mess) const;
};
#endif

View File

@@ -1,7 +1,6 @@
#include "FHE/Plaintext.h"
#include "FHE/Ring_Element.h"
#include "FHE/PPData.h"
#include "FHE/P2Data.h"
#include "FHE/Rq_Element.h"
#include "FHE_Keys.h"
@@ -11,10 +10,43 @@
template<class T, class FD, class S>
Plaintext<T, FD, S>::Plaintext(const FHE_Params& params) :
Plaintext(params.get_plaintext_field_data<FD>(), Both)
{
}
template<class T, class FD, class S>
unsigned int Plaintext<T, FD, S>::num_slots() const
{
return (*Field_Data).phi_m();
}
template<class T, class FD, class S>
int Plaintext<T, FD, S>::degree() const
{
return (*Field_Data).phi_m();
}
template<>
unsigned int Plaintext<gf2n_short,P2Data,int>::num_slots() const
{
return (*Field_Data).num_slots();
}
template<>
int Plaintext<gf2n_short,P2Data,int>::degree() const
{
return (*Field_Data).degree();
}
template<>
void Plaintext<gfp, FFT_Data, bigint>::from(const Generator<bigint>& source) const
{
b.resize(degree);
b.resize(degree());
for (auto& x : b)
{
source.get(bigint::tmp);
@@ -31,7 +63,7 @@ void Plaintext<gfp,FFT_Data,bigint>::from_poly() const
Ring_Element e(*Field_Data,polynomial);
e.from(b);
e.change_rep(evaluation);
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
a[i] = gfp(e.get_element(i), e.get_FFTD().get_prD());
type=Both;
@@ -52,45 +84,12 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
}
template<>
void Plaintext<gfp,PPData,bigint>::from_poly() const
{
if (type!=Polynomial) { return; }
vector<modp> aa((*Field_Data).phi_m());
for (unsigned int i=0; i<aa.size(); i++)
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
(*Field_Data).to_eval(aa);
a.resize(n_slots);
for (unsigned int i=0; i<aa.size(); i++)
a[i] = {aa[i], Field_Data->get_prD()};
type=Both;
}
template<>
void Plaintext<gfp,PPData,bigint>::to_poly() const
{
if (type!=Evaluation) { return; }
cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl;
vector<modp> bb((*Field_Data).phi_m());
for (unsigned int i=0; i<bb.size(); i++)
{ bb[i]=a[i].get(); }
(*Field_Data).from_eval(bb);
for (unsigned int i=0; i<bb.size(); i++)
{
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
b[i] = bigint::tmp;
}
type=Both;
}
template<>
void Plaintext<gf2n_short,P2Data,int>::from_poly() const
{
if (type!=Polynomial) { return; }
a.resize(n_slots);
a.resize(num_slots());
(*Field_Data).backward(a,b);
type=Both;
}
@@ -106,34 +105,13 @@ void Plaintext<gf2n_short,P2Data,int>::to_poly() const
template<>
void Plaintext<gfp,FFT_Data,bigint>::set_sizes()
{ n_slots = (*Field_Data).phi_m();
degree = n_slots;
}
template<>
void Plaintext<gfp,PPData,bigint>::set_sizes()
{ n_slots = (*Field_Data).phi_m();
degree = n_slots;
}
template<>
void Plaintext<gf2n_short,P2Data,int>::set_sizes()
{ n_slots = (*Field_Data).num_slots();
degree = (*Field_Data).degree();
}
template<class T, class FD, class S>
void Plaintext<T, FD, S>::allocate(PT_Type type) const
{
if (type != Evaluation)
b.resize(degree);
b.resize(degree());
if (type != Polynomial)
a.resize(n_slots);
a.resize(num_slots());
this->type = type;
}
@@ -141,7 +119,7 @@ void Plaintext<T, FD, S>::allocate(PT_Type type) const
template<class T, class FD, class S>
void Plaintext<T, FD, S>::allocate_slots(const bigint& value)
{
b.resize(degree);
b.resize(degree());
for (auto& x : b)
x.allocate_slots(value);
}
@@ -236,7 +214,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
type=Polynomial;
break;
case Diagonal:
a.resize(n_slots);
a.resize(num_slots());
a[0].randomize(G);
for (unsigned int i=1; i<a.size(); i++)
{ a[i]=a[0]; }
@@ -244,7 +222,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
break;
default:
// Gen a plaintext with 0/1 in each slot
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{
if (G.get_bit())
@@ -272,7 +250,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
b[0].generateUniform(G, n_bits, false);
}
else
for (int i = 0; i < n_slots; i++)
for (size_t i = 0; i < num_slots(); i++)
b[i].generateUniform(G, n_bits, false);
break;
default:
@@ -288,7 +266,7 @@ void Plaintext<T,FD,S>::assign_zero(PT_Type t)
allocate();
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].assign_zero(); }
}
@@ -306,7 +284,7 @@ void Plaintext<T,FD,S>::assign_one(PT_Type t)
allocate();
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].assign_one(); }
}
@@ -359,35 +337,7 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]+y.b[i];
if (z.b[i]>(*z.Field_Data).get_prime())
{ z.b[i]-=(*z.Field_Data).get_prime(); }
}
}
}
template<>
void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
@@ -418,7 +368,7 @@ void add(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i].add(x.a[i],y.a[i]); }
}
@@ -446,7 +396,7 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i]= (x.a[i] - y.a[i]); }
}
@@ -463,36 +413,6 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
template<>
void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] - y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]-y.b[i];
if (z.b[i]<0)
{ z.b[i]+=(*z.Field_Data).get_prime(); }
}
}
}
template<>
@@ -510,7 +430,7 @@ void sub(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i].sub(x.a[i],y.a[i]); }
}
@@ -545,7 +465,7 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
@@ -560,23 +480,6 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
}
}
template<>
void Plaintext<gfp,PPData,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(n_slots);
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
if (type!=Evaluation)
{ for (unsigned int i=0; i<b.size(); i++)
{ if (b[i]!=0)
{ b[i]=(*Field_Data).get_prime()-b[i]; }
}
}
}
template<>
@@ -607,7 +510,7 @@ bool Plaintext<T,FD,S>::equals(const Plaintext& x) const
if (type!=Polynomial and x.type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ if (!(a[i] == x.a[i])) { return false; } }
}
@@ -671,9 +574,9 @@ void Plaintext<T,FD,S>::unpack(octetStream& o)
unsigned int size;
o.get(size);
allocate();
if (size != b.size())
if (size != b.size() and size != 0)
throw length_error("unexpected length received");
for (unsigned int i = 0; i < b.size(); i++)
for (unsigned int i = 0; i < size; i++)
b[i] = o.get<S>();
}
@@ -719,12 +622,6 @@ template void mul(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data
template class Plaintext<gfp,PPData,bigint>;
template void mul(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,const Plaintext<gfp,PPData,bigint>& y);
template class Plaintext<gf2n_short,P2Data,int>;
template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y);

View File

@@ -18,6 +18,7 @@
*/
#include "FHE/Generator.h"
#include "FHE/FFT_Data.h"
#include "Math/fixint.h"
#include <vector>
@@ -25,6 +26,8 @@ using namespace std;
class FHE_PK;
class Rq_Element;
class FHE_Params;
class FFT_Data;
template<class T> class AddableVector;
// Forward declaration as apparently this is needed for friends in templates
@@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits };
enum PT_Type { Polynomial, Evaluation, Both };
/**
* BGV plaintext.
* Use ``Plaintext_mod_prime`` instead of filling in the templates.
* The plaintext is held in one of the two representations or both,
* polynomial and evaluation. The latter is the one allowing element-wise
* multiplication over a vector.
* Plaintexts can be added, subtracted, and multiplied via operator overloading.
*/
template<class T,class FD,class _>
class Plaintext
{
typedef typename FD::poly_type S;
int n_slots;
int degree;
mutable vector<T> a; // The thing in evaluation/FFT form
mutable vector<S> b; // Now in polynomial form
@@ -58,33 +67,47 @@ class Plaintext
const FD *Field_Data;
void set_sizes();
int degree() const;
public:
const FD& get_field() const { return *Field_Data; }
unsigned int num_slots() const { return n_slots; }
/// Number of slots
unsigned int num_slots() const;
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
{ Field_Data=&FieldD; allocate(type); }
Plaintext(const FD& FieldD, const Rq_Element& other);
/// Initialization
Plaintext(const FHE_Params& params);
void allocate(PT_Type type) const;
void allocate() const { allocate(type); }
void allocate_slots(const bigint& value);
int get_min_alloc();
// Access evaluation representation
/**
* Read slot.
* @param i slot number
* @returns slot content
*/
T element(int i) const
{ if (type==Polynomial)
{ from_poly(); }
return a[i];
}
/**
* Write to slot
* @param i slot number
* @param e new slot content
*/
void set_element(int i,const T& e)
{ if (type==Polynomial)
{ throw not_implemented(); }
a.resize(n_slots);
a.resize(num_slots());
a[i]=e;
type=Evaluation;
}
@@ -171,10 +194,10 @@ class Plaintext
bool is_diagonal() const;
/* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already
*/
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o);
size_t report_size(ReportType type);
@@ -185,4 +208,6 @@ class Plaintext
template <class FD>
using Plaintext_ = Plaintext<typename FD::T, FD, typename FD::S>;
typedef Plaintext_<FFT_Data> Plaintext_mod_prime;
#endif

View File

@@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o)
o.get(pi_inv);
o.get(poly);
}
else
else if (mm != 0)
init(*this, mm);
}

View File

@@ -50,6 +50,7 @@ void Ring_Element::prepare_push()
void Ring_Element::allocate()
{
assert(FFTD);
element.resize(FFTD->phi_m());
}
@@ -86,7 +87,6 @@ void Ring_Element::negate()
void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
if (a.element.empty())
{
@@ -99,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
return;
}
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (&ans == &a)
{
ans += b;
@@ -401,19 +403,29 @@ void Ring_Element::change_rep(RepType r)
bool Ring_Element::equals(const Ring_Element& a) const
{
if (element.empty() and a.element.empty())
return true;
else if (element.empty() or a.element.empty())
throw not_implemented();
if (rep!=a.rep) { throw rep_mismatch(); }
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
if (is_zero() or a.is_zero())
return is_zero() and a.is_zero();
for (int i=0; i<(*FFTD).phi_m(); i++)
{ if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } }
return true;
}
bool Ring_Element::is_zero() const
{
if (element.empty())
return true;
for (auto& x : element)
if (not ::isZero(x, FFTD->get_prD()))
return false;
return true;
}
ConversionIterator Ring_Element::get_iterator() const
{
if (rep != polynomial)
@@ -560,6 +572,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const
{
if (&FFTD != this->FFTD)
throw params_mismatch();
if (is_zero())
throw runtime_error("element is zero");
}

View File

@@ -95,6 +95,7 @@ class Ring_Element
void randomize(PRNG& G,bool Diag=false);
bool equals(const Ring_Element& a) const;
bool is_zero() const;
// This is a NOP in cases where we cannot do a FFT
void change_rep(RepType r);

View File

@@ -5,7 +5,7 @@
#include "Math/modp.hpp"
Rq_Element::Rq_Element(const FHE_PK& pk) :
Rq_Element(pk.get_params().FFTD())
Rq_Element(pk.get_params().FFTD(), evaluation, evaluation)
{
}
@@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
}
}
void Rq_Element::add(octetStream& os, int)
{
Rq_Element tmp(*this);
tmp.unpack(os);
*this += tmp;
}
void Rq_Element::randomize(PRNG& G,int l)
{
set_level(l);
@@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p)
// Now add delta back onto a0
Rq_Element bb(b0,b1);
add(*this,*this,bb);
::add(*this,*this,bb);
// Now divide by p1 mod p0
modp p1_inv,pp;
@@ -291,7 +298,7 @@ void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)
partial_assign(x);
}
void Rq_Element::pack(octetStream& o) const
void Rq_Element::pack(octetStream& o, int) const
{
check_level();
o.store(lev);
@@ -299,7 +306,7 @@ void Rq_Element::pack(octetStream& o) const
a[i].pack(o);
}
void Rq_Element::unpack(octetStream& o)
void Rq_Element::unpack(octetStream& o, int)
{
unsigned int ll; o.get(ll); lev=ll;
check_level();
@@ -340,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const
return sz;
}
void Rq_Element::unpack(octetStream& o, const FHE_Params& params)
{
set_data(params.FFTD());
unpack(o);
}
void Rq_Element::print_first_non_zero() const
{
vector<bigint> v = to_vec_bigint();

View File

@@ -69,8 +69,9 @@ protected:
a({b0}), lev(n_mults()) {}
template<class T, class FD, class S>
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext) :
Rq_Element(params)
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext,
RepType r0 = polynomial, RepType r1 = polynomial) :
Rq_Element(params, r0, r1)
{
from(plaintext.get_iterator());
}
@@ -93,12 +94,14 @@ protected:
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
void add(octetStream& os, int = -1);
template<class S>
Rq_Element& operator+=(const vector<S>& other);
Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; }
Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; }
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; }
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; }
Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; }
template <class T>
Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; }
@@ -154,8 +157,11 @@ protected:
* For unpack we assume the prData for a0 and a1 has been assigned
* correctly already
*/
void pack(octetStream& o) const;
void unpack(octetStream& o);
void pack(octetStream& o, int = -1) const;
void unpack(octetStream& o, int = -1);
// without prior initialization
void unpack(octetStream& o, const FHE_Params& params);
void output(ostream& s) const;
void input(istream& s);
@@ -176,7 +182,7 @@ Rq_Element& Rq_Element::operator+=(const vector<S>& other)
{
Rq_Element tmp = *this;
tmp.from(Iterator<S>(other), lev);
add(*this, *this, tmp);
::add(*this, *this, tmp);
return *this;
}

View File

@@ -11,35 +11,15 @@ void Subs(modp& ans,const vector<int>& poly,const modp& x,const Zp_Data& ZpD)
assignZero(ans,ZpD);
for (int i=poly.size()-1; i>=0; i--)
{ Mul(ans,ans,x,ZpD);
switch (poly[i])
{ case 0:
break;
case 1:
Add(ans,ans,one,ZpD);
break;
case -1:
Sub(ans,ans,one,ZpD);
break;
case 2:
Add(ans,ans,one,ZpD);
Add(ans,ans,one,ZpD);
break;
case -2:
Sub(ans,ans,one,ZpD);
Sub(ans,ans,one,ZpD);
break;
case 3:
Add(ans,ans,one,ZpD);
Add(ans,ans,one,ZpD);
Add(ans,ans,one,ZpD);
break;
case -3:
Sub(ans,ans,one,ZpD);
Sub(ans,ans,one,ZpD);
Sub(ans,ans,one,ZpD);
break;
default:
throw not_implemented();
if (poly[i] > 0)
{
for (int j = 0; j < poly[i]; j++)
Add(ans, ans, one, ZpD);
}
if (poly[i] < 0)
{
for (int j = 0; j < -poly[i]; j++)
Sub(ans, ans, one, ZpD);
}
}
}

View File

@@ -40,10 +40,9 @@ template <class FD>
void PartSetup<FD>::generate_setup(int n_parties, int plaintext_length, int sec,
int slack, bool round_up)
{
sec = max(sec, 40);
params.set_min_sec(sec);
Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup(
params, FieldD);
params.set_sec(sec);
pk = FHE_PK(params, FieldD.get_prime());
sk = FHE_SK(params, FieldD.get_prime());
calpha = Ciphertext(params);
@@ -180,11 +179,8 @@ void PartSetup<P2Data>::init_field()
}
template <class FD>
void PartSetup<FD>::check(int sec) const
void PartSetup<FD>::check() const
{
sec = max(sec, 40);
if (abs(sec - params.secp()) > 2)
throw runtime_error("security parameters vary too much between protocol and distributed decryption");
sk.check(params, pk, FieldD.get_prime());
}
@@ -203,7 +199,7 @@ template<class FD>
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
::secure_init(*this, P, machine, plaintext_length, sec, params);
}
template<class FD>

View File

@@ -57,7 +57,7 @@ public:
void init_field();
void check(int sec) const;
void check() const;
bool operator!=(const PartSetup<FD>& other);
void secure_init(Player& P, MachineBase& machine, int plaintext_length,

View File

@@ -274,7 +274,7 @@ void EncCommit<T,FD,S>::Create_More() const
(*P).Broadcast_Receive(ctx_Delta);
// Output the ctx_Delta to a file
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ofstream outf(filename);
for (int j=0; j<(*P).num_players(); j++)
{
@@ -308,7 +308,7 @@ void EncCommit<T,FD,S>::Create_More() const
octetStream occ,ctx_D;
for (int i=0; i<2*TT; i++)
{ if (open[i]==1)
{ sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
{ snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ifstream inpf(filename);
for (int j=0; j<(*P).num_players(); j++)
{
@@ -386,7 +386,7 @@ void EncCommit<T,FD,S>::Create_More() const
Ciphertext enc1(params),enc2(params),eDelta(params);
octetStream oe1,oe2;
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
ifstream inpf(filename);
for (int k=0; k<(*P).num_players(); k++)
{

View File

@@ -57,7 +57,7 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
template <class FD>
void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
OT_ROLE role, int n_summands)
OT_ROLE role, int)
{
o.reset_write_head();
@@ -67,20 +67,10 @@ void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
G.ReSeed();
timers["Mask randomization"].start();
product_share.randomize(G);
bigint B = 6 * machine.setup<FD>().params.get_R();
B *= machine.setup<FD>().FieldD.get_prime();
B <<= machine.drown_sec;
// slack
B *= NonInteractiveProof::slack(machine.sec,
machine.setup<FD>().params.phi_m());
B <<= machine.extra_slack;
B *= n_summands;
rc.generateUniform(G, 0, B, B);
mask = c;
mask.rerandomize(other_pk);
timers["Mask randomization"].stop();
timers["Encryption"].start();
other_pk.encrypt(mask, product_share, rc);
timers["Encryption"].stop();
mask += c;
mask += product_share;
mask.pack(o);
res -= product_share;
}
@@ -130,6 +120,13 @@ void Multiplier<FD>::report_size(ReportType type, MemoryUsage& res)
res += memory_usage;
}
template<class FD>
const vector<Ciphertext>& Multiplier<FD>::get_multiplicands(
const vector<vector<Ciphertext> >& others_ct, const FHE_PK&)
{
return others_ct[P.get_full_player().get_player(-P.get_offset())];
}
template class Multiplier<FFT_Data>;
template class Multiplier<P2Data>;

View File

@@ -55,6 +55,9 @@ public:
size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res);
size_t report_volatile() { return volatile_capacity; }
const vector<Ciphertext>& get_multiplicands(
const vector<vector<Ciphertext>>& others_ct, const FHE_PK&);
};
#endif /* FHEOFFLINE_MULTIPLIER_H_ */

View File

@@ -24,7 +24,7 @@ PairwiseGenerator<FD>::PairwiseGenerator(int thread_num,
thread_num, machine.output, machine.get_prep_dir<FD>(P)),
EC(P, machine.other_pks, machine.setup<FD>().FieldD, timers, machine, *this),
MC(machine.setup<FD>().alphai),
n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)),
n_ciphertexts(EC.proof.U),
C(n_ciphertexts, machine.setup<FD>().params), volatile_memory(0),
machine(machine)
{
@@ -175,7 +175,7 @@ size_t PairwiseGenerator<FD>::report_size(ReportType type)
template <class FD>
size_t PairwiseGenerator<FD>::report_sent()
{
return P.sent;
return P.total_comm().sent;
}
template <class FD>

View File

@@ -17,19 +17,17 @@ PairwiseMachine::PairwiseMachine(Player& P) :
{
}
PairwiseMachine::PairwiseMachine(int argc, const char** argv) :
MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")),
other_pks(N.num_players(), {setup_p.params, 0}),
pk(other_pks[N.my_num()]), sk(pk)
RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) :
MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise"))
{
init();
}
void PairwiseMachine::init()
void RealPairwiseMachine::init()
{
if (use_gf2n)
{
field_size = 40;
field_size = gf2n_short::DEFAULT_LENGTH;
gf2n_short::init_field(field_size);
setup_keys<P2Data>();
}
@@ -63,11 +61,11 @@ PairwiseSetup<P2Data>& PairwiseMachine::setup()
}
template <class FD>
void PairwiseMachine::setup_keys()
void RealPairwiseMachine::setup_keys()
{
auto& N = P;
PairwiseSetup<FD>& s = setup<FD>();
s.init(P, drown_sec, field_size, extra_slack);
s.init(P, sec, field_size, extra_slack);
if (output)
write_mac_key(get_prep_dir<FD>(P), P.my_num(), P.num_players(), s.alphai);
for (auto& x : other_pks)
@@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys()
if (i != N.my_num())
other_pks[i].unpack(os[i]);
set_mac_key(s.alphai);
Share<typename FD::T>::MAC_Check::setup(P);
}
template <class T>
void PairwiseMachine::set_mac_key(T alphai)
void RealPairwiseMachine::set_mac_key(T alphai)
{
typedef typename T::FD FD;
auto& N = P;
@@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const
bundle.compare(P);
}
template void PairwiseMachine::setup_keys<FFT_Data>();
template void PairwiseMachine::setup_keys<P2Data>();
template void RealPairwiseMachine::setup_keys<FFT_Data>();
template void RealPairwiseMachine::setup_keys<P2Data>();

View File

@@ -10,7 +10,7 @@
#include "FHEOffline/SimpleMachine.h"
#include "FHEOffline/PairwiseSetup.h"
class PairwiseMachine : public MachineBase
class PairwiseMachine : public virtual MachineBase
{
public:
PairwiseSetup<FFT_Data> setup_p;
@@ -23,15 +23,6 @@ public:
vector<Ciphertext> enc_alphas;
PairwiseMachine(Player& P);
PairwiseMachine(int argc, const char** argv);
void init();
template <class FD>
void setup_keys();
template <class T>
void set_mac_key(T alphai);
template <class FD>
PairwiseSetup<FD>& setup();
@@ -42,4 +33,18 @@ public:
void check(Player& P) const;
};
class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine
{
public:
RealPairwiseMachine(int argc, const char** argv);
void init();
template <class FD>
void setup_keys();
template <class T>
void set_mac_key(T alphai);
};
#endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */

View File

@@ -9,6 +9,7 @@
#include "Math/Setup.h"
#include "FHEOffline/Proof.h"
#include "FHEOffline/PairwiseMachine.h"
#include "FHEOffline/TemiSetup.h"
#include "Tools/Commit.h"
#include "Tools/Bundle.h"
#include "Processor/OnlineOptions.h"
@@ -53,7 +54,7 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
template <class FD>
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
::secure_init(*this, P, machine, plaintext_length, sec, params);
alpha = FieldD;
machine.sk = FHE_SK(params, FieldD.get_prime());
for (auto& pk : machine.other_pks)
@@ -62,16 +63,20 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
template <class T, class U>
void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec)
int plaintext_length, int sec, FHE_Params& params)
{
assert(sec >= 0);
machine.sec = sec;
sec = max(sec, 40);
machine.drown_sec = sec;
params.set_min_sec(sec);
string filename = PREP_DIR + T::name() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
+ to_string(params.secp()) + "-"
+ to_string(params.get_matrix_dim()) + "-"
+ OnlineOptions::singleton.prime.get_str() + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
string reason;
try
{
ifstream file(filename);
@@ -79,13 +84,30 @@ void secure_init(T& setup, Player& P, U& machine,
os.input(file);
os.get(machine.extra_slack);
setup.unpack(os);
}
catch (exception& e)
{
reason = e.what();
}
try
{
setup.check(P, machine);
}
catch (...)
catch (exception& e)
{
cout << "Finding parameters for security " << sec << " and field size ~2^"
<< plaintext_length << endl;
setup.params = setup.params.n_mults();
reason = e.what();
}
if (not reason.empty())
{
if (OnlineOptions::singleton.verbose)
cerr << "Generating parameters for security " << sec
<< " and field size ~2^" << plaintext_length
<< " because no suitable material "
"from a previous run was found (" << reason << ")"
<< endl;
setup = {};
setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine);
octetStream os;
@@ -94,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine,
ofstream file(filename);
os.output(file);
}
if (OnlineOptions::singleton.verbose)
{
cerr << "Ciphertext length: " << params.p0().numBits();
for (size_t i = 1; i < params.FFTD().size(); i++)
cerr << "+" << params.FFTD()[i].get_prime().numBits();
cerr << endl;
}
}
template <class FD>
@@ -208,5 +238,8 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
template class PairwiseSetup<FFT_Data>;
template class PairwiseSetup<P2Data>;
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);

View File

@@ -15,7 +15,7 @@ class MachineBase;
template <class T, class U>
void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec);
int plaintext_length, int sec, FHE_Params& params);
template <class FD>
class PairwiseSetup

View File

@@ -577,7 +577,7 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
for (int j = min; j < max; j++)
{
AddableVector<Ciphertext> C;
vector<Plaintext_<FD>> m(EC.machine->sec, FieldD);
vector<Plaintext_<FD>> m(personal_EC.proof.U, FieldD);
if (j == P.my_num())
{
for (auto& x : m)

Some files were not shown because too many files have changed in this diff Show More