Merge branch 'refs/heads/master' into parallel_permutations

# Conflicts:
#	Processor/Processor.hpp
#	Protocols/FakeProtocol.h
#	Protocols/Rep3Shuffler.h
#	Protocols/Rep3Shuffler.hpp
#	Protocols/SecureShuffle.h
#	Protocols/SecureShuffle.hpp
This commit is contained in:
Vincent Ehrmanntraut
2024-12-12 09:25:02 +01:00
393 changed files with 8477 additions and 2586 deletions

1
.gitignore vendored
View File

@@ -49,6 +49,7 @@ callgrind.out.*
Programs/Bytecode/*
Programs/Schedules/*
Programs/Public-Input/*
Programs/Functions
*.com
*.class
*.dll

3
.gitmodules vendored
View File

@@ -13,3 +13,6 @@
[submodule "deps/SimplestOT_C"]
path = deps/SimplestOT_C
url = https://github.com/mkskeller/SimplestOT_C
[submodule "deps/sse2neon"]
path = deps/sse2neon
url = https://github.com/DLTcollab/sse2neon

View File

@@ -15,7 +15,7 @@ int AndJob::run()
#endif
__m128i* prf_output = new __m128i[PAD_TO_8(ProgramParty::s().get_n_parties())];
auto gate = gates.begin();
vector< GC::Secret<EvalRegister> >& S = *this->S;
auto& S = *this->S;
const vector<int>& args = *this->args;
int i_gate = 0;
for (size_t i = start; i < end; i += 4)

View File

@@ -15,7 +15,7 @@ using namespace std;
class AndJob
{
vector< GC::Secret<EvalRegister> >* S;
StackedVector< GC::Secret<EvalRegister> >* S;
const vector<int>* args;
public:
@@ -25,7 +25,7 @@ public:
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
void reset(StackedVector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
{
this->S = &S;

View File

@@ -1,5 +1,31 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.4.0 (November 21, 2024)
- Functionality to call high-level code from C++
- Matrix triples from file for all appropriate protocols
- Exit with message on errors instead of uncaught exceptions
- Reduce memory usage for binary memory
- Optimized cint-regint conversion in Dealer protocol
- Fixed security bug: missing MAC check in probabilistic truncation
## 0.3.9 (July 9, 2024)
- Inference with non-sequential PyTorch networks
- SHA-3 for any input length (@hiddely)
- Improved client facilities
- Shuffling with malicious security for SPDZ-wise protocols by [Asharov et al.](https://ia.cr/2022/1595)
- More reusable bytecode via in-thread calling facility
- Recursive functions without return values
- Fewer rounds for parallel matrix multiplications (@vincent-ehrmanntraut)
- Optimized usage of SoftSpokenOT in semi-honest protocols
- More integrity checks on storage in MAC-based protocols
- Use C++17
- Use glibc 2.18 for the binaries
- Fixed security bugs: remotely caused buffer overflows (#1382)
- Fixed security bug: Missing randomization before revealing to client
- Fixed security bug: Bias in Rep3 secure shuffling
## 0.3.8 (December 14, 2023)
- Functionality for multiple nodes per party

8
CONFIG
View File

@@ -48,6 +48,8 @@ ARCH =
AVX_OT = 0
endif
AVX_SIMPLEOT := $(AVX_OT)
ifeq ($(OS), Darwin)
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
@@ -69,11 +71,13 @@ CXX = clang++
# use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine
AVX_SIMPLEOT := $(AVX_OT)
ifeq ($(USE_GF2N_LONG),1)
GF2N_LONG = -DUSE_GF2N_LONG
endif
ifeq ($(AVX_OT), 0)
ifeq ($(AVX_SIMPLEOT), 0)
CFLAGS += -DNO_AVX_OT
endif
@@ -106,7 +110,7 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++17 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

View File

@@ -203,8 +203,10 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
def add_usage(self, req_node):
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)
n = (n - 3) // 2
req_node.increment(('bit', 'triple'), size * n)
if n > 1:
req_node.increment(('bit', 'mixed'), size * ((n + 63) // 64))
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))
@@ -538,8 +540,8 @@ class split(base.Instruction):
:param: number of arguments to follow (number of bits times number of additive shares plus one)
:param: source (sint)
:param: first share of least significant bit
:param: second share of least significant bit
:param: first share of least significant bit (sbit)
:param: second share of least significant bit (sbit)
:param: (remaining share of least significant bit)...
:param: (repeat from first share for bit one step higher)...
"""

View File

@@ -13,7 +13,7 @@ from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint,
from Compiler.types import vectorized_classmethod
from Compiler.program import Tape, Program
from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint, library
from Compiler import util, oram, floatingpoint, library, comparison
from Compiler import instructions_base
import Compiler.GC.instructions as inst
import operator
@@ -21,6 +21,11 @@ import math
import itertools
from functools import reduce
class _binary:
def reveal_to(self, *args, **kwargs):
raise CompilerError(
'%s does not support revealing to indivual players' % type(self))
class bits(Tape.Register, _structure, _bit):
n = 40
unit = 64
@@ -149,6 +154,12 @@ class bits(Tape.Register, _structure, _bit):
self.n = n
def set_size(self, size):
pass
def load_int(self, value):
n_limbs = math.ceil(self.n / self.unit)
for i in range(n_limbs):
self.conv_regint(min(self.unit, self.n - i * self.unit),
self[i], regint(value % 2 ** self.unit))
value >>= self.unit
def load_other(self, other):
if isinstance(other, cint):
assert(self.n == other.size)
@@ -236,12 +247,14 @@ class bits(Tape.Register, _structure, _bit):
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
Bit-wise oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
This will output 1 because it selects the two least
significant bits from 5 and the rest of the bits from 2.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
def zero_if_not(self, condition):
@@ -268,6 +281,9 @@ class bits(Tape.Register, _structure, _bit):
self.bit_compose(source.bit_decompose()[base:base + size]))
def vector_size(self):
return self.n
@staticmethod
def size_for_mem():
return 1
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
@@ -302,13 +318,6 @@ class cbits(bits):
else:
return super(cbits, cls).conv(other)
types = {}
def load_int(self, value):
n_limbs = math.ceil(self.n / self.unit)
tmp = regint(size=n_limbs)
for i in range(n_limbs):
tmp[i].load_int(value % 2 ** self.unit)
value >>= self.unit
self.load_other(tmp)
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
@@ -502,11 +511,7 @@ class sbits(bits):
if self.n <= 32:
inst.ldbits(self, self.n, value)
else:
size = math.ceil(self.n / self.unit)
tmp = regint(size=size)
for i in range(size):
tmp[i].load_int((value >> (i * 64)) % 2**64)
self.load_other(tmp)
bits.load_int(self, value)
def load_other(self, other):
if isinstance(other, cbits) and self.n == other.n:
inst.convcbit2s(self.n, self, other)
@@ -675,7 +680,7 @@ class sbits(bits):
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
class sbitvec(_vec, _bit):
class sbitvec(_vec, _bit, _binary):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
This facilitates parallel arithmetic operations in binary circuits.
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -732,15 +737,20 @@ class sbitvec(_vec, _bit):
:py:obj:`v` and the columns by calling :py:obj:`elements`.
"""
class sbitvecn(cls, _structure):
n_bits = n
@staticmethod
def malloc(size, creator_tape=None):
return sbit.malloc(size * n, creator_tape=creator_tape)
def get_type(n):
return cls.get_type(n)
@classmethod
def malloc(cls, size, creator_tape=None):
return sbit.malloc(
size * cls.mem_size(), creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@staticmethod
def mem_size():
return n
return sbits.get_type(n).mem_size()
@classmethod
def get_input_from(cls, player, size=1, f=0):
""" Secret input from :py:obj:`player`. The input is decomposed
@@ -748,17 +758,19 @@ class sbitvec(_vec, _bit):
:param: player (int)
"""
v = [0] * n
sbits._check_input_player(player)
instructions_base.check_vector_size(size)
for i in range(size):
vv = [sbit() for i in range(n)]
inst.inputbvec(n + 3, f, player, *vv)
for j in range(n):
tmp = vv[j] << i
v[j] = tmp ^ v[j]
sbits._check_input_player(player)
return cls.from_vec(v)
if size == 1:
res = cls.from_vec(sbit() for i in range(n))
inst.inputbvec(n + 3, f, player, *res.v)
return res
else:
elements = []
for i in range(size):
v = sbits.get_type(n)()
inst.inputb(player, n, f, v)
elements.append(v)
return cls(elements)
get_raw_input_from = get_input_from
@classmethod
def from_vec(cls, vector):
@@ -780,38 +792,27 @@ class sbitvec(_vec, _bit):
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
assert size is None or size == self.v[0].n
@vectorized_classmethod
def load_mem(cls, address):
size = instructions_base.get_global_vector_size()
if size not in (None, 1):
assert isinstance(address, int) or len(address) == 1
sb = sbits.get_type(size)
return cls.from_vec(sb.bit_compose(
sbit.load_mem(address + i + j * n) for j in range(size))
for i in range(n))
if not isinstance(address, int):
v = [sbit.load_mem(x, size=n).v[0] for x in address]
return cls(v)
@classmethod
def load_mem(cls, address, size=None):
if isinstance(address, int) or len(address) == 1:
address = [address + i * cls.mem_size()
for i in range(size or 1)]
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
assert size == None
return cls(
[sbits.get_type(n).load_mem(x) for x in address])
def store_in_mem(self, address):
size = 1
for x in self.v:
if not util.is_constant(x):
size = max(size, x.n)
v = [sbits.get_type(size).conv(x) for x in self.v]
if not isinstance(address, int) and len(address) != 1:
v = self.elements()
assert len(v) == len(address)
for x, y in zip(v, address):
for i, xx in enumerate(x.bit_decompose(n)):
xx.store_in_mem(y + i)
if isinstance(address, int) or len(address) == 1:
address = [address + i * self.mem_size()
for i in range(size)]
else:
assert isinstance(address, int) or len(address) == 1
for i in range(n):
for j, x in enumerate(v[i].bit_decompose()):
x.store_in_mem(address + i + j * n)
assert size == len(address)
for x, dest in zip(self.elements(), address):
x.store_in_mem(dest)
@classmethod
def two_power(cls, nn, size=1):
return cls.from_vec(
@@ -864,7 +865,7 @@ class sbitvec(_vec, _bit):
assert isinstance(elements, sint)
if Program.prog.use_split():
x = elements.split_to_two_summands(length)
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
v = sbitint.bit_adder(x[0], x[1])
else:
prog = Program.prog
if not prog.options.ring:
@@ -877,6 +878,7 @@ class sbitvec(_vec, _bit):
length, prog.security)
prog.use_edabit(backup)
return
comparison.require_ring_size(length, 'A2B conversion')
l = int(Program.prog.options.ring)
r, r_bits = sint.get_edabit(length, size=elements.size)
c = ((elements - r) << (l - length)).reveal()
@@ -885,6 +887,8 @@ class sbitvec(_vec, _bit):
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
elif isinstance(elements, sbitvec):
self.v = elements.v
elif elements is not None and not (util.is_constant(elements) and \
elements == 0):
self.v = sbits.trans(elements)
@@ -1347,13 +1351,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
def __add__(self, other):
if util.is_zero(other):
return self
a, b = self.expand(other)
try:
a, b = self.expand(other)
except:
return NotImplemented
v = sbitint.bit_adder(a, b)
return self.get_type(len(v)).from_vec(v)
__radd__ = __add__
__sub__ = _bitint.__sub__
def __rsub__(self, other):
a, b = self.expand(other)
try:
a, b = self.expand(other)
except:
return NotImplemented
return self.from_vec(b) - self.from_vec(a)
def __mul__(self, other):
if isinstance(other, sbits):
@@ -1447,7 +1457,7 @@ class cbitfix(object):
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
class sbitfix(_fix):
class sbitfix(_fix, _binary):
""" Secret signed fixed-point number in one binary register.
Use :py:obj:`set_precision()` to change the precision.
@@ -1515,7 +1525,7 @@ class sbitfix(_fix):
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
class sbitfixvec(_fix, _vec):
class sbitfixvec(_fix, _vec, _binary):
""" Vector of fixed-point numbers for parallel binary computation.
Use :py:obj:`set_precision()` to change the precision.

View File

@@ -76,7 +76,7 @@ class AllocRange:
self.top += size
self.limit = max(self.limit, self.top)
if res >= REG_MAX:
raise RegisterOverflowError()
raise RegisterOverflowError(size)
return res
def free(self, base, size):
@@ -178,6 +178,8 @@ class StraightlineAllocator:
dup = dup.vectorbase
self.alloc[dup] = self.alloc[base]
dup.i = self.alloc[base]
if not dup.dup_count:
dup.dup_count = len(base.duplicates)
def dealloc_reg(self, reg, inst, free):
if reg.vector:
@@ -209,7 +211,8 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)
if reg not in self.program.base_addresses:
if reg not in self.program.base_addresses \
and not isinstance(inst, call_arg):
free.free(base)
if inst.is_vec() and base.vector:
self.defined[base] = inst
@@ -274,8 +277,9 @@ class StraightlineAllocator:
for reg in self.alloc:
for x in reg.get_all():
if x not in self.dealloc and reg not in self.dealloc \
and len(x.duplicates) == 0:
print('Warning: read before write at register', x)
and len(x.duplicates) == x.dup_count:
print('Warning: read before write at register %s/%x' %
(x, id(x)))
print('\tregister trace: %s' % format_trace(x.caller,
'\t\t'))
if options.stop:
@@ -485,7 +489,11 @@ class Merger:
if this[-1] < other[0]:
del this[:]
this.append(n)
for inst in other:
if last_access_this_kind == last_mem_write_of:
insts = itertools.chain(other, this)
else:
insts = other
for inst in insts:
add_edge(inst, n)
def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
@@ -518,7 +526,11 @@ class Merger:
last_other_kind[-1] > last_this_kind[-1]:
last_this_kind[:] = []
last_this_kind.append(n)
for i in last_other_kind:
if last_this_kind == last_mem_write:
insts = itertools.chain(last_other_kind, last_this_kind)
else:
insts = last_other_kind
for i in insts:
add_edge(i, n)
def keep_order(instr, n, t, arg_index=None):
@@ -608,7 +620,8 @@ class Merger:
# so this threshold should lead to acceptable compile times even on slower processors.
first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4]
second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5]
max_dependencies_per_matrix = 1500**2
max_dependencies_per_matrix = \
self.block.parent.program.budget
if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix:
if block.warn_about_mem and not block.parent.warned_about_mem:
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')
@@ -748,6 +761,8 @@ class Merger:
G.remove_node(i)
merge_nodes.discard(i)
stats[type(instructions[i]).__name__] += 1
for reg in instructions[i].get_def():
self.block.parent.program.base_addresses.pop(reg)
instructions[i] = None
if unused_result:
eliminate(i)
@@ -779,10 +794,10 @@ class RegintOptimizer:
self.rev_offset_cache = {}
self.range_cache = util.dict_by_id()
def add_offset(self, res, new_base, new_offset):
self.offset_cache[res] = new_base, new_offset
if (new_base.i, new_offset) not in self.rev_offset_cache:
self.rev_offset_cache[new_base.i, new_offset] = res
def add_offset(self, res, new_base, new_offset, multiplier):
self.offset_cache[res] = new_base, new_offset, multiplier
if (new_base.i, new_offset, multiplier) not in self.rev_offset_cache:
self.rev_offset_cache[new_base.i, new_offset, multiplier] = res
def run(self, instructions, program):
for i, inst in enumerate(instructions):
@@ -806,31 +821,40 @@ class RegintOptimizer:
def f(base, delta_reg):
delta = self.cache[delta_reg]
if base in self.offset_cache:
reg, offset = self.offset_cache[base]
reg, offset, mult = self.offset_cache[base]
new_base, new_offset = reg, offset + delta
else:
new_base, new_offset = base, delta
self.add_offset(inst.args[0], new_base, new_offset)
mult = 1
self.add_offset(inst.args[0], new_base, new_offset,
mult)
if inst.args[1] in self.cache:
f(inst.args[2], inst.args[1])
elif inst.args[2] in self.cache:
f(inst.args[1], inst.args[2])
elif isinstance(inst, subint_class) and \
inst.args[2] in self.cache:
delta = self.cache[inst.args[2]]
if inst.args[1] in self.offset_cache:
reg, offset = self.offset_cache[inst.args[1]]
new_base, new_offset = reg, offset - delta
else:
new_base, new_offset = inst.args[1], -delta
self.add_offset(inst.args[0], new_base, new_offset)
elif isinstance(inst, subint_class):
def f(reg, cached, reverse):
delta = self.cache[cached]
if reg in self.offset_cache:
reg, offset, mult = self.offset_cache[reg]
new_base, new_offset = reg, offset - delta
else:
new_base = reg
new_offset = -delta if reverse else delta
mult = 1
self.add_offset(inst.args[0], new_base, new_offset,
mult if reverse else -mult)
if inst.args[1] in self.cache:
f(inst.args[2], inst.args[1], False)
elif inst.args[2] in self.cache:
f(inst.args[1], inst.args[2], True)
elif isinstance(inst, IndirectMemoryInstruction):
if inst.args[1] in self.cache:
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
instructions[i]._protect = inst._protect
elif inst.args[1] in self.offset_cache:
base, offset = self.offset_cache[inst.args[1]]
addr = self.rev_offset_cache[base.i, offset]
base, offset, mult = self.offset_cache[inst.args[1]]
addr = self.rev_offset_cache[base.i, offset, mult]
inst.args[1] = addr
elif inst.args[1] in self.range_cache:
size, base = self.range_cache[inst.args[1]]

View File

@@ -5,7 +5,7 @@ the ones used below into ``Programs/Circuits`` as follows::
make Programs/Circuits
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
.. _`Bristol Fashion`: https://nigelsmart.github.io/MPC-Circuits
"""
import math
@@ -15,6 +15,7 @@ from Compiler.library import function_block, get_tape
from Compiler import util
import itertools
import struct
import os
class Circuit:
"""
@@ -47,7 +48,12 @@ class Circuit:
"""
def __init__(self, name):
self.name = name
self.filename = 'Programs/Circuits/%s.txt' % name
if not os.path.exists(self.filename):
if os.system('make Programs/Circuits'):
raise CompilerError('Cannot download circuit descriptions. '
'Make sure make and git are installed.')
f = open(self.filename)
self.functions = {}
@@ -57,8 +63,9 @@ class Circuit:
def run(self, *inputs):
n = inputs[0][0].n, get_tape()
if n not in self.functions:
self.functions[n] = function_block(lambda *args:
self.compile(*args))
self.functions[n] = function_block(
lambda *args: self.compile(*args))
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
flat_res = self.functions[n](*itertools.chain(*inputs))
res = []
i = 0
@@ -124,7 +131,7 @@ Keccak_f = None
def sha3_256(x):
"""
This function implements SHA3-256 for inputs of up to 1080 bits::
This function implements SHA3-256 for inputs of any length::
from circuit import sha3_256
a = sbitvec.from_vec([])
@@ -138,7 +145,8 @@ def sha3_256(x):
for x in a, b, c, d, e, f, g, h:
sha3_256(x).reveal_print_hex()
This should output the `test vectors
This should output the hashes of the above inputs, beginning with
the `test vectors
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
0 byte::

View File

@@ -76,13 +76,13 @@ def require_ring_size(k, op):
program.curr_tape.require_bit_length(k)
@instructions_base.cisc
def LTZ(s, a, k, kappa):
def LTZ(s, a, k):
"""
s = (a ?< 0)
k: bit length of a
"""
movs(s, program.non_linear.ltz(a, k, kappa))
movs(s, program.non_linear.ltz(a, k))
def LtzRing(a, k):
from .types import sint, _bitint
@@ -105,14 +105,14 @@ def LtzRing(a, k):
u = CarryOutRaw(a[::-1], b[::-1])
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
def LessThanZero(a, k, kappa):
def LessThanZero(a, k):
from . import types
res = types.sint()
LTZ(res, a, k, kappa)
LTZ(res, a, k)
return res
@instructions_base.cisc
def Trunc(d, a, k, m, kappa, signed):
def Trunc(d, a, k, m, signed):
"""
d = a >> m
@@ -124,7 +124,7 @@ def Trunc(d, a, k, m, kappa, signed):
movs(d, a)
return
else:
movs(d, program.non_linear.trunc(a, k, m, kappa, signed))
movs(d, program.non_linear.trunc(a, k, m, signed=signed))
def TruncRing(d, a, k, m, signed):
program.curr_tape.require_bit_length(1)
@@ -197,13 +197,13 @@ def TruncLeakyInRing(a, k, m, signed):
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits[:n_bits], 0)
BitLTL(u, masked, r_bits[:n_bits])
res = (u << n_bits) + masked - r
if signed:
res -= (1 << (n_bits - 1))
return res
def TruncRoundNearest(a, k, m, kappa, signed=False):
def TruncRoundNearest(a, k, m, signed=False):
"""
Returns a / 2^m, rounded to the nearest integer.
@@ -212,12 +212,10 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
"""
if m == 0:
return a
nl = program.non_linear
nl.check_security(kappa)
return program.non_linear.trunc_round_nearest(a, k, m, signed)
@instructions_base.cisc
def Mod2m(a_prime, a, k, m, kappa, signed):
def Mod2m(a_prime, a, k, m, signed):
"""
a_prime = a % 2^m
@@ -225,8 +223,6 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
m: compile-time integer
signed: True/False, describes a
"""
nl = program.non_linear
nl.check_security(kappa)
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
def Mod2mRing(a_prime, a, k, m, signed):
@@ -237,13 +233,13 @@ def Mod2mRing(a_prime, a, k, m, signed):
tmp = a + r_prime
c_prime = (tmp << shift).reveal(False) >> shift
u = sint()
BitLTL(u, c_prime, r_bin[:m], 0)
BitLTL(u, c_prime, r_bin[:m])
res = (u << m) + c_prime - r_prime
if a_prime is not None:
movs(a_prime, res)
return res
def Mod2mField(a_prime, a, k, m, kappa, signed):
def Mod2mField(a_prime, a, k, m, signed):
from .types import sint
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
@@ -255,7 +251,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
t = [program.curr_block.new_reg('s') for i in range(6)]
c2m = program.curr_block.new_reg('c')
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, r, k, m, kappa)
PRandM(r_dprime, r_prime, r, k, m)
ld2i(c2m, m)
mulm(t[0], r_dprime, c2m)
if signed:
@@ -268,9 +264,9 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
asm_open(True, c, t[3])
modc(c_prime, c, c2m)
if const_rounds:
BitLTC1(u, c_prime, r, kappa)
BitLTC1(u, c_prime, r)
else:
BitLTL(u, c_prime, r, kappa)
BitLTL(u, c_prime, r)
mulm(t[4], u, c2m)
submr(t[5], c_prime, r_prime)
adds(a_prime, t[5], t[4])
@@ -288,13 +284,15 @@ def MaskingBitsInRing(m, strict=False):
r_bin = r
return sint.bit_compose(r), r_bin
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
def PRandM(r_dprime, r_prime, b, k, m, use_dabit=True):
"""
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
r_prime = random secret integer in range [0, 2^m - 1]
b = array containing bits of r_prime
"""
program.curr_tape.require_bit_length(k + kappa)
assert k >= m
kappa = program.security
program.curr_tape.require_bit_length(k + kappa, reason='statistical masking as in https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation')
from .types import sint
if program.use_edabit() and not const_rounds:
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
@@ -329,7 +327,7 @@ def PRandInt(r, k):
bit(t[1][i])
adds(t[2][i], t[0][i], t[1][i])
def BitLTC1(u, a, b, kappa):
def BitLTC1(u, a, b):
"""
u = a <? b
@@ -395,7 +393,7 @@ def BitLTC1(u, a, b, kappa):
subcfi(c[3][i], a_bits[i], 1)
mulm(t[3][i], s[i], c[3][i])
adds(t[4][i], t[4][i-1], t[3][i])
Mod2(u, t[4][k-1], k, kappa, False)
Mod2(u, t[4][k-1], k, False)
return p, a_bits, d, s, t, c, b, pre_input
def carry(b, a, compute_p=True):
@@ -414,7 +412,7 @@ def carry(b, a, compute_p=True):
# from WP9 report
# length of a is even
def CarryOutAux(a, kappa):
def CarryOutAux(a):
k = len(a)
if k > 1 and k % 2 == 1:
a.append(None)
@@ -424,12 +422,12 @@ def CarryOutAux(a, kappa):
if k > 1:
for i in range(k//2):
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
return CarryOutAux(u[:k//2][::-1], kappa)
return CarryOutAux(u[:k//2][::-1])
else:
return a[0][1]
# carry out with carry-in bit c
def CarryOut(res, a, b, c=0, kappa=None):
def CarryOut(res, a, b, c=0):
"""
res = last carry bit in addition of a and b
@@ -456,7 +454,7 @@ def CarryOutRaw(a, b, c=0):
s[0] = d[-1][0].bit_and(c)
s[1] = d[-1][1] + s[0]
d[-1][1] = s[1]
return CarryOutAux(d[::-1], None)
return CarryOutAux(d[::-1])
def CarryOutRawLE(a, b, c=0):
""" Little-endian version """
@@ -469,7 +467,7 @@ def CarryOutLE(a, b, c=0):
CarryOut(res, a[::-1], b[::-1], c)
return res
def BitLTL(res, a, b, kappa):
def BitLTL(res, a, b):
"""
res = a <? b (logarithmic rounds version)
@@ -624,7 +622,7 @@ def KMulC(a):
PreMulC_without_inverses(p, a)
return p
def Mod2(a_0, a, k, kappa, signed):
def Mod2(a_0, a, k, signed):
"""
a_0 = a % 2
@@ -641,7 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
tc = program.curr_block.new_reg('c')
t = [program.curr_block.new_reg('s') for i in range(6)]
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
PRandM(r_dprime, r_prime, [r_0], k, 1)
r_0 = r_prime
mulsi(t[0], r_dprime, 2)
if signed:

View File

@@ -13,8 +13,18 @@ from .program import Program, defaults
class Compiler:
singleton = None
def __init__(self, custom_args=None, usage=None, execute=False,
split_args=False):
if Compiler.singleton:
raise CompilerError(
"Cannot have more than one compiler instance. "
"It's not possible to run direct compilation programs with "
"compile.py or compile-run.py.")
else:
Compiler.singleton = self
if usage:
self.usage = usage
else:
@@ -165,7 +175,8 @@ class Compiler:
dest="prime",
default=defaults.prime,
help="use bit decomposition with a specifed prime modulus "
"for non-linear computation (default: use the masking approach)",
"for non-linear computation (default: use the masking approach). "
"Don't use this unless you're certain that you need it.",
)
parser.add_option(
"-I",
@@ -252,6 +263,13 @@ class Compiler:
dest="hostfile",
help="hosts to execute with",
)
parser.add_option(
"-t",
"--tidy_output",
action="store_true",
dest="tidy_output",
help="make output prints tidy and grouped by party (note: delays the prints)",
)
else:
parser.add_option(
"-E",
@@ -263,6 +281,14 @@ class Compiler:
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.options.verbose:
self.runtime_args += ["--verbose"]
if self.options.execute:
self.options.execute = re.sub("-party.x$", "", self.options.execute)
for s, l in self.match.items():
if self.options.execute == l:
self.options.execute = s
break
if self.execute:
if not self.options.execute:
if len(self.args) > 1:
@@ -313,6 +339,8 @@ class Compiler:
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
if self.options.execute.find("dealer") >= 0:
self.prog.use_edabit(True)
def build_vars(self):
from . import comparison, floatingpoint, instructions, library, types
@@ -368,7 +396,14 @@ class Compiler:
"cfloat",
"squant",
]:
del self.VARS[i]
class dummy:
def __init__(self, *args):
raise CompilerError(self.error)
dummy.error = i + " not availabe with binary circuits"
if i in ("cint", "cfix"):
dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \
"latest/Compiler.html#Compiler.types." + i
self.VARS[i] = dummy
else:
self.sint = types.sint
self.sfix = types.sfix
@@ -418,6 +453,8 @@ class Compiler:
continue
m = re.match(r"(\s*)if(\W.*):", line)
if m:
while if_stack and if_stack[-1][0] == m.group(1):
if_stack.pop()
if_stack.append((m.group(1), len(output)))
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
output.append("%sdef _():\n" % (m.group(1)))
@@ -503,13 +540,15 @@ class Compiler:
return self.prog
@staticmethod
def executable_from_protocol(protocol):
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
@classmethod
def executable_from_protocol(cls, protocol):
match = cls.match
if protocol in match:
protocol = match[protocol]
if protocol.find("bmr") == -1:
@@ -588,11 +627,25 @@ class Compiler:
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
def run_with_error(i):
try:
run(i)
except IOError:
print('IO error when copying files, does %s have enough space?' %
hostnames[i])
raise
import threading
import random
import io
def run_and_capture_outputs(outputs, fn, i):
out = fn(i)
outputs[i] = out
threads = []
for i in range(len(hosts)):
threads.append(threading.Thread(target=run, args=(i,)))
threads.append(threading.Thread(target=run_with_error, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
@@ -600,6 +653,14 @@ class Compiler:
# execution
threads = []
# tidy up output prints
hide_option = False
if self.options.tidy_output:
outputs = []
for i in range(len(connections)):
outputs += [""]
hide_option = True
# random port numbers to avoid conflict
port = 10000 + random.randrange(40000)
if '@' in hostnames[0]:
@@ -614,9 +675,15 @@ class Compiler:
run = lambda i: connections[i].run(
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
(destinations[i], vm, i, self.prog.name, party0, port,
' '.join(args + N)))
threads.append(threading.Thread(target=run, args=(i,)))
' '.join(args + N)), hide=hide_option)
if self.options.tidy_output:
threads.append(threading.Thread(target=run_and_capture_outputs, args=(outputs, run, i,)))
else:
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if self.options.tidy_output:
for out in outputs:
print(out)

View File

@@ -2,6 +2,7 @@ from collections import defaultdict
REG_MAX = 2 ** 32
USER_MEM = 8192
MEM_MAX = 2 ** 64
P_VALUES = { 32: 2147565569, \
64: 9223372036855103489, \

View File

@@ -1,3 +1,8 @@
""" This module implements `Dijkstra's algorithm
<https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>`_ based on
oblivious RAM. """
from Compiler.oram import *
from Compiler.program import Program
@@ -222,7 +227,21 @@ class HeapQ(object):
print_ln()
print_ln()
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
debug=False):
""" Securely compute Dijstra's algorithm on a secret graph. See
:download:`../Programs/Source/dijkstra_example.mpc` for an
explanation of the required inputs.
:param source: source node (secret or clear-text integer)
:param edges: ORAM representation of edges
:param e_index: ORAM representation of vertices
:param oram_type: ORAM type to use internally (default:
:py:func:`~Compiler.oram.OptimalORAM`)
:param n_loops: when to stop (default: number of edges)
:param int_type: secret integer type (default: sint)
"""
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
@@ -246,10 +265,12 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
last_edge = MemValue(basic_type(1))
i_edge = MemValue(int_type(0))
u = MemValue(basic_type(0))
running = MemValue(basic_type(1))
@for_range(n_loops or edges.size)
def f(i):
print_ln('loop %s', i)
time()
running.write(last_edge.bit_not().bit_or(Q.size > 0).bit_and(running))
u.write(if_else(last_edge, Q.pop(last_edge), u))
#visited.access(u, True, last_edge)
i_edge.write(int_type(if_else(last_edge, e_index[u], i_edge)))
@@ -261,30 +282,50 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
dv, not_visited = dist.read(v)
# relying on default dv negative here
is_shorter = (alt < int_type(dv[0])) + not_visited
is_shorter *= running
dist.access(v, (basic_type(alt), u), is_shorter)
#previous.access(v, u, is_shorter)
Q.update(v, basic_type(alt), is_shorter)
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \
not_visited.reveal())
if debug:
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, '
'shorter: %s, running: %s, queue size: %s, last edge: %s',
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(),
not_visited.reveal(), is_shorter.reveal(),
running.reveal(), Q.size.reveal(), last_edge.reveal())
return dist
def convert_graph(G):
edges = [None] * (2 * G.size())
e_index = [None] * (len(G))
i = 0
for v in G:
e_index[v] = i
for u in G[v]:
edges[i] = [u, G[v][u]['weight'], 0]
i += 1
edges[i-1][-1] = 1
return edges, e_index
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
""" Convert a `NetworkX directed graph
<https://networkx.org/documentation/stable/reference/classes/digraph.html>`_
to the cleartext representation of what :py:func:`dijkstra` expects. """
G = G.copy()
for u in G:
for v in G[u]:
G[u][v].setdefault('weight', 1)
edges = [None] * (2 * G.size())
e_index = [None] * (len(G))
i = 0
for v in sorted(G):
e_index[v] = i
for u in sorted(G[v]):
edges[i] = [u, G[v][u]['weight'], 0]
i += 1
if not G[v]:
edges[i] = [v, 0, 0]
i += 1
edges[i-1][-1] = 1
return list(filter(lambda x: x, edges)), e_index
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None,
int_type=sint):
""" Securely compute Dijstra's algorithm on a cleartext graph.
:param G: directed graph with NetworkX interface
:param source: source node (secret or clear-text integer)
:param n_loops: when to stop (default: number of edges)
:param int_type: secret integer type (default: sint)
"""
edges_list, e_index_list = convert_graph(G)
edges = oram_type(len(edges_list), \
entry_size=(log2(len(G)), log2(len(G)), 1), \
@@ -558,7 +599,7 @@ def test_stupid_dijkstra_on_cycle(n, n_loops=None):
@for_range(n)
def f(i):
M[i][(i+1)%n] = ExtInt(1)
M[i][(i-1)%n] = ExtInt(1)
M[i][(i-1+n)%n] = ExtInt(1)
if n_loops is not None:
stop_timer(1)
start_timer()

View File

@@ -39,26 +39,25 @@ def maskRing(a, k):
c = ((a + r_prime) << shift).reveal(False) >> shift
return c, r
def maskField(a, k, kappa):
def maskField(a, k):
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
r = [types.sint() for i in range(k)]
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
comparison.PRandM(r_dprime, r_prime, r, k, k)
# always signed due to usage in equality testing
a += two_power(k)
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
return c, r
@instructions_base.ret_cisc
def EQZ(a, k, kappa):
def EQZ(a, k):
prog = program.Program.prog
if prog.use_split():
from GC.types import sbitvec
from Compiler.GC.types import sbitvec
v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sintbit.conv(bit)
prog.non_linear.check_security(kappa)
return prog.non_linear.eqz(a, k)
def bits(a,m):
@@ -99,12 +98,12 @@ def or_op(a, b, void=None):
def mul_op(a, b, void=None):
return a * b
def PreORC(a, kappa=None, m=None, raw=False):
def PreORC(a, m=None, raw=False):
k = len(a)
if k == 1:
return [a[0]]
prog = program.Program.prog
kappa = kappa or prog.security
kappa = prog.security
m = m or k
if isinstance(a[0], types.sgf2n):
max_k = program.Program.prog.galois_length - 1
@@ -128,13 +127,13 @@ def PreORC(a, kappa=None, m=None, raw=False):
t = [types.sint() for i in range(m)]
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
for i in range(m):
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
comparison.Mod2(t[i], b[k-1-i], k, False)
p[m-1-i] = 1 - t[i]
return p
else:
# not constant-round anymore
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
s = [PreORC(a[i:i+max_k], raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], raw=raw)
return sum(([or_op(x, y) for x in si]
for si,y in zip(s[1:],t)), s[0])[-m:]
@@ -175,6 +174,41 @@ def PreOpL2(op, items):
output[2 * i] = op(v[i - 1], items[2 * i])
return output
def PreOpL2_vec(op, *items):
""" Vectorized version of :py:func:`PreOpL2` """
k = len(items[0])
for x in items:
assert len(x) == k
if k == 1:
return items
half = k // 2
other_half = (k + 1) // 2 - 1
u = op([x.get_vector(base=0, size=half, skip=2) for x in items],
[x.get_vector(base=1, size=half, skip=2) for x in items])
assert len(u) == len(items)
assert len(u[0]) == half
v = PreOpL2_vec(op, *u)
if other_half:
w = op([x.get_vector(base=0, size=other_half) for x in v],
[x.get_vector(base=2, size=other_half, skip=2) for x in items])
if half == other_half:
res = [type(x).zip(x, y) for x, y in zip(v, w)]
for i in range(len(res)):
res[i] = type(res[i]).concat((items[i].get_vector(base=0, size=1),
res[i]))
else:
if other_half:
for i in range(len(w)):
w[i] = type(w[i]).concat((items[i].get_vector(base=0, size=1),
w[i]))
else:
w = [x.get_vector(base=0, size=1) for x in items]
res = [type(x).zip(x, y) for x, y in zip(w, v)]
assert len(res) == len(items)
for x in res:
assert len(x) == k
return res
def PreOpN(op, items):
""" Naive PreOp algorithm """
k = len(items)
@@ -184,9 +218,9 @@ def PreOpN(op, items):
output[i] = op(output[i-1], items[i])
return output
def PreOR(a, kappa=None, raw=False):
def PreOR(a=None, raw=False):
if comparison.const_rounds:
return PreORC(a, kappa, raw=raw)
return PreORC(a, raw=raw)
else:
return PreOpL(or_op, a)
@@ -199,24 +233,24 @@ def KOpL(op, a):
t2 = KOpL(op, a[k//2:])
return op(t1, t2)
def KORL(a, kappa=None):
def KORL(a):
""" log rounds k-ary OR """
k = len(a)
if k == 1:
return a[0]
else:
t1 = KORL(a[:k//2], kappa)
t2 = KORL(a[k//2:], kappa)
t1 = KORL(a[:k//2])
t2 = KORL(a[k//2:])
return t1 + t2 - t1.bit_and(t2)
def KORC(a, kappa):
return PreORC(a, kappa, 1)[0]
def KORC(a):
return PreORC(a, 1)[0]
def KOR(a, kappa):
def KOR(a):
if comparison.const_rounds:
return KORC(a, kappa)
return KORC(a)
else:
return KORL(a, None)
return KORL(a)
def KMul(a):
if comparison.const_rounds:
@@ -262,7 +296,7 @@ def BitAdd(a, b, bits_to_compute=None):
s[k] = c[k-1]
return s
def BitDec(a, k, m, kappa, bits_to_compute=None):
def BitDec(a, k, m, bits_to_compute=None):
return program.Program.prog.non_linear.bit_dec(a, k, m)
def BitDecRingRaw(a, k, m):
@@ -270,7 +304,7 @@ def BitDecRingRaw(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m
if program.Program.prog.use_split():
x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
bits = types._bitint.bit_adder(x[0], x[1])
return bits[:m]
else:
if program.Program.prog.use_edabit():
@@ -292,13 +326,14 @@ def BitDecRing(a, k, m):
# reversing to reduce number of rounds
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
r = [types.sint() for i in range(m)]
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
comparison.PRandM(r_dprime, r_prime, r, k, m)
kappa = program.Program.prog.security
pow2 = two_power(k + kappa)
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
@@ -306,16 +341,16 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
return res
@instructions_base.bit_cisc
def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
def BitDecField(a, k, m, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, bits_to_compute)
return [types.sintbit.conv(bit) for bit in res]
@instructions_base.ret_cisc
def Pow2(a, l, kappa):
def Pow2(a, l):
comparison.program.curr_tape.require_bit_length(l - 1)
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
t = BitDec(a, m, m)
return Pow2_from_bits(t)
def Pow2_from_bits(bits):
@@ -327,11 +362,12 @@ def Pow2_from_bits(bits):
t[i] = t[i]*pow2k[i] + 1 - t[i]
return KMul(t)
def B2U(a, l, kappa):
pow2a = Pow2(a, l, kappa)
return B2U_from_Pow2(pow2a, l, kappa), pow2a
def B2U(a, l):
pow2a = Pow2(a, l)
return B2U_from_Pow2(pow2a, l), pow2a
def B2U_from_Pow2(pow2a, l, kappa):
def B2U_from_Pow2(pow2a, l):
kappa = program.Program.prog.security
r = [types.sint() for i in range(l)]
t = types.sint()
c = types.cint()
@@ -353,17 +389,17 @@ def B2U_from_Pow2(pow2a, l, kappa):
c = list(r_bits[0].bit_decompose_clear(c, l))
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
#print ' '.join(str(b.value) for b in x)
y = PreOR(x, kappa)
y = PreOR(x)
#print ' '.join(str(b.value) for b in y)
return [types.sint.conv(1 - y[i]) for i in range(l)]
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
def Trunc(a, l, m, compute_modulo=False, signed=False):
""" Oblivious truncation by secret m """
prog = program.Program.prog
if util.is_constant(m) and not compute_modulo:
# cheaper
res = type(a)(size=a.size)
comparison.Trunc(res, a, l, m, kappa, signed=signed)
comparison.Trunc(res, a, l, m, signed=signed)
return res
if l == 1:
if compute_modulo:
@@ -371,9 +407,9 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
else:
return a * (1 - m)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, Pow2(m, l, kappa))
return TruncInRing(a, l, Pow2(m, l))
else:
kappa = kappa or program.Program.prog.security
kappa = program.Program.prog.security
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
@@ -381,7 +417,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
c = types.cint()
ci = [types.cint() for i in range(l)]
d = types.sint()
x, pow2m = B2U(m, l, kappa)
x, pow2m = B2U(m, l)
for i in range(l):
bit(r[i])
t1 = two_power(i) * r[i]
@@ -398,7 +434,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
@@ -429,33 +465,33 @@ def TruncInRing(to_shift, l, pow2m):
def SplitInRing(a, l, m):
if l == 1:
return m.if_else(a, 0), m.if_else(0, a), 1
pow2m = Pow2(m, l, None)
pow2m = Pow2(m, l)
upper = TruncInRing(a, l, pow2m)
lower = a - upper * pow2m
return lower, upper, pow2m
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
def TruncRoundNearestAdjustOverflow(a, length, target_length):
t = comparison.TruncRoundNearest(a, length, length - target_length)
overflow = t.greater_equal(two_power(target_length), target_length + 1)
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
return s, overflow
def Int2FL(a, gamma, l, kappa=None):
def Int2FL(a, gamma, l):
lam = gamma - 1
s = a.less_than(0, gamma, security=kappa)
z = a.equal(0, gamma, security=kappa)
s = a.less_than(0, gamma)
z = a.equal(0, gamma)
a = s.if_else(-a, a)
a_bits = a.bit_decompose(lam, security=kappa)
a_bits = a.bit_decompose(lam)
a_bits.reverse()
b = PreOR(a_bits, kappa)
b = PreOR(a_bits)
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
p = a.popcnt_bits(b) - lam
if gamma - 1 > l:
if types.sfloat.round_nearest:
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l)
p = p + overflow
else:
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
v = t.right_shift(gamma - l - 1, gamma - 1, signed=False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * z.bit_not()
@@ -466,37 +502,38 @@ def FLRound(x, mode):
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
a = types.sint()
comparison.LTZ(a, p1, k, x.kappa)
b = p1.less_than(-l + 1, k, x.kappa)
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
c = EQZ(v2, l, x.kappa)
comparison.LTZ(a, p1, k)
b = p1.less_than(-l + 1, k)
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, compute_modulo=True)
c = EQZ(v2, l)
if mode == -1:
away_from_zero = 0
mode = x.s
else:
away_from_zero = mode + s1 - 2 * mode * s1
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
d = v.equal(two_power(l), l + 1, x.kappa)
d = v.equal(two_power(l), l + 1)
v = d * two_power(l-1) + (1 - d) * v
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
s = (1 - b * mode) * s1
z = or_op(EQZ(v, l, x.kappa), z1)
z = or_op(EQZ(v, l), z1)
v = v * (1 - z)
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
return v, p, z, s
@instructions_base.ret_cisc
def TruncPr(a, k, m, kappa=None, signed=True):
def TruncPr(a, k, m, signed=True):
""" Probabilistic truncation [a/2^m + u]
where Pr[u = 1] = (a % 2^m) / 2^m
"""
nl = program.Program.prog.non_linear
nl.check_security(kappa)
return nl.trunc_pr(a, k, m, signed)
def TruncPrRing(a, k, m, signed=True):
if m == 0:
return a
prog = program.Program.prog
prog.trunc_pr_warning()
n_ring = int(program.Program.prog.options.ring)
comparison.require_ring_size(k, 'truncation')
if k == n_ring:
@@ -517,7 +554,6 @@ def TruncPrRing(a, k, m, signed=True):
trunc_pr(res, a, k, m)
else:
# extra bit to mask overflow
prog = program.Program.prog
prog.curr_tape.require_bit_length(1)
if prog.use_edabit() or prog.use_split() > 2:
lower = sint.get_random_int(m)
@@ -540,68 +576,67 @@ def TruncPrRing(a, k, m, signed=True):
res -= (1 << (k - m - 1))
return res
def TruncPrField(a, k, m, kappa=None):
def TruncPrField(a, k, m):
if m == 0:
return a
if kappa is None:
kappa = 40
program.Program.prog.trunc_pr_warning()
b = two_power(k-1) + a
r_prime, r_dprime = types.sint(), types.sint()
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
k, m, kappa, use_dabit=False)
k, m, use_dabit=False)
two_to_m = two_power(m)
r = two_to_m * r_dprime + r_prime
c = (b + r).reveal(False)
c = (b + r).reveal(True)
c_prime = c % two_to_m
a_prime = c_prime - r_prime
d = (a - a_prime).field_div(two_to_m)
return d
@instructions_base.ret_cisc
def SDiv(a, b, l, kappa, round_nearest=False):
def SDiv(a, b, l, round_nearest=False):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
w = types.cint(int(2.9142 * 2 ** l)) - 2 * b
x = alpha - b * w
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
comparison.Mod2m(x2, x, 2 * l + 1, l, signed=True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest,
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l,
nearest=round_nearest,
signed=False)
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, nearest=round_nearest,
signed=False)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
x = x1 * x1 + x.round(2 * l + 1, l - 1, nearest=round_nearest,
signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
comparison.Mod2m(x2, x, 2 * l, l, signed=False)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest, signed=False)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
signed=False)
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
return y
def SDiv_mono(a, b, l, kappa):
def SDiv_mono(a, b, l):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
y = TruncPr(y, 2 * l + 1, l + 1)
for i in range(theta-1):
y = y * (alpha + x)
# keep y with l bits
y = TruncPr(y, 3 * l, 2 * l, kappa)
y = TruncPr(y, 3 * l, 2 * l)
x = x**2
# keep x with 2l bits
x = TruncPr(x, 4 * l, 2 * l, kappa)
x = TruncPr(x, 4 * l, 2 * l)
y = y * (alpha + x)
y = TruncPr(y, 3 * l, 2 * l, kappa)
y = TruncPr(y, 3 * l, 2 * l)
return y
# LT bit comparison on shared bit values

View File

@@ -399,7 +399,7 @@ class stop(base.Instruction):
arg_format = ['i']
class use(base.Instruction):
""" Offline data usage. Necessary to avoid reusage while using
r""" Offline data usage. Necessary to avoid reusage while using
preprocessing from files. Also used to multithreading for expensive
preprocessing.
@@ -419,7 +419,7 @@ class use(base.Instruction):
args[2].i}
class use_inp(base.Instruction):
""" Input usage. Necessary to avoid reusage while using
r""" Input usage. Necessary to avoid reusage while using
preprocessing from files.
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
@@ -488,6 +488,70 @@ class join_tape(base.Instruction):
code = base.opcodes['JOIN_TAPE']
arg_format = ['int']
class call_tape(base.DoNotEliminateInstruction):
""" Start tape/bytecode file in same thread. Arguments/return values
starting from :py:obj:`direction` are optional.
:param: tape number (int)
:param: arg (regint)
:param: direction (0 for argument, 1 for return value)
:param: register type (see :py:obj:`vm_types`)
:param: register size (int)
:param: destination register
:param: source register
:param: (repeat from direction)
"""
code = base.opcodes['CALL_TAPE']
arg_format = tools.chain(['int', 'ci'],
tools.cycle(['int','int','int','*w','*']))
@staticmethod
def type_check(reg, type_id):
assert base.vm_types[reg.reg_type] == type_id
def __init__(self, *args, **kwargs):
super(call_tape, self).__init__(*args, **kwargs)
for i in range(2, len(args), 5):
for reg in args[i + 3:i + 5]:
self.type_check(reg, args[i + 1])
assert reg.size == args[i + 2]
assert args[i] in (0, 1)
assert args[i + 4 - args[i]].program == program.curr_tape
assert args[i + 3 + args[i]].program == program.tapes[args[0]]
def get_def(self):
# hide registers from called tape
for i in range(2, len(self.args), 5):
if self.args[i]:
yield self.args[i + 3]
def get_used(self):
# hide registers from called tape
yield self.args[1]
for i in range(2, len(self.args), 5):
if not self.args[i]:
yield self.args[i + 4]
def add_usage(self, req_node):
req_node.num += program.tapes[self.args[0]].req_tree.aggregate()
class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction):
""" Pseudo instruction for arguments in connection with
:py:class:`call_tape`.
:param: destination (register)
:param: register type (see :py:obj:`vm_types`)
"""
code = base.opcodes['CALL_ARG']
arg_format = ['*w','int']
def __init__(self, *args, **kwargs):
super(call_arg, self).__init__(*args, **kwargs)
for i in range(0, len(args), 2):
call_tape.type_check(args[i], args[i + 1])
class crash(base.IOInstruction):
""" Crash runtime if the value in the register is not zero.
@@ -687,6 +751,27 @@ class concats(base.VectorInstruction):
for i in range(1, len(args), 2):
assert args[i] == len(args[i + 1])
class zips(base.Instruction):
""" Zip vectors.
:param: result (sint)
:param: operand (sint)
:param: operand (sint)
"""
__slots__ = []
code = base.opcodes['ZIPS']
arg_format = ['sw','s','s']
is_vec = lambda self: True
def __init__(self, *args):
super(zips, self).__init__(*args)
assert len(args[0]) == len(args[1]) + len(args[2])
assert len(args[1]) == len(args[2])
def get_code(self):
return super(zips, self).get_code(len(self.args[1]))
@base.gf2n
@base.vectorize
class mulc(base.MulBase):
@@ -1653,7 +1738,7 @@ class print_reg_plains(base.IOInstruction):
arg_format = ['s']
class cond_print_plain(base.IOInstruction):
""" Conditionally output clear register (with precision).
r""" Conditionally output clear register (with precision).
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
:param: condition (cint, no output if zero)
@@ -1904,7 +1989,7 @@ class closeclientconnection(base.IOInstruction):
code = base.opcodes['CLOSECLIENTCONNECTION']
arg_format = ['ci']
class writesharestofile(base.IOInstruction):
class writesharestofile(base.VectorInstruction, base.IOInstruction):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
@@ -1917,11 +2002,12 @@ class writesharestofile(base.IOInstruction):
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = tools.chain(['ci'], itertools.repeat('s'))
vector_index = 1
def has_var_args(self):
return True
class readsharesfromfile(base.IOInstruction):
class readsharesfromfile(base.VectorInstruction, base.IOInstruction):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
:param: number of arguments to follow / number of shares plus two (int)
@@ -1933,6 +2019,7 @@ class readsharesfromfile(base.IOInstruction):
__slots__ = []
code = base.opcodes['READFILESHARE']
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
vector_index = 2
def has_var_args(self):
return True
@@ -2256,7 +2343,7 @@ class convint(base.Instruction):
@base.vectorize
class convmodp(base.Instruction):
""" Convert clear integer register (vector) to clear register
r""" Convert clear integer register (vector) to clear register
(vector). If the bit length is zero, the unsigned conversion is
used, otherwise signed conversion is used. This makes a difference
when computing modulo a prime :math:`p`. Signed conversion of
@@ -2734,13 +2821,11 @@ class check(base.Instruction):
@base.gf2n
@base.vectorize
class sqrs(base.CISC):
""" Secret squaring $s_i = s_j \cdot s_j$. """
r""" Secret squaring $s_i = s_j \cdot s_j$. """
__slots__ = []
arg_format = ['sw', 's']
def expand(self):
if program.options.ring:
return muls(self.args[0], self.args[1], self.args[1])
s = [program.curr_block.new_reg('s') for i in range(6)]
c = [program.curr_block.new_reg('c') for i in range(2)]
square(s[0], s[1])

View File

@@ -68,6 +68,8 @@ opcodes = dict(
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
CMDLINEARG = 0xEB,
CALL_TAPE = 0xEC,
CALL_ARG = 0xED,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -85,6 +87,7 @@ opcodes = dict(
PREFIXSUMS = 0x2D,
PICKS = 0x2E,
CONCATS = 0x2F,
ZIPS = 0x3F,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
@@ -223,6 +226,17 @@ opcodes = dict(
)
vm_types = dict(
ci = 0,
sb = 1,
cb = 2,
s = 4,
c = 5,
sg = 6,
cg = 7,
)
def int_to_bytes(x):
""" 32 bit int to big-endian 4 byte conversion. """
assert(x < 2**32 and x >= -2**32)
@@ -491,7 +505,8 @@ def cisc(function, n_outputs=1):
reset_global_vector_size()
program.curr_tape = old_tape
for x, bl in tape.req_bit_length.items():
old_tape.require_bit_length(bl - 1, x)
old_tape.require_bit_length(
bl - 1, x, tape.bit_length_reason if x == 'p' else '')
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
@@ -516,40 +531,48 @@ def cisc(function, n_outputs=1):
inst.copy(size, subs)
reset_global_vector_size()
def expand_to_function(self, size, new_regs):
key = size, program.curr_tape, \
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
tuple(type(reg) for reg in new_regs)
if key not in self.functions:
from Compiler import library, types
class Arg:
def __init__(self, reg):
from Compiler.GC.types import bits
class Arg:
def __init__(self, reg):
self.type = type(reg)
self.binary = isinstance(reg, bits)
self.reg = reg
# if reg is not None:
# program.base_addresses[reg] = None
def new(self):
if self.binary:
return self.type()
else:
return self.type(size=size)
def load(self):
return self.reg
def store(self, reg):
if self.type != type(None):
self.reg.update(reg)
args = [Arg(x) for x in new_regs]
self.type = type(reg)
self.binary = isinstance(reg, bits)
self.reg = reg
def new(self, size):
if self.binary:
return self.type()
else:
return self.type(size=size)
def load(self):
return self.reg
def store(self, reg):
if self.type != type(None):
self.reg.update(reg)
def is_real(self):
return self.reg is not None
def base_key(self, size, new_regs):
return size, tuple(
arg for arg, reg in zip(self.args, new_regs) if reg is None), \
tuple(type(reg) for reg in new_regs)
@staticmethod
def get_name(key):
return '_'.join(['%s(%d)' % (function.__name__, key[0])] +
[str(x) for x in key[1]])
def expand_to_function(self, size, new_regs):
key = self.base_key(size, new_regs) + (program.curr_tape,)
if key not in self.functions:
args = [self.Arg(x) for x in new_regs]
from Compiler import library, types
@library.function_block
def f():
res = [arg.new() for arg in args[:n_outputs]]
self.new_instructions(size,
res + [arg.load() for arg in args[n_outputs:]])
res = [arg.new(size) for arg in args[:n_outputs]]
self.new_instructions(
size, res + [arg.load() for arg in args[n_outputs:]])
for reg, arg in zip(res, args):
arg.store(reg)
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
[str(x) for x in key[2]])
f.name = self.get_name(key)
self.functions[key] = f, args
f, args = self.functions[key]
for i in range(len(new_regs) - n_outputs):
@@ -558,6 +581,31 @@ def cisc(function, n_outputs=1):
for i in range(n_outputs):
new_regs[i].link(args[i].load())
def expand_to_tape(self, size, new_regs):
key = self.base_key(size, new_regs)
args = [self.Arg(x) for x in new_regs]
if key not in self.functions:
from Compiler import library, types
@library.function_call_tape
def f(*in_args):
res = [arg.new(size) for arg in args[:n_outputs]]
in_args = list(in_args)
my_args = list(res)
for arg in args[n_outputs:]:
if arg.is_real():
my_args.append(in_args.pop(0))
else:
my_args.append(arg.reg)
self.new_instructions(size, my_args)
return res
f.name = self.get_name(key)
self.functions[key] = f
f = self.functions[key]
in_args = filter(lambda arg: arg.is_real(), args[n_outputs:])
res = util.tuplify(f(*(arg.load() for arg in in_args)))
for i in range(n_outputs):
new_regs[i].link(res[i])
def expand_merged(self, skip):
if function.__name__ in skip:
good = True
@@ -595,7 +643,11 @@ def cisc(function, n_outputs=1):
raise
if program.cisc_to_function and \
(program.curr_tape.singular or program.n_running_threads):
self.expand_to_function(size, new_regs)
if (program.options.garbled or program.options.binary or \
not program.use_tape_calls) and not program.force_cisc_tape:
self.expand_to_function(size, new_regs)
else:
self.expand_to_tape(size, new_regs)
else:
self.new_instructions(size, new_regs)
program.curr_block.n_rounds += self.n_rounds - 1
@@ -795,6 +847,12 @@ class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
name = 'regint'
class AnyRegAF(RegisterArgFormat):
reg_type = '*'
@staticmethod
def check(arg):
assert isinstance(arg, program.curr_tape.Register)
class IntArgFormat(ArgFormat):
n_bits = 32
@@ -898,6 +956,8 @@ ArgFormats = {
'sgw': SecretGF2NAF,
'ci': ClearIntAF,
'ciw': ClearIntAF,
'*': AnyRegAF,
'*w': AnyRegAF,
'i': ImmediateModpAF,
'ig': ImmediateGF2NAF,
'int': IntArgFormat,
@@ -938,6 +998,7 @@ class Instruction(object):
Instruction.count += 1
if Instruction.count % 100000 == 0:
print("Compiled %d lines at" % self.__class__.count, time.asctime())
sys.stdout.flush()
if Instruction.count > 10 ** 7:
print("Compilation produced more that 10 million instructions. "
"Consider using './compile.py -l' or replacing for loops "
@@ -1139,9 +1200,11 @@ class VarArgsInstruction(Instruction):
class VectorInstruction(Instruction):
__slots__ = []
is_vec = lambda self: True
vector_index = 0
def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
return super(VectorInstruction, self).get_code(
len(self.args[self.vector_index]))
class Ciscable(Instruction):
def copy(self, size, subs):

View File

@@ -107,7 +107,7 @@ def print_str(s, *args, print_secrets=False):
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, (list, tuple, Array, SubMultiArray)):
print_str(*_expand_to_print(val))
print_str(*_expand_to_print(val), print_secrets=print_secrets)
else:
try:
val.output()
@@ -314,7 +314,7 @@ def get_cmdline_arg(idx):
return localint(res)
def make_array(l, t=None):
if isinstance(l, Tape.Register):
if isinstance(l, types._structure):
res = Array(len(l), t or type(l))
res[:] = l
else:
@@ -334,13 +334,12 @@ class FunctionTapeCall:
return self
def join(self):
self.thread.join()
instructions.program.free(self.base, 'ci')
for reg_type,addr in self.bases.items():
get_program().free(addr, reg_type.reg_type)
if self.base is not None:
instructions.program.free(self.base, 'ci')
class Function:
def __init__(self, function, name=None, compile_args=[]):
self.type_args = {}
self.last_key = None
self.function = function
self.name = name
if name is None:
@@ -348,46 +347,40 @@ class Function:
self.compile_args = compile_args
def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
from .types import _types
get_reg_type = lambda x: \
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
key = len(args), get_tape()
if key not in self.type_args:
runtime_args = []
reg_args = []
key = self.base_key(),
for i,arg in enumerate(args):
if isinstance(arg, types._vectorizable):
key += (arg.shape, arg.value_type)
else:
arg = MemValue(arg)
reg_args.append(arg)
t = arg.value_type
key += (arg.size, t)
runtime_args.append(arg)
if key != self.last_key:
# first call
type_args = collections.defaultdict(list)
for i,arg in enumerate(args):
if not isinstance(arg, types._vectorizable):
type_args[get_reg_type(arg)].append(i)
outer_runtime_args = runtime_args
def wrapped_function(*compile_args):
base = get_arg()
bases = dict((t, regint.load_mem(base + i)) \
for i,t in enumerate(sorted(type_args,
key=lambda x:
x.reg_type)))
runtime_args = list(args)
for t in sorted(type_args, key=lambda x: x.reg_type):
i = 0
for i_arg in type_args[t]:
runtime_args[i_arg] = t.load_mem(bases[t] + i)
i += util.mem_size(t)
return self.function(*(list(compile_args) + runtime_args))
addresses = regint.Array(len(outer_runtime_args),
address=get_arg())
runtime_args = []
for i, arg in enumerate(outer_runtime_args):
if isinstance(arg, MemValue):
arg = arg.value_type.load_mem(
address=addresses[i], size=arg.size)
runtime_args.append(arg)
self.result = self.function(
*(list(compile_args) + runtime_args))
return self.result
self.on_first_call(wrapped_function)
self.type_args[key] = type_args
type_args = self.type_args[key]
base = instructions.program.malloc(len(type_args), 'ci')
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
for t in type_args)
for i,reg_type in enumerate(sorted(type_args,
key=lambda x: x.reg_type)):
store_in_mem(bases[reg_type], base + i)
j = 0
for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch: "%s" not of type "%s"' %
(args[i_arg], reg_type))
store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type)
return self.on_call(base, bases)
self.last_key = key
addresses = regint.Array(len(runtime_args))
for i, arg in enumerate(reg_args):
addresses[i] = arg.address
return self.on_call(addresses._address,
[(arg.value_type, arg.address) for arg in reg_args])
class FunctionTape(Function):
# not thread-safe
@@ -401,6 +394,171 @@ class FunctionTape(Function):
single_thread=self.single_thread)
def on_call(self, base, bases):
return FunctionTapeCall(self.thread, base, bases)
@staticmethod
def base_key():
pass
class FunctionCallTape(FunctionTape):
def __init__(self, *args, **kwargs):
super(FunctionTape, self).__init__(*args, **kwargs)
self.instances = {}
@staticmethod
def get_key(args, kwargs):
key = (get_program(),)
def process_for_key(arg):
nonlocal key
if isinstance(arg, types._vectorizable):
key += (arg.value_type, tuple(arg.shape))
elif isinstance(arg, Tape.Register):
key += (type(arg), arg.size)
elif isinstance(arg, list):
key += (tuple(arg), 'l')
else:
key += (arg,)
for arg in args:
process_for_key(arg)
for name, arg in sorted(kwargs.items()):
key += (name, 'kw')
process_for_key(arg)
return key
def __call__(self, *args, **kwargs):
key = self.get_key(args, kwargs)
if key not in self.instances:
my_args = []
def wrapped_function():
actual_call_args = []
def process_for_call(arg):
if isinstance(arg, Tape.Register):
my_arg = arg.same_type()
call_arg(my_arg, base.vm_types[my_arg.reg_type])
my_args.append(my_arg)
return my_arg
elif isinstance(arg, types._vectorizable):
my_arg = arg.same_shape(address=regint())
call_arg(my_arg.address, base.vm_types['ci'])
my_args.append(my_arg)
my_arg = arg.same_shape(
address=MemValue(my_arg.address))
return my_arg
actual_call_args.append(my_arg)
else:
my_args.append(arg)
return arg
for arg in args:
actual_call_args.append(process_for_call(arg))
actual_call_kwargs = {}
for name, arg in sorted(kwargs.items()):
actual_call_kwargs[name] = process_for_call(arg)
self.result = self.function(*actual_call_args,
**actual_call_kwargs)
if self.result is not None:
self.result = list(tuplify(self.result))
for i, res in enumerate(self.result):
if util.is_constant(res):
self.result[i] = regint(res)
self.on_first_call(wrapped_function, key, my_args)
for name, arg in sorted(kwargs.items()):
args += arg,
return self.on_call(*self.instances[key], args)
def on_first_call(self, wrapped_function, key, inside_args):
program = get_program()
program.curr_tape
tape_handle = len(program.tapes)
# entry for recursion
self.instances[key] = tape_handle, None, inside_args
assert tape_handle == program.new_tape(
wrapped_function, name=self.name, args=self.compile_args,
single_thread=get_tape().singular, finalize=False,
thread_pool=get_tape().free_threads)
tape = program.tapes[tape_handle]
if self.result is not None:
self.result = list(tuplify(self.result))
for reg in self.result:
reg.can_eliminate = False
tape.return_values.append(reg)
assert not tape.purged
get_program().finalize_tape(tape)
self.instances[key] = tape_handle, self.result, inside_args
def on_call(self, tape_handle, result, inside_args, args):
tape = get_program().tapes[tape_handle]
if tape.ran_threads and tape.free_threads != get_tape().free_threads:
raise CompilerError(
'cannot call thread-running tape from another thread')
assert len(inside_args) == len(args)
out_result = []
call_args = []
if result is not None:
out_result = [reg.same_type() for reg in result]
for x, y in zip(out_result, result):
call_args += [
1, instructions_base.vm_types[x.reg_type],
x.size_for_mem(), x, y]
for x, y in zip(inside_args, args):
if isinstance(x, Tape.Register):
call_args += [
0, instructions_base.vm_types[x.reg_type],
x.size_for_mem(), x, y]
elif isinstance(x, types._vectorizable):
call_args += [0, base.vm_types['ci'], 1,
x.address, regint.conv(y.address)]
call_tape(tape_handle, regint(0),
*call_args)
break_point('call-%s' % self.name)
return untuplify(tuple(out_result))
class ExportFunction(FunctionCallTape):
def __init__(self, function):
super(ExportFunction, self).__init__(function)
self.done = set()
def __call__(self, *args, **kwargs):
if kwargs:
raise CompilerError('keyword arguments not supported')
def arg_signature(arg):
if isinstance(arg, types._structure):
return '%s:%d' % (arg.arg_type(), arg.size)
elif isinstance(arg, types._vectorizable):
from .GC.types import sbitvec
if issubclass(arg.value_type, sbitvec):
return 'sbv:[%dx%d]' % (arg.total_size(),
arg.value_type.n_bits)
else:
return '%s:[%d]' % (arg.value_type.arg_type(),
arg.total_size())
else:
raise CompilerError('argument not supported: %s' % arg)
signature = []
for arg in args:
signature.append(arg_signature(arg))
signature = tuple(signature)
key = self.get_key(args, kwargs)
if key in self.instances and signature not in self.done:
raise CompilerError('signature conflict')
super(ExportFunction, self).__call__(*args, **kwargs)
if signature not in self.done:
filename = '%s/%s/%s-%s' % (get_program().programs_dir, 'Functions',
self.name, '-'.join(signature))
print('Writing to', filename)
out = open(filename, 'w')
print(get_program().name, file=out)
print(self.instances[key][0], file=out)
result = self.instances[key][1]
try:
if result is not None:
result = untuplify(result)
print(arg_signature(result), result.i, file=out)
else:
print('- 0', file=out)
except CompilerError:
raise CompilerError('return type not supported: %s' % result)
for arg in self.instances[key][2]:
if isinstance(arg, types._structure):
print(arg.i, end=' ', file=out)
elif isinstance(arg, types._vectorizable):
print(arg.address, end=' ', file=out)
else:
CompilerError('argument not supported: %s', arg)
print(file=out)
self.done.add(signature)
def function_tape(function):
return FunctionTape(function)
@@ -413,18 +571,111 @@ def function_tape_with_compile_args(*args):
def single_thread_function_tape(function):
return FunctionTape(function, single_thread=True)
def memorize(x):
if isinstance(x, (tuple, list)):
return tuple(memorize(i) for i in x)
def function_call_tape(function):
if get_program().use_tape_calls:
return FunctionCallTape(function)
else:
return MemValue(x)
return function
def method_call_tape(function):
tapes = {}
def wrapper(self, *args, **kwargs):
def use(name):
x = self.__dict__[name]
return not isinstance(x, types.MultiArray) or \
x.array._address is not None
key = (type(self),) + tuple(filter(use, sorted(self.__dict__)))
member_key = key[1:]
if key not in tapes:
def f(*args, **kwargs):
class Dummy(type(self)):
__init__ = lambda self: None
dummy = Dummy()
members = args[:len(member_key)]
real_args = args[len(member_key):]
addresses = {}
for name, member in zip(member_key, members):
dummy.__dict__[name] = member
if isinstance(member, types._vectorizable):
addresses[name] = member.address
res = function(dummy, *real_args, **kwargs)
for name, member in zip(member_key, members):
new_member = dummy.__dict__[name]
desc = '%s in %s.%s' % (name, type(self).__name__,
function.__name__)
if id(new_member) != id(member):
raise CompilerError('cannot change members '
'in method tape (%s)' % desc)
if isinstance(member, types._vectorizable) and \
id(new_member.address) != id(addresses[name]):
raise CompilerError('cannot change memory address '
'in method tape (%s)' % desc)
if set(member_key) != set(dummy.__dict__):
raise CompilerError('cannot add members '
'in method tape (%s)' % desc)
return res
f.__name__ = '%s-%s' % (type(self).__name__, function.__name__)
tapes[key] = function_call_tape(f)
members = tuple(self.__dict__[x] for x in member_key)
res = tapes[key](*(members + args), **kwargs)
return res
return wrapper
def function(function):
""" Create a run-time function. The arguments can be memory or basic
types, and return values can be basic types::
@function
def f(x, y, z):
y.write(1)
z[0] = 2
return x + 3
a = MemValue(sint(0))
b = sint.Array(10)
c = f(sint(4), a, b)
print_ln('%s %s %s', a.reveal(), b[0].reveal(), c.reveal())
This should output::
1 2 7
You can use run-time functions recursively but without return
values in this case.
"""
return FunctionCallTape(function)
def export(function):
return ExportFunction(function)
def memorize(x, write=True):
if isinstance(x, (tuple, list)):
return tuple(memorize(i, write=write) for i in x)
elif x is None:
return
else:
return MemValue(x, write=write)
def unmemorize(x):
if isinstance(x, (tuple, list)):
return tuple(unmemorize(i) for i in x)
elif x is None:
return
else:
return x.read()
def write_mem(dest, source):
if isinstance(dest, (tuple, list)):
assert len(dest) == len(source)
for x, y in zip(dest, source):
write_mem(x, y)
elif dest is None:
return
else:
dest.write(source)
class FunctionBlock(Function):
def on_first_call(self, wrapped_function):
p_return_address = get_tape().program.malloc(1, 'ci')
@@ -470,6 +721,10 @@ class FunctionBlock(Function):
if self.result is not None:
return unmemorize(self.result)
@staticmethod
def base_key():
return get_tape()
def function_block(function):
return FunctionBlock(function)
@@ -704,8 +959,10 @@ def for_range(start, stop=None, step=None):
"""
def decorator(loop_body):
get_tape().unused_decorators.pop(decorator)
range_loop(loop_body, start, stop, step)
return loop_body
get_tape().unused_decorators[decorator] = 'for_range'
return decorator
def for_range_parallel(n_parallel, n_loops):
@@ -754,7 +1011,7 @@ def for_range_opt(start, stop=None, step=None, budget=None):
:param start/stop/step: int/regint/cint (used as in :py:func:`range`)
or :py:obj:`start` only as list/tuple of int (see below)
:param budget: number of instructions after which to start optimization
(default is 100,000)
(default is 1000 or as given with ``--budget``)
Example:
@@ -777,7 +1034,8 @@ def for_range_opt(start, stop=None, step=None, budget=None):
if stop is not None:
start, stop, step = _range_prep(start, stop, step)
def wrapper(loop_body):
n_loops = (step - 1 + stop - start) // step
range_ = stop-start
n_loops = ((range_% step) != 0) + range_ // step
@for_range_opt(n_loops, budget=budget)
def _(i):
return loop_body(start + i * step)
@@ -907,11 +1165,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
r = reducer(tuplify(loop_body(j)), mem_state)
write_state_to_memory(r)
state = mem_state
for i,x in enumerate(state):
if use_array:
mem_state[i] = x
else:
mem_state[i].write(x)
if use_array and len(state) and \
isinstance(types._register, types._vectorizable):
mem_state[:] = state.get_vector()
else:
for i,x in enumerate(state):
if use_array:
mem_state[i] = x
else:
mem_state[i].write(x)
def returner():
return untuplify(tuple(state))
return returner
@@ -987,7 +1249,7 @@ def multithread(n_threads, n_items=None, max_size=None):
.. code::
@multithread(8, 25)
@multithread(3, 25)
def f(base, size):
...
"""
@@ -1077,7 +1339,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
return loop_body(base + i)
prog = get_program()
thread_args = []
if prog.curr_tape == prog.tapes[0]:
if prog.curr_tape.singular:
prog.n_running_threads = n_threads
if not util.is_zero(thread_rounds):
prog.prevent_breaks = False
@@ -1261,7 +1523,7 @@ def while_loop(loop_body, condition, arg=None, g=None):
result = loop_body(arg)
if isinstance(result, MemValue):
result = result.read()
arg.update(result)
arg.link(type(arg)(result))
return condition(result)
if not isinstance(pre_condition, (bool,int)) or pre_condition:
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
@@ -1288,11 +1550,26 @@ def while_do(condition, *args):
return loop_body
return decorator
def _run_and_link(function, g=None):
def _run_and_link(function, g=None, lock_lists=True, allow_return=False):
if g is None:
g = function.__globals__
if lock_lists:
class A(list):
def __init_(self, l):
self[:] = l
def __setitem__(*args):
raise Exception('you cannot change lists in branches, '
'use Array or MultiArray instead')
__delitem__ = append = clear = extend = insert = __setitem__
pop = remove = reverse = sort = __setitem__
for x in g:
if isinstance(g[x], list):
g[x] = A(g[x])
pre = copy.copy(g)
res = function()
if res is not None and not allow_return:
raise CompilerError('Conditional blocks cannot return values. '
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else')
_link(pre, g)
return res
@@ -1325,7 +1602,7 @@ def do_while(loop_fn, g=None):
name='begin-loop')
get_tape().loop_breaks.append([])
loop_block = instructions.program.curr_block
condition = _run_and_link(loop_fn, g)
condition = _run_and_link(loop_fn, g, allow_return=True)
if callable(condition):
condition = condition()
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
@@ -1347,7 +1624,9 @@ def if_then(condition):
condition = condition()
try:
if not condition.is_clear:
raise CompilerError('cannot branch on secret values')
raise CompilerError(
'cannot branch on secret values, use if_else instead: '
'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.if_else')
except AttributeError:
pass
state.condition = regint.conv(condition)
@@ -1488,30 +1767,30 @@ def else_(body):
end_if()
def and_(*terms):
res = regint(0)
for term in terms:
if_then(term())
old_res = res
res = regint(1)
res.link(old_res)
for term in terms:
else_then()
end_if()
def load_result():
res = regint(0)
for term in terms:
if_then(term())
old_res = res
res = regint(1)
res.link(old_res)
for term in terms:
else_then()
end_if()
return res
return load_result
def or_(*terms):
res = regint(1)
for term in terms:
if_then(term())
else_then()
old_res = res
res = regint(0)
res.link(old_res)
for term in terms:
end_if()
def load_result():
res = regint(1)
for term in terms:
if_then(term())
else_then()
old_res = res
res = regint(0)
res.link(old_res)
for term in terms:
end_if()
return res
return load_result
@@ -1570,14 +1849,25 @@ def listen_for_clients(port):
"""
instructions.listen(regint.conv(port))
def accept_client_connection(port):
def accept_client_connection(port, players=None):
""" Accept client connection on specific port base.
:param port: port base (int/regint/cint)
:param players: subset of players (default: all)
:returns: client id
"""
res = regint()
instructions.acceptclientconnection(res, regint.conv(port))
if players is None:
instructions.acceptclientconnection(res, regint.conv(port))
else:
@if_e(sum(regint(players) ==
get_player_id()._v.expand_to_vector(len(players))))
def _():
res.update(accept_client_connection(port))
@else_
def _():
res.update(-1)
return res
def init_client_connection(host, port, my_id, relative_port=True):
@@ -1697,14 +1987,14 @@ def cint_cint_division(a, b, k, f):
from Compiler.program import Program
@instructions_base.ret_cisc
def sint_cint_division(a, b, k, f, kappa, nearest=False):
def sint_cint_division(a, b, k, f, nearest=False):
"""
type(a) = sint, type(b) = cint
"""
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k)
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
@@ -1714,20 +2004,20 @@ def sint_cint_division(a, b, k, f, kappa, nearest=False):
W = w0
for i in range(1, theta):
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
A = (A * W).round(2 * k, f, nearest=nearest, signed=True)
temp = (B * W + 2 * (f - 1)) >> f
W = two - temp
B = temp
return (sign_a * sign_b) * A
def IntDiv(a, b, k, kappa=None):
def IntDiv(a, b, k):
l = 2 * k + 1
b = a.conv(b)
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
kappa, nearest=True)
nearest=True)
@instructions_base.ret_cisc
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
"""
Goldschmidt method as presented in Catrina10,
"""
@@ -1750,40 +2040,40 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
w = AppRcr(b, k, f, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
y = a.extend(l_y) * w
y = y.round(l_y, f, kappa, nearest, signed=True)
y = y.round(l_y, f, nearest, signed=True)
for i in range(theta - 1):
x = x.extend(2 * k)
y = y.extend(l_y) * (alpha + x).extend(l_y)
x = x * x
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
y = y.round(l_y, 2*f, nearest, signed=True)
x = x.round(2*k, 2*f, nearest, signed=True)
x = x.extend(2 * k)
y = y.extend(l_y) * (alpha + x).extend(l_y)
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
y = y.round(l_y, 3 * f - res_f, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
def AppRcr(b, k, f, simplex_flag=False, nearest=False):
"""
Approximate reciprocal of [b]:
Given [b], compute [1/b]
"""
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
c, v = b.Norm(k, f, kappa, simplex_flag)
c, v = b.Norm(k, f, simplex_flag)
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
d = alpha - 2 * c
w = d * v
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
w = w.round(2 * k + 1, 2 * (k - f), nearest, signed=True)
# now w * 2 ^ {-f} should be an initial approximation of 1/b
return w
def Norm(b, k, f, kappa, simplex_flag=False):
def Norm(b, k, f, simplex_flag=False):
"""
Computes secret integer values [c] and [v_prime] st.
2^{k-1} <= c < 2^k and c = b*v_prime
@@ -1799,8 +2089,8 @@ def Norm(b, k, f, kappa, simplex_flag=False):
absolute_val = sign * b
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
suffixes = PreOR(bits, kappa)[::-1]
bits = absolute_val.bit_decompose(k, maybe_mixed=True)[::-1]
suffixes = PreOR(bits)[::-1]
z = [0] * k
for i in range(k - 1):

View File

@@ -203,6 +203,20 @@ def _no_mem_warnings(function):
copy_doc(wrapper, function)
return wrapper
def _layer_method_call_tape(function):
function = method_call_tape(function)
def wrapper(self, *args, **kwargs):
self._Y.alloc()
if self.inputs and len(self.inputs) == 1:
backup = self.inputs
del self.inputs
res = function(self, *args, **kwargs)
self.inputs = backup
return res
else:
return function(self, *args, **kwargs)
return wrapper
class Tensor(MultiArray):
def __init__(self, *args, **kwargs):
kwargs['alloc'] = False
@@ -259,6 +273,7 @@ class Layer:
def Y(self, value):
self._Y = value
@_layer_method_call_tape
def forward(self, batch=None, training=None):
if batch is None:
batch = Array.create_from(regint(0))
@@ -834,7 +849,7 @@ class Dense(DenseBase):
prod = MultiArray([N, self.d, self.d_out], sfix)
else:
prod = self.f_input
max_size = get_program().budget // self.d_out
max_size = get_program().budget
@multithread(self.n_threads, N, max_size)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
@@ -1045,7 +1060,8 @@ class ElementWiseLayer(NoVariableLayer):
def _forward(self, batch=[0]):
n_per_item = reduce(operator.mul, self.X.sizes[1:])
@multithread(self.n_threads, len(batch) * n_per_item)
@multithread(self.n_threads, len(batch) * n_per_item,
max_size=program.budget)
def _(base, size):
self.Y.assign_vector(self.f_part(base, size), base)
@@ -1195,6 +1211,7 @@ class MaxPool(PoolBase):
list/tuple of integers
"""
@_layer_method_call_tape
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
@@ -1252,26 +1269,38 @@ class Concat(NoVariableLayer):
self.dimension = dimension
shapes = [inp.shape for inp in inputs]
assert dimension == 3
assert len(shapes) == 2
assert len(shapes[0]) == len(shapes[1])
assert len(shapes[0]) == 4
for shape in shapes:
assert len(shape) == len(shapes[0])
shape = []
for i in range(len(shapes[0])):
if i == dimension:
shape.append(shapes[0][i] + shapes[1][i])
shape.append(sum(x[i] for x in shapes))
else:
assert shapes[0][i] == shapes[1][i]
shape.append(shapes[0][i])
self.Y = Tensor(shape, sfix)
self.bases = [sum(x[dimension] for x in shapes[:k])
for k in range(len(shapes))]
self.addresses = Array.create_from(regint(list(
x.Y.address for x in inputs)))
def _forward(self, batch=[0]):
assert len(batch) == 1
@for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3])
def _(i, j):
X = [x.Y[batch[0]] for x in self.inputs]
self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector())
self.Y[batch[0]][i][j].assign_part_vector(
X[1][i][j].get_vector(),
len(X[0][i][j]))
if len(set(self.bases)) == 1:
@for_range(len(self.inputs))
def _(k):
self.Y[batch[0]][i][j].assign_part_vector(
MultiArray(
self.inputs[0].shape,
address=self.addresses[k])[i][j].get_vector(),
k * self.bases[1])
else:
X = [x.Y[batch[0]] for x in self.inputs]
for k in range(len(self.inputs)):
self.Y[batch[0]][i][j].assign_part_vector(
X[k][i][j].get_vector(), self.bases[k])
class Add(NoVariableLayer):
""" Fixed-point addition layer.
@@ -1287,12 +1316,12 @@ class Add(NoVariableLayer):
self.inputs = inputs
def _forward(self, batch=[0]):
assert len(batch) == 1
@multithread(self.n_threads, self.Y[0].total_size())
def _(base, size):
tmp = sum(inp.Y[batch[0]].get_vector(base, size)
for inp in self.inputs)
self.Y[batch[0]].assign_vector(tmp, base)
for bb in batch:
tmp = sum(inp.Y[bb].get_vector(base, size)
for inp in self.inputs)
self.Y[bb].assign_vector(tmp, base)
class FusedBatchNorm(Layer):
""" Fixed-point fused batch normalization layer (inference only).
@@ -1334,12 +1363,14 @@ class BatchNorm(Layer):
def __init__(self, shape, approx=True, args=None):
assert len(shape) in (2, 3, 4)
self.Y = sfix.Tensor(shape)
if len(shape) == 4:
shape = [shape[0], shape[1] * shape[2], shape[3]]
elif len(shape) == 2:
shape = [shape[0], 1, shape[1]]
tensors = (Tensor(shape, sfix) for i in range(4))
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
self.my_Y = sfix.Tensor(shape, address=self.Y.address)
tensors = (Tensor(shape, sfix) for i in range(3))
self.X, self.nabla_X, self.nabla_Y = tensors
arrays = (sfix.Array(shape[2]) for i in range(4))
self.var, self.mu, self.weights, self.bias = arrays
arrays = (sfix.Array(shape[2]) for i in range(4))
@@ -1369,16 +1400,16 @@ class BatchNorm(Layer):
def _output(self, batch, mu, var):
factor = sfix.Array(len(mu))
factor[:] = self.InvertSqrt(var[:] + self.epsilon)
factor[:] = self.InvertSqrt(var[:] + self.epsilon) * self.weights[:]
@for_range_opt_multithread(self.n_threads,
[len(batch), self.X.sizes[1]])
def _(i, j):
tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:]
self.Y[i][j][:] = self.bias[:] + tmp
tmp = (self.X[i][j][:] - mu[:]) * factor[:]
self.my_Y[i][j][:] = self.bias[:] + tmp
@_layer_method_call_tape
def forward(self, batch, training=False):
if training or not self.is_trained:
self.is_trained = True
d = self.X.sizes[1]
d_in = self.X.sizes[2]
s = sfix.Array(d_in)
@@ -1411,7 +1442,7 @@ class BatchNorm(Layer):
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s',
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
self.Y[i][j][k].reveal())
self.my_Y[i][j][k].reveal())
else:
self._output(batch, self.mu_hat, self.var_hat)
@@ -1471,6 +1502,10 @@ class BatchNorm(Layer):
self.nabla_Y[i][j][k].reveal(),
self.nabla_X[i][j][k].reveal())
def reveal_parameters_to_binary(self):
for param in self.thetas() + (self.mu_hat, self.var_hat):
param.reveal().binary_output()
class QuantBase(object):
bias_before_reduction = True
@@ -1498,11 +1533,12 @@ class QuantBase(object):
class FixBase:
bias_before_reduction = False
@staticmethod
def new_squant():
class _(sfix):
params = None
return _
class my_squant(sfix):
params = None
@classmethod
def new_squant(cls):
return cls.my_squant
def input_params_from(self, player):
pass
@@ -1553,7 +1589,8 @@ class ConvBase(BaseLayer):
cls.temp_inputs = sfix.Array(size)
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride,
padding='SAME', tf_weight_format=False, inputs=None):
padding='SAME', tf_weight_format=False, inputs=None,
weight_type=None):
super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs)
self.weight_shape = weight_shape
@@ -1582,7 +1619,11 @@ class ConvBase(BaseLayer):
else:
self.padding = padding
self.weight_squant = self.new_squant()
if weight_type:
self.weight_squant = weight_type
else:
self.weight_squant = self.new_squant()
self.bias_squant = self.new_squant()
self.weights = Tensor(weight_shape, self.weight_squant)
@@ -1645,8 +1686,7 @@ class ConvBase(BaseLayer):
n_summands = self.n_summands()
#start_timer(2)
n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:])
@multithread(self.n_threads, n_outputs,
1000 if sfix.round_nearest else 10 ** 6)
@multithread(self.n_threads, n_outputs, max_size=program.budget)
def _(base, n_per_thread):
res = self.input_squant().unreduced(
sint.load_mem(unreduced.address + base,
@@ -1709,7 +1749,7 @@ class Conv2d(ConvBase):
res += self.bias.expand_to_vector(j, res.size).v
else:
res += self.bias.expand_to_vector(j, res.size).v << \
self.input_squant.f
self.weight_squant.f
addresses = regint.inc(res.size,
self.unreduced[i * part_size].address + j,
n_channels_out)
@@ -2029,7 +2069,7 @@ class AveragePool2d(BaseLayer):
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
padding=0):
padding=0, **kwargs):
""" More convenient interface to :py:class:`FixConv2d`.
:param input_shape: input shape (tuple/list of four int)
@@ -2050,7 +2090,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
padding = padding.upper() if isinstance(padding, str) \
else padding
return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape,
stride, padding)
stride, padding, **kwargs)
def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
""" More convenient interface to :py:class:`MaxPool`.
@@ -2102,6 +2142,7 @@ class FixAveragePool2d(PoolBase, FixBase):
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
[1] + list(filter_size) + [1], padding)
self.pool_size = reduce(operator.mul, filter_size)
self.d_out = self.Y.shape[-1]
if output_shape:
assert self.Y.shape == list(output_shape)
@@ -2192,7 +2233,7 @@ class Optimizer:
res.output_stats = 'output_stats' in program.args
return res
def __init__(self, layers=[], report_loss=None):
def __init__(self, layers=[], report_loss=None, time_layers=False):
if get_program().options.binary:
raise CompilerError(
'machine learning code not compatible with binary circuits')
@@ -2207,6 +2248,10 @@ class Optimizer:
self.stopped_on_loss = MemValue(0)
self.stopped_on_low_loss = MemValue(0)
self.layers = layers
self.time_layers = time_layers
if time_layers:
for i, layer in enumerate(layers):
print('Timer %d: %s' % (100 + i, repr(layer)))
@property
def layers(self):
@@ -2275,6 +2320,8 @@ class Optimizer:
if self.time_layers:
start_timer(100 + i)
if i != len(self.layers) - 1 or run_last:
for theta in layer.thetas():
theta.alloc()
layer.forward(batch=self.batch_for(layer, batch),
training=training)
if self.print_random_update:
@@ -2322,6 +2369,10 @@ class Optimizer:
self.forward(batch=batch, run_last=False)
part = self.layers[-1].eval(batch_size, top=top)
res.assign_part_vector(part.get_vector(), start)
if self.output_stats:
for layer in self.layers[:-1]:
print_ln(layer)
self.stat(' Y', layer.Y)
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
return res
@@ -2567,9 +2618,9 @@ class Optimizer:
def run_in_batches(self, f, data, batch_size, truth=None):
batch_size = min(batch_size, data.sizes[0])
training_data = self.layers[0].X.address
training_data = self.layers[0]._X.array._address
training_truth = self.layers[-1].Y.address
self.layers[0].X.address = data.address
self.layers[0]._X.address = data.address
if truth:
self.layers[-1].Y.address = truth.address
N = data.sizes[0]
@@ -2620,8 +2671,12 @@ class Optimizer:
if model_input:
for layer in self.layers:
layer.input_from(0)
elif reset:
elif reset and not 'no_reset' in program.args:
self.reset()
else:
for layer in self.layers:
for theta in layer.thetas():
theta.alloc()
if 'one_iter' in program.args:
print_float_prec(16)
self.output_weights()
@@ -2642,6 +2697,8 @@ class Optimizer:
if 'bench10' in program.args or 'bench1' in program.args:
n = 1 if 'bench1' in program.args else 10
print('benchmarking %s iterations' % n)
# force allocatoin
self.layers[0].X, self.layers[-1].Y
@for_range(n)
def _(i):
batch = Array.create_from(regint.inc(batch_size))
@@ -2748,13 +2805,15 @@ class Optimizer:
return list(self.thetas)
def reveal_model_to_binary(self):
input_shape = self.layers[0].X.shape
""" Reveal model and store it in the binary output file, see
:ref:`reveal-model` for details. """
input_shape = self.layers[0]._X.shape
for layer in self.layers:
if len(input_shape) == 4 and isinstance(layer, DenseBase):
layer.reveal_parameters_to_binary(reshape=input_shape[1:])
else:
layer.reveal_parameters_to_binary()
input_shape = layer.Y.shape
input_shape = layer._Y.shape
class Adam(Optimizer):
""" Adam/AMSgrad optimizer.
@@ -2866,10 +2925,12 @@ class SGD(Optimizer):
self.layers = layers
self.n_epochs = n_epochs
self.nablas = []
self.momentum_values = []
self.delta_thetas = []
for layer in layers:
self.nablas.extend(layer.nablas())
for theta in layer.thetas():
self.momentum_values.append(theta.same_shape())
self.delta_thetas.append(theta.same_shape())
self.set_learning_rate(0.01)
self.debug = debug
@@ -2889,23 +2950,27 @@ class SGD(Optimizer):
j = i + label * len(X_by_label[0])
self.layers[0].X[j] = X[i]
self.layers[-1].Y[j] = label
for y in self.momentum_values:
y.assign_all(0)
for y in self.delta_thetas:
y.assign_all(0)
super(SGD, self).reset()
def _update(self, i_epoch, i_batch, batch):
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
self.delta_thetas):
for nabla, theta, momentum_value, delta_theta in zip(self.nablas, self.thetas,
self.momentum_values, self.delta_thetas):
@multithread(self.n_threads, nabla.total_size())
def _(base, size):
old = delta_theta.get_vector(base, size)
old = momentum_value.get_vector(base, size)
red_old = self.momentum * old
rate = self.gamma.expand_to_vector(size)
nabla_vector = nabla.get_vector(base, size)
log_batch_size = math.log(len(batch), 2)
# divide by len(batch) by truncation
# increased rate if len(batch) is not a power of two
pre_trunc = nabla_vector.v * rate.v
diff = red_old - nabla_vector
pre_trunc = diff.v * rate.v
momentum_value.assign_vector(diff, base)
k = max(nabla_vector.k, rate.k) + rate.f
m = rate.f + int(log_batch_size)
if self.early_division:
@@ -2914,8 +2979,7 @@ class SGD(Optimizer):
v = pre_trunc.round(k, m, signed=True,
nearest=sfix.round_nearest)
new = nabla_vector._new(v)
diff = red_old - new
delta_theta.assign_vector(diff, base)
delta_theta.assign_vector(new, base)
theta.assign_vector(theta.get_vector(base, size) +
delta_theta.get_vector(base, size), base)
if self.print_update_average:
@@ -3046,7 +3110,7 @@ class keras:
def summary(self):
self.opt.summary()
def build(self, input_shape, batch_size=128):
def build(self, input_shape, batch_size=128, program=None):
data_input_shape = input_shape
if self.opt != None and \
input_shape == self.opt.layers[0]._X.sizes and \
@@ -3120,8 +3184,11 @@ class keras:
if layers[-1].d_out == 1:
layers.append(Output(data_input_shape[0]))
else:
layers.append(
MultiOutput(data_input_shape[0], layers[-1].d_out))
shape = data_input_shape[0], layers[-1].d_out
if program:
layers.append(MultiOutput.from_args(program, *shape))
else:
layers.append(MultiOutput(*shape))
if self.optimizer[1]:
raise Exception('use keyword arguments for optimizer')
opt = self.optimizer[0]
@@ -3186,11 +3253,11 @@ class keras:
batch_size = min(batch_size, self.batch_size)
return self.opt.eval(x, batch_size=batch_size)
def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
regression=False):
""" Convert a PyTorch Sequential object to MP-SPDZ layers.
def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
regression=False, layer_args={}, program=None):
""" Convert a PyTorch Module object to MP-SPDZ layers.
:param sequence: PyTorch Sequential object
:param model: PyTorch Module object
:param data_input_shape: input shape (list of four int)
:param batch_size: batch size (int)
:param input_via: player to input model data via (default: don't)
@@ -3198,17 +3265,39 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
"""
layers = []
named_layers = {}
def mul(x):
return reduce(operator.mul, x)
def process(item):
nonlocal input_shape
import torch
def process(item, inputs, input_shape, args, kwargs={}):
if item == torch.cat:
if len(inputs) > 1:
layers.append(
Concat(inputs, dimension=len(inputs[0].shape) - 1))
return
elif item == operator.add:
layers.append(Add(inputs))
return
elif item in (torch.flatten, 'flatten', 'size'):
return
elif item == 'view':
assert -1 in args or \
reduce(operator.mul, args) == reduce(operator.mul, input_shape)
return
elif item == torch.nn.functional.avg_pool2d:
layers.append(FixAveragePool2d(input_shape, None, args[1],
kwargs.get('stride', args[1]),
kwargs.get('padding', 0)))
input_shape = layers[-1].Y.shape
return
# single-input layers from here
if inputs and len(inputs) > 1:
raise CompilerError('multi-input layer %s not supported' % item)
name = type(item).__name__
if name == 'Sequential':
for x in item:
process(x)
elif name == 'Linear':
if name == 'Linear':
assert mul(input_shape[1:]) == item.in_features
assert item.bias is not None
layers.append(Dense(input_shape[0], item.in_features,
@@ -3239,7 +3328,7 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
elif name == 'Conv2d':
layers.append(easyConv2d(input_shape, batch_size, item.out_channels,
item.kernel_size, item.stride,
item.padding))
item.padding, **layer_args.get(item, {})))
input_shape = layers[-1].Y.shape
if input_via is not None:
shapes = [x.shape for x in
@@ -3247,11 +3336,14 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
import numpy
swapped = numpy.moveaxis(
numpy.array(item.weight.detach()), 1, -1)
layers[-1].weights = sfix.input_tensor_via(input_via, swapped)
layers[-1].bias = sfix.input_tensor_via(
input_via, item.bias.detach())
layers[-1].weights = \
layers[-1].weights.value_type.input_tensor_via(
input_via, swapped)
assert layers[-1].weights.shape == shapes[0]
assert layers[-1].bias.shape == shapes[1]
if isinstance(item.bias, torch.Tensor):
layers[-1].bias = sfix.input_tensor_via(
input_via, item.bias.detach())
assert layers[-1].bias.shape == shapes[1]
elif name == 'MaxPool2d':
layers.append(easyMaxPool(input_shape, item.kernel_size,
item.stride, item.padding))
@@ -3260,10 +3352,24 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
item.stride, item.padding))
input_shape = layers[-1].Y.shape
elif name == 'ReLU':
elif name == 'AdaptiveAvgPool2d' or \
item == torch.nn.functional.adaptive_avg_pool2d:
if name == 'AdaptiveAvgPool2d':
output = item.output_size
else:
output = args[1]
for i in (0, 1):
assert input_shape[1 + i] % output[i] == 0
stride = [input_shape[1 + i] // output[i] for i in (0, 1)]
kernel_size = [input_shape[1 + i] - (output[i] - 1) * stride[i]
for i in (0, 1)]
layers.append(FixAveragePool2d(input_shape, None, kernel_size,
stride, padding=0))
input_shape = layers[-1].Y.shape
elif name == 'ReLU' or item == torch.nn.functional.relu:
layers.append(Relu(input_shape))
elif name == 'Flatten':
pass
return
elif name == 'BatchNorm2d':
layers.append(BatchNorm(layers[-1].Y.sizes))
if input_via is not None:
@@ -3277,16 +3383,57 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
alpha=item.p))
input_shape = layers[-1].Y.sizes
else:
raise CompilerError('unknown PyTorch module: ' + name)
raise CompilerError('unknown PyTorch module: %s' % item)
layers[-1].inputs = inputs
input_shape = data_input_shape + [1] * (4 - len(data_input_shape))
process(sequence)
torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes)
for i, layer in enumerate(torch_layers[1:-1]):
if layer.op == 'call_module':
target = model
for attr in layer.target.split('.'):
target = getattr(target, attr)
else:
target = layer.target
if not layers:
assert layer.args == (torch_layers[i],)
inputs = None
else:
if len(layer.args) < 2 or (layer.args[1] != 1 and
layer.args[1] != (1, 1)):
args = layer.args
elif isinstance(layer.args[0], list):
args = layer.args[0]
else:
args = layer.args[0],
inputs = []
try:
for x in args:
inputs.append(named_layers[x])
except KeyError:
pass
if len(inputs) == 1:
if isinstance(inputs[0], (Dropout, BatchNorm)):
input_shape = inputs[0].inputs[0].Y.shape
else:
input_shape = inputs[0]._Y.shape
else:
input_shape = None
process(target, inputs, input_shape, layer.args, layer.kwargs)
if layers:
named_layers[layer] = layers[-1]
if regression:
layers.append(LinearOutput(data_input_shape[0], layers[-1].d_out))
elif layers[-1].d_out == 1:
layers.append(Output(data_input_shape[0]))
else:
layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out))
shape = data_input_shape[0], layers[-1].d_out
if program:
layers.append(MultiOutput.from_args(program, *shape))
else:
layers.append(MultiOutput(*shape))
return layers
class OneLayerSGD:
@@ -3335,6 +3482,9 @@ class OneLayerSGD:
class SGDLogistic(OneLayerSGD):
""" Logistic regression using SGD.
The member :py:obj:`opt` refers to the internal instance of
:py:class:`Optimizer`, which allows to use the funcionality
therein.
:param n_epochs: number of epochs
:param batch_size: batch size
@@ -3460,6 +3610,8 @@ def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
def mr(A, n_iterations, stop=False):
""" Iterative matrix inverse approximation.
This is based on the conjugate gradients algorithm in Section
10.2.4 of `these lecture notes <https://graphics.stanford.edu/courses/cs205a-13-fall/assets/notes/cs205a_notes.pdf>`_.
:param A: matrix to invert
:param n_iterations: maximum number of iterations

View File

@@ -1,7 +1,9 @@
"""
Module for math operations.
Implements trigonometric and logarithmic functions.
Most of the functionality is due to `Aly and Smart
<https://eprint.iacr.org/2019/354>`_ with some optimizations by
`Keller and Sun <https://eprint.iacr.org/2022/933>`_.
This has to imported explicitly.
"""
@@ -98,7 +100,7 @@ pi_over_2 = math.radians(90)
# @return truncated sint value of x
def trunc(x):
if isinstance(x, types._fix):
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
return x.v.right_shift(x.f, x.k, signed=True)
elif type(x) is types.sfloat:
v, p, z, s = floatingpoint.FLRound(x, 0)
#return types.sfloat(v, p, z, s, x.err)
@@ -106,19 +108,6 @@ def trunc(x):
return x
##
# loads integer to fractional type (sint)
# @param x: coefficient to be truncated.
#
# @return returns sfix, sfloat loaded value
def load_sint(x, l_type):
if l_type is types.sfix:
return types.sfix.from_sint(x)
elif l_type is types.sfloat:
return x
return x
##
# evaluates a Polynomial to a given x in a privacy preserving manner.
# Inputs can be of any kind of register, secret or otherwise.
@@ -433,7 +422,7 @@ def mux_exp(x, y, block_size=8):
@types.vectorize
@instructions_base.sfix_cisc
def log2_fx(x, use_division=True):
"""
r"""
Returns the result of :math:`\log_2(x)` for any unbounded
number. This is achieved by changing :py:obj:`x` into
:math:`f \cdot 2^n` where f is bounded by :math:`[0.5, 1]`. Then the
@@ -448,7 +437,7 @@ def log2_fx(x, use_division=True):
if isinstance(x, types._fix):
# transforms sfix to f*2^n, where f is [o.5,1] bounded
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f)
p -= x.f
vlen = x.f
v = x._new(v, k=x.k, f=x.f)
@@ -473,8 +462,8 @@ def log2_fx(x, use_division=True):
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
def pow_fx(x, y):
"""
def pow_fx(x, y, zero_output=False):
r"""
Returns the value of the expression :math:`x^y` where both inputs
are secret shared. It uses :py:func:`log2_fx` together with
:py:func:`exp2_fx` to calculate the expression :math:`2^{y \log_2(x)}`.
@@ -494,11 +483,11 @@ def pow_fx(x, y):
# obtains y * log2(x)
exp = y * log2_x
# returns 2^(y*log2(x))
return exp2_fx(exp)
return exp2_fx(exp, zero_output)
def log_fx(x, b):
"""
r"""
Returns the value of the expression :math:`\log_b(x)` where
:py:obj:`x` is secret shared. It uses :py:func:`log2_fx` to
calculate the expression :math:`\log_b(2) \cdot \log_2(x)`.
@@ -535,8 +524,8 @@ def abs_fx(x):
#
# @return floored sint value of x
def floor_fx(x):
return load_sint(x.v.right_shift(x.f, bit_length=x.k, security=x.kappa,
signed=True), type(x))
return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True),
k=x.k, f=x.f)
### sqrt methods
@@ -743,13 +732,13 @@ def lin_app_SQ(b, k, f):
c, v, m, W = norm_SQ(types.sint(b), k)
# c is now escalated
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
w = alpha * c + beta # equation before b and reduction by order of k
# m even or odd determination
m_bit = types.sint()
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
m = load_sint(m_bit, types.sfix)
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False)
m = m_bit
# w times v this way both terms have 2^3k and can be symplified
w = w * v
@@ -774,7 +763,7 @@ def lin_app_SQ(b, k, f):
def sqrt_fx(x_l, k, f):
factor = 1.0 / (2.0 ** f)
x = load_sint(x_l, types.sfix) * factor
x = x_l * factor
theta = int(math.ceil(math.log(k/5.4)))
@@ -870,7 +859,7 @@ def atan(x):
def asin(x):
"""
r"""
Returns the arcsine (sfix) of any given fractional value.
:param x: fractional input (sfix). valid interval is :math:`-1 \le x \le 1`
@@ -886,7 +875,7 @@ def asin(x):
def acos(x):
"""
r"""
Returns the arccosine (sfix) of any given fractional value.
:param x: fractional input (sfix). :math:`-1 \le x \le 1`
@@ -898,7 +887,7 @@ def acos(x):
def tanh(x):
"""
r"""
Hyperbolic tangent. For efficiency, accuracy is diminished
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
:math:`f` denote the fixed-point parameters.
@@ -912,29 +901,29 @@ def tanh(x):
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
def Sep(x):
def Sep(x, sfix=types.sfix):
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
bb = b[:]
while len(bb) < 2 * x.f - 1:
bb.insert(0, type(b[0])(0))
t = x.v * (1 + x.v.bit_compose(b_i.bit_not()
for b_i in bb[-2 * x.f + 1:]))
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
b += [b[0].long_one()]
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
def SqrtComp(z, old=False):
f = types.sfix.f
def SqrtComp(z, old=False, sfix=types.sfix):
f = sfix.f
k = len(z)
if isinstance(z[0], types.sint):
return types.sfix._new(sum(z[i] * types.cfix(
return sfix._new(sum(z[i] * types.cfix(
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
k_prime = k // 2
f_prime = f // 2
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
c0 = types.sfix(2 ** (f / 2 + 1))
c1 = sfix(2 ** ((f + 1) / 2 + 1))
c0 = sfix(2 ** (f / 2 + 1))
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
if old:
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
else:
@@ -942,11 +931,15 @@ def SqrtComp(z, old=False):
return types.sint.conv(b).if_else(c1, c0) * tmp
@types.vectorize
@instructions_base.sfix_cisc
def InvertSqrt(x, old=False):
"""
Reciprocal square root approximation by `Lu et al.
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
"""
u, z = Sep(x)
class my_sfix(types.sfix):
f = x.f
k = x.k
u, z = Sep(x, sfix=my_sfix)
c = 3.14736 + u * (4.63887 * u - 5.77789)
return c * SqrtComp(z, old=old)
return c * SqrtComp(z, old=old, sfix=my_sfix)

View File

@@ -4,14 +4,6 @@ from .types import *
from . import comparison, program
class NonLinear:
kappa = None
def set_security(self, kappa):
pass
def check_security(self, kappa):
pass
def mod2m(self, a, k, m, signed):
"""
a_prime = a % 2^m
@@ -45,18 +37,16 @@ class NonLinear:
def trunc_round_nearest(self, a, k, m, signed):
res = sint()
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa,
signed)
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed)
return res
def trunc(self, a, k, m, kappa, signed):
self.check_security(kappa)
def trunc(self, a, k, m, signed):
if m == 0:
return a
return self._trunc(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return -self.trunc(a, k, k - 1, kappa, True)
def ltz(self, a, k):
return -self.trunc(a, k, k - 1, True)
class Masking(NonLinear):
def eqz(self, a, k):
@@ -68,28 +58,19 @@ class Masking(NonLinear):
class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """
def __init__(self, kappa):
self.set_security(kappa)
def set_security(self, kappa):
self.kappa = kappa
def check_security(self, kappa):
assert self.kappa == kappa or kappa is None
def _mod2m(self, a, k, m, signed):
res = sint()
if m == 1:
Mod2(res, a, k, self.kappa, signed)
Mod2(res, a, k, signed)
else:
Mod2mField(res, a, k, m, self.kappa, signed)
Mod2mField(res, a, k, m, signed)
return res
def _mask(self, a, k):
return maskField(a, k, self.kappa)
return maskField(a, k)
def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m, self.kappa)
return TruncPrField(a, k, m)
def _trunc(self, a, k, m, signed=None):
a_prime = self.mod2m(a, k, m, signed)
@@ -99,12 +80,12 @@ class Prime(Masking):
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecFieldRaw(a, k, m, self.kappa)
return BitDecFieldRaw(a, k, m)
else:
return BitDecField(a, k, m, self.kappa)
return BitDecField(a, k, m)
def kor(self, d):
return KOR(d, self.kappa)
return KOR(d)
class KnownPrime(NonLinear):
""" Non-linear functionality modulo a prime known at compile time. """
@@ -144,13 +125,13 @@ class KnownPrime(NonLinear):
a += two_power(k)
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
def ltz(self, a, k, kappa=None):
def ltz(self, a, k):
if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
return self.mod2m(2 * a, k + 1, 1, False)
else:
return super(KnownPrime, self).ltz(a, k, kappa)
return super(KnownPrime, self).ltz(a, k)
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
@@ -189,5 +170,5 @@ class Ring(Masking):
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
def ltz(self, a, k, kappa=None):
def ltz(self, a, k):
return LtzRing(a, k)

View File

@@ -9,6 +9,11 @@ secret index::
i = sint.get_input_from(0)
a[i] = sint.get_input_from(1)
`The introductory book by Evans et
al. <https://securecomputation.org>`_ contains `a chapter dedicated to
oblivious RAM
<https://securecomputation.org/docs/ch5-obliviousdata.pdf>`_.
"""
import random
@@ -41,6 +46,7 @@ debug_online = False
crash_on_overflow = False
use_insecure_randomness = False
debug_ram_size = False
single_thread = False
def maybe_start_timer(n):
if detailed_timing:
@@ -844,7 +850,7 @@ class TrivialORAM(RefTrivialORAM, AbstractORAM):
start_timer()
def get_n_threads(n_loops):
if n_threads is None:
if n_threads is None and not single_thread:
if n_loops > 2048:
return 8
else:
@@ -877,7 +883,7 @@ class LinearORAM(TrivialORAM):
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@map_sum(get_n_threads(self.size), None, self.size, \
self.value_length + 1, t)
def f(i):
entry = self.ram[i]
@@ -897,7 +903,7 @@ class LinearORAM(TrivialORAM):
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
@for_range_multithread(get_n_threads(self.size), None, self.size)
def f(i):
entry = self.ram[i]
access_here = self.index_vector[i]
@@ -917,7 +923,7 @@ class LinearORAM(TrivialORAM):
max(x or 0 for x in self.entry_size)))
new_empty = MemValue(new_empty)
write = MemValue(write)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@map_sum(get_n_threads(self.size), None, self.size, \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
def f(i):
@@ -1038,7 +1044,7 @@ class LocalIndexStructure(List):
__getitem__ = lambda self,index: List.__getitem__(self, index)[0]
def get_n_threads_for_tree(size):
if n_threads_for_tree is None:
if n_threads_for_tree is None and not single_thread:
if size >= 2**13:
return 8
else:
@@ -1340,8 +1346,8 @@ class TreeORAM(AbstractORAM):
half = (empty_positions[i]+1 - parity) // 2
half_max = self.bucket_size // 2
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
bits = floatingpoint.B2U(half, half_max)[0]
bits2 = floatingpoint.B2U(half+parity, half_max)[0]
# (doesn't work)
#bits2 = [0] * half_max
## second half with parity bit
@@ -1350,7 +1356,8 @@ class TreeORAM(AbstractORAM):
#bits2[0] = (1 - bits[0]) * parity
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
else:
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
bucket_bits = floatingpoint.B2U(empty_positions[i]+1,
self.bucket_size)[0]
assert len(bucket_bits) == self.bucket_size
for j, b in enumerate(bucket_bits):
pos_bits[i * self.bucket_size + j] = [b, leaf]
@@ -1376,8 +1383,7 @@ class TreeORAM(AbstractORAM):
Program.prog.curr_tape.start_new_basicblock()
bucket_sizes = Array(2**self.D, regint)
for i in range(2**self.D):
bucket_sizes[i] = 0
bucket_sizes.assign_all(0)
@for_range_opt(len(entries))
def _(k):
@@ -1697,7 +1703,7 @@ class OneLevelORAM(TreeORAM):
class BinaryORAM:
def __init__(self, size, value_type=None, **kwargs):
import circuit_oram
from Compiler import circuit_oram
from Compiler.GC import types
n_bits = int(get_program().options.binary)
self.value_type = value_type or types.sbitintvec.get_type(n_bits)

View File

@@ -1,4 +1,4 @@
"""This module contains an implementation of the "Path Oblivious Heap"
r"""This module contains an implementation of the "Path Oblivious Heap"
oblivious priority queue as proposed by
`Shi <https://eprint.iacr.org/2019/274.pdf>`_.
@@ -968,6 +968,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
self.type_hiding_security = type_hiding_security
self.capacity = capacity
self.entry_size = entry_size
self.size = MemValue(sint(0))
# Print debug messages
dprint(f"[POH] __init__: Initializing a queue...")
@@ -1024,6 +1025,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
dprint_ln("\n[POH] insert")
indent()
self.tree.insert(value, priority, fake)
self.size.iadd(fake.bit_not())
outdent()
def _extract_min(self, fake: _secret) -> _secret:
@@ -1031,6 +1033,7 @@ class PathObliviousHeap(AbstractMinPriorityQueue[_secret]):
dprint_ln("\n[POH] extract_min")
indent()
value = self.tree.extract_min(fake)
self.size.iadd(-fake.bit_not())
outdent()
if TRACE:
dprint_ln("[POH] extract_min: extracted value %s", value.reveal())

View File

@@ -18,7 +18,7 @@ from functools import reduce
import Compiler.instructions
import Compiler.instructions_base
import Compiler.instructions_base as inst_base
from Compiler.config import REG_MAX, USER_MEM, COST
from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX
from Compiler.exceptions import CompilerError
from Compiler.instructions_base import RegType
@@ -103,6 +103,38 @@ class Program(object):
self.bit_length = int(options.binary) or int(options.field)
if options.prime:
self.prime = int(options.prime)
print("WARNING: --prime/-P activates code that usually isn't "
"the most efficient variant. Consider using --field/-F "
"and set the prime only during the actual computation.")
if not self.rabbit_gap() and self.prime > 2 ** 50:
print("The chosen prime is particularly inefficient. "
"Consider using a prime that is closer to a power "
"of two", end='')
try:
import gmpy2
bad_prime = self.prime
self.prime = 2 ** int(
round(math.log(self.prime, 2))) + 1
while True:
if self.prime > 2 ** 59:
# LWE compatibility
step = 2 ** 15
else:
step = 1
if self.prime < bad_prime:
self.prime += step
else:
self.prime -= step
if gmpy2.is_prime(self.prime):
break
assert self.rabbit_gap()
print(", for example, %d." % self.prime)
self.prime = bad_prime
except ImportError:
print(".")
if options.execute:
print("Use '-- --prime <prime>' to specify the prime for "
"execution only.")
max_bit_length = int(options.prime).bit_length() - 2
if self.bit_length > max_bit_length:
raise CompilerError(
@@ -111,7 +143,7 @@ class Program(object):
self.bit_length = self.bit_length or max_bit_length
self.non_linear = KnownPrime(self.prime)
else:
self.non_linear = Prime(self.security)
self.non_linear = Prime()
if not self.bit_length:
self.bit_length = 64
print("Default bit length for compilation:", self.bit_length)
@@ -197,6 +229,9 @@ class Program(object):
self.cisc_to_function = True
if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard
self.use_tape_calls = True
self.force_cisc_tape = False
self.have_warned_trunc_pr = False
Program.prog = self
from . import comparison, instructions, instructions_base, types
@@ -225,7 +260,7 @@ class Program(object):
os.mkdir(dirname)
# create extra directories if needed
for dirname in ["Public-Input", "Bytecode", "Schedules"]:
for dirname in ["Public-Input", "Bytecode", "Schedules", "Functions"]:
if not os.path.exists(self.programs_dir + "/" + dirname):
os.mkdir(self.programs_dir + "/" + dirname)
@@ -278,7 +313,8 @@ class Program(object):
self.non_linear = Ring(ring_size)
self.options.ring = str(ring_size)
def new_tape(self, function, args=[], name=None, single_thread=False):
def new_tape(self, function, args=[], name=None, single_thread=False,
finalize=True, **kwargs):
"""
Create a new tape from a function. See
:py:func:`~Compiler.library.multithread` and
@@ -309,11 +345,12 @@ class Program(object):
self.curr_tape
tape_index = len(self.tapes)
self.tape_stack.append(self.curr_tape)
self.curr_tape = Tape(name, self)
self.curr_tape = Tape(name, self, **kwargs)
self.curr_tape.singular = single_thread
self.tapes.append(self.curr_tape)
function(*args)
self.finalize_tape(self.curr_tape)
if finalize:
self.finalize_tape(self.curr_tape)
if self.tape_stack:
self.curr_tape = self.tape_stack.pop()
return tape_index
@@ -346,6 +383,7 @@ class Program(object):
thread_numbers = []
while len(thread_numbers) < len(args):
free_threads = self.curr_tape.free_threads
self.curr_tape.ran_threads = True
if free_threads:
thread_numbers.append(min(free_threads))
free_threads.remove(thread_numbers[-1])
@@ -417,7 +455,10 @@ class Program(object):
def finalize_tape(self, tape):
if not tape.purged:
curr_tape = self.curr_tape
self.curr_tape = tape
tape.optimize(self.options)
self.curr_tape = curr_tape
tape.write_bytes()
if self.options.asmoutfile:
tape.write_str(self.options.asmoutfile + "-" + tape.name)
@@ -472,16 +513,18 @@ class Program(object):
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
if addr + size >= 2**64:
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
if addr + size >= MEM_MAX:
raise CompilerError(
"allocation exceeded for type '%s' after adding %d" % \
(mem_type, size))
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
if single_size:
from .library import get_thread_number, runtime_error_if
from .library import get_arg, runtime_error_if
bak = self.curr_tape.active_basicblock
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
tn = get_thread_number()
runtime_error_if(tn > self.n_running_threads, "malloc")
res = addr + single_size * (tn - 1)
arg = get_arg()
runtime_error_if(arg >= self.n_running_threads, "malloc")
res = addr + single_size * arg
self.curr_tape.active_basicblock = bak
self.base_addresses[res] = addr
return res
@@ -577,7 +620,6 @@ class Program(object):
def set_security(self, security):
changed = self._security != security
self._security = security
self.non_linear.set_security(security)
if changed:
print("Changed statistical security for comparison etc. to",
security)
@@ -612,6 +654,14 @@ class Program(object):
def use_trunc_pr(self, change):
self._use_trunc_pr = change
def trunc_pr_warning(self):
if not self.have_warned_trunc_pr:
print("WARNING: Probabilistic truncation leaks some information, "
"see https://eprint.iacr.org/2024/1127 for discussion. "
"Use 'sfix.round_nearest = True' to deactivate this for "
"fixed-point operations.")
self.have_warned_trunc_pr = True
def use_edabit(self, change=None):
"""Setting whether to use edaBits for non-linear
functionality (default: false).
@@ -783,7 +833,7 @@ class Program(object):
class Tape:
"""A tape contains a list of basic blocks, onto which instructions are added."""
def __init__(self, name, program):
def __init__(self, name, program, thread_pool=None):
"""Set prime p and the initial instructions and registers."""
self.program = program
name += "-%d" % program.get_tape_counter()
@@ -800,12 +850,16 @@ class Tape:
self.merge_opens = True
self.if_states = []
self.req_bit_length = defaultdict(lambda: 0)
self.bit_length_reason = None
self.function_basicblocks = {}
self.functions = []
self.singular = True
self.free_threads = set()
self.free_threads = set() if thread_pool is None else thread_pool
self.loop_breaks = []
self.warned_about_mem = False
self.return_values = []
self.ran_threads = False
self.unused_decorators = {}
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None,
@@ -880,7 +934,8 @@ class Tape:
self.usage_instructions = list(filter(relevant, self.instructions))
else:
self.usage_instructions = []
if len(self.usage_instructions) > 1000:
if len(self.usage_instructions) > 1000 and \
self.parent.program.verbose:
print("Retaining %d instructions" % len(self.usage_instructions))
del self.instructions
self.purged = True
@@ -1000,6 +1055,10 @@ class Tape:
print()
raise CompilerError("Unclosed if/else blocks, see tracebacks above")
if self.unused_decorators:
raise CompilerError("Unused branching decorators, make sure to write " + ",".join(
"'@%s' instead of '%s'" % (x, x) for x in set(self.unused_decorators.values())))
if self.program.verbose:
print(
"Processing tape", self.name, "with %d blocks" % len(self.basicblocks)
@@ -1107,6 +1166,9 @@ class Tape:
if addr.program == self and self.basicblocks:
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
for reg in self.return_values:
allocator.alloc_reg(reg, self.basicblocks[-1].alloc_pool)
seen = set()
def alloc(block):
@@ -1214,7 +1276,10 @@ class Tape:
Compiler.instructions.reqbl(bl, add_to_prog=False)
)
if self.program.verbose:
print("Tape requires prime bit length", self.req_bit_length["p"])
print("Tape requires prime bit length",
self.req_bit_length["p"],
('for %s' % self.bit_length_reason
if self.bit_length_reason else ''))
print("Tape requires galois bit length", self.req_bit_length["2"])
@unpurged
@@ -1287,6 +1352,7 @@ class Tape:
if "Bytecode" not in filename:
filename = self.program.programs_dir + "/Bytecode/" + filename
print("Writing to", filename)
sys.stdout.flush()
f = open(filename, "wb")
h = hashlib.sha256()
for i in self._get_instructions():
@@ -1395,13 +1461,12 @@ class Tape:
return repr(dict(self))
class ReqNode(object):
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
def __init__(self, name):
self._children = []
self.name = name
self.blocks = []
self.aggregated = None
self.num = None
@property
def children(self):
@@ -1411,12 +1476,17 @@ class Tape:
def aggregate(self, *args):
if self.aggregated is not None:
return self.aggregated
self.recursion = self.num is not None
if self.recursion:
return Tape.ReqNum()
self.num = Tape.ReqNum()
for block in self.blocks:
block.add_usage(self)
res = reduce(
lambda x, y: x + y.aggregate(self.name), self.children, self.num
)
if self.recursion:
res *= float('inf')
self.aggregated = res
return res
@@ -1442,7 +1512,7 @@ class Tape:
n_reps = self.aggregator([1])
n_rounds = res["all", "round"]
n_invs = res["all", "inv"]
if (n_invs / n_rounds) * 1000 < n_reps:
if (n_invs / n_rounds) * 1000 < n_reps and Program.prog.verbose:
print(
self.nodes[0].blocks[0].name,
"blowing up rounds: ",
@@ -1468,15 +1538,19 @@ class Tape:
def close_scope(self, outer_scope, parent_req_node, name):
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
def require_bit_length(self, bit_length, t="p"):
def require_bit_length(self, bit_length, t="p", reason=None):
if t == "p":
if self.program.prime:
if bit_length >= self.program.prime.bit_length() - 1:
raise CompilerError(
"required bit length %d too much for %d"
% (bit_length, self.program.prime)
+ ('(for %s)' % reason if reason else '')
)
self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t])
bit_length += 1
if bit_length > self.req_bit_length[t]:
self.req_bit_length[t] = bit_length
self.bit_length_reason = reason
else:
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
@@ -1492,12 +1566,24 @@ class Tape:
def __bool__(self):
raise CompilerError(
"Cannot derive truth value from register. "
"This is a catch-all error appearing if you try to use a "
"run-time value where the compiler expects a compile-time "
"value, most likely a Python integer. "
"In some cases, you can fix this by using 'compile.py -l'."
"See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-derive-truth-value-from-register"
)
def __int__(self):
raise CompilerError(
"It is impossible to convert run-time types to compile-time "
"Python types like int or float. The reason for this is that "
"%s objects are only a placeholder during the execution in "
"Python, the actual value of which is only defined in the "
"virtual machine at a later time. See "
"https://mp-spdz.readthedocs.io/en/latest/journey.html "
"to get an understanding of the overall design. "
"In rare cases, you can fix this by using 'compile.py -l'." % \
type(self).__name__
)
__float__ = __int__
class Register(_no_truth):
"""
Class for creating new registers. The register's index is automatically assigned
@@ -1515,6 +1601,7 @@ class Tape:
"caller",
"can_eliminate",
"duplicates",
"dup_count",
"block",
]
maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1
@@ -1547,6 +1634,7 @@ class Tape:
self.vector = []
self.can_eliminate = True
self.duplicates = util.set_by_id([self])
self.dup_count = None
if Program.prog.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
else:
@@ -1619,6 +1707,9 @@ class Tape:
def copy(self):
return Tape.Register(self.reg_type, Program.prog.curr_tape)
def same_type(self):
return type(self)(size=self.size)
def link(self, other):
if Program.prog.options.noreallocate:
raise CompilerError("reallocation necessary for linking, "

View File

@@ -1,5 +1,6 @@
import itertools
from Compiler import types, library, instructions
from Compiler import comparison, util
def dest_comp(B):
Bt = B.transpose()
@@ -10,7 +11,7 @@ def dest_comp(B):
return sum(Tt) - 1
def reveal_sort(k, D, reverse=False):
""" Sort in place according to "perfect" key. The name hints at the fact
r""" Sort in place according to "perfect" key. The name hints at the fact
that a random order of the keys is revealed.
:param k: vector or Array of sint containing exactly :math:`0,\dots,n-1`
@@ -20,6 +21,7 @@ def reveal_sort(k, D, reverse=False):
backward order
"""
comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting')
assert len(k) == len(D)
library.break_point()
shuffle = types.sint.get_secure_shuffle(len(k))

View File

@@ -109,7 +109,7 @@ class SqrtOram(Generic[T, B]):
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from(
[self.index_type(i) for i in range(self.n)])
self.index_type(regint.inc(self.n)))
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
@@ -122,7 +122,7 @@ class SqrtOram(Generic[T, B]):
# Note that self.shuffle_the_shuffle mutates this field
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
self.permutation = Array.create_from(
[self.index_type(i) for i in range(self.n)])
self.index_type(regint.inc(self.n)))
# We allow the caller to postpone the initialization of the shuffle
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
# Note that if you do not initialize, the ORAM is insecure
@@ -256,7 +256,6 @@ class SqrtOram(Generic[T, B]):
return result
@lib.method_block
def write(self, index: T, *value: T):
global trace, n_parallel
if trace:
@@ -271,7 +270,12 @@ class SqrtOram(Generic[T, B]):
else:
raise Exception("Cannot handle type of value passed")
print(self.entry_length, value, type(value),len(value))
value = MemValue(value)
self._write(index, *value)
@lib.method_block
def _write(self, index: T, *value: T):
value = MemValue(self.value_type(value))
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@@ -513,14 +517,14 @@ class SqrtOram(Generic[T, B]):
# Since the underlying memory of the position map is already aligned in
# this packed structure, we can simply overwrite the memory while
# maintaining the structure.
self.position_map.reinitialize(*self.permutation)
self.position_map.reinitialize(self.permutation)
def reinitialize(self, *data: T):
def reinitialize(self, data: T):
# Note that this method is only used during refresh, and as such is
# only called with a permutation as data.
# The logical addresses of some previous permutation are irrelevant and must be reset
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
# Reset the clock
self.t.write(0)
# Reset shuffle_used
@@ -530,10 +534,10 @@ class SqrtOram(Generic[T, B]):
# This structure is preserved while overwriting the values using
# assign_vector
self.shuffle.assign_vector(self.value_type(
data, size=self.n * self.entry_length))
data[:], size=self.n * self.entry_length))
# Note that this updates self.permutation (see constructor for explanation)
self.shuffle_the_shuffle()
self.position_map.reinitialize(*self.permutation)
self.position_map.reinitialize(self.permutation)
def _reset_shuffle_used(self):
global allow_memory_allocation
@@ -568,7 +572,7 @@ class PositionMap(Generic[T, B]):
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, *permutation: T):
def reinitialize(self, permutation: T):
"""Reinitialize this PositionMap.
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
@@ -613,9 +617,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
packed_size = int(math.ceil(self.n / pack))
packed_structure = MultiArray(
(packed_size, pack), value_type=value_type)
for i in range(packed_size):
@lib.for_range(packed_size)
def _(i):
packed_structure[i] = Array.create_from(
permutation[i*pack:(i+1)*pack])
permutation.get_vector(base=i * pack, size=pack))
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth,
@@ -720,8 +725,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
return p.reveal()
def reinitialize(self, *permutation: T):
SqrtOram.reinitialize(self, *permutation)
def reinitialize(self, permutation: T):
SqrtOram.reinitialize(self, permutation)
class LinearPositionMap(PositionMap):
@@ -790,8 +795,8 @@ class LinearPositionMap(PositionMap):
return p.reveal()
def reinitialize(self, *data: T):
self.physical.assign_vector(data)
def reinitialize(self, data : T):
self.physical.assign(data)
global allow_memory_allocation
if allow_memory_allocation:

File diff suppressed because it is too large Load Diff

View File

@@ -28,11 +28,11 @@ def greater_than(a, b, bits):
else:
return a.greater_than(b, bits)
def pow2_value(a, bit_length=None, security=None):
def pow2_value(a, bit_length=None):
if is_constant_float(a):
return 2**a
else:
return a.pow2(bit_length, security)
return a.pow2(bit_length)
def mod2m(a, b, bits, signed):
if isinstance(a, int):
@@ -292,6 +292,9 @@ class dict_by_id(object):
def __iter__(self):
return self.keys()
def pop(self, key):
return self.content.pop(id(key), None)
class defaultdict_by_id(dict_by_id):
def __init__(self, default):
dict_by_id.__init__(self)

View File

@@ -15,7 +15,7 @@
int main()
{
P256Element::init();
P256Element::Scalar key;
KeySetup<Share<P256Element::Scalar>> key;
string prefix = PREP_DIR "ECDSA/";
mkdir_p(prefix.c_str());
write_online_setup(prefix, P256Element::Scalar::pr());

View File

@@ -10,10 +10,12 @@
#include "Math/gfp.h"
#include "ECDSA/P256Element.h"
#include "GC/VectorInput.h"
#include "Protocols/SPDZ.h"
#include "ECDSA/preprocessing.hpp"
#include "ECDSA/sign.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Hemi.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/MAC_Check.hpp"
@@ -44,6 +46,7 @@ int main(int argc, const char** argv)
typedef Share<P256Element::Scalar> pShare;
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
read_mac_key(prefix, N, keyp);
pShare::set_mac_key(keyp);
pShare::MAC_Check::setup(P);
Share<P256Element>::MAC_Check::setup(P);

View File

@@ -32,7 +32,7 @@
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/Shamir.hpp"
#include "Machines/MalRep.hpp"
#include "Machines/Rep.hpp"

View File

@@ -8,6 +8,7 @@
#include "Math/gfp.h"
#include "ECDSA/P256Element.h"
#include "Protocols/SemiShare.h"
#include "Protocols/SPDZ.h"
#include "Processor/BaseMachine.h"
#include "ECDSA/preprocessing.hpp"
@@ -15,6 +16,7 @@
#include "Protocols/Beaver.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/Hemi.hpp"
#include "Processor/Processor.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Input.hpp"

View File

@@ -37,6 +37,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
EcdsaOptions opts)
{
bool prep_mul = opts.prep_mul;
if (prep_mul)
proc.protocol.init_mul();
Timer timer;
timer.start();
Player& P = proc.P;
@@ -77,7 +80,6 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul();
for (int i = 0; i < buffer_size; i++)
protocol.prepare_mul(inv_ks[i], sk);
protocol.start_exchange();

View File

@@ -6,6 +6,7 @@
#include "GC/SemiSecret.h"
#include "GC/SemiPrep.h"
#include "Protocols/Hemi.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/SemiPrep.hpp"
#include "Protocols/SemiInput.hpp"

View File

@@ -23,6 +23,7 @@
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"

View File

@@ -24,6 +24,13 @@ Client::Client(const vector<string>& hostnames, int port_base,
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
else
{
octetStream spec;
spec.Receive(sockets[i]);
if (spec != specification)
throw runtime_error("inconsistent specification");
}
}
}

View File

@@ -17,11 +17,10 @@
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* To run with 2 parties / SPDZ engines:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* To run:
* ./Scripts/setup-clients.sh to create SSL keys and certificates for clients
* ./compile.py bankers_bonus
* ./Scripts/run-online.sh bankers_bonus to run the engines.
* ./Scripts/compile-run.py <protocol> bankers_bonus to compile and run the engines.
* (See https://github.com/data61/MP-SPDZ/?tab=readme-ov-file#protocols for options.)
*
* ./bankers-bonus-client.x 0 2 100 0
* ./bankers-bonus-client.x 1 2 200 0

View File

@@ -31,6 +31,15 @@ def set_keepalive_osx(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
class Client:
"""Client to servers running secure computation. Works both as a client
to all parties or a trusted client to a single party.
:param hostnames: hostnames or IP addresses to connect to
:param port_base: port number for first hostname,
increases by one for every additional hostname
:param my_client_id: number to identify client
"""
def __init__(self, hostnames, port_base, my_client_id):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
name = 'C%d' % my_client_id
@@ -62,6 +71,11 @@ class Client:
self.specification = octetStream()
self.specification.Receive(self.sockets[0])
for sock in self.sockets[1:]:
specification = octetStream()
specification.Receive(sock)
if specification.buf != self.specification.buf:
raise Exception('inconsistent specification')
type = self.specification.get_int(4)
if type == ord('R'):
self.domain = Z2(self.specification.get_int(4))
@@ -99,6 +113,12 @@ class Client:
return triples
def send_private_inputs(self, values):
""" Send inputs privately to the computation servers.
This assumes that the client is connected to all servers.
:param values: list of input values
"""
T = self.domain
triples = self.receive_triples(T, len(values))
os = octetStream()
@@ -109,10 +129,46 @@ class Client:
os.Send(socket)
def receive_outputs(self, n):
""" Receive outputs privately from the computation servers.
This assumes that the client is connected to all servers.
:param n: number of outputs
"""
T = self.domain
triples = self.receive_triples(T, n)
return [int(self.clear_domain(triple[0].v)) for triple in triples]
def send_public_inputs(self, values):
""" Send values in the clear. This works for public inputs
to all servers or to send shares to a single server.
:param values: list of values
"""
os = octetStream()
for value in values:
self.domain(value).pack(os)
for socket in self.sockets:
os.Send(socket)
def receive_plain_values(self, socket=None):
""" Receive values in the clear. This works for public inputs
to all servers or to send shares to a single server.
:param socket: socket to use (need to specify it there is more than one)
"""
if socket is None:
if len(self.sockets) != 1:
raise Exception('need to specify socket')
socket = self.sockets[0]
os = octetStream()
os.Receive(socket)
assert len(os) % self.domain.size() == 0
return [int(os.get(self.domain))
for i in range(len(os) // self.domain.size())]
class octetStream:
def __init__(self, value=None):
self.buf = b''
@@ -123,6 +179,8 @@ class octetStream:
def get_length(self):
return len(self.buf)
__len__ = get_length
def reset_write_head(self):
self.buf = b''
self.ptr = 0
@@ -164,6 +222,11 @@ class octetStream:
else:
return 0
def get(self, type):
res = type()
res.unpack(self)
return res
def consume(self, length):
self.ptr += length
assert self.ptr <= len(self.buf)

View File

@@ -7,7 +7,7 @@ class Domain:
def __int__(self):
res = self.v % self.modulus
return res if res < self.modulus / 2 else res - self.modulus
return int(res if res < self.modulus / 2 else res - self.modulus)
def __add__(self, other):
try:

View File

@@ -0,0 +1,19 @@
#!/usr/bin/python3
import sys, random
sys.path.insert(0, 'ExternalIO')
from client import *
party = int(sys.argv[1])
client = Client(['localhost'], 15000 + party, 0)
n = 1000
if party < 2:
client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n))
x = [client.receive_plain_values() for i in range(2)]
client.send_public_inputs(a + b for a, b in zip(*x))

View File

@@ -16,14 +16,14 @@ template<class T>
class AddableVector: public vector<T>
{
public:
AddableVector<T>() {}
AddableVector<T>(size_t n, const T& x = T()) : vector<T>(n, x) {}
AddableVector() {}
AddableVector(size_t n, const T& x = T()) : vector<T>(n, x) {}
template <class U, class FD, class S>
AddableVector<T>(const Plaintext<U,FD,S>& other) :
AddableVector<T>(other.get_poly()) {}
AddableVector(const Plaintext<U,FD,S>& other) :
AddableVector(other.get_poly()) {}
template <class U>
AddableVector<T>(const vector<U>& other)
AddableVector(const vector<U>& other)
{
this->assign(other.begin(), other.end());
}
@@ -129,29 +129,41 @@ public:
(*this)[i].pack(os);
}
void unpack_size(octetStream& os, const T& init = T())
size_t unpack_size(octetStream& os)
{
unsigned int size;
os.get(size);
this->resize(size, init);
this->reserve(size);
return size;
}
void unpack(octetStream& os, const T& init = T())
{
unpack_size(os, init);
for (unsigned int i = 0; i < this->size(); i++)
(*this)[i].unpack(os);
size_t new_size = unpack_size(os);
this->clear();
for (unsigned int i = 0; i < new_size; i++)
{
this->push_back(init);
this->back().unpack(os);
}
}
void add(octetStream& os, T& tmp)
{
unpack_size(os, tmp);
size_t new_size = unpack_size(os);
T init = tmp;
T& item = tmp;
for (unsigned int i = 0; i < this->size(); i++)
{
item.unpack(os);
(*this)[i] += item;
}
for (size_t i = this->size(); i < new_size; i++)
{
item.unpack(os);
this->push_back(init);
this->back() += item;
}
}
T infinity_norm() const

View File

@@ -61,9 +61,7 @@ bigint FHE_Params::Q() const
void FHE_Params::pack(octetStream& o) const
{
o.store(FFTData.size());
for(auto& fd: FFTData)
fd.pack(o);
o.store(FFTData);
Chi.pack(o);
Bval.pack(o);
o.store(sec_p);
@@ -73,11 +71,7 @@ void FHE_Params::pack(octetStream& o) const
void FHE_Params::unpack(octetStream& o)
{
size_t size;
o.get(size);
FFTData.resize(size);
for (auto& fd : FFTData)
fd.unpack(o);
o.get(FFTData);
Chi.unpack(o);
Bval.unpack(o);
o.get(sec_p);

View File

@@ -311,22 +311,18 @@ void imatrix::hash(octetStream& o) const
void imatrix::pack(octetStream& o) const
{
o.store(size());
for (auto& x : *this)
{
assert(x.size() == size());
x.pack(o);
}
o.store(static_cast<const super&>(*this));
}
void imatrix::unpack(octetStream& o)
{
size_t size;
o.get(size);
resize(size);
o.get(static_cast<super&>(*this));
for (auto& x : *this)
{
x.resize(size);
x.unpack(o);
assert(x.size() == size());
}
}

View File

@@ -13,6 +13,8 @@ typedef vector< vector<bigint> > matrix;
class imatrix : public vector< BitVector >
{
typedef vector<BitVector> super;
public:
bool operator!=(const imatrix& other) const;

View File

@@ -13,6 +13,8 @@
#include "FHEOffline/Proof.h"
#include "Processor/OnlineOptions.h"
#include <fstream>
using namespace std;
@@ -735,9 +737,9 @@ void load_or_generate(P2Data& P2D, const Ring& R)
{
P2D.load(R);
}
catch (...)
catch (exception& e)
{
cout << "Loading failed" << endl;
cerr << "Loading parameters failed, generating (" << e.what() << ")" << endl;
init(P2D,R);
P2D.store(R);
}

View File

@@ -2,6 +2,7 @@
#include "FHE/P2Data.h"
#include "Math/Setup.h"
#include "Math/fixint.h"
#include "Processor/OnlineOptions.h"
#include <fstream>
@@ -74,7 +75,6 @@ bool P2Data::operator!=(const P2Data& other) const
void P2Data::hash(octetStream& o) const
{
check_dimensions();
o.store(gf2n_short::degree());
o.store(slots);
A.hash(o);
@@ -113,17 +113,18 @@ string get_filename(const Ring& Rg)
void P2Data::load(const Ring& Rg)
{
string filename = get_filename(Rg);
cout << "Loading from " << filename << endl;
ifstream s(filename);
if (OnlineOptions::singleton.verbose)
cerr << "Loading from " << filename << endl;
octetStream os;
os.input(s);
os.input(filename);
unpack(os);
}
void P2Data::store(const Ring& Rg) const
{
string filename = get_filename(Rg);
cout << "Storing in " << filename << endl;
if (OnlineOptions::singleton.verbose)
cerr << "Storing in " << filename << endl;
ofstream s(filename);
octetStream os;
pack(os);

View File

@@ -562,22 +562,17 @@ template <class T,class FD,class S>
void Plaintext<T,FD,S>::pack(octetStream& o) const
{
to_poly();
o.store((unsigned int)b.size());
for (unsigned int i = 0; i < b.size(); i++)
o.store(b[i]);
o.store(b);
}
template <class T,class FD,class S>
void Plaintext<T,FD,S>::unpack(octetStream& o)
{
type = Polynomial;
unsigned int size;
o.get(size);
allocate();
o.get(b);
auto size = b.size();
allocate(Polynomial);
if (size != b.size() and size != 0)
throw length_error("unexpected length received");
for (unsigned int i = 0; i < size; i++)
b[i] = o.get<S>();
}

View File

@@ -512,9 +512,7 @@ modp Ring_Element::get_constant() const
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
{
ZpD.pack(o);
o.store((int)v.size());
for (unsigned int i=0; i<v.size(); i++)
{ v[i].pack(o,ZpD); }
o.store(v);
}
@@ -526,16 +524,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
throw runtime_error(
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
+ to_string(ZpD.pr_bit_length));
unsigned int length;
o.get(length);
v.clear();
v.reserve(length);
modp tmp;
for (unsigned int i=0; i<length; i++)
{
tmp.unpack(o,ZpD);
v.push_back(tmp);
}
o.get(v);
}

View File

@@ -1,5 +1,5 @@
#ifndef _Subroutines
#define _Subroutines
#ifndef FHE_SUBROUTINES_H_
#define FHE_SUBROUTINES_H_
#include "Math/Zp_Data.h"

View File

@@ -27,8 +27,7 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
try
{
ifstream input(filename);
os.input(input);
os.input(filename);
setup.unpack(os);
machine.unpack(os);
}
@@ -44,11 +43,15 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
}
catch (mismatch_among_parties& e)
{
error = e.what();
if (error.empty())
error = e.what();
}
if (not error.empty())
{
if (OnlineOptions::singleton.has_option("expect_setup"))
throw runtime_error("error in setup: " + error);
cerr << "Running secrets generation because no suitable material "
"from a previous run was found (" << error << ")" << endl;
setup.key_and_mac_generation(P, machine, num_runs, V());

View File

@@ -80,9 +80,8 @@ void secure_init(T& setup, Player& P, U& machine,
try
{
ifstream file(filename);
octetStream os;
os.input(file);
os.input(filename);
os.get(machine.extra_slack);
setup.unpack(os);
}
@@ -95,13 +94,17 @@ void secure_init(T& setup, Player& P, U& machine,
{
setup.check(P, machine);
}
catch (exception& e)
catch (mismatch_among_parties& e)
{
reason = e.what();
if (reason.empty())
reason = e.what();
}
if (not reason.empty())
{
if (OnlineOptions::singleton.has_option("expect_setup"))
throw runtime_error("error in setup: " + reason);
if (OnlineOptions::singleton.verbose)
cerr << "Generating parameters for security " << sec
<< " and field size ~2^" << plaintext_length

View File

@@ -13,6 +13,8 @@
#include "SimpleMachine.h"
#include "Tools/mkpath.h"
#include "Protocols/Share.hpp"
template<class FD>
Producer<FD>::Producer(int output_thread, bool write_output) :
n_slots(0), output_thread(output_thread), write_output(write_output),

View File

@@ -8,6 +8,8 @@
#include "BitAdder.h"
#include "Protocols/BufferScope.h"
#include <assert.h>
template<class T>
@@ -69,6 +71,9 @@ void BitAdder::add(vector<vector<T> >& res,
size_t n_items = end - begin;
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%lu ANDs in bit adder\n", length * n_items * n_bits);
if (supply)
{
#ifdef VERBOSE_EDA
@@ -85,6 +90,7 @@ void BitAdder::add(vector<vector<T> >& res,
vector<T> carries(n_items);
vector<T> a(n_items), b(n_items);
auto& protocol = proc.protocol;
BufferScope scope(proc.DataF, n_items * length * n_bits);
for (int i = 0; i < n_bits; i++)
{
assert(summands[i].size() == 2);

View File

@@ -24,14 +24,14 @@ void FakeSecret::load_clear(int n, const Integer& x)
*this = x;
}
void FakeSecret::bitcom(Memory<FakeSecret>& S, const vector<int>& regs)
void FakeSecret::bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
void FakeSecret::bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const
void FakeSecret::bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;

View File

@@ -142,8 +142,8 @@ public:
template <class T>
void store(Memory<T>& mem, size_t address) { mem[address] = *this; }
void bitcom(Memory<FakeSecret>& S, const vector<int>& regs);
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
void bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs);
void bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const;
template <class T>
void xor_(int n, const FakeSecret& x, const T& y)

View File

@@ -82,10 +82,24 @@ void Instruction::parse(istream& s, int pos)
#undef X
default:
ostringstream os;
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
os << "This virtual machine executes binary circuits only." << endl;
os << "Use 'compile.py -B'." << endl;
throw Invalid_Instruction(os.str());
os << "Code not defined for instruction ";
bool known = true;
switch (opcode)
{
#define X(NAME, PRE, CODE) case NAME: os << #NAME; break;
ALL_INSTRUCTIONS
#undef X
default:
known = false;
os << showbase << hex << opcode << dec;
}
os << endl;
if (known)
{
os << "This virtual machine executes binary circuits only. ";
os << "Use 'compile.py -B'.";
}
exit_error(os.str());
break;
}
}

View File

@@ -62,6 +62,7 @@ public:
typedef MaliciousRepMC<U> MC;
typedef MC MAC_Check;
typedef HashMaliciousRepMC<U> DefaultMC;
typedef ReplicatedInput<U> Input;
typedef RepPrep<U> LivePrep;

View File

@@ -52,7 +52,7 @@ public:
static DataFieldType field_type()
{
throw not_implemented();
return DATA_GF2;
}
static void init_minimum(int)
@@ -80,7 +80,8 @@ public:
bool operator!=(NoValue) const { return false; }
bool operator==(int) { fail(); return false; }
bool operator==(int) const { fail(); return false; }
bool operator==(NoValue) const { fail(); return false; }
bool get_bit(int) { fail(); return 0; }
@@ -92,6 +93,8 @@ public:
void input(istream&, bool) { fail(); }
void output(ostream&, bool) {}
void pack(octetStream&) const { fail(); }
};
inline ostream& operator<<(ostream& o, NoValue)
@@ -169,8 +172,8 @@ public:
void load_clear(Integer, Integer) { fail(); }
void random_bit() { fail(); }
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitdec(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
void assign(const char*) { fail(); }
@@ -190,6 +193,8 @@ public:
NoShare& operator+=(const NoShare&) { fail(); return *this; }
bool operator==(NoShare) const { fail(); return false; }
NoShare get_bit(int) const { fail(); return {}; }
void xor_bit(int, NoShare) const { fail(); }
@@ -201,6 +206,8 @@ public:
void input(istream&, bool) { fail(); }
void output(ostream&, bool) { fail(); }
void pack(octetStream&) const { fail(); }
};
} /* namespace GC */

View File

@@ -79,10 +79,11 @@ template<class T>
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
size_t begin, size_t end)
{
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
RunningTimer timer;
#endif
bool verbose = OnlineOptions::singleton.has_option("verbose_eda");
if (verbose)
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
auto& party = ShareThread<typename T::whole_type>::s();
auto& MC = party.MC->get_part_MC();
auto& P = *party.P;
@@ -102,9 +103,9 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
input.exchange();
for (size_t i = begin; i < end; i++)
triples[i][2] = input.finalize(input_player, T::default_length);
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
#endif
if (verbose)
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
}
}

View File

@@ -16,6 +16,7 @@ using namespace std;
#include "Math/Integer.h"
#include "Processor/ProcessorBase.h"
#include "Processor/Instruction.h"
#include "Tools/CheckVector.h"
namespace GC
{
@@ -38,9 +39,9 @@ public:
// rough measure for the memory usage
size_t complexity;
Memory<T> S;
Memory<Clear> C;
Memory<Integer> I;
StackedVector<T> S;
StackedVector<Clear> C;
StackedVector<Integer> I;
Timer xor_timer;
@@ -78,8 +79,8 @@ public:
template<class U>
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
template<class U>
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
template<class U, class V>
void mem_op(int n, U& dest, const V& source,
Integer dest_address, Integer source_address);
void xors(const vector<int>& args);
@@ -105,6 +106,9 @@ public:
template<int = 0>
void convcbit2s(const BaseInstruction& instruction);
void convcbitvec(const BaseInstruction& instruction, StackedVector<Integer>& Ci,
Player* P);
void print_reg(int reg, int n, int size);
void print_reg_plain(Clear& value);
void print_reg_signed(unsigned n_bits, Integer value);
@@ -114,6 +118,13 @@ public:
void print_float_prec(int n);
void incint(const BaseInstruction& instruction);
void push_stack();
void push_args(const vector<int>& args);
void pop_stack(const vector<int>& results);
template<class U>
void call_tape(const BaseInstruction& instruction, U& dynamic_memory);
};
template <class T>

View File

@@ -18,7 +18,7 @@ using namespace std;
#include "Math/BitVec.h"
#include "GC/Machine.hpp"
#include "Processor/ProcessorBase.hpp"
#include "Processor/Processor.hpp"
#include "Processor/IntInput.hpp"
#include "Math/bigint.hpp"
@@ -53,9 +53,9 @@ template <class T>
template <class U>
void Processor<T>::reset(const U& program, int arg)
{
S.resize(program.num_reg(SBIT), "registers");
C.resize(program.num_reg(CBIT), "registers");
I.resize(program.num_reg(INT), "registers");
S.resize(program.num_reg(SBIT));
C.resize(program.num_reg(CBIT));
I.resize(program.num_reg(INT));
set_arg(arg);
PC = 0;
}
@@ -202,14 +202,14 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
}
template<class T>
template<class U>
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
template<class U, class V>
void Processor<T>::mem_op(int n, U& dest, const V& source,
Integer dest_address, Integer source_address)
{
dest.check_index(dest_address + n - 1);
source.check_index(source_address + n - 1);
auto d = &dest[dest_address];
auto s = &source[source_address];
auto d = &dest[dest_address.get()];
auto s = &source[source_address.get()];
for (int i = 0; i < n; i++)
{
*d++ = *s++;
@@ -388,6 +388,30 @@ void Processor<T>::convcbit2s(const BaseInstruction& instruction)
min(size_t(unit), instruction.get_n() - i * unit));
}
template<class T>
void Processor<T>::convcbitvec(const BaseInstruction& instruction,
StackedVector<Integer>& Ci, Player* P)
{
vector<Integer> bits;
auto n = instruction.get_n();
bits.reserve(n);
for (size_t i = 0; i < instruction.get_n(); i++)
{
int i1 = i / GC::Clear::N_BITS;
int i2 = i % GC::Clear::N_BITS;
auto bit = C[instruction.get_r(1) + i1].get_bit(i2);
bits.push_back(bit);
}
if (P)
sync<T>(bits, *P);
else if (not T::symmetric)
sync<T>(bits, *Thread<T>::s().P);
for (size_t i = 0; i < n; i++)
Ci[instruction.get_r(0) + i] = bits[i];
}
template <class T>
void Processor<T>::print_reg(int reg, int n, int size)
{
@@ -417,7 +441,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
{
if (n_bits <= Clear::N_BITS)
{
auto value = C[reg];
auto value = C[reg.get()];
unsigned n_shift = 0;
if (n_bits > 1)
n_shift = sizeof(value.get()) * 8 - n_bits;
@@ -477,6 +501,56 @@ void Processor<T>::incint(const BaseInstruction& instruction)
}
}
template<class T>
void GC::Processor<T>::push_stack()
{
S.push_stack();
C.push_stack();
}
template<class T>
void GC::Processor<T>::push_args(const vector<int>& args)
{
S.push_args(args, SBIT);
C.push_args(args, CBIT);
}
template<class T>
void GC::Processor<T>::pop_stack(const vector<int>& results)
{
S.pop_stack(results, SBIT);
C.pop_stack(results, CBIT);
}
template<class T>
template<class U>
void Processor<T>::call_tape(const BaseInstruction& instruction, U& dynamic_memory)
{
auto new_arg = I.at(instruction.get_r(1)).get();
PC_stack.push_back(PC);
arg_stack.push_back(this->arg);
push_stack();
I.push_stack();
auto& tape = machine->progs.at(instruction.get_r(0));
reset(tape, new_arg);
auto& args = instruction.get_start();
push_args(args);
I.push_args(args, INT);
tape.execute(*this, dynamic_memory, PC);
pop_stack(args);
I.pop_stack(args, INT);
PC = PC_stack.back();
PC_stack.pop_back();
this->arg = arg_stack.back();
arg_stack.pop_back();
}
} /* namespace GC */
#endif

View File

@@ -71,9 +71,9 @@ void Program::parse(istream& s)
CALLGRIND_STOP_INSTRUMENTATION;
while (!s.eof())
{
instr.parse(s, pos);
if (s.bad() or s.fail())
throw runtime_error("error reading program");
instr.parse(s, pos);
p.push_back(instr);
//cerr << "\t" << instr << endl;
s.peek();

View File

@@ -7,6 +7,7 @@
#include "Protocols/Rep4.hpp"
#include "Protocols/Rep4Input.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ReplicatedPrep.hpp"
namespace GC

View File

@@ -65,6 +65,8 @@ public:
typedef typename T::out_type out_type;
typedef void DefaultMC;
static string type_string() { return "evaluation secret"; }
static string phase_name() { return T::name(); }
@@ -76,6 +78,10 @@ public:
static const bool actual_inputs = T::actual_inputs;
static const bool symmetric = true;
static bool real_shares(const Player&) { return true; }
static int threshold(int nplayers) { return T::threshold(nplayers); }
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
@@ -148,9 +154,9 @@ public:
Secret<T> operator>>(int i) const;
template<class U>
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitcom(StackedVector<U>& S, const vector<int>& regs);
template<class U>
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
Secret<T> operator+(const Secret<T>& x) const;
Secret<T>& operator+=(const Secret<T>& x) { *this = *this + x; return *this; }
@@ -175,6 +181,7 @@ public:
void finalize_input(U& inputter, int from, int n_bits);
int size() const { return registers.size(); }
size_t maximum_size() const { return registers.size(); }
RegVector& get_regs() { return registers; }
const RegVector& get_regs() const { return registers; }

View File

@@ -197,7 +197,7 @@ Secret<T> Secret<T>::operator>>(int i) const
template <class T>
template <class U>
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
void Secret<T>::bitcom(StackedVector<U>& S, const vector<int>& regs)
{
registers.clear();
for (unsigned int i = 0; i < regs.size(); i++)
@@ -210,7 +210,7 @@ void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
template <class T>
template <class U>
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
void Secret<T>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
{
if (regs.size() > registers.size())
throw overflow("not enough bits for bit decomposition", regs.size(),

View File

@@ -17,7 +17,8 @@ namespace GC
void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat)
{
if (repeat and OnlineOptions::singleton.live_prep)
if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1)
and P.num_players() == 2)
{
this->triples.push_back({{}});
auto& triple = this->triples.back();
@@ -35,6 +36,8 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n)
{
if (n == -1)
n = SemiSecret::default_length;
super::prepare_mul(x.mask(n), y.mask(n), n);
}

View File

@@ -24,7 +24,7 @@ public:
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat);
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n);
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n = -1);
};
} /* namespace GC */

View File

@@ -79,6 +79,9 @@ array<SemiSecret, 3> SemiPrep::get_mixed_triple(int n)
if (mixed_triples.empty())
{
assert(this->triple_generator);
this->triple_generator->set_batch_size(
BaseMachine::batch_size<SemiSecret>(DATA_MIXED,
this->buffer_size));
this->triple_generator->generateMixedTriples();
for (auto& x : this->triple_generator->mixedTriples)
{

View File

@@ -38,12 +38,24 @@ public:
static const int default_length = sizeof(BitVec) * 8;
static const bool symmetric = V::symmetric;
static bool real_shares(const Player& P)
{
return V::real_shares(P);
}
static string type_string() { return "binary secret"; }
static string phase_name() { return "Binary computation"; }
static void trans(Processor<T>& processor, int n_outputs,
const vector<int>& args);
static size_t maximum_size()
{
return default_length;
}
SemiSecretBase()
{
}
@@ -64,8 +76,8 @@ public:
void load_clear(int n, const Integer& x);
void bitcom(Memory<T>& S, const vector<int>& regs);
void bitdec(Memory<T>& S, const vector<int>& regs) const;
void bitcom(StackedVector<T>& S, const vector<int>& regs);
void bitdec(StackedVector<T>& S, const vector<int>& regs) const;
void xor_(int n, const T& x, const T& y)
{ *this = BitVec(x ^ y).mask(n); }

View File

@@ -3,6 +3,9 @@
*
*/
#ifndef GC_SEMISECRET_HPP_
#define GC_SEMISECRET_HPP_
#include "GC/ShareParty.h"
#include "GC/ShareSecret.hpp"
#include "Protocols/MAC_Check_Base.hpp"
@@ -70,30 +73,40 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
assert(protocol);
protocol->init_mul();
auto it = args.begin();
int total_bits = 0, total_ops = 0;
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
total_bits += n_args * size;
it += n_args;
int base = *it++;
assert(n_args <= N_BITS);
for (int i = 0; i < size; i += N_BITS)
{
square64 square;
for (int j = 0; j < n_args; j++)
square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get();
int n_ops = min(N_BITS, size - i);
square.transpose(n_args, n_ops);
for (int j = 0; j < n_ops; j++)
for (int k = 0; k < n_args; k += N_BITS)
{
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
auto y_ext = SemiSecret(bit).extend_bit();
protocol->prepare_mult(square.rows[j], y_ext, n_args, true);
int left = min(N_BITS, n_args - k);
square64 square;
for (int j = 0; j < left; j++)
square.rows[j] = processor.S.at(
*(it + k + j) + i / N_BITS).get();
int n_ops = min(N_BITS, size - i);
total_ops += n_ops;
square.transpose(left, n_ops);
for (int j = 0; j < n_ops; j++)
{
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
auto y_ext = SemiSecret(bit).extend_bit();
protocol->prepare_mult(square.rows[j], y_ext, left, true);
}
}
}
it += n_args;
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d/%d repeat ANDs\n", total_bits, total_ops);
protocol->exchange();
it = args.begin();
@@ -103,13 +116,18 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
int size = *it++;
for (int i = 0; i < size; i += N_BITS)
{
int n_ops = min(N_BITS, size - i);
square64 square;
for (int j = 0; j < n_ops; j++)
square.rows[j] = protocol->finalize_mul(n_args).get();
square.transpose(n_ops, n_args);
for (int j = 0; j < n_args; j++)
processor.S.at(*(it + j) + i / N_BITS) = square.rows[j];
for (int base = 0; base < n_args; base += N_BITS)
{
int left = min(N_BITS, n_args - base);
int n_ops = min(N_BITS, size - i);
square64 square;
for (int j = 0; j < n_ops; j++)
square.rows[j] = protocol->finalize_mul(left).get();
square.transpose(n_ops, left);
for (int j = 0; j < left; j++)
processor.S.at(*(it + base + j) + i / N_BITS) =
square.rows[j];
}
}
it += 2 * n_args + 1;
}
@@ -123,7 +141,7 @@ void SemiSecretBase<T, V>::load_clear(int n, const Integer& x)
}
template<class T, class V>
void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
void SemiSecretBase<T, V>::bitcom(StackedVector<T>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
@@ -131,7 +149,7 @@ void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
}
template<class T, class V>
void SemiSecretBase<T, V>::bitdec(Memory<T>& S,
void SemiSecretBase<T, V>::bitdec(StackedVector<T>& S,
const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
@@ -146,3 +164,5 @@ void SemiSecretBase<T, V>::reveal(size_t n_bits, Clear& x)
}
} /* namespace GC */
#endif

View File

@@ -42,7 +42,7 @@ inline ShareParty<T>& ShareParty<T>::s()
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
throw runtime_error("no ShareParty singleton");
}
}

View File

@@ -108,18 +108,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
else
P = new PlainPlayer(this->N, "shareparty");
try
{
read_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
this->N,
this->mac_key);
}
catch (exception& e)
{
SeededPRNG G;
this->mac_key.randomize(G);
}
T::read_or_generate_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
*P, this->mac_key);
T::MC::setup(*P);

View File

@@ -46,6 +46,9 @@ public:
static const bool is_real = true;
static const bool actual_inputs = true;
static const bool symmetric = true;
static bool real_shares(const Player&) { return true; }
static ShareThread<U>& get_party()
{
@@ -118,9 +121,12 @@ public:
typedef BitVec open_type;
typedef NoShare mac_type;
typedef NoValue mac_key_type;
typedef NoShare mac_share_type;
typedef NoShare bit_type;
typedef void DefaultMC;
static const int N_BITS = clear::N_BITS;
static const bool dishonest_majority = false;
@@ -151,12 +157,22 @@ public:
{
}
static GC::NoValue get_mac_key()
{
throw runtime_error("no MAC");
}
template<class T>
static string proto_fake_opts()
{
return T::fake_opts();
}
static size_t maximum_size()
{
return default_length;
}
RepSecretBase()
{
}
@@ -166,8 +182,8 @@ public:
{
}
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void bitcom(StackedVector<U>& S, const vector<int>& regs);
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
void xor_(int n, const This& x, const This& y)
{ *this = (x ^ y).mask(n); }
@@ -194,7 +210,7 @@ public:
typedef ReplicatedBase Protocol;
static ReplicatedSecret constant(const typename super::clear& value,
int my_num, typename super::mac_key_type, int = -1)
int my_num, typename super::mac_key_type = {}, int = -1)
{
ReplicatedSecret res;
if (my_num < 2)

View File

@@ -54,7 +54,7 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
}
template<class U, int L>
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
void RepSecretBase<U, L>::bitcom(StackedVector<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
@@ -62,7 +62,7 @@ void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
}
template<class U, int L>
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
void RepSecretBase<U, L>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;

View File

@@ -67,7 +67,7 @@ inline ShareThread<T>& ShareThread<T>::s()
if (singleton and T::is_real)
return *singleton;
else
throw runtime_error("no singleton");
throw runtime_error("no ShareThread singleton");
}
} /* namespace GC */

View File

@@ -94,23 +94,35 @@ void ShareThread<T>::and_(Processor<T>& processor,
processor.check_args(args, 4);
protocol->init_mul();
T x_ext, y_ext;
int total_bits = 0;
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
total_bits += n_bits;
int left = args[i + 2];
int right = args[i + 3];
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
{
int n = min(T::default_length, n_bits - j * T::default_length);
if (not repeat and n == T::default_length)
{
protocol->prepare_mul(processor.S[left + j], processor.S[right + j]);
continue;
}
processor.S[left + j].mask(x_ext, n);
if (repeat)
processor.S[right].extend_bit(y_ext, n);
else
processor.S[right + j].mask(y_ext, n);
processor.S[left + j].mask(x_ext, n);
protocol->prepare_mult(x_ext, y_ext, n, repeat);
}
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d%s ANDs\n", total_bits, repeat ? " repeat" : "");
protocol->exchange();
for (size_t i = 0; i < args.size(); i += 4)
@@ -121,10 +133,20 @@ void ShareThread<T>::and_(Processor<T>& processor,
{
int n = min(T::default_length, n_bits - j * T::default_length);
auto& res = processor.S[out + j];
if (not repeat and n == T::default_length)
{
res = protocol->finalize_mul();
continue;
}
protocol->finalize_mult(res, n);
res.mask(res, n);
}
}
if (OnlineOptions::singleton.has_option("always_check"))
protocol->check();
}
template<class T>
@@ -136,10 +158,12 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
protocol->init_mul();
auto it = args.begin();
T x_ext, y_ext;
int total_bits = 0;
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
total_bits += size * n_args;
it += n_args;
int base = *it++;
for (int i = 0; i < size; i += N_BITS)
@@ -155,6 +179,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
it += n_args;
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d repeat ANDs\n", total_bits);
protocol->exchange();
it = args.begin();
@@ -171,6 +198,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
}
it += 2 * n_args + 1;
}
if (OnlineOptions::singleton.has_option("always_check"))
protocol->check();
}
template<class T>

View File

@@ -32,6 +32,8 @@ public:
array<T, 3> get_triple_no_count(int n_bits)
{
int max_n_bits = T::default_length;
if (n_bits == -1)
n_bits = max_n_bits;
assert(n_bits <= max_n_bits);
assert(n_bits > 0);
array<T, 3> res;

View File

@@ -48,6 +48,7 @@ public:
Thread(int thread_num, ThreadMaster<T>& master);
virtual ~Thread();
void start();
void run();
virtual void pre_run() {}
virtual void run(Program& program);

View File

@@ -31,7 +31,12 @@ template<class T>
Thread<T>::Thread(int thread_num, ThreadMaster<T>& master) :
master(master), machine(master.machine), processor(machine),
N(master.N), P(0),
thread_num(thread_num)
thread_num(thread_num), thread(0)
{
}
template<class T>
void Thread<T>::start()
{
pthread_create(&thread, 0, run_thread, this);
}

View File

@@ -60,6 +60,7 @@ public:
virtual Thread<T>* new_thread(int i);
void run();
void run_with_error();
virtual void post_run() {}
};

View File

@@ -59,6 +59,25 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
template<class T>
void ThreadMaster<T>::run()
{
if (opts.has_option("throw_exceptions"))
run_with_error();
else
{
try
{
run_with_error();
}
catch (exception& e)
{
cerr << "Fatal error: " << e.what() << endl;
exit(1);
}
}
}
template<class T>
void ThreadMaster<T>::run_with_error()
{
if (not opts.live_prep)
{
@@ -72,6 +91,9 @@ void ThreadMaster<T>::run()
for (int i = 0; i < machine.nthreads; i++)
threads.push_back(new_thread(i));
// must start after constructor due to virtual functions
for (auto thread : threads)
thread->start();
for (auto thread : threads)
thread->join_tape();

View File

@@ -50,7 +50,7 @@ public:
return *part_MC;
}
void init_open(const Player& P, int n)
void init_open(const Player& P, int n = 0)
{
part_MC->init_open(P);
sizes.clear();

View File

@@ -53,9 +53,15 @@ public:
static const bool malicious = T::malicious;
static const bool expensive_triples = T::expensive_triples;
static const bool randoms_for_opens = false;
static const bool symmetric = true;
static const int default_length = 64;
static bool real_shares(const Player&)
{
return true;
}
static int size()
{
return part_type::size() * default_length;
@@ -72,6 +78,11 @@ public:
T::read_or_generate_mac_key(directory, P, key);
}
static typename T::mac_type get_mac_key()
{
return T::get_mac_key();
}
template<class U>
static void reveal_inst(U& processor, const vector<int>& args)
{

View File

@@ -80,14 +80,15 @@
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
X(CONVCBIT, Proc.write_Ci(R0, Proc.sync(PC1.get()))) \
X(CONVCINTVEC, Proc.convcintvec(instruction)) \
X(CONVCBITVEC, Proc.convcbitvec(instruction)) \
X(CONVCBITVEC, Proc.Procb.convcbitvec(instruction, Proc.get_Ci(), &Proc.P)) \
X(CONVCBIT2S, PROC.convcbit2s(instruction)) \
X(DABIT, Proc.dabit(INST)) \
X(EDABIT, Proc.edabit(INST)) \
X(SEDABIT, Proc.edabit(INST, true)) \
X(SPLIT, Proc.split(INST)) \
X(CALL_ARG, ) \
#define GC_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, EXTRA)) \
@@ -101,6 +102,7 @@
X(CONVCINT, C0 = PI1) \
X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \
X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \
X(CONVCBITVEC, PROC.convcbitvec(instruction, Ci, 0)) \
X(PRINTCHR, PROC.print_chr(IMM)) \
X(PRINTSTR, PROC.print_str(IMM)) \
X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \
@@ -146,8 +148,11 @@
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \
X(CRASH, if (I0.get() and T::actual_inputs) throw crash_requested()) \
X(ACTIVE, ) \
X(LDTN, I0 = BaseMachine::thread_num) \
X(CALL_TAPE, PROC.call_tape(INST, MD)) \
X(CALL_ARG, ) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS

View File

@@ -1,4 +1,4 @@
The Software is copyright (c) 2023, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
The Software is copyright (c) 2024, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence.

View File

@@ -10,7 +10,8 @@
#include "Protocols/AtlasPrep.h"
#include "GC/AtlasSecret.h"
#include "ShamirMachine.hpp"
#include "Protocols/Atlas.hpp"
#include "Shamir.hpp"
#endif /* MACHINES_ATLAS_HPP_ */

View File

@@ -4,7 +4,7 @@
*/
#include "BMR/RealProgramParty.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/Shamir.hpp"
#include "Math/Z2k.hpp"
#include "Machines/MalRep.hpp"

View File

@@ -4,7 +4,7 @@
*/
#include "BMR/RealProgramParty.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/Shamir.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)

View File

@@ -232,9 +232,9 @@ OTMachine::OTMachine(int argc, const char** argv)
gettimeofday(&baseOTstart, NULL);
// swap role for base OTs
if (opt.isSet("-r"))
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
else
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
cout << "real mode " << opt.isSet("-r") << endl;
BaseOT& bot = *bot_;
bot.exec_base();

View File

@@ -5,3 +5,5 @@
#include "Math/gfp.hpp"
template class FieldMachine<Share, Share, DishonestMajorityMachine>;
template class Machine<Share<gfp_<0, 2>>, Share<gf2n>>;

View File

@@ -6,6 +6,9 @@
#ifndef MACHINES_SPDZ_HPP_
#define MACHINES_SPDZ_HPP_
#include "Protocols/MAC_Check.h"
#include "Protocols/SPDZ.h"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"

50
Machines/Shamir.hpp Normal file
View File

@@ -0,0 +1,50 @@
/*
* ShamirMachine.cpp
*
*/
#ifndef MACHINE_SHAMIR_HPP_
#define MACHINE_SHAMIR_HPP_
#include "Protocols/ShamirOptions.h"
#include "Protocols/ShamirShare.h"
#include "Protocols/MaliciousShamirShare.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "GC/VectorProtocol.h"
#include "GC/CcdPrep.h"
#include "GC/TinyMC.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/VectorInput.h"
#include "Processor/FieldMachine.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/ShamirInput.hpp"
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirMC.hpp"
#include "Protocols/MaliciousShamirMC.hpp"
#include "Protocols/MaliciousShamirPO.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Spdz2kPrep.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Math/gfp.hpp"
template<template<class U> class T>
ShamirMachineSpec<T>::ShamirMachineSpec(int argc, const char** argv)
{
auto& opts = ShamirOptions::singleton;
ez::ezOptionParser opt;
opts = {opt, argc, argv};
HonestMajorityFieldMachine<T>(argc, argv, opt, opts.nparties);
}
#endif

View File

@@ -14,6 +14,7 @@
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "GC/PersonalPrep.hpp"
#include "Protocols/Share.hpp"
//template class GC::ShareParty<GC::TinierSecret<gf2n_mac_key>>;
template class GC::CcdPrep<GC::TinierSecret<gf2n_mac_key>>;

View File

@@ -15,6 +15,7 @@
#include "Math/BitVec.h"
#include "GC/TinierSecret.h"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Math/Z2k.hpp"

View File

@@ -13,7 +13,7 @@
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/Shamir.hpp"
int main(int argc, const char** argv)
{

View File

@@ -0,0 +1,8 @@
/*
* export-vm.cpp
*
*/
#include "maximal.hpp"
template class Machine<AtlasShare<gfp_<0, 2>>>;

View File

@@ -0,0 +1,8 @@
/*
* export-vm.cpp
*
*/
#include "maximal.hpp"
template class Machine<CowGearShare<gfp_<0, 2>>>;

View File

@@ -0,0 +1,8 @@
/*
* export-vm.cpp
*
*/
#include "maximal.hpp"
template class Machine<DealerShare<SignedZ2<64>>>;

8
Machines/export-hemi.cpp Normal file
View File

@@ -0,0 +1,8 @@
/*
* export-vm.cpp
*
*/
#include "maximal.hpp"
template class Machine<HemiShare<gfp_<0, 2>>>;

View File

@@ -0,0 +1,8 @@
/*
* export-vm.cpp
*
*/
#include "maximal.hpp"
template class Machine<Rep4Share2<64>>;

Some files were not shown because too many files have changed in this diff Show More