mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Merge branch 'refs/heads/master' into parallel_permutations
# Conflicts: # Processor/Processor.hpp # Protocols/FakeProtocol.h # Protocols/Rep3Shuffler.h # Protocols/Rep3Shuffler.hpp # Protocols/SecureShuffle.h # Protocols/SecureShuffle.hpp
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,6 +49,7 @@ callgrind.out.*
|
||||
Programs/Bytecode/*
|
||||
Programs/Schedules/*
|
||||
Programs/Public-Input/*
|
||||
Programs/Functions
|
||||
*.com
|
||||
*.class
|
||||
*.dll
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -13,3 +13,6 @@
|
||||
[submodule "deps/SimplestOT_C"]
|
||||
path = deps/SimplestOT_C
|
||||
url = https://github.com/mkskeller/SimplestOT_C
|
||||
[submodule "deps/sse2neon"]
|
||||
path = deps/sse2neon
|
||||
url = https://github.com/DLTcollab/sse2neon
|
||||
|
||||
@@ -15,7 +15,7 @@ int AndJob::run()
|
||||
#endif
|
||||
__m128i* prf_output = new __m128i[PAD_TO_8(ProgramParty::s().get_n_parties())];
|
||||
auto gate = gates.begin();
|
||||
vector< GC::Secret<EvalRegister> >& S = *this->S;
|
||||
auto& S = *this->S;
|
||||
const vector<int>& args = *this->args;
|
||||
int i_gate = 0;
|
||||
for (size_t i = start; i < end; i += 4)
|
||||
|
||||
@@ -15,7 +15,7 @@ using namespace std;
|
||||
|
||||
class AndJob
|
||||
{
|
||||
vector< GC::Secret<EvalRegister> >* S;
|
||||
StackedVector< GC::Secret<EvalRegister> >* S;
|
||||
const vector<int>* args;
|
||||
|
||||
public:
|
||||
@@ -25,7 +25,7 @@ public:
|
||||
|
||||
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
|
||||
|
||||
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
|
||||
void reset(StackedVector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
|
||||
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
|
||||
{
|
||||
this->S = &S;
|
||||
|
||||
26
CHANGELOG.md
26
CHANGELOG.md
@@ -1,5 +1,31 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.4.0 (November 21, 2024)
|
||||
|
||||
- Functionality to call high-level code from C++
|
||||
- Matrix triples from file for all appropriate protocols
|
||||
- Exit with message on errors instead of uncaught exceptions
|
||||
- Reduce memory usage for binary memory
|
||||
- Optimized cint-regint conversion in Dealer protocol
|
||||
- Fixed security bug: missing MAC check in probabilistic truncation
|
||||
|
||||
## 0.3.9 (July 9, 2024)
|
||||
|
||||
- Inference with non-sequential PyTorch networks
|
||||
- SHA-3 for any input length (@hiddely)
|
||||
- Improved client facilities
|
||||
- Shuffling with malicious security for SPDZ-wise protocols by [Asharov et al.](https://ia.cr/2022/1595)
|
||||
- More reusable bytecode via in-thread calling facility
|
||||
- Recursive functions without return values
|
||||
- Fewer rounds for parallel matrix multiplications (@vincent-ehrmanntraut)
|
||||
- Optimized usage of SoftSpokenOT in semi-honest protocols
|
||||
- More integrity checks on storage in MAC-based protocols
|
||||
- Use C++17
|
||||
- Use glibc 2.18 for the binaries
|
||||
- Fixed security bugs: remotely caused buffer overflows (#1382)
|
||||
- Fixed security bug: Missing randomization before revealing to client
|
||||
- Fixed security bug: Bias in Rep3 secure shuffling
|
||||
|
||||
## 0.3.8 (December 14, 2023)
|
||||
|
||||
- Functionality for multiple nodes per party
|
||||
|
||||
8
CONFIG
8
CONFIG
@@ -48,6 +48,8 @@ ARCH =
|
||||
AVX_OT = 0
|
||||
endif
|
||||
|
||||
AVX_SIMPLEOT := $(AVX_OT)
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
|
||||
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
|
||||
@@ -69,11 +71,13 @@ CXX = clang++
|
||||
# use CONFIG.mine to overwrite DIR settings
|
||||
-include CONFIG.mine
|
||||
|
||||
AVX_SIMPLEOT := $(AVX_OT)
|
||||
|
||||
ifeq ($(USE_GF2N_LONG),1)
|
||||
GF2N_LONG = -DUSE_GF2N_LONG
|
||||
endif
|
||||
|
||||
ifeq ($(AVX_OT), 0)
|
||||
ifeq ($(AVX_SIMPLEOT), 0)
|
||||
CFLAGS += -DNO_AVX_OT
|
||||
endif
|
||||
|
||||
@@ -106,7 +110,7 @@ else
|
||||
BOOST = -lboost_thread $(MY_BOOST)
|
||||
endif
|
||||
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++17 -Werror
|
||||
CFLAGS += $(BREW_CFLAGS)
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
@@ -203,8 +203,10 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
|
||||
def add_usage(self, req_node):
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
size = self.args[i + 1]
|
||||
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
|
||||
req_node.increment(('bit', 'mixed'), size)
|
||||
n = (n - 3) // 2
|
||||
req_node.increment(('bit', 'triple'), size * n)
|
||||
if n > 1:
|
||||
req_node.increment(('bit', 'mixed'), size * ((n + 63) // 64))
|
||||
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
@@ -538,8 +540,8 @@ class split(base.Instruction):
|
||||
|
||||
:param: number of arguments to follow (number of bits times number of additive shares plus one)
|
||||
:param: source (sint)
|
||||
:param: first share of least significant bit
|
||||
:param: second share of least significant bit
|
||||
:param: first share of least significant bit (sbit)
|
||||
:param: second share of least significant bit (sbit)
|
||||
:param: (remaining share of least significant bit)...
|
||||
:param: (repeat from first share for bit one step higher)...
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@ from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint,
|
||||
from Compiler.types import vectorized_classmethod
|
||||
from Compiler.program import Tape, Program
|
||||
from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint, library
|
||||
from Compiler import util, oram, floatingpoint, library, comparison
|
||||
from Compiler import instructions_base
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
@@ -21,6 +21,11 @@ import math
|
||||
import itertools
|
||||
from functools import reduce
|
||||
|
||||
class _binary:
|
||||
def reveal_to(self, *args, **kwargs):
|
||||
raise CompilerError(
|
||||
'%s does not support revealing to indivual players' % type(self))
|
||||
|
||||
class bits(Tape.Register, _structure, _bit):
|
||||
n = 40
|
||||
unit = 64
|
||||
@@ -149,6 +154,12 @@ class bits(Tape.Register, _structure, _bit):
|
||||
self.n = n
|
||||
def set_size(self, size):
|
||||
pass
|
||||
def load_int(self, value):
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
for i in range(n_limbs):
|
||||
self.conv_regint(min(self.unit, self.n - i * self.unit),
|
||||
self[i], regint(value % 2 ** self.unit))
|
||||
value >>= self.unit
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cint):
|
||||
assert(self.n == other.size)
|
||||
@@ -236,12 +247,14 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return res
|
||||
def if_else(self, x, y):
|
||||
"""
|
||||
Vectorized oblivious selection::
|
||||
Bit-wise oblivious selection::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
|
||||
|
||||
This will output 1.
|
||||
This will output 1 because it selects the two least
|
||||
significant bits from 5 and the rest of the bits from 2.
|
||||
|
||||
"""
|
||||
return result_conv(x, y)(self & (x ^ y) ^ y)
|
||||
def zero_if_not(self, condition):
|
||||
@@ -268,6 +281,9 @@ class bits(Tape.Register, _structure, _bit):
|
||||
self.bit_compose(source.bit_decompose()[base:base + size]))
|
||||
def vector_size(self):
|
||||
return self.n
|
||||
@staticmethod
|
||||
def size_for_mem():
|
||||
return 1
|
||||
|
||||
class cbits(bits):
|
||||
""" Clear bits register. Helper type with limited functionality. """
|
||||
@@ -302,13 +318,6 @@ class cbits(bits):
|
||||
else:
|
||||
return super(cbits, cls).conv(other)
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=n_limbs)
|
||||
for i in range(n_limbs):
|
||||
tmp[i].load_int(value % 2 ** self.unit)
|
||||
value >>= self.unit
|
||||
self.load_other(tmp)
|
||||
def store_in_dynamic_mem(self, address):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
def clear_op(self, other, c_inst, ci_inst, op):
|
||||
@@ -502,11 +511,7 @@ class sbits(bits):
|
||||
if self.n <= 32:
|
||||
inst.ldbits(self, self.n, value)
|
||||
else:
|
||||
size = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=size)
|
||||
for i in range(size):
|
||||
tmp[i].load_int((value >> (i * 64)) % 2**64)
|
||||
self.load_other(tmp)
|
||||
bits.load_int(self, value)
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cbits) and self.n == other.n:
|
||||
inst.convcbit2s(self.n, self, other)
|
||||
@@ -675,7 +680,7 @@ class sbits(bits):
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
|
||||
class sbitvec(_vec, _bit):
|
||||
class sbitvec(_vec, _bit, _binary):
|
||||
""" Vector of registers of secret bits, effectively a matrix of secret bits.
|
||||
This facilitates parallel arithmetic operations in binary circuits.
|
||||
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
|
||||
@@ -732,15 +737,20 @@ class sbitvec(_vec, _bit):
|
||||
:py:obj:`v` and the columns by calling :py:obj:`elements`.
|
||||
"""
|
||||
class sbitvecn(cls, _structure):
|
||||
n_bits = n
|
||||
@staticmethod
|
||||
def malloc(size, creator_tape=None):
|
||||
return sbit.malloc(size * n, creator_tape=creator_tape)
|
||||
def get_type(n):
|
||||
return cls.get_type(n)
|
||||
@classmethod
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return sbit.malloc(
|
||||
size * cls.mem_size(), creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@staticmethod
|
||||
def mem_size():
|
||||
return n
|
||||
return sbits.get_type(n).mem_size()
|
||||
@classmethod
|
||||
def get_input_from(cls, player, size=1, f=0):
|
||||
""" Secret input from :py:obj:`player`. The input is decomposed
|
||||
@@ -748,17 +758,19 @@ class sbitvec(_vec, _bit):
|
||||
|
||||
:param: player (int)
|
||||
"""
|
||||
v = [0] * n
|
||||
sbits._check_input_player(player)
|
||||
instructions_base.check_vector_size(size)
|
||||
for i in range(size):
|
||||
vv = [sbit() for i in range(n)]
|
||||
inst.inputbvec(n + 3, f, player, *vv)
|
||||
for j in range(n):
|
||||
tmp = vv[j] << i
|
||||
v[j] = tmp ^ v[j]
|
||||
sbits._check_input_player(player)
|
||||
return cls.from_vec(v)
|
||||
if size == 1:
|
||||
res = cls.from_vec(sbit() for i in range(n))
|
||||
inst.inputbvec(n + 3, f, player, *res.v)
|
||||
return res
|
||||
else:
|
||||
elements = []
|
||||
for i in range(size):
|
||||
v = sbits.get_type(n)()
|
||||
inst.inputb(player, n, f, v)
|
||||
elements.append(v)
|
||||
return cls(elements)
|
||||
get_raw_input_from = get_input_from
|
||||
@classmethod
|
||||
def from_vec(cls, vector):
|
||||
@@ -780,38 +792,27 @@ class sbitvec(_vec, _bit):
|
||||
self.v = sbits.get_type(n)(other).bit_decompose()
|
||||
assert len(self.v) == n
|
||||
assert size is None or size == self.v[0].n
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address):
|
||||
size = instructions_base.get_global_vector_size()
|
||||
if size not in (None, 1):
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
sb = sbits.get_type(size)
|
||||
return cls.from_vec(sb.bit_compose(
|
||||
sbit.load_mem(address + i + j * n) for j in range(size))
|
||||
for i in range(n))
|
||||
if not isinstance(address, int):
|
||||
v = [sbit.load_mem(x, size=n).v[0] for x in address]
|
||||
return cls(v)
|
||||
@classmethod
|
||||
def load_mem(cls, address, size=None):
|
||||
if isinstance(address, int) or len(address) == 1:
|
||||
address = [address + i * cls.mem_size()
|
||||
for i in range(size or 1)]
|
||||
else:
|
||||
return cls.from_vec(sbit.load_mem(address + i)
|
||||
for i in range(n))
|
||||
assert size == None
|
||||
return cls(
|
||||
[sbits.get_type(n).load_mem(x) for x in address])
|
||||
def store_in_mem(self, address):
|
||||
size = 1
|
||||
for x in self.v:
|
||||
if not util.is_constant(x):
|
||||
size = max(size, x.n)
|
||||
v = [sbits.get_type(size).conv(x) for x in self.v]
|
||||
if not isinstance(address, int) and len(address) != 1:
|
||||
v = self.elements()
|
||||
assert len(v) == len(address)
|
||||
for x, y in zip(v, address):
|
||||
for i, xx in enumerate(x.bit_decompose(n)):
|
||||
xx.store_in_mem(y + i)
|
||||
if isinstance(address, int) or len(address) == 1:
|
||||
address = [address + i * self.mem_size()
|
||||
for i in range(size)]
|
||||
else:
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
for i in range(n):
|
||||
for j, x in enumerate(v[i].bit_decompose()):
|
||||
x.store_in_mem(address + i + j * n)
|
||||
assert size == len(address)
|
||||
for x, dest in zip(self.elements(), address):
|
||||
x.store_in_mem(dest)
|
||||
@classmethod
|
||||
def two_power(cls, nn, size=1):
|
||||
return cls.from_vec(
|
||||
@@ -864,7 +865,7 @@ class sbitvec(_vec, _bit):
|
||||
assert isinstance(elements, sint)
|
||||
if Program.prog.use_split():
|
||||
x = elements.split_to_two_summands(length)
|
||||
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
|
||||
v = sbitint.bit_adder(x[0], x[1])
|
||||
else:
|
||||
prog = Program.prog
|
||||
if not prog.options.ring:
|
||||
@@ -877,6 +878,7 @@ class sbitvec(_vec, _bit):
|
||||
length, prog.security)
|
||||
prog.use_edabit(backup)
|
||||
return
|
||||
comparison.require_ring_size(length, 'A2B conversion')
|
||||
l = int(Program.prog.options.ring)
|
||||
r, r_bits = sint.get_edabit(length, size=elements.size)
|
||||
c = ((elements - r) << (l - length)).reveal()
|
||||
@@ -885,6 +887,8 @@ class sbitvec(_vec, _bit):
|
||||
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
|
||||
v = x.v
|
||||
self.v = v[:length]
|
||||
elif isinstance(elements, sbitvec):
|
||||
self.v = elements.v
|
||||
elif elements is not None and not (util.is_constant(elements) and \
|
||||
elements == 0):
|
||||
self.v = sbits.trans(elements)
|
||||
@@ -1347,13 +1351,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
def __add__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
a, b = self.expand(other)
|
||||
try:
|
||||
a, b = self.expand(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
v = sbitint.bit_adder(a, b)
|
||||
return self.get_type(len(v)).from_vec(v)
|
||||
__radd__ = __add__
|
||||
__sub__ = _bitint.__sub__
|
||||
def __rsub__(self, other):
|
||||
a, b = self.expand(other)
|
||||
try:
|
||||
a, b = self.expand(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
return self.from_vec(b) - self.from_vec(a)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbits):
|
||||
@@ -1447,7 +1457,7 @@ class cbitfix(object):
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
|
||||
class sbitfix(_fix):
|
||||
class sbitfix(_fix, _binary):
|
||||
""" Secret signed fixed-point number in one binary register.
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
|
||||
@@ -1515,7 +1525,7 @@ class sbitfix(_fix):
|
||||
cls.set_precision(f, k)
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
class sbitfixvec(_fix, _vec):
|
||||
class sbitfixvec(_fix, _vec, _binary):
|
||||
""" Vector of fixed-point numbers for parallel binary computation.
|
||||
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
|
||||
@@ -76,7 +76,7 @@ class AllocRange:
|
||||
self.top += size
|
||||
self.limit = max(self.limit, self.top)
|
||||
if res >= REG_MAX:
|
||||
raise RegisterOverflowError()
|
||||
raise RegisterOverflowError(size)
|
||||
return res
|
||||
|
||||
def free(self, base, size):
|
||||
@@ -178,6 +178,8 @@ class StraightlineAllocator:
|
||||
dup = dup.vectorbase
|
||||
self.alloc[dup] = self.alloc[base]
|
||||
dup.i = self.alloc[base]
|
||||
if not dup.dup_count:
|
||||
dup.dup_count = len(base.duplicates)
|
||||
|
||||
def dealloc_reg(self, reg, inst, free):
|
||||
if reg.vector:
|
||||
@@ -209,7 +211,8 @@ class StraightlineAllocator:
|
||||
for x in itertools.chain(dup.duplicates, base.duplicates):
|
||||
to_check.add(x)
|
||||
|
||||
if reg not in self.program.base_addresses:
|
||||
if reg not in self.program.base_addresses \
|
||||
and not isinstance(inst, call_arg):
|
||||
free.free(base)
|
||||
if inst.is_vec() and base.vector:
|
||||
self.defined[base] = inst
|
||||
@@ -274,8 +277,9 @@ class StraightlineAllocator:
|
||||
for reg in self.alloc:
|
||||
for x in reg.get_all():
|
||||
if x not in self.dealloc and reg not in self.dealloc \
|
||||
and len(x.duplicates) == 0:
|
||||
print('Warning: read before write at register', x)
|
||||
and len(x.duplicates) == x.dup_count:
|
||||
print('Warning: read before write at register %s/%x' %
|
||||
(x, id(x)))
|
||||
print('\tregister trace: %s' % format_trace(x.caller,
|
||||
'\t\t'))
|
||||
if options.stop:
|
||||
@@ -485,7 +489,11 @@ class Merger:
|
||||
if this[-1] < other[0]:
|
||||
del this[:]
|
||||
this.append(n)
|
||||
for inst in other:
|
||||
if last_access_this_kind == last_mem_write_of:
|
||||
insts = itertools.chain(other, this)
|
||||
else:
|
||||
insts = other
|
||||
for inst in insts:
|
||||
add_edge(inst, n)
|
||||
|
||||
def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
|
||||
@@ -518,7 +526,11 @@ class Merger:
|
||||
last_other_kind[-1] > last_this_kind[-1]:
|
||||
last_this_kind[:] = []
|
||||
last_this_kind.append(n)
|
||||
for i in last_other_kind:
|
||||
if last_this_kind == last_mem_write:
|
||||
insts = itertools.chain(last_other_kind, last_this_kind)
|
||||
else:
|
||||
insts = last_other_kind
|
||||
for i in insts:
|
||||
add_edge(i, n)
|
||||
|
||||
def keep_order(instr, n, t, arg_index=None):
|
||||
@@ -608,7 +620,8 @@ class Merger:
|
||||
# so this threshold should lead to acceptable compile times even on slower processors.
|
||||
first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4]
|
||||
second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5]
|
||||
max_dependencies_per_matrix = 1500**2
|
||||
max_dependencies_per_matrix = \
|
||||
self.block.parent.program.budget
|
||||
if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix:
|
||||
if block.warn_about_mem and not block.parent.warned_about_mem:
|
||||
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')
|
||||
@@ -748,6 +761,8 @@ class Merger:
|
||||
G.remove_node(i)
|
||||
merge_nodes.discard(i)
|
||||
stats[type(instructions[i]).__name__] += 1
|
||||
for reg in instructions[i].get_def():
|
||||
self.block.parent.program.base_addresses.pop(reg)
|
||||
instructions[i] = None
|
||||
if unused_result:
|
||||
eliminate(i)
|
||||
@@ -779,10 +794,10 @@ class RegintOptimizer:
|
||||
self.rev_offset_cache = {}
|
||||
self.range_cache = util.dict_by_id()
|
||||
|
||||
def add_offset(self, res, new_base, new_offset):
|
||||
self.offset_cache[res] = new_base, new_offset
|
||||
if (new_base.i, new_offset) not in self.rev_offset_cache:
|
||||
self.rev_offset_cache[new_base.i, new_offset] = res
|
||||
def add_offset(self, res, new_base, new_offset, multiplier):
|
||||
self.offset_cache[res] = new_base, new_offset, multiplier
|
||||
if (new_base.i, new_offset, multiplier) not in self.rev_offset_cache:
|
||||
self.rev_offset_cache[new_base.i, new_offset, multiplier] = res
|
||||
|
||||
def run(self, instructions, program):
|
||||
for i, inst in enumerate(instructions):
|
||||
@@ -806,31 +821,40 @@ class RegintOptimizer:
|
||||
def f(base, delta_reg):
|
||||
delta = self.cache[delta_reg]
|
||||
if base in self.offset_cache:
|
||||
reg, offset = self.offset_cache[base]
|
||||
reg, offset, mult = self.offset_cache[base]
|
||||
new_base, new_offset = reg, offset + delta
|
||||
else:
|
||||
new_base, new_offset = base, delta
|
||||
self.add_offset(inst.args[0], new_base, new_offset)
|
||||
mult = 1
|
||||
self.add_offset(inst.args[0], new_base, new_offset,
|
||||
mult)
|
||||
if inst.args[1] in self.cache:
|
||||
f(inst.args[2], inst.args[1])
|
||||
elif inst.args[2] in self.cache:
|
||||
f(inst.args[1], inst.args[2])
|
||||
elif isinstance(inst, subint_class) and \
|
||||
inst.args[2] in self.cache:
|
||||
delta = self.cache[inst.args[2]]
|
||||
if inst.args[1] in self.offset_cache:
|
||||
reg, offset = self.offset_cache[inst.args[1]]
|
||||
new_base, new_offset = reg, offset - delta
|
||||
else:
|
||||
new_base, new_offset = inst.args[1], -delta
|
||||
self.add_offset(inst.args[0], new_base, new_offset)
|
||||
elif isinstance(inst, subint_class):
|
||||
def f(reg, cached, reverse):
|
||||
delta = self.cache[cached]
|
||||
if reg in self.offset_cache:
|
||||
reg, offset, mult = self.offset_cache[reg]
|
||||
new_base, new_offset = reg, offset - delta
|
||||
else:
|
||||
new_base = reg
|
||||
new_offset = -delta if reverse else delta
|
||||
mult = 1
|
||||
self.add_offset(inst.args[0], new_base, new_offset,
|
||||
mult if reverse else -mult)
|
||||
if inst.args[1] in self.cache:
|
||||
f(inst.args[2], inst.args[1], False)
|
||||
elif inst.args[2] in self.cache:
|
||||
f(inst.args[1], inst.args[2], True)
|
||||
elif isinstance(inst, IndirectMemoryInstruction):
|
||||
if inst.args[1] in self.cache:
|
||||
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
|
||||
instructions[i]._protect = inst._protect
|
||||
elif inst.args[1] in self.offset_cache:
|
||||
base, offset = self.offset_cache[inst.args[1]]
|
||||
addr = self.rev_offset_cache[base.i, offset]
|
||||
base, offset, mult = self.offset_cache[inst.args[1]]
|
||||
addr = self.rev_offset_cache[base.i, offset, mult]
|
||||
inst.args[1] = addr
|
||||
elif inst.args[1] in self.range_cache:
|
||||
size, base = self.range_cache[inst.args[1]]
|
||||
|
||||
@@ -5,7 +5,7 @@ the ones used below into ``Programs/Circuits`` as follows::
|
||||
|
||||
make Programs/Circuits
|
||||
|
||||
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
|
||||
.. _`Bristol Fashion`: https://nigelsmart.github.io/MPC-Circuits
|
||||
|
||||
"""
|
||||
import math
|
||||
@@ -15,6 +15,7 @@ from Compiler.library import function_block, get_tape
|
||||
from Compiler import util
|
||||
import itertools
|
||||
import struct
|
||||
import os
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
@@ -47,7 +48,12 @@ class Circuit:
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.filename = 'Programs/Circuits/%s.txt' % name
|
||||
if not os.path.exists(self.filename):
|
||||
if os.system('make Programs/Circuits'):
|
||||
raise CompilerError('Cannot download circuit descriptions. '
|
||||
'Make sure make and git are installed.')
|
||||
f = open(self.filename)
|
||||
self.functions = {}
|
||||
|
||||
@@ -57,8 +63,9 @@ class Circuit:
|
||||
def run(self, *inputs):
|
||||
n = inputs[0][0].n, get_tape()
|
||||
if n not in self.functions:
|
||||
self.functions[n] = function_block(lambda *args:
|
||||
self.compile(*args))
|
||||
self.functions[n] = function_block(
|
||||
lambda *args: self.compile(*args))
|
||||
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
|
||||
flat_res = self.functions[n](*itertools.chain(*inputs))
|
||||
res = []
|
||||
i = 0
|
||||
@@ -124,7 +131,7 @@ Keccak_f = None
|
||||
|
||||
def sha3_256(x):
|
||||
"""
|
||||
This function implements SHA3-256 for inputs of up to 1080 bits::
|
||||
This function implements SHA3-256 for inputs of any length::
|
||||
|
||||
from circuit import sha3_256
|
||||
a = sbitvec.from_vec([])
|
||||
@@ -138,7 +145,8 @@ def sha3_256(x):
|
||||
for x in a, b, c, d, e, f, g, h:
|
||||
sha3_256(x).reveal_print_hex()
|
||||
|
||||
This should output the `test vectors
|
||||
This should output the hashes of the above inputs, beginning with
|
||||
the `test vectors
|
||||
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
|
||||
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
|
||||
0 byte::
|
||||
|
||||
@@ -76,13 +76,13 @@ def require_ring_size(k, op):
|
||||
program.curr_tape.require_bit_length(k)
|
||||
|
||||
@instructions_base.cisc
|
||||
def LTZ(s, a, k, kappa):
|
||||
def LTZ(s, a, k):
|
||||
"""
|
||||
s = (a ?< 0)
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
movs(s, program.non_linear.ltz(a, k, kappa))
|
||||
movs(s, program.non_linear.ltz(a, k))
|
||||
|
||||
def LtzRing(a, k):
|
||||
from .types import sint, _bitint
|
||||
@@ -105,14 +105,14 @@ def LtzRing(a, k):
|
||||
u = CarryOutRaw(a[::-1], b[::-1])
|
||||
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
def LessThanZero(a, k):
|
||||
from . import types
|
||||
res = types.sint()
|
||||
LTZ(res, a, k, kappa)
|
||||
LTZ(res, a, k)
|
||||
return res
|
||||
|
||||
@instructions_base.cisc
|
||||
def Trunc(d, a, k, m, kappa, signed):
|
||||
def Trunc(d, a, k, m, signed):
|
||||
"""
|
||||
d = a >> m
|
||||
|
||||
@@ -124,7 +124,7 @@ def Trunc(d, a, k, m, kappa, signed):
|
||||
movs(d, a)
|
||||
return
|
||||
else:
|
||||
movs(d, program.non_linear.trunc(a, k, m, kappa, signed))
|
||||
movs(d, program.non_linear.trunc(a, k, m, signed=signed))
|
||||
|
||||
def TruncRing(d, a, k, m, signed):
|
||||
program.curr_tape.require_bit_length(1)
|
||||
@@ -197,13 +197,13 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
|
||||
masked = shifted >> n_shift
|
||||
u = sint()
|
||||
BitLTL(u, masked, r_bits[:n_bits], 0)
|
||||
BitLTL(u, masked, r_bits[:n_bits])
|
||||
res = (u << n_bits) + masked - r
|
||||
if signed:
|
||||
res -= (1 << (n_bits - 1))
|
||||
return res
|
||||
|
||||
def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
def TruncRoundNearest(a, k, m, signed=False):
|
||||
"""
|
||||
Returns a / 2^m, rounded to the nearest integer.
|
||||
|
||||
@@ -212,12 +212,10 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
"""
|
||||
if m == 0:
|
||||
return a
|
||||
nl = program.non_linear
|
||||
nl.check_security(kappa)
|
||||
return program.non_linear.trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
@instructions_base.cisc
|
||||
def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
def Mod2m(a_prime, a, k, m, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
|
||||
@@ -225,8 +223,6 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
m: compile-time integer
|
||||
signed: True/False, describes a
|
||||
"""
|
||||
nl = program.non_linear
|
||||
nl.check_security(kappa)
|
||||
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
|
||||
|
||||
def Mod2mRing(a_prime, a, k, m, signed):
|
||||
@@ -237,13 +233,13 @@ def Mod2mRing(a_prime, a, k, m, signed):
|
||||
tmp = a + r_prime
|
||||
c_prime = (tmp << shift).reveal(False) >> shift
|
||||
u = sint()
|
||||
BitLTL(u, c_prime, r_bin[:m], 0)
|
||||
BitLTL(u, c_prime, r_bin[:m])
|
||||
res = (u << m) + c_prime - r_prime
|
||||
if a_prime is not None:
|
||||
movs(a_prime, res)
|
||||
return res
|
||||
|
||||
def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
def Mod2mField(a_prime, a, k, m, signed):
|
||||
from .types import sint
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
@@ -255,7 +251,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2m = program.curr_block.new_reg('c')
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
PRandM(r_dprime, r_prime, r, k, m)
|
||||
ld2i(c2m, m)
|
||||
mulm(t[0], r_dprime, c2m)
|
||||
if signed:
|
||||
@@ -268,9 +264,9 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
asm_open(True, c, t[3])
|
||||
modc(c_prime, c, c2m)
|
||||
if const_rounds:
|
||||
BitLTC1(u, c_prime, r, kappa)
|
||||
BitLTC1(u, c_prime, r)
|
||||
else:
|
||||
BitLTL(u, c_prime, r, kappa)
|
||||
BitLTL(u, c_prime, r)
|
||||
mulm(t[4], u, c2m)
|
||||
submr(t[5], c_prime, r_prime)
|
||||
adds(a_prime, t[5], t[4])
|
||||
@@ -288,13 +284,15 @@ def MaskingBitsInRing(m, strict=False):
|
||||
r_bin = r
|
||||
return sint.bit_compose(r), r_bin
|
||||
|
||||
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
def PRandM(r_dprime, r_prime, b, k, m, use_dabit=True):
|
||||
"""
|
||||
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
|
||||
r_prime = random secret integer in range [0, 2^m - 1]
|
||||
b = array containing bits of r_prime
|
||||
"""
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
assert k >= m
|
||||
kappa = program.security
|
||||
program.curr_tape.require_bit_length(k + kappa, reason='statistical masking as in https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation')
|
||||
from .types import sint
|
||||
if program.use_edabit() and not const_rounds:
|
||||
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
|
||||
@@ -329,7 +327,7 @@ def PRandInt(r, k):
|
||||
bit(t[1][i])
|
||||
adds(t[2][i], t[0][i], t[1][i])
|
||||
|
||||
def BitLTC1(u, a, b, kappa):
|
||||
def BitLTC1(u, a, b):
|
||||
"""
|
||||
u = a <? b
|
||||
|
||||
@@ -395,7 +393,7 @@ def BitLTC1(u, a, b, kappa):
|
||||
subcfi(c[3][i], a_bits[i], 1)
|
||||
mulm(t[3][i], s[i], c[3][i])
|
||||
adds(t[4][i], t[4][i-1], t[3][i])
|
||||
Mod2(u, t[4][k-1], k, kappa, False)
|
||||
Mod2(u, t[4][k-1], k, False)
|
||||
return p, a_bits, d, s, t, c, b, pre_input
|
||||
|
||||
def carry(b, a, compute_p=True):
|
||||
@@ -414,7 +412,7 @@ def carry(b, a, compute_p=True):
|
||||
|
||||
# from WP9 report
|
||||
# length of a is even
|
||||
def CarryOutAux(a, kappa):
|
||||
def CarryOutAux(a):
|
||||
k = len(a)
|
||||
if k > 1 and k % 2 == 1:
|
||||
a.append(None)
|
||||
@@ -424,12 +422,12 @@ def CarryOutAux(a, kappa):
|
||||
if k > 1:
|
||||
for i in range(k//2):
|
||||
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
|
||||
return CarryOutAux(u[:k//2][::-1], kappa)
|
||||
return CarryOutAux(u[:k//2][::-1])
|
||||
else:
|
||||
return a[0][1]
|
||||
|
||||
# carry out with carry-in bit c
|
||||
def CarryOut(res, a, b, c=0, kappa=None):
|
||||
def CarryOut(res, a, b, c=0):
|
||||
"""
|
||||
res = last carry bit in addition of a and b
|
||||
|
||||
@@ -456,7 +454,7 @@ def CarryOutRaw(a, b, c=0):
|
||||
s[0] = d[-1][0].bit_and(c)
|
||||
s[1] = d[-1][1] + s[0]
|
||||
d[-1][1] = s[1]
|
||||
return CarryOutAux(d[::-1], None)
|
||||
return CarryOutAux(d[::-1])
|
||||
|
||||
def CarryOutRawLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
@@ -469,7 +467,7 @@ def CarryOutLE(a, b, c=0):
|
||||
CarryOut(res, a[::-1], b[::-1], c)
|
||||
return res
|
||||
|
||||
def BitLTL(res, a, b, kappa):
|
||||
def BitLTL(res, a, b):
|
||||
"""
|
||||
res = a <? b (logarithmic rounds version)
|
||||
|
||||
@@ -624,7 +622,7 @@ def KMulC(a):
|
||||
PreMulC_without_inverses(p, a)
|
||||
return p
|
||||
|
||||
def Mod2(a_0, a, k, kappa, signed):
|
||||
def Mod2(a_0, a, k, signed):
|
||||
"""
|
||||
a_0 = a % 2
|
||||
|
||||
@@ -641,7 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
tc = program.curr_block.new_reg('c')
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1)
|
||||
r_0 = r_prime
|
||||
mulsi(t[0], r_dprime, 2)
|
||||
if signed:
|
||||
|
||||
@@ -13,8 +13,18 @@ from .program import Program, defaults
|
||||
|
||||
|
||||
class Compiler:
|
||||
singleton = None
|
||||
|
||||
def __init__(self, custom_args=None, usage=None, execute=False,
|
||||
split_args=False):
|
||||
if Compiler.singleton:
|
||||
raise CompilerError(
|
||||
"Cannot have more than one compiler instance. "
|
||||
"It's not possible to run direct compilation programs with "
|
||||
"compile.py or compile-run.py.")
|
||||
else:
|
||||
Compiler.singleton = self
|
||||
|
||||
if usage:
|
||||
self.usage = usage
|
||||
else:
|
||||
@@ -165,7 +175,8 @@ class Compiler:
|
||||
dest="prime",
|
||||
default=defaults.prime,
|
||||
help="use bit decomposition with a specifed prime modulus "
|
||||
"for non-linear computation (default: use the masking approach)",
|
||||
"for non-linear computation (default: use the masking approach). "
|
||||
"Don't use this unless you're certain that you need it.",
|
||||
)
|
||||
parser.add_option(
|
||||
"-I",
|
||||
@@ -252,6 +263,13 @@ class Compiler:
|
||||
dest="hostfile",
|
||||
help="hosts to execute with",
|
||||
)
|
||||
parser.add_option(
|
||||
"-t",
|
||||
"--tidy_output",
|
||||
action="store_true",
|
||||
dest="tidy_output",
|
||||
help="make output prints tidy and grouped by party (note: delays the prints)",
|
||||
)
|
||||
else:
|
||||
parser.add_option(
|
||||
"-E",
|
||||
@@ -263,6 +281,14 @@ class Compiler:
|
||||
|
||||
def parse_args(self):
|
||||
self.options, self.args = self.parser.parse_args(self.custom_args)
|
||||
if self.options.verbose:
|
||||
self.runtime_args += ["--verbose"]
|
||||
if self.options.execute:
|
||||
self.options.execute = re.sub("-party.x$", "", self.options.execute)
|
||||
for s, l in self.match.items():
|
||||
if self.options.execute == l:
|
||||
self.options.execute = s
|
||||
break
|
||||
if self.execute:
|
||||
if not self.options.execute:
|
||||
if len(self.args) > 1:
|
||||
@@ -313,6 +339,8 @@ class Compiler:
|
||||
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
|
||||
if self.options.execute in ("rep4-ring",):
|
||||
self.prog.use_split(4)
|
||||
if self.options.execute.find("dealer") >= 0:
|
||||
self.prog.use_edabit(True)
|
||||
|
||||
def build_vars(self):
|
||||
from . import comparison, floatingpoint, instructions, library, types
|
||||
@@ -368,7 +396,14 @@ class Compiler:
|
||||
"cfloat",
|
||||
"squant",
|
||||
]:
|
||||
del self.VARS[i]
|
||||
class dummy:
|
||||
def __init__(self, *args):
|
||||
raise CompilerError(self.error)
|
||||
dummy.error = i + " not availabe with binary circuits"
|
||||
if i in ("cint", "cfix"):
|
||||
dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \
|
||||
"latest/Compiler.html#Compiler.types." + i
|
||||
self.VARS[i] = dummy
|
||||
else:
|
||||
self.sint = types.sint
|
||||
self.sfix = types.sfix
|
||||
@@ -418,6 +453,8 @@ class Compiler:
|
||||
continue
|
||||
m = re.match(r"(\s*)if(\W.*):", line)
|
||||
if m:
|
||||
while if_stack and if_stack[-1][0] == m.group(1):
|
||||
if_stack.pop()
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
|
||||
output.append("%sdef _():\n" % (m.group(1)))
|
||||
@@ -503,13 +540,15 @@ class Compiler:
|
||||
|
||||
return self.prog
|
||||
|
||||
@staticmethod
|
||||
def executable_from_protocol(protocol):
|
||||
match = {
|
||||
"ring": "replicated-ring",
|
||||
"rep-field": "replicated-field",
|
||||
"replicated": "replicated-bin"
|
||||
}
|
||||
match = {
|
||||
"ring": "replicated-ring",
|
||||
"rep-field": "replicated-field",
|
||||
"replicated": "replicated-bin"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def executable_from_protocol(cls, protocol):
|
||||
match = cls.match
|
||||
if protocol in match:
|
||||
protocol = match[protocol]
|
||||
if protocol.find("bmr") == -1:
|
||||
@@ -588,11 +627,25 @@ class Compiler:
|
||||
for filename in glob.glob("Player-Data/*.0"):
|
||||
connection.put(filename, dest + "Player-Data")
|
||||
|
||||
def run_with_error(i):
|
||||
try:
|
||||
run(i)
|
||||
except IOError:
|
||||
print('IO error when copying files, does %s have enough space?' %
|
||||
hostnames[i])
|
||||
raise
|
||||
|
||||
import threading
|
||||
import random
|
||||
import io
|
||||
|
||||
def run_and_capture_outputs(outputs, fn, i):
|
||||
out = fn(i)
|
||||
outputs[i] = out
|
||||
|
||||
threads = []
|
||||
for i in range(len(hosts)):
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
threads.append(threading.Thread(target=run_with_error, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
@@ -600,6 +653,14 @@ class Compiler:
|
||||
|
||||
# execution
|
||||
threads = []
|
||||
|
||||
# tidy up output prints
|
||||
hide_option = False
|
||||
if self.options.tidy_output:
|
||||
outputs = []
|
||||
for i in range(len(connections)):
|
||||
outputs += [""]
|
||||
hide_option = True
|
||||
# random port numbers to avoid conflict
|
||||
port = 10000 + random.randrange(40000)
|
||||
if '@' in hostnames[0]:
|
||||
@@ -614,9 +675,15 @@ class Compiler:
|
||||
run = lambda i: connections[i].run(
|
||||
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
|
||||
(destinations[i], vm, i, self.prog.name, party0, port,
|
||||
' '.join(args + N)))
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
' '.join(args + N)), hide=hide_option)
|
||||
if self.options.tidy_output:
|
||||
threads.append(threading.Thread(target=run_and_capture_outputs, args=(outputs, run, i,)))
|
||||
else:
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
if self.options.tidy_output:
|
||||
for out in outputs:
|
||||
print(out)
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
|
||||
REG_MAX = 2 ** 32
|
||||
USER_MEM = 8192
|
||||
MEM_MAX = 2 ** 64
|
||||
|
||||
P_VALUES = { 32: 2147565569, \
|
||||
64: 9223372036855103489, \
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
""" This module implements `Dijkstra's algorithm
|
||||
<https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>`_ based on
|
||||
oblivious RAM. """
|
||||
|
||||
|
||||
from Compiler.oram import *
|
||||
|
||||
from Compiler.program import Program
|
||||
@@ -222,7 +227,21 @@ class HeapQ(object):
|
||||
print_ln()
|
||||
print_ln()
|
||||
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
|
||||
debug=False):
|
||||
""" Securely compute Dijstra's algorithm on a secret graph. See
|
||||
:download:`../Programs/Source/dijkstra_example.mpc` for an
|
||||
explanation of the required inputs.
|
||||
|
||||
:param source: source node (secret or clear-text integer)
|
||||
:param edges: ORAM representation of edges
|
||||
:param e_index: ORAM representation of vertices
|
||||
:param oram_type: ORAM type to use internally (default:
|
||||
:py:func:`~Compiler.oram.OptimalORAM`)
|
||||
:param n_loops: when to stop (default: number of edges)
|
||||
:param int_type: secret integer type (default: sint)
|
||||
|
||||
"""
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else -1
|
||||
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
|
||||
@@ -246,10 +265,12 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
|
||||
last_edge = MemValue(basic_type(1))
|
||||
i_edge = MemValue(int_type(0))
|
||||
u = MemValue(basic_type(0))
|
||||
running = MemValue(basic_type(1))
|
||||
@for_range(n_loops or edges.size)
|
||||
def f(i):
|
||||
print_ln('loop %s', i)
|
||||
time()
|
||||
running.write(last_edge.bit_not().bit_or(Q.size > 0).bit_and(running))
|
||||
u.write(if_else(last_edge, Q.pop(last_edge), u))
|
||||
#visited.access(u, True, last_edge)
|
||||
i_edge.write(int_type(if_else(last_edge, e_index[u], i_edge)))
|
||||
@@ -261,30 +282,50 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
|
||||
dv, not_visited = dist.read(v)
|
||||
# relying on default dv negative here
|
||||
is_shorter = (alt < int_type(dv[0])) + not_visited
|
||||
is_shorter *= running
|
||||
dist.access(v, (basic_type(alt), u), is_shorter)
|
||||
#previous.access(v, u, is_shorter)
|
||||
Q.update(v, basic_type(alt), is_shorter)
|
||||
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \
|
||||
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \
|
||||
not_visited.reveal())
|
||||
if debug:
|
||||
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, '
|
||||
'shorter: %s, running: %s, queue size: %s, last edge: %s',
|
||||
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(),
|
||||
not_visited.reveal(), is_shorter.reveal(),
|
||||
running.reveal(), Q.size.reveal(), last_edge.reveal())
|
||||
return dist
|
||||
|
||||
def convert_graph(G):
|
||||
edges = [None] * (2 * G.size())
|
||||
e_index = [None] * (len(G))
|
||||
i = 0
|
||||
for v in G:
|
||||
e_index[v] = i
|
||||
for u in G[v]:
|
||||
edges[i] = [u, G[v][u]['weight'], 0]
|
||||
i += 1
|
||||
edges[i-1][-1] = 1
|
||||
return edges, e_index
|
||||
|
||||
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
|
||||
""" Convert a `NetworkX directed graph
|
||||
<https://networkx.org/documentation/stable/reference/classes/digraph.html>`_
|
||||
to the cleartext representation of what :py:func:`dijkstra` expects. """
|
||||
G = G.copy()
|
||||
for u in G:
|
||||
for v in G[u]:
|
||||
G[u][v].setdefault('weight', 1)
|
||||
edges = [None] * (2 * G.size())
|
||||
e_index = [None] * (len(G))
|
||||
i = 0
|
||||
for v in sorted(G):
|
||||
e_index[v] = i
|
||||
for u in sorted(G[v]):
|
||||
edges[i] = [u, G[v][u]['weight'], 0]
|
||||
i += 1
|
||||
if not G[v]:
|
||||
edges[i] = [v, 0, 0]
|
||||
i += 1
|
||||
edges[i-1][-1] = 1
|
||||
return list(filter(lambda x: x, edges)), e_index
|
||||
|
||||
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None,
|
||||
int_type=sint):
|
||||
""" Securely compute Dijstra's algorithm on a cleartext graph.
|
||||
|
||||
:param G: directed graph with NetworkX interface
|
||||
:param source: source node (secret or clear-text integer)
|
||||
:param n_loops: when to stop (default: number of edges)
|
||||
:param int_type: secret integer type (default: sint)
|
||||
|
||||
"""
|
||||
edges_list, e_index_list = convert_graph(G)
|
||||
edges = oram_type(len(edges_list), \
|
||||
entry_size=(log2(len(G)), log2(len(G)), 1), \
|
||||
@@ -558,7 +599,7 @@ def test_stupid_dijkstra_on_cycle(n, n_loops=None):
|
||||
@for_range(n)
|
||||
def f(i):
|
||||
M[i][(i+1)%n] = ExtInt(1)
|
||||
M[i][(i-1)%n] = ExtInt(1)
|
||||
M[i][(i-1+n)%n] = ExtInt(1)
|
||||
if n_loops is not None:
|
||||
stop_timer(1)
|
||||
start_timer()
|
||||
|
||||
@@ -39,26 +39,25 @@ def maskRing(a, k):
|
||||
c = ((a + r_prime) << shift).reveal(False) >> shift
|
||||
return c, r
|
||||
|
||||
def maskField(a, k, kappa):
|
||||
def maskField(a, k):
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
r = [types.sint() for i in range(k)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k)
|
||||
# always signed due to usage in equality testing
|
||||
a += two_power(k)
|
||||
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
|
||||
return c, r
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def EQZ(a, k, kappa):
|
||||
def EQZ(a, k):
|
||||
prog = program.Program.prog
|
||||
if prog.use_split():
|
||||
from GC.types import sbitvec
|
||||
from Compiler.GC.types import sbitvec
|
||||
v = sbitvec(a, k).v
|
||||
bit = util.tree_reduce(operator.and_, (~b for b in v))
|
||||
return types.sintbit.conv(bit)
|
||||
prog.non_linear.check_security(kappa)
|
||||
return prog.non_linear.eqz(a, k)
|
||||
|
||||
def bits(a,m):
|
||||
@@ -99,12 +98,12 @@ def or_op(a, b, void=None):
|
||||
def mul_op(a, b, void=None):
|
||||
return a * b
|
||||
|
||||
def PreORC(a, kappa=None, m=None, raw=False):
|
||||
def PreORC(a, m=None, raw=False):
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return [a[0]]
|
||||
prog = program.Program.prog
|
||||
kappa = kappa or prog.security
|
||||
kappa = prog.security
|
||||
m = m or k
|
||||
if isinstance(a[0], types.sgf2n):
|
||||
max_k = program.Program.prog.galois_length - 1
|
||||
@@ -128,13 +127,13 @@ def PreORC(a, kappa=None, m=None, raw=False):
|
||||
t = [types.sint() for i in range(m)]
|
||||
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
|
||||
for i in range(m):
|
||||
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
|
||||
comparison.Mod2(t[i], b[k-1-i], k, False)
|
||||
p[m-1-i] = 1 - t[i]
|
||||
return p
|
||||
else:
|
||||
# not constant-round anymore
|
||||
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
|
||||
s = [PreORC(a[i:i+max_k], raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], raw=raw)
|
||||
return sum(([or_op(x, y) for x in si]
|
||||
for si,y in zip(s[1:],t)), s[0])[-m:]
|
||||
|
||||
@@ -175,6 +174,41 @@ def PreOpL2(op, items):
|
||||
output[2 * i] = op(v[i - 1], items[2 * i])
|
||||
return output
|
||||
|
||||
def PreOpL2_vec(op, *items):
|
||||
""" Vectorized version of :py:func:`PreOpL2` """
|
||||
k = len(items[0])
|
||||
for x in items:
|
||||
assert len(x) == k
|
||||
if k == 1:
|
||||
return items
|
||||
half = k // 2
|
||||
other_half = (k + 1) // 2 - 1
|
||||
u = op([x.get_vector(base=0, size=half, skip=2) for x in items],
|
||||
[x.get_vector(base=1, size=half, skip=2) for x in items])
|
||||
assert len(u) == len(items)
|
||||
assert len(u[0]) == half
|
||||
v = PreOpL2_vec(op, *u)
|
||||
if other_half:
|
||||
w = op([x.get_vector(base=0, size=other_half) for x in v],
|
||||
[x.get_vector(base=2, size=other_half, skip=2) for x in items])
|
||||
if half == other_half:
|
||||
res = [type(x).zip(x, y) for x, y in zip(v, w)]
|
||||
for i in range(len(res)):
|
||||
res[i] = type(res[i]).concat((items[i].get_vector(base=0, size=1),
|
||||
res[i]))
|
||||
else:
|
||||
if other_half:
|
||||
for i in range(len(w)):
|
||||
w[i] = type(w[i]).concat((items[i].get_vector(base=0, size=1),
|
||||
w[i]))
|
||||
else:
|
||||
w = [x.get_vector(base=0, size=1) for x in items]
|
||||
res = [type(x).zip(x, y) for x, y in zip(w, v)]
|
||||
assert len(res) == len(items)
|
||||
for x in res:
|
||||
assert len(x) == k
|
||||
return res
|
||||
|
||||
def PreOpN(op, items):
|
||||
""" Naive PreOp algorithm """
|
||||
k = len(items)
|
||||
@@ -184,9 +218,9 @@ def PreOpN(op, items):
|
||||
output[i] = op(output[i-1], items[i])
|
||||
return output
|
||||
|
||||
def PreOR(a, kappa=None, raw=False):
|
||||
def PreOR(a=None, raw=False):
|
||||
if comparison.const_rounds:
|
||||
return PreORC(a, kappa, raw=raw)
|
||||
return PreORC(a, raw=raw)
|
||||
else:
|
||||
return PreOpL(or_op, a)
|
||||
|
||||
@@ -199,24 +233,24 @@ def KOpL(op, a):
|
||||
t2 = KOpL(op, a[k//2:])
|
||||
return op(t1, t2)
|
||||
|
||||
def KORL(a, kappa=None):
|
||||
def KORL(a):
|
||||
""" log rounds k-ary OR """
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KORL(a[:k//2], kappa)
|
||||
t2 = KORL(a[k//2:], kappa)
|
||||
t1 = KORL(a[:k//2])
|
||||
t2 = KORL(a[k//2:])
|
||||
return t1 + t2 - t1.bit_and(t2)
|
||||
|
||||
def KORC(a, kappa):
|
||||
return PreORC(a, kappa, 1)[0]
|
||||
def KORC(a):
|
||||
return PreORC(a, 1)[0]
|
||||
|
||||
def KOR(a, kappa):
|
||||
def KOR(a):
|
||||
if comparison.const_rounds:
|
||||
return KORC(a, kappa)
|
||||
return KORC(a)
|
||||
else:
|
||||
return KORL(a, None)
|
||||
return KORL(a)
|
||||
|
||||
def KMul(a):
|
||||
if comparison.const_rounds:
|
||||
@@ -262,7 +296,7 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
s[k] = c[k-1]
|
||||
return s
|
||||
|
||||
def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
def BitDec(a, k, m, bits_to_compute=None):
|
||||
return program.Program.prog.non_linear.bit_dec(a, k, m)
|
||||
|
||||
def BitDecRingRaw(a, k, m):
|
||||
@@ -270,7 +304,7 @@ def BitDecRingRaw(a, k, m):
|
||||
n_shift = int(program.Program.prog.options.ring) - m
|
||||
if program.Program.prog.use_split():
|
||||
x = a.split_to_two_summands(m)
|
||||
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
|
||||
bits = types._bitint.bit_adder(x[0], x[1])
|
||||
return bits[:m]
|
||||
else:
|
||||
if program.Program.prog.use_edabit():
|
||||
@@ -292,13 +326,14 @@ def BitDecRing(a, k, m):
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
r = [types.sint() for i in range(m)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m)
|
||||
kappa = program.Program.prog.security
|
||||
pow2 = two_power(k + kappa)
|
||||
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
|
||||
@@ -306,16 +341,16 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
return res
|
||||
|
||||
@instructions_base.bit_cisc
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
|
||||
def BitDecField(a, k, m, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, bits_to_compute)
|
||||
return [types.sintbit.conv(bit) for bit in res]
|
||||
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def Pow2(a, l, kappa):
|
||||
def Pow2(a, l):
|
||||
comparison.program.curr_tape.require_bit_length(l - 1)
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
t = BitDec(a, m, m)
|
||||
return Pow2_from_bits(t)
|
||||
|
||||
def Pow2_from_bits(bits):
|
||||
@@ -327,11 +362,12 @@ def Pow2_from_bits(bits):
|
||||
t[i] = t[i]*pow2k[i] + 1 - t[i]
|
||||
return KMul(t)
|
||||
|
||||
def B2U(a, l, kappa):
|
||||
pow2a = Pow2(a, l, kappa)
|
||||
return B2U_from_Pow2(pow2a, l, kappa), pow2a
|
||||
def B2U(a, l):
|
||||
pow2a = Pow2(a, l)
|
||||
return B2U_from_Pow2(pow2a, l), pow2a
|
||||
|
||||
def B2U_from_Pow2(pow2a, l, kappa):
|
||||
def B2U_from_Pow2(pow2a, l):
|
||||
kappa = program.Program.prog.security
|
||||
r = [types.sint() for i in range(l)]
|
||||
t = types.sint()
|
||||
c = types.cint()
|
||||
@@ -353,17 +389,17 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
c = list(r_bits[0].bit_decompose_clear(c, l))
|
||||
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
|
||||
#print ' '.join(str(b.value) for b in x)
|
||||
y = PreOR(x, kappa)
|
||||
y = PreOR(x)
|
||||
#print ' '.join(str(b.value) for b in y)
|
||||
return [types.sint.conv(1 - y[i]) for i in range(l)]
|
||||
|
||||
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
def Trunc(a, l, m, compute_modulo=False, signed=False):
|
||||
""" Oblivious truncation by secret m """
|
||||
prog = program.Program.prog
|
||||
if util.is_constant(m) and not compute_modulo:
|
||||
# cheaper
|
||||
res = type(a)(size=a.size)
|
||||
comparison.Trunc(res, a, l, m, kappa, signed=signed)
|
||||
comparison.Trunc(res, a, l, m, signed=signed)
|
||||
return res
|
||||
if l == 1:
|
||||
if compute_modulo:
|
||||
@@ -371,9 +407,9 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
else:
|
||||
return a * (1 - m)
|
||||
if program.Program.prog.options.ring and not compute_modulo:
|
||||
return TruncInRing(a, l, Pow2(m, l, kappa))
|
||||
return TruncInRing(a, l, Pow2(m, l))
|
||||
else:
|
||||
kappa = kappa or program.Program.prog.security
|
||||
kappa = program.Program.prog.security
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
@@ -381,7 +417,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
c = types.cint()
|
||||
ci = [types.cint() for i in range(l)]
|
||||
d = types.sint()
|
||||
x, pow2m = B2U(m, l, kappa)
|
||||
x, pow2m = B2U(m, l)
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
t1 = two_power(i) * r[i]
|
||||
@@ -398,7 +434,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l)
|
||||
if compute_modulo:
|
||||
b = c_dprime - r_prime + pow2m * d
|
||||
return b, pow2m
|
||||
@@ -429,33 +465,33 @@ def TruncInRing(to_shift, l, pow2m):
|
||||
def SplitInRing(a, l, m):
|
||||
if l == 1:
|
||||
return m.if_else(a, 0), m.if_else(0, a), 1
|
||||
pow2m = Pow2(m, l, None)
|
||||
pow2m = Pow2(m, l)
|
||||
upper = TruncInRing(a, l, pow2m)
|
||||
lower = a - upper * pow2m
|
||||
return lower, upper, pow2m
|
||||
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1)
|
||||
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
|
||||
return s, overflow
|
||||
|
||||
def Int2FL(a, gamma, l, kappa=None):
|
||||
def Int2FL(a, gamma, l):
|
||||
lam = gamma - 1
|
||||
s = a.less_than(0, gamma, security=kappa)
|
||||
z = a.equal(0, gamma, security=kappa)
|
||||
s = a.less_than(0, gamma)
|
||||
z = a.equal(0, gamma)
|
||||
a = s.if_else(-a, a)
|
||||
a_bits = a.bit_decompose(lam, security=kappa)
|
||||
a_bits = a.bit_decompose(lam)
|
||||
a_bits.reverse()
|
||||
b = PreOR(a_bits, kappa)
|
||||
b = PreOR(a_bits)
|
||||
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
|
||||
p = a.popcnt_bits(b) - lam
|
||||
if gamma - 1 > l:
|
||||
if types.sfloat.round_nearest:
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l)
|
||||
p = p + overflow
|
||||
else:
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, signed=False)
|
||||
else:
|
||||
v = 2**(l-gamma+1) * t
|
||||
p = (p + gamma - 1 - l) * z.bit_not()
|
||||
@@ -466,37 +502,38 @@ def FLRound(x, mode):
|
||||
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
|
||||
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
|
||||
a = types.sint()
|
||||
comparison.LTZ(a, p1, k, x.kappa)
|
||||
b = p1.less_than(-l + 1, k, x.kappa)
|
||||
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
|
||||
c = EQZ(v2, l, x.kappa)
|
||||
comparison.LTZ(a, p1, k)
|
||||
b = p1.less_than(-l + 1, k)
|
||||
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, compute_modulo=True)
|
||||
c = EQZ(v2, l)
|
||||
if mode == -1:
|
||||
away_from_zero = 0
|
||||
mode = x.s
|
||||
else:
|
||||
away_from_zero = mode + s1 - 2 * mode * s1
|
||||
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
|
||||
d = v.equal(two_power(l), l + 1, x.kappa)
|
||||
d = v.equal(two_power(l), l + 1)
|
||||
v = d * two_power(l-1) + (1 - d) * v
|
||||
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
|
||||
s = (1 - b * mode) * s1
|
||||
z = or_op(EQZ(v, l, x.kappa), z1)
|
||||
z = or_op(EQZ(v, l), z1)
|
||||
v = v * (1 - z)
|
||||
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
|
||||
return v, p, z, s
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def TruncPr(a, k, m, kappa=None, signed=True):
|
||||
def TruncPr(a, k, m, signed=True):
|
||||
""" Probabilistic truncation [a/2^m + u]
|
||||
where Pr[u = 1] = (a % 2^m) / 2^m
|
||||
"""
|
||||
nl = program.Program.prog.non_linear
|
||||
nl.check_security(kappa)
|
||||
return nl.trunc_pr(a, k, m, signed)
|
||||
|
||||
def TruncPrRing(a, k, m, signed=True):
|
||||
if m == 0:
|
||||
return a
|
||||
prog = program.Program.prog
|
||||
prog.trunc_pr_warning()
|
||||
n_ring = int(program.Program.prog.options.ring)
|
||||
comparison.require_ring_size(k, 'truncation')
|
||||
if k == n_ring:
|
||||
@@ -517,7 +554,6 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
trunc_pr(res, a, k, m)
|
||||
else:
|
||||
# extra bit to mask overflow
|
||||
prog = program.Program.prog
|
||||
prog.curr_tape.require_bit_length(1)
|
||||
if prog.use_edabit() or prog.use_split() > 2:
|
||||
lower = sint.get_random_int(m)
|
||||
@@ -540,68 +576,67 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
res -= (1 << (k - m - 1))
|
||||
return res
|
||||
|
||||
def TruncPrField(a, k, m, kappa=None):
|
||||
def TruncPrField(a, k, m):
|
||||
if m == 0:
|
||||
return a
|
||||
if kappa is None:
|
||||
kappa = 40
|
||||
|
||||
program.Program.prog.trunc_pr_warning()
|
||||
b = two_power(k-1) + a
|
||||
r_prime, r_dprime = types.sint(), types.sint()
|
||||
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
|
||||
k, m, kappa, use_dabit=False)
|
||||
k, m, use_dabit=False)
|
||||
two_to_m = two_power(m)
|
||||
r = two_to_m * r_dprime + r_prime
|
||||
c = (b + r).reveal(False)
|
||||
c = (b + r).reveal(True)
|
||||
c_prime = c % two_to_m
|
||||
a_prime = c_prime - r_prime
|
||||
d = (a - a_prime).field_div(two_to_m)
|
||||
return d
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
def SDiv(a, b, l, round_nearest=False):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
w = types.cint(int(2.9142 * 2 ** l)) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, signed=True)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest,
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l,
|
||||
nearest=round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
|
||||
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, nearest=round_nearest,
|
||||
signed=False)
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, nearest=round_nearest,
|
||||
signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
|
||||
comparison.Mod2m(x2, x, 2 * l, l, signed=False)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
|
||||
return y
|
||||
|
||||
def SDiv_mono(a, b, l, kappa):
|
||||
def SDiv_mono(a, b, l):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
|
||||
y = TruncPr(y, 2 * l + 1, l + 1)
|
||||
for i in range(theta-1):
|
||||
y = y * (alpha + x)
|
||||
# keep y with l bits
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
y = TruncPr(y, 3 * l, 2 * l)
|
||||
x = x**2
|
||||
# keep x with 2l bits
|
||||
x = TruncPr(x, 4 * l, 2 * l, kappa)
|
||||
x = TruncPr(x, 4 * l, 2 * l)
|
||||
y = y * (alpha + x)
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
y = TruncPr(y, 3 * l, 2 * l)
|
||||
return y
|
||||
|
||||
# LT bit comparison on shared bit values
|
||||
|
||||
@@ -399,7 +399,7 @@ class stop(base.Instruction):
|
||||
arg_format = ['i']
|
||||
|
||||
class use(base.Instruction):
|
||||
""" Offline data usage. Necessary to avoid reusage while using
|
||||
r""" Offline data usage. Necessary to avoid reusage while using
|
||||
preprocessing from files. Also used to multithreading for expensive
|
||||
preprocessing.
|
||||
|
||||
@@ -419,7 +419,7 @@ class use(base.Instruction):
|
||||
args[2].i}
|
||||
|
||||
class use_inp(base.Instruction):
|
||||
""" Input usage. Necessary to avoid reusage while using
|
||||
r""" Input usage. Necessary to avoid reusage while using
|
||||
preprocessing from files.
|
||||
|
||||
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
|
||||
@@ -488,6 +488,70 @@ class join_tape(base.Instruction):
|
||||
code = base.opcodes['JOIN_TAPE']
|
||||
arg_format = ['int']
|
||||
|
||||
class call_tape(base.DoNotEliminateInstruction):
|
||||
""" Start tape/bytecode file in same thread. Arguments/return values
|
||||
starting from :py:obj:`direction` are optional.
|
||||
|
||||
:param: tape number (int)
|
||||
:param: arg (regint)
|
||||
:param: direction (0 for argument, 1 for return value)
|
||||
:param: register type (see :py:obj:`vm_types`)
|
||||
:param: register size (int)
|
||||
:param: destination register
|
||||
:param: source register
|
||||
:param: (repeat from direction)
|
||||
|
||||
"""
|
||||
code = base.opcodes['CALL_TAPE']
|
||||
arg_format = tools.chain(['int', 'ci'],
|
||||
tools.cycle(['int','int','int','*w','*']))
|
||||
|
||||
@staticmethod
|
||||
def type_check(reg, type_id):
|
||||
assert base.vm_types[reg.reg_type] == type_id
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(call_tape, self).__init__(*args, **kwargs)
|
||||
for i in range(2, len(args), 5):
|
||||
for reg in args[i + 3:i + 5]:
|
||||
self.type_check(reg, args[i + 1])
|
||||
assert reg.size == args[i + 2]
|
||||
assert args[i] in (0, 1)
|
||||
assert args[i + 4 - args[i]].program == program.curr_tape
|
||||
assert args[i + 3 + args[i]].program == program.tapes[args[0]]
|
||||
|
||||
def get_def(self):
|
||||
# hide registers from called tape
|
||||
for i in range(2, len(self.args), 5):
|
||||
if self.args[i]:
|
||||
yield self.args[i + 3]
|
||||
|
||||
def get_used(self):
|
||||
# hide registers from called tape
|
||||
yield self.args[1]
|
||||
for i in range(2, len(self.args), 5):
|
||||
if not self.args[i]:
|
||||
yield self.args[i + 4]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.num += program.tapes[self.args[0]].req_tree.aggregate()
|
||||
|
||||
class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction):
|
||||
""" Pseudo instruction for arguments in connection with
|
||||
:py:class:`call_tape`.
|
||||
|
||||
:param: destination (register)
|
||||
:param: register type (see :py:obj:`vm_types`)
|
||||
|
||||
"""
|
||||
code = base.opcodes['CALL_ARG']
|
||||
arg_format = ['*w','int']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(call_arg, self).__init__(*args, **kwargs)
|
||||
for i in range(0, len(args), 2):
|
||||
call_tape.type_check(args[i], args[i + 1])
|
||||
|
||||
class crash(base.IOInstruction):
|
||||
""" Crash runtime if the value in the register is not zero.
|
||||
|
||||
@@ -687,6 +751,27 @@ class concats(base.VectorInstruction):
|
||||
for i in range(1, len(args), 2):
|
||||
assert args[i] == len(args[i + 1])
|
||||
|
||||
class zips(base.Instruction):
|
||||
""" Zip vectors.
|
||||
|
||||
:param: result (sint)
|
||||
:param: operand (sint)
|
||||
:param: operand (sint)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['ZIPS']
|
||||
arg_format = ['sw','s','s']
|
||||
is_vec = lambda self: True
|
||||
|
||||
def __init__(self, *args):
|
||||
super(zips, self).__init__(*args)
|
||||
assert len(args[0]) == len(args[1]) + len(args[2])
|
||||
assert len(args[1]) == len(args[2])
|
||||
|
||||
def get_code(self):
|
||||
return super(zips, self).get_code(len(self.args[1]))
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class mulc(base.MulBase):
|
||||
@@ -1653,7 +1738,7 @@ class print_reg_plains(base.IOInstruction):
|
||||
arg_format = ['s']
|
||||
|
||||
class cond_print_plain(base.IOInstruction):
|
||||
""" Conditionally output clear register (with precision).
|
||||
r""" Conditionally output clear register (with precision).
|
||||
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
|
||||
|
||||
:param: condition (cint, no output if zero)
|
||||
@@ -1904,7 +1989,7 @@ class closeclientconnection(base.IOInstruction):
|
||||
code = base.opcodes['CLOSECLIENTCONNECTION']
|
||||
arg_format = ['ci']
|
||||
|
||||
class writesharestofile(base.IOInstruction):
|
||||
class writesharestofile(base.VectorInstruction, base.IOInstruction):
|
||||
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
|
||||
(appending at the end).
|
||||
|
||||
@@ -1917,11 +2002,12 @@ class writesharestofile(base.IOInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITEFILESHARE']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('s'))
|
||||
vector_index = 1
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class readsharesfromfile(base.IOInstruction):
|
||||
class readsharesfromfile(base.VectorInstruction, base.IOInstruction):
|
||||
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
|
||||
|
||||
:param: number of arguments to follow / number of shares plus two (int)
|
||||
@@ -1933,6 +2019,7 @@ class readsharesfromfile(base.IOInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['READFILESHARE']
|
||||
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
|
||||
vector_index = 2
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -2256,7 +2343,7 @@ class convint(base.Instruction):
|
||||
|
||||
@base.vectorize
|
||||
class convmodp(base.Instruction):
|
||||
""" Convert clear integer register (vector) to clear register
|
||||
r""" Convert clear integer register (vector) to clear register
|
||||
(vector). If the bit length is zero, the unsigned conversion is
|
||||
used, otherwise signed conversion is used. This makes a difference
|
||||
when computing modulo a prime :math:`p`. Signed conversion of
|
||||
@@ -2734,13 +2821,11 @@ class check(base.Instruction):
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class sqrs(base.CISC):
|
||||
""" Secret squaring $s_i = s_j \cdot s_j$. """
|
||||
r""" Secret squaring $s_i = s_j \cdot s_j$. """
|
||||
__slots__ = []
|
||||
arg_format = ['sw', 's']
|
||||
|
||||
def expand(self):
|
||||
if program.options.ring:
|
||||
return muls(self.args[0], self.args[1], self.args[1])
|
||||
s = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c = [program.curr_block.new_reg('c') for i in range(2)]
|
||||
square(s[0], s[1])
|
||||
|
||||
@@ -68,6 +68,8 @@ opcodes = dict(
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
CALL_TAPE = 0xEC,
|
||||
CALL_ARG = 0xED,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -85,6 +87,7 @@ opcodes = dict(
|
||||
PREFIXSUMS = 0x2D,
|
||||
PICKS = 0x2E,
|
||||
CONCATS = 0x2F,
|
||||
ZIPS = 0x3F,
|
||||
# Multiplication/division
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -223,6 +226,17 @@ opcodes = dict(
|
||||
)
|
||||
|
||||
|
||||
vm_types = dict(
|
||||
ci = 0,
|
||||
sb = 1,
|
||||
cb = 2,
|
||||
s = 4,
|
||||
c = 5,
|
||||
sg = 6,
|
||||
cg = 7,
|
||||
)
|
||||
|
||||
|
||||
def int_to_bytes(x):
|
||||
""" 32 bit int to big-endian 4 byte conversion. """
|
||||
assert(x < 2**32 and x >= -2**32)
|
||||
@@ -491,7 +505,8 @@ def cisc(function, n_outputs=1):
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
for x, bl in tape.req_bit_length.items():
|
||||
old_tape.require_bit_length(bl - 1, x)
|
||||
old_tape.require_bit_length(
|
||||
bl - 1, x, tape.bit_length_reason if x == 'p' else '')
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
@@ -516,40 +531,48 @@ def cisc(function, n_outputs=1):
|
||||
inst.copy(size, subs)
|
||||
reset_global_vector_size()
|
||||
|
||||
def expand_to_function(self, size, new_regs):
|
||||
key = size, program.curr_tape, \
|
||||
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
|
||||
tuple(type(reg) for reg in new_regs)
|
||||
if key not in self.functions:
|
||||
from Compiler import library, types
|
||||
class Arg:
|
||||
def __init__(self, reg):
|
||||
from Compiler.GC.types import bits
|
||||
class Arg:
|
||||
def __init__(self, reg):
|
||||
self.type = type(reg)
|
||||
self.binary = isinstance(reg, bits)
|
||||
self.reg = reg
|
||||
# if reg is not None:
|
||||
# program.base_addresses[reg] = None
|
||||
def new(self):
|
||||
if self.binary:
|
||||
return self.type()
|
||||
else:
|
||||
return self.type(size=size)
|
||||
def load(self):
|
||||
return self.reg
|
||||
def store(self, reg):
|
||||
if self.type != type(None):
|
||||
self.reg.update(reg)
|
||||
args = [Arg(x) for x in new_regs]
|
||||
self.type = type(reg)
|
||||
self.binary = isinstance(reg, bits)
|
||||
self.reg = reg
|
||||
def new(self, size):
|
||||
if self.binary:
|
||||
return self.type()
|
||||
else:
|
||||
return self.type(size=size)
|
||||
def load(self):
|
||||
return self.reg
|
||||
def store(self, reg):
|
||||
if self.type != type(None):
|
||||
self.reg.update(reg)
|
||||
def is_real(self):
|
||||
return self.reg is not None
|
||||
|
||||
def base_key(self, size, new_regs):
|
||||
return size, tuple(
|
||||
arg for arg, reg in zip(self.args, new_regs) if reg is None), \
|
||||
tuple(type(reg) for reg in new_regs)
|
||||
|
||||
@staticmethod
|
||||
def get_name(key):
|
||||
return '_'.join(['%s(%d)' % (function.__name__, key[0])] +
|
||||
[str(x) for x in key[1]])
|
||||
|
||||
def expand_to_function(self, size, new_regs):
|
||||
key = self.base_key(size, new_regs) + (program.curr_tape,)
|
||||
if key not in self.functions:
|
||||
args = [self.Arg(x) for x in new_regs]
|
||||
from Compiler import library, types
|
||||
@library.function_block
|
||||
def f():
|
||||
res = [arg.new() for arg in args[:n_outputs]]
|
||||
self.new_instructions(size,
|
||||
res + [arg.load() for arg in args[n_outputs:]])
|
||||
res = [arg.new(size) for arg in args[:n_outputs]]
|
||||
self.new_instructions(
|
||||
size, res + [arg.load() for arg in args[n_outputs:]])
|
||||
for reg, arg in zip(res, args):
|
||||
arg.store(reg)
|
||||
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
|
||||
[str(x) for x in key[2]])
|
||||
f.name = self.get_name(key)
|
||||
self.functions[key] = f, args
|
||||
f, args = self.functions[key]
|
||||
for i in range(len(new_regs) - n_outputs):
|
||||
@@ -558,6 +581,31 @@ def cisc(function, n_outputs=1):
|
||||
for i in range(n_outputs):
|
||||
new_regs[i].link(args[i].load())
|
||||
|
||||
def expand_to_tape(self, size, new_regs):
|
||||
key = self.base_key(size, new_regs)
|
||||
args = [self.Arg(x) for x in new_regs]
|
||||
if key not in self.functions:
|
||||
from Compiler import library, types
|
||||
@library.function_call_tape
|
||||
def f(*in_args):
|
||||
res = [arg.new(size) for arg in args[:n_outputs]]
|
||||
in_args = list(in_args)
|
||||
my_args = list(res)
|
||||
for arg in args[n_outputs:]:
|
||||
if arg.is_real():
|
||||
my_args.append(in_args.pop(0))
|
||||
else:
|
||||
my_args.append(arg.reg)
|
||||
self.new_instructions(size, my_args)
|
||||
return res
|
||||
f.name = self.get_name(key)
|
||||
self.functions[key] = f
|
||||
f = self.functions[key]
|
||||
in_args = filter(lambda arg: arg.is_real(), args[n_outputs:])
|
||||
res = util.tuplify(f(*(arg.load() for arg in in_args)))
|
||||
for i in range(n_outputs):
|
||||
new_regs[i].link(res[i])
|
||||
|
||||
def expand_merged(self, skip):
|
||||
if function.__name__ in skip:
|
||||
good = True
|
||||
@@ -595,7 +643,11 @@ def cisc(function, n_outputs=1):
|
||||
raise
|
||||
if program.cisc_to_function and \
|
||||
(program.curr_tape.singular or program.n_running_threads):
|
||||
self.expand_to_function(size, new_regs)
|
||||
if (program.options.garbled or program.options.binary or \
|
||||
not program.use_tape_calls) and not program.force_cisc_tape:
|
||||
self.expand_to_function(size, new_regs)
|
||||
else:
|
||||
self.expand_to_tape(size, new_regs)
|
||||
else:
|
||||
self.new_instructions(size, new_regs)
|
||||
program.curr_block.n_rounds += self.n_rounds - 1
|
||||
@@ -795,6 +847,12 @@ class ClearIntAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearInt
|
||||
name = 'regint'
|
||||
|
||||
class AnyRegAF(RegisterArgFormat):
|
||||
reg_type = '*'
|
||||
@staticmethod
|
||||
def check(arg):
|
||||
assert isinstance(arg, program.curr_tape.Register)
|
||||
|
||||
class IntArgFormat(ArgFormat):
|
||||
n_bits = 32
|
||||
|
||||
@@ -898,6 +956,8 @@ ArgFormats = {
|
||||
'sgw': SecretGF2NAF,
|
||||
'ci': ClearIntAF,
|
||||
'ciw': ClearIntAF,
|
||||
'*': AnyRegAF,
|
||||
'*w': AnyRegAF,
|
||||
'i': ImmediateModpAF,
|
||||
'ig': ImmediateGF2NAF,
|
||||
'int': IntArgFormat,
|
||||
@@ -938,6 +998,7 @@ class Instruction(object):
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
sys.stdout.flush()
|
||||
if Instruction.count > 10 ** 7:
|
||||
print("Compilation produced more that 10 million instructions. "
|
||||
"Consider using './compile.py -l' or replacing for loops "
|
||||
@@ -1139,9 +1200,11 @@ class VarArgsInstruction(Instruction):
|
||||
class VectorInstruction(Instruction):
|
||||
__slots__ = []
|
||||
is_vec = lambda self: True
|
||||
vector_index = 0
|
||||
|
||||
def get_code(self):
|
||||
return super(VectorInstruction, self).get_code(len(self.args[0]))
|
||||
return super(VectorInstruction, self).get_code(
|
||||
len(self.args[self.vector_index]))
|
||||
|
||||
class Ciscable(Instruction):
|
||||
def copy(self, size, subs):
|
||||
|
||||
@@ -107,7 +107,7 @@ def print_str(s, *args, print_secrets=False):
|
||||
elif isinstance(val, cfloat):
|
||||
val.print_float_plain()
|
||||
elif isinstance(val, (list, tuple, Array, SubMultiArray)):
|
||||
print_str(*_expand_to_print(val))
|
||||
print_str(*_expand_to_print(val), print_secrets=print_secrets)
|
||||
else:
|
||||
try:
|
||||
val.output()
|
||||
@@ -314,7 +314,7 @@ def get_cmdline_arg(idx):
|
||||
return localint(res)
|
||||
|
||||
def make_array(l, t=None):
|
||||
if isinstance(l, Tape.Register):
|
||||
if isinstance(l, types._structure):
|
||||
res = Array(len(l), t or type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
@@ -334,13 +334,12 @@ class FunctionTapeCall:
|
||||
return self
|
||||
def join(self):
|
||||
self.thread.join()
|
||||
instructions.program.free(self.base, 'ci')
|
||||
for reg_type,addr in self.bases.items():
|
||||
get_program().free(addr, reg_type.reg_type)
|
||||
if self.base is not None:
|
||||
instructions.program.free(self.base, 'ci')
|
||||
|
||||
class Function:
|
||||
def __init__(self, function, name=None, compile_args=[]):
|
||||
self.type_args = {}
|
||||
self.last_key = None
|
||||
self.function = function
|
||||
self.name = name
|
||||
if name is None:
|
||||
@@ -348,46 +347,40 @@ class Function:
|
||||
self.compile_args = compile_args
|
||||
def __call__(self, *args):
|
||||
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
|
||||
from .types import _types
|
||||
get_reg_type = lambda x: \
|
||||
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
|
||||
key = len(args), get_tape()
|
||||
if key not in self.type_args:
|
||||
runtime_args = []
|
||||
reg_args = []
|
||||
key = self.base_key(),
|
||||
for i,arg in enumerate(args):
|
||||
if isinstance(arg, types._vectorizable):
|
||||
key += (arg.shape, arg.value_type)
|
||||
else:
|
||||
arg = MemValue(arg)
|
||||
reg_args.append(arg)
|
||||
t = arg.value_type
|
||||
key += (arg.size, t)
|
||||
runtime_args.append(arg)
|
||||
if key != self.last_key:
|
||||
# first call
|
||||
type_args = collections.defaultdict(list)
|
||||
for i,arg in enumerate(args):
|
||||
if not isinstance(arg, types._vectorizable):
|
||||
type_args[get_reg_type(arg)].append(i)
|
||||
outer_runtime_args = runtime_args
|
||||
def wrapped_function(*compile_args):
|
||||
base = get_arg()
|
||||
bases = dict((t, regint.load_mem(base + i)) \
|
||||
for i,t in enumerate(sorted(type_args,
|
||||
key=lambda x:
|
||||
x.reg_type)))
|
||||
runtime_args = list(args)
|
||||
for t in sorted(type_args, key=lambda x: x.reg_type):
|
||||
i = 0
|
||||
for i_arg in type_args[t]:
|
||||
runtime_args[i_arg] = t.load_mem(bases[t] + i)
|
||||
i += util.mem_size(t)
|
||||
return self.function(*(list(compile_args) + runtime_args))
|
||||
addresses = regint.Array(len(outer_runtime_args),
|
||||
address=get_arg())
|
||||
runtime_args = []
|
||||
for i, arg in enumerate(outer_runtime_args):
|
||||
if isinstance(arg, MemValue):
|
||||
arg = arg.value_type.load_mem(
|
||||
address=addresses[i], size=arg.size)
|
||||
runtime_args.append(arg)
|
||||
self.result = self.function(
|
||||
*(list(compile_args) + runtime_args))
|
||||
return self.result
|
||||
self.on_first_call(wrapped_function)
|
||||
self.type_args[key] = type_args
|
||||
type_args = self.type_args[key]
|
||||
base = instructions.program.malloc(len(type_args), 'ci')
|
||||
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
|
||||
for t in type_args)
|
||||
for i,reg_type in enumerate(sorted(type_args,
|
||||
key=lambda x: x.reg_type)):
|
||||
store_in_mem(bases[reg_type], base + i)
|
||||
j = 0
|
||||
for i_arg in type_args[reg_type]:
|
||||
if get_reg_type(args[i_arg]) != reg_type:
|
||||
raise CompilerError('type mismatch: "%s" not of type "%s"' %
|
||||
(args[i_arg], reg_type))
|
||||
store_in_mem(args[i_arg], bases[reg_type] + j)
|
||||
j += util.mem_size(reg_type)
|
||||
return self.on_call(base, bases)
|
||||
self.last_key = key
|
||||
addresses = regint.Array(len(runtime_args))
|
||||
for i, arg in enumerate(reg_args):
|
||||
addresses[i] = arg.address
|
||||
return self.on_call(addresses._address,
|
||||
[(arg.value_type, arg.address) for arg in reg_args])
|
||||
|
||||
class FunctionTape(Function):
|
||||
# not thread-safe
|
||||
@@ -401,6 +394,171 @@ class FunctionTape(Function):
|
||||
single_thread=self.single_thread)
|
||||
def on_call(self, base, bases):
|
||||
return FunctionTapeCall(self.thread, base, bases)
|
||||
@staticmethod
|
||||
def base_key():
|
||||
pass
|
||||
|
||||
class FunctionCallTape(FunctionTape):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FunctionTape, self).__init__(*args, **kwargs)
|
||||
self.instances = {}
|
||||
@staticmethod
|
||||
def get_key(args, kwargs):
|
||||
key = (get_program(),)
|
||||
def process_for_key(arg):
|
||||
nonlocal key
|
||||
if isinstance(arg, types._vectorizable):
|
||||
key += (arg.value_type, tuple(arg.shape))
|
||||
elif isinstance(arg, Tape.Register):
|
||||
key += (type(arg), arg.size)
|
||||
elif isinstance(arg, list):
|
||||
key += (tuple(arg), 'l')
|
||||
else:
|
||||
key += (arg,)
|
||||
for arg in args:
|
||||
process_for_key(arg)
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
key += (name, 'kw')
|
||||
process_for_key(arg)
|
||||
return key
|
||||
def __call__(self, *args, **kwargs):
|
||||
key = self.get_key(args, kwargs)
|
||||
if key not in self.instances:
|
||||
my_args = []
|
||||
def wrapped_function():
|
||||
actual_call_args = []
|
||||
def process_for_call(arg):
|
||||
if isinstance(arg, Tape.Register):
|
||||
my_arg = arg.same_type()
|
||||
call_arg(my_arg, base.vm_types[my_arg.reg_type])
|
||||
my_args.append(my_arg)
|
||||
return my_arg
|
||||
elif isinstance(arg, types._vectorizable):
|
||||
my_arg = arg.same_shape(address=regint())
|
||||
call_arg(my_arg.address, base.vm_types['ci'])
|
||||
my_args.append(my_arg)
|
||||
my_arg = arg.same_shape(
|
||||
address=MemValue(my_arg.address))
|
||||
return my_arg
|
||||
actual_call_args.append(my_arg)
|
||||
else:
|
||||
my_args.append(arg)
|
||||
return arg
|
||||
for arg in args:
|
||||
actual_call_args.append(process_for_call(arg))
|
||||
actual_call_kwargs = {}
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
actual_call_kwargs[name] = process_for_call(arg)
|
||||
self.result = self.function(*actual_call_args,
|
||||
**actual_call_kwargs)
|
||||
if self.result is not None:
|
||||
self.result = list(tuplify(self.result))
|
||||
for i, res in enumerate(self.result):
|
||||
if util.is_constant(res):
|
||||
self.result[i] = regint(res)
|
||||
self.on_first_call(wrapped_function, key, my_args)
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
args += arg,
|
||||
return self.on_call(*self.instances[key], args)
|
||||
def on_first_call(self, wrapped_function, key, inside_args):
|
||||
program = get_program()
|
||||
program.curr_tape
|
||||
tape_handle = len(program.tapes)
|
||||
# entry for recursion
|
||||
self.instances[key] = tape_handle, None, inside_args
|
||||
assert tape_handle == program.new_tape(
|
||||
wrapped_function, name=self.name, args=self.compile_args,
|
||||
single_thread=get_tape().singular, finalize=False,
|
||||
thread_pool=get_tape().free_threads)
|
||||
tape = program.tapes[tape_handle]
|
||||
if self.result is not None:
|
||||
self.result = list(tuplify(self.result))
|
||||
for reg in self.result:
|
||||
reg.can_eliminate = False
|
||||
tape.return_values.append(reg)
|
||||
assert not tape.purged
|
||||
get_program().finalize_tape(tape)
|
||||
self.instances[key] = tape_handle, self.result, inside_args
|
||||
def on_call(self, tape_handle, result, inside_args, args):
|
||||
tape = get_program().tapes[tape_handle]
|
||||
if tape.ran_threads and tape.free_threads != get_tape().free_threads:
|
||||
raise CompilerError(
|
||||
'cannot call thread-running tape from another thread')
|
||||
assert len(inside_args) == len(args)
|
||||
out_result = []
|
||||
call_args = []
|
||||
if result is not None:
|
||||
out_result = [reg.same_type() for reg in result]
|
||||
for x, y in zip(out_result, result):
|
||||
call_args += [
|
||||
1, instructions_base.vm_types[x.reg_type],
|
||||
x.size_for_mem(), x, y]
|
||||
for x, y in zip(inside_args, args):
|
||||
if isinstance(x, Tape.Register):
|
||||
call_args += [
|
||||
0, instructions_base.vm_types[x.reg_type],
|
||||
x.size_for_mem(), x, y]
|
||||
elif isinstance(x, types._vectorizable):
|
||||
call_args += [0, base.vm_types['ci'], 1,
|
||||
x.address, regint.conv(y.address)]
|
||||
call_tape(tape_handle, regint(0),
|
||||
*call_args)
|
||||
break_point('call-%s' % self.name)
|
||||
return untuplify(tuple(out_result))
|
||||
|
||||
class ExportFunction(FunctionCallTape):
|
||||
def __init__(self, function):
|
||||
super(ExportFunction, self).__init__(function)
|
||||
self.done = set()
|
||||
def __call__(self, *args, **kwargs):
|
||||
if kwargs:
|
||||
raise CompilerError('keyword arguments not supported')
|
||||
def arg_signature(arg):
|
||||
if isinstance(arg, types._structure):
|
||||
return '%s:%d' % (arg.arg_type(), arg.size)
|
||||
elif isinstance(arg, types._vectorizable):
|
||||
from .GC.types import sbitvec
|
||||
if issubclass(arg.value_type, sbitvec):
|
||||
return 'sbv:[%dx%d]' % (arg.total_size(),
|
||||
arg.value_type.n_bits)
|
||||
else:
|
||||
return '%s:[%d]' % (arg.value_type.arg_type(),
|
||||
arg.total_size())
|
||||
else:
|
||||
raise CompilerError('argument not supported: %s' % arg)
|
||||
signature = []
|
||||
for arg in args:
|
||||
signature.append(arg_signature(arg))
|
||||
signature = tuple(signature)
|
||||
key = self.get_key(args, kwargs)
|
||||
if key in self.instances and signature not in self.done:
|
||||
raise CompilerError('signature conflict')
|
||||
super(ExportFunction, self).__call__(*args, **kwargs)
|
||||
if signature not in self.done:
|
||||
filename = '%s/%s/%s-%s' % (get_program().programs_dir, 'Functions',
|
||||
self.name, '-'.join(signature))
|
||||
print('Writing to', filename)
|
||||
out = open(filename, 'w')
|
||||
print(get_program().name, file=out)
|
||||
print(self.instances[key][0], file=out)
|
||||
result = self.instances[key][1]
|
||||
try:
|
||||
if result is not None:
|
||||
result = untuplify(result)
|
||||
print(arg_signature(result), result.i, file=out)
|
||||
else:
|
||||
print('- 0', file=out)
|
||||
except CompilerError:
|
||||
raise CompilerError('return type not supported: %s' % result)
|
||||
for arg in self.instances[key][2]:
|
||||
if isinstance(arg, types._structure):
|
||||
print(arg.i, end=' ', file=out)
|
||||
elif isinstance(arg, types._vectorizable):
|
||||
print(arg.address, end=' ', file=out)
|
||||
else:
|
||||
CompilerError('argument not supported: %s', arg)
|
||||
print(file=out)
|
||||
self.done.add(signature)
|
||||
|
||||
def function_tape(function):
|
||||
return FunctionTape(function)
|
||||
@@ -413,18 +571,111 @@ def function_tape_with_compile_args(*args):
|
||||
def single_thread_function_tape(function):
|
||||
return FunctionTape(function, single_thread=True)
|
||||
|
||||
def memorize(x):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(memorize(i) for i in x)
|
||||
def function_call_tape(function):
|
||||
if get_program().use_tape_calls:
|
||||
return FunctionCallTape(function)
|
||||
else:
|
||||
return MemValue(x)
|
||||
return function
|
||||
|
||||
def method_call_tape(function):
|
||||
tapes = {}
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def use(name):
|
||||
x = self.__dict__[name]
|
||||
return not isinstance(x, types.MultiArray) or \
|
||||
x.array._address is not None
|
||||
key = (type(self),) + tuple(filter(use, sorted(self.__dict__)))
|
||||
member_key = key[1:]
|
||||
if key not in tapes:
|
||||
def f(*args, **kwargs):
|
||||
class Dummy(type(self)):
|
||||
__init__ = lambda self: None
|
||||
dummy = Dummy()
|
||||
members = args[:len(member_key)]
|
||||
real_args = args[len(member_key):]
|
||||
addresses = {}
|
||||
for name, member in zip(member_key, members):
|
||||
dummy.__dict__[name] = member
|
||||
if isinstance(member, types._vectorizable):
|
||||
addresses[name] = member.address
|
||||
res = function(dummy, *real_args, **kwargs)
|
||||
for name, member in zip(member_key, members):
|
||||
new_member = dummy.__dict__[name]
|
||||
desc = '%s in %s.%s' % (name, type(self).__name__,
|
||||
function.__name__)
|
||||
if id(new_member) != id(member):
|
||||
raise CompilerError('cannot change members '
|
||||
'in method tape (%s)' % desc)
|
||||
if isinstance(member, types._vectorizable) and \
|
||||
id(new_member.address) != id(addresses[name]):
|
||||
raise CompilerError('cannot change memory address '
|
||||
'in method tape (%s)' % desc)
|
||||
if set(member_key) != set(dummy.__dict__):
|
||||
raise CompilerError('cannot add members '
|
||||
'in method tape (%s)' % desc)
|
||||
return res
|
||||
f.__name__ = '%s-%s' % (type(self).__name__, function.__name__)
|
||||
tapes[key] = function_call_tape(f)
|
||||
members = tuple(self.__dict__[x] for x in member_key)
|
||||
res = tapes[key](*(members + args), **kwargs)
|
||||
return res
|
||||
return wrapper
|
||||
|
||||
def function(function):
|
||||
""" Create a run-time function. The arguments can be memory or basic
|
||||
types, and return values can be basic types::
|
||||
|
||||
@function
|
||||
def f(x, y, z):
|
||||
y.write(1)
|
||||
z[0] = 2
|
||||
return x + 3
|
||||
|
||||
a = MemValue(sint(0))
|
||||
b = sint.Array(10)
|
||||
c = f(sint(4), a, b)
|
||||
|
||||
print_ln('%s %s %s', a.reveal(), b[0].reveal(), c.reveal())
|
||||
|
||||
This should output::
|
||||
|
||||
1 2 7
|
||||
|
||||
You can use run-time functions recursively but without return
|
||||
values in this case.
|
||||
|
||||
"""
|
||||
return FunctionCallTape(function)
|
||||
|
||||
def export(function):
|
||||
return ExportFunction(function)
|
||||
|
||||
def memorize(x, write=True):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(memorize(i, write=write) for i in x)
|
||||
elif x is None:
|
||||
return
|
||||
else:
|
||||
return MemValue(x, write=write)
|
||||
|
||||
def unmemorize(x):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(unmemorize(i) for i in x)
|
||||
elif x is None:
|
||||
return
|
||||
else:
|
||||
return x.read()
|
||||
|
||||
def write_mem(dest, source):
|
||||
if isinstance(dest, (tuple, list)):
|
||||
assert len(dest) == len(source)
|
||||
for x, y in zip(dest, source):
|
||||
write_mem(x, y)
|
||||
elif dest is None:
|
||||
return
|
||||
else:
|
||||
dest.write(source)
|
||||
|
||||
class FunctionBlock(Function):
|
||||
def on_first_call(self, wrapped_function):
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
@@ -470,6 +721,10 @@ class FunctionBlock(Function):
|
||||
if self.result is not None:
|
||||
return unmemorize(self.result)
|
||||
|
||||
@staticmethod
|
||||
def base_key():
|
||||
return get_tape()
|
||||
|
||||
def function_block(function):
|
||||
return FunctionBlock(function)
|
||||
|
||||
@@ -704,8 +959,10 @@ def for_range(start, stop=None, step=None):
|
||||
|
||||
"""
|
||||
def decorator(loop_body):
|
||||
get_tape().unused_decorators.pop(decorator)
|
||||
range_loop(loop_body, start, stop, step)
|
||||
return loop_body
|
||||
get_tape().unused_decorators[decorator] = 'for_range'
|
||||
return decorator
|
||||
|
||||
def for_range_parallel(n_parallel, n_loops):
|
||||
@@ -754,7 +1011,7 @@ def for_range_opt(start, stop=None, step=None, budget=None):
|
||||
:param start/stop/step: int/regint/cint (used as in :py:func:`range`)
|
||||
or :py:obj:`start` only as list/tuple of int (see below)
|
||||
:param budget: number of instructions after which to start optimization
|
||||
(default is 100,000)
|
||||
(default is 1000 or as given with ``--budget``)
|
||||
|
||||
Example:
|
||||
|
||||
@@ -777,7 +1034,8 @@ def for_range_opt(start, stop=None, step=None, budget=None):
|
||||
if stop is not None:
|
||||
start, stop, step = _range_prep(start, stop, step)
|
||||
def wrapper(loop_body):
|
||||
n_loops = (step - 1 + stop - start) // step
|
||||
range_ = stop-start
|
||||
n_loops = ((range_% step) != 0) + range_ // step
|
||||
@for_range_opt(n_loops, budget=budget)
|
||||
def _(i):
|
||||
return loop_body(start + i * step)
|
||||
@@ -907,11 +1165,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
r = reducer(tuplify(loop_body(j)), mem_state)
|
||||
write_state_to_memory(r)
|
||||
state = mem_state
|
||||
for i,x in enumerate(state):
|
||||
if use_array:
|
||||
mem_state[i] = x
|
||||
else:
|
||||
mem_state[i].write(x)
|
||||
if use_array and len(state) and \
|
||||
isinstance(types._register, types._vectorizable):
|
||||
mem_state[:] = state.get_vector()
|
||||
else:
|
||||
for i,x in enumerate(state):
|
||||
if use_array:
|
||||
mem_state[i] = x
|
||||
else:
|
||||
mem_state[i].write(x)
|
||||
def returner():
|
||||
return untuplify(tuple(state))
|
||||
return returner
|
||||
@@ -987,7 +1249,7 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
|
||||
.. code::
|
||||
|
||||
@multithread(8, 25)
|
||||
@multithread(3, 25)
|
||||
def f(base, size):
|
||||
...
|
||||
"""
|
||||
@@ -1077,7 +1339,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
return loop_body(base + i)
|
||||
prog = get_program()
|
||||
thread_args = []
|
||||
if prog.curr_tape == prog.tapes[0]:
|
||||
if prog.curr_tape.singular:
|
||||
prog.n_running_threads = n_threads
|
||||
if not util.is_zero(thread_rounds):
|
||||
prog.prevent_breaks = False
|
||||
@@ -1261,7 +1523,7 @@ def while_loop(loop_body, condition, arg=None, g=None):
|
||||
result = loop_body(arg)
|
||||
if isinstance(result, MemValue):
|
||||
result = result.read()
|
||||
arg.update(result)
|
||||
arg.link(type(arg)(result))
|
||||
return condition(result)
|
||||
if not isinstance(pre_condition, (bool,int)) or pre_condition:
|
||||
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
|
||||
@@ -1288,11 +1550,26 @@ def while_do(condition, *args):
|
||||
return loop_body
|
||||
return decorator
|
||||
|
||||
def _run_and_link(function, g=None):
|
||||
def _run_and_link(function, g=None, lock_lists=True, allow_return=False):
|
||||
if g is None:
|
||||
g = function.__globals__
|
||||
if lock_lists:
|
||||
class A(list):
|
||||
def __init_(self, l):
|
||||
self[:] = l
|
||||
def __setitem__(*args):
|
||||
raise Exception('you cannot change lists in branches, '
|
||||
'use Array or MultiArray instead')
|
||||
__delitem__ = append = clear = extend = insert = __setitem__
|
||||
pop = remove = reverse = sort = __setitem__
|
||||
for x in g:
|
||||
if isinstance(g[x], list):
|
||||
g[x] = A(g[x])
|
||||
pre = copy.copy(g)
|
||||
res = function()
|
||||
if res is not None and not allow_return:
|
||||
raise CompilerError('Conditional blocks cannot return values. '
|
||||
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else')
|
||||
_link(pre, g)
|
||||
return res
|
||||
|
||||
@@ -1325,7 +1602,7 @@ def do_while(loop_fn, g=None):
|
||||
name='begin-loop')
|
||||
get_tape().loop_breaks.append([])
|
||||
loop_block = instructions.program.curr_block
|
||||
condition = _run_and_link(loop_fn, g)
|
||||
condition = _run_and_link(loop_fn, g, allow_return=True)
|
||||
if callable(condition):
|
||||
condition = condition()
|
||||
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
|
||||
@@ -1347,7 +1624,9 @@ def if_then(condition):
|
||||
condition = condition()
|
||||
try:
|
||||
if not condition.is_clear:
|
||||
raise CompilerError('cannot branch on secret values')
|
||||
raise CompilerError(
|
||||
'cannot branch on secret values, use if_else instead: '
|
||||
'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.if_else')
|
||||
except AttributeError:
|
||||
pass
|
||||
state.condition = regint.conv(condition)
|
||||
@@ -1488,30 +1767,30 @@ def else_(body):
|
||||
end_if()
|
||||
|
||||
def and_(*terms):
|
||||
res = regint(0)
|
||||
for term in terms:
|
||||
if_then(term())
|
||||
old_res = res
|
||||
res = regint(1)
|
||||
res.link(old_res)
|
||||
for term in terms:
|
||||
else_then()
|
||||
end_if()
|
||||
def load_result():
|
||||
res = regint(0)
|
||||
for term in terms:
|
||||
if_then(term())
|
||||
old_res = res
|
||||
res = regint(1)
|
||||
res.link(old_res)
|
||||
for term in terms:
|
||||
else_then()
|
||||
end_if()
|
||||
return res
|
||||
return load_result
|
||||
|
||||
def or_(*terms):
|
||||
res = regint(1)
|
||||
for term in terms:
|
||||
if_then(term())
|
||||
else_then()
|
||||
old_res = res
|
||||
res = regint(0)
|
||||
res.link(old_res)
|
||||
for term in terms:
|
||||
end_if()
|
||||
def load_result():
|
||||
res = regint(1)
|
||||
for term in terms:
|
||||
if_then(term())
|
||||
else_then()
|
||||
old_res = res
|
||||
res = regint(0)
|
||||
res.link(old_res)
|
||||
for term in terms:
|
||||
end_if()
|
||||
return res
|
||||
return load_result
|
||||
|
||||
@@ -1570,14 +1849,25 @@ def listen_for_clients(port):
|
||||
"""
|
||||
instructions.listen(regint.conv(port))
|
||||
|
||||
def accept_client_connection(port):
|
||||
def accept_client_connection(port, players=None):
|
||||
""" Accept client connection on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
:param players: subset of players (default: all)
|
||||
:returns: client id
|
||||
|
||||
"""
|
||||
res = regint()
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
if players is None:
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
else:
|
||||
@if_e(sum(regint(players) ==
|
||||
get_player_id()._v.expand_to_vector(len(players))))
|
||||
def _():
|
||||
res.update(accept_client_connection(port))
|
||||
@else_
|
||||
def _():
|
||||
res.update(-1)
|
||||
return res
|
||||
|
||||
def init_client_connection(host, port, my_id, relative_port=True):
|
||||
@@ -1697,14 +1987,14 @@ def cint_cint_division(a, b, k, f):
|
||||
from Compiler.program import Program
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def sint_cint_division(a, b, k, f, kappa, nearest=False):
|
||||
def sint_cint_division(a, b, k, f, nearest=False):
|
||||
"""
|
||||
type(a) = sint, type(b) = cint
|
||||
"""
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k)
|
||||
absolute_b = b * sign_b
|
||||
absolute_a = a * sign_a
|
||||
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
||||
@@ -1714,20 +2004,20 @@ def sint_cint_division(a, b, k, f, kappa, nearest=False):
|
||||
W = w0
|
||||
|
||||
for i in range(1, theta):
|
||||
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
|
||||
A = (A * W).round(2 * k, f, nearest=nearest, signed=True)
|
||||
temp = (B * W + 2 * (f - 1)) >> f
|
||||
W = two - temp
|
||||
B = temp
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
def IntDiv(a, b, k, kappa=None):
|
||||
def IntDiv(a, b, k):
|
||||
l = 2 * k + 1
|
||||
b = a.conv(b)
|
||||
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
|
||||
kappa, nearest=True)
|
||||
nearest=True)
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Goldschmidt method as presented in Catrina10,
|
||||
"""
|
||||
@@ -1750,40 +2040,40 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
|
||||
base.set_global_vector_size(b.size)
|
||||
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
|
||||
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
|
||||
w = AppRcr(b, k, f, simplex_flag, nearest).extend(2 * k)
|
||||
x = alpha - b.extend(2 * k) * w
|
||||
base.reset_global_vector_size()
|
||||
|
||||
y = a.extend(l_y) * w
|
||||
y = y.round(l_y, f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, f, nearest, signed=True)
|
||||
|
||||
for i in range(theta - 1):
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
x = x * x
|
||||
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, 2*f, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, nearest, signed=True)
|
||||
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, 3 * f - res_f, nearest, signed=True)
|
||||
return y
|
||||
|
||||
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
|
||||
def AppRcr(b, k, f, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Approximate reciprocal of [b]:
|
||||
Given [b], compute [1/b]
|
||||
"""
|
||||
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
|
||||
c, v = b.Norm(k, f, kappa, simplex_flag)
|
||||
c, v = b.Norm(k, f, simplex_flag)
|
||||
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
|
||||
d = alpha - 2 * c
|
||||
w = d * v
|
||||
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
|
||||
w = w.round(2 * k + 1, 2 * (k - f), nearest, signed=True)
|
||||
# now w * 2 ^ {-f} should be an initial approximation of 1/b
|
||||
return w
|
||||
|
||||
def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
def Norm(b, k, f, simplex_flag=False):
|
||||
"""
|
||||
Computes secret integer values [c] and [v_prime] st.
|
||||
2^{k-1} <= c < 2^k and c = b*v_prime
|
||||
@@ -1799,8 +2089,8 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
absolute_val = sign * b
|
||||
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
|
||||
suffixes = PreOR(bits, kappa)[::-1]
|
||||
bits = absolute_val.bit_decompose(k, maybe_mixed=True)[::-1]
|
||||
suffixes = PreOR(bits)[::-1]
|
||||
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
|
||||
290
Compiler/ml.py
290
Compiler/ml.py
@@ -203,6 +203,20 @@ def _no_mem_warnings(function):
|
||||
copy_doc(wrapper, function)
|
||||
return wrapper
|
||||
|
||||
def _layer_method_call_tape(function):
|
||||
function = method_call_tape(function)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._Y.alloc()
|
||||
if self.inputs and len(self.inputs) == 1:
|
||||
backup = self.inputs
|
||||
del self.inputs
|
||||
res = function(self, *args, **kwargs)
|
||||
self.inputs = backup
|
||||
return res
|
||||
else:
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class Tensor(MultiArray):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['alloc'] = False
|
||||
@@ -259,6 +273,7 @@ class Layer:
|
||||
def Y(self, value):
|
||||
self._Y = value
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, training=None):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
@@ -834,7 +849,7 @@ class Dense(DenseBase):
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
else:
|
||||
prod = self.f_input
|
||||
max_size = get_program().budget // self.d_out
|
||||
max_size = get_program().budget
|
||||
@multithread(self.n_threads, N, max_size)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
@@ -1045,7 +1060,8 @@ class ElementWiseLayer(NoVariableLayer):
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
n_per_item = reduce(operator.mul, self.X.sizes[1:])
|
||||
@multithread(self.n_threads, len(batch) * n_per_item)
|
||||
@multithread(self.n_threads, len(batch) * n_per_item,
|
||||
max_size=program.budget)
|
||||
def _(base, size):
|
||||
self.Y.assign_vector(self.f_part(base, size), base)
|
||||
|
||||
@@ -1195,6 +1211,7 @@ class MaxPool(PoolBase):
|
||||
list/tuple of integers
|
||||
|
||||
"""
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, training=False):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
@@ -1252,26 +1269,38 @@ class Concat(NoVariableLayer):
|
||||
self.dimension = dimension
|
||||
shapes = [inp.shape for inp in inputs]
|
||||
assert dimension == 3
|
||||
assert len(shapes) == 2
|
||||
assert len(shapes[0]) == len(shapes[1])
|
||||
assert len(shapes[0]) == 4
|
||||
for shape in shapes:
|
||||
assert len(shape) == len(shapes[0])
|
||||
shape = []
|
||||
for i in range(len(shapes[0])):
|
||||
if i == dimension:
|
||||
shape.append(shapes[0][i] + shapes[1][i])
|
||||
shape.append(sum(x[i] for x in shapes))
|
||||
else:
|
||||
assert shapes[0][i] == shapes[1][i]
|
||||
shape.append(shapes[0][i])
|
||||
self.Y = Tensor(shape, sfix)
|
||||
self.bases = [sum(x[dimension] for x in shapes[:k])
|
||||
for k in range(len(shapes))]
|
||||
self.addresses = Array.create_from(regint(list(
|
||||
x.Y.address for x in inputs)))
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
assert len(batch) == 1
|
||||
@for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3])
|
||||
def _(i, j):
|
||||
X = [x.Y[batch[0]] for x in self.inputs]
|
||||
self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector())
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
X[1][i][j].get_vector(),
|
||||
len(X[0][i][j]))
|
||||
if len(set(self.bases)) == 1:
|
||||
@for_range(len(self.inputs))
|
||||
def _(k):
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
MultiArray(
|
||||
self.inputs[0].shape,
|
||||
address=self.addresses[k])[i][j].get_vector(),
|
||||
k * self.bases[1])
|
||||
else:
|
||||
X = [x.Y[batch[0]] for x in self.inputs]
|
||||
for k in range(len(self.inputs)):
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
X[k][i][j].get_vector(), self.bases[k])
|
||||
|
||||
class Add(NoVariableLayer):
|
||||
""" Fixed-point addition layer.
|
||||
@@ -1287,12 +1316,12 @@ class Add(NoVariableLayer):
|
||||
self.inputs = inputs
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
assert len(batch) == 1
|
||||
@multithread(self.n_threads, self.Y[0].total_size())
|
||||
def _(base, size):
|
||||
tmp = sum(inp.Y[batch[0]].get_vector(base, size)
|
||||
for inp in self.inputs)
|
||||
self.Y[batch[0]].assign_vector(tmp, base)
|
||||
for bb in batch:
|
||||
tmp = sum(inp.Y[bb].get_vector(base, size)
|
||||
for inp in self.inputs)
|
||||
self.Y[bb].assign_vector(tmp, base)
|
||||
|
||||
class FusedBatchNorm(Layer):
|
||||
""" Fixed-point fused batch normalization layer (inference only).
|
||||
@@ -1334,12 +1363,14 @@ class BatchNorm(Layer):
|
||||
|
||||
def __init__(self, shape, approx=True, args=None):
|
||||
assert len(shape) in (2, 3, 4)
|
||||
self.Y = sfix.Tensor(shape)
|
||||
if len(shape) == 4:
|
||||
shape = [shape[0], shape[1] * shape[2], shape[3]]
|
||||
elif len(shape) == 2:
|
||||
shape = [shape[0], 1, shape[1]]
|
||||
tensors = (Tensor(shape, sfix) for i in range(4))
|
||||
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
|
||||
self.my_Y = sfix.Tensor(shape, address=self.Y.address)
|
||||
tensors = (Tensor(shape, sfix) for i in range(3))
|
||||
self.X, self.nabla_X, self.nabla_Y = tensors
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
self.var, self.mu, self.weights, self.bias = arrays
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
@@ -1369,16 +1400,16 @@ class BatchNorm(Layer):
|
||||
|
||||
def _output(self, batch, mu, var):
|
||||
factor = sfix.Array(len(mu))
|
||||
factor[:] = self.InvertSqrt(var[:] + self.epsilon)
|
||||
factor[:] = self.InvertSqrt(var[:] + self.epsilon) * self.weights[:]
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[len(batch), self.X.sizes[1]])
|
||||
def _(i, j):
|
||||
tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:]
|
||||
self.Y[i][j][:] = self.bias[:] + tmp
|
||||
tmp = (self.X[i][j][:] - mu[:]) * factor[:]
|
||||
self.my_Y[i][j][:] = self.bias[:] + tmp
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch, training=False):
|
||||
if training or not self.is_trained:
|
||||
self.is_trained = True
|
||||
d = self.X.sizes[1]
|
||||
d_in = self.X.sizes[2]
|
||||
s = sfix.Array(d_in)
|
||||
@@ -1411,7 +1442,7 @@ class BatchNorm(Layer):
|
||||
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
|
||||
print_ln('%s at (%s, %s, %s): in=%s out=%s',
|
||||
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
|
||||
self.Y[i][j][k].reveal())
|
||||
self.my_Y[i][j][k].reveal())
|
||||
else:
|
||||
self._output(batch, self.mu_hat, self.var_hat)
|
||||
|
||||
@@ -1471,6 +1502,10 @@ class BatchNorm(Layer):
|
||||
self.nabla_Y[i][j][k].reveal(),
|
||||
self.nabla_X[i][j][k].reveal())
|
||||
|
||||
def reveal_parameters_to_binary(self):
|
||||
for param in self.thetas() + (self.mu_hat, self.var_hat):
|
||||
param.reveal().binary_output()
|
||||
|
||||
class QuantBase(object):
|
||||
bias_before_reduction = True
|
||||
|
||||
@@ -1498,11 +1533,12 @@ class QuantBase(object):
|
||||
class FixBase:
|
||||
bias_before_reduction = False
|
||||
|
||||
@staticmethod
|
||||
def new_squant():
|
||||
class _(sfix):
|
||||
params = None
|
||||
return _
|
||||
class my_squant(sfix):
|
||||
params = None
|
||||
|
||||
@classmethod
|
||||
def new_squant(cls):
|
||||
return cls.my_squant
|
||||
|
||||
def input_params_from(self, player):
|
||||
pass
|
||||
@@ -1553,7 +1589,8 @@ class ConvBase(BaseLayer):
|
||||
cls.temp_inputs = sfix.Array(size)
|
||||
|
||||
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride,
|
||||
padding='SAME', tf_weight_format=False, inputs=None):
|
||||
padding='SAME', tf_weight_format=False, inputs=None,
|
||||
weight_type=None):
|
||||
super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs)
|
||||
|
||||
self.weight_shape = weight_shape
|
||||
@@ -1582,7 +1619,11 @@ class ConvBase(BaseLayer):
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.weight_squant = self.new_squant()
|
||||
if weight_type:
|
||||
self.weight_squant = weight_type
|
||||
else:
|
||||
self.weight_squant = self.new_squant()
|
||||
|
||||
self.bias_squant = self.new_squant()
|
||||
|
||||
self.weights = Tensor(weight_shape, self.weight_squant)
|
||||
@@ -1645,8 +1686,7 @@ class ConvBase(BaseLayer):
|
||||
n_summands = self.n_summands()
|
||||
#start_timer(2)
|
||||
n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:])
|
||||
@multithread(self.n_threads, n_outputs,
|
||||
1000 if sfix.round_nearest else 10 ** 6)
|
||||
@multithread(self.n_threads, n_outputs, max_size=program.budget)
|
||||
def _(base, n_per_thread):
|
||||
res = self.input_squant().unreduced(
|
||||
sint.load_mem(unreduced.address + base,
|
||||
@@ -1709,7 +1749,7 @@ class Conv2d(ConvBase):
|
||||
res += self.bias.expand_to_vector(j, res.size).v
|
||||
else:
|
||||
res += self.bias.expand_to_vector(j, res.size).v << \
|
||||
self.input_squant.f
|
||||
self.weight_squant.f
|
||||
addresses = regint.inc(res.size,
|
||||
self.unreduced[i * part_size].address + j,
|
||||
n_channels_out)
|
||||
@@ -2029,7 +2069,7 @@ class AveragePool2d(BaseLayer):
|
||||
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
|
||||
|
||||
def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
|
||||
padding=0):
|
||||
padding=0, **kwargs):
|
||||
""" More convenient interface to :py:class:`FixConv2d`.
|
||||
|
||||
:param input_shape: input shape (tuple/list of four int)
|
||||
@@ -2050,7 +2090,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
|
||||
padding = padding.upper() if isinstance(padding, str) \
|
||||
else padding
|
||||
return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape,
|
||||
stride, padding)
|
||||
stride, padding, **kwargs)
|
||||
|
||||
def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
|
||||
""" More convenient interface to :py:class:`MaxPool`.
|
||||
@@ -2102,6 +2142,7 @@ class FixAveragePool2d(PoolBase, FixBase):
|
||||
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
|
||||
[1] + list(filter_size) + [1], padding)
|
||||
self.pool_size = reduce(operator.mul, filter_size)
|
||||
self.d_out = self.Y.shape[-1]
|
||||
if output_shape:
|
||||
assert self.Y.shape == list(output_shape)
|
||||
|
||||
@@ -2192,7 +2233,7 @@ class Optimizer:
|
||||
res.output_stats = 'output_stats' in program.args
|
||||
return res
|
||||
|
||||
def __init__(self, layers=[], report_loss=None):
|
||||
def __init__(self, layers=[], report_loss=None, time_layers=False):
|
||||
if get_program().options.binary:
|
||||
raise CompilerError(
|
||||
'machine learning code not compatible with binary circuits')
|
||||
@@ -2207,6 +2248,10 @@ class Optimizer:
|
||||
self.stopped_on_loss = MemValue(0)
|
||||
self.stopped_on_low_loss = MemValue(0)
|
||||
self.layers = layers
|
||||
self.time_layers = time_layers
|
||||
if time_layers:
|
||||
for i, layer in enumerate(layers):
|
||||
print('Timer %d: %s' % (100 + i, repr(layer)))
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@@ -2275,6 +2320,8 @@ class Optimizer:
|
||||
if self.time_layers:
|
||||
start_timer(100 + i)
|
||||
if i != len(self.layers) - 1 or run_last:
|
||||
for theta in layer.thetas():
|
||||
theta.alloc()
|
||||
layer.forward(batch=self.batch_for(layer, batch),
|
||||
training=training)
|
||||
if self.print_random_update:
|
||||
@@ -2322,6 +2369,10 @@ class Optimizer:
|
||||
self.forward(batch=batch, run_last=False)
|
||||
part = self.layers[-1].eval(batch_size, top=top)
|
||||
res.assign_part_vector(part.get_vector(), start)
|
||||
if self.output_stats:
|
||||
for layer in self.layers[:-1]:
|
||||
print_ln(layer)
|
||||
self.stat(' Y', layer.Y)
|
||||
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
|
||||
return res
|
||||
|
||||
@@ -2567,9 +2618,9 @@ class Optimizer:
|
||||
|
||||
def run_in_batches(self, f, data, batch_size, truth=None):
|
||||
batch_size = min(batch_size, data.sizes[0])
|
||||
training_data = self.layers[0].X.address
|
||||
training_data = self.layers[0]._X.array._address
|
||||
training_truth = self.layers[-1].Y.address
|
||||
self.layers[0].X.address = data.address
|
||||
self.layers[0]._X.address = data.address
|
||||
if truth:
|
||||
self.layers[-1].Y.address = truth.address
|
||||
N = data.sizes[0]
|
||||
@@ -2620,8 +2671,12 @@ class Optimizer:
|
||||
if model_input:
|
||||
for layer in self.layers:
|
||||
layer.input_from(0)
|
||||
elif reset:
|
||||
elif reset and not 'no_reset' in program.args:
|
||||
self.reset()
|
||||
else:
|
||||
for layer in self.layers:
|
||||
for theta in layer.thetas():
|
||||
theta.alloc()
|
||||
if 'one_iter' in program.args:
|
||||
print_float_prec(16)
|
||||
self.output_weights()
|
||||
@@ -2642,6 +2697,8 @@ class Optimizer:
|
||||
if 'bench10' in program.args or 'bench1' in program.args:
|
||||
n = 1 if 'bench1' in program.args else 10
|
||||
print('benchmarking %s iterations' % n)
|
||||
# force allocatoin
|
||||
self.layers[0].X, self.layers[-1].Y
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
batch = Array.create_from(regint.inc(batch_size))
|
||||
@@ -2748,13 +2805,15 @@ class Optimizer:
|
||||
return list(self.thetas)
|
||||
|
||||
def reveal_model_to_binary(self):
|
||||
input_shape = self.layers[0].X.shape
|
||||
""" Reveal model and store it in the binary output file, see
|
||||
:ref:`reveal-model` for details. """
|
||||
input_shape = self.layers[0]._X.shape
|
||||
for layer in self.layers:
|
||||
if len(input_shape) == 4 and isinstance(layer, DenseBase):
|
||||
layer.reveal_parameters_to_binary(reshape=input_shape[1:])
|
||||
else:
|
||||
layer.reveal_parameters_to_binary()
|
||||
input_shape = layer.Y.shape
|
||||
input_shape = layer._Y.shape
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam/AMSgrad optimizer.
|
||||
@@ -2866,10 +2925,12 @@ class SGD(Optimizer):
|
||||
self.layers = layers
|
||||
self.n_epochs = n_epochs
|
||||
self.nablas = []
|
||||
self.momentum_values = []
|
||||
self.delta_thetas = []
|
||||
for layer in layers:
|
||||
self.nablas.extend(layer.nablas())
|
||||
for theta in layer.thetas():
|
||||
self.momentum_values.append(theta.same_shape())
|
||||
self.delta_thetas.append(theta.same_shape())
|
||||
self.set_learning_rate(0.01)
|
||||
self.debug = debug
|
||||
@@ -2889,23 +2950,27 @@ class SGD(Optimizer):
|
||||
j = i + label * len(X_by_label[0])
|
||||
self.layers[0].X[j] = X[i]
|
||||
self.layers[-1].Y[j] = label
|
||||
for y in self.momentum_values:
|
||||
y.assign_all(0)
|
||||
for y in self.delta_thetas:
|
||||
y.assign_all(0)
|
||||
super(SGD, self).reset()
|
||||
|
||||
def _update(self, i_epoch, i_batch, batch):
|
||||
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
||||
self.delta_thetas):
|
||||
for nabla, theta, momentum_value, delta_theta in zip(self.nablas, self.thetas,
|
||||
self.momentum_values, self.delta_thetas):
|
||||
@multithread(self.n_threads, nabla.total_size())
|
||||
def _(base, size):
|
||||
old = delta_theta.get_vector(base, size)
|
||||
old = momentum_value.get_vector(base, size)
|
||||
red_old = self.momentum * old
|
||||
rate = self.gamma.expand_to_vector(size)
|
||||
nabla_vector = nabla.get_vector(base, size)
|
||||
log_batch_size = math.log(len(batch), 2)
|
||||
# divide by len(batch) by truncation
|
||||
# increased rate if len(batch) is not a power of two
|
||||
pre_trunc = nabla_vector.v * rate.v
|
||||
diff = red_old - nabla_vector
|
||||
pre_trunc = diff.v * rate.v
|
||||
momentum_value.assign_vector(diff, base)
|
||||
k = max(nabla_vector.k, rate.k) + rate.f
|
||||
m = rate.f + int(log_batch_size)
|
||||
if self.early_division:
|
||||
@@ -2914,8 +2979,7 @@ class SGD(Optimizer):
|
||||
v = pre_trunc.round(k, m, signed=True,
|
||||
nearest=sfix.round_nearest)
|
||||
new = nabla_vector._new(v)
|
||||
diff = red_old - new
|
||||
delta_theta.assign_vector(diff, base)
|
||||
delta_theta.assign_vector(new, base)
|
||||
theta.assign_vector(theta.get_vector(base, size) +
|
||||
delta_theta.get_vector(base, size), base)
|
||||
if self.print_update_average:
|
||||
@@ -3046,7 +3110,7 @@ class keras:
|
||||
def summary(self):
|
||||
self.opt.summary()
|
||||
|
||||
def build(self, input_shape, batch_size=128):
|
||||
def build(self, input_shape, batch_size=128, program=None):
|
||||
data_input_shape = input_shape
|
||||
if self.opt != None and \
|
||||
input_shape == self.opt.layers[0]._X.sizes and \
|
||||
@@ -3120,8 +3184,11 @@ class keras:
|
||||
if layers[-1].d_out == 1:
|
||||
layers.append(Output(data_input_shape[0]))
|
||||
else:
|
||||
layers.append(
|
||||
MultiOutput(data_input_shape[0], layers[-1].d_out))
|
||||
shape = data_input_shape[0], layers[-1].d_out
|
||||
if program:
|
||||
layers.append(MultiOutput.from_args(program, *shape))
|
||||
else:
|
||||
layers.append(MultiOutput(*shape))
|
||||
if self.optimizer[1]:
|
||||
raise Exception('use keyword arguments for optimizer')
|
||||
opt = self.optimizer[0]
|
||||
@@ -3186,11 +3253,11 @@ class keras:
|
||||
batch_size = min(batch_size, self.batch_size)
|
||||
return self.opt.eval(x, batch_size=batch_size)
|
||||
|
||||
def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
regression=False):
|
||||
""" Convert a PyTorch Sequential object to MP-SPDZ layers.
|
||||
def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
regression=False, layer_args={}, program=None):
|
||||
""" Convert a PyTorch Module object to MP-SPDZ layers.
|
||||
|
||||
:param sequence: PyTorch Sequential object
|
||||
:param model: PyTorch Module object
|
||||
:param data_input_shape: input shape (list of four int)
|
||||
:param batch_size: batch size (int)
|
||||
:param input_via: player to input model data via (default: don't)
|
||||
@@ -3198,17 +3265,39 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
|
||||
"""
|
||||
layers = []
|
||||
named_layers = {}
|
||||
|
||||
def mul(x):
|
||||
return reduce(operator.mul, x)
|
||||
|
||||
def process(item):
|
||||
nonlocal input_shape
|
||||
import torch
|
||||
|
||||
def process(item, inputs, input_shape, args, kwargs={}):
|
||||
if item == torch.cat:
|
||||
if len(inputs) > 1:
|
||||
layers.append(
|
||||
Concat(inputs, dimension=len(inputs[0].shape) - 1))
|
||||
return
|
||||
elif item == operator.add:
|
||||
layers.append(Add(inputs))
|
||||
return
|
||||
elif item in (torch.flatten, 'flatten', 'size'):
|
||||
return
|
||||
elif item == 'view':
|
||||
assert -1 in args or \
|
||||
reduce(operator.mul, args) == reduce(operator.mul, input_shape)
|
||||
return
|
||||
elif item == torch.nn.functional.avg_pool2d:
|
||||
layers.append(FixAveragePool2d(input_shape, None, args[1],
|
||||
kwargs.get('stride', args[1]),
|
||||
kwargs.get('padding', 0)))
|
||||
input_shape = layers[-1].Y.shape
|
||||
return
|
||||
# single-input layers from here
|
||||
if inputs and len(inputs) > 1:
|
||||
raise CompilerError('multi-input layer %s not supported' % item)
|
||||
name = type(item).__name__
|
||||
if name == 'Sequential':
|
||||
for x in item:
|
||||
process(x)
|
||||
elif name == 'Linear':
|
||||
if name == 'Linear':
|
||||
assert mul(input_shape[1:]) == item.in_features
|
||||
assert item.bias is not None
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
@@ -3239,7 +3328,7 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
elif name == 'Conv2d':
|
||||
layers.append(easyConv2d(input_shape, batch_size, item.out_channels,
|
||||
item.kernel_size, item.stride,
|
||||
item.padding))
|
||||
item.padding, **layer_args.get(item, {})))
|
||||
input_shape = layers[-1].Y.shape
|
||||
if input_via is not None:
|
||||
shapes = [x.shape for x in
|
||||
@@ -3247,11 +3336,14 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
import numpy
|
||||
swapped = numpy.moveaxis(
|
||||
numpy.array(item.weight.detach()), 1, -1)
|
||||
layers[-1].weights = sfix.input_tensor_via(input_via, swapped)
|
||||
layers[-1].bias = sfix.input_tensor_via(
|
||||
input_via, item.bias.detach())
|
||||
layers[-1].weights = \
|
||||
layers[-1].weights.value_type.input_tensor_via(
|
||||
input_via, swapped)
|
||||
assert layers[-1].weights.shape == shapes[0]
|
||||
assert layers[-1].bias.shape == shapes[1]
|
||||
if isinstance(item.bias, torch.Tensor):
|
||||
layers[-1].bias = sfix.input_tensor_via(
|
||||
input_via, item.bias.detach())
|
||||
assert layers[-1].bias.shape == shapes[1]
|
||||
elif name == 'MaxPool2d':
|
||||
layers.append(easyMaxPool(input_shape, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
@@ -3260,10 +3352,24 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'ReLU':
|
||||
elif name == 'AdaptiveAvgPool2d' or \
|
||||
item == torch.nn.functional.adaptive_avg_pool2d:
|
||||
if name == 'AdaptiveAvgPool2d':
|
||||
output = item.output_size
|
||||
else:
|
||||
output = args[1]
|
||||
for i in (0, 1):
|
||||
assert input_shape[1 + i] % output[i] == 0
|
||||
stride = [input_shape[1 + i] // output[i] for i in (0, 1)]
|
||||
kernel_size = [input_shape[1 + i] - (output[i] - 1) * stride[i]
|
||||
for i in (0, 1)]
|
||||
layers.append(FixAveragePool2d(input_shape, None, kernel_size,
|
||||
stride, padding=0))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'ReLU' or item == torch.nn.functional.relu:
|
||||
layers.append(Relu(input_shape))
|
||||
elif name == 'Flatten':
|
||||
pass
|
||||
return
|
||||
elif name == 'BatchNorm2d':
|
||||
layers.append(BatchNorm(layers[-1].Y.sizes))
|
||||
if input_via is not None:
|
||||
@@ -3277,16 +3383,57 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
alpha=item.p))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
else:
|
||||
raise CompilerError('unknown PyTorch module: ' + name)
|
||||
raise CompilerError('unknown PyTorch module: %s' % item)
|
||||
layers[-1].inputs = inputs
|
||||
|
||||
input_shape = data_input_shape + [1] * (4 - len(data_input_shape))
|
||||
process(sequence)
|
||||
|
||||
torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes)
|
||||
for i, layer in enumerate(torch_layers[1:-1]):
|
||||
if layer.op == 'call_module':
|
||||
target = model
|
||||
for attr in layer.target.split('.'):
|
||||
target = getattr(target, attr)
|
||||
else:
|
||||
target = layer.target
|
||||
if not layers:
|
||||
assert layer.args == (torch_layers[i],)
|
||||
inputs = None
|
||||
else:
|
||||
if len(layer.args) < 2 or (layer.args[1] != 1 and
|
||||
layer.args[1] != (1, 1)):
|
||||
args = layer.args
|
||||
elif isinstance(layer.args[0], list):
|
||||
args = layer.args[0]
|
||||
else:
|
||||
args = layer.args[0],
|
||||
inputs = []
|
||||
try:
|
||||
for x in args:
|
||||
inputs.append(named_layers[x])
|
||||
except KeyError:
|
||||
pass
|
||||
if len(inputs) == 1:
|
||||
if isinstance(inputs[0], (Dropout, BatchNorm)):
|
||||
input_shape = inputs[0].inputs[0].Y.shape
|
||||
else:
|
||||
input_shape = inputs[0]._Y.shape
|
||||
else:
|
||||
input_shape = None
|
||||
process(target, inputs, input_shape, layer.args, layer.kwargs)
|
||||
if layers:
|
||||
named_layers[layer] = layers[-1]
|
||||
|
||||
if regression:
|
||||
layers.append(LinearOutput(data_input_shape[0], layers[-1].d_out))
|
||||
elif layers[-1].d_out == 1:
|
||||
layers.append(Output(data_input_shape[0]))
|
||||
else:
|
||||
layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out))
|
||||
shape = data_input_shape[0], layers[-1].d_out
|
||||
if program:
|
||||
layers.append(MultiOutput.from_args(program, *shape))
|
||||
else:
|
||||
layers.append(MultiOutput(*shape))
|
||||
return layers
|
||||
|
||||
class OneLayerSGD:
|
||||
@@ -3335,6 +3482,9 @@ class OneLayerSGD:
|
||||
|
||||
class SGDLogistic(OneLayerSGD):
|
||||
""" Logistic regression using SGD.
|
||||
The member :py:obj:`opt` refers to the internal instance of
|
||||
:py:class:`Optimizer`, which allows to use the funcionality
|
||||
therein.
|
||||
|
||||
:param n_epochs: number of epochs
|
||||
:param batch_size: batch size
|
||||
@@ -3460,6 +3610,8 @@ def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
|
||||
|
||||
def mr(A, n_iterations, stop=False):
|
||||
""" Iterative matrix inverse approximation.
|
||||
This is based on the conjugate gradients algorithm in Section
|
||||
10.2.4 of `these lecture notes <https://graphics.stanford.edu/courses/cs205a-13-fall/assets/notes/cs205a_notes.pdf>`_.
|
||||
|
||||
:param A: matrix to invert
|
||||
:param n_iterations: maximum number of iterations
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
Module for math operations.
|
||||
|
||||
Implements trigonometric and logarithmic functions.
|
||||
Most of the functionality is due to `Aly and Smart
|
||||
<https://eprint.iacr.org/2019/354>`_ with some optimizations by
|
||||
`Keller and Sun <https://eprint.iacr.org/2022/933>`_.
|
||||
|
||||
This has to imported explicitly.
|
||||
"""
|
||||
@@ -98,7 +100,7 @@ pi_over_2 = math.radians(90)
|
||||
# @return truncated sint value of x
|
||||
def trunc(x):
|
||||
if isinstance(x, types._fix):
|
||||
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
|
||||
return x.v.right_shift(x.f, x.k, signed=True)
|
||||
elif type(x) is types.sfloat:
|
||||
v, p, z, s = floatingpoint.FLRound(x, 0)
|
||||
#return types.sfloat(v, p, z, s, x.err)
|
||||
@@ -106,19 +108,6 @@ def trunc(x):
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# loads integer to fractional type (sint)
|
||||
# @param x: coefficient to be truncated.
|
||||
#
|
||||
# @return returns sfix, sfloat loaded value
|
||||
def load_sint(x, l_type):
|
||||
if l_type is types.sfix:
|
||||
return types.sfix.from_sint(x)
|
||||
elif l_type is types.sfloat:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# evaluates a Polynomial to a given x in a privacy preserving manner.
|
||||
# Inputs can be of any kind of register, secret or otherwise.
|
||||
@@ -433,7 +422,7 @@ def mux_exp(x, y, block_size=8):
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def log2_fx(x, use_division=True):
|
||||
"""
|
||||
r"""
|
||||
Returns the result of :math:`\log_2(x)` for any unbounded
|
||||
number. This is achieved by changing :py:obj:`x` into
|
||||
:math:`f \cdot 2^n` where f is bounded by :math:`[0.5, 1]`. Then the
|
||||
@@ -448,7 +437,7 @@ def log2_fx(x, use_division=True):
|
||||
if isinstance(x, types._fix):
|
||||
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
||||
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f)
|
||||
p -= x.f
|
||||
vlen = x.f
|
||||
v = x._new(v, k=x.k, f=x.f)
|
||||
@@ -473,8 +462,8 @@ def log2_fx(x, use_division=True):
|
||||
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
|
||||
|
||||
|
||||
def pow_fx(x, y):
|
||||
"""
|
||||
def pow_fx(x, y, zero_output=False):
|
||||
r"""
|
||||
Returns the value of the expression :math:`x^y` where both inputs
|
||||
are secret shared. It uses :py:func:`log2_fx` together with
|
||||
:py:func:`exp2_fx` to calculate the expression :math:`2^{y \log_2(x)}`.
|
||||
@@ -494,11 +483,11 @@ def pow_fx(x, y):
|
||||
# obtains y * log2(x)
|
||||
exp = y * log2_x
|
||||
# returns 2^(y*log2(x))
|
||||
return exp2_fx(exp)
|
||||
return exp2_fx(exp, zero_output)
|
||||
|
||||
|
||||
def log_fx(x, b):
|
||||
"""
|
||||
r"""
|
||||
Returns the value of the expression :math:`\log_b(x)` where
|
||||
:py:obj:`x` is secret shared. It uses :py:func:`log2_fx` to
|
||||
calculate the expression :math:`\log_b(2) \cdot \log_2(x)`.
|
||||
@@ -535,8 +524,8 @@ def abs_fx(x):
|
||||
#
|
||||
# @return floored sint value of x
|
||||
def floor_fx(x):
|
||||
return load_sint(x.v.right_shift(x.f, bit_length=x.k, security=x.kappa,
|
||||
signed=True), type(x))
|
||||
return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True),
|
||||
k=x.k, f=x.f)
|
||||
|
||||
|
||||
### sqrt methods
|
||||
@@ -743,13 +732,13 @@ def lin_app_SQ(b, k, f):
|
||||
c, v, m, W = norm_SQ(types.sint(b), k)
|
||||
|
||||
# c is now escalated
|
||||
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
|
||||
w = alpha * c + beta # equation before b and reduction by order of k
|
||||
|
||||
|
||||
# m even or odd determination
|
||||
m_bit = types.sint()
|
||||
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
|
||||
m = load_sint(m_bit, types.sfix)
|
||||
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False)
|
||||
m = m_bit
|
||||
|
||||
# w times v this way both terms have 2^3k and can be symplified
|
||||
w = w * v
|
||||
@@ -774,7 +763,7 @@ def lin_app_SQ(b, k, f):
|
||||
def sqrt_fx(x_l, k, f):
|
||||
factor = 1.0 / (2.0 ** f)
|
||||
|
||||
x = load_sint(x_l, types.sfix) * factor
|
||||
x = x_l * factor
|
||||
|
||||
theta = int(math.ceil(math.log(k/5.4)))
|
||||
|
||||
@@ -870,7 +859,7 @@ def atan(x):
|
||||
|
||||
|
||||
def asin(x):
|
||||
"""
|
||||
r"""
|
||||
Returns the arcsine (sfix) of any given fractional value.
|
||||
|
||||
:param x: fractional input (sfix). valid interval is :math:`-1 \le x \le 1`
|
||||
@@ -886,7 +875,7 @@ def asin(x):
|
||||
|
||||
|
||||
def acos(x):
|
||||
"""
|
||||
r"""
|
||||
Returns the arccosine (sfix) of any given fractional value.
|
||||
|
||||
:param x: fractional input (sfix). :math:`-1 \le x \le 1`
|
||||
@@ -898,7 +887,7 @@ def acos(x):
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""
|
||||
r"""
|
||||
Hyperbolic tangent. For efficiency, accuracy is diminished
|
||||
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
|
||||
:math:`f` denote the fixed-point parameters.
|
||||
@@ -912,29 +901,29 @@ def tanh(x):
|
||||
|
||||
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
|
||||
|
||||
def Sep(x):
|
||||
def Sep(x, sfix=types.sfix):
|
||||
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
|
||||
bb = b[:]
|
||||
while len(bb) < 2 * x.f - 1:
|
||||
bb.insert(0, type(b[0])(0))
|
||||
t = x.v * (1 + x.v.bit_compose(b_i.bit_not()
|
||||
for b_i in bb[-2 * x.f + 1:]))
|
||||
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
||||
u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
||||
b += [b[0].long_one()]
|
||||
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
|
||||
|
||||
def SqrtComp(z, old=False):
|
||||
f = types.sfix.f
|
||||
def SqrtComp(z, old=False, sfix=types.sfix):
|
||||
f = sfix.f
|
||||
k = len(z)
|
||||
if isinstance(z[0], types.sint):
|
||||
return types.sfix._new(sum(z[i] * types.cfix(
|
||||
return sfix._new(sum(z[i] * types.cfix(
|
||||
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
|
||||
k_prime = k // 2
|
||||
f_prime = f // 2
|
||||
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
|
||||
c0 = types.sfix(2 ** (f / 2 + 1))
|
||||
c1 = sfix(2 ** ((f + 1) / 2 + 1))
|
||||
c0 = sfix(2 ** (f / 2 + 1))
|
||||
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
|
||||
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
||||
tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
||||
if old:
|
||||
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
|
||||
else:
|
||||
@@ -942,11 +931,15 @@ def SqrtComp(z, old=False):
|
||||
return types.sint.conv(b).if_else(c1, c0) * tmp
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def InvertSqrt(x, old=False):
|
||||
"""
|
||||
Reciprocal square root approximation by `Lu et al.
|
||||
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
|
||||
"""
|
||||
u, z = Sep(x)
|
||||
class my_sfix(types.sfix):
|
||||
f = x.f
|
||||
k = x.k
|
||||
u, z = Sep(x, sfix=my_sfix)
|
||||
c = 3.14736 + u * (4.63887 * u - 5.77789)
|
||||
return c * SqrtComp(z, old=old)
|
||||
return c * SqrtComp(z, old=old, sfix=my_sfix)
|
||||
|
||||
@@ -4,14 +4,6 @@ from .types import *
|
||||
from . import comparison, program
|
||||
|
||||
class NonLinear:
|
||||
kappa = None
|
||||
|
||||
def set_security(self, kappa):
|
||||
pass
|
||||
|
||||
def check_security(self, kappa):
|
||||
pass
|
||||
|
||||
def mod2m(self, a, k, m, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
@@ -45,18 +37,16 @@ class NonLinear:
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
res = sint()
|
||||
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa,
|
||||
signed)
|
||||
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed)
|
||||
return res
|
||||
|
||||
def trunc(self, a, k, m, kappa, signed):
|
||||
self.check_security(kappa)
|
||||
def trunc(self, a, k, m, signed):
|
||||
if m == 0:
|
||||
return a
|
||||
return self._trunc(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return -self.trunc(a, k, k - 1, kappa, True)
|
||||
def ltz(self, a, k):
|
||||
return -self.trunc(a, k, k - 1, True)
|
||||
|
||||
class Masking(NonLinear):
|
||||
def eqz(self, a, k):
|
||||
@@ -68,28 +58,19 @@ class Masking(NonLinear):
|
||||
|
||||
class Prime(Masking):
|
||||
""" Non-linear functionality modulo a prime with statistical masking. """
|
||||
def __init__(self, kappa):
|
||||
self.set_security(kappa)
|
||||
|
||||
def set_security(self, kappa):
|
||||
self.kappa = kappa
|
||||
|
||||
def check_security(self, kappa):
|
||||
assert self.kappa == kappa or kappa is None
|
||||
|
||||
def _mod2m(self, a, k, m, signed):
|
||||
res = sint()
|
||||
if m == 1:
|
||||
Mod2(res, a, k, self.kappa, signed)
|
||||
Mod2(res, a, k, signed)
|
||||
else:
|
||||
Mod2mField(res, a, k, m, self.kappa, signed)
|
||||
Mod2mField(res, a, k, m, signed)
|
||||
return res
|
||||
|
||||
def _mask(self, a, k):
|
||||
return maskField(a, k, self.kappa)
|
||||
return maskField(a, k)
|
||||
|
||||
def _trunc_pr(self, a, k, m, signed=None):
|
||||
return TruncPrField(a, k, m, self.kappa)
|
||||
return TruncPrField(a, k, m)
|
||||
|
||||
def _trunc(self, a, k, m, signed=None):
|
||||
a_prime = self.mod2m(a, k, m, signed)
|
||||
@@ -99,12 +80,12 @@ class Prime(Masking):
|
||||
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
if maybe_mixed:
|
||||
return BitDecFieldRaw(a, k, m, self.kappa)
|
||||
return BitDecFieldRaw(a, k, m)
|
||||
else:
|
||||
return BitDecField(a, k, m, self.kappa)
|
||||
return BitDecField(a, k, m)
|
||||
|
||||
def kor(self, d):
|
||||
return KOR(d, self.kappa)
|
||||
return KOR(d)
|
||||
|
||||
class KnownPrime(NonLinear):
|
||||
""" Non-linear functionality modulo a prime known at compile time. """
|
||||
@@ -144,13 +125,13 @@ class KnownPrime(NonLinear):
|
||||
a += two_power(k)
|
||||
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
def ltz(self, a, k):
|
||||
if k + 1 < self.prime.bit_length():
|
||||
# https://dl.acm.org/doi/10.1145/3474123.3486757
|
||||
# "negative" values wrap around when doubling, thus becoming odd
|
||||
return self.mod2m(2 * a, k + 1, 1, False)
|
||||
else:
|
||||
return super(KnownPrime, self).ltz(a, k, kappa)
|
||||
return super(KnownPrime, self).ltz(a, k)
|
||||
|
||||
class Ring(Masking):
|
||||
""" Non-linear functionality modulo a power of two known at compile time.
|
||||
@@ -189,5 +170,5 @@ class Ring(Masking):
|
||||
else:
|
||||
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
def ltz(self, a, k):
|
||||
return LtzRing(a, k)
|
||||
|
||||
@@ -9,6 +9,11 @@ secret index::
|
||||
i = sint.get_input_from(0)
|
||||
a[i] = sint.get_input_from(1)
|
||||
|
||||
`The introductory book by Evans et
|
||||
al. <https://securecomputation.org>`_ contains `a chapter dedicated to
|
||||
oblivious RAM
|
||||
<https://securecomputation.org/docs/ch5-obliviousdata.pdf>`_.
|
||||
|
||||
"""
|
||||
|
||||
import random
|
||||
@@ -41,6 +46,7 @@ debug_online = False
|
||||
crash_on_overflow = False
|
||||
use_insecure_randomness = False
|
||||
debug_ram_size = False
|
||||
single_thread = False
|
||||
|
||||
def maybe_start_timer(n):
|
||||
if detailed_timing:
|
||||
@@ -844,7 +850,7 @@ class TrivialORAM(RefTrivialORAM, AbstractORAM):
|
||||
start_timer()
|
||||
|
||||
def get_n_threads(n_loops):
|
||||
if n_threads is None:
|
||||
if n_threads is None and not single_thread:
|
||||
if n_loops > 2048:
|
||||
return 8
|
||||
else:
|
||||
@@ -877,7 +883,7 @@ class LinearORAM(TrivialORAM):
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@map_sum(get_n_threads(self.size), None, self.size, \
|
||||
self.value_length + 1, t)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
@@ -897,7 +903,7 @@ class LinearORAM(TrivialORAM):
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
|
||||
@for_range_multithread(get_n_threads(self.size), None, self.size)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
access_here = self.index_vector[i]
|
||||
@@ -917,7 +923,7 @@ class LinearORAM(TrivialORAM):
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
new_empty = MemValue(new_empty)
|
||||
write = MemValue(write)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@map_sum(get_n_threads(self.size), None, self.size, \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
def f(i):
|
||||
@@ -1038,7 +1044,7 @@ class LocalIndexStructure(List):
|
||||
__getitem__ = lambda self,index: List.__getitem__(self, index)[0]
|
||||
|
||||
def get_n_threads_for_tree(size):
|
||||
if n_threads_for_tree is None:
|
||||
if n_threads_for_tree is None and not single_thread:
|
||||
if size >= 2**13:
|
||||
return 8
|
||||
else:
|
||||
@@ -1340,8 +1346,8 @@ class TreeORAM(AbstractORAM):
|
||||
half = (empty_positions[i]+1 - parity) // 2
|
||||
half_max = self.bucket_size // 2
|
||||
|
||||
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
|
||||
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
|
||||
bits = floatingpoint.B2U(half, half_max)[0]
|
||||
bits2 = floatingpoint.B2U(half+parity, half_max)[0]
|
||||
# (doesn't work)
|
||||
#bits2 = [0] * half_max
|
||||
## second half with parity bit
|
||||
@@ -1350,7 +1356,8 @@ class TreeORAM(AbstractORAM):
|
||||
#bits2[0] = (1 - bits[0]) * parity
|
||||
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
|
||||
else:
|
||||
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
|
||||
bucket_bits = floatingpoint.B2U(empty_positions[i]+1,
|
||||
self.bucket_size)[0]
|
||||
assert len(bucket_bits) == self.bucket_size
|
||||
for j, b in enumerate(bucket_bits):
|
||||
pos_bits[i * self.bucket_size + j] = [b, leaf]
|
||||
@@ -1376,8 +1383,7 @@ class TreeORAM(AbstractORAM):
|
||||
Program.prog.curr_tape.start_new_basicblock()
|
||||
|
||||
bucket_sizes = Array(2**self.D, regint)
|
||||
for i in range(2**self.D):
|
||||
bucket_sizes[i] = 0
|
||||
bucket_sizes.assign_all(0)
|
||||
|
||||
@for_range_opt(len(entries))
|
||||
def _(k):
|
||||
@@ -1697,7 +1703,7 @@ class OneLevelORAM(TreeORAM):
|
||||
|
||||
class BinaryORAM:
|
||||
def __init__(self, size, value_type=None, **kwargs):
|
||||
import circuit_oram
|
||||
from Compiler import circuit_oram
|
||||
from Compiler.GC import types
|
||||
n_bits = int(get_program().options.binary)
|
||||
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""This module contains an implementation of the "Path Oblivious Heap"
|
||||
r"""This module contains an implementation of the "Path Oblivious Heap"
|
||||
oblivious priority queue as proposed by
|
||||
`Shi <https://eprint.iacr.org/2019/274.pdf>`_.
|
||||
|
||||
@@ -968,6 +968,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
|
||||
self.type_hiding_security = type_hiding_security
|
||||
self.capacity = capacity
|
||||
self.entry_size = entry_size
|
||||
self.size = MemValue(sint(0))
|
||||
|
||||
# Print debug messages
|
||||
dprint(f"[POH] __init__: Initializing a queue...")
|
||||
@@ -1024,6 +1025,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
|
||||
dprint_ln("\n[POH] insert")
|
||||
indent()
|
||||
self.tree.insert(value, priority, fake)
|
||||
self.size.iadd(fake.bit_not())
|
||||
outdent()
|
||||
|
||||
def _extract_min(self, fake: _secret) -> _secret:
|
||||
@@ -1031,6 +1033,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
|
||||
dprint_ln("\n[POH] extract_min")
|
||||
indent()
|
||||
value = self.tree.extract_min(fake)
|
||||
self.size.iadd(-fake.bit_not())
|
||||
outdent()
|
||||
if TRACE:
|
||||
dprint_ln("[POH] extract_min: extracted value %s", value.reveal())
|
||||
|
||||
@@ -18,7 +18,7 @@ from functools import reduce
|
||||
import Compiler.instructions
|
||||
import Compiler.instructions_base
|
||||
import Compiler.instructions_base as inst_base
|
||||
from Compiler.config import REG_MAX, USER_MEM, COST
|
||||
from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX
|
||||
from Compiler.exceptions import CompilerError
|
||||
from Compiler.instructions_base import RegType
|
||||
|
||||
@@ -103,6 +103,38 @@ class Program(object):
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if options.prime:
|
||||
self.prime = int(options.prime)
|
||||
print("WARNING: --prime/-P activates code that usually isn't "
|
||||
"the most efficient variant. Consider using --field/-F "
|
||||
"and set the prime only during the actual computation.")
|
||||
if not self.rabbit_gap() and self.prime > 2 ** 50:
|
||||
print("The chosen prime is particularly inefficient. "
|
||||
"Consider using a prime that is closer to a power "
|
||||
"of two", end='')
|
||||
try:
|
||||
import gmpy2
|
||||
bad_prime = self.prime
|
||||
self.prime = 2 ** int(
|
||||
round(math.log(self.prime, 2))) + 1
|
||||
while True:
|
||||
if self.prime > 2 ** 59:
|
||||
# LWE compatibility
|
||||
step = 2 ** 15
|
||||
else:
|
||||
step = 1
|
||||
if self.prime < bad_prime:
|
||||
self.prime += step
|
||||
else:
|
||||
self.prime -= step
|
||||
if gmpy2.is_prime(self.prime):
|
||||
break
|
||||
assert self.rabbit_gap()
|
||||
print(", for example, %d." % self.prime)
|
||||
self.prime = bad_prime
|
||||
except ImportError:
|
||||
print(".")
|
||||
if options.execute:
|
||||
print("Use '-- --prime <prime>' to specify the prime for "
|
||||
"execution only.")
|
||||
max_bit_length = int(options.prime).bit_length() - 2
|
||||
if self.bit_length > max_bit_length:
|
||||
raise CompilerError(
|
||||
@@ -111,7 +143,7 @@ class Program(object):
|
||||
self.bit_length = self.bit_length or max_bit_length
|
||||
self.non_linear = KnownPrime(self.prime)
|
||||
else:
|
||||
self.non_linear = Prime(self.security)
|
||||
self.non_linear = Prime()
|
||||
if not self.bit_length:
|
||||
self.bit_length = 64
|
||||
print("Default bit length for compilation:", self.bit_length)
|
||||
@@ -197,6 +229,9 @@ class Program(object):
|
||||
self.cisc_to_function = True
|
||||
if not self.options.cisc:
|
||||
self.options.cisc = not self.options.optimize_hard
|
||||
self.use_tape_calls = True
|
||||
self.force_cisc_tape = False
|
||||
self.have_warned_trunc_pr = False
|
||||
|
||||
Program.prog = self
|
||||
from . import comparison, instructions, instructions_base, types
|
||||
@@ -225,7 +260,7 @@ class Program(object):
|
||||
os.mkdir(dirname)
|
||||
|
||||
# create extra directories if needed
|
||||
for dirname in ["Public-Input", "Bytecode", "Schedules"]:
|
||||
for dirname in ["Public-Input", "Bytecode", "Schedules", "Functions"]:
|
||||
if not os.path.exists(self.programs_dir + "/" + dirname):
|
||||
os.mkdir(self.programs_dir + "/" + dirname)
|
||||
|
||||
@@ -278,7 +313,8 @@ class Program(object):
|
||||
self.non_linear = Ring(ring_size)
|
||||
self.options.ring = str(ring_size)
|
||||
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False):
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False,
|
||||
finalize=True, **kwargs):
|
||||
"""
|
||||
Create a new tape from a function. See
|
||||
:py:func:`~Compiler.library.multithread` and
|
||||
@@ -309,11 +345,12 @@ class Program(object):
|
||||
self.curr_tape
|
||||
tape_index = len(self.tapes)
|
||||
self.tape_stack.append(self.curr_tape)
|
||||
self.curr_tape = Tape(name, self)
|
||||
self.curr_tape = Tape(name, self, **kwargs)
|
||||
self.curr_tape.singular = single_thread
|
||||
self.tapes.append(self.curr_tape)
|
||||
function(*args)
|
||||
self.finalize_tape(self.curr_tape)
|
||||
if finalize:
|
||||
self.finalize_tape(self.curr_tape)
|
||||
if self.tape_stack:
|
||||
self.curr_tape = self.tape_stack.pop()
|
||||
return tape_index
|
||||
@@ -346,6 +383,7 @@ class Program(object):
|
||||
thread_numbers = []
|
||||
while len(thread_numbers) < len(args):
|
||||
free_threads = self.curr_tape.free_threads
|
||||
self.curr_tape.ran_threads = True
|
||||
if free_threads:
|
||||
thread_numbers.append(min(free_threads))
|
||||
free_threads.remove(thread_numbers[-1])
|
||||
@@ -417,7 +455,10 @@ class Program(object):
|
||||
|
||||
def finalize_tape(self, tape):
|
||||
if not tape.purged:
|
||||
curr_tape = self.curr_tape
|
||||
self.curr_tape = tape
|
||||
tape.optimize(self.options)
|
||||
self.curr_tape = curr_tape
|
||||
tape.write_bytes()
|
||||
if self.options.asmoutfile:
|
||||
tape.write_str(self.options.asmoutfile + "-" + tape.name)
|
||||
@@ -472,16 +513,18 @@ class Program(object):
|
||||
self.allocated_mem[mem_type] += size
|
||||
if len(str(addr)) != len(str(addr + size)) and self.verbose:
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
if addr + size >= 2**64:
|
||||
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
|
||||
if addr + size >= MEM_MAX:
|
||||
raise CompilerError(
|
||||
"allocation exceeded for type '%s' after adding %d" % \
|
||||
(mem_type, size))
|
||||
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
|
||||
if single_size:
|
||||
from .library import get_thread_number, runtime_error_if
|
||||
from .library import get_arg, runtime_error_if
|
||||
bak = self.curr_tape.active_basicblock
|
||||
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
|
||||
tn = get_thread_number()
|
||||
runtime_error_if(tn > self.n_running_threads, "malloc")
|
||||
res = addr + single_size * (tn - 1)
|
||||
arg = get_arg()
|
||||
runtime_error_if(arg >= self.n_running_threads, "malloc")
|
||||
res = addr + single_size * arg
|
||||
self.curr_tape.active_basicblock = bak
|
||||
self.base_addresses[res] = addr
|
||||
return res
|
||||
@@ -577,7 +620,6 @@ class Program(object):
|
||||
def set_security(self, security):
|
||||
changed = self._security != security
|
||||
self._security = security
|
||||
self.non_linear.set_security(security)
|
||||
if changed:
|
||||
print("Changed statistical security for comparison etc. to",
|
||||
security)
|
||||
@@ -612,6 +654,14 @@ class Program(object):
|
||||
def use_trunc_pr(self, change):
|
||||
self._use_trunc_pr = change
|
||||
|
||||
def trunc_pr_warning(self):
|
||||
if not self.have_warned_trunc_pr:
|
||||
print("WARNING: Probabilistic truncation leaks some information, "
|
||||
"see https://eprint.iacr.org/2024/1127 for discussion. "
|
||||
"Use 'sfix.round_nearest = True' to deactivate this for "
|
||||
"fixed-point operations.")
|
||||
self.have_warned_trunc_pr = True
|
||||
|
||||
def use_edabit(self, change=None):
|
||||
"""Setting whether to use edaBits for non-linear
|
||||
functionality (default: false).
|
||||
@@ -783,7 +833,7 @@ class Program(object):
|
||||
class Tape:
|
||||
"""A tape contains a list of basic blocks, onto which instructions are added."""
|
||||
|
||||
def __init__(self, name, program):
|
||||
def __init__(self, name, program, thread_pool=None):
|
||||
"""Set prime p and the initial instructions and registers."""
|
||||
self.program = program
|
||||
name += "-%d" % program.get_tape_counter()
|
||||
@@ -800,12 +850,16 @@ class Tape:
|
||||
self.merge_opens = True
|
||||
self.if_states = []
|
||||
self.req_bit_length = defaultdict(lambda: 0)
|
||||
self.bit_length_reason = None
|
||||
self.function_basicblocks = {}
|
||||
self.functions = []
|
||||
self.singular = True
|
||||
self.free_threads = set()
|
||||
self.free_threads = set() if thread_pool is None else thread_pool
|
||||
self.loop_breaks = []
|
||||
self.warned_about_mem = False
|
||||
self.return_values = []
|
||||
self.ran_threads = False
|
||||
self.unused_decorators = {}
|
||||
|
||||
class BasicBlock(object):
|
||||
def __init__(self, parent, name, scope, exit_condition=None,
|
||||
@@ -880,7 +934,8 @@ class Tape:
|
||||
self.usage_instructions = list(filter(relevant, self.instructions))
|
||||
else:
|
||||
self.usage_instructions = []
|
||||
if len(self.usage_instructions) > 1000:
|
||||
if len(self.usage_instructions) > 1000 and \
|
||||
self.parent.program.verbose:
|
||||
print("Retaining %d instructions" % len(self.usage_instructions))
|
||||
del self.instructions
|
||||
self.purged = True
|
||||
@@ -1000,6 +1055,10 @@ class Tape:
|
||||
print()
|
||||
raise CompilerError("Unclosed if/else blocks, see tracebacks above")
|
||||
|
||||
if self.unused_decorators:
|
||||
raise CompilerError("Unused branching decorators, make sure to write " + ",".join(
|
||||
"'@%s' instead of '%s'" % (x, x) for x in set(self.unused_decorators.values())))
|
||||
|
||||
if self.program.verbose:
|
||||
print(
|
||||
"Processing tape", self.name, "with %d blocks" % len(self.basicblocks)
|
||||
@@ -1107,6 +1166,9 @@ class Tape:
|
||||
if addr.program == self and self.basicblocks:
|
||||
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
|
||||
|
||||
for reg in self.return_values:
|
||||
allocator.alloc_reg(reg, self.basicblocks[-1].alloc_pool)
|
||||
|
||||
seen = set()
|
||||
|
||||
def alloc(block):
|
||||
@@ -1214,7 +1276,10 @@ class Tape:
|
||||
Compiler.instructions.reqbl(bl, add_to_prog=False)
|
||||
)
|
||||
if self.program.verbose:
|
||||
print("Tape requires prime bit length", self.req_bit_length["p"])
|
||||
print("Tape requires prime bit length",
|
||||
self.req_bit_length["p"],
|
||||
('for %s' % self.bit_length_reason
|
||||
if self.bit_length_reason else ''))
|
||||
print("Tape requires galois bit length", self.req_bit_length["2"])
|
||||
|
||||
@unpurged
|
||||
@@ -1287,6 +1352,7 @@ class Tape:
|
||||
if "Bytecode" not in filename:
|
||||
filename = self.program.programs_dir + "/Bytecode/" + filename
|
||||
print("Writing to", filename)
|
||||
sys.stdout.flush()
|
||||
f = open(filename, "wb")
|
||||
h = hashlib.sha256()
|
||||
for i in self._get_instructions():
|
||||
@@ -1395,13 +1461,12 @@ class Tape:
|
||||
return repr(dict(self))
|
||||
|
||||
class ReqNode(object):
|
||||
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
|
||||
|
||||
def __init__(self, name):
|
||||
self._children = []
|
||||
self.name = name
|
||||
self.blocks = []
|
||||
self.aggregated = None
|
||||
self.num = None
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
@@ -1411,12 +1476,17 @@ class Tape:
|
||||
def aggregate(self, *args):
|
||||
if self.aggregated is not None:
|
||||
return self.aggregated
|
||||
self.recursion = self.num is not None
|
||||
if self.recursion:
|
||||
return Tape.ReqNum()
|
||||
self.num = Tape.ReqNum()
|
||||
for block in self.blocks:
|
||||
block.add_usage(self)
|
||||
res = reduce(
|
||||
lambda x, y: x + y.aggregate(self.name), self.children, self.num
|
||||
)
|
||||
if self.recursion:
|
||||
res *= float('inf')
|
||||
self.aggregated = res
|
||||
return res
|
||||
|
||||
@@ -1442,7 +1512,7 @@ class Tape:
|
||||
n_reps = self.aggregator([1])
|
||||
n_rounds = res["all", "round"]
|
||||
n_invs = res["all", "inv"]
|
||||
if (n_invs / n_rounds) * 1000 < n_reps:
|
||||
if (n_invs / n_rounds) * 1000 < n_reps and Program.prog.verbose:
|
||||
print(
|
||||
self.nodes[0].blocks[0].name,
|
||||
"blowing up rounds: ",
|
||||
@@ -1468,15 +1538,19 @@ class Tape:
|
||||
def close_scope(self, outer_scope, parent_req_node, name):
|
||||
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
|
||||
|
||||
def require_bit_length(self, bit_length, t="p"):
|
||||
def require_bit_length(self, bit_length, t="p", reason=None):
|
||||
if t == "p":
|
||||
if self.program.prime:
|
||||
if bit_length >= self.program.prime.bit_length() - 1:
|
||||
raise CompilerError(
|
||||
"required bit length %d too much for %d"
|
||||
% (bit_length, self.program.prime)
|
||||
+ ('(for %s)' % reason if reason else '')
|
||||
)
|
||||
self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t])
|
||||
bit_length += 1
|
||||
if bit_length > self.req_bit_length[t]:
|
||||
self.req_bit_length[t] = bit_length
|
||||
self.bit_length_reason = reason
|
||||
else:
|
||||
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
|
||||
|
||||
@@ -1492,12 +1566,24 @@ class Tape:
|
||||
def __bool__(self):
|
||||
raise CompilerError(
|
||||
"Cannot derive truth value from register. "
|
||||
"This is a catch-all error appearing if you try to use a "
|
||||
"run-time value where the compiler expects a compile-time "
|
||||
"value, most likely a Python integer. "
|
||||
"In some cases, you can fix this by using 'compile.py -l'."
|
||||
"See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-derive-truth-value-from-register"
|
||||
)
|
||||
|
||||
def __int__(self):
|
||||
raise CompilerError(
|
||||
"It is impossible to convert run-time types to compile-time "
|
||||
"Python types like int or float. The reason for this is that "
|
||||
"%s objects are only a placeholder during the execution in "
|
||||
"Python, the actual value of which is only defined in the "
|
||||
"virtual machine at a later time. See "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/journey.html "
|
||||
"to get an understanding of the overall design. "
|
||||
"In rare cases, you can fix this by using 'compile.py -l'." % \
|
||||
type(self).__name__
|
||||
)
|
||||
|
||||
__float__ = __int__
|
||||
|
||||
class Register(_no_truth):
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
@@ -1515,6 +1601,7 @@ class Tape:
|
||||
"caller",
|
||||
"can_eliminate",
|
||||
"duplicates",
|
||||
"dup_count",
|
||||
"block",
|
||||
]
|
||||
maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1
|
||||
@@ -1547,6 +1634,7 @@ class Tape:
|
||||
self.vector = []
|
||||
self.can_eliminate = True
|
||||
self.duplicates = util.set_by_id([self])
|
||||
self.dup_count = None
|
||||
if Program.prog.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
else:
|
||||
@@ -1619,6 +1707,9 @@ class Tape:
|
||||
def copy(self):
|
||||
return Tape.Register(self.reg_type, Program.prog.curr_tape)
|
||||
|
||||
def same_type(self):
|
||||
return type(self)(size=self.size)
|
||||
|
||||
def link(self, other):
|
||||
if Program.prog.options.noreallocate:
|
||||
raise CompilerError("reallocation necessary for linking, "
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from Compiler import types, library, instructions
|
||||
from Compiler import comparison, util
|
||||
|
||||
def dest_comp(B):
|
||||
Bt = B.transpose()
|
||||
@@ -10,7 +11,7 @@ def dest_comp(B):
|
||||
return sum(Tt) - 1
|
||||
|
||||
def reveal_sort(k, D, reverse=False):
|
||||
""" Sort in place according to "perfect" key. The name hints at the fact
|
||||
r""" Sort in place according to "perfect" key. The name hints at the fact
|
||||
that a random order of the keys is revealed.
|
||||
|
||||
:param k: vector or Array of sint containing exactly :math:`0,\dots,n-1`
|
||||
@@ -20,6 +21,7 @@ def reveal_sort(k, D, reverse=False):
|
||||
backward order
|
||||
|
||||
"""
|
||||
comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting')
|
||||
assert len(k) == len(D)
|
||||
library.break_point()
|
||||
shuffle = types.sint.get_secure_shuffle(len(k))
|
||||
|
||||
@@ -109,7 +109,7 @@ class SqrtOram(Generic[T, B]):
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shufflei = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
self.index_type(regint.inc(self.n)))
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
@@ -122,7 +122,7 @@ class SqrtOram(Generic[T, B]):
|
||||
# Note that self.shuffle_the_shuffle mutates this field
|
||||
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
|
||||
self.permutation = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
self.index_type(regint.inc(self.n)))
|
||||
# We allow the caller to postpone the initialization of the shuffle
|
||||
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
|
||||
# Note that if you do not initialize, the ORAM is insecure
|
||||
@@ -256,7 +256,6 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
global trace, n_parallel
|
||||
if trace:
|
||||
@@ -271,7 +270,12 @@ class SqrtOram(Generic[T, B]):
|
||||
else:
|
||||
raise Exception("Cannot handle type of value passed")
|
||||
print(self.entry_length, value, type(value),len(value))
|
||||
value = MemValue(value)
|
||||
|
||||
self._write(index, *value)
|
||||
|
||||
@lib.method_block
|
||||
def _write(self, index: T, *value: T):
|
||||
value = MemValue(self.value_type(value))
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@@ -513,14 +517,14 @@ class SqrtOram(Generic[T, B]):
|
||||
# Since the underlying memory of the position map is already aligned in
|
||||
# this packed structure, we can simply overwrite the memory while
|
||||
# maintaining the structure.
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
self.position_map.reinitialize(self.permutation)
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
def reinitialize(self, data: T):
|
||||
# Note that this method is only used during refresh, and as such is
|
||||
# only called with a permutation as data.
|
||||
|
||||
# The logical addresses of some previous permutation are irrelevant and must be reset
|
||||
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
|
||||
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
@@ -530,10 +534,10 @@ class SqrtOram(Generic[T, B]):
|
||||
# This structure is preserved while overwriting the values using
|
||||
# assign_vector
|
||||
self.shuffle.assign_vector(self.value_type(
|
||||
data, size=self.n * self.entry_length))
|
||||
data[:], size=self.n * self.entry_length))
|
||||
# Note that this updates self.permutation (see constructor for explanation)
|
||||
self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
self.position_map.reinitialize(self.permutation)
|
||||
|
||||
def _reset_shuffle_used(self):
|
||||
global allow_memory_allocation
|
||||
@@ -568,7 +572,7 @@ class PositionMap(Generic[T, B]):
|
||||
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
|
||||
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
def reinitialize(self, permutation: T):
|
||||
"""Reinitialize this PositionMap.
|
||||
|
||||
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
|
||||
@@ -613,9 +617,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
packed_size = int(math.ceil(self.n / pack))
|
||||
packed_structure = MultiArray(
|
||||
(packed_size, pack), value_type=value_type)
|
||||
for i in range(packed_size):
|
||||
@lib.for_range(packed_size)
|
||||
def _(i):
|
||||
packed_structure[i] = Array.create_from(
|
||||
permutation[i*pack:(i+1)*pack])
|
||||
permutation.get_vector(base=i * pack, size=pack))
|
||||
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
||||
period=period, entry_length=pack, k=self.depth,
|
||||
@@ -720,8 +725,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
SqrtOram.reinitialize(self, *permutation)
|
||||
def reinitialize(self, permutation: T):
|
||||
SqrtOram.reinitialize(self, permutation)
|
||||
|
||||
|
||||
class LinearPositionMap(PositionMap):
|
||||
@@ -790,8 +795,8 @@ class LinearPositionMap(PositionMap):
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
self.physical.assign_vector(data)
|
||||
def reinitialize(self, data : T):
|
||||
self.physical.assign(data)
|
||||
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,11 +28,11 @@ def greater_than(a, b, bits):
|
||||
else:
|
||||
return a.greater_than(b, bits)
|
||||
|
||||
def pow2_value(a, bit_length=None, security=None):
|
||||
def pow2_value(a, bit_length=None):
|
||||
if is_constant_float(a):
|
||||
return 2**a
|
||||
else:
|
||||
return a.pow2(bit_length, security)
|
||||
return a.pow2(bit_length)
|
||||
|
||||
def mod2m(a, b, bits, signed):
|
||||
if isinstance(a, int):
|
||||
@@ -292,6 +292,9 @@ class dict_by_id(object):
|
||||
def __iter__(self):
|
||||
return self.keys()
|
||||
|
||||
def pop(self, key):
|
||||
return self.content.pop(id(key), None)
|
||||
|
||||
class defaultdict_by_id(dict_by_id):
|
||||
def __init__(self, default):
|
||||
dict_by_id.__init__(self)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
int main()
|
||||
{
|
||||
P256Element::init();
|
||||
P256Element::Scalar key;
|
||||
KeySetup<Share<P256Element::Scalar>> key;
|
||||
string prefix = PREP_DIR "ECDSA/";
|
||||
mkdir_p(prefix.c_str());
|
||||
write_online_setup(prefix, P256Element::Scalar::pr());
|
||||
|
||||
@@ -10,10 +10,12 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "ECDSA/P256Element.h"
|
||||
#include "GC/VectorInput.h"
|
||||
#include "Protocols/SPDZ.h"
|
||||
|
||||
#include "ECDSA/preprocessing.hpp"
|
||||
#include "ECDSA/sign.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
@@ -44,6 +46,7 @@ int main(int argc, const char** argv)
|
||||
typedef Share<P256Element::Scalar> pShare;
|
||||
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
|
||||
read_mac_key(prefix, N, keyp);
|
||||
pShare::set_mac_key(keyp);
|
||||
|
||||
pShare::MAC_Check::setup(P);
|
||||
Share<P256Element>::MAC_Check::setup(P);
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
#include "GC/RepPrep.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/Shamir.hpp"
|
||||
#include "Machines/MalRep.hpp"
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "ECDSA/P256Element.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "Protocols/SPDZ.h"
|
||||
#include "Processor/BaseMachine.h"
|
||||
|
||||
#include "ECDSA/preprocessing.hpp"
|
||||
@@ -15,6 +16,7 @@
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "Processor/Processor.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
|
||||
@@ -37,6 +37,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
EcdsaOptions opts)
|
||||
{
|
||||
bool prep_mul = opts.prep_mul;
|
||||
if (prep_mul)
|
||||
proc.protocol.init_mul();
|
||||
|
||||
Timer timer;
|
||||
timer.start();
|
||||
Player& P = proc.P;
|
||||
@@ -77,7 +80,6 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
protocol.prepare_mul(inv_ks[i], sk);
|
||||
protocol.start_exchange();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "GC/SemiSecret.h"
|
||||
#include "GC/SemiPrep.h"
|
||||
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/SemiPrep.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
|
||||
@@ -24,6 +24,13 @@ Client::Client(const vector<string>& hostnames, int port_base,
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
else
|
||||
{
|
||||
octetStream spec;
|
||||
spec.Receive(sockets[i]);
|
||||
if (spec != specification)
|
||||
throw runtime_error("inconsistent specification");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,11 +17,10 @@
|
||||
* - share of winning unique id * random value [w]
|
||||
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
|
||||
*
|
||||
* To run with 2 parties / SPDZ engines:
|
||||
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
|
||||
* To run:
|
||||
* ./Scripts/setup-clients.sh to create SSL keys and certificates for clients
|
||||
* ./compile.py bankers_bonus
|
||||
* ./Scripts/run-online.sh bankers_bonus to run the engines.
|
||||
* ./Scripts/compile-run.py <protocol> bankers_bonus to compile and run the engines.
|
||||
* (See https://github.com/data61/MP-SPDZ/?tab=readme-ov-file#protocols for options.)
|
||||
*
|
||||
* ./bankers-bonus-client.x 0 2 100 0
|
||||
* ./bankers-bonus-client.x 1 2 200 0
|
||||
|
||||
@@ -31,6 +31,15 @@ def set_keepalive_osx(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
|
||||
|
||||
class Client:
|
||||
"""Client to servers running secure computation. Works both as a client
|
||||
to all parties or a trusted client to a single party.
|
||||
|
||||
:param hostnames: hostnames or IP addresses to connect to
|
||||
:param port_base: port number for first hostname,
|
||||
increases by one for every additional hostname
|
||||
:param my_client_id: number to identify client
|
||||
|
||||
"""
|
||||
def __init__(self, hostnames, port_base, my_client_id):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
name = 'C%d' % my_client_id
|
||||
@@ -62,6 +71,11 @@ class Client:
|
||||
|
||||
self.specification = octetStream()
|
||||
self.specification.Receive(self.sockets[0])
|
||||
for sock in self.sockets[1:]:
|
||||
specification = octetStream()
|
||||
specification.Receive(sock)
|
||||
if specification.buf != self.specification.buf:
|
||||
raise Exception('inconsistent specification')
|
||||
type = self.specification.get_int(4)
|
||||
if type == ord('R'):
|
||||
self.domain = Z2(self.specification.get_int(4))
|
||||
@@ -99,6 +113,12 @@ class Client:
|
||||
return triples
|
||||
|
||||
def send_private_inputs(self, values):
|
||||
""" Send inputs privately to the computation servers.
|
||||
This assumes that the client is connected to all servers.
|
||||
|
||||
:param values: list of input values
|
||||
|
||||
"""
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, len(values))
|
||||
os = octetStream()
|
||||
@@ -109,10 +129,46 @@ class Client:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_outputs(self, n):
|
||||
""" Receive outputs privately from the computation servers.
|
||||
This assumes that the client is connected to all servers.
|
||||
|
||||
:param n: number of outputs
|
||||
|
||||
"""
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, n)
|
||||
return [int(self.clear_domain(triple[0].v)) for triple in triples]
|
||||
|
||||
def send_public_inputs(self, values):
|
||||
""" Send values in the clear. This works for public inputs
|
||||
to all servers or to send shares to a single server.
|
||||
|
||||
:param values: list of values
|
||||
|
||||
"""
|
||||
os = octetStream()
|
||||
for value in values:
|
||||
self.domain(value).pack(os)
|
||||
for socket in self.sockets:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_plain_values(self, socket=None):
|
||||
""" Receive values in the clear. This works for public inputs
|
||||
to all servers or to send shares to a single server.
|
||||
|
||||
:param socket: socket to use (need to specify it there is more than one)
|
||||
|
||||
"""
|
||||
if socket is None:
|
||||
if len(self.sockets) != 1:
|
||||
raise Exception('need to specify socket')
|
||||
socket = self.sockets[0]
|
||||
os = octetStream()
|
||||
os.Receive(socket)
|
||||
assert len(os) % self.domain.size() == 0
|
||||
return [int(os.get(self.domain))
|
||||
for i in range(len(os) // self.domain.size())]
|
||||
|
||||
class octetStream:
|
||||
def __init__(self, value=None):
|
||||
self.buf = b''
|
||||
@@ -123,6 +179,8 @@ class octetStream:
|
||||
def get_length(self):
|
||||
return len(self.buf)
|
||||
|
||||
__len__ = get_length
|
||||
|
||||
def reset_write_head(self):
|
||||
self.buf = b''
|
||||
self.ptr = 0
|
||||
@@ -164,6 +222,11 @@ class octetStream:
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get(self, type):
|
||||
res = type()
|
||||
res.unpack(self)
|
||||
return res
|
||||
|
||||
def consume(self, length):
|
||||
self.ptr += length
|
||||
assert self.ptr <= len(self.buf)
|
||||
|
||||
@@ -7,7 +7,7 @@ class Domain:
|
||||
|
||||
def __int__(self):
|
||||
res = self.v % self.modulus
|
||||
return res if res < self.modulus / 2 else res - self.modulus
|
||||
return int(res if res < self.modulus / 2 else res - self.modulus)
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
|
||||
19
ExternalIO/personal-client-example.py
Executable file
19
ExternalIO/personal-client-example.py
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import sys, random
|
||||
|
||||
sys.path.insert(0, 'ExternalIO')
|
||||
|
||||
from client import *
|
||||
|
||||
party = int(sys.argv[1])
|
||||
|
||||
client = Client(['localhost'], 15000 + party, 0)
|
||||
|
||||
n = 1000
|
||||
|
||||
if party < 2:
|
||||
client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n))
|
||||
|
||||
x = [client.receive_plain_values() for i in range(2)]
|
||||
client.send_public_inputs(a + b for a, b in zip(*x))
|
||||
@@ -16,14 +16,14 @@ template<class T>
|
||||
class AddableVector: public vector<T>
|
||||
{
|
||||
public:
|
||||
AddableVector<T>() {}
|
||||
AddableVector<T>(size_t n, const T& x = T()) : vector<T>(n, x) {}
|
||||
AddableVector() {}
|
||||
AddableVector(size_t n, const T& x = T()) : vector<T>(n, x) {}
|
||||
template <class U, class FD, class S>
|
||||
AddableVector<T>(const Plaintext<U,FD,S>& other) :
|
||||
AddableVector<T>(other.get_poly()) {}
|
||||
AddableVector(const Plaintext<U,FD,S>& other) :
|
||||
AddableVector(other.get_poly()) {}
|
||||
|
||||
template <class U>
|
||||
AddableVector<T>(const vector<U>& other)
|
||||
AddableVector(const vector<U>& other)
|
||||
{
|
||||
this->assign(other.begin(), other.end());
|
||||
}
|
||||
@@ -129,29 +129,41 @@ public:
|
||||
(*this)[i].pack(os);
|
||||
}
|
||||
|
||||
void unpack_size(octetStream& os, const T& init = T())
|
||||
size_t unpack_size(octetStream& os)
|
||||
{
|
||||
unsigned int size;
|
||||
os.get(size);
|
||||
this->resize(size, init);
|
||||
this->reserve(size);
|
||||
return size;
|
||||
}
|
||||
|
||||
void unpack(octetStream& os, const T& init = T())
|
||||
{
|
||||
unpack_size(os, init);
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
(*this)[i].unpack(os);
|
||||
size_t new_size = unpack_size(os);
|
||||
this->clear();
|
||||
for (unsigned int i = 0; i < new_size; i++)
|
||||
{
|
||||
this->push_back(init);
|
||||
this->back().unpack(os);
|
||||
}
|
||||
}
|
||||
|
||||
void add(octetStream& os, T& tmp)
|
||||
{
|
||||
unpack_size(os, tmp);
|
||||
size_t new_size = unpack_size(os);
|
||||
T init = tmp;
|
||||
T& item = tmp;
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
{
|
||||
item.unpack(os);
|
||||
(*this)[i] += item;
|
||||
}
|
||||
for (size_t i = this->size(); i < new_size; i++)
|
||||
{
|
||||
item.unpack(os);
|
||||
this->push_back(init);
|
||||
this->back() += item;
|
||||
}
|
||||
}
|
||||
|
||||
T infinity_norm() const
|
||||
|
||||
@@ -61,9 +61,7 @@ bigint FHE_Params::Q() const
|
||||
|
||||
void FHE_Params::pack(octetStream& o) const
|
||||
{
|
||||
o.store(FFTData.size());
|
||||
for(auto& fd: FFTData)
|
||||
fd.pack(o);
|
||||
o.store(FFTData);
|
||||
Chi.pack(o);
|
||||
Bval.pack(o);
|
||||
o.store(sec_p);
|
||||
@@ -73,11 +71,7 @@ void FHE_Params::pack(octetStream& o) const
|
||||
|
||||
void FHE_Params::unpack(octetStream& o)
|
||||
{
|
||||
size_t size;
|
||||
o.get(size);
|
||||
FFTData.resize(size);
|
||||
for (auto& fd : FFTData)
|
||||
fd.unpack(o);
|
||||
o.get(FFTData);
|
||||
Chi.unpack(o);
|
||||
Bval.unpack(o);
|
||||
o.get(sec_p);
|
||||
|
||||
@@ -311,22 +311,18 @@ void imatrix::hash(octetStream& o) const
|
||||
|
||||
void imatrix::pack(octetStream& o) const
|
||||
{
|
||||
o.store(size());
|
||||
for (auto& x : *this)
|
||||
{
|
||||
assert(x.size() == size());
|
||||
x.pack(o);
|
||||
}
|
||||
o.store(static_cast<const super&>(*this));
|
||||
}
|
||||
|
||||
void imatrix::unpack(octetStream& o)
|
||||
{
|
||||
size_t size;
|
||||
o.get(size);
|
||||
resize(size);
|
||||
o.get(static_cast<super&>(*this));
|
||||
for (auto& x : *this)
|
||||
{
|
||||
x.resize(size);
|
||||
x.unpack(o);
|
||||
assert(x.size() == size());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ typedef vector< vector<bigint> > matrix;
|
||||
|
||||
class imatrix : public vector< BitVector >
|
||||
{
|
||||
typedef vector<BitVector> super;
|
||||
|
||||
public:
|
||||
bool operator!=(const imatrix& other) const;
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
#include "FHEOffline/Proof.h"
|
||||
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
@@ -735,9 +737,9 @@ void load_or_generate(P2Data& P2D, const Ring& R)
|
||||
{
|
||||
P2D.load(R);
|
||||
}
|
||||
catch (...)
|
||||
catch (exception& e)
|
||||
{
|
||||
cout << "Loading failed" << endl;
|
||||
cerr << "Loading parameters failed, generating (" << e.what() << ")" << endl;
|
||||
init(P2D,R);
|
||||
P2D.store(R);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "FHE/P2Data.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
#include <fstream>
|
||||
|
||||
|
||||
@@ -74,7 +75,6 @@ bool P2Data::operator!=(const P2Data& other) const
|
||||
|
||||
void P2Data::hash(octetStream& o) const
|
||||
{
|
||||
check_dimensions();
|
||||
o.store(gf2n_short::degree());
|
||||
o.store(slots);
|
||||
A.hash(o);
|
||||
@@ -113,17 +113,18 @@ string get_filename(const Ring& Rg)
|
||||
void P2Data::load(const Ring& Rg)
|
||||
{
|
||||
string filename = get_filename(Rg);
|
||||
cout << "Loading from " << filename << endl;
|
||||
ifstream s(filename);
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Loading from " << filename << endl;
|
||||
octetStream os;
|
||||
os.input(s);
|
||||
os.input(filename);
|
||||
unpack(os);
|
||||
}
|
||||
|
||||
void P2Data::store(const Ring& Rg) const
|
||||
{
|
||||
string filename = get_filename(Rg);
|
||||
cout << "Storing in " << filename << endl;
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Storing in " << filename << endl;
|
||||
ofstream s(filename);
|
||||
octetStream os;
|
||||
pack(os);
|
||||
|
||||
@@ -562,22 +562,17 @@ template <class T,class FD,class S>
|
||||
void Plaintext<T,FD,S>::pack(octetStream& o) const
|
||||
{
|
||||
to_poly();
|
||||
o.store((unsigned int)b.size());
|
||||
for (unsigned int i = 0; i < b.size(); i++)
|
||||
o.store(b[i]);
|
||||
o.store(b);
|
||||
}
|
||||
|
||||
template <class T,class FD,class S>
|
||||
void Plaintext<T,FD,S>::unpack(octetStream& o)
|
||||
{
|
||||
type = Polynomial;
|
||||
unsigned int size;
|
||||
o.get(size);
|
||||
allocate();
|
||||
o.get(b);
|
||||
auto size = b.size();
|
||||
allocate(Polynomial);
|
||||
if (size != b.size() and size != 0)
|
||||
throw length_error("unexpected length received");
|
||||
for (unsigned int i = 0; i < size; i++)
|
||||
b[i] = o.get<S>();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -512,9 +512,7 @@ modp Ring_Element::get_constant() const
|
||||
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
|
||||
{
|
||||
ZpD.pack(o);
|
||||
o.store((int)v.size());
|
||||
for (unsigned int i=0; i<v.size(); i++)
|
||||
{ v[i].pack(o,ZpD); }
|
||||
o.store(v);
|
||||
}
|
||||
|
||||
|
||||
@@ -526,16 +524,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
throw runtime_error(
|
||||
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
|
||||
+ to_string(ZpD.pr_bit_length));
|
||||
unsigned int length;
|
||||
o.get(length);
|
||||
v.clear();
|
||||
v.reserve(length);
|
||||
modp tmp;
|
||||
for (unsigned int i=0; i<length; i++)
|
||||
{
|
||||
tmp.unpack(o,ZpD);
|
||||
v.push_back(tmp);
|
||||
}
|
||||
o.get(v);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef _Subroutines
|
||||
#define _Subroutines
|
||||
#ifndef FHE_SUBROUTINES_H_
|
||||
#define FHE_SUBROUTINES_H_
|
||||
|
||||
|
||||
#include "Math/Zp_Data.h"
|
||||
|
||||
@@ -27,8 +27,7 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
|
||||
|
||||
try
|
||||
{
|
||||
ifstream input(filename);
|
||||
os.input(input);
|
||||
os.input(filename);
|
||||
setup.unpack(os);
|
||||
machine.unpack(os);
|
||||
}
|
||||
@@ -44,11 +43,15 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
|
||||
}
|
||||
catch (mismatch_among_parties& e)
|
||||
{
|
||||
error = e.what();
|
||||
if (error.empty())
|
||||
error = e.what();
|
||||
}
|
||||
|
||||
if (not error.empty())
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("expect_setup"))
|
||||
throw runtime_error("error in setup: " + error);
|
||||
|
||||
cerr << "Running secrets generation because no suitable material "
|
||||
"from a previous run was found (" << error << ")" << endl;
|
||||
setup.key_and_mac_generation(P, machine, num_runs, V());
|
||||
|
||||
@@ -80,9 +80,8 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
|
||||
try
|
||||
{
|
||||
ifstream file(filename);
|
||||
octetStream os;
|
||||
os.input(file);
|
||||
os.input(filename);
|
||||
os.get(machine.extra_slack);
|
||||
setup.unpack(os);
|
||||
}
|
||||
@@ -95,13 +94,17 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
{
|
||||
setup.check(P, machine);
|
||||
}
|
||||
catch (exception& e)
|
||||
catch (mismatch_among_parties& e)
|
||||
{
|
||||
reason = e.what();
|
||||
if (reason.empty())
|
||||
reason = e.what();
|
||||
}
|
||||
|
||||
if (not reason.empty())
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("expect_setup"))
|
||||
throw runtime_error("error in setup: " + reason);
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Generating parameters for security " << sec
|
||||
<< " and field size ~2^" << plaintext_length
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "SimpleMachine.h"
|
||||
#include "Tools/mkpath.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
template<class FD>
|
||||
Producer<FD>::Producer(int output_thread, bool write_output) :
|
||||
n_slots(0), output_thread(output_thread), write_output(write_output),
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
#include "BitAdder.h"
|
||||
|
||||
#include "Protocols/BufferScope.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
template<class T>
|
||||
@@ -69,6 +71,9 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
|
||||
size_t n_items = end - begin;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%lu ANDs in bit adder\n", length * n_items * n_bits);
|
||||
|
||||
if (supply)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
@@ -85,6 +90,7 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
vector<T> carries(n_items);
|
||||
vector<T> a(n_items), b(n_items);
|
||||
auto& protocol = proc.protocol;
|
||||
BufferScope scope(proc.DataF, n_items * length * n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
assert(summands[i].size() == 2);
|
||||
|
||||
@@ -24,14 +24,14 @@ void FakeSecret::load_clear(int n, const Integer& x)
|
||||
*this = x;
|
||||
}
|
||||
|
||||
void FakeSecret::bitcom(Memory<FakeSecret>& S, const vector<int>& regs)
|
||||
void FakeSecret::bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
void FakeSecret::bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const
|
||||
void FakeSecret::bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
|
||||
@@ -142,8 +142,8 @@ public:
|
||||
template <class T>
|
||||
void store(Memory<T>& mem, size_t address) { mem[address] = *this; }
|
||||
|
||||
void bitcom(Memory<FakeSecret>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const;
|
||||
|
||||
template <class T>
|
||||
void xor_(int n, const FakeSecret& x, const T& y)
|
||||
|
||||
@@ -82,10 +82,24 @@ void Instruction::parse(istream& s, int pos)
|
||||
#undef X
|
||||
default:
|
||||
ostringstream os;
|
||||
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
|
||||
os << "This virtual machine executes binary circuits only." << endl;
|
||||
os << "Use 'compile.py -B'." << endl;
|
||||
throw Invalid_Instruction(os.str());
|
||||
os << "Code not defined for instruction ";
|
||||
bool known = true;
|
||||
switch (opcode)
|
||||
{
|
||||
#define X(NAME, PRE, CODE) case NAME: os << #NAME; break;
|
||||
ALL_INSTRUCTIONS
|
||||
#undef X
|
||||
default:
|
||||
known = false;
|
||||
os << showbase << hex << opcode << dec;
|
||||
}
|
||||
os << endl;
|
||||
if (known)
|
||||
{
|
||||
os << "This virtual machine executes binary circuits only. ";
|
||||
os << "Use 'compile.py -B'.";
|
||||
}
|
||||
exit_error(os.str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ public:
|
||||
|
||||
typedef MaliciousRepMC<U> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef HashMaliciousRepMC<U> DefaultMC;
|
||||
|
||||
typedef ReplicatedInput<U> Input;
|
||||
typedef RepPrep<U> LivePrep;
|
||||
|
||||
15
GC/NoShare.h
15
GC/NoShare.h
@@ -52,7 +52,7 @@ public:
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
throw not_implemented();
|
||||
return DATA_GF2;
|
||||
}
|
||||
|
||||
static void init_minimum(int)
|
||||
@@ -80,7 +80,8 @@ public:
|
||||
|
||||
bool operator!=(NoValue) const { return false; }
|
||||
|
||||
bool operator==(int) { fail(); return false; }
|
||||
bool operator==(int) const { fail(); return false; }
|
||||
bool operator==(NoValue) const { fail(); return false; }
|
||||
|
||||
bool get_bit(int) { fail(); return 0; }
|
||||
|
||||
@@ -92,6 +93,8 @@ public:
|
||||
|
||||
void input(istream&, bool) { fail(); }
|
||||
void output(ostream&, bool) {}
|
||||
|
||||
void pack(octetStream&) const { fail(); }
|
||||
};
|
||||
|
||||
inline ostream& operator<<(ostream& o, NoValue)
|
||||
@@ -169,8 +172,8 @@ public:
|
||||
|
||||
void load_clear(Integer, Integer) { fail(); }
|
||||
void random_bit() { fail(); }
|
||||
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitdec(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitcom(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
|
||||
void assign(const char*) { fail(); }
|
||||
|
||||
@@ -190,6 +193,8 @@ public:
|
||||
|
||||
NoShare& operator+=(const NoShare&) { fail(); return *this; }
|
||||
|
||||
bool operator==(NoShare) const { fail(); return false; }
|
||||
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void xor_bit(int, NoShare) const { fail(); }
|
||||
@@ -201,6 +206,8 @@ public:
|
||||
|
||||
void input(istream&, bool) { fail(); }
|
||||
void output(ostream&, bool) { fail(); }
|
||||
|
||||
void pack(octetStream&) const { fail(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -79,10 +79,11 @@ template<class T>
|
||||
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
size_t begin, size_t end)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
|
||||
RunningTimer timer;
|
||||
#endif
|
||||
bool verbose = OnlineOptions::singleton.has_option("verbose_eda");
|
||||
if (verbose)
|
||||
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
|
||||
|
||||
auto& party = ShareThread<typename T::whole_type>::s();
|
||||
auto& MC = party.MC->get_part_MC();
|
||||
auto& P = *party.P;
|
||||
@@ -102,9 +103,9 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
input.exchange();
|
||||
for (size_t i = begin; i < end; i++)
|
||||
triples[i][2] = input.finalize(input_player, T::default_length);
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
|
||||
#endif
|
||||
|
||||
if (verbose)
|
||||
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ using namespace std;
|
||||
#include "Math/Integer.h"
|
||||
#include "Processor/ProcessorBase.h"
|
||||
#include "Processor/Instruction.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -38,9 +39,9 @@ public:
|
||||
// rough measure for the memory usage
|
||||
size_t complexity;
|
||||
|
||||
Memory<T> S;
|
||||
Memory<Clear> C;
|
||||
Memory<Integer> I;
|
||||
StackedVector<T> S;
|
||||
StackedVector<Clear> C;
|
||||
StackedVector<Integer> I;
|
||||
|
||||
Timer xor_timer;
|
||||
|
||||
@@ -78,8 +79,8 @@ public:
|
||||
template<class U>
|
||||
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
|
||||
|
||||
template<class U>
|
||||
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
template<class U, class V>
|
||||
void mem_op(int n, U& dest, const V& source,
|
||||
Integer dest_address, Integer source_address);
|
||||
|
||||
void xors(const vector<int>& args);
|
||||
@@ -105,6 +106,9 @@ public:
|
||||
template<int = 0>
|
||||
void convcbit2s(const BaseInstruction& instruction);
|
||||
|
||||
void convcbitvec(const BaseInstruction& instruction, StackedVector<Integer>& Ci,
|
||||
Player* P);
|
||||
|
||||
void print_reg(int reg, int n, int size);
|
||||
void print_reg_plain(Clear& value);
|
||||
void print_reg_signed(unsigned n_bits, Integer value);
|
||||
@@ -114,6 +118,13 @@ public:
|
||||
void print_float_prec(int n);
|
||||
|
||||
void incint(const BaseInstruction& instruction);
|
||||
|
||||
void push_stack();
|
||||
void push_args(const vector<int>& args);
|
||||
void pop_stack(const vector<int>& results);
|
||||
|
||||
template<class U>
|
||||
void call_tape(const BaseInstruction& instruction, U& dynamic_memory);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -18,7 +18,7 @@ using namespace std;
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
#include "GC/Machine.hpp"
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "Processor/Processor.hpp"
|
||||
#include "Processor/IntInput.hpp"
|
||||
#include "Math/bigint.hpp"
|
||||
|
||||
@@ -53,9 +53,9 @@ template <class T>
|
||||
template <class U>
|
||||
void Processor<T>::reset(const U& program, int arg)
|
||||
{
|
||||
S.resize(program.num_reg(SBIT), "registers");
|
||||
C.resize(program.num_reg(CBIT), "registers");
|
||||
I.resize(program.num_reg(INT), "registers");
|
||||
S.resize(program.num_reg(SBIT));
|
||||
C.resize(program.num_reg(CBIT));
|
||||
I.resize(program.num_reg(INT));
|
||||
set_arg(arg);
|
||||
PC = 0;
|
||||
}
|
||||
@@ -202,14 +202,14 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
template<class U, class V>
|
||||
void Processor<T>::mem_op(int n, U& dest, const V& source,
|
||||
Integer dest_address, Integer source_address)
|
||||
{
|
||||
dest.check_index(dest_address + n - 1);
|
||||
source.check_index(source_address + n - 1);
|
||||
auto d = &dest[dest_address];
|
||||
auto s = &source[source_address];
|
||||
auto d = &dest[dest_address.get()];
|
||||
auto s = &source[source_address.get()];
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
*d++ = *s++;
|
||||
@@ -388,6 +388,30 @@ void Processor<T>::convcbit2s(const BaseInstruction& instruction)
|
||||
min(size_t(unit), instruction.get_n() - i * unit));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::convcbitvec(const BaseInstruction& instruction,
|
||||
StackedVector<Integer>& Ci, Player* P)
|
||||
{
|
||||
vector<Integer> bits;
|
||||
auto n = instruction.get_n();
|
||||
bits.reserve(n);
|
||||
for (size_t i = 0; i < instruction.get_n(); i++)
|
||||
{
|
||||
int i1 = i / GC::Clear::N_BITS;
|
||||
int i2 = i % GC::Clear::N_BITS;
|
||||
auto bit = C[instruction.get_r(1) + i1].get_bit(i2);
|
||||
bits.push_back(bit);
|
||||
}
|
||||
|
||||
if (P)
|
||||
sync<T>(bits, *P);
|
||||
else if (not T::symmetric)
|
||||
sync<T>(bits, *Thread<T>::s().P);
|
||||
|
||||
for (size_t i = 0; i < n; i++)
|
||||
Ci[instruction.get_r(0) + i] = bits[i];
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg(int reg, int n, int size)
|
||||
{
|
||||
@@ -417,7 +441,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
|
||||
{
|
||||
if (n_bits <= Clear::N_BITS)
|
||||
{
|
||||
auto value = C[reg];
|
||||
auto value = C[reg.get()];
|
||||
unsigned n_shift = 0;
|
||||
if (n_bits > 1)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
@@ -477,6 +501,56 @@ void Processor<T>::incint(const BaseInstruction& instruction)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::push_stack()
|
||||
{
|
||||
S.push_stack();
|
||||
C.push_stack();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::push_args(const vector<int>& args)
|
||||
{
|
||||
S.push_args(args, SBIT);
|
||||
C.push_args(args, CBIT);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::pop_stack(const vector<int>& results)
|
||||
{
|
||||
S.pop_stack(results, SBIT);
|
||||
C.pop_stack(results, CBIT);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void Processor<T>::call_tape(const BaseInstruction& instruction, U& dynamic_memory)
|
||||
{
|
||||
auto new_arg = I.at(instruction.get_r(1)).get();
|
||||
|
||||
PC_stack.push_back(PC);
|
||||
arg_stack.push_back(this->arg);
|
||||
push_stack();
|
||||
I.push_stack();
|
||||
|
||||
auto& tape = machine->progs.at(instruction.get_r(0));
|
||||
reset(tape, new_arg);
|
||||
|
||||
auto& args = instruction.get_start();
|
||||
push_args(args);
|
||||
I.push_args(args, INT);
|
||||
|
||||
tape.execute(*this, dynamic_memory, PC);
|
||||
|
||||
pop_stack(args);
|
||||
I.pop_stack(args, INT);
|
||||
|
||||
PC = PC_stack.back();
|
||||
PC_stack.pop_back();
|
||||
this->arg = arg_stack.back();
|
||||
arg_stack.pop_back();
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -71,9 +71,9 @@ void Program::parse(istream& s)
|
||||
CALLGRIND_STOP_INSTRUMENTATION;
|
||||
while (!s.eof())
|
||||
{
|
||||
instr.parse(s, pos);
|
||||
if (s.bad() or s.fail())
|
||||
throw runtime_error("error reading program");
|
||||
instr.parse(s, pos);
|
||||
p.push_back(instr);
|
||||
//cerr << "\t" << instr << endl;
|
||||
s.peek();
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "Protocols/Rep4.hpp"
|
||||
#include "Protocols/Rep4Input.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
|
||||
namespace GC
|
||||
|
||||
11
GC/Secret.h
11
GC/Secret.h
@@ -65,6 +65,8 @@ public:
|
||||
|
||||
typedef typename T::out_type out_type;
|
||||
|
||||
typedef void DefaultMC;
|
||||
|
||||
static string type_string() { return "evaluation secret"; }
|
||||
static string phase_name() { return T::name(); }
|
||||
|
||||
@@ -76,6 +78,10 @@ public:
|
||||
|
||||
static const bool actual_inputs = T::actual_inputs;
|
||||
|
||||
static const bool symmetric = true;
|
||||
|
||||
static bool real_shares(const Player&) { return true; }
|
||||
|
||||
static int threshold(int nplayers) { return T::threshold(nplayers); }
|
||||
|
||||
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
|
||||
@@ -148,9 +154,9 @@ public:
|
||||
Secret<T> operator>>(int i) const;
|
||||
|
||||
template<class U>
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitcom(StackedVector<U>& S, const vector<int>& regs);
|
||||
template<class U>
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
|
||||
|
||||
Secret<T> operator+(const Secret<T>& x) const;
|
||||
Secret<T>& operator+=(const Secret<T>& x) { *this = *this + x; return *this; }
|
||||
@@ -175,6 +181,7 @@ public:
|
||||
void finalize_input(U& inputter, int from, int n_bits);
|
||||
|
||||
int size() const { return registers.size(); }
|
||||
size_t maximum_size() const { return registers.size(); }
|
||||
RegVector& get_regs() { return registers; }
|
||||
const RegVector& get_regs() const { return registers; }
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ Secret<T> Secret<T>::operator>>(int i) const
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
void Secret<T>::bitcom(StackedVector<U>& S, const vector<int>& regs)
|
||||
{
|
||||
registers.clear();
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -210,7 +210,7 @@ void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
void Secret<T>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
if (regs.size() > registers.size())
|
||||
throw overflow("not enough bits for bit decomposition", regs.size(),
|
||||
|
||||
@@ -17,7 +17,8 @@ namespace GC
|
||||
void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
bool repeat)
|
||||
{
|
||||
if (repeat and OnlineOptions::singleton.live_prep)
|
||||
if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1)
|
||||
and P.num_players() == 2)
|
||||
{
|
||||
this->triples.push_back({{}});
|
||||
auto& triple = this->triples.back();
|
||||
@@ -35,6 +36,8 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
|
||||
void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = SemiSecret::default_length;
|
||||
super::prepare_mul(x.mask(n), y.mask(n), n);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ public:
|
||||
|
||||
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
bool repeat);
|
||||
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n);
|
||||
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n = -1);
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -79,6 +79,9 @@ array<SemiSecret, 3> SemiPrep::get_mixed_triple(int n)
|
||||
if (mixed_triples.empty())
|
||||
{
|
||||
assert(this->triple_generator);
|
||||
this->triple_generator->set_batch_size(
|
||||
BaseMachine::batch_size<SemiSecret>(DATA_MIXED,
|
||||
this->buffer_size));
|
||||
this->triple_generator->generateMixedTriples();
|
||||
for (auto& x : this->triple_generator->mixedTriples)
|
||||
{
|
||||
|
||||
@@ -38,12 +38,24 @@ public:
|
||||
|
||||
static const int default_length = sizeof(BitVec) * 8;
|
||||
|
||||
static const bool symmetric = V::symmetric;
|
||||
|
||||
static bool real_shares(const Player& P)
|
||||
{
|
||||
return V::real_shares(P);
|
||||
}
|
||||
|
||||
static string type_string() { return "binary secret"; }
|
||||
static string phase_name() { return "Binary computation"; }
|
||||
|
||||
static void trans(Processor<T>& processor, int n_outputs,
|
||||
const vector<int>& args);
|
||||
|
||||
static size_t maximum_size()
|
||||
{
|
||||
return default_length;
|
||||
}
|
||||
|
||||
SemiSecretBase()
|
||||
{
|
||||
}
|
||||
@@ -64,8 +76,8 @@ public:
|
||||
|
||||
void load_clear(int n, const Integer& x);
|
||||
|
||||
void bitcom(Memory<T>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<T>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<T>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<T>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const T& x, const T& y)
|
||||
{ *this = BitVec(x ^ y).mask(n); }
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SEMISECRET_HPP_
|
||||
#define GC_SEMISECRET_HPP_
|
||||
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
@@ -70,30 +73,40 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
|
||||
assert(protocol);
|
||||
protocol->init_mul();
|
||||
auto it = args.begin();
|
||||
int total_bits = 0, total_ops = 0;
|
||||
while (it < args.end())
|
||||
{
|
||||
int n_args = (*it++ - 3) / 2;
|
||||
int size = *it++;
|
||||
total_bits += n_args * size;
|
||||
it += n_args;
|
||||
int base = *it++;
|
||||
assert(n_args <= N_BITS);
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
{
|
||||
square64 square;
|
||||
for (int j = 0; j < n_args; j++)
|
||||
square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get();
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square.transpose(n_args, n_ops);
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
for (int k = 0; k < n_args; k += N_BITS)
|
||||
{
|
||||
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
|
||||
auto y_ext = SemiSecret(bit).extend_bit();
|
||||
protocol->prepare_mult(square.rows[j], y_ext, n_args, true);
|
||||
int left = min(N_BITS, n_args - k);
|
||||
square64 square;
|
||||
for (int j = 0; j < left; j++)
|
||||
square.rows[j] = processor.S.at(
|
||||
*(it + k + j) + i / N_BITS).get();
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
total_ops += n_ops;
|
||||
square.transpose(left, n_ops);
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
{
|
||||
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
|
||||
auto y_ext = SemiSecret(bit).extend_bit();
|
||||
protocol->prepare_mult(square.rows[j], y_ext, left, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
it += n_args;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d/%d repeat ANDs\n", total_bits, total_ops);
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
it = args.begin();
|
||||
@@ -103,13 +116,18 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
|
||||
int size = *it++;
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
{
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square64 square;
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
square.rows[j] = protocol->finalize_mul(n_args).get();
|
||||
square.transpose(n_ops, n_args);
|
||||
for (int j = 0; j < n_args; j++)
|
||||
processor.S.at(*(it + j) + i / N_BITS) = square.rows[j];
|
||||
for (int base = 0; base < n_args; base += N_BITS)
|
||||
{
|
||||
int left = min(N_BITS, n_args - base);
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square64 square;
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
square.rows[j] = protocol->finalize_mul(left).get();
|
||||
square.transpose(n_ops, left);
|
||||
for (int j = 0; j < left; j++)
|
||||
processor.S.at(*(it + base + j) + i / N_BITS) =
|
||||
square.rows[j];
|
||||
}
|
||||
}
|
||||
it += 2 * n_args + 1;
|
||||
}
|
||||
@@ -123,7 +141,7 @@ void SemiSecretBase<T, V>::load_clear(int n, const Integer& x)
|
||||
}
|
||||
|
||||
template<class T, class V>
|
||||
void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
|
||||
void SemiSecretBase<T, V>::bitcom(StackedVector<T>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -131,7 +149,7 @@ void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
|
||||
}
|
||||
|
||||
template<class T, class V>
|
||||
void SemiSecretBase<T, V>::bitdec(Memory<T>& S,
|
||||
void SemiSecretBase<T, V>::bitdec(StackedVector<T>& S,
|
||||
const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -146,3 +164,5 @@ void SemiSecretBase<T, V>::reveal(size_t n_bits, Clear& x)
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -42,7 +42,7 @@ inline ShareParty<T>& ShareParty<T>::s()
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
throw runtime_error("no ShareParty singleton");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -108,18 +108,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
else
|
||||
P = new PlainPlayer(this->N, "shareparty");
|
||||
|
||||
try
|
||||
{
|
||||
read_mac_key(
|
||||
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
|
||||
this->N,
|
||||
this->mac_key);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
SeededPRNG G;
|
||||
this->mac_key.randomize(G);
|
||||
}
|
||||
T::read_or_generate_mac_key(
|
||||
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
|
||||
*P, this->mac_key);
|
||||
|
||||
T::MC::setup(*P);
|
||||
|
||||
|
||||
@@ -46,6 +46,9 @@ public:
|
||||
|
||||
static const bool is_real = true;
|
||||
static const bool actual_inputs = true;
|
||||
static const bool symmetric = true;
|
||||
|
||||
static bool real_shares(const Player&) { return true; }
|
||||
|
||||
static ShareThread<U>& get_party()
|
||||
{
|
||||
@@ -118,9 +121,12 @@ public:
|
||||
typedef BitVec open_type;
|
||||
typedef NoShare mac_type;
|
||||
typedef NoValue mac_key_type;
|
||||
typedef NoShare mac_share_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
typedef void DefaultMC;
|
||||
|
||||
static const int N_BITS = clear::N_BITS;
|
||||
|
||||
static const bool dishonest_majority = false;
|
||||
@@ -151,12 +157,22 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
static GC::NoValue get_mac_key()
|
||||
{
|
||||
throw runtime_error("no MAC");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static string proto_fake_opts()
|
||||
{
|
||||
return T::fake_opts();
|
||||
}
|
||||
|
||||
static size_t maximum_size()
|
||||
{
|
||||
return default_length;
|
||||
}
|
||||
|
||||
RepSecretBase()
|
||||
{
|
||||
}
|
||||
@@ -166,8 +182,8 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<U>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const This& x, const This& y)
|
||||
{ *this = (x ^ y).mask(n); }
|
||||
@@ -194,7 +210,7 @@ public:
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static ReplicatedSecret constant(const typename super::clear& value,
|
||||
int my_num, typename super::mac_key_type, int = -1)
|
||||
int my_num, typename super::mac_key_type = {}, int = -1)
|
||||
{
|
||||
ReplicatedSecret res;
|
||||
if (my_num < 2)
|
||||
|
||||
@@ -54,7 +54,7 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
|
||||
}
|
||||
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
void RepSecretBase<U, L>::bitcom(StackedVector<U>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -62,7 +62,7 @@ void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
}
|
||||
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
void RepSecretBase<U, L>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
|
||||
@@ -67,7 +67,7 @@ inline ShareThread<T>& ShareThread<T>::s()
|
||||
if (singleton and T::is_real)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
throw runtime_error("no ShareThread singleton");
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -94,23 +94,35 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
processor.check_args(args, 4);
|
||||
protocol->init_mul();
|
||||
T x_ext, y_ext;
|
||||
int total_bits = 0;
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
total_bits += n_bits;
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
|
||||
if (not repeat and n == T::default_length)
|
||||
{
|
||||
protocol->prepare_mul(processor.S[left + j], processor.S[right + j]);
|
||||
continue;
|
||||
}
|
||||
|
||||
processor.S[left + j].mask(x_ext, n);
|
||||
if (repeat)
|
||||
processor.S[right].extend_bit(y_ext, n);
|
||||
else
|
||||
processor.S[right + j].mask(y_ext, n);
|
||||
processor.S[left + j].mask(x_ext, n);
|
||||
protocol->prepare_mult(x_ext, y_ext, n, repeat);
|
||||
}
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d%s ANDs\n", total_bits, repeat ? " repeat" : "");
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
@@ -121,10 +133,20 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
auto& res = processor.S[out + j];
|
||||
|
||||
if (not repeat and n == T::default_length)
|
||||
{
|
||||
res = protocol->finalize_mul();
|
||||
continue;
|
||||
}
|
||||
|
||||
protocol->finalize_mult(res, n);
|
||||
res.mask(res, n);
|
||||
}
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("always_check"))
|
||||
protocol->check();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -136,10 +158,12 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
|
||||
protocol->init_mul();
|
||||
auto it = args.begin();
|
||||
T x_ext, y_ext;
|
||||
int total_bits = 0;
|
||||
while (it < args.end())
|
||||
{
|
||||
int n_args = (*it++ - 3) / 2;
|
||||
int size = *it++;
|
||||
total_bits += size * n_args;
|
||||
it += n_args;
|
||||
int base = *it++;
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
@@ -155,6 +179,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
|
||||
it += n_args;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d repeat ANDs\n", total_bits);
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
it = args.begin();
|
||||
@@ -171,6 +198,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
|
||||
}
|
||||
it += 2 * n_args + 1;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("always_check"))
|
||||
protocol->check();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -32,6 +32,8 @@ public:
|
||||
array<T, 3> get_triple_no_count(int n_bits)
|
||||
{
|
||||
int max_n_bits = T::default_length;
|
||||
if (n_bits == -1)
|
||||
n_bits = max_n_bits;
|
||||
assert(n_bits <= max_n_bits);
|
||||
assert(n_bits > 0);
|
||||
array<T, 3> res;
|
||||
|
||||
@@ -48,6 +48,7 @@ public:
|
||||
Thread(int thread_num, ThreadMaster<T>& master);
|
||||
virtual ~Thread();
|
||||
|
||||
void start();
|
||||
void run();
|
||||
virtual void pre_run() {}
|
||||
virtual void run(Program& program);
|
||||
|
||||
@@ -31,7 +31,12 @@ template<class T>
|
||||
Thread<T>::Thread(int thread_num, ThreadMaster<T>& master) :
|
||||
master(master), machine(master.machine), processor(machine),
|
||||
N(master.N), P(0),
|
||||
thread_num(thread_num)
|
||||
thread_num(thread_num), thread(0)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Thread<T>::start()
|
||||
{
|
||||
pthread_create(&thread, 0, run_thread, this);
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ public:
|
||||
virtual Thread<T>* new_thread(int i);
|
||||
|
||||
void run();
|
||||
void run_with_error();
|
||||
|
||||
virtual void post_run() {}
|
||||
};
|
||||
|
||||
@@ -59,6 +59,25 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
|
||||
|
||||
template<class T>
|
||||
void ThreadMaster<T>::run()
|
||||
{
|
||||
if (opts.has_option("throw_exceptions"))
|
||||
run_with_error();
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
run_with_error();
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
cerr << "Fatal error: " << e.what() << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ThreadMaster<T>::run_with_error()
|
||||
{
|
||||
if (not opts.live_prep)
|
||||
{
|
||||
@@ -72,6 +91,9 @@ void ThreadMaster<T>::run()
|
||||
|
||||
for (int i = 0; i < machine.nthreads; i++)
|
||||
threads.push_back(new_thread(i));
|
||||
// must start after constructor due to virtual functions
|
||||
for (auto thread : threads)
|
||||
thread->start();
|
||||
for (auto thread : threads)
|
||||
thread->join_tape();
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ public:
|
||||
return *part_MC;
|
||||
}
|
||||
|
||||
void init_open(const Player& P, int n)
|
||||
void init_open(const Player& P, int n = 0)
|
||||
{
|
||||
part_MC->init_open(P);
|
||||
sizes.clear();
|
||||
|
||||
@@ -53,9 +53,15 @@ public:
|
||||
static const bool malicious = T::malicious;
|
||||
static const bool expensive_triples = T::expensive_triples;
|
||||
static const bool randoms_for_opens = false;
|
||||
static const bool symmetric = true;
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
static bool real_shares(const Player&)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static int size()
|
||||
{
|
||||
return part_type::size() * default_length;
|
||||
@@ -72,6 +78,11 @@ public:
|
||||
T::read_or_generate_mac_key(directory, P, key);
|
||||
}
|
||||
|
||||
static typename T::mac_type get_mac_key()
|
||||
{
|
||||
return T::get_mac_key();
|
||||
}
|
||||
|
||||
template<class U>
|
||||
static void reveal_inst(U& processor, const vector<int>& args)
|
||||
{
|
||||
|
||||
@@ -80,14 +80,15 @@
|
||||
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
|
||||
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
|
||||
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, Proc.sync(PC1.get()))) \
|
||||
X(CONVCINTVEC, Proc.convcintvec(instruction)) \
|
||||
X(CONVCBITVEC, Proc.convcbitvec(instruction)) \
|
||||
X(CONVCBITVEC, Proc.Procb.convcbitvec(instruction, Proc.get_Ci(), &Proc.P)) \
|
||||
X(CONVCBIT2S, PROC.convcbit2s(instruction)) \
|
||||
X(DABIT, Proc.dabit(INST)) \
|
||||
X(EDABIT, Proc.edabit(INST)) \
|
||||
X(SEDABIT, Proc.edabit(INST, true)) \
|
||||
X(SPLIT, Proc.split(INST)) \
|
||||
X(CALL_ARG, ) \
|
||||
|
||||
#define GC_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, EXTRA)) \
|
||||
@@ -101,6 +102,7 @@
|
||||
X(CONVCINT, C0 = PI1) \
|
||||
X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \
|
||||
X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \
|
||||
X(CONVCBITVEC, PROC.convcbitvec(instruction, Ci, 0)) \
|
||||
X(PRINTCHR, PROC.print_chr(IMM)) \
|
||||
X(PRINTSTR, PROC.print_str(IMM)) \
|
||||
X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \
|
||||
@@ -146,8 +148,11 @@
|
||||
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
|
||||
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
|
||||
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
|
||||
X(CRASH, if (I0.get()) throw crash_requested()) \
|
||||
X(CRASH, if (I0.get() and T::actual_inputs) throw crash_requested()) \
|
||||
X(ACTIVE, ) \
|
||||
X(LDTN, I0 = BaseMachine::thread_num) \
|
||||
X(CALL_TAPE, PROC.call_tape(INST, MD)) \
|
||||
X(CALL_ARG, ) \
|
||||
|
||||
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
The Software is copyright (c) 2023, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
The Software is copyright (c) 2024, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
|
||||
CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence.
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
#include "Protocols/AtlasPrep.h"
|
||||
#include "GC/AtlasSecret.h"
|
||||
|
||||
#include "ShamirMachine.hpp"
|
||||
#include "Protocols/Atlas.hpp"
|
||||
|
||||
#include "Shamir.hpp"
|
||||
|
||||
#endif /* MACHINES_ATLAS_HPP_ */
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/Shamir.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Machines/MalRep.hpp"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/Shamir.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -232,9 +232,9 @@ OTMachine::OTMachine(int argc, const char** argv)
|
||||
gettimeofday(&baseOTstart, NULL);
|
||||
// swap role for base OTs
|
||||
if (opt.isSet("-r"))
|
||||
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
|
||||
else
|
||||
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
|
||||
cout << "real mode " << opt.isSet("-r") << endl;
|
||||
BaseOT& bot = *bot_;
|
||||
bot.exec_base();
|
||||
|
||||
@@ -5,3 +5,5 @@
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
template class FieldMachine<Share, Share, DishonestMajorityMachine>;
|
||||
|
||||
template class Machine<Share<gfp_<0, 2>>, Share<gf2n>>;
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#ifndef MACHINES_SPDZ_HPP_
|
||||
#define MACHINES_SPDZ_HPP_
|
||||
|
||||
#include "Protocols/MAC_Check.h"
|
||||
#include "Protocols/SPDZ.h"
|
||||
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
|
||||
50
Machines/Shamir.hpp
Normal file
50
Machines/Shamir.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* ShamirMachine.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MACHINE_SHAMIR_HPP_
|
||||
#define MACHINE_SHAMIR_HPP_
|
||||
|
||||
#include "Protocols/ShamirOptions.h"
|
||||
#include "Protocols/ShamirShare.h"
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "GC/VectorProtocol.h"
|
||||
#include "GC/CcdPrep.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "Processor/FieldMachine.hpp"
|
||||
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Protocols/ShamirInput.hpp"
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
#include "Protocols/MaliciousShamirMC.hpp"
|
||||
#include "Protocols/MaliciousShamirPO.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Spdz2kPrep.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
template<template<class U> class T>
|
||||
ShamirMachineSpec<T>::ShamirMachineSpec(int argc, const char** argv)
|
||||
{
|
||||
auto& opts = ShamirOptions::singleton;
|
||||
ez::ezOptionParser opt;
|
||||
opts = {opt, argc, argv};
|
||||
HonestMajorityFieldMachine<T>(argc, argv, opt, opts.nparties);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "GC/PersonalPrep.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
//template class GC::ShareParty<GC::TinierSecret<gf2n_mac_key>>;
|
||||
template class GC::CcdPrep<GC::TinierSecret<gf2n_mac_key>>;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "Math/BitVec.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/Shamir.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
8
Machines/export-atlas.cpp
Normal file
8
Machines/export-atlas.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* export-vm.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "maximal.hpp"
|
||||
|
||||
template class Machine<AtlasShare<gfp_<0, 2>>>;
|
||||
8
Machines/export-cowgear.cpp
Normal file
8
Machines/export-cowgear.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* export-vm.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "maximal.hpp"
|
||||
|
||||
template class Machine<CowGearShare<gfp_<0, 2>>>;
|
||||
8
Machines/export-dealer.cpp
Normal file
8
Machines/export-dealer.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* export-vm.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "maximal.hpp"
|
||||
|
||||
template class Machine<DealerShare<SignedZ2<64>>>;
|
||||
8
Machines/export-hemi.cpp
Normal file
8
Machines/export-hemi.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* export-vm.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "maximal.hpp"
|
||||
|
||||
template class Machine<HemiShare<gfp_<0, 2>>>;
|
||||
8
Machines/export-rep4-ring.cpp
Normal file
8
Machines/export-rep4-ring.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* export-vm.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "maximal.hpp"
|
||||
|
||||
template class Machine<Rep4Share2<64>>;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user