Maintenance.

This commit is contained in:
Marcel Keller
2024-07-09 12:17:25 +10:00
parent b0dc2b36f8
commit 78fe3d8bad
234 changed files with 4273 additions and 1367 deletions

View File

@@ -15,7 +15,7 @@ int AndJob::run()
#endif
__m128i* prf_output = new __m128i[PAD_TO_8(ProgramParty::s().get_n_parties())];
auto gate = gates.begin();
vector< GC::Secret<EvalRegister> >& S = *this->S;
auto& S = *this->S;
const vector<int>& args = *this->args;
int i_gate = 0;
for (size_t i = start; i < end; i += 4)

View File

@@ -15,7 +15,7 @@ using namespace std;
class AndJob
{
vector< GC::Secret<EvalRegister> >* S;
StackedVector< GC::Secret<EvalRegister> >* S;
const vector<int>* args;
public:
@@ -25,7 +25,7 @@ public:
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
void reset(StackedVector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
{
this->S = &S;

View File

@@ -1,5 +1,22 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.9 (July 9, 2024)
- Inference with non-sequential PyTorch networks
- SHA-3 for any input length (@hiddely)
- Improved client facilities
- Shuffling with malicious security for SPDZ-wise protocols by [Asharov et al.](https://ia.cr/2022/1595)
- More reusable bytecode via in-thread calling facility
- Recursive functions without return values
- Fewer rounds for parallel matrix multiplications (@vincent-ehrmanntraut)
- Optimized usage of SoftSpokenOT in semi-honest protocols
- More integrity checks on storage in MAC-based protocols
- Use C++17
- Use glibc 2.18 for the binaries
- Fixed security bugs: remotely caused buffer overflows (#1382)
- Fixed security bug: Missing randomization before revealing to client
- Fixed security bug: Bias in Rep3 secure shuffling
## 0.3.8 (December 14, 2023)
- Functionality for multiple nodes per party

2
CONFIG
View File

@@ -106,7 +106,7 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++17 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

View File

@@ -203,8 +203,10 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
def add_usage(self, req_node):
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)
n = (n - 3) // 2
req_node.increment(('bit', 'triple'), size * n)
if n > 1:
req_node.increment(('bit', 'mixed'), size * ((n + 63) // 64))
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))

View File

@@ -13,7 +13,7 @@ from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint,
from Compiler.types import vectorized_classmethod
from Compiler.program import Tape, Program
from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint, library
from Compiler import util, oram, floatingpoint, library, comparison
from Compiler import instructions_base
import Compiler.GC.instructions as inst
import operator
@@ -21,6 +21,11 @@ import math
import itertools
from functools import reduce
class _binary:
def reveal_to(self, *args, **kwargs):
raise CompilerError(
'%s does not support revealing to indivual players' % type(self))
class bits(Tape.Register, _structure, _bit):
n = 40
unit = 64
@@ -149,6 +154,12 @@ class bits(Tape.Register, _structure, _bit):
self.n = n
def set_size(self, size):
pass
def load_int(self, value):
n_limbs = math.ceil(self.n / self.unit)
for i in range(n_limbs):
self.conv_regint(min(self.unit, self.n - i * self.unit),
self[i], regint(value % 2 ** self.unit))
value >>= self.unit
def load_other(self, other):
if isinstance(other, cint):
assert(self.n == other.size)
@@ -236,12 +247,14 @@ class bits(Tape.Register, _structure, _bit):
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
Bit-wise oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
This will output 1 because it selects the two least
significant bits from 5 and the rest of the bits from 2.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
def zero_if_not(self, condition):
@@ -268,6 +281,9 @@ class bits(Tape.Register, _structure, _bit):
self.bit_compose(source.bit_decompose()[base:base + size]))
def vector_size(self):
return self.n
@staticmethod
def size_for_mem():
return 1
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
@@ -302,13 +318,6 @@ class cbits(bits):
else:
return super(cbits, cls).conv(other)
types = {}
def load_int(self, value):
n_limbs = math.ceil(self.n / self.unit)
tmp = regint(size=n_limbs)
for i in range(n_limbs):
tmp[i].load_int(value % 2 ** self.unit)
value >>= self.unit
self.load_other(tmp)
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
@@ -502,11 +511,7 @@ class sbits(bits):
if self.n <= 32:
inst.ldbits(self, self.n, value)
else:
size = math.ceil(self.n / self.unit)
tmp = regint(size=size)
for i in range(size):
tmp[i].load_int((value >> (i * 64)) % 2**64)
self.load_other(tmp)
bits.load_int(self, value)
def load_other(self, other):
if isinstance(other, cbits) and self.n == other.n:
inst.convcbit2s(self.n, self, other)
@@ -675,7 +680,7 @@ class sbits(bits):
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
class sbitvec(_vec, _bit):
class sbitvec(_vec, _bit, _binary):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
This facilitates parallel arithmetic operations in binary circuits.
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -732,15 +737,16 @@ class sbitvec(_vec, _bit):
:py:obj:`v` and the columns by calling :py:obj:`elements`.
"""
class sbitvecn(cls, _structure):
@staticmethod
def malloc(size, creator_tape=None):
return sbit.malloc(size * n, creator_tape=creator_tape)
@classmethod
def malloc(cls, size, creator_tape=None):
return sbit.malloc(
size * cls.mem_size(), creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@staticmethod
def mem_size():
return n
return sbits.get_type(n).mem_size()
@classmethod
def get_input_from(cls, player, size=1, f=0):
""" Secret input from :py:obj:`player`. The input is decomposed
@@ -780,38 +786,28 @@ class sbitvec(_vec, _bit):
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
assert size is None or size == self.v[0].n
@vectorized_classmethod
def load_mem(cls, address):
size = instructions_base.get_global_vector_size()
if size not in (None, 1):
assert isinstance(address, int) or len(address) == 1
sb = sbits.get_type(size)
return cls.from_vec(sb.bit_compose(
sbit.load_mem(address + i + j * n) for j in range(size))
for i in range(n))
if not isinstance(address, int):
v = [sbit.load_mem(x, size=n).v[0] for x in address]
return cls(v)
@classmethod
def load_mem(cls, address, size=None):
if isinstance(address, int) or len(address) == 1:
address = [address + i for i in range(size or 1)]
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
assert size == None
return cls(
[sbits.get_type(n).load_mem(x) for x in address])
def store_in_mem(self, address):
size = 1
for x in self.v:
if not util.is_constant(x):
size = max(size, x.n)
v = [sbits.get_type(size).conv(x) for x in self.v]
if not isinstance(address, int) and len(address) != 1:
v = self.elements()
assert len(v) == len(address)
for x, y in zip(v, address):
for i, xx in enumerate(x.bit_decompose(n)):
xx.store_in_mem(y + i)
if isinstance(address, int):
address = range(address, address + size)
elif len(address) == 1:
address = [address + i * self.mem_size()
for i in range(size)]
else:
assert isinstance(address, int) or len(address) == 1
for i in range(n):
for j, x in enumerate(v[i].bit_decompose()):
x.store_in_mem(address + i + j * n)
assert size == len(address)
for x, dest in zip(self.elements(), address):
x.store_in_mem(dest)
@classmethod
def two_power(cls, nn, size=1):
return cls.from_vec(
@@ -864,7 +860,7 @@ class sbitvec(_vec, _bit):
assert isinstance(elements, sint)
if Program.prog.use_split():
x = elements.split_to_two_summands(length)
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
v = sbitint.bit_adder(x[0], x[1])
else:
prog = Program.prog
if not prog.options.ring:
@@ -877,6 +873,7 @@ class sbitvec(_vec, _bit):
length, prog.security)
prog.use_edabit(backup)
return
comparison.require_ring_size(length, 'A2B conversion')
l = int(Program.prog.options.ring)
r, r_bits = sint.get_edabit(length, size=elements.size)
c = ((elements - r) << (l - length)).reveal()
@@ -885,6 +882,8 @@ class sbitvec(_vec, _bit):
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
elif isinstance(elements, sbitvec):
self.v = elements.v
elif elements is not None and not (util.is_constant(elements) and \
elements == 0):
self.v = sbits.trans(elements)
@@ -1347,13 +1346,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
def __add__(self, other):
if util.is_zero(other):
return self
a, b = self.expand(other)
try:
a, b = self.expand(other)
except:
return NotImplemented
v = sbitint.bit_adder(a, b)
return self.get_type(len(v)).from_vec(v)
__radd__ = __add__
__sub__ = _bitint.__sub__
def __rsub__(self, other):
a, b = self.expand(other)
try:
a, b = self.expand(other)
except:
return NotImplemented
return self.from_vec(b) - self.from_vec(a)
def __mul__(self, other):
if isinstance(other, sbits):
@@ -1447,7 +1452,7 @@ class cbitfix(object):
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
class sbitfix(_fix):
class sbitfix(_fix, _binary):
""" Secret signed fixed-point number in one binary register.
Use :py:obj:`set_precision()` to change the precision.
@@ -1515,7 +1520,7 @@ class sbitfix(_fix):
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
class sbitfixvec(_fix, _vec):
class sbitfixvec(_fix, _vec, _binary):
""" Vector of fixed-point numbers for parallel binary computation.
Use :py:obj:`set_precision()` to change the precision.

View File

@@ -76,7 +76,7 @@ class AllocRange:
self.top += size
self.limit = max(self.limit, self.top)
if res >= REG_MAX:
raise RegisterOverflowError()
raise RegisterOverflowError(size)
return res
def free(self, base, size):
@@ -209,7 +209,8 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)
if reg not in self.program.base_addresses:
if reg not in self.program.base_addresses \
and not isinstance(inst, call_arg):
free.free(base)
if inst.is_vec() and base.vector:
self.defined[base] = inst
@@ -608,7 +609,8 @@ class Merger:
# so this threshold should lead to acceptable compile times even on slower processors.
first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4]
second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5]
max_dependencies_per_matrix = 1500**2
max_dependencies_per_matrix = \
self.block.parent.program.budget
if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix:
if block.warn_about_mem and not block.parent.warned_about_mem:
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')

View File

@@ -5,7 +5,7 @@ the ones used below into ``Programs/Circuits`` as follows::
make Programs/Circuits
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
.. _`Bristol Fashion`: https://nigelsmart.github.io/MPC-Circuits
"""
import math
@@ -15,6 +15,7 @@ from Compiler.library import function_block, get_tape
from Compiler import util
import itertools
import struct
import os
class Circuit:
"""
@@ -47,7 +48,12 @@ class Circuit:
"""
def __init__(self, name):
self.name = name
self.filename = 'Programs/Circuits/%s.txt' % name
if not os.path.exists(self.filename):
if os.system('make Programs/Circuits'):
raise CompilerError('Cannot download circuit descriptions. '
'Make sure make and git are installed.')
f = open(self.filename)
self.functions = {}
@@ -57,8 +63,9 @@ class Circuit:
def run(self, *inputs):
n = inputs[0][0].n, get_tape()
if n not in self.functions:
self.functions[n] = function_block(lambda *args:
self.compile(*args))
self.functions[n] = function_block(
lambda *args: self.compile(*args))
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
flat_res = self.functions[n](*itertools.chain(*inputs))
res = []
i = 0
@@ -124,7 +131,7 @@ Keccak_f = None
def sha3_256(x):
"""
This function implements SHA3-256 for inputs of up to 1080 bits::
This function implements SHA3-256 for inputs of any length::
from circuit import sha3_256
a = sbitvec.from_vec([])
@@ -138,7 +145,8 @@ def sha3_256(x):
for x in a, b, c, d, e, f, g, h:
sha3_256(x).reveal_print_hex()
This should output the `test vectors
This should output the hashes of the above inputs, beginning with
the `test vectors
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
0 byte::

View File

@@ -76,13 +76,13 @@ def require_ring_size(k, op):
program.curr_tape.require_bit_length(k)
@instructions_base.cisc
def LTZ(s, a, k, kappa):
def LTZ(s, a, k):
"""
s = (a ?< 0)
k: bit length of a
"""
movs(s, program.non_linear.ltz(a, k, kappa))
movs(s, program.non_linear.ltz(a, k))
def LtzRing(a, k):
from .types import sint, _bitint
@@ -105,14 +105,14 @@ def LtzRing(a, k):
u = CarryOutRaw(a[::-1], b[::-1])
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
def LessThanZero(a, k, kappa):
def LessThanZero(a, k):
from . import types
res = types.sint()
LTZ(res, a, k, kappa)
LTZ(res, a, k)
return res
@instructions_base.cisc
def Trunc(d, a, k, m, kappa, signed):
def Trunc(d, a, k, m, signed):
"""
d = a >> m
@@ -124,7 +124,7 @@ def Trunc(d, a, k, m, kappa, signed):
movs(d, a)
return
else:
movs(d, program.non_linear.trunc(a, k, m, kappa, signed))
movs(d, program.non_linear.trunc(a, k, m, signed=signed))
def TruncRing(d, a, k, m, signed):
program.curr_tape.require_bit_length(1)
@@ -197,13 +197,13 @@ def TruncLeakyInRing(a, k, m, signed):
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits[:n_bits], 0)
BitLTL(u, masked, r_bits[:n_bits])
res = (u << n_bits) + masked - r
if signed:
res -= (1 << (n_bits - 1))
return res
def TruncRoundNearest(a, k, m, kappa, signed=False):
def TruncRoundNearest(a, k, m, signed=False):
"""
Returns a / 2^m, rounded to the nearest integer.
@@ -212,12 +212,10 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
"""
if m == 0:
return a
nl = program.non_linear
nl.check_security(kappa)
return program.non_linear.trunc_round_nearest(a, k, m, signed)
@instructions_base.cisc
def Mod2m(a_prime, a, k, m, kappa, signed):
def Mod2m(a_prime, a, k, m, signed):
"""
a_prime = a % 2^m
@@ -225,8 +223,6 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
m: compile-time integer
signed: True/False, describes a
"""
nl = program.non_linear
nl.check_security(kappa)
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
def Mod2mRing(a_prime, a, k, m, signed):
@@ -237,13 +233,13 @@ def Mod2mRing(a_prime, a, k, m, signed):
tmp = a + r_prime
c_prime = (tmp << shift).reveal(False) >> shift
u = sint()
BitLTL(u, c_prime, r_bin[:m], 0)
BitLTL(u, c_prime, r_bin[:m])
res = (u << m) + c_prime - r_prime
if a_prime is not None:
movs(a_prime, res)
return res
def Mod2mField(a_prime, a, k, m, kappa, signed):
def Mod2mField(a_prime, a, k, m, signed):
from .types import sint
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
@@ -255,7 +251,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
t = [program.curr_block.new_reg('s') for i in range(6)]
c2m = program.curr_block.new_reg('c')
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, r, k, m, kappa)
PRandM(r_dprime, r_prime, r, k, m)
ld2i(c2m, m)
mulm(t[0], r_dprime, c2m)
if signed:
@@ -268,9 +264,9 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
asm_open(True, c, t[3])
modc(c_prime, c, c2m)
if const_rounds:
BitLTC1(u, c_prime, r, kappa)
BitLTC1(u, c_prime, r)
else:
BitLTL(u, c_prime, r, kappa)
BitLTL(u, c_prime, r)
mulm(t[4], u, c2m)
submr(t[5], c_prime, r_prime)
adds(a_prime, t[5], t[4])
@@ -288,13 +284,15 @@ def MaskingBitsInRing(m, strict=False):
r_bin = r
return sint.bit_compose(r), r_bin
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
def PRandM(r_dprime, r_prime, b, k, m, use_dabit=True):
"""
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
r_prime = random secret integer in range [0, 2^m - 1]
b = array containing bits of r_prime
"""
program.curr_tape.require_bit_length(k + kappa)
assert k >= m
kappa = program.security
program.curr_tape.require_bit_length(k + kappa, reason='statistical masking as in https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation')
from .types import sint
if program.use_edabit() and not const_rounds:
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
@@ -329,7 +327,7 @@ def PRandInt(r, k):
bit(t[1][i])
adds(t[2][i], t[0][i], t[1][i])
def BitLTC1(u, a, b, kappa):
def BitLTC1(u, a, b):
"""
u = a <? b
@@ -395,7 +393,7 @@ def BitLTC1(u, a, b, kappa):
subcfi(c[3][i], a_bits[i], 1)
mulm(t[3][i], s[i], c[3][i])
adds(t[4][i], t[4][i-1], t[3][i])
Mod2(u, t[4][k-1], k, kappa, False)
Mod2(u, t[4][k-1], k, False)
return p, a_bits, d, s, t, c, b, pre_input
def carry(b, a, compute_p=True):
@@ -414,7 +412,7 @@ def carry(b, a, compute_p=True):
# from WP9 report
# length of a is even
def CarryOutAux(a, kappa):
def CarryOutAux(a):
k = len(a)
if k > 1 and k % 2 == 1:
a.append(None)
@@ -424,12 +422,12 @@ def CarryOutAux(a, kappa):
if k > 1:
for i in range(k//2):
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
return CarryOutAux(u[:k//2][::-1], kappa)
return CarryOutAux(u[:k//2][::-1])
else:
return a[0][1]
# carry out with carry-in bit c
def CarryOut(res, a, b, c=0, kappa=None):
def CarryOut(res, a, b, c=0):
"""
res = last carry bit in addition of a and b
@@ -456,7 +454,7 @@ def CarryOutRaw(a, b, c=0):
s[0] = d[-1][0].bit_and(c)
s[1] = d[-1][1] + s[0]
d[-1][1] = s[1]
return CarryOutAux(d[::-1], None)
return CarryOutAux(d[::-1])
def CarryOutRawLE(a, b, c=0):
""" Little-endian version """
@@ -469,7 +467,7 @@ def CarryOutLE(a, b, c=0):
CarryOut(res, a[::-1], b[::-1], c)
return res
def BitLTL(res, a, b, kappa):
def BitLTL(res, a, b):
"""
res = a <? b (logarithmic rounds version)
@@ -624,7 +622,7 @@ def KMulC(a):
PreMulC_without_inverses(p, a)
return p
def Mod2(a_0, a, k, kappa, signed):
def Mod2(a_0, a, k, signed):
"""
a_0 = a % 2
@@ -641,7 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
tc = program.curr_block.new_reg('c')
t = [program.curr_block.new_reg('s') for i in range(6)]
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
PRandM(r_dprime, r_prime, [r_0], k, 1)
r_0 = r_prime
mulsi(t[0], r_dprime, 2)
if signed:

View File

@@ -165,7 +165,8 @@ class Compiler:
dest="prime",
default=defaults.prime,
help="use bit decomposition with a specifed prime modulus "
"for non-linear computation (default: use the masking approach)",
"for non-linear computation (default: use the masking approach). "
"Don't use this unless you're certain that you need it.",
)
parser.add_option(
"-I",
@@ -263,6 +264,14 @@ class Compiler:
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.options.verbose:
self.runtime_args += ["--verbose"]
if self.options.execute:
self.options.execute = re.sub("-party.x$", "", self.options.execute)
for s, l in self.match.items():
if self.options.execute == l:
self.options.execute = s
break
if self.execute:
if not self.options.execute:
if len(self.args) > 1:
@@ -313,6 +322,8 @@ class Compiler:
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
if self.options.execute.find("dealer") >= 0:
self.prog.use_edabit(True)
def build_vars(self):
from . import comparison, floatingpoint, instructions, library, types
@@ -368,7 +379,14 @@ class Compiler:
"cfloat",
"squant",
]:
del self.VARS[i]
class dummy:
def __init__(self, *args):
raise CompilerError(self.error)
dummy.error = i + " not availabe with binary circuits"
if i in ("cint", "cfix"):
dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \
"latest/Compiler.html#Compiler.types." + i
self.VARS[i] = dummy
else:
self.sint = types.sint
self.sfix = types.sfix
@@ -503,13 +521,15 @@ class Compiler:
return self.prog
@staticmethod
def executable_from_protocol(protocol):
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
@classmethod
def executable_from_protocol(cls, protocol):
match = cls.match
if protocol in match:
protocol = match[protocol]
if protocol.find("bmr") == -1:
@@ -588,11 +608,19 @@ class Compiler:
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
def run_with_error(i):
try:
run(i)
except IOError:
print('IO error when copying files, does %s have enough space?' %
hostnames[i])
raise
import threading
import random
threads = []
for i in range(len(hosts)):
threads.append(threading.Thread(target=run, args=(i,)))
threads.append(threading.Thread(target=run_with_error, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:

View File

@@ -2,6 +2,7 @@ from collections import defaultdict
REG_MAX = 2 ** 32
USER_MEM = 8192
MEM_MAX = 2 ** 64
P_VALUES = { 32: 2147565569, \
64: 9223372036855103489, \

View File

@@ -558,7 +558,7 @@ def test_stupid_dijkstra_on_cycle(n, n_loops=None):
@for_range(n)
def f(i):
M[i][(i+1)%n] = ExtInt(1)
M[i][(i-1)%n] = ExtInt(1)
M[i][(i-1+n)%n] = ExtInt(1)
if n_loops is not None:
stop_timer(1)
start_timer()

View File

@@ -39,26 +39,25 @@ def maskRing(a, k):
c = ((a + r_prime) << shift).reveal(False) >> shift
return c, r
def maskField(a, k, kappa):
def maskField(a, k):
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
r = [types.sint() for i in range(k)]
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
comparison.PRandM(r_dprime, r_prime, r, k, k)
# always signed due to usage in equality testing
a += two_power(k)
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
return c, r
@instructions_base.ret_cisc
def EQZ(a, k, kappa):
def EQZ(a, k):
prog = program.Program.prog
if prog.use_split():
from GC.types import sbitvec
v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sintbit.conv(bit)
prog.non_linear.check_security(kappa)
return prog.non_linear.eqz(a, k)
def bits(a,m):
@@ -99,12 +98,12 @@ def or_op(a, b, void=None):
def mul_op(a, b, void=None):
return a * b
def PreORC(a, kappa=None, m=None, raw=False):
def PreORC(a, m=None, raw=False):
k = len(a)
if k == 1:
return [a[0]]
prog = program.Program.prog
kappa = kappa or prog.security
kappa = prog.security
m = m or k
if isinstance(a[0], types.sgf2n):
max_k = program.Program.prog.galois_length - 1
@@ -128,13 +127,13 @@ def PreORC(a, kappa=None, m=None, raw=False):
t = [types.sint() for i in range(m)]
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
for i in range(m):
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
comparison.Mod2(t[i], b[k-1-i], k, False)
p[m-1-i] = 1 - t[i]
return p
else:
# not constant-round anymore
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
s = [PreORC(a[i:i+max_k], raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], raw=raw)
return sum(([or_op(x, y) for x in si]
for si,y in zip(s[1:],t)), s[0])[-m:]
@@ -175,6 +174,41 @@ def PreOpL2(op, items):
output[2 * i] = op(v[i - 1], items[2 * i])
return output
def PreOpL2_vec(op, *items):
""" Vectorized version of :py:func:`PreOpL2` """
k = len(items[0])
for x in items:
assert len(x) == k
if k == 1:
return items
half = k // 2
other_half = (k + 1) // 2 - 1
u = op([x.get_vector(base=0, size=half, skip=2) for x in items],
[x.get_vector(base=1, size=half, skip=2) for x in items])
assert len(u) == len(items)
assert len(u[0]) == half
v = PreOpL2_vec(op, *u)
if other_half:
w = op([x.get_vector(base=0, size=other_half) for x in v],
[x.get_vector(base=2, size=other_half, skip=2) for x in items])
if half == other_half:
res = [type(x).zip(x, y) for x, y in zip(v, w)]
for i in range(len(res)):
res[i] = type(res[i]).concat((items[i].get_vector(base=0, size=1),
res[i]))
else:
if other_half:
for i in range(len(w)):
w[i] = type(w[i]).concat((items[i].get_vector(base=0, size=1),
w[i]))
else:
w = [x.get_vector(base=0, size=1) for x in items]
res = [type(x).zip(x, y) for x, y in zip(w, v)]
assert len(res) == len(items)
for x in res:
assert len(x) == k
return res
def PreOpN(op, items):
""" Naive PreOp algorithm """
k = len(items)
@@ -184,9 +218,9 @@ def PreOpN(op, items):
output[i] = op(output[i-1], items[i])
return output
def PreOR(a, kappa=None, raw=False):
def PreOR(a=None, raw=False):
if comparison.const_rounds:
return PreORC(a, kappa, raw=raw)
return PreORC(a, raw=raw)
else:
return PreOpL(or_op, a)
@@ -199,24 +233,24 @@ def KOpL(op, a):
t2 = KOpL(op, a[k//2:])
return op(t1, t2)
def KORL(a, kappa=None):
def KORL(a):
""" log rounds k-ary OR """
k = len(a)
if k == 1:
return a[0]
else:
t1 = KORL(a[:k//2], kappa)
t2 = KORL(a[k//2:], kappa)
t1 = KORL(a[:k//2])
t2 = KORL(a[k//2:])
return t1 + t2 - t1.bit_and(t2)
def KORC(a, kappa):
return PreORC(a, kappa, 1)[0]
def KORC(a):
return PreORC(a, 1)[0]
def KOR(a, kappa):
def KOR(a):
if comparison.const_rounds:
return KORC(a, kappa)
return KORC(a)
else:
return KORL(a, None)
return KORL(a)
def KMul(a):
if comparison.const_rounds:
@@ -262,7 +296,7 @@ def BitAdd(a, b, bits_to_compute=None):
s[k] = c[k-1]
return s
def BitDec(a, k, m, kappa, bits_to_compute=None):
def BitDec(a, k, m, bits_to_compute=None):
return program.Program.prog.non_linear.bit_dec(a, k, m)
def BitDecRingRaw(a, k, m):
@@ -270,7 +304,7 @@ def BitDecRingRaw(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m
if program.Program.prog.use_split():
x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
bits = types._bitint.bit_adder(x[0], x[1])
return bits[:m]
else:
if program.Program.prog.use_edabit():
@@ -292,13 +326,14 @@ def BitDecRing(a, k, m):
# reversing to reduce number of rounds
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
r = [types.sint() for i in range(m)]
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
comparison.PRandM(r_dprime, r_prime, r, k, m)
kappa = program.Program.prog.security
pow2 = two_power(k + kappa)
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
@@ -306,16 +341,16 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
return res
@instructions_base.bit_cisc
def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
def BitDecField(a, k, m, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, bits_to_compute)
return [types.sintbit.conv(bit) for bit in res]
@instructions_base.ret_cisc
def Pow2(a, l, kappa):
def Pow2(a, l):
comparison.program.curr_tape.require_bit_length(l - 1)
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
t = BitDec(a, m, m)
return Pow2_from_bits(t)
def Pow2_from_bits(bits):
@@ -327,11 +362,12 @@ def Pow2_from_bits(bits):
t[i] = t[i]*pow2k[i] + 1 - t[i]
return KMul(t)
def B2U(a, l, kappa):
pow2a = Pow2(a, l, kappa)
return B2U_from_Pow2(pow2a, l, kappa), pow2a
def B2U(a, l):
pow2a = Pow2(a, l)
return B2U_from_Pow2(pow2a, l), pow2a
def B2U_from_Pow2(pow2a, l, kappa):
def B2U_from_Pow2(pow2a, l):
kappa = program.Program.prog.security
r = [types.sint() for i in range(l)]
t = types.sint()
c = types.cint()
@@ -353,17 +389,17 @@ def B2U_from_Pow2(pow2a, l, kappa):
c = list(r_bits[0].bit_decompose_clear(c, l))
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
#print ' '.join(str(b.value) for b in x)
y = PreOR(x, kappa)
y = PreOR(x)
#print ' '.join(str(b.value) for b in y)
return [types.sint.conv(1 - y[i]) for i in range(l)]
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
def Trunc(a, l, m, compute_modulo=False, signed=False):
""" Oblivious truncation by secret m """
prog = program.Program.prog
if util.is_constant(m) and not compute_modulo:
# cheaper
res = type(a)(size=a.size)
comparison.Trunc(res, a, l, m, kappa, signed=signed)
comparison.Trunc(res, a, l, m, signed=signed)
return res
if l == 1:
if compute_modulo:
@@ -371,9 +407,9 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
else:
return a * (1 - m)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, Pow2(m, l, kappa))
return TruncInRing(a, l, Pow2(m, l))
else:
kappa = kappa or program.Program.prog.security
kappa = program.Program.prog.security
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
@@ -381,7 +417,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
c = types.cint()
ci = [types.cint() for i in range(l)]
d = types.sint()
x, pow2m = B2U(m, l, kappa)
x, pow2m = B2U(m, l)
for i in range(l):
bit(r[i])
t1 = two_power(i) * r[i]
@@ -398,7 +434,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
@@ -429,33 +465,33 @@ def TruncInRing(to_shift, l, pow2m):
def SplitInRing(a, l, m):
if l == 1:
return m.if_else(a, 0), m.if_else(0, a), 1
pow2m = Pow2(m, l, None)
pow2m = Pow2(m, l)
upper = TruncInRing(a, l, pow2m)
lower = a - upper * pow2m
return lower, upper, pow2m
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
def TruncRoundNearestAdjustOverflow(a, length, target_length):
t = comparison.TruncRoundNearest(a, length, length - target_length)
overflow = t.greater_equal(two_power(target_length), target_length + 1)
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
return s, overflow
def Int2FL(a, gamma, l, kappa=None):
def Int2FL(a, gamma, l):
lam = gamma - 1
s = a.less_than(0, gamma, security=kappa)
z = a.equal(0, gamma, security=kappa)
s = a.less_than(0, gamma)
z = a.equal(0, gamma)
a = s.if_else(-a, a)
a_bits = a.bit_decompose(lam, security=kappa)
a_bits = a.bit_decompose(lam)
a_bits.reverse()
b = PreOR(a_bits, kappa)
b = PreOR(a_bits)
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
p = a.popcnt_bits(b) - lam
if gamma - 1 > l:
if types.sfloat.round_nearest:
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l)
p = p + overflow
else:
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
v = t.right_shift(gamma - l - 1, gamma - 1, signed=False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * z.bit_not()
@@ -466,32 +502,31 @@ def FLRound(x, mode):
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
a = types.sint()
comparison.LTZ(a, p1, k, x.kappa)
b = p1.less_than(-l + 1, k, x.kappa)
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
c = EQZ(v2, l, x.kappa)
comparison.LTZ(a, p1, k)
b = p1.less_than(-l + 1, k)
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, compute_modulo=True)
c = EQZ(v2, l)
if mode == -1:
away_from_zero = 0
mode = x.s
else:
away_from_zero = mode + s1 - 2 * mode * s1
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
d = v.equal(two_power(l), l + 1, x.kappa)
d = v.equal(two_power(l), l + 1)
v = d * two_power(l-1) + (1 - d) * v
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
s = (1 - b * mode) * s1
z = or_op(EQZ(v, l, x.kappa), z1)
z = or_op(EQZ(v, l), z1)
v = v * (1 - z)
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
return v, p, z, s
@instructions_base.ret_cisc
def TruncPr(a, k, m, kappa=None, signed=True):
def TruncPr(a, k, m, signed=True):
""" Probabilistic truncation [a/2^m + u]
where Pr[u = 1] = (a % 2^m) / 2^m
"""
nl = program.Program.prog.non_linear
nl.check_security(kappa)
return nl.trunc_pr(a, k, m, signed)
def TruncPrRing(a, k, m, signed=True):
@@ -540,16 +575,14 @@ def TruncPrRing(a, k, m, signed=True):
res -= (1 << (k - m - 1))
return res
def TruncPrField(a, k, m, kappa=None):
def TruncPrField(a, k, m):
if m == 0:
return a
if kappa is None:
kappa = 40
b = two_power(k-1) + a
r_prime, r_dprime = types.sint(), types.sint()
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
k, m, kappa, use_dabit=False)
k, m, use_dabit=False)
two_to_m = two_power(m)
r = two_to_m * r_dprime + r_prime
c = (b + r).reveal(False)
@@ -559,49 +592,49 @@ def TruncPrField(a, k, m, kappa=None):
return d
@instructions_base.ret_cisc
def SDiv(a, b, l, kappa, round_nearest=False):
def SDiv(a, b, l, round_nearest=False):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
w = types.cint(int(2.9142 * 2 ** l)) - 2 * b
x = alpha - b * w
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
comparison.Mod2m(x2, x, 2 * l + 1, l, signed=True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest,
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l,
nearest=round_nearest,
signed=False)
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, nearest=round_nearest,
signed=False)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
x = x1 * x1 + x.round(2 * l + 1, l - 1, nearest=round_nearest,
signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
comparison.Mod2m(x2, x, 2 * l, l, signed=False)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest, signed=False)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
signed=False)
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
return y
def SDiv_mono(a, b, l, kappa):
def SDiv_mono(a, b, l):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
y = TruncPr(y, 2 * l + 1, l + 1)
for i in range(theta-1):
y = y * (alpha + x)
# keep y with l bits
y = TruncPr(y, 3 * l, 2 * l, kappa)
y = TruncPr(y, 3 * l, 2 * l)
x = x**2
# keep x with 2l bits
x = TruncPr(x, 4 * l, 2 * l, kappa)
x = TruncPr(x, 4 * l, 2 * l)
y = y * (alpha + x)
y = TruncPr(y, 3 * l, 2 * l, kappa)
y = TruncPr(y, 3 * l, 2 * l)
return y
# LT bit comparison on shared bit values

View File

@@ -488,6 +488,70 @@ class join_tape(base.Instruction):
code = base.opcodes['JOIN_TAPE']
arg_format = ['int']
class call_tape(base.DoNotEliminateInstruction):
""" Start tape/bytecode file in same thread. Arguments/return values
starting from :py:obj:`direction` are optional.
:param: tape number (int)
:param: arg (regint)
:param: direction (0 for argument, 1 for return value)
:param: register type (see :py:obj:`vm_types`)
:param: register size (int)
:param: destination register
:param: source register
:param: (repeat from direction)
"""
code = base.opcodes['CALL_TAPE']
arg_format = tools.chain(['int', 'ci'],
tools.cycle(['int','int','int','*w','*']))
@staticmethod
def type_check(reg, type_id):
assert base.vm_types[reg.reg_type] == type_id
def __init__(self, *args, **kwargs):
super(call_tape, self).__init__(*args, **kwargs)
for i in range(2, len(args), 5):
for reg in args[i + 3:i + 5]:
self.type_check(reg, args[i + 1])
assert reg.size == args[i + 2]
assert args[i] in (0, 1)
assert args[i + 4 - args[i]].program == program.curr_tape
assert args[i + 3 + args[i]].program == program.tapes[args[0]]
def get_def(self):
# hide registers from called tape
for i in range(2, len(self.args), 5):
if self.args[i]:
yield self.args[i + 3]
def get_used(self):
# hide registers from called tape
yield self.args[1]
for i in range(2, len(self.args), 5):
if not self.args[i]:
yield self.args[i + 4]
def add_usage(self, req_node):
req_node.num += program.tapes[self.args[0]].req_tree.aggregate()
class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction):
""" Pseudo instruction for arguments in connection with
:py:class:`call_tape`.
:param: destination (register)
:param: register type (see :py:obj:`vm_types`)
"""
code = base.opcodes['CALL_ARG']
arg_format = ['*w','int']
def __init__(self, *args, **kwargs):
super(call_arg, self).__init__(*args, **kwargs)
for i in range(0, len(args), 2):
call_tape.type_check(args[i], args[i + 1])
class crash(base.IOInstruction):
""" Crash runtime if the value in the register is not zero.
@@ -687,6 +751,27 @@ class concats(base.VectorInstruction):
for i in range(1, len(args), 2):
assert args[i] == len(args[i + 1])
class zips(base.Instruction):
""" Zip vectors.
:param: result (sint)
:param: operand (sint)
:param: operand (sint)
"""
__slots__ = []
code = base.opcodes['ZIPS']
arg_format = ['sw','s','s']
is_vec = lambda self: True
def __init__(self, *args):
super(zips, self).__init__(*args)
assert len(args[0]) == len(args[1]) + len(args[2])
assert len(args[1]) == len(args[2])
def get_code(self):
return super(zips, self).get_code(len(self.args[1]))
@base.gf2n
@base.vectorize
class mulc(base.MulBase):

View File

@@ -68,6 +68,8 @@ opcodes = dict(
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
CMDLINEARG = 0xEB,
CALL_TAPE = 0xEC,
CALL_ARG = 0xED,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -85,6 +87,7 @@ opcodes = dict(
PREFIXSUMS = 0x2D,
PICKS = 0x2E,
CONCATS = 0x2F,
ZIPS = 0x3F,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
@@ -223,6 +226,17 @@ opcodes = dict(
)
vm_types = dict(
ci = 0,
sb = 1,
cb = 2,
s = 4,
c = 5,
sg = 6,
cg = 7,
)
def int_to_bytes(x):
""" 32 bit int to big-endian 4 byte conversion. """
assert(x < 2**32 and x >= -2**32)
@@ -491,7 +505,8 @@ def cisc(function, n_outputs=1):
reset_global_vector_size()
program.curr_tape = old_tape
for x, bl in tape.req_bit_length.items():
old_tape.require_bit_length(bl - 1, x)
old_tape.require_bit_length(
bl - 1, x, tape.bit_length_reason if x == 'p' else '')
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
@@ -516,40 +531,48 @@ def cisc(function, n_outputs=1):
inst.copy(size, subs)
reset_global_vector_size()
def expand_to_function(self, size, new_regs):
key = size, program.curr_tape, \
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
tuple(type(reg) for reg in new_regs)
if key not in self.functions:
from Compiler import library, types
class Arg:
def __init__(self, reg):
from Compiler.GC.types import bits
class Arg:
def __init__(self, reg):
self.type = type(reg)
self.binary = isinstance(reg, bits)
self.reg = reg
# if reg is not None:
# program.base_addresses[reg] = None
def new(self):
if self.binary:
return self.type()
else:
return self.type(size=size)
def load(self):
return self.reg
def store(self, reg):
if self.type != type(None):
self.reg.update(reg)
args = [Arg(x) for x in new_regs]
self.type = type(reg)
self.binary = isinstance(reg, bits)
self.reg = reg
def new(self, size):
if self.binary:
return self.type()
else:
return self.type(size=size)
def load(self):
return self.reg
def store(self, reg):
if self.type != type(None):
self.reg.update(reg)
def is_real(self):
return self.reg is not None
def base_key(self, size, new_regs):
return size, tuple(
arg for arg, reg in zip(self.args, new_regs) if reg is None), \
tuple(type(reg) for reg in new_regs)
@staticmethod
def get_name(key):
return '_'.join(['%s(%d)' % (function.__name__, key[0])] +
[str(x) for x in key[1]])
def expand_to_function(self, size, new_regs):
key = self.base_key(size, new_regs) + (program.curr_tape,)
if key not in self.functions:
args = [self.Arg(x) for x in new_regs]
from Compiler import library, types
@library.function_block
def f():
res = [arg.new() for arg in args[:n_outputs]]
self.new_instructions(size,
res + [arg.load() for arg in args[n_outputs:]])
res = [arg.new(size) for arg in args[:n_outputs]]
self.new_instructions(
size, res + [arg.load() for arg in args[n_outputs:]])
for reg, arg in zip(res, args):
arg.store(reg)
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
[str(x) for x in key[2]])
f.name = self.get_name(key)
self.functions[key] = f, args
f, args = self.functions[key]
for i in range(len(new_regs) - n_outputs):
@@ -558,6 +581,31 @@ def cisc(function, n_outputs=1):
for i in range(n_outputs):
new_regs[i].link(args[i].load())
def expand_to_tape(self, size, new_regs):
key = self.base_key(size, new_regs)
args = [self.Arg(x) for x in new_regs]
if key not in self.functions:
from Compiler import library, types
@library.function_call_tape
def f(*in_args):
res = [arg.new(size) for arg in args[:n_outputs]]
in_args = list(in_args)
my_args = list(res)
for arg in args[n_outputs:]:
if arg.is_real():
my_args.append(in_args.pop(0))
else:
my_args.append(arg.reg)
self.new_instructions(size, my_args)
return res
f.name = self.get_name(key)
self.functions[key] = f
f = self.functions[key]
in_args = filter(lambda arg: arg.is_real(), args[n_outputs:])
res = util.tuplify(f(*(arg.load() for arg in in_args)))
for i in range(n_outputs):
new_regs[i].link(res[i])
def expand_merged(self, skip):
if function.__name__ in skip:
good = True
@@ -595,7 +643,11 @@ def cisc(function, n_outputs=1):
raise
if program.cisc_to_function and \
(program.curr_tape.singular or program.n_running_threads):
self.expand_to_function(size, new_regs)
if (program.options.garbled or program.options.binary or \
not program.use_tape_calls) and not program.force_cisc_tape:
self.expand_to_function(size, new_regs)
else:
self.expand_to_tape(size, new_regs)
else:
self.new_instructions(size, new_regs)
program.curr_block.n_rounds += self.n_rounds - 1
@@ -795,6 +847,12 @@ class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
name = 'regint'
class AnyRegAF(RegisterArgFormat):
reg_type = '*'
@staticmethod
def check(arg):
assert isinstance(arg, program.curr_tape.Register)
class IntArgFormat(ArgFormat):
n_bits = 32
@@ -898,6 +956,8 @@ ArgFormats = {
'sgw': SecretGF2NAF,
'ci': ClearIntAF,
'ciw': ClearIntAF,
'*': AnyRegAF,
'*w': AnyRegAF,
'i': ImmediateModpAF,
'ig': ImmediateGF2NAF,
'int': IntArgFormat,
@@ -938,6 +998,7 @@ class Instruction(object):
Instruction.count += 1
if Instruction.count % 100000 == 0:
print("Compiled %d lines at" % self.__class__.count, time.asctime())
sys.stdout.flush()
if Instruction.count > 10 ** 7:
print("Compilation produced more that 10 million instructions. "
"Consider using './compile.py -l' or replacing for loops "

View File

@@ -107,7 +107,7 @@ def print_str(s, *args, print_secrets=False):
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, (list, tuple, Array, SubMultiArray)):
print_str(*_expand_to_print(val))
print_str(*_expand_to_print(val), print_secrets=print_secrets)
else:
try:
val.output()
@@ -314,7 +314,7 @@ def get_cmdline_arg(idx):
return localint(res)
def make_array(l, t=None):
if isinstance(l, Tape.Register):
if isinstance(l, types._structure):
res = Array(len(l), t or type(l))
res[:] = l
else:
@@ -334,13 +334,12 @@ class FunctionTapeCall:
return self
def join(self):
self.thread.join()
instructions.program.free(self.base, 'ci')
for reg_type,addr in self.bases.items():
get_program().free(addr, reg_type.reg_type)
if self.base is not None:
instructions.program.free(self.base, 'ci')
class Function:
def __init__(self, function, name=None, compile_args=[]):
self.type_args = {}
self.last_key = None
self.function = function
self.name = name
if name is None:
@@ -348,46 +347,40 @@ class Function:
self.compile_args = compile_args
def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
from .types import _types
get_reg_type = lambda x: \
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
key = len(args), get_tape()
if key not in self.type_args:
runtime_args = []
reg_args = []
key = self.base_key(),
for i,arg in enumerate(args):
if isinstance(arg, types._vectorizable):
key += (arg.shape, arg.value_type)
else:
arg = MemValue(arg)
reg_args.append(arg)
t = arg.value_type
key += (arg.size, t)
runtime_args.append(arg)
if key != self.last_key:
# first call
type_args = collections.defaultdict(list)
for i,arg in enumerate(args):
if not isinstance(arg, types._vectorizable):
type_args[get_reg_type(arg)].append(i)
outer_runtime_args = runtime_args
def wrapped_function(*compile_args):
base = get_arg()
bases = dict((t, regint.load_mem(base + i)) \
for i,t in enumerate(sorted(type_args,
key=lambda x:
x.reg_type)))
runtime_args = list(args)
for t in sorted(type_args, key=lambda x: x.reg_type):
i = 0
for i_arg in type_args[t]:
runtime_args[i_arg] = t.load_mem(bases[t] + i)
i += util.mem_size(t)
return self.function(*(list(compile_args) + runtime_args))
addresses = regint.Array(len(outer_runtime_args),
address=get_arg())
runtime_args = []
for i, arg in enumerate(outer_runtime_args):
if isinstance(arg, MemValue):
arg = arg.value_type.load_mem(
address=addresses[i], size=arg.size)
runtime_args.append(arg)
self.result = self.function(
*(list(compile_args) + runtime_args))
return self.result
self.on_first_call(wrapped_function)
self.type_args[key] = type_args
type_args = self.type_args[key]
base = instructions.program.malloc(len(type_args), 'ci')
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
for t in type_args)
for i,reg_type in enumerate(sorted(type_args,
key=lambda x: x.reg_type)):
store_in_mem(bases[reg_type], base + i)
j = 0
for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch: "%s" not of type "%s"' %
(args[i_arg], reg_type))
store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type)
return self.on_call(base, bases)
self.last_key = key
addresses = regint.Array(len(runtime_args))
for i, arg in enumerate(reg_args):
addresses[i] = arg.address
return self.on_call(addresses._address,
[(arg.value_type, arg.address) for arg in reg_args])
class FunctionTape(Function):
# not thread-safe
@@ -401,6 +394,113 @@ class FunctionTape(Function):
single_thread=self.single_thread)
def on_call(self, base, bases):
return FunctionTapeCall(self.thread, base, bases)
@staticmethod
def base_key():
pass
class FunctionCallTape(FunctionTape):
def __init__(self, *args, **kwargs):
super(FunctionTape, self).__init__(*args, **kwargs)
self.instances = {}
def __call__(self, *args, **kwargs):
key = ()
def process_for_key(arg):
nonlocal key
if isinstance(arg, types._vectorizable):
key += (arg.value_type, tuple(arg.shape))
elif isinstance(arg, Tape.Register):
key += (type(arg), arg.size)
elif isinstance(arg, list):
key += (tuple(arg), 'l')
else:
key += (arg,)
for arg in args:
process_for_key(arg)
for name, arg in sorted(kwargs.items()):
key += (name, 'kw')
process_for_key(arg)
if key not in self.instances:
my_args = []
def wrapped_function():
actual_call_args = []
def process_for_call(arg):
if isinstance(arg, Tape.Register):
my_arg = arg.same_type()
call_arg(my_arg, base.vm_types[my_arg.reg_type])
my_args.append(my_arg)
return my_arg
elif isinstance(arg, types._vectorizable):
my_arg = arg.same_shape(address=regint())
call_arg(my_arg.address, base.vm_types['ci'])
my_args.append(my_arg)
my_arg = arg.same_shape(
address=MemValue(my_arg.address))
return my_arg
actual_call_args.append(my_arg)
else:
my_args.append(arg)
return arg
for arg in args:
actual_call_args.append(process_for_call(arg))
actual_call_kwargs = {}
for name, arg in sorted(kwargs.items()):
actual_call_kwargs[name] = process_for_call(arg)
self.result = self.function(*actual_call_args,
**actual_call_kwargs)
if self.result is not None:
self.result = list(tuplify(self.result))
for i, res in enumerate(self.result):
if util.is_constant(res):
self.result[i] = regint(res)
self.on_first_call(wrapped_function, key, my_args)
for name, arg in sorted(kwargs.items()):
args += arg,
return self.on_call(*self.instances[key], args)
def on_first_call(self, wrapped_function, key, inside_args):
program = get_program()
program.curr_tape
tape_handle = len(program.tapes)
# entry for recursion
self.instances[key] = tape_handle, None, inside_args
assert tape_handle == program.new_tape(
wrapped_function, name=self.name, args=self.compile_args,
single_thread=get_tape().singular, finalize=False,
thread_pool=get_tape().free_threads)
tape = program.tapes[tape_handle]
if self.result is not None:
self.result = list(tuplify(self.result))
for reg in self.result:
reg.can_eliminate = False
tape.return_values.append(reg)
assert not tape.purged
get_program().finalize_tape(tape)
self.instances[key] = tape_handle, self.result, inside_args
def on_call(self, tape_handle, result, inside_args, args):
tape = get_program().tapes[tape_handle]
if tape.ran_threads and tape.free_threads != get_tape().free_threads:
raise CompilerError(
'cannot call thread-running tape from another thread')
assert len(inside_args) == len(args)
out_result = []
call_args = []
if result is not None:
out_result = [reg.same_type() for reg in result]
for x, y in zip(out_result, result):
call_args += [
1, instructions_base.vm_types[x.reg_type],
x.size_for_mem(), x, y]
for x, y in zip(inside_args, args):
if isinstance(x, Tape.Register):
call_args += [
0, instructions_base.vm_types[x.reg_type],
x.size_for_mem(), x, y]
elif isinstance(x, types._vectorizable):
call_args += [0, base.vm_types['ci'], 1,
x.address, regint.conv(y.address)]
call_tape(tape_handle, regint(0),
*call_args)
break_point('call-%s' % self.name)
return untuplify(tuple(out_result))
def function_tape(function):
return FunctionTape(function)
@@ -413,18 +513,108 @@ def function_tape_with_compile_args(*args):
def single_thread_function_tape(function):
return FunctionTape(function, single_thread=True)
def memorize(x):
if isinstance(x, (tuple, list)):
return tuple(memorize(i) for i in x)
def function_call_tape(function):
if get_program().use_tape_calls:
return FunctionCallTape(function)
else:
return MemValue(x)
return function
def method_call_tape(function):
tapes = {}
def wrapper(self, *args, **kwargs):
def use(name):
x = self.__dict__[name]
return not isinstance(x, types.MultiArray) or \
x.array._address is not None
key = (type(self),) + tuple(filter(use, sorted(self.__dict__)))
member_key = key[1:]
if key not in tapes:
def f(*args, **kwargs):
class Dummy(type(self)):
__init__ = lambda self: None
dummy = Dummy()
members = args[:len(member_key)]
real_args = args[len(member_key):]
addresses = {}
for name, member in zip(member_key, members):
dummy.__dict__[name] = member
if isinstance(member, types._vectorizable):
addresses[name] = member.address
res = function(dummy, *real_args, **kwargs)
for name, member in zip(member_key, members):
new_member = dummy.__dict__[name]
desc = '%s in %s.%s' % (name, type(self).__name__,
function.__name__)
if id(new_member) != id(member):
raise CompilerError('cannot change members '
'in method tape (%s)' % desc)
if isinstance(member, types._vectorizable) and \
id(new_member.address) != id(addresses[name]):
raise CompilerError('cannot change memory address '
'in method tape (%s)' % desc)
if set(member_key) != set(dummy.__dict__):
raise CompilerError('cannot add members '
'in method tape (%s)' % desc)
return res
f.__name__ = '%s-%s' % (type(self).__name__, function.__name__)
tapes[key] = function_call_tape(f)
members = tuple(self.__dict__[x] for x in member_key)
res = tapes[key](*(members + args), **kwargs)
return res
return wrapper
def function(function):
""" Create a run-time function. The arguments can be memory or basic
types, and return values can be basic types::
@function
def f(x, y, z):
y.write(1)
z[0] = 2
return x + 3
a = MemValue(sint(0))
b = sint.Array(10)
c = f(sint(4), a, b)
print_ln('%s %s %s', a.reveal(), b[0].reveal(), c.reveal())
This should output::
1 2 7
You can use run-time functions recursively but without return
values in this case.
"""
return FunctionCallTape(function)
def memorize(x, write=True):
if isinstance(x, (tuple, list)):
return tuple(memorize(i, write=write) for i in x)
elif x is None:
return
else:
return MemValue(x, write=write)
def unmemorize(x):
if isinstance(x, (tuple, list)):
return tuple(unmemorize(i) for i in x)
elif x is None:
return
else:
return x.read()
def write_mem(dest, source):
if isinstance(dest, (tuple, list)):
assert len(dest) == len(source)
for x, y in zip(dest, source):
write_mem(x, y)
elif dest is None:
return
else:
dest.write(source)
class FunctionBlock(Function):
def on_first_call(self, wrapped_function):
p_return_address = get_tape().program.malloc(1, 'ci')
@@ -470,6 +660,10 @@ class FunctionBlock(Function):
if self.result is not None:
return unmemorize(self.result)
@staticmethod
def base_key():
return get_tape()
def function_block(function):
return FunctionBlock(function)
@@ -907,11 +1101,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
r = reducer(tuplify(loop_body(j)), mem_state)
write_state_to_memory(r)
state = mem_state
for i,x in enumerate(state):
if use_array:
mem_state[i] = x
else:
mem_state[i].write(x)
if use_array and len(state) and \
isinstance(types._register, types._vectorizable):
mem_state[:] = state.get_vector()
else:
for i,x in enumerate(state):
if use_array:
mem_state[i] = x
else:
mem_state[i].write(x)
def returner():
return untuplify(tuple(state))
return returner
@@ -987,7 +1185,7 @@ def multithread(n_threads, n_items=None, max_size=None):
.. code::
@multithread(8, 25)
@multithread(3, 25)
def f(base, size):
...
"""
@@ -1077,7 +1275,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
return loop_body(base + i)
prog = get_program()
thread_args = []
if prog.curr_tape == prog.tapes[0]:
if prog.curr_tape.singular:
prog.n_running_threads = n_threads
if not util.is_zero(thread_rounds):
prog.prevent_breaks = False
@@ -1288,9 +1486,21 @@ def while_do(condition, *args):
return loop_body
return decorator
def _run_and_link(function, g=None):
def _run_and_link(function, g=None, lock_lists=True):
if g is None:
g = function.__globals__
if lock_lists:
class A(list):
def __init_(self, l):
self[:] = l
def __setitem__(*args):
raise Exception('you cannot change lists in branches, '
'use Array or MultiArray instead')
__delitem__ = append = clear = extend = insert = __setitem__
pop = remove = reverse = sort = __setitem__
for x in g:
if isinstance(g[x], list):
g[x] = A(g[x])
pre = copy.copy(g)
res = function()
_link(pre, g)
@@ -1570,14 +1780,25 @@ def listen_for_clients(port):
"""
instructions.listen(regint.conv(port))
def accept_client_connection(port):
def accept_client_connection(port, players=None):
""" Accept client connection on specific port base.
:param port: port base (int/regint/cint)
:param players: subset of players (default: all)
:returns: client id
"""
res = regint()
instructions.acceptclientconnection(res, regint.conv(port))
if players is None:
instructions.acceptclientconnection(res, regint.conv(port))
else:
@if_e(sum(regint(players) ==
get_player_id()._v.expand_to_vector(len(players))))
def _():
res.update(accept_client_connection(port))
@else_
def _():
res.update(-1)
return res
def init_client_connection(host, port, my_id, relative_port=True):
@@ -1697,14 +1918,14 @@ def cint_cint_division(a, b, k, f):
from Compiler.program import Program
@instructions_base.ret_cisc
def sint_cint_division(a, b, k, f, kappa, nearest=False):
def sint_cint_division(a, b, k, f, nearest=False):
"""
type(a) = sint, type(b) = cint
"""
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k)
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
@@ -1714,20 +1935,20 @@ def sint_cint_division(a, b, k, f, kappa, nearest=False):
W = w0
for i in range(1, theta):
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
A = (A * W).round(2 * k, f, nearest=nearest, signed=True)
temp = (B * W + 2 * (f - 1)) >> f
W = two - temp
B = temp
return (sign_a * sign_b) * A
def IntDiv(a, b, k, kappa=None):
def IntDiv(a, b, k):
l = 2 * k + 1
b = a.conv(b)
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
kappa, nearest=True)
nearest=True)
@instructions_base.ret_cisc
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
"""
Goldschmidt method as presented in Catrina10,
"""
@@ -1750,40 +1971,40 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
w = AppRcr(b, k, f, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
y = a.extend(l_y) * w
y = y.round(l_y, f, kappa, nearest, signed=True)
y = y.round(l_y, f, nearest, signed=True)
for i in range(theta - 1):
x = x.extend(2 * k)
y = y.extend(l_y) * (alpha + x).extend(l_y)
x = x * x
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
y = y.round(l_y, 2*f, nearest, signed=True)
x = x.round(2*k, 2*f, nearest, signed=True)
x = x.extend(2 * k)
y = y.extend(l_y) * (alpha + x).extend(l_y)
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
y = y.round(l_y, 3 * f - res_f, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
def AppRcr(b, k, f, simplex_flag=False, nearest=False):
"""
Approximate reciprocal of [b]:
Given [b], compute [1/b]
"""
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
c, v = b.Norm(k, f, kappa, simplex_flag)
c, v = b.Norm(k, f, simplex_flag)
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
d = alpha - 2 * c
w = d * v
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
w = w.round(2 * k + 1, 2 * (k - f), nearest, signed=True)
# now w * 2 ^ {-f} should be an initial approximation of 1/b
return w
def Norm(b, k, f, kappa, simplex_flag=False):
def Norm(b, k, f, simplex_flag=False):
"""
Computes secret integer values [c] and [v_prime] st.
2^{k-1} <= c < 2^k and c = b*v_prime
@@ -1799,8 +2020,8 @@ def Norm(b, k, f, kappa, simplex_flag=False):
absolute_val = sign * b
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
suffixes = PreOR(bits, kappa)[::-1]
bits = absolute_val.bit_decompose(k, maybe_mixed=True)[::-1]
suffixes = PreOR(bits)[::-1]
z = [0] * k
for i in range(k - 1):

View File

@@ -203,6 +203,20 @@ def _no_mem_warnings(function):
copy_doc(wrapper, function)
return wrapper
def _layer_method_call_tape(function):
function = method_call_tape(function)
def wrapper(self, *args, **kwargs):
self._Y.alloc()
if self.inputs and len(self.inputs) == 1:
backup = self.inputs
del self.inputs
res = function(self, *args, **kwargs)
self.inputs = backup
return res
else:
return function(self, *args, **kwargs)
return wrapper
class Tensor(MultiArray):
def __init__(self, *args, **kwargs):
kwargs['alloc'] = False
@@ -259,6 +273,7 @@ class Layer:
def Y(self, value):
self._Y = value
@_layer_method_call_tape
def forward(self, batch=None, training=None):
if batch is None:
batch = Array.create_from(regint(0))
@@ -1045,7 +1060,8 @@ class ElementWiseLayer(NoVariableLayer):
def _forward(self, batch=[0]):
n_per_item = reduce(operator.mul, self.X.sizes[1:])
@multithread(self.n_threads, len(batch) * n_per_item)
@multithread(self.n_threads, len(batch) * n_per_item,
max_size=program.budget)
def _(base, size):
self.Y.assign_vector(self.f_part(base, size), base)
@@ -1195,6 +1211,7 @@ class MaxPool(PoolBase):
list/tuple of integers
"""
@_layer_method_call_tape
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
@@ -1252,26 +1269,38 @@ class Concat(NoVariableLayer):
self.dimension = dimension
shapes = [inp.shape for inp in inputs]
assert dimension == 3
assert len(shapes) == 2
assert len(shapes[0]) == len(shapes[1])
assert len(shapes[0]) == 4
for shape in shapes:
assert len(shape) == len(shapes[0])
shape = []
for i in range(len(shapes[0])):
if i == dimension:
shape.append(shapes[0][i] + shapes[1][i])
shape.append(sum(x[i] for x in shapes))
else:
assert shapes[0][i] == shapes[1][i]
shape.append(shapes[0][i])
self.Y = Tensor(shape, sfix)
self.bases = [sum(x[dimension] for x in shapes[:k])
for k in range(len(shapes))]
self.addresses = Array.create_from(regint(list(
x.Y.address for x in inputs)))
def _forward(self, batch=[0]):
assert len(batch) == 1
@for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3])
def _(i, j):
X = [x.Y[batch[0]] for x in self.inputs]
self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector())
self.Y[batch[0]][i][j].assign_part_vector(
X[1][i][j].get_vector(),
len(X[0][i][j]))
if len(set(self.bases)) == 1:
@for_range(len(self.inputs))
def _(k):
self.Y[batch[0]][i][j].assign_part_vector(
MultiArray(
self.inputs[0].shape,
address=self.addresses[k])[i][j].get_vector(),
k * self.bases[1])
else:
X = [x.Y[batch[0]] for x in self.inputs]
for k in range(len(self.inputs)):
self.Y[batch[0]][i][j].assign_part_vector(
X[k][i][j].get_vector(), self.bases[k])
class Add(NoVariableLayer):
""" Fixed-point addition layer.
@@ -1334,12 +1363,14 @@ class BatchNorm(Layer):
def __init__(self, shape, approx=True, args=None):
assert len(shape) in (2, 3, 4)
self.Y = sfix.Tensor(shape)
if len(shape) == 4:
shape = [shape[0], shape[1] * shape[2], shape[3]]
elif len(shape) == 2:
shape = [shape[0], 1, shape[1]]
tensors = (Tensor(shape, sfix) for i in range(4))
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
self.my_Y = sfix.Tensor(shape, address=self.Y.address)
tensors = (Tensor(shape, sfix) for i in range(3))
self.X, self.nabla_X, self.nabla_Y = tensors
arrays = (sfix.Array(shape[2]) for i in range(4))
self.var, self.mu, self.weights, self.bias = arrays
arrays = (sfix.Array(shape[2]) for i in range(4))
@@ -1374,11 +1405,11 @@ class BatchNorm(Layer):
[len(batch), self.X.sizes[1]])
def _(i, j):
tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:]
self.Y[i][j][:] = self.bias[:] + tmp
self.my_Y[i][j][:] = self.bias[:] + tmp
@_layer_method_call_tape
def forward(self, batch, training=False):
if training or not self.is_trained:
self.is_trained = True
d = self.X.sizes[1]
d_in = self.X.sizes[2]
s = sfix.Array(d_in)
@@ -1411,7 +1442,7 @@ class BatchNorm(Layer):
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s',
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
self.Y[i][j][k].reveal())
self.my_Y[i][j][k].reveal())
else:
self._output(batch, self.mu_hat, self.var_hat)
@@ -1471,6 +1502,10 @@ class BatchNorm(Layer):
self.nabla_Y[i][j][k].reveal(),
self.nabla_X[i][j][k].reveal())
def reveal_parameters_to_binary(self):
for param in self.thetas() + (self.mu_hat, self.var_hat):
param.reveal().binary_output()
class QuantBase(object):
bias_before_reduction = True
@@ -1498,11 +1533,12 @@ class QuantBase(object):
class FixBase:
bias_before_reduction = False
@staticmethod
def new_squant():
class _(sfix):
params = None
return _
class my_squant(sfix):
params = None
@classmethod
def new_squant(cls):
return cls.my_squant
def input_params_from(self, player):
pass
@@ -1553,7 +1589,8 @@ class ConvBase(BaseLayer):
cls.temp_inputs = sfix.Array(size)
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride,
padding='SAME', tf_weight_format=False, inputs=None):
padding='SAME', tf_weight_format=False, inputs=None,
weight_type=None):
super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs)
self.weight_shape = weight_shape
@@ -1582,7 +1619,11 @@ class ConvBase(BaseLayer):
else:
self.padding = padding
self.weight_squant = self.new_squant()
if weight_type:
self.weight_squant = weight_type
else:
self.weight_squant = self.new_squant()
self.bias_squant = self.new_squant()
self.weights = Tensor(weight_shape, self.weight_squant)
@@ -1645,8 +1686,7 @@ class ConvBase(BaseLayer):
n_summands = self.n_summands()
#start_timer(2)
n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:])
@multithread(self.n_threads, n_outputs,
1000 if sfix.round_nearest else 10 ** 6)
@multithread(self.n_threads, n_outputs, max_size=program.budget)
def _(base, n_per_thread):
res = self.input_squant().unreduced(
sint.load_mem(unreduced.address + base,
@@ -1709,7 +1749,7 @@ class Conv2d(ConvBase):
res += self.bias.expand_to_vector(j, res.size).v
else:
res += self.bias.expand_to_vector(j, res.size).v << \
self.input_squant.f
self.weight_squant.f
addresses = regint.inc(res.size,
self.unreduced[i * part_size].address + j,
n_channels_out)
@@ -2029,7 +2069,7 @@ class AveragePool2d(BaseLayer):
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
padding=0):
padding=0, **kwargs):
""" More convenient interface to :py:class:`FixConv2d`.
:param input_shape: input shape (tuple/list of four int)
@@ -2050,7 +2090,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
padding = padding.upper() if isinstance(padding, str) \
else padding
return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape,
stride, padding)
stride, padding, **kwargs)
def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
""" More convenient interface to :py:class:`MaxPool`.
@@ -2102,6 +2142,7 @@ class FixAveragePool2d(PoolBase, FixBase):
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
[1] + list(filter_size) + [1], padding)
self.pool_size = reduce(operator.mul, filter_size)
self.d_out = self.Y.shape[-1]
if output_shape:
assert self.Y.shape == list(output_shape)
@@ -2322,6 +2363,10 @@ class Optimizer:
self.forward(batch=batch, run_last=False)
part = self.layers[-1].eval(batch_size, top=top)
res.assign_part_vector(part.get_vector(), start)
if self.output_stats:
for layer in self.layers[:-1]:
print_ln(layer)
self.stat(' Y', layer.Y)
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
return res
@@ -2567,9 +2612,9 @@ class Optimizer:
def run_in_batches(self, f, data, batch_size, truth=None):
batch_size = min(batch_size, data.sizes[0])
training_data = self.layers[0].X.address
training_data = self.layers[0]._X.array._address
training_truth = self.layers[-1].Y.address
self.layers[0].X.address = data.address
self.layers[0]._X.address = data.address
if truth:
self.layers[-1].Y.address = truth.address
N = data.sizes[0]
@@ -2748,13 +2793,15 @@ class Optimizer:
return list(self.thetas)
def reveal_model_to_binary(self):
input_shape = self.layers[0].X.shape
""" Reveal model and store it in the binary output file, see
:ref:`reveal-model` for details. """
input_shape = self.layers[0]._X.shape
for layer in self.layers:
if len(input_shape) == 4 and isinstance(layer, DenseBase):
layer.reveal_parameters_to_binary(reshape=input_shape[1:])
else:
layer.reveal_parameters_to_binary()
input_shape = layer.Y.shape
input_shape = layer._Y.shape
class Adam(Optimizer):
""" Adam/AMSgrad optimizer.
@@ -3046,7 +3093,7 @@ class keras:
def summary(self):
self.opt.summary()
def build(self, input_shape, batch_size=128):
def build(self, input_shape, batch_size=128, program=None):
data_input_shape = input_shape
if self.opt != None and \
input_shape == self.opt.layers[0]._X.sizes and \
@@ -3120,8 +3167,11 @@ class keras:
if layers[-1].d_out == 1:
layers.append(Output(data_input_shape[0]))
else:
layers.append(
MultiOutput(data_input_shape[0], layers[-1].d_out))
shape = data_input_shape[0], layers[-1].d_out
if program:
layers.append(MultiOutput.from_args(program, *shape))
else:
layers.append(MultiOutput(*shape))
if self.optimizer[1]:
raise Exception('use keyword arguments for optimizer')
opt = self.optimizer[0]
@@ -3186,11 +3236,11 @@ class keras:
batch_size = min(batch_size, self.batch_size)
return self.opt.eval(x, batch_size=batch_size)
def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
regression=False):
""" Convert a PyTorch Sequential object to MP-SPDZ layers.
def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
regression=False, layer_args={}, program=None):
""" Convert a PyTorch Module object to MP-SPDZ layers.
:param sequence: PyTorch Sequential object
:param model: PyTorch Module object
:param data_input_shape: input shape (list of four int)
:param batch_size: batch size (int)
:param input_via: player to input model data via (default: don't)
@@ -3198,17 +3248,29 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
"""
layers = []
named_layers = {}
def mul(x):
return reduce(operator.mul, x)
def process(item):
nonlocal input_shape
import torch
def process(item, inputs, input_shape, args):
if item == torch.cat:
if len(inputs) > 1:
layers.append(
Concat(inputs, dimension=len(inputs[0].shape) - 1))
return
elif item == operator.add:
layers.append(Add(inputs))
return
elif item == torch.flatten:
return
# single-input layers from here
if inputs and len(inputs) > 1:
raise CompilerError('multi-input layer %s not supported' % item)
name = type(item).__name__
if name == 'Sequential':
for x in item:
process(x)
elif name == 'Linear':
if name == 'Linear':
assert mul(input_shape[1:]) == item.in_features
assert item.bias is not None
layers.append(Dense(input_shape[0], item.in_features,
@@ -3239,7 +3301,7 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
elif name == 'Conv2d':
layers.append(easyConv2d(input_shape, batch_size, item.out_channels,
item.kernel_size, item.stride,
item.padding))
item.padding, **layer_args.get(item, {})))
input_shape = layers[-1].Y.shape
if input_via is not None:
shapes = [x.shape for x in
@@ -3247,11 +3309,14 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
import numpy
swapped = numpy.moveaxis(
numpy.array(item.weight.detach()), 1, -1)
layers[-1].weights = sfix.input_tensor_via(input_via, swapped)
layers[-1].bias = sfix.input_tensor_via(
input_via, item.bias.detach())
layers[-1].weights = \
layers[-1].weights.value_type.input_tensor_via(
input_via, swapped)
assert layers[-1].weights.shape == shapes[0]
assert layers[-1].bias.shape == shapes[1]
if isinstance(item.bias, torch.Tensor):
layers[-1].bias = sfix.input_tensor_via(
input_via, item.bias.detach())
assert layers[-1].bias.shape == shapes[1]
elif name == 'MaxPool2d':
layers.append(easyMaxPool(input_shape, item.kernel_size,
item.stride, item.padding))
@@ -3260,10 +3325,24 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
item.stride, item.padding))
input_shape = layers[-1].Y.shape
elif name == 'ReLU':
elif name == 'AdaptiveAvgPool2d' or \
item == torch.nn.functional.adaptive_avg_pool2d:
if name == 'AdaptiveAvgPool2d':
output = item.output_size
else:
output = args[1]
for i in (0, 1):
assert input_shape[1 + i] % output[i] == 0
stride = [input_shape[1 + i] // output[i] for i in (0, 1)]
kernel_size = [input_shape[1 + i] - (output[i] - 1) * stride[i]
for i in (0, 1)]
layers.append(FixAveragePool2d(input_shape, None, kernel_size,
stride, padding=0))
input_shape = layers[-1].Y.shape
elif name == 'ReLU' or item == torch.nn.functional.relu:
layers.append(Relu(input_shape))
elif name == 'Flatten':
pass
return
elif name == 'BatchNorm2d':
layers.append(BatchNorm(layers[-1].Y.sizes))
if input_via is not None:
@@ -3277,16 +3356,52 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
alpha=item.p))
input_shape = layers[-1].Y.sizes
else:
raise CompilerError('unknown PyTorch module: ' + name)
raise CompilerError('unknown PyTorch module: %s' % item)
layers[-1].inputs = inputs
input_shape = data_input_shape + [1] * (4 - len(data_input_shape))
process(sequence)
torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes)
for i, layer in enumerate(torch_layers[1:-1]):
if layer.op == 'call_module':
target = model
for attr in layer.target.split('.'):
target = getattr(target, attr)
else:
target = layer.target
if not layers:
assert layer.args == (torch_layers[i],)
inputs = None
else:
if len(layer.args) < 2 or (layer.args[1] != 1 and
layer.args[1] != (1, 1)):
args = layer.args
elif isinstance(layer.args[0], list):
args = layer.args[0]
else:
args = layer.args[0],
inputs = [named_layers[x] for x in args]
if len(inputs) == 1:
if isinstance(inputs[0], (Dropout, BatchNorm)):
input_shape = inputs[0].inputs[0].Y.shape
else:
input_shape = inputs[0]._Y.shape
else:
input_shape = None
process(target, inputs, input_shape, layer.args)
if layers:
named_layers[layer] = layers[-1]
if regression:
layers.append(LinearOutput(data_input_shape[0], layers[-1].d_out))
elif layers[-1].d_out == 1:
layers.append(Output(data_input_shape[0]))
else:
layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out))
shape = data_input_shape[0], layers[-1].d_out
if program:
layers.append(MultiOutput.from_args(program, *shape))
else:
layers.append(MultiOutput(*shape))
return layers
class OneLayerSGD:
@@ -3335,6 +3450,9 @@ class OneLayerSGD:
class SGDLogistic(OneLayerSGD):
""" Logistic regression using SGD.
The member :py:obj:`opt` refers to the internal instance of
:py:class:`Optimizer`, which allows to use the funcionality
therein.
:param n_epochs: number of epochs
:param batch_size: batch size
@@ -3460,6 +3578,8 @@ def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
def mr(A, n_iterations, stop=False):
""" Iterative matrix inverse approximation.
This is based on the conjugate gradients algorithm in Section
10.2.4 of `these lecture notes <https://graphics.stanford.edu/courses/cs205a-13-fall/assets/notes/cs205a_notes.pdf>`_.
:param A: matrix to invert
:param n_iterations: maximum number of iterations

View File

@@ -1,7 +1,9 @@
"""
Module for math operations.
Implements trigonometric and logarithmic functions.
Most of the functionality is due to `Aly and Smart
<https://eprint.iacr.org/2019/354>`_ with some optimizations by
`Keller and Sun <https://eprint.iacr.org/2022/933>`_.
This has to imported explicitly.
"""
@@ -98,7 +100,7 @@ pi_over_2 = math.radians(90)
# @return truncated sint value of x
def trunc(x):
if isinstance(x, types._fix):
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
return x.v.right_shift(x.f, x.k, signed=True)
elif type(x) is types.sfloat:
v, p, z, s = floatingpoint.FLRound(x, 0)
#return types.sfloat(v, p, z, s, x.err)
@@ -106,19 +108,6 @@ def trunc(x):
return x
##
# loads integer to fractional type (sint)
# @param x: coefficient to be truncated.
#
# @return returns sfix, sfloat loaded value
def load_sint(x, l_type):
if l_type is types.sfix:
return types.sfix.from_sint(x)
elif l_type is types.sfloat:
return x
return x
##
# evaluates a Polynomial to a given x in a privacy preserving manner.
# Inputs can be of any kind of register, secret or otherwise.
@@ -448,7 +437,7 @@ def log2_fx(x, use_division=True):
if isinstance(x, types._fix):
# transforms sfix to f*2^n, where f is [o.5,1] bounded
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f)
p -= x.f
vlen = x.f
v = x._new(v, k=x.k, f=x.f)
@@ -473,7 +462,7 @@ def log2_fx(x, use_division=True):
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
def pow_fx(x, y):
def pow_fx(x, y, zero_output=False):
"""
Returns the value of the expression :math:`x^y` where both inputs
are secret shared. It uses :py:func:`log2_fx` together with
@@ -494,7 +483,7 @@ def pow_fx(x, y):
# obtains y * log2(x)
exp = y * log2_x
# returns 2^(y*log2(x))
return exp2_fx(exp)
return exp2_fx(exp, zero_output)
def log_fx(x, b):
@@ -535,8 +524,8 @@ def abs_fx(x):
#
# @return floored sint value of x
def floor_fx(x):
return load_sint(x.v.right_shift(x.f, bit_length=x.k, security=x.kappa,
signed=True), type(x))
return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True),
k=x.k, f=x.f)
### sqrt methods
@@ -743,13 +732,13 @@ def lin_app_SQ(b, k, f):
c, v, m, W = norm_SQ(types.sint(b), k)
# c is now escalated
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
w = alpha * c + beta # equation before b and reduction by order of k
# m even or odd determination
m_bit = types.sint()
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
m = load_sint(m_bit, types.sfix)
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False)
m = m_bit
# w times v this way both terms have 2^3k and can be symplified
w = w * v
@@ -774,7 +763,7 @@ def lin_app_SQ(b, k, f):
def sqrt_fx(x_l, k, f):
factor = 1.0 / (2.0 ** f)
x = load_sint(x_l, types.sfix) * factor
x = x_l * factor
theta = int(math.ceil(math.log(k/5.4)))
@@ -912,29 +901,29 @@ def tanh(x):
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
def Sep(x):
def Sep(x, sfix=types.sfix):
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
bb = b[:]
while len(bb) < 2 * x.f - 1:
bb.insert(0, type(b[0])(0))
t = x.v * (1 + x.v.bit_compose(b_i.bit_not()
for b_i in bb[-2 * x.f + 1:]))
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
b += [b[0].long_one()]
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
def SqrtComp(z, old=False):
f = types.sfix.f
def SqrtComp(z, old=False, sfix=types.sfix):
f = sfix.f
k = len(z)
if isinstance(z[0], types.sint):
return types.sfix._new(sum(z[i] * types.cfix(
return sfix._new(sum(z[i] * types.cfix(
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
k_prime = k // 2
f_prime = f // 2
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
c0 = types.sfix(2 ** (f / 2 + 1))
c1 = sfix(2 ** ((f + 1) / 2 + 1))
c0 = sfix(2 ** (f / 2 + 1))
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
if old:
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
else:
@@ -942,11 +931,15 @@ def SqrtComp(z, old=False):
return types.sint.conv(b).if_else(c1, c0) * tmp
@types.vectorize
@instructions_base.sfix_cisc
def InvertSqrt(x, old=False):
"""
Reciprocal square root approximation by `Lu et al.
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
"""
u, z = Sep(x)
class my_sfix(types.sfix):
f = x.f
k = x.k
u, z = Sep(x, sfix=my_sfix)
c = 3.14736 + u * (4.63887 * u - 5.77789)
return c * SqrtComp(z, old=old)
return c * SqrtComp(z, old=old, sfix=my_sfix)

View File

@@ -4,14 +4,6 @@ from .types import *
from . import comparison, program
class NonLinear:
kappa = None
def set_security(self, kappa):
pass
def check_security(self, kappa):
pass
def mod2m(self, a, k, m, signed):
"""
a_prime = a % 2^m
@@ -45,18 +37,16 @@ class NonLinear:
def trunc_round_nearest(self, a, k, m, signed):
res = sint()
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa,
signed)
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed)
return res
def trunc(self, a, k, m, kappa, signed):
self.check_security(kappa)
def trunc(self, a, k, m, signed):
if m == 0:
return a
return self._trunc(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return -self.trunc(a, k, k - 1, kappa, True)
def ltz(self, a, k):
return -self.trunc(a, k, k - 1, True)
class Masking(NonLinear):
def eqz(self, a, k):
@@ -68,28 +58,19 @@ class Masking(NonLinear):
class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """
def __init__(self, kappa):
self.set_security(kappa)
def set_security(self, kappa):
self.kappa = kappa
def check_security(self, kappa):
assert self.kappa == kappa or kappa is None
def _mod2m(self, a, k, m, signed):
res = sint()
if m == 1:
Mod2(res, a, k, self.kappa, signed)
Mod2(res, a, k, signed)
else:
Mod2mField(res, a, k, m, self.kappa, signed)
Mod2mField(res, a, k, m, signed)
return res
def _mask(self, a, k):
return maskField(a, k, self.kappa)
return maskField(a, k)
def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m, self.kappa)
return TruncPrField(a, k, m)
def _trunc(self, a, k, m, signed=None):
a_prime = self.mod2m(a, k, m, signed)
@@ -99,12 +80,12 @@ class Prime(Masking):
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecFieldRaw(a, k, m, self.kappa)
return BitDecFieldRaw(a, k, m)
else:
return BitDecField(a, k, m, self.kappa)
return BitDecField(a, k, m)
def kor(self, d):
return KOR(d, self.kappa)
return KOR(d)
class KnownPrime(NonLinear):
""" Non-linear functionality modulo a prime known at compile time. """
@@ -144,13 +125,13 @@ class KnownPrime(NonLinear):
a += two_power(k)
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
def ltz(self, a, k, kappa=None):
def ltz(self, a, k):
if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
return self.mod2m(2 * a, k + 1, 1, False)
else:
return super(KnownPrime, self).ltz(a, k, kappa)
return super(KnownPrime, self).ltz(a, k)
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
@@ -189,5 +170,5 @@ class Ring(Masking):
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
def ltz(self, a, k, kappa=None):
def ltz(self, a, k):
return LtzRing(a, k)

View File

@@ -877,7 +877,7 @@ class LinearORAM(TrivialORAM):
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@map_sum(get_n_threads(self.size), None, self.size, \
self.value_length + 1, t)
def f(i):
entry = self.ram[i]
@@ -897,7 +897,7 @@ class LinearORAM(TrivialORAM):
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
@for_range_multithread(get_n_threads(self.size), None, self.size)
def f(i):
entry = self.ram[i]
access_here = self.index_vector[i]
@@ -917,7 +917,7 @@ class LinearORAM(TrivialORAM):
max(x or 0 for x in self.entry_size)))
new_empty = MemValue(new_empty)
write = MemValue(write)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@map_sum(get_n_threads(self.size), None, self.size, \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
def f(i):
@@ -1340,8 +1340,8 @@ class TreeORAM(AbstractORAM):
half = (empty_positions[i]+1 - parity) // 2
half_max = self.bucket_size // 2
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
bits = floatingpoint.B2U(half, half_max)[0]
bits2 = floatingpoint.B2U(half+parity, half_max)[0]
# (doesn't work)
#bits2 = [0] * half_max
## second half with parity bit
@@ -1350,7 +1350,8 @@ class TreeORAM(AbstractORAM):
#bits2[0] = (1 - bits[0]) * parity
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
else:
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
bucket_bits = floatingpoint.B2U(empty_positions[i]+1,
self.bucket_size)[0]
assert len(bucket_bits) == self.bucket_size
for j, b in enumerate(bucket_bits):
pos_bits[i * self.bucket_size + j] = [b, leaf]
@@ -1376,8 +1377,7 @@ class TreeORAM(AbstractORAM):
Program.prog.curr_tape.start_new_basicblock()
bucket_sizes = Array(2**self.D, regint)
for i in range(2**self.D):
bucket_sizes[i] = 0
bucket_sizes.assign_all(0)
@for_range_opt(len(entries))
def _(k):
@@ -1697,7 +1697,7 @@ class OneLevelORAM(TreeORAM):
class BinaryORAM:
def __init__(self, size, value_type=None, **kwargs):
import circuit_oram
from Compiler import circuit_oram
from Compiler.GC import types
n_bits = int(get_program().options.binary)
self.value_type = value_type or types.sbitintvec.get_type(n_bits)

View File

@@ -18,7 +18,7 @@ from functools import reduce
import Compiler.instructions
import Compiler.instructions_base
import Compiler.instructions_base as inst_base
from Compiler.config import REG_MAX, USER_MEM, COST
from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX
from Compiler.exceptions import CompilerError
from Compiler.instructions_base import RegType
@@ -103,6 +103,38 @@ class Program(object):
self.bit_length = int(options.binary) or int(options.field)
if options.prime:
self.prime = int(options.prime)
print("WARNING: --prime/-P activates code that usually isn't "
"the most efficient variant. Consider using --field/-F "
"and set the prime only during the actual computation.")
if not self.rabbit_gap() and self.prime > 2 ** 50:
print("The chosen prime is particularly inefficient. "
"Consider using a prime that is closer to a power "
"of two", end='')
try:
import gmpy2
bad_prime = self.prime
self.prime = 2 ** int(
round(math.log(self.prime, 2))) + 1
while True:
if self.prime > 2 ** 59:
# LWE compatibility
step = 2 ** 15
else:
step = 1
if self.prime < bad_prime:
self.prime += step
else:
self.prime -= step
if gmpy2.is_prime(self.prime):
break
assert self.rabbit_gap()
print(", for example, %d." % self.prime)
self.prime = bad_prime
except ImportError:
print(".")
if options.execute:
print("Use '-- --prime <prime>' to specify the prime for "
"execution only.")
max_bit_length = int(options.prime).bit_length() - 2
if self.bit_length > max_bit_length:
raise CompilerError(
@@ -111,7 +143,7 @@ class Program(object):
self.bit_length = self.bit_length or max_bit_length
self.non_linear = KnownPrime(self.prime)
else:
self.non_linear = Prime(self.security)
self.non_linear = Prime()
if not self.bit_length:
self.bit_length = 64
print("Default bit length for compilation:", self.bit_length)
@@ -197,6 +229,8 @@ class Program(object):
self.cisc_to_function = True
if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard
self.use_tape_calls = True
self.force_cisc_tape = False
Program.prog = self
from . import comparison, instructions, instructions_base, types
@@ -278,7 +312,8 @@ class Program(object):
self.non_linear = Ring(ring_size)
self.options.ring = str(ring_size)
def new_tape(self, function, args=[], name=None, single_thread=False):
def new_tape(self, function, args=[], name=None, single_thread=False,
finalize=True, **kwargs):
"""
Create a new tape from a function. See
:py:func:`~Compiler.library.multithread` and
@@ -309,11 +344,12 @@ class Program(object):
self.curr_tape
tape_index = len(self.tapes)
self.tape_stack.append(self.curr_tape)
self.curr_tape = Tape(name, self)
self.curr_tape = Tape(name, self, **kwargs)
self.curr_tape.singular = single_thread
self.tapes.append(self.curr_tape)
function(*args)
self.finalize_tape(self.curr_tape)
if finalize:
self.finalize_tape(self.curr_tape)
if self.tape_stack:
self.curr_tape = self.tape_stack.pop()
return tape_index
@@ -346,6 +382,7 @@ class Program(object):
thread_numbers = []
while len(thread_numbers) < len(args):
free_threads = self.curr_tape.free_threads
self.curr_tape.ran_threads = True
if free_threads:
thread_numbers.append(min(free_threads))
free_threads.remove(thread_numbers[-1])
@@ -417,7 +454,10 @@ class Program(object):
def finalize_tape(self, tape):
if not tape.purged:
curr_tape = self.curr_tape
self.curr_tape = tape
tape.optimize(self.options)
self.curr_tape = curr_tape
tape.write_bytes()
if self.options.asmoutfile:
tape.write_str(self.options.asmoutfile + "-" + tape.name)
@@ -472,16 +512,18 @@ class Program(object):
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
if addr + size >= 2**64:
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
if addr + size >= MEM_MAX:
raise CompilerError(
"allocation exceeded for type '%s' after adding %d" % \
(mem_type, size))
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
if single_size:
from .library import get_thread_number, runtime_error_if
from .library import get_arg, runtime_error_if
bak = self.curr_tape.active_basicblock
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
tn = get_thread_number()
runtime_error_if(tn > self.n_running_threads, "malloc")
res = addr + single_size * (tn - 1)
arg = get_arg()
runtime_error_if(arg >= self.n_running_threads, "malloc")
res = addr + single_size * arg
self.curr_tape.active_basicblock = bak
self.base_addresses[res] = addr
return res
@@ -577,7 +619,6 @@ class Program(object):
def set_security(self, security):
changed = self._security != security
self._security = security
self.non_linear.set_security(security)
if changed:
print("Changed statistical security for comparison etc. to",
security)
@@ -783,7 +824,7 @@ class Program(object):
class Tape:
"""A tape contains a list of basic blocks, onto which instructions are added."""
def __init__(self, name, program):
def __init__(self, name, program, thread_pool=None):
"""Set prime p and the initial instructions and registers."""
self.program = program
name += "-%d" % program.get_tape_counter()
@@ -800,12 +841,15 @@ class Tape:
self.merge_opens = True
self.if_states = []
self.req_bit_length = defaultdict(lambda: 0)
self.bit_length_reason = None
self.function_basicblocks = {}
self.functions = []
self.singular = True
self.free_threads = set()
self.free_threads = set() if thread_pool is None else thread_pool
self.loop_breaks = []
self.warned_about_mem = False
self.return_values = []
self.ran_threads = False
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None,
@@ -880,7 +924,8 @@ class Tape:
self.usage_instructions = list(filter(relevant, self.instructions))
else:
self.usage_instructions = []
if len(self.usage_instructions) > 1000:
if len(self.usage_instructions) > 1000 and \
self.parent.program.verbose:
print("Retaining %d instructions" % len(self.usage_instructions))
del self.instructions
self.purged = True
@@ -1107,6 +1152,9 @@ class Tape:
if addr.program == self and self.basicblocks:
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
for reg in self.return_values:
allocator.alloc_reg(reg, self.basicblocks[-1].alloc_pool)
seen = set()
def alloc(block):
@@ -1214,7 +1262,10 @@ class Tape:
Compiler.instructions.reqbl(bl, add_to_prog=False)
)
if self.program.verbose:
print("Tape requires prime bit length", self.req_bit_length["p"])
print("Tape requires prime bit length",
self.req_bit_length["p"],
('for %s' % self.bit_length_reason
if self.bit_length_reason else ''))
print("Tape requires galois bit length", self.req_bit_length["2"])
@unpurged
@@ -1287,6 +1338,7 @@ class Tape:
if "Bytecode" not in filename:
filename = self.program.programs_dir + "/Bytecode/" + filename
print("Writing to", filename)
sys.stdout.flush()
f = open(filename, "wb")
h = hashlib.sha256()
for i in self._get_instructions():
@@ -1395,13 +1447,12 @@ class Tape:
return repr(dict(self))
class ReqNode(object):
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
def __init__(self, name):
self._children = []
self.name = name
self.blocks = []
self.aggregated = None
self.num = None
@property
def children(self):
@@ -1411,12 +1462,17 @@ class Tape:
def aggregate(self, *args):
if self.aggregated is not None:
return self.aggregated
self.recursion = self.num is not None
if self.recursion:
return Tape.ReqNum()
self.num = Tape.ReqNum()
for block in self.blocks:
block.add_usage(self)
res = reduce(
lambda x, y: x + y.aggregate(self.name), self.children, self.num
)
if self.recursion:
res *= float('inf')
self.aggregated = res
return res
@@ -1442,7 +1498,7 @@ class Tape:
n_reps = self.aggregator([1])
n_rounds = res["all", "round"]
n_invs = res["all", "inv"]
if (n_invs / n_rounds) * 1000 < n_reps:
if (n_invs / n_rounds) * 1000 < n_reps and Program.prog.verbose:
print(
self.nodes[0].blocks[0].name,
"blowing up rounds: ",
@@ -1468,15 +1524,19 @@ class Tape:
def close_scope(self, outer_scope, parent_req_node, name):
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
def require_bit_length(self, bit_length, t="p"):
def require_bit_length(self, bit_length, t="p", reason=None):
if t == "p":
if self.program.prime:
if bit_length >= self.program.prime.bit_length() - 1:
raise CompilerError(
"required bit length %d too much for %d"
% (bit_length, self.program.prime)
+ ('(for %s)' % reason if reason else '')
)
self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t])
bit_length += 1
if bit_length > self.req_bit_length[t]:
self.req_bit_length[t] = bit_length
self.bit_length_reason = reason
else:
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
@@ -1498,6 +1558,21 @@ class Tape:
"In some cases, you can fix this by using 'compile.py -l'."
)
def __int__(self):
raise CompilerError(
"It is impossible to convert run-time types to compile-time "
"Python types like int or float. The reason for this is that "
"%s objects are only a placeholder during the execution in "
"Python, the actual value of which is only defined in the "
"virtual machine at a later time. See "
"https://mp-spdz.readthedocs.io/en/latest/journey.html "
"to get an understanding of the overall design. "
"In rare cases, you can fix this by using 'compile.py -l'." % \
type(self).__name__
)
__float__ = __int__
class Register(_no_truth):
"""
Class for creating new registers. The register's index is automatically assigned
@@ -1619,6 +1694,9 @@ class Tape:
def copy(self):
return Tape.Register(self.reg_type, Program.prog.curr_tape)
def same_type(self):
return type(self)(size=self.size)
def link(self, other):
if Program.prog.options.noreallocate:
raise CompilerError("reallocation necessary for linking, "

View File

@@ -1,5 +1,6 @@
import itertools
from Compiler import types, library, instructions
from Compiler import comparison, util
def dest_comp(B):
Bt = B.transpose()
@@ -20,6 +21,7 @@ def reveal_sort(k, D, reverse=False):
backward order
"""
comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting')
assert len(k) == len(D)
library.break_point()
shuffle = types.sint.get_secure_shuffle(len(k))

View File

@@ -109,7 +109,7 @@ class SqrtOram(Generic[T, B]):
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from(
[self.index_type(i) for i in range(self.n)])
self.index_type(regint.inc(self.n)))
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
@@ -122,7 +122,7 @@ class SqrtOram(Generic[T, B]):
# Note that self.shuffle_the_shuffle mutates this field
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
self.permutation = Array.create_from(
[self.index_type(i) for i in range(self.n)])
self.index_type(regint.inc(self.n)))
# We allow the caller to postpone the initialization of the shuffle
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
# Note that if you do not initialize, the ORAM is insecure
@@ -256,7 +256,6 @@ class SqrtOram(Generic[T, B]):
return result
@lib.method_block
def write(self, index: T, *value: T):
global trace, n_parallel
if trace:
@@ -271,7 +270,12 @@ class SqrtOram(Generic[T, B]):
else:
raise Exception("Cannot handle type of value passed")
print(self.entry_length, value, type(value),len(value))
value = MemValue(value)
self._write(index, *value)
@lib.method_block
def _write(self, index: T, *value: T):
value = MemValue(self.value_type(value))
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@@ -513,14 +517,14 @@ class SqrtOram(Generic[T, B]):
# Since the underlying memory of the position map is already aligned in
# this packed structure, we can simply overwrite the memory while
# maintaining the structure.
self.position_map.reinitialize(*self.permutation)
self.position_map.reinitialize(self.permutation)
def reinitialize(self, *data: T):
def reinitialize(self, data: T):
# Note that this method is only used during refresh, and as such is
# only called with a permutation as data.
# The logical addresses of some previous permutation are irrelevant and must be reset
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
# Reset the clock
self.t.write(0)
# Reset shuffle_used
@@ -530,10 +534,10 @@ class SqrtOram(Generic[T, B]):
# This structure is preserved while overwriting the values using
# assign_vector
self.shuffle.assign_vector(self.value_type(
data, size=self.n * self.entry_length))
data[:], size=self.n * self.entry_length))
# Note that this updates self.permutation (see constructor for explanation)
self.shuffle_the_shuffle()
self.position_map.reinitialize(*self.permutation)
self.position_map.reinitialize(self.permutation)
def _reset_shuffle_used(self):
global allow_memory_allocation
@@ -568,7 +572,7 @@ class PositionMap(Generic[T, B]):
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, *permutation: T):
def reinitialize(self, permutation: T):
"""Reinitialize this PositionMap.
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
@@ -613,9 +617,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
packed_size = int(math.ceil(self.n / pack))
packed_structure = MultiArray(
(packed_size, pack), value_type=value_type)
for i in range(packed_size):
@lib.for_range(packed_size)
def _(i):
packed_structure[i] = Array.create_from(
permutation[i*pack:(i+1)*pack])
permutation.get_vector(base=i * pack, size=pack))
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth,
@@ -720,8 +725,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
return p.reveal()
def reinitialize(self, *permutation: T):
SqrtOram.reinitialize(self, *permutation)
def reinitialize(self, permutation: T):
SqrtOram.reinitialize(self, permutation)
class LinearPositionMap(PositionMap):
@@ -790,8 +795,8 @@ class LinearPositionMap(PositionMap):
return p.reveal()
def reinitialize(self, *data: T):
self.physical.assign_vector(data)
def reinitialize(self, data : T):
self.physical.assign(data)
global allow_memory_allocation
if allow_memory_allocation:

File diff suppressed because it is too large Load Diff

View File

@@ -28,11 +28,11 @@ def greater_than(a, b, bits):
else:
return a.greater_than(b, bits)
def pow2_value(a, bit_length=None, security=None):
def pow2_value(a, bit_length=None):
if is_constant_float(a):
return 2**a
else:
return a.pow2(bit_length, security)
return a.pow2(bit_length)
def mod2m(a, b, bits, signed):
if isinstance(a, int):

View File

@@ -15,7 +15,7 @@
int main()
{
P256Element::init();
P256Element::Scalar key;
KeySetup<Share<P256Element::Scalar>> key;
string prefix = PREP_DIR "ECDSA/";
mkdir_p(prefix.c_str());
write_online_setup(prefix, P256Element::Scalar::pr());

View File

@@ -44,6 +44,7 @@ int main(int argc, const char** argv)
typedef Share<P256Element::Scalar> pShare;
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
read_mac_key(prefix, N, keyp);
pShare::set_mac_key(keyp);
pShare::MAC_Check::setup(P);
Share<P256Element>::MAC_Check::setup(P);

View File

@@ -37,6 +37,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
EcdsaOptions opts)
{
bool prep_mul = opts.prep_mul;
if (prep_mul)
proc.protocol.init_mul();
Timer timer;
timer.start();
Player& P = proc.P;
@@ -77,7 +80,6 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul();
for (int i = 0; i < buffer_size; i++)
protocol.prepare_mul(inv_ks[i], sk);
protocol.start_exchange();

View File

@@ -23,6 +23,7 @@
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"

View File

@@ -24,6 +24,13 @@ Client::Client(const vector<string>& hostnames, int port_base,
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
else
{
octetStream spec;
spec.Receive(sockets[i]);
if (spec != specification)
throw runtime_error("inconsistent specification");
}
}
}

View File

@@ -31,6 +31,15 @@ def set_keepalive_osx(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
class Client:
"""Client to servers running secure computation. Works both as a client
to all parties or a trusted client to a single party.
:param hostnames: hostnames or IP addresses to connect to
:param port_base: port number for first hostname,
increases by one for every additional hostname
:param my_client_id: number to identify client
"""
def __init__(self, hostnames, port_base, my_client_id):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
name = 'C%d' % my_client_id
@@ -62,6 +71,11 @@ class Client:
self.specification = octetStream()
self.specification.Receive(self.sockets[0])
for sock in self.sockets[1:]:
specification = octetStream()
specification.Receive(sock)
if specification.buf != self.specification.buf:
raise Exception('inconsistent specification')
type = self.specification.get_int(4)
if type == ord('R'):
self.domain = Z2(self.specification.get_int(4))
@@ -99,6 +113,12 @@ class Client:
return triples
def send_private_inputs(self, values):
""" Send inputs privately to the computation servers.
This assumes that the client is connected to all servers.
:param values: list of input values
"""
T = self.domain
triples = self.receive_triples(T, len(values))
os = octetStream()
@@ -109,10 +129,46 @@ class Client:
os.Send(socket)
def receive_outputs(self, n):
""" Receive outputs privately from the computation servers.
This assumes that the client is connected to all servers.
:param n: number of outputs
"""
T = self.domain
triples = self.receive_triples(T, n)
return [int(self.clear_domain(triple[0].v)) for triple in triples]
def send_public_inputs(self, values):
""" Send values in the clear. This works for public inputs
to all servers or to send shares to a single server.
:param values: list of values
"""
os = octetStream()
for value in values:
self.domain(value).pack(os)
for socket in self.sockets:
os.Send(socket)
def receive_plain_values(self, socket=None):
""" Receive values in the clear. This works for public inputs
to all servers or to send shares to a single server.
:param socket: socket to use (need to specify it there is more than one)
"""
if socket is None:
if len(self.sockets) != 1:
raise Exception('need to specify socket')
socket = self.sockets[0]
os = octetStream()
os.Receive(socket)
assert len(os) % self.domain.size() == 0
return [int(os.get(self.domain))
for i in range(len(os) // self.domain.size())]
class octetStream:
def __init__(self, value=None):
self.buf = b''
@@ -123,6 +179,8 @@ class octetStream:
def get_length(self):
return len(self.buf)
__len__ = get_length
def reset_write_head(self):
self.buf = b''
self.ptr = 0
@@ -164,6 +222,11 @@ class octetStream:
else:
return 0
def get(self, type):
res = type()
res.unpack(self)
return res
def consume(self, length):
self.ptr += length
assert self.ptr <= len(self.buf)

View File

@@ -7,7 +7,7 @@ class Domain:
def __int__(self):
res = self.v % self.modulus
return res if res < self.modulus / 2 else res - self.modulus
return int(res if res < self.modulus / 2 else res - self.modulus)
def __add__(self, other):
try:

View File

@@ -0,0 +1,19 @@
#!/usr/bin/python3
import sys, random
sys.path.insert(0, 'ExternalIO')
from client import *
party = int(sys.argv[1])
client = Client(['localhost'], 15000 + party, 0)
n = 1000
if party < 2:
client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n))
x = [client.receive_plain_values() for i in range(2)]
client.send_public_inputs(a + b for a, b in zip(*x))

View File

@@ -13,6 +13,8 @@
#include "SimpleMachine.h"
#include "Tools/mkpath.h"
#include "Protocols/Share.hpp"
template<class FD>
Producer<FD>::Producer(int output_thread, bool write_output) :
n_slots(0), output_thread(output_thread), write_output(write_output),

View File

@@ -8,6 +8,8 @@
#include "BitAdder.h"
#include "Protocols/BufferScope.h"
#include <assert.h>
template<class T>
@@ -69,6 +71,9 @@ void BitAdder::add(vector<vector<T> >& res,
size_t n_items = end - begin;
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%lu ANDs in bit adder\n", length * n_items * n_bits);
if (supply)
{
#ifdef VERBOSE_EDA
@@ -85,6 +90,7 @@ void BitAdder::add(vector<vector<T> >& res,
vector<T> carries(n_items);
vector<T> a(n_items), b(n_items);
auto& protocol = proc.protocol;
BufferScope scope(proc.DataF, n_items * length * n_bits);
for (int i = 0; i < n_bits; i++)
{
assert(summands[i].size() == 2);

View File

@@ -24,14 +24,14 @@ void FakeSecret::load_clear(int n, const Integer& x)
*this = x;
}
void FakeSecret::bitcom(Memory<FakeSecret>& S, const vector<int>& regs)
void FakeSecret::bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
void FakeSecret::bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const
void FakeSecret::bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;

View File

@@ -142,8 +142,8 @@ public:
template <class T>
void store(Memory<T>& mem, size_t address) { mem[address] = *this; }
void bitcom(Memory<FakeSecret>& S, const vector<int>& regs);
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
void bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs);
void bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const;
template <class T>
void xor_(int n, const FakeSecret& x, const T& y)

View File

@@ -84,8 +84,8 @@ void Instruction::parse(istream& s, int pos)
ostringstream os;
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
os << "This virtual machine executes binary circuits only." << endl;
os << "Use 'compile.py -B'." << endl;
throw Invalid_Instruction(os.str());
os << "Use 'compile.py -B'.";
exit_error(os.str());
break;
}
}

View File

@@ -52,7 +52,7 @@ public:
static DataFieldType field_type()
{
throw not_implemented();
return DATA_GF2;
}
static void init_minimum(int)
@@ -80,7 +80,8 @@ public:
bool operator!=(NoValue) const { return false; }
bool operator==(int) { fail(); return false; }
bool operator==(int) const { fail(); return false; }
bool operator==(NoValue) const { fail(); return false; }
bool get_bit(int) { fail(); return 0; }
@@ -92,6 +93,8 @@ public:
void input(istream&, bool) { fail(); }
void output(ostream&, bool) {}
void pack(octetStream&) const { fail(); }
};
inline ostream& operator<<(ostream& o, NoValue)
@@ -169,8 +172,8 @@ public:
void load_clear(Integer, Integer) { fail(); }
void random_bit() { fail(); }
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitdec(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
void assign(const char*) { fail(); }
@@ -190,6 +193,8 @@ public:
NoShare& operator+=(const NoShare&) { fail(); return *this; }
bool operator==(NoShare) const { fail(); return false; }
NoShare get_bit(int) const { fail(); return {}; }
void xor_bit(int, NoShare) const { fail(); }
@@ -201,6 +206,8 @@ public:
void input(istream&, bool) { fail(); }
void output(ostream&, bool) { fail(); }
void pack(octetStream&) const { fail(); }
};
} /* namespace GC */

View File

@@ -79,10 +79,11 @@ template<class T>
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
size_t begin, size_t end)
{
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
RunningTimer timer;
#endif
bool verbose = OnlineOptions::singleton.has_option("verbose_eda");
if (verbose)
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
auto& party = ShareThread<typename T::whole_type>::s();
auto& MC = party.MC->get_part_MC();
auto& P = *party.P;
@@ -102,9 +103,9 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
input.exchange();
for (size_t i = begin; i < end; i++)
triples[i][2] = input.finalize(input_player, T::default_length);
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
#endif
if (verbose)
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
}
}

View File

@@ -16,6 +16,7 @@ using namespace std;
#include "Math/Integer.h"
#include "Processor/ProcessorBase.h"
#include "Processor/Instruction.h"
#include "Tools/CheckVector.h"
namespace GC
{
@@ -38,9 +39,9 @@ public:
// rough measure for the memory usage
size_t complexity;
Memory<T> S;
Memory<Clear> C;
Memory<Integer> I;
StackedVector<T> S;
StackedVector<Clear> C;
StackedVector<Integer> I;
Timer xor_timer;
@@ -78,8 +79,8 @@ public:
template<class U>
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
template<class U>
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
template<class U, class V>
void mem_op(int n, U& dest, const V& source,
Integer dest_address, Integer source_address);
void xors(const vector<int>& args);
@@ -105,6 +106,9 @@ public:
template<int = 0>
void convcbit2s(const BaseInstruction& instruction);
void convcbitvec(const BaseInstruction& instruction, StackedVector<Integer>& Ci,
Player* P);
void print_reg(int reg, int n, int size);
void print_reg_plain(Clear& value);
void print_reg_signed(unsigned n_bits, Integer value);
@@ -114,6 +118,13 @@ public:
void print_float_prec(int n);
void incint(const BaseInstruction& instruction);
void push_stack();
void push_args(const vector<int>& args);
void pop_stack(const vector<int>& results);
template<class U>
void call_tape(const BaseInstruction& instruction, U& dynamic_memory);
};
template <class T>

View File

@@ -18,7 +18,7 @@ using namespace std;
#include "Math/BitVec.h"
#include "GC/Machine.hpp"
#include "Processor/ProcessorBase.hpp"
#include "Processor/Processor.hpp"
#include "Processor/IntInput.hpp"
#include "Math/bigint.hpp"
@@ -53,9 +53,9 @@ template <class T>
template <class U>
void Processor<T>::reset(const U& program, int arg)
{
S.resize(program.num_reg(SBIT), "registers");
C.resize(program.num_reg(CBIT), "registers");
I.resize(program.num_reg(INT), "registers");
S.resize(program.num_reg(SBIT));
C.resize(program.num_reg(CBIT));
I.resize(program.num_reg(INT));
set_arg(arg);
PC = 0;
}
@@ -202,14 +202,14 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
}
template<class T>
template<class U>
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
template<class U, class V>
void Processor<T>::mem_op(int n, U& dest, const V& source,
Integer dest_address, Integer source_address)
{
dest.check_index(dest_address + n - 1);
source.check_index(source_address + n - 1);
auto d = &dest[dest_address];
auto s = &source[source_address];
auto d = &dest[dest_address.get()];
auto s = &source[source_address.get()];
for (int i = 0; i < n; i++)
{
*d++ = *s++;
@@ -388,6 +388,30 @@ void Processor<T>::convcbit2s(const BaseInstruction& instruction)
min(size_t(unit), instruction.get_n() - i * unit));
}
template<class T>
void Processor<T>::convcbitvec(const BaseInstruction& instruction,
StackedVector<Integer>& Ci, Player* P)
{
vector<Integer> bits;
auto n = instruction.get_n();
bits.reserve(n);
for (size_t i = 0; i < instruction.get_n(); i++)
{
int i1 = i / GC::Clear::N_BITS;
int i2 = i % GC::Clear::N_BITS;
auto bit = C[instruction.get_r(1) + i1].get_bit(i2);
bits.push_back(bit);
}
if (P)
sync<T>(bits, *P);
else if (not T::symmetric)
sync<T>(bits, *Thread<T>::s().P);
for (size_t i = 0; i < n; i++)
Ci[instruction.get_r(0) + i] = bits[i];
}
template <class T>
void Processor<T>::print_reg(int reg, int n, int size)
{
@@ -417,7 +441,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
{
if (n_bits <= Clear::N_BITS)
{
auto value = C[reg];
auto value = C[reg.get()];
unsigned n_shift = 0;
if (n_bits > 1)
n_shift = sizeof(value.get()) * 8 - n_bits;
@@ -477,6 +501,56 @@ void Processor<T>::incint(const BaseInstruction& instruction)
}
}
template<class T>
void GC::Processor<T>::push_stack()
{
S.push_stack();
C.push_stack();
}
template<class T>
void GC::Processor<T>::push_args(const vector<int>& args)
{
S.push_args(args, SBIT);
C.push_args(args, CBIT);
}
template<class T>
void GC::Processor<T>::pop_stack(const vector<int>& results)
{
S.pop_stack(results, SBIT);
C.pop_stack(results, CBIT);
}
template<class T>
template<class U>
void Processor<T>::call_tape(const BaseInstruction& instruction, U& dynamic_memory)
{
auto new_arg = I.at(instruction.get_r(1)).get();
PC_stack.push_back(PC);
arg_stack.push_back(this->arg);
push_stack();
I.push_stack();
auto& tape = machine->progs.at(instruction.get_r(0));
reset(tape, new_arg);
auto& args = instruction.get_start();
push_args(args);
I.push_args(args, INT);
tape.execute(*this, dynamic_memory, PC);
pop_stack(args);
I.pop_stack(args, INT);
PC = PC_stack.back();
PC_stack.pop_back();
this->arg = arg_stack.back();
arg_stack.pop_back();
}
} /* namespace GC */
#endif

View File

@@ -7,6 +7,7 @@
#include "Protocols/Rep4.hpp"
#include "Protocols/Rep4Input.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ReplicatedPrep.hpp"
namespace GC

View File

@@ -76,6 +76,10 @@ public:
static const bool actual_inputs = T::actual_inputs;
static const bool symmetric = true;
static bool real_shares(const Player&) { return true; }
static int threshold(int nplayers) { return T::threshold(nplayers); }
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
@@ -148,9 +152,9 @@ public:
Secret<T> operator>>(int i) const;
template<class U>
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitcom(StackedVector<U>& S, const vector<int>& regs);
template<class U>
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
Secret<T> operator+(const Secret<T>& x) const;
Secret<T>& operator+=(const Secret<T>& x) { *this = *this + x; return *this; }

View File

@@ -197,7 +197,7 @@ Secret<T> Secret<T>::operator>>(int i) const
template <class T>
template <class U>
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
void Secret<T>::bitcom(StackedVector<U>& S, const vector<int>& regs)
{
registers.clear();
for (unsigned int i = 0; i < regs.size(); i++)
@@ -210,7 +210,7 @@ void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
template <class T>
template <class U>
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
void Secret<T>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
{
if (regs.size() > registers.size())
throw overflow("not enough bits for bit decomposition", regs.size(),

View File

@@ -17,7 +17,7 @@ namespace GC
void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat)
{
if (repeat and OnlineOptions::singleton.live_prep)
if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1))
{
this->triples.push_back({{}});
auto& triple = this->triples.back();
@@ -35,6 +35,8 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n)
{
if (n == -1)
n = SemiSecret::default_length;
super::prepare_mul(x.mask(n), y.mask(n), n);
}

View File

@@ -24,7 +24,7 @@ public:
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat);
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n);
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n = -1);
};
} /* namespace GC */

View File

@@ -79,6 +79,9 @@ array<SemiSecret, 3> SemiPrep::get_mixed_triple(int n)
if (mixed_triples.empty())
{
assert(this->triple_generator);
this->triple_generator->set_batch_size(
BaseMachine::batch_size<SemiSecret>(DATA_MIXED,
this->buffer_size));
this->triple_generator->generateMixedTriples();
for (auto& x : this->triple_generator->mixedTriples)
{

View File

@@ -38,6 +38,13 @@ public:
static const int default_length = sizeof(BitVec) * 8;
static const bool symmetric = V::symmetric;
static bool real_shares(const Player& P)
{
return V::real_shares(P);
}
static string type_string() { return "binary secret"; }
static string phase_name() { return "Binary computation"; }
@@ -64,8 +71,8 @@ public:
void load_clear(int n, const Integer& x);
void bitcom(Memory<T>& S, const vector<int>& regs);
void bitdec(Memory<T>& S, const vector<int>& regs) const;
void bitcom(StackedVector<T>& S, const vector<int>& regs);
void bitdec(StackedVector<T>& S, const vector<int>& regs) const;
void xor_(int n, const T& x, const T& y)
{ *this = BitVec(x ^ y).mask(n); }

View File

@@ -70,30 +70,40 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
assert(protocol);
protocol->init_mul();
auto it = args.begin();
int total_bits = 0, total_ops = 0;
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
total_bits += n_args * size;
it += n_args;
int base = *it++;
assert(n_args <= N_BITS);
for (int i = 0; i < size; i += N_BITS)
{
square64 square;
for (int j = 0; j < n_args; j++)
square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get();
int n_ops = min(N_BITS, size - i);
square.transpose(n_args, n_ops);
for (int j = 0; j < n_ops; j++)
for (int k = 0; k < n_args; k += N_BITS)
{
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
auto y_ext = SemiSecret(bit).extend_bit();
protocol->prepare_mult(square.rows[j], y_ext, n_args, true);
int left = min(N_BITS, n_args - k);
square64 square;
for (int j = 0; j < left; j++)
square.rows[j] = processor.S.at(
*(it + k + j) + i / N_BITS).get();
int n_ops = min(N_BITS, size - i);
total_ops += n_ops;
square.transpose(left, n_ops);
for (int j = 0; j < n_ops; j++)
{
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
auto y_ext = SemiSecret(bit).extend_bit();
protocol->prepare_mult(square.rows[j], y_ext, left, true);
}
}
}
it += n_args;
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d/%d repeat ANDs\n", total_bits, total_ops);
protocol->exchange();
it = args.begin();
@@ -103,13 +113,18 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
int size = *it++;
for (int i = 0; i < size; i += N_BITS)
{
int n_ops = min(N_BITS, size - i);
square64 square;
for (int j = 0; j < n_ops; j++)
square.rows[j] = protocol->finalize_mul(n_args).get();
square.transpose(n_ops, n_args);
for (int j = 0; j < n_args; j++)
processor.S.at(*(it + j) + i / N_BITS) = square.rows[j];
for (int base = 0; base < n_args; base += N_BITS)
{
int left = min(N_BITS, n_args - base);
int n_ops = min(N_BITS, size - i);
square64 square;
for (int j = 0; j < n_ops; j++)
square.rows[j] = protocol->finalize_mul(left).get();
square.transpose(n_ops, left);
for (int j = 0; j < left; j++)
processor.S.at(*(it + base + j) + i / N_BITS) =
square.rows[j];
}
}
it += 2 * n_args + 1;
}
@@ -123,7 +138,7 @@ void SemiSecretBase<T, V>::load_clear(int n, const Integer& x)
}
template<class T, class V>
void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
void SemiSecretBase<T, V>::bitcom(StackedVector<T>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
@@ -131,7 +146,7 @@ void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
}
template<class T, class V>
void SemiSecretBase<T, V>::bitdec(Memory<T>& S,
void SemiSecretBase<T, V>::bitdec(StackedVector<T>& S,
const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)

View File

@@ -108,18 +108,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
else
P = new PlainPlayer(this->N, "shareparty");
try
{
read_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
this->N,
this->mac_key);
}
catch (exception& e)
{
SeededPRNG G;
this->mac_key.randomize(G);
}
T::read_or_generate_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
*P, this->mac_key);
T::MC::setup(*P);

View File

@@ -46,6 +46,9 @@ public:
static const bool is_real = true;
static const bool actual_inputs = true;
static const bool symmetric = true;
static bool real_shares(const Player&) { return true; }
static ShareThread<U>& get_party()
{
@@ -118,6 +121,7 @@ public:
typedef BitVec open_type;
typedef NoShare mac_type;
typedef NoValue mac_key_type;
typedef NoShare mac_share_type;
typedef NoShare bit_type;
@@ -151,6 +155,11 @@ public:
{
}
static GC::NoValue get_mac_key()
{
throw runtime_error("no MAC");
}
template<class T>
static string proto_fake_opts()
{
@@ -166,8 +175,8 @@ public:
{
}
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void bitcom(StackedVector<U>& S, const vector<int>& regs);
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
void xor_(int n, const This& x, const This& y)
{ *this = (x ^ y).mask(n); }

View File

@@ -54,7 +54,7 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
}
template<class U, int L>
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
void RepSecretBase<U, L>::bitcom(StackedVector<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
@@ -62,7 +62,7 @@ void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
}
template<class U, int L>
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
void RepSecretBase<U, L>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;

View File

@@ -94,23 +94,35 @@ void ShareThread<T>::and_(Processor<T>& processor,
processor.check_args(args, 4);
protocol->init_mul();
T x_ext, y_ext;
int total_bits = 0;
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
total_bits += n_bits;
int left = args[i + 2];
int right = args[i + 3];
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
{
int n = min(T::default_length, n_bits - j * T::default_length);
if (not repeat and n == T::default_length)
{
protocol->prepare_mul(processor.S[left + j], processor.S[right + j]);
continue;
}
processor.S[left + j].mask(x_ext, n);
if (repeat)
processor.S[right].extend_bit(y_ext, n);
else
processor.S[right + j].mask(y_ext, n);
processor.S[left + j].mask(x_ext, n);
protocol->prepare_mult(x_ext, y_ext, n, repeat);
}
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d%s ANDs\n", total_bits, repeat ? " repeat" : "");
protocol->exchange();
for (size_t i = 0; i < args.size(); i += 4)
@@ -121,6 +133,13 @@ void ShareThread<T>::and_(Processor<T>& processor,
{
int n = min(T::default_length, n_bits - j * T::default_length);
auto& res = processor.S[out + j];
if (not repeat and n == T::default_length)
{
res = protocol->finalize_mul();
continue;
}
protocol->finalize_mult(res, n);
res.mask(res, n);
}
@@ -136,10 +155,12 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
protocol->init_mul();
auto it = args.begin();
T x_ext, y_ext;
int total_bits = 0;
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
total_bits += size * n_args;
it += n_args;
int base = *it++;
for (int i = 0; i < size; i += N_BITS)
@@ -155,6 +176,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
it += n_args;
}
if (OnlineOptions::singleton.has_option("verbose_and"))
fprintf(stderr, "%d repeat ANDs\n", total_bits);
protocol->exchange();
it = args.begin();

View File

@@ -32,6 +32,8 @@ public:
array<T, 3> get_triple_no_count(int n_bits)
{
int max_n_bits = T::default_length;
if (n_bits == -1)
n_bits = max_n_bits;
assert(n_bits <= max_n_bits);
assert(n_bits > 0);
array<T, 3> res;

View File

@@ -53,9 +53,15 @@ public:
static const bool malicious = T::malicious;
static const bool expensive_triples = T::expensive_triples;
static const bool randoms_for_opens = false;
static const bool symmetric = true;
static const int default_length = 64;
static bool real_shares(const Player&)
{
return true;
}
static int size()
{
return part_type::size() * default_length;
@@ -72,6 +78,11 @@ public:
T::read_or_generate_mac_key(directory, P, key);
}
static typename T::mac_type get_mac_key()
{
return T::get_mac_key();
}
template<class U>
static void reveal_inst(U& processor, const vector<int>& args)
{

View File

@@ -80,14 +80,15 @@
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
X(CONVCBIT, Proc.write_Ci(R0, Proc.sync(PC1.get()))) \
X(CONVCINTVEC, Proc.convcintvec(instruction)) \
X(CONVCBITVEC, Proc.convcbitvec(instruction)) \
X(CONVCBITVEC, Proc.Procb.convcbitvec(instruction, Proc.get_Ci(), &Proc.P)) \
X(CONVCBIT2S, PROC.convcbit2s(instruction)) \
X(DABIT, Proc.dabit(INST)) \
X(EDABIT, Proc.edabit(INST)) \
X(SEDABIT, Proc.edabit(INST, true)) \
X(SPLIT, Proc.split(INST)) \
X(CALL_ARG, ) \
#define GC_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, EXTRA)) \
@@ -101,6 +102,7 @@
X(CONVCINT, C0 = PI1) \
X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \
X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \
X(CONVCBITVEC, PROC.convcbitvec(instruction, Ci, 0)) \
X(PRINTCHR, PROC.print_chr(IMM)) \
X(PRINTSTR, PROC.print_str(IMM)) \
X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \
@@ -146,8 +148,11 @@
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \
X(CRASH, if (I0.get() and T::actual_inputs) throw crash_requested()) \
X(ACTIVE, ) \
X(LDTN, I0 = BaseMachine::thread_num) \
X(CALL_TAPE, PROC.call_tape(INST, MD)) \
X(CALL_ARG, ) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS

View File

@@ -1,4 +1,4 @@
The Software is copyright (c) 2023, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
The Software is copyright (c) 2024, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence.

View File

@@ -232,9 +232,9 @@ OTMachine::OTMachine(int argc, const char** argv)
gettimeofday(&baseOTstart, NULL);
// swap role for base OTs
if (opt.isSet("-r"))
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
else
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
cout << "real mode " << opt.isSet("-r") << endl;
BaseOT& bot = *bot_;
bot.exec_base();

View File

@@ -14,6 +14,7 @@
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "GC/PersonalPrep.hpp"
#include "Protocols/Share.hpp"
//template class GC::ShareParty<GC::TinierSecret<gf2n_mac_key>>;
template class GC::CcdPrep<GC::TinierSecret<gf2n_mac_key>>;

View File

@@ -15,6 +15,7 @@
#include "Math/BitVec.h"
#include "GC/TinierSecret.h"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Math/Z2k.hpp"

View File

@@ -39,5 +39,6 @@ int main(int argc, const char** argv)
return run<2, 1>(machine);
cerr << "Not compiled for choice of parameters" << endl;
cerr << "Try using '-lgp 128'" << endl;
exit(1);
}

View File

@@ -23,6 +23,7 @@
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"

View File

@@ -22,6 +22,7 @@
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
#include "Protocols/PostSacrifice.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/MaliciousRepPrep.hpp"

View File

@@ -26,6 +26,7 @@
#include "Protocols/Beaver.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/Share.hpp"
int main(int argc, const char** argv)
{

View File

@@ -26,6 +26,7 @@
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/Share.hpp"
int main(int argc, const char** argv)
{

View File

@@ -52,7 +52,7 @@ endif
endif
# used for dependency generation
OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp)) $(STATIC_OTE)
OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp */*/*.cpp)) $(STATIC_OTE)
DEPS := $(wildcard */*.d */*/*.d)
# never delete
@@ -150,13 +150,17 @@ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a
$(MAKE) static-dir
$(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static/%.x: Machines/BMR/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a
$(MAKE) static-dir
$(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static-dir:
@ mkdir static 2> /dev/null; true
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) static/emulate.x
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) $(patsubst Machines/BMR/%.cpp, static/%.x, $(wildcard Machines/BMR/*-party.cpp)) static/emulate.x
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS)
@@ -352,7 +356,7 @@ cmake:
wget https://github.com/Kitware/CMake/releases/download/v3.24.1/cmake-3.24.1.tar.gz
tar xzvf cmake-3.24.1.tar.gz
cd cmake-3.24.1; \
./bootstrap --parallel=8 --prefix=../local && make && make install
./bootstrap --parallel=8 --prefix=../local && make -j8 && make install
mac-setup: mac-machine-setup
brew install openssl boost libsodium gmp yasm ntl cmake

View File

@@ -67,8 +67,8 @@ public:
{
if (n == -1)
pack(os);
else if (n == 1)
os.store_bit(this->a);
else if (n < 8)
os.store_bits(this->a, n);
else
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
}
@@ -77,8 +77,8 @@ public:
{
if (n == -1)
unpack(os);
else if (n == 1)
this->a = os.get_bit();
else if (n < 8)
this->a = os.get_bits(n);
else
this->a = os.get_int(DIV_CEIL(n, 8));
}

View File

@@ -161,6 +161,11 @@ public:
return equal(1);
}
bool operator==(const FixedVec<T, L>& other) const
{
return equal(other);
}
bool operator!=(const FixedVec<T, L>& other) const
{
return not equal(other);

View File

@@ -306,7 +306,7 @@ public:
return operator*(SignedZ2<64>(other));
}
void output(ostream& s, bool human = true) const;
void output(ostream& s, bool human = true, bool signed_ = true) const;
};
template<int K>
@@ -479,12 +479,17 @@ SignedZ2<K> abs(const SignedZ2<K>& x)
}
template<int K>
void SignedZ2<K>::output(ostream& s, bool human) const
void SignedZ2<K>::output(ostream& s, bool human, bool signed_) const
{
if (human)
{
bigint::tmp = *this;
s << bigint::tmp;
if (signed_)
{
bigint::tmp = *this;
s << bigint::tmp;
}
else
Z2<K>::output(s, human);
}
else
Z2<K>::output(s, false);
@@ -493,7 +498,7 @@ void SignedZ2<K>::output(ostream& s, bool human) const
template<int K>
ostream& operator<<(ostream& o, const SignedZ2<K>& x)
{
x.output(o, true);
x.output(o, true, false);
return o;
}

View File

@@ -510,6 +510,7 @@ gf2n_short::gf2n_short(const int128& a)
// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40)
void expand_byte(gf2n_short& a,int b)
{
gf2n_short::init_field(40);
gf2n_short x,xp;
x = (32+1);
xp.assign_one();

View File

@@ -107,7 +107,7 @@ class gfp_ : public ValueInterface
static void write_setup(string dir)
{ write_online_setup(dir, pr()); }
static void check_setup(string dir);
static string fake_opts() { return " -lgp " + to_string(length()); }
static string fake_opts() { return " -P " + to_string(pr()); }
/**
* Get the prime modulus
@@ -229,18 +229,25 @@ class gfp_ : public ValueInterface
// faster randomization, see implementation for explanation
void almost_randomize(PRNG& G);
void output(ostream& s,bool human) const
{ a.output(s,ZpD,human); }
/**
* Output.
* @param s output stream
* @param human human-readable or binary
* @param signed_ signed representation (range `[-p/2,p/2]` instead of `[0,p]`)
*/
void output(ostream& s, bool human, bool signed_ = false) const
{ a.output(s,ZpD, human, signed_); }
void input(istream& s,bool human)
{ a.input(s,ZpD,human); }
/**
* Human-readable output in the range `[-p/2, p/2]`.
* Human-readable output in the range `[0, p]`.
* @param s output stream
* @param x value
*/
friend ostream& operator<<(ostream& s,const gfp_& x)
{ x.output(s,true);
{
x.output(s, true, false);
return s;
}
/**

View File

@@ -82,7 +82,7 @@ public:
{
write_setup(get_prep_sub_dir<T>(nplayers));
}
static string fake_opts() { return " -lgp " + to_string(length()); }
static string fake_opts() { return " -P " + to_string(pr()); }
gfpvar_();
gfpvar_(int other);

View File

@@ -132,7 +132,7 @@ class modp_
// - Can do in human or machine only format (later should be faster)
// - If human output appends a space to help with reading
// and also convert back/forth from Montgomery if needed
void output(ostream& s,const Zp_Data& ZpD,bool human) const;
void output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_ = false) const;
void input(istream& s,const Zp_Data& ZpD,bool human);
template<int X, int K>

View File

@@ -327,12 +327,12 @@ void Power(modp_<L>& ans,const modp_<L>& x,const bigint& exp,const Zp_Data& ZpD)
template<int L>
void modp_<L>::output(ostream& s,const Zp_Data& ZpD,bool human) const
void modp_<L>::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_) const
{
if (human)
{ bigint te;
to_bigint(te, ZpD);
if (te < ZpD.pr / 2)
if (te < ZpD.pr / 2 or not signed_)
s << te;
else
s << (te - ZpD.pr);

View File

@@ -10,12 +10,12 @@
void check_ssl_file(string filename)
{
if (not ifstream(filename))
throw runtime_error("Cannot access " + filename
exit_error("Cannot access " + filename
+ ". Have you set up SSL?\n"
"You can use `Scripts/setup-ssl.sh <nparties>`.");
}
void ssl_error(string side, string other, string me)
void ssl_error(string side, string other, string me, exception& e)
{
cerr << side << "-side handshake with " << other
<< " failed. Make sure both sides "
@@ -48,6 +48,8 @@ void ssl_error(string side, string other, string me)
cerr << "/";
}
cerr << endl;
cerr << "SSL error: " << e.what() << endl;
exit(1);
}
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :

View File

@@ -7,6 +7,7 @@
#include "Networking/Server.h"
#include "Networking/ServerSocket.h"
#include "Networking/Exchanger.h"
#include "Processor/OnlineOptions.h"
#include <sys/select.h>
#include <utility>
@@ -78,7 +79,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
}
}
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
throw runtime_error("not enough hosts in " + filename);
exit_error("not enough hosts in " + filename);
#ifdef DEBUG_NETWORKING
cerr << "Got list of " << nplayers << " players from file: " << endl;
for (unsigned int i = 0; i < names.size(); i++)
@@ -127,7 +128,17 @@ void Names::setup_names(const char *servername, int my_port)
int socket_num;
int pn = portnum_base;
set_up_client_socket(socket_num, servername, pn);
try
{
set_up_client_socket(socket_num, servername, pn);
}
catch (exception& e)
{
exit_error(
string("cannot reach coordination server: ") + e.what());
}
octetStream("P" + to_string(player_no)).Send(socket_num);
#ifdef DEBUG_NETWORKING
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
@@ -155,11 +166,11 @@ void Names::setup_names(const char *servername, int my_port)
}
catch (exception& e)
{
throw runtime_error(string("error in network setup: ") + e.what());
exit_error(string("error in network setup: ") + e.what());
}
if (names.size() != ports.size())
throw runtime_error("invalid network setup");
exit_error("invalid network setup");
nplayers = names.size();
#ifdef VERBOSE
for (int i = 0; i < nplayers; i++)
@@ -288,7 +299,15 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
"Setting up send to self socket to %s:%d with id %s\n",
localhost, ports[i], pn.c_str());
#endif
set_up_client_socket(sockets[i],localhost,ports[i]);
try
{
set_up_client_socket(sockets[i],localhost,ports[i]);
}
catch (exception& e)
{
exit_error("cannot connect to myself, "
"maybe check your firewall configuration");
}
} else {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up client to %s:%d with id %s\n",
@@ -762,7 +781,7 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
{
sent += other.sent;
for (auto it = other.begin(); it != other.end(); it++)
(*this)[it->first] += it->second;
map<string, CommStats>::operator[](it->first) += it->second;
return *this;
}
@@ -786,7 +805,7 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const
NamedCommStats res = *this;
res.sent = sent - other.sent;
for (auto it = other.begin(); it != other.end(); it++)
res[it->first] -= it->second;
res.map<string, CommStats>::operator[](it->first) -= it->second;
return res;
}
@@ -818,9 +837,25 @@ Timer& NamedCommStats::add_to_last_round(const string& name, size_t length)
}
}
void PlayerBase::reset_stats()
Timer& CommStatsWithName::add_length_only(size_t length)
{
if (OnlineOptions::singleton.has_option("verbose_comm"))
fprintf(stderr, "%s %zu bytes in same round\n", name.c_str(), length);
return stats.add_length_only(length);
}
Timer& CommStatsWithName::add(const octetStream& os)
{
if (OnlineOptions::singleton.has_option("verbose_comm"))
fprintf(stderr, "%s %zu bytes\n", name.c_str(), os.get_length());
return stats.add(os);
}
void Player::reset_stats()
{
comm_stats.reset();
for (auto& x : thread_stats)
x.reset();
}
NamedCommStats Player::total_comm() const

View File

@@ -141,18 +141,28 @@ struct CommStats
}
Timer& add_length_only(size_t length)
{
#ifdef VERBOSE_COMM
cout << "add " << length << endl;
#endif
data += length;
return timer;
}
Timer& add(const octetStream& os) { return add(os.get_length()); }
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
CommStats& operator+=(const CommStats& other);
CommStats& operator-=(const CommStats& other);
};
class CommStatsWithName
{
const string& name;
CommStats& stats;
public:
CommStatsWithName(const string& name, CommStats& stats) :
name(name), stats(stats) {}
Timer& add_length_only(size_t length);
Timer& add(const octetStream& os);
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
};
class NamedCommStats : public map<string, CommStats>
{
public:
@@ -167,14 +177,8 @@ public:
void print(bool newline = false);
void reset();
Timer& add_to_last_round(const string& name, size_t length);
#ifdef VERBOSE_COMM
CommStats& operator[](const string& name)
{
auto& res = map<string, CommStats>::operator[](name);
cout << name << " after " << res.data << endl;
return res;
}
#endif
CommStatsWithName operator[](const string& name)
{ return {name, map<string, CommStats>::operator[](name)}; }
};
/**
@@ -209,8 +213,6 @@ public:
virtual void send_receive_all(const vector<octetStream>&,
vector<octetStream>&) const
{ throw not_implemented(); }
void reset_stats();
};
/**
@@ -394,6 +396,7 @@ public:
{ receive_player(i, o); }
NamedCommStats total_comm() const;
void reset_stats();
};
/**

View File

@@ -168,8 +168,13 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
cerr << "Starting networking for " << my_num << "/" << nplayers
<< " with server on " << hostname << ":" << (portnum) << endl;
#endif
assert(my_num >= 0);
assert(my_num < nplayers);
if (my_num < 0 or my_num >= nplayers)
{
cerr << "Player number " << my_num << " outside range: 0-"
<< nplayers - 1 << endl;
exit(1);
}
Server* server = 0;
pthread_t thread;
if (my_num == 0)

View File

@@ -205,7 +205,7 @@ int ServerSocket::get_connection_socket(const string& id)
while (clients.find(id) == clients.end())
{
if (data_signal.wait(CONNECTION_TIMEOUT) == ETIMEDOUT)
throw runtime_error("Timed out waiting for peer. See "
exit_error("Timed out waiting for peer. See "
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
"for details on networking.");
}
@@ -230,7 +230,7 @@ void AnonymousServerSocket::init()
void AnonymousServerSocket::process_client(const string& client_id)
{
if (clients.find(client_id) != clients.end())
throw runtime_error("client " + client_id + " already connected");
exit_error("client " + client_id + " already connected");
client_connection_queue.push(client_id);
}
@@ -242,7 +242,7 @@ int AnonymousServerSocket::get_connection_socket(string& client_id)
{
int res = data_signal.wait(CONNECTION_TIMEOUT);
if (res == ETIMEDOUT)
throw runtime_error("timed out while waiting for client");
exit_error("timed out while waiting for client");
else if (res)
throw runtime_error("waiting error");
}

View File

@@ -14,7 +14,7 @@ void error(const char *str)
gethostname(err,1000);
strcat(err," : ");
strcat(err,str);
throw runtime_error(string() + err + " : " + strerror(old_errno));
exit_error(string() + err + " : " + strerror(old_errno));
}
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
@@ -62,7 +62,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
{
for (rp = ai; rp != NULL; rp = rp->ai_next)
cerr << "Family on offer: " << ai->ai_family << endl;
runtime_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name);
exit_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name);
}
@@ -106,10 +106,11 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
if (fl < 0)
{
throw runtime_error(
exit_error(
string() + "cannot connect from " + my_name + " to " + hostname + ":"
+ to_string(Portnum) + " after " + to_string(attempts)
+ " attempts in one minute because " + strerror(connect_errno) + ". "
+ " attempts in " + to_string(CONNECTION_TIMEOUT)
+ " seconds because " + strerror(connect_errno) + ". "
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#"
"connection-failures has more information on port requirements.");
}

View File

@@ -21,7 +21,7 @@
typedef boost::asio::io_service ssl_service;
void check_ssl_file(string filename);
void ssl_error(string side, string other, string server);
void ssl_error(string side, string other, string server, exception& e);
class ssl_ctx : public boost::asio::ssl::context
{
@@ -62,9 +62,9 @@ public:
try
{
handshake(ssl_socket::client);
} catch (...)
} catch (exception& e)
{
ssl_error("Client", other, me);
ssl_error("Client", other, me, e);
throw;
}
else
@@ -72,9 +72,9 @@ public:
try
{
handshake(ssl_socket::server);
} catch (...)
} catch (exception& e)
{
ssl_error("Server", other, me);
ssl_error("Server", other, me, e);
throw;
}

View File

@@ -47,12 +47,12 @@ public:
vector<BitVector> receiver_outputs;
TwoPartyPlayer* P;
/// Number of OTs
int nOT, ot_length;
int nOT;
/// Which role(s) on this side
OT_ROLE ot_role;
BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH)
: P(player), nOT(nOT), ot_length(ot_length), ot_role(role)
BaseOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH)
: P(player), nOT(nOT), ot_role(role)
{
receiver_inputs.resize(nOT);
sender_inputs.resize(nOT);
@@ -69,14 +69,12 @@ public:
}
BaseOT(TwoPartyPlayer* player, OT_ROLE role) :
BaseOT(128, 128, player, role)
BaseOT(128, player, role)
{
}
virtual ~BaseOT() {}
int length() { return ot_length; }
/// Set choice bits
void set_receiver_inputs(const BitVector& new_inputs)
{
@@ -126,8 +124,8 @@ protected:
class FakeOT : public BaseOT
{
public:
FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
BaseOT(nOT, ot_length, player, role) {}
FakeOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
BaseOT(nOT, player, role) {}
void exec_base(bool new_receiver_inputs=true);
};

View File

@@ -8,12 +8,13 @@
void BitDiagonal::pack(octetStream& os) const
{
for (int i = 0; i < N_ROWS; i++)
os.store_int(rows[i].get_bit(i), 1);
os.store_bit(rows[i].get_bit(i));
os.append(0);
}
void BitDiagonal::unpack(octetStream& os)
{
*this = {};
for (int i = 0; i < N_ROWS; i++)
rows[i] = os.get_int(1) << i;
rows[i] = RowType(os.get_bit()) << i;
}

View File

@@ -72,6 +72,11 @@ Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
template<class T>
void OTTripleGenerator<T>::set_batch_size(int batch_size)
{
// limit to ~1 GB
batch_size = min(batch_size, int(1e7 / sizeof(T) / sizeof(T)));
if (OnlineOptions::singleton.has_option("verbose_ot"))
fprintf(stderr, "OT batch size %d (share size %d)\n", batch_size,
int(sizeof(T)));
nTriplesPerLoop = DIV_CEIL(batch_size, nloops);
nTriples = nTriplesPerLoop * nloops;
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
@@ -198,9 +203,9 @@ void NPartyTripleGenerator<T>::generate()
{
outputFile.open(ss.str().c_str());
if (machine.generateMACs or not T::clear::invertible)
file_signature<T>().output(outputFile);
file_signature<T>(this->mac_key).output(outputFile);
else
file_signature<typename T::clear>().output(outputFile);
file_signature<SemiShare<typename T::clear>>().output(outputFile);
}
if (machine.generateBits)
@@ -250,6 +255,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
inputs.resize(toCheck);
auto mac_key = this->get_mac_key();
SemiInput<SemiShare<T>> input(0, globalPlayer);
input.maybe_init(globalPlayer);
input.reset_all(globalPlayer);
vector<T> secrets(toCheck);
if (mine)
@@ -528,8 +534,10 @@ void OTTripleGenerator<T>::generateMixedTriples()
machine.set_passive();
machine.output = false;
int n = multiple_minimum(100 * nPreampTriplesPerLoop,
T::open_type::size_in_bits());
int n = multiple_minimum(nPreampTriplesPerLoop, 8);
if (OnlineOptions::singleton.has_option("verbose_mixed"))
fprintf(stderr, "generating %d mixed triples\n", n);
valueBits.resize(2);
valueBits[0].resize(n);
@@ -556,6 +564,9 @@ void OTTripleGenerator<T>::generateMixedTriples()
template<class U>
void OTTripleGenerator<U>::plainTripleRound(int k)
{
if (OnlineOptions::singleton.has_option("verbose_triples"))
fprintf(stderr, "generating %d triples\n", nPreampTriplesPerLoop);
typedef typename U::open_type T;
if (not (machine.amplify or machine.output))

View File

@@ -78,6 +78,11 @@ void OTCorrelator<U>::correlate(int start, int slice,
Slice<U> t1Slice(t1, start, slice);
Slice<U> uSlice(u, start, slice);
if (OnlineOptions::singleton.has_option("verbose_correlate"))
fprintf(stderr, "correlate %d matrices of size %d*%d, %u bits\n", slice,
int(U::PartType::n_rows()), int(U::PartType::n_columns()),
newReceiverInput.size());
// create correlation
if (ot_role & RECEIVER)
{

View File

@@ -20,7 +20,7 @@ osuCrypto::IOService ot_extension_ios;
OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player,
int128 delta, OT_ROLE role, bool passive)
{
BaseOT baseOT(128, 128, &player, INV_ROLE(role));
BaseOT baseOT(128, &player, INV_ROLE(role));
PRNG G;
G.ReSeed();
baseOT.set_receiver_inputs(delta);
@@ -30,6 +30,11 @@ OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player,
OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player,
bool passive) : OTCorrelator(baseOT, player, passive)
{
init_me();
}
void OTExtensionWithMatrix::init_me()
{
G.ReSeed();
nsubloops = 1;
@@ -37,6 +42,7 @@ OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* pla
#ifndef USE_KOS
channel = 0;
#endif
softspoken_k = 2;
}
OTExtensionWithMatrix::~OTExtensionWithMatrix()
@@ -47,27 +53,43 @@ OTExtensionWithMatrix::~OTExtensionWithMatrix()
#endif
}
bool OTExtensionWithMatrix::use_kos()
{
#ifdef USE_KOS
return true;
#else
return OnlineOptions::singleton.has_option("use_kos");
#endif
}
void OTExtensionWithMatrix::protocol_agreement()
{
if (agreed)
return;
Bundle<octetStream> bundle(*player);
#ifdef USE_KOS
bundle.mine = string("KOS15");
#else
bundle.mine = string("SoftSpokenOT");
#endif
if (use_kos())
bundle.mine = string("KOS15");
else
bundle.mine = string("SoftSpokenOT");
if (OnlineOptions::singleton.has_option("high_softspoken"))
softspoken_k = 8;
bundle.mine.store(softspoken_k);
player->unchecked_broadcast(bundle);
try
{
bundle.compare(*player);
agreed = true;
}
catch (mismatch_among_parties&)
{
cerr << "Parties compiled with different OT extensions" << endl;
cerr << "Set \"USE_KOS\" to the same value on all parties" << endl;
cerr << "and make sure that the SoftSpokenOT parameter is the same" << endl;
exit(1);
}
}
@@ -104,16 +126,27 @@ void OTExtensionWithMatrix::transfer(int nOTs,
#endif
}
void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newReceiverInput)
void OTExtensionWithMatrix::extend(int nOTs_requested,
const BitVector& newReceiverInput, bool hash)
{
protocol_agreement();
if (use_kos())
{
extend_correlated(nOTs_requested, newReceiverInput);
if (hash)
hash_outputs(nOTs_requested);
return;
}
#ifdef USE_KOS
extend_correlated(nOTs_requested, newReceiverInput);
hash_outputs(nOTs_requested);
assert(use_kos());
#else
resize(nOTs_requested);
if (nOTs_requested == 0)
return;
if (not channel)
channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player));
@@ -141,14 +174,18 @@ void OTExtensionWithMatrix::soft_sender(size_t n)
if (not (ot_role & SENDER))
return;
if (OnlineOptions::singleton.has_option("verbose_ot"))
fprintf(stderr, "%zu OTs as sender\n", n);
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(2);
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k);
vector<osuCrypto::block> outputs;
for (auto& x : G_receiver)
{
outputs.push_back(x.get_doubleword());
}
sender.malicious = not passive_only;
sender.setBaseOts(outputs,
{baseReceiverInput.get_ptr(), sender.baseOtCount()}, prng,
*channel);
@@ -171,8 +208,11 @@ void OTExtensionWithMatrix::soft_receiver(size_t n,
if (not (ot_role & RECEIVER))
return;
if (OnlineOptions::singleton.has_option("verbose_ot"))
fprintf(stderr, "%zu OTs as receiver\n", n);
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(2);
osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(softspoken_k);
vector<array<osuCrypto::block, 2>> inputs;
for (auto& x : G_sender)
@@ -181,6 +221,7 @@ void OTExtensionWithMatrix::soft_receiver(size_t n,
for (int i = 0; i < 2; i++)
inputs.back()[i] = x[i].get_doubleword();
}
recver.malicious = not passive_only;
recver.setBaseOts(inputs, prng, *channel);
// Choose which messages should be received.

View File

@@ -63,6 +63,10 @@ class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
bool agreed;
int softspoken_k;
void init_me();
public:
PRNG G;
@@ -76,21 +80,18 @@ public:
: OTCorrelator<BitMatrix>(player, role, passive),
nsubloops(nsubloops)
{
G.ReSeed();
agreed = false;
#ifndef USE_KOS
channel = 0;
#endif
init_me();
}
OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
~OTExtensionWithMatrix();
bool use_kos();
void protocol_agreement();
void transfer(int nOTs, const BitVector& receiverInput, int nloops);
void extend(int nOTs, const BitVector& newReceiverInput);
void extend(int nOTs, const BitVector& newReceiverInput, bool hash = true);
void extend_correlated(const BitVector& newReceiverInput);
void extend_correlated(int nOTs, const BitVector& newReceiverInput);
void transpose(int start = 0, int slice = -1);

View File

@@ -77,6 +77,8 @@ public:
OTMultiplier(OTTripleGenerator<T>& generator, int thread_num);
virtual ~OTMultiplier();
void init();
void multiply();
};

View File

@@ -88,11 +88,10 @@ OTMultiplier<T>::~OTMultiplier()
}
template<class T>
void OTMultiplier<T>::multiply()
void OTMultiplier<T>::init()
{
keyBits.set(generator.get_mac_key());
rot_ext.extend(keyBits.size(), keyBits);
this->outbox.push({});
senderOutput.resize(keyBits.size());
for (size_t j = 0; j < keyBits.size(); j++)
{
@@ -106,10 +105,18 @@ void OTMultiplier<T>::multiply()
assert(receiverOutput.size() >= keyBits.size());
receiverOutput.resize(keyBits.size());
init_authenticator(keyBits, senderOutput, receiverOutput);
}
template<class T>
void OTMultiplier<T>::multiply()
{
this->outbox.push({});
MultJob job;
while (this->inbox.pop(job))
{
if (receiverOutput.empty())
init();
if (job.input)
{
if (job.player == generator.my_num
@@ -155,13 +162,14 @@ void SemiMultiplier<T>::multiplyForBits()
otCorrelator.set_role(role);
BitVector aBits = this->generator.valueBits[0];
rot_ext.extend_correlated(aBits);
rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos());
typedef typename T::Rectangle X;
vector<Matrix<X> >& baseSenderOutputs = otCorrelator.matrices;
Matrix<X>& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput,
rot_ext.use_kos());
int n_squares = otCorrelator.receiverOutputMatrix.squares.size();
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
@@ -201,12 +209,13 @@ void SemiMultiplier<T>::multiplyForMixed()
this->generator.players[this->thread_num], BOTH, true);
BitVector aBits = this->generator.valueBits[0];
rot_ext.extend_correlated(aBits);
rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos());
auto& baseSenderOutputs = otCorrelator.matrices;
auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput,
rot_ext.use_kos());
if (this->generator.get_player().num_players() == 2)
{
@@ -265,16 +274,17 @@ void OTMultiplier<W>::multiplyForTriples()
//timers["Extension"].start();
if (generator.machine.use_extension)
{
#ifdef USE_KOS
rot_ext.extend_correlated(aBits);
#else
rot_ext.extend(aBits.size(), aBits);
corr_hash = false;
#endif
if (rot_ext.use_kos())
rot_ext.extend_correlated(aBits);
else
{
rot_ext.extend(aBits.size(), aBits);
corr_hash = false;
}
}
else
{
BaseOT bot(aBits.size(), -1, generator.players[thread_num]);
BaseOT bot(aBits.size(), generator.players[thread_num]);
bot.set_receiver_inputs(aBits);
bot.exec_base(false);
for (size_t i = 0; i < aBits.size(); i++)

View File

@@ -1,5 +1,11 @@
#include "OTTripleSetup.h"
void* run_ot(void* job)
{
((OTTripleSetup::SetupJob*)job)->run();
return 0;
}
void OTTripleSetup::setup()
{
timeval baseOTstart, baseOTend;
@@ -12,13 +18,13 @@ void OTTripleSetup::setup()
}
//baseReceiverInput.randomize(G);
vector<SetupJob> threads;
for (int i = 0; i < nparties - 1; i++)
{
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
baseOTs[i]->exec_base(false);
baseSenderInputs[i] = baseOTs[i]->sender_inputs;
baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs;
}
threads.push_back({*this, i});
for (int i = 0; i < nparties - 1; i++)
pthread_create(&threads[i].thread, 0, run_ot, &threads[i]);
for (int i = 0; i < nparties - 1; i++)
pthread_join(threads[i].thread, 0);
gettimeofday(&baseOTend, NULL);
#ifdef VERBOSE_BASEOT
double basetime = timeval_diff(&baseOTstart, &baseOTend);
@@ -34,6 +40,14 @@ void OTTripleSetup::setup()
// (since Sender finishes baseOTs before Receiver)
}
void OTTripleSetup::run(int i)
{
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
baseOTs[i]->exec_base(false);
baseSenderInputs[i] = baseOTs[i]->sender_inputs;
baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs;
}
void OTTripleSetup::close_connections()
{
for (size_t i = 0; i < players.size(); i++)
@@ -47,7 +61,7 @@ OTTripleSetup OTTripleSetup::get_fresh()
OTTripleSetup res = *this;
for (int i = 0; i < nparties - 1; i++)
{
BaseOT bot(nbase, 128, 0);
BaseOT bot(nbase, 0);
bot.sender_inputs = baseSenderInputs[i];
bot.receiver_outputs = baseReceiverOutputs[i];
bot.set_seeds();

View File

@@ -13,6 +13,8 @@
*/
class OTTripleSetup
{
void run(int i);
BitVector base_receiver_inputs;
vector<BaseOT*> baseOTs;
@@ -22,8 +24,27 @@ class OTTripleSetup
int nbase;
public:
class SetupJob
{
OTTripleSetup& setup;
int i;
public:
pthread_t thread;
SetupJob(OTTripleSetup& setup, int i) :
setup(setup), i(i), thread(0)
{
}
void run()
{
setup.run(i);
}
};
map<string,Timer> timers;
vector<OffsetPlayer*> players;
vector<TwoPartyPlayer*> players;
vector< vector< array<BitVector, 2> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
@@ -56,16 +77,16 @@ public:
else
other_player = i;
players.push_back(new OffsetPlayer(N, N.get_offset(other_player)));
players.push_back(new VirtualTwoPartyPlayer(N, other_player));
// sets up a pair of base OTs, playing both roles
if (real_OTs)
{
baseOTs[i] = new BaseOT(nbase, 128, players[i]);
baseOTs[i] = new BaseOT(nbase, players[i]);
}
else
{
baseOTs[i] = new FakeOT(nbase, 128, players[i]);
baseOTs[i] = new FakeOT(nbase, players[i]);
}
}

View File

@@ -70,6 +70,29 @@ int BaseMachine::bucket_size(size_t usage)
return res;
}
int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols)
{
unsigned res = min(100, OnlineOptions::singleton.batch_size);
if (has_program())
res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols));
return res;
}
int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
{
if (has_program())
{
auto res = s().progs[0].get_offline_data_used().matmuls[
{n_rows, n_inner, n_cols}];
if (res)
return res;
else
return -1;
}
else
return -1;
}
BaseMachine::BaseMachine() : nthreads(0)
{
if (sodium_init() == -1)
@@ -287,7 +310,7 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (nthreads > 1)
if (multithread)
cerr << "; rounds counted double due to multi-threading";
if (not OnlineOptions::singleton.verbose)
cerr << "; use '-v' for more details";

View File

@@ -44,6 +44,7 @@ public:
string progname;
int nthreads;
bool multithread;
ThreadQueues queues;
@@ -66,10 +67,14 @@ public:
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
template<class T>
static int input_batch_size(int player, int buffer_size = 0);
template<class T>
static int edabit_batch_size(int n_bits, int buffer_size = 0);
static int edabit_bucket_size(int n_bits);
static int triple_bucket_size(DataFieldType type);
static int bucket_size(size_t usage);
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
BaseMachine();
virtual ~BaseMachine() {}
@@ -105,6 +110,10 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
template<class T>
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
{
if (OnlineOptions::singleton.has_option("debug_batch_size"))
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
fallback);
int n_opts;
int n = 0;
int res = 0;
@@ -114,7 +123,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
else if (fallback > 0)
n_opts = fallback;
else
n_opts = OnlineOptions::singleton.batch_size;
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
if (buffer_size <= 0 and has_program())
{
@@ -132,7 +141,6 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
{
n = buffer_size;
buffer_size = 0;
n_opts = OnlineOptions::singleton.batch_size;
}
if (n > 0 and not (buffer_size > 0))
@@ -161,16 +169,33 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
else
res = n_opts;
#ifdef DEBUG_BATCH_SIZE
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n="
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
#endif
if (OnlineOptions::singleton.has_option("debug_batch_size"))
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
<< " buffer_size=" << buffer_size << endl;
assert(res > 0);
return res;
}
template<class T>
int BaseMachine::input_batch_size(int player, int buffer_size)
{
if (buffer_size)
return buffer_size;
if (has_program())
{
auto res =
s().progs[0].get_offline_data_used(
).inputs[player][T::clear::field_type()];
if (res > 0)
return res;
}
return OnlineOptions::singleton.batch_size;
}
template<class T>
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
{

View File

@@ -29,10 +29,12 @@ public:
Conv2dTuple(const vector<int>& args, int start);
array<int, 3> matrix_dimensions();
template<class T>
void pre(vector<T>& S, typename T::Protocol& protocol);
void pre(StackedVector<T>& S, typename T::Protocol& protocol);
template<class T>
void post(vector<T>& S, typename T::Protocol& protocol);
void post(StackedVector<T>& S, typename T::Protocol& protocol);
template<class T>
void run_matrix(SubProcessor<T>& processor);

View File

@@ -12,8 +12,10 @@
#include "Networking/Player.h"
#include "Protocols/edabit.h"
#include "PrepBase.h"
#include "PrepBuffer.h"
#include "EdabitBuffer.h"
#include "Tools/TimerWithComm.h"
#include "Tools/CheckVector.h"
#include <fstream>
#include <map>
@@ -104,8 +106,6 @@ class Preprocessing : public PrepBase
protected:
static const bool use_part = false;
DataPositions& usage;
bool do_count;
void count(Dtype dtype, int n = 1)
@@ -115,9 +115,9 @@ protected:
template<int>
void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
template<int>
void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
void get_edabits(bool, size_t, T*, StackedVector<typename T::bit_type>&,
const vector<int>&, true_type)
{ throw not_implemented(); }
@@ -126,6 +126,8 @@ protected:
T get_random_from_inputs(int nplayers);
public:
int buffer_size;
template<class U, class V>
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
SubProcessor<T>* proc);
@@ -135,7 +137,8 @@ public:
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
DataPositions& usage);
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
Preprocessing(DataPositions& usage) :
PrepBase(usage), do_count(true), buffer_size(0) {}
virtual ~Preprocessing() {}
virtual void set_protocol(typename T::Protocol&) {};
@@ -151,7 +154,7 @@ public:
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
virtual void get_input_no_count(T&, typename T::open_type&, int)
{ throw not_implemented() ; }
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
virtual void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
{ throw not_implemented(); }
void get(Dtype dtype, T* a);
@@ -159,7 +162,7 @@ public:
void get_two(Dtype dtype, T& a, T& b);
void get_one(Dtype dtype, T& a);
void get_input(T& a, typename T::open_type& x, int i);
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
/// Get fresh random multiplication triple
virtual array<T, 3> get_triple(int n_bits);
@@ -174,7 +177,7 @@ public:
virtual void get_dabit(T& a, typename T::bit_type& b);
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs)
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs)
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
virtual void get_edabit_no_count(bool, int, edabit<T>&)
{ throw runtime_error("no edaBits"); }
@@ -201,11 +204,11 @@ class Sub_Data_Files : public Preprocessing<T>
static int tuple_length(int dtype);
BufferOwner<T, T> buffers[N_DTYPE];
vector<BufferOwner<T, T>> input_buffers;
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
array<PrepBuffer<T>, N_DTYPE> buffers;
vector<PrepBuffer<T>> input_buffers;
PrepBuffer<InputTuple<T>, RefInputTuple<T>, T> my_input_buffers;
map<DataTag, PrepBuffer<T> > extended;
PrepBuffer<dabit<T>, dabit<T>, T> dabit_buffer;
map<int, EdabitBuffer<T>> edabit_buffers;
map<int, edabitvec<T>> my_edabits;
@@ -284,7 +287,7 @@ public:
}
void setup_extended(const DataTag& tag, int tuple_size = 0);
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_dabit_no_count(T& a, typename T::bit_type& b);
part_type& get_part();
@@ -397,7 +400,7 @@ inline void Preprocessing<T>::get_input(T& a, typename T::open_type& x, int i)
}
template<class T>
inline void Preprocessing<T>::get(vector<T>& S, DataTag tag,
inline void Preprocessing<T>::get(StackedVector<T>& S, DataTag tag,
const vector<int>& regs, int vector_size)
{
usage.count(T::clear::field_type(), tag, vector_size);

View File

@@ -143,14 +143,14 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
{
if (T::clear::allows(Dtype(dtype)))
{
buffers[dtype].setup(
buffers[dtype].setup(num_players,
PrepBase::get_filename(prep_data_dir, Dtype(dtype), type_short,
my_num, thread_num), tuple_length(dtype), type_string,
DataPositions::dtype_names[dtype]);
}
}
dabit_buffer.setup(
dabit_buffer.setup(num_players,
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
type_short, my_num, thread_num), dabit<T>::size(), type_string,
DataPositions::dtype_names[DATA_DABIT]);
@@ -161,10 +161,10 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
string filename = PrepBase::get_input_filename(prep_data_dir,
type_short, i, my_num, thread_num);
if (i == my_num)
my_input_buffers.setup(filename,
my_input_buffers.setup(num_players, filename,
InputTuple<T>::size(), type_string);
else
input_buffers[i].setup(filename,
input_buffers[i].setup(num_players, filename,
T::size(), type_string);
}
@@ -344,14 +344,14 @@ void Sub_Data_Files<T>::setup_extended(const DataTag& tag, int tuple_size)
{
stringstream ss;
ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num;
buffer.setup(ss.str(), tuple_length);
buffer.setup(num_players, ss.str(), tuple_length);
}
buffer.check_tuple_length(tuple_length);
}
template<class T>
void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
void Sub_Data_Files<T>::get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
{
setup_extended(tag, regs.size());
for (int j = 0; j < vector_size; j++)

View File

@@ -12,6 +12,7 @@ using namespace std;
#include "Math/BitVec.h"
#include "Data_Files.h"
#include "Protocols/Replicated.h"
#include "Protocols/ReplicatedPrep.h"
#include "Protocols/MAC_Check_Base.h"
#include "Processor/Input.h"
@@ -109,7 +110,7 @@ public:
};
template<class T>
class DummyLivePrep : public Preprocessing<T>
class DummyLivePrep : public BufferPrep<T>
{
public:
static const bool homomorphic = true;
@@ -133,16 +134,16 @@ public:
}
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
DummyLivePrep(DataPositions& usage, bool = true) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
@@ -165,7 +166,7 @@ public:
{
fail();
}
void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
{
fail();
}

View File

@@ -81,7 +81,6 @@ int ExternalClients::init_client_connection(const string& host, int portnum,
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
true);
if (party_num == 0)
{
octetStream specification;
specification.Receive(socket);

View File

@@ -15,7 +15,7 @@
#include <iomanip>
template<class cgf2n>
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
void Instruction::execute_clear_gf2n(StackedVector<cgf2n>& registers,
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
{
auto& C2 = registers;
@@ -30,7 +30,7 @@ void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
}
template<class cgf2n>
void Instruction::gbitdec(vector<cgf2n>& registers) const
void Instruction::gbitdec(StackedVector<cgf2n>& registers) const
{
for (int j = 0; j < size; j++)
{
@@ -44,7 +44,7 @@ void Instruction::gbitdec(vector<cgf2n>& registers) const
}
template<class cgf2n>
void Instruction::gbitcom(vector<cgf2n>& registers) const
void Instruction::gbitcom(StackedVector<cgf2n>& registers) const
{
for (int j = 0; j < size; j++)
{
@@ -124,7 +124,7 @@ ostream& operator<<(ostream& s, const Instruction& instr)
return s;
}
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_short>& registers,
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_long>& registers,
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;

Some files were not shown because too many files have changed in this diff Show More