mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Expected communication cost in compiler.
This commit is contained in:
@@ -1,5 +1,14 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.4.2 (Dec 24, 2025)
|
||||
|
||||
- Expected communication cost in compiler
|
||||
- Semi-honest option of Rep4
|
||||
- Reduced communication for preprocessing in Dealer protocol
|
||||
- Option of choosing SoftSpoken parameter at run-time
|
||||
- BERT functionality (@hiddely)
|
||||
- Recommended reading list in documentation
|
||||
|
||||
## 0.4.1 (May 30, 2025)
|
||||
|
||||
- Add protocols with function-dependent preprocessing (https://eprint.iacr.org/2025/919)
|
||||
|
||||
@@ -618,6 +618,10 @@ class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable):
|
||||
code = opcodes['REVEAL']
|
||||
arg_format = tools.cycle(['int','cbw','sb'])
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('bit', 'open'), sum(
|
||||
int(math.ceil(x / 64)) * 8 for x in self.args[0::3]))
|
||||
|
||||
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
""" Copy private input to secret bit register vectors. The input is
|
||||
read as floating-point number, multiplied by a power of two, and then
|
||||
|
||||
@@ -24,9 +24,14 @@ from functools import reduce
|
||||
class _binary:
|
||||
def __or__(self, other):
|
||||
return self ^ other ^ (self & other)
|
||||
__ror__ = __or__
|
||||
def reveal_to(self, *args, **kwargs):
|
||||
raise CompilerError(
|
||||
'%s does not support revealing to individual players' % type(self))
|
||||
@staticmethod
|
||||
def direct_matrix_mul(*args, **kwargs):
|
||||
raise AttributeError('direct matrix multiplication only supported '
|
||||
'in arithmetic circuits')
|
||||
|
||||
class bits(Tape.Register, _structure, _bit, _binary):
|
||||
n = 40
|
||||
@@ -432,7 +437,7 @@ class cbits(bits):
|
||||
inst.convcbitvec(self.n, res, self)
|
||||
return res
|
||||
|
||||
class sbits(bits):
|
||||
class sbits(bits, Tape._no_secret_truth):
|
||||
"""
|
||||
Secret bits register. This type supports basic bit-wise operations::
|
||||
|
||||
@@ -697,7 +702,7 @@ class sbits(bits):
|
||||
def output(self):
|
||||
inst.print_reg_plainsb(self)
|
||||
|
||||
class sbitvec(_vec, _bit, _binary):
|
||||
class sbitvec(Tape._no_secret_truth, _vec, _bit, _binary):
|
||||
""" Vector of registers of secret bits, effectively a matrix of secret bits.
|
||||
This facilitates parallel arithmetic operations in binary circuits.
|
||||
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
|
||||
@@ -907,35 +912,27 @@ class sbitvec(_vec, _bit, _binary):
|
||||
def __init__(self, elements=None, length=None, input_length=None):
|
||||
if length:
|
||||
assert isinstance(elements, sint)
|
||||
if Program.prog.use_split():
|
||||
x = elements.split_to_two_summands(length)
|
||||
v = sbitint.bit_adder(x[0], x[1])
|
||||
else:
|
||||
prog = Program.prog
|
||||
if not prog.options.ring:
|
||||
# force the use of edaBits
|
||||
backup = prog.use_edabit()
|
||||
prog.use_edabit(True)
|
||||
self.v = prog.non_linear.bit_dec(
|
||||
elements, max(length, input_length or prog.bit_length),
|
||||
length, maybe_mixed=True)
|
||||
assert isinstance(self.v[0], sbits)
|
||||
prog.use_edabit(backup)
|
||||
return
|
||||
comparison.require_ring_size(length, 'A2B conversion')
|
||||
l = int(Program.prog.options.ring)
|
||||
r, r_bits = sint.get_edabit(length, size=elements.size)
|
||||
c = ((elements - r) << (l - length)).reveal()
|
||||
c >>= l - length
|
||||
cb = [(c >> i) for i in range(length)]
|
||||
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
|
||||
v = x.v
|
||||
self.v = v[:length]
|
||||
prog = Program.prog
|
||||
backup = prog.use_edabit()
|
||||
if not prog.have_a2b():
|
||||
# force the use of edaBits
|
||||
prog.use_edabit(True)
|
||||
self.v = prog.non_linear.bit_dec(
|
||||
elements, max(length, input_length or prog.bit_length),
|
||||
length, maybe_mixed=True)
|
||||
assert isinstance(self.v[0], sbits)
|
||||
prog.use_edabit(backup)
|
||||
elif isinstance(elements, sbitvec):
|
||||
self.v = elements.v
|
||||
elif isinstance(elements, (list, tuple)) and \
|
||||
isinstance(elements[0], sbitvec):
|
||||
self.v = sbitvec(sum((x.elements() for x in elements), [])).v
|
||||
elif elements is not None and not (util.is_constant(elements) and \
|
||||
elements == 0):
|
||||
self.v = sbits.trans(elements)
|
||||
def __str__(self):
|
||||
return 'sbitvec(%s/%s)' % (len(self.v), self.size)
|
||||
__repr__ = __str__
|
||||
def popcnt(self):
|
||||
""" Population count / Hamming weight.
|
||||
|
||||
@@ -961,7 +958,7 @@ class sbitvec(_vec, _bit, _binary):
|
||||
return self.from_vec(x ^ y for x, y in zip(*self.expand(other)))
|
||||
def __and__(self, other):
|
||||
return self.from_vec(x & y for x, y in zip(*self.expand(other)))
|
||||
__rxor__ = __xor__
|
||||
__add__ = __radd__ = __sub__ = __rsub__ =__rxor__ = __xor__
|
||||
__rand__ = __and__
|
||||
def __invert__(self):
|
||||
return self.from_vec(~x for x in self.v)
|
||||
@@ -969,10 +966,6 @@ class sbitvec(_vec, _bit, _binary):
|
||||
return util.if_else(self.v[0], x, y)
|
||||
def __iter__(self):
|
||||
return iter(self.elements())
|
||||
def __len__(self):
|
||||
return len(self.v)
|
||||
def __getitem__(self, index):
|
||||
return self.v[index]
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cls):
|
||||
@@ -999,8 +992,6 @@ class sbitvec(_vec, _bit, _binary):
|
||||
return util.untuplify([x.reveal() for x in self.elements()])
|
||||
def long_one(self):
|
||||
return [x.long_one() for x in self.v]
|
||||
def __rsub__(self, other):
|
||||
return self.from_vec(y - x for x, y in zip(self.v, other))
|
||||
def half_adder(self, other):
|
||||
other = self.coerce(other)
|
||||
res = zip(*(x.half_adder(y) for x, y in zip(self.v, other)))
|
||||
@@ -1014,7 +1005,7 @@ class sbitvec(_vec, _bit, _binary):
|
||||
elif len(self.v) == 1:
|
||||
self, other = other, self.v[0]
|
||||
else:
|
||||
raise CompilerError('no operand of lenght 1: %d/%d',
|
||||
raise CompilerError('no operand of length 1: %d/%d',
|
||||
(len(self.v), len(other.v)))
|
||||
if not isinstance(other, sbits):
|
||||
return NotImplemented
|
||||
@@ -1036,8 +1027,6 @@ class sbitvec(_vec, _bit, _binary):
|
||||
i += 1
|
||||
return sbitvec.from_vec(res)
|
||||
__rmul__ = __mul__
|
||||
def __add__(self, other):
|
||||
return self.from_vec(x + y for x, y in zip(self.v, other))
|
||||
def bit_and(self, other):
|
||||
return self & other
|
||||
def bit_xor(self, other):
|
||||
@@ -1060,7 +1049,12 @@ class sbitvec(_vec, _bit, _binary):
|
||||
@classmethod
|
||||
def comp_result(cls, x):
|
||||
return cls.get_type(1).from_vec([x])
|
||||
def expand(self, other, expand=True):
|
||||
@staticmethod
|
||||
def reverse_type(other):
|
||||
return isinstance(other, sbitfixvec)
|
||||
equal = __eq__ = _bitint.__eq__
|
||||
eqz = staticmethod(_bitint.eqz)
|
||||
def expand(self, other, expand=True, copy=False):
|
||||
assert not isinstance(other, sbitfixvec)
|
||||
m = 1
|
||||
for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []):
|
||||
@@ -1076,7 +1070,10 @@ class sbitvec(_vec, _bit, _binary):
|
||||
res.append([x * sbits.get_type(m)().long_one()
|
||||
for x in util.bit_decompose(y, len(self.v))])
|
||||
else:
|
||||
v = [type(x)(x) if isinstance(x, bits) else x for x in y.v]
|
||||
if copy:
|
||||
v = [type(x)(x) if isinstance(x, bits) else x for x in y.v]
|
||||
else:
|
||||
v = y.v
|
||||
res.append([x.expand(m) if (expand and isinstance(x, bits))
|
||||
else x for x in v])
|
||||
return res
|
||||
@@ -1364,7 +1361,21 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
|
||||
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
"""
|
||||
Vector of signed integers for parallel binary computation.
|
||||
Values and vectors of signed integers for parallel binary computation::
|
||||
|
||||
si32 = sbitintvec.get_type(32)
|
||||
print_ln('add: %s', (si32(5) + si32(3)).reveal())
|
||||
print_ln('sub: %s', (si32(5) - si32(3)).reveal())
|
||||
print_ln('mul: %s', (si32(5) * si32(3)).reveal())
|
||||
print_ln('lt: %s', (si32(5) < si32(3)).reveal())
|
||||
|
||||
This should output::
|
||||
|
||||
add: 8
|
||||
sub: 2
|
||||
mul: 15
|
||||
lt: 0
|
||||
|
||||
The following example uses vectors of size two::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
@@ -1389,7 +1400,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
|
||||
"""
|
||||
bit_extend = staticmethod(_complement_two_extend)
|
||||
mul_functions = {}
|
||||
functions = {}
|
||||
@classmethod
|
||||
def popcnt_bits(cls, bits):
|
||||
return sbitvec.from_vec(bits).popcnt()
|
||||
@@ -1406,8 +1417,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
if len(a) == 1:
|
||||
res = _bitint.bit_adder(a, b, get_carry=True)
|
||||
return self.get_type(32).from_vec(res, signed=False)
|
||||
v = sbitint.bit_adder(a, b)
|
||||
return self.get_type(len(v)).from_vec(v)
|
||||
return self.maybe_function(self.binary_add, a, b)
|
||||
__radd__ = __add__
|
||||
__sub__ = _bitint.__sub__
|
||||
def __rsub__(self, other):
|
||||
@@ -1424,9 +1434,12 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return NotImplemented
|
||||
try:
|
||||
my_bits, other_bits = self.expand(other, False)
|
||||
my_bits, other_bits = self.expand(other, False, copy=True)
|
||||
except:
|
||||
return NotImplemented
|
||||
return self.maybe_function(self.binary_mul, my_bits, other_bits)
|
||||
@classmethod
|
||||
def maybe_function(cls, call, my_bits, other_bits, result_length=None):
|
||||
m = float('inf')
|
||||
uniform = True
|
||||
for x in itertools.chain(my_bits, other_bits):
|
||||
@@ -1437,21 +1450,26 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
pass
|
||||
if uniform and Program.prog.options.cisc:
|
||||
bl = len(my_bits)
|
||||
key = bl, len(other_bits)
|
||||
if key not in self.mul_functions:
|
||||
ol = result_length or bl
|
||||
key = call.__name__, ol, bl, len(other_bits)
|
||||
if key not in cls.functions:
|
||||
def instruction(*args):
|
||||
res = self.binary_mul(args[bl:2 * bl], args[2 * bl:],
|
||||
args[0].n)
|
||||
res = call(args[ol:ol + bl], args[ol + bl:], args[0].n)
|
||||
for x, y in zip(sbitvec.from_vec(res).v, args):
|
||||
x.mov(y, x)
|
||||
instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits))
|
||||
self.mul_functions[key] = instructions_base.cisc(instruction,
|
||||
bl)
|
||||
res = [sbits.get_type(m)() for i in range(bl)]
|
||||
self.mul_functions[key](*(res + my_bits + other_bits))
|
||||
return self.from_vec(res)
|
||||
instruction.__name__ = '%s%sx%s' % (call.__name__, bl, len(other_bits))
|
||||
cls.functions[key] = instructions_base.cisc(instruction, ol)
|
||||
res = [sbits.get_type(m)() for i in range(ol)]
|
||||
cls.functions[key](*(res + my_bits + other_bits))
|
||||
if result_length:
|
||||
return res
|
||||
else:
|
||||
return cls.from_vec(res)
|
||||
else:
|
||||
return self.binary_mul(my_bits, other_bits, m)
|
||||
return call(my_bits, other_bits, m)
|
||||
@classmethod
|
||||
def binary_add(cls, a, b, m):
|
||||
return cls.from_vec(sbitint.bit_adder(a, b))
|
||||
@classmethod
|
||||
def binary_mul(cls, my_bits, other_bits, m):
|
||||
matrix = []
|
||||
@@ -1468,21 +1486,21 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
def TruncMul(self, other, k, m, kappa=None, nearest=False):
|
||||
if nearest:
|
||||
raise CompilerError('round to nearest not implemented')
|
||||
if not isinstance(other, sbitintvec):
|
||||
other = sbitintvec(other)
|
||||
if isinstance(other, int):
|
||||
b = other
|
||||
else:
|
||||
if not isinstance(other, sbitintvec):
|
||||
other = sbitintvec(other)
|
||||
b = self.get_type(k).from_vec(_complement_two_extend(other.v, k))
|
||||
a = self.get_type(k).from_vec(_complement_two_extend(self.v, k))
|
||||
b = self.get_type(k).from_vec(_complement_two_extend(other.v, k))
|
||||
tmp = a * b
|
||||
assert len(tmp.v) == k
|
||||
return self.get_type(k - m).from_vec(tmp[m:])
|
||||
return self.get_type(k - m).from_vec(tmp.v[m:])
|
||||
def pow2(self, k):
|
||||
""" Computer integer power of two.
|
||||
|
||||
:param k: bit length of input """
|
||||
return _sbitintbase.pow2(self, k)
|
||||
@staticmethod
|
||||
def reverse_type(other):
|
||||
return isinstance(other, sbitfixvec)
|
||||
|
||||
sbits.vec = sbitvec
|
||||
sbitint.vec = sbitintvec
|
||||
@@ -1511,6 +1529,14 @@ class cbitfix(object):
|
||||
v = self.v
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
def __iter__(self):
|
||||
return iter([self])
|
||||
def error(*args, **kwargs):
|
||||
raise CompilerError(
|
||||
'Support for revealed fixed-point values in binary circuits '
|
||||
'is currently limited to simple outputs. '
|
||||
'Please file a feature request if you need this for an application.')
|
||||
__add__ = __mul__ = __sub__ = error
|
||||
|
||||
class sbitfix(_fix, _binary):
|
||||
""" Secret signed fixed-point number in one binary register.
|
||||
@@ -1581,14 +1607,29 @@ class sbitfix(_fix, _binary):
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
class sbitfixvec(_fix, _vec, _binary):
|
||||
""" Vector of fixed-point numbers for parallel binary computation.
|
||||
"""
|
||||
Values and vectors of fixed-point numbers for parallel binary computation::
|
||||
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
print_ln('add: %s', (sbitfixvec(0.5) + sbitfixvec(0.3)).reveal())
|
||||
print_ln('mul: %s', (sbitfixvec(0.5) * sbitfixvec(0.3)).reveal())
|
||||
print_ln('sub: %s', (sbitfixvec(0.5) - sbitfixvec(0.3)).reveal())
|
||||
print_ln('lt: %s', (sbitfixvec(0.5) < sbitfixvec(0.3)).reveal())
|
||||
|
||||
Example::
|
||||
will output roughly::
|
||||
|
||||
a = sbitfixvec([sbitfix(0.3), sbitfix(0.5)])
|
||||
b = sbitfixvec([sbitfix(0.4), sbitfix(0.6)])
|
||||
add: 0.800003
|
||||
mul: 0.149994
|
||||
sub: 0.199997
|
||||
lt: 0
|
||||
|
||||
Note that the default precision (16 bits after the dot, 31 bits in
|
||||
total) only allows numbers up to :math:`2^{31-16-1} \\approx
|
||||
16000`. You can increase this using :py:func:`set_precision`.
|
||||
|
||||
Refer to the following example for the vector functionality::
|
||||
|
||||
a = sbitfixvec([sbitfixvec(0.3), sbitfixvec(0.5)])
|
||||
b = sbitfixvec([sbitfixvec(0.4), sbitfixvec(0.6)])
|
||||
c = (a + b).elements()
|
||||
print_ln('add: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
c = (a * b).elements()
|
||||
@@ -1606,13 +1647,12 @@ class sbitfixvec(_fix, _vec, _binary):
|
||||
lt: 1, 1
|
||||
|
||||
"""
|
||||
int_type = sbitintvec.get_type(sbitfix.k)
|
||||
float_type = type(None)
|
||||
clear_type = cbitfix
|
||||
rep_type = staticmethod(lambda x: x)
|
||||
@property
|
||||
def bit_type(self):
|
||||
return type(self.v[0])
|
||||
return type(self.v.v[0])
|
||||
@classmethod
|
||||
def set_precision(cls, f, k=None):
|
||||
super(sbitfixvec, cls).set_precision(f=f, k=k)
|
||||
@@ -1637,7 +1677,8 @@ class sbitfixvec(_fix, _vec, _binary):
|
||||
value = self.int_type(value)
|
||||
super(sbitfixvec, self).__init__(value, *args, **kwargs)
|
||||
def elements(self):
|
||||
return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()]
|
||||
return [sbitfixvec._new(x, f=self.f, k=self.k)
|
||||
for x in self.v.elements()]
|
||||
def mul(self, other):
|
||||
if isinstance(other, sbits):
|
||||
return self._new(self.v * other)
|
||||
|
||||
@@ -712,7 +712,8 @@ class Merger:
|
||||
elif isinstance(instr, StackInstruction):
|
||||
keep_order(instr, n, StackInstruction)
|
||||
elif isinstance(instr, applyshuffle):
|
||||
shuffles[instr.args[3]].add(n)
|
||||
for handle in instr.handles():
|
||||
shuffles[handle].add(n)
|
||||
elif isinstance(instr, delshuffle):
|
||||
for i_inst in shuffles[instr.args[0]]:
|
||||
add_edge(i_inst, n)
|
||||
|
||||
@@ -61,14 +61,15 @@ class Circuit:
|
||||
return self.run(*inputs)
|
||||
|
||||
def run(self, *inputs):
|
||||
n = inputs[0][0].n, get_tape()
|
||||
inputs = [sbitvec.from_vec(x) for x in inputs]
|
||||
n = inputs[0].v[0].n, get_tape()
|
||||
if n not in self.functions:
|
||||
if get_program().force_cisc_tape:
|
||||
f = function_call_tape
|
||||
else:
|
||||
f = function_block
|
||||
self.functions[n] = f(lambda *args: self.compile(*args))
|
||||
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
|
||||
self.functions[n].name = '%s(%d)' % (self.name, inputs[0].v[0].n)
|
||||
flat_res = self.functions[n](*itertools.chain(*(
|
||||
sbitvec.from_vec(x).v for x in inputs)))
|
||||
res = []
|
||||
@@ -208,7 +209,7 @@ def sha3_256(x):
|
||||
for x in range(5):
|
||||
for i in range(w):
|
||||
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
|
||||
res[x][y][i] = S_flat[1600 - 1 -j]
|
||||
res[x][y][i] = S_flat.v[1600 - 1 -j]
|
||||
return res
|
||||
|
||||
w = 64
|
||||
@@ -313,7 +314,7 @@ class ieee_float:
|
||||
|
||||
for i in range(2):
|
||||
for j in range(10):
|
||||
values.append(sbitint.get_type(64).get_input_from(i))
|
||||
values.append(sbitintvec.get_type(64).get_input_from(i))
|
||||
|
||||
fvalues = [ieee_float(x) for x in values]
|
||||
|
||||
|
||||
@@ -67,16 +67,21 @@ def ld2i(c, n):
|
||||
def maybe_mulm(res, x, y):
|
||||
# overwrite instruction for function-dependent preprocessing protocols
|
||||
from Compiler import types
|
||||
res.link(x * y)
|
||||
program.curr_block.replace_last_reg(res, x * y)
|
||||
|
||||
def require_ring_size(k, op, suffix=''):
|
||||
def require_ring_size(k, op, suffix='', slack=0):
|
||||
if not program.options.ring:
|
||||
return
|
||||
diff = slack * (not program.allow_tight_parameters)
|
||||
k += diff
|
||||
if int(program.options.ring) < k:
|
||||
msg = 'ring size too small for %s, compile ' \
|
||||
'with \'-R %d\' or more' % (op, k)
|
||||
if k > 64 and k < 128:
|
||||
msg += ' (maybe \'-R 128\' as it is supported by default)'
|
||||
if int(program.options.ring) >= k - diff:
|
||||
msg += ", alternatively set " \
|
||||
"'program.allow_tight_parameters=True' in the program"
|
||||
raise CompilerError(msg + suffix)
|
||||
program.curr_tape.require_bit_length(k)
|
||||
|
||||
@@ -97,13 +102,13 @@ def LtzRingRaw(a, k):
|
||||
from .types import sint, _bitint
|
||||
from .GC.types import sbitvec
|
||||
if program.use_split():
|
||||
program.reading('comparison', 'ABY3')
|
||||
program.reading('comparison', 'Keller25', 'Section 6')
|
||||
summands = a.split_to_two_summands(k)
|
||||
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
|
||||
msb = carry ^ summands[0][-1] ^ summands[1][-1]
|
||||
return msb
|
||||
else:
|
||||
program.reading('comparison', 'DEK20-pre')
|
||||
program.reading('comparison', 'DEK20-pre', 'Paragraph III.D.8')
|
||||
from . import floatingpoint
|
||||
require_ring_size(k, 'comparison')
|
||||
m = k - 1
|
||||
@@ -195,7 +200,7 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
if k == m:
|
||||
return 0
|
||||
assert k > m
|
||||
program.reading('truncation', 'DEK20-pre')
|
||||
program.reading('truncation', 'DEK20-pre', 'Paragraph III.D.4')
|
||||
require_ring_size(k, 'leaky truncation')
|
||||
from .types import sint, intbitint, cint, cgf2n
|
||||
n_bits = k - m
|
||||
@@ -239,7 +244,7 @@ def Mod2m(a_prime, a, k, m, signed):
|
||||
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
|
||||
|
||||
def Mod2mRing(a_prime, a, k, m, signed):
|
||||
program.reading('modulo', 'DEK20-pre')
|
||||
program.reading('modulo', 'DEK20-pre', 'Paragraph III.D.3')
|
||||
require_ring_size(k, 'modulo power of two')
|
||||
from Compiler.types import sint, intbitint, cint
|
||||
shift = int(program.options.ring) - m
|
||||
@@ -254,7 +259,7 @@ def Mod2mRing(a_prime, a, k, m, signed):
|
||||
return res
|
||||
|
||||
def Mod2mField(a_prime, a, k, m, signed):
|
||||
program.reading('modulo', 'CdH10')
|
||||
program.reading('modulo', 'CdH10', 'Protocol 3.2')
|
||||
from .types import sint
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
@@ -349,6 +354,8 @@ def BitLTC1(u, a, b):
|
||||
a: array of clear bits
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
program.reading('constant-round bit-wise public-private comparison',
|
||||
'CdH10', 'Protocol 4.5')
|
||||
k = len(b)
|
||||
p = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
from . import floatingpoint
|
||||
@@ -489,6 +496,8 @@ def BitLTL(res, a, b):
|
||||
a: clear integer register
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
program.reading('logarithmic-round bit-wise public-private comparison',
|
||||
'CdH10', 'Protocol 4.1')
|
||||
k = len(b)
|
||||
a_bits = b[0].bit_decompose_clear(a, k)
|
||||
from .types import sint
|
||||
@@ -655,7 +664,7 @@ def Mod2(a_0, a, k, signed):
|
||||
if k <= 1:
|
||||
movs(a_0, a)
|
||||
return
|
||||
program.reading('modulo', 'CdH10')
|
||||
program.reading('modulo', 'CdH10', 'Protocol 3.4')
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
r_0 = program.curr_block.new_reg('s')
|
||||
|
||||
@@ -587,6 +587,12 @@ class Compiler:
|
||||
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
|
||||
print("Memory size:", dict(self.prog.allocated_mem))
|
||||
|
||||
comm = self.prog.expected_communication()
|
||||
if sum(comm):
|
||||
print(
|
||||
"Expected communication is %g MB online and %g MB offline." % \
|
||||
(comm[0] / 1e6, comm[1] / 1e6))
|
||||
|
||||
return self.prog
|
||||
|
||||
match = {
|
||||
@@ -608,6 +614,13 @@ class Compiler:
|
||||
else:
|
||||
return protocol + "-party.x"
|
||||
|
||||
@classmethod
|
||||
def short_protocol_name(cls, protocol):
|
||||
for x in cls.match.items():
|
||||
if protocol == x[1]:
|
||||
return x[0]
|
||||
return re.sub('^malicious-', 'mal-', protocol)
|
||||
|
||||
def local_execution(self, args=None):
|
||||
if args is None:
|
||||
args = self.runtime_args
|
||||
@@ -651,6 +664,7 @@ class Compiler:
|
||||
destinations.append('.')
|
||||
connections = [Connection(hostname) for hostname in hostnames]
|
||||
print("Setting up players...")
|
||||
lockfile = ".transfer.lock"
|
||||
|
||||
def run(i):
|
||||
dest = destinations[i]
|
||||
@@ -658,6 +672,16 @@ class Compiler:
|
||||
connection.run(
|
||||
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
|
||||
dest)
|
||||
dest_lockfile = "%s/%s" % (dest, lockfile)
|
||||
try:
|
||||
connection.run("test -e %s && exit 1; touch %s" % (
|
||||
(dest_lockfile,) * 2))
|
||||
except:
|
||||
raise Exception(
|
||||
"Problem with %s on %s. You cannot use the same directory "
|
||||
"for several instances (including the control instance). "
|
||||
"Remove %s on %s if this has been left behind from an "
|
||||
"aborted exection." % ((dest_lockfile, hostnames[i]) * 2))
|
||||
# executable
|
||||
connection.put("%s/static/%s" % (self.root, vm), dest)
|
||||
# program
|
||||
@@ -676,10 +700,12 @@ class Compiler:
|
||||
dest + "Player-Data")
|
||||
for filename in glob.glob("Player-Data/*.0"):
|
||||
connection.put(filename, dest + "Player-Data")
|
||||
connection.run("rm %s" % dest_lockfile)
|
||||
|
||||
def run_with_error(i):
|
||||
try:
|
||||
run(i)
|
||||
copied[i] = True
|
||||
except IOError:
|
||||
print('IO error when copying files, does %s have enough space?' %
|
||||
hostnames[i])
|
||||
@@ -693,13 +719,19 @@ class Compiler:
|
||||
out = fn(i)
|
||||
outputs[i] = out
|
||||
|
||||
open(lockfile, "w")
|
||||
threads = []
|
||||
copied = [False] * len(hosts)
|
||||
for i in range(len(hosts)):
|
||||
threads.append(threading.Thread(target=run_with_error, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
os.remove(lockfile)
|
||||
if False in copied:
|
||||
print("Error in remote copying, see above")
|
||||
sys.exit(1)
|
||||
|
||||
# execution
|
||||
threads = []
|
||||
|
||||
485
Compiler/cost.py
Normal file
485
Compiler/cost.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import re
|
||||
import math
|
||||
import os
|
||||
import itertools
|
||||
|
||||
class Comm:
|
||||
def __init__(self, comm=0, offline=0):
|
||||
try:
|
||||
comm = comm()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
self.online, self.offline = comm
|
||||
assert not offline
|
||||
except:
|
||||
self.online = comm or 0
|
||||
assert isinstance(self.online, (int, float))
|
||||
self.offline = offline
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.offline if index else self.online
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.online, self.offline))
|
||||
|
||||
def __add__(self, other):
|
||||
return Comm(x + y for x, y in zip(self, other))
|
||||
|
||||
def __sub__(self, other):
|
||||
return self + -1 * other
|
||||
|
||||
def __mul__(self, other):
|
||||
return Comm(x * other for x in self)
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __repr__(self):
|
||||
return 'Comm(%d, %d)' % tuple(self)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(sum(self))
|
||||
|
||||
def sanitize(self):
|
||||
try:
|
||||
return tuple(int(x) for x in self)
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
dishonest_majority = {
|
||||
'emi',
|
||||
'mascot',
|
||||
'spdz',
|
||||
'soho',
|
||||
'gear',
|
||||
}
|
||||
|
||||
semihonest = {
|
||||
'emi|soho',
|
||||
'atlas|^shamir',
|
||||
'dealer',
|
||||
}
|
||||
|
||||
ring = {
|
||||
'ring',
|
||||
'2k',
|
||||
}
|
||||
|
||||
fixed = {
|
||||
'^(ring|rep-field)': 3,
|
||||
'rep4': 6,
|
||||
'mal-rep-field': (6, 9),
|
||||
'mal-rep-ring': (lambda l: (6 * l, (l + 5) * 9)),
|
||||
'sy-rep-field': 6,
|
||||
'sy-rep-ring': lambda l: (6 * (l + 5), 0),
|
||||
'ps-rep-field': 9,
|
||||
'ps-rep-ring': lambda l: 9 * (l + 5),
|
||||
'brain': lambda l: (3 * 2 * l, 3 * (2 * (l + 5) + 3 * (2 * l + 15))),
|
||||
}
|
||||
|
||||
ot_cost = 64
|
||||
spdz2k_sec = 64
|
||||
|
||||
def lowgear_cipher_length(l):
|
||||
res = (30 + 2 * l) // 8
|
||||
return res
|
||||
|
||||
def highgear_cipher_lengths(l):
|
||||
res = 71 + 16 * l, 57 + 8 * l
|
||||
return res
|
||||
|
||||
def highgear_cipher_limbs(l):
|
||||
res = sum(int(math.ceil(x / 64)) for x in highgear_cipher_lengths(l))
|
||||
return res
|
||||
|
||||
def highgear_decrypt_length(l):
|
||||
return highgear_cipher_lengths(l)[0] / 8 + 1
|
||||
|
||||
def hemi_cipher_length(l):
|
||||
res = 16 * l + 77
|
||||
return res
|
||||
|
||||
def hemi_cipher_limbs(l):
|
||||
res = int(math.ceil(hemi_cipher_length(l) / 64))
|
||||
return res
|
||||
|
||||
variable = {
|
||||
'^shamir': lambda N: N * (N - 1) // 2,
|
||||
'atlas': lambda N: N // 2 * 4,
|
||||
'dealer': lambda N: (2 * (N - 1), 1),
|
||||
'semi': lambda N: lambda l: (
|
||||
4 * (N - 1) * l, N * (N - 1) * (l * (ot_cost + 8 * l))),
|
||||
'mascot': lambda N: lambda l: (
|
||||
4 * (N - 1) * l, N * (N - 1) * (l * (3 * ot_cost + 64 * l))),
|
||||
'spdz2k': lambda N: lambda l: (
|
||||
4 * (N - 1) * l,
|
||||
N * (N - 1) * (ot_cost * (2 * l + 4 * spdz2k_sec // 8) + \
|
||||
(l + spdz2k_sec // 8) * (4 * spdz2k_sec + 2 * l * 8) + \
|
||||
(5 * (l + 2 * spdz2k_sec // 8) * spdz2k_sec))),
|
||||
'hemi': lambda N: lambda l: (
|
||||
4 * (N - 1) * l, N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * 2),
|
||||
'temi': lambda N: lambda l: (
|
||||
4 * (N - 1) * l, (N - 1) * (hemi_cipher_limbs(l) * 8 * 2 * 2 +
|
||||
hemi_cipher_length(l) / 8 + 1) * 2),
|
||||
'soho': lambda N: lambda l: (
|
||||
4 * (N - 1) * l,
|
||||
(N - 1) * (N * highgear_cipher_limbs(l) * 8 * 2 +
|
||||
highgear_decrypt_length(l)) * 2),
|
||||
'owgear': lambda N: lambda l: (
|
||||
4 * (N - 1) * l,
|
||||
N * ((N - 1) * (lowgear_cipher_length(l) * (128 + 48) + 64) + 2 * l)),
|
||||
'.*i.*gear': lambda N: lambda l: (
|
||||
4 * (N - 1) * l,
|
||||
(N - 1) * (highgear_cipher_limbs(l) * 96 * 3 +
|
||||
highgear_decrypt_length(l) * 16 + N * 192 + 6 * l)),
|
||||
'sy-shamir': lambda N: 2 * variable['^shamir'](N) + variable_random['^shamir|atlas'](N)
|
||||
}
|
||||
|
||||
variable_square = {
|
||||
'soho': lambda N: lambda l: (
|
||||
0, (N - 1) * (N * highgear_cipher_limbs(l) * 8 + 46) * 2),
|
||||
'i.*gear': lambda N: lambda l: (
|
||||
0, (N - 1) * (highgear_cipher_limbs(l) * 64 * 3 +
|
||||
highgear_decrypt_length(l) * 12 + N * 128 + 4 * l)),
|
||||
'ps-rep-ring': lambda N: lambda l: fixed['ps-rep-ring'](l),
|
||||
'sy-shamir': lambda N: (
|
||||
0, variable['sy-shamir'](N) + variable_random['sy-shamir'](N))
|
||||
}
|
||||
|
||||
matrix_triples = {
|
||||
'dealer': lambda N: (N - 1, 1),
|
||||
}
|
||||
|
||||
diag_matrix = {
|
||||
'hemi': lambda N, l, dims: N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * \
|
||||
(dims[0] * dims[1] + dims[0] * dims[2]),
|
||||
'temi': lambda N, l, dims: (N - 1) * (
|
||||
hemi_cipher_limbs(l) * 8 * 2 * 2 * (
|
||||
dims[0] * dims[1] + dims[0] * dims[2]) +
|
||||
(hemi_cipher_length(l) / 8 + 1) * 2 * (dims[0] * dims[2])),
|
||||
}
|
||||
|
||||
fixed_bit = {
|
||||
'mal-rep-field': (0, 11),
|
||||
'rep4': (0, 8),
|
||||
}
|
||||
|
||||
fixed_square = {
|
||||
'mal-rep-ring': lambda l: fixed['mal-rep-ring'](l)[1],
|
||||
}
|
||||
|
||||
variable_bit = {
|
||||
'dealer': lambda N: (0, 1),
|
||||
# missing OT cost
|
||||
'emi': lambda N: lambda l: (0, l + ot_cost / 8) if N == 2 else None,
|
||||
'mal-shamir': lambda N: (
|
||||
0, variable_random['^shamir|atlas'](N) + \
|
||||
math.ceil(N / 2) * variable_input['^shamir|atlas'](N) + \
|
||||
(math.ceil(N / 2) - 0) * variable['^shamir'](N) + \
|
||||
2 * reveal_variable['(mal|sy)-shamir'](N)),
|
||||
}
|
||||
|
||||
fixed_and = {
|
||||
'(mal|sy|ps)-rep': lambda bucket_size=4: (6, 3 * (3 * bucket_size - 2)),
|
||||
}
|
||||
|
||||
variable_and = {
|
||||
'emi': lambda N: (4 * (N - 1), N * (N - 1) * ot_cost)
|
||||
}
|
||||
|
||||
trunc_pr = {
|
||||
'^ring': 4,
|
||||
'rep-field': 1,
|
||||
'rep4': 12,
|
||||
}
|
||||
|
||||
bit2a = {
|
||||
'^(ring|rep-field)': 3,
|
||||
}
|
||||
|
||||
dabit_from_bit = {
|
||||
'ring',
|
||||
'-rep-ring',
|
||||
'semi2k',
|
||||
}
|
||||
|
||||
bits_from_squares = {
|
||||
'atlas': lambda N: N > 4,
|
||||
'sy-shamir': lambda N: True,
|
||||
'soho': lambda N: True,
|
||||
'gear': lambda N: True,
|
||||
'ps-rep-ring': lambda N: True,
|
||||
'spdz2k': lambda N: True,
|
||||
'mascot': lambda N: True,
|
||||
'mal-rep-ring': lambda N: True,
|
||||
'emi$': lambda N: True,
|
||||
}
|
||||
|
||||
reveal = {
|
||||
'((^|rep.*)ring|rep-field|brain)': 3,
|
||||
'rep4': 4,
|
||||
}
|
||||
|
||||
reveal_variable = {
|
||||
'^shamir|atlas': lambda N: 3 * (N - 1) // 2,
|
||||
'(mal|sy)-shamir': lambda N: (N - 1) // 2 * 2 * N,
|
||||
'dealer': lambda N: 2 * (N - 2),
|
||||
'spdz2k': lambda N: N * variable_input['mascot|spdz2k'](N),
|
||||
}
|
||||
|
||||
fixed_input = {
|
||||
'(^|ps-|mal-)(ring|rep-)': 1,
|
||||
'sy-rep-ring': lambda l: 4 * (l + 5),
|
||||
'sy-rep-field': 4,
|
||||
'rep4': 2,
|
||||
}
|
||||
|
||||
variable_input = {
|
||||
'^shamir|atlas': lambda N: N // 2,
|
||||
'mal-shamir': lambda N: N // 2,
|
||||
'sy-shamir': lambda N: \
|
||||
N // 2 + variable['^shamir'](N) + variable_random['^shamir|atlas'](N),
|
||||
'mascot|spdz2k': lambda N: (N - 1) * Comm(1, ot_cost * 2),
|
||||
'owgear': lambda N: lambda l: (
|
||||
(N - 1) * l, (N - 1) * lowgear_cipher_length(l) * 16),
|
||||
'i.*gear': lambda N: lambda l: (
|
||||
(N - 1) * l, (N - 1) * (highgear_cipher_limbs(l) * 24 + 32 +
|
||||
highgear_decrypt_length(l) * 4)),
|
||||
}
|
||||
|
||||
variable_random = {
|
||||
'^shamir|atlas': lambda N: N * (N // 2) / ((N + 2) // 2),
|
||||
'mal-shamir': lambda N: N // 2 * N,
|
||||
'sy-shamir': lambda N: \
|
||||
2 * variable_random['^shamir|atlas'](N) + variable['^shamir'](N),
|
||||
}
|
||||
|
||||
# cut random values
|
||||
fixed_randoms = {
|
||||
'sy-rep-ring': lambda l: 3 * (l + 5),
|
||||
}
|
||||
|
||||
cheap_dot_product = {
|
||||
'^(ring|rep-field)',
|
||||
'sy-*',
|
||||
'^shamir',
|
||||
'rep4',
|
||||
'atlas',
|
||||
}
|
||||
|
||||
shuffle_application = {
|
||||
'^(ring|rep-field)': 6,
|
||||
'sy-rep-field': 12,
|
||||
'sy-rep-ring': lambda l: 12 * (l + 5)
|
||||
}
|
||||
|
||||
variable_edabit = {
|
||||
'dealer': lambda N: lambda n_bits: lambda l: l + n_bits / 8
|
||||
}
|
||||
|
||||
def find_match(data, protocol):
|
||||
for x in data:
|
||||
if re.search(x, protocol):
|
||||
return x
|
||||
|
||||
def get_match(data, protocol):
|
||||
x = find_match(data, protocol)
|
||||
try:
|
||||
return data.get(x)
|
||||
except:
|
||||
return bool(x)
|
||||
|
||||
def get_match_variable(data, protocol, n_parties):
|
||||
f = get_match(data, protocol)
|
||||
if f:
|
||||
return f(n_parties)
|
||||
|
||||
def apply_length(unit, length):
|
||||
try:
|
||||
return Comm(unit(length))
|
||||
except:
|
||||
return Comm(unit) * length
|
||||
|
||||
def get_cost(fixed, variable, protocol, n_parties):
|
||||
return get_match(fixed, protocol) or \
|
||||
get_match_variable(variable, protocol, n_parties)
|
||||
|
||||
def get_mul_cost(protocol, n_parties):
|
||||
return get_cost(fixed, variable, protocol, n_parties)
|
||||
|
||||
def get_and_cost(protocol, n_parties):
|
||||
return get_cost(fixed_and, variable_and, protocol, n_parties)
|
||||
|
||||
def expected_communication(protocol, req_num, length, n_parties=None,
|
||||
force_triple_use=False):
|
||||
from Compiler.instructions import shuffle_base
|
||||
from Compiler.program import Tape
|
||||
get_int = lambda x: req_num.get(('modp', x), 0)
|
||||
get_bit = lambda x: req_num.get(('bit', x), 0)
|
||||
res = Comm()
|
||||
if not protocol:
|
||||
return res
|
||||
if not n_parties:
|
||||
try:
|
||||
if get_match(fixed, protocol):
|
||||
raise TypeError()
|
||||
n_parties = int(os.getenv('PLAYERS'))
|
||||
except TypeError:
|
||||
if find_match(dishonest_majority, protocol):
|
||||
n_parties = 2
|
||||
else:
|
||||
n_parties = 3
|
||||
if find_match(dishonest_majority, protocol):
|
||||
threshold = n_parties - 1
|
||||
elif re.match('rep4', protocol):
|
||||
n_parties = 4
|
||||
threshold = 1
|
||||
elif re.match('dealer', protocol):
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = n_parties // 2
|
||||
malicious = not find_match(semihonest, protocol)
|
||||
x = find_match(fixed, protocol)
|
||||
y = get_mul_cost(protocol, n_parties)
|
||||
unit = apply_length(y, length)
|
||||
n_mults = get_int('simple multiplication')
|
||||
matrix_cost = apply_length(
|
||||
get_match_variable(matrix_triples, protocol, n_parties), length)
|
||||
use_diag_matrix = get_match(diag_matrix, protocol)
|
||||
use_triple_number = False
|
||||
if find_match(cheap_dot_product, protocol):
|
||||
n_mults += get_int('dot product')
|
||||
elif (not matrix_cost and not use_diag_matrix) or force_triple_use:
|
||||
use_triple_number = True
|
||||
n_mults = get_int('triple')
|
||||
and_cost = get_and_cost(protocol, n_parties)
|
||||
if and_cost:
|
||||
res += Comm(and_cost) * math.ceil(get_bit('triple') / 8)
|
||||
else:
|
||||
n_mults += get_bit('triple') / (length * 8)
|
||||
bit_cost = Comm(apply_length(
|
||||
bit2a.get(x) or get_match(fixed_bit, protocol) or
|
||||
get_match_variable(variable_bit, protocol, n_parties),
|
||||
length))
|
||||
input_cost = apply_length(
|
||||
get_match(fixed_input, protocol) or \
|
||||
get_match_variable(variable_input, protocol, n_parties), length)
|
||||
output_cost = Comm(
|
||||
get_match(reveal, protocol) or \
|
||||
get_match_variable(reveal_variable, protocol, n_parties) or \
|
||||
(n_parties - 1) * 2)
|
||||
random_cost = apply_length(
|
||||
get_match_variable(variable_random, protocol, n_parties), length)
|
||||
if not random_cost:
|
||||
random_cost = n_parties * input_cost
|
||||
square_unit = get_cost(fixed_square, variable_square, protocol, n_parties)
|
||||
if not square_unit:
|
||||
def square_unit(l):
|
||||
unit = apply_length(y, l)
|
||||
return Comm(0, (unit[1] or unit[0]) + sum(
|
||||
apply_length(output_cost, l)))
|
||||
square_cost = apply_length(square_unit, length)
|
||||
res += square_cost * get_int('square')
|
||||
if bit_cost:
|
||||
res += bit_cost * get_int('bit')
|
||||
elif get_match_variable(bits_from_squares, protocol, n_parties):
|
||||
if square_cost:
|
||||
if get_match(ring, protocol):
|
||||
sb_cost = apply_length(square_unit, length + 1)
|
||||
else:
|
||||
sb_cost = square_cost
|
||||
bit_cost = Comm(0, offline=sum(sb_cost + output_cost * length))
|
||||
else:
|
||||
bit_cost = Comm(0, offline=sum(
|
||||
unit + random_cost + length * output_cost))
|
||||
res += bit_cost * get_int('bit')
|
||||
else:
|
||||
bit_cost = Comm(0, offline=sum(
|
||||
threshold * unit + (threshold + 1) * input_cost))
|
||||
res += bit_cost * get_int('bit')
|
||||
res += unit * n_mults
|
||||
if not unit:
|
||||
sh_protocol = re.sub('mal-', '', protocol)
|
||||
sh_unit = get_mul_cost(sh_protocol, n_parties)
|
||||
sh_random_unit = get_match_variable(
|
||||
variable_random, sh_protocol, n_parties)
|
||||
if sh_unit:
|
||||
res += length * Comm(
|
||||
sum(2 * output_cost * n_mults),
|
||||
int(n_mults * (3 * sh_random_unit + 2 * sh_unit + \
|
||||
2 * sum(output_cost))))
|
||||
res += Comm(get_match(trunc_pr, protocol)) * length * \
|
||||
get_int('probabilistic truncation')
|
||||
res += Comm(bit2a.get(x)) * length * get_int('bit2A')
|
||||
res += output_cost * length * get_int('open')
|
||||
res += output_cost * get_bit('open')
|
||||
res += get_match(dabit_from_bit, protocol) * bit_cost * get_int('dabit')
|
||||
res += random_cost * get_int('random')
|
||||
res += get_int('cut random') * apply_length(
|
||||
get_match(fixed_randoms, protocol), length)
|
||||
shuffle_correction = not find_match(shuffle_application, protocol)
|
||||
def get_node():
|
||||
req_node = Tape.ReqNode("")
|
||||
req_node.aggregate()
|
||||
return req_node
|
||||
for x in req_num:
|
||||
if len(x) >= 3 and x[0] == 'modp':
|
||||
if x[1] == 'input':
|
||||
res += input_cost * req_num[x]
|
||||
elif x[1] == 'shuffle application':
|
||||
shuffle_cost = apply_length(
|
||||
get_match(shuffle_application, protocol), length)
|
||||
if shuffle_cost:
|
||||
res += Comm(shuffle_cost) * req_num[x] * x[2]
|
||||
elif find_match(cheap_dot_product, protocol) or \
|
||||
'dealer' in protocol:
|
||||
res += shuffle_base.n_swaps(x[2]) * (threshold + 1) * \
|
||||
req_num[x] * unit * (x[3] + malicious)
|
||||
elif shuffle_correction:
|
||||
node = get_node()
|
||||
shuffle_base.add_apply_usage(
|
||||
node, x[2], x[3], add_shuffles=False)
|
||||
node.num = -node.num
|
||||
if not use_triple_number:
|
||||
node.num['modp', 'triple'] = 0
|
||||
shuffle_base.add_apply_usage(
|
||||
node, x[2], x[3], add_shuffles=False,
|
||||
malicious=malicious, n_relevant_parties=threshold + 1)
|
||||
res += req_num[x] * \
|
||||
expected_communication(protocol, node.num, length,
|
||||
force_triple_use=True)
|
||||
elif x[1] == 'shuffle generation':
|
||||
if 'dealer' in protocol:
|
||||
res += Comm(
|
||||
req_num[x] * shuffle_base.n_swaps(x[2]) * length)
|
||||
else:
|
||||
req_node = get_node()
|
||||
shuffle_base.add_gen_usage(
|
||||
req_node, x[2], add_shuffles=False)
|
||||
req_node.num = -req_node.num
|
||||
if shuffle_correction:
|
||||
shuffle_base.add_gen_usage(
|
||||
req_node, x[2], add_shuffles=False,
|
||||
malicious=malicious,
|
||||
n_relevant_parties=threshold + 1)
|
||||
res += req_num[x] * \
|
||||
expected_communication(protocol, req_node.num, length)
|
||||
elif x[0] == 'matmul':
|
||||
mm_unit = Comm()
|
||||
if use_diag_matrix:
|
||||
dims = list(x[1])
|
||||
if dims[0] > dims[2]:
|
||||
dims[0::2] = dims[2::-2]
|
||||
mm_unit += Comm(
|
||||
offline=use_diag_matrix(n_parties, length, dims))
|
||||
matrix_cost = Comm(unit.online / 2)
|
||||
for idx in ((0, 1), (1, 2)):
|
||||
mm_unit += Comm(matrix_cost.online) * \
|
||||
x[1][idx[0]] * x[1][idx[1]]
|
||||
mm_unit += Comm(offline=matrix_cost.offline) * x[1][0] * x[1][2]
|
||||
res += mm_unit * req_num[x]
|
||||
elif re.search('edabit', x[0]):
|
||||
edabit = get_match_variable(variable_edabit, protocol, n_parties)
|
||||
if edabit:
|
||||
res += Comm(offline=edabit(x[1])(length)) * req_num[x]
|
||||
res.n_parties = n_parties
|
||||
return res
|
||||
@@ -636,7 +636,8 @@ def preprocess_pandas(data):
|
||||
elif pandas.api.types.is_object_dtype(t):
|
||||
values = list(filter(lambda x: isinstance(x, str),
|
||||
list(data.iloc[:,i].unique())))
|
||||
print('converting the following to unary:', values)
|
||||
print('converting the following to unary from %d: %s' %
|
||||
(len(res), values))
|
||||
if len(values) == 2:
|
||||
res.append(data.iloc[:,i].to_numpy() == values[1])
|
||||
types.append('b')
|
||||
|
||||
@@ -98,7 +98,7 @@ class HeapQ(object):
|
||||
self.size = MemValue(int_type(0))
|
||||
self.int_type = int_type
|
||||
self.basic_type = basic_type
|
||||
prog.reading('heap queue', 'KS14')
|
||||
prog.reading('heap queue', 'KS14', 'Section 5.1')
|
||||
print('heap: %d levels, depth %d, size %d, index size %d' % \
|
||||
(self.levels, self.depth, self.heap.oram.size, self.value_index.size))
|
||||
def update(self, value, prio, for_real=True):
|
||||
@@ -243,7 +243,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
|
||||
:param int_type: secret integer type (default: sint)
|
||||
|
||||
"""
|
||||
prog.reading("Dijkstra's algorithm", "KS14")
|
||||
prog.reading("Dijkstra's algorithm", "KS14", "Section 5.2")
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else -1
|
||||
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
|
||||
|
||||
@@ -59,7 +59,7 @@ def EQZ(a, k):
|
||||
v = sbitvec(a, k).v
|
||||
bit = util.tree_reduce(operator.and_, (~b for b in v))
|
||||
return types.sintbit.conv(bit)
|
||||
prog.reading('equality', 'ABZS13')
|
||||
prog.reading('equality', 'CdH10', 'Protocol 3.7')
|
||||
return prog.non_linear.eqz(a, k)
|
||||
|
||||
def bits(a,m):
|
||||
@@ -313,9 +313,10 @@ def BitDecRingRaw(a, k, m):
|
||||
return bits[:m]
|
||||
else:
|
||||
if program.Program.prog.use_edabit():
|
||||
r, r_bits = types.sint.get_edabit(m, strict=False)
|
||||
r, r_bits = types.sint.get_edabit(m, strict=False, size=a.size)
|
||||
elif program.Program.prog.use_dabit:
|
||||
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
|
||||
r, r_bits = zip(*(types.sint.get_dabit(size=a.size)
|
||||
for i in range(m)))
|
||||
r = types.sint.bit_compose(r)
|
||||
else:
|
||||
r_bits = [types.sint.get_random_bit() for i in range(m)]
|
||||
@@ -334,7 +335,8 @@ def BitDecRing(a, k, m):
|
||||
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
|
||||
comparison.program.reading('bit decomposition', 'ABZS13')
|
||||
comparison.program.reading('bit decomposition', 'CdH10-fixed',
|
||||
'Protocol 3.7')
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
@@ -362,7 +364,7 @@ def Pow2(a, l):
|
||||
return Pow2_from_bits(t)
|
||||
|
||||
def Pow2_from_bits(bits):
|
||||
comparison.program.reading('power of two', 'ABZS13')
|
||||
comparison.program.reading('power of two', 'ABZS13', 'Section 3')
|
||||
m = len(bits)
|
||||
t = list(bits)
|
||||
pow2k = [None for i in range(m)]
|
||||
@@ -419,7 +421,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False):
|
||||
return TruncInRing(a, l, Pow2(m, l))
|
||||
else:
|
||||
kappa = program.Program.prog.security
|
||||
prog.reading('secret truncation', 'ABZS13')
|
||||
prog.reading('secret truncation', 'ABZS13', 'Section 3')
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
@@ -460,7 +462,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False):
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def TruncInRing(to_shift, l, pow2m):
|
||||
comparison.program.reading('secret truncation', 'DEK20')
|
||||
comparison.program.reading('secret truncation', 'DEK20', 'Section 3.2.3')
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
bits = util.bit_decompose(to_shift, l)
|
||||
rev = types.sint.bit_compose(reversed(bits))
|
||||
@@ -564,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
res = sint()
|
||||
trunc_pr(res, a, k, m)
|
||||
else:
|
||||
prog.reading('probabilistic truncation', 'CdH10-fixed')
|
||||
prog.reading('probabilistic truncation', 'CdH10-fixed',
|
||||
'Protocol 3.1')
|
||||
# extra bit to mask overflow
|
||||
prog.curr_tape.require_bit_length(1)
|
||||
if prog.use_edabit() or prog.use_split() > 2:
|
||||
@@ -594,7 +597,7 @@ def TruncPrField(a, k, m):
|
||||
|
||||
program.Program.prog.trunc_pr_warning()
|
||||
prog = program.Program.prog
|
||||
prog.reading('probabilistic truncation', 'CdH10-fixed')
|
||||
prog.reading('probabilistic truncation', 'CdH10-fixed', 'Protocol 3.1')
|
||||
b = two_power(k-1) + a
|
||||
r_prime, r_dprime = types.sint(), types.sint()
|
||||
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
|
||||
@@ -632,7 +635,7 @@ def SDiv(a, b, l, round_nearest=False):
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
|
||||
y = y.round(2 * l + 1, l + 1, nearest=round_nearest, signed=False)
|
||||
return y
|
||||
|
||||
def SDiv_mono(a, b, l):
|
||||
@@ -684,7 +687,7 @@ def BITLT(a, b, bit_length):
|
||||
def BitDecFull(a, n_bits=None, maybe_mixed=False):
|
||||
from .library import get_program, do_while, if_, break_point
|
||||
from .types import sint, regint, longint, cint
|
||||
get_program().reading('full bit decomposition', 'NO07')
|
||||
get_program().reading('full bit decomposition', 'NO07', 'Figure 2')
|
||||
p = get_program().prime
|
||||
assert p
|
||||
bit_length = p.bit_length()
|
||||
@@ -731,6 +734,7 @@ def BitDecFull(a, n_bits=None, maybe_mixed=False):
|
||||
r = sint.get_edabit(bit_length, True)
|
||||
bs[j].link(r[0])
|
||||
tbits[j].link(sbitvec.from_vec(r[1]))
|
||||
tbits[j] = tbits[j].v
|
||||
else:
|
||||
for i in range(bit_length):
|
||||
tbits[j][i].link(sint.get_random_bit())
|
||||
|
||||
@@ -1348,6 +1348,9 @@ class randoms(base.Instruction):
|
||||
arg_format = ['sw','int']
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'cut random'), self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
class randomfulls(base.DataInstruction):
|
||||
""" Store share(s) of a fresh secret random element in secret
|
||||
@@ -1365,6 +1368,12 @@ class randomfulls(base.DataInstruction):
|
||||
return len(self.args)
|
||||
|
||||
class unsplit(base.VectorInstruction, base.Ciscable):
|
||||
""" Bit injection (conversion from binary to arithmetic).
|
||||
|
||||
:param: destination (sint)
|
||||
:param: source (sbits)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['UNSPLIT']
|
||||
arg_format = tools.chain(['sb'], itertools.repeat('sw'))
|
||||
@@ -2568,6 +2577,12 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction,
|
||||
for reg in self.args[i + 2:i + self.args[i]]:
|
||||
yield reg
|
||||
|
||||
def add_usage(self, req_num):
|
||||
base.DataInstruction.add_usage(self, req_num)
|
||||
req_num.increment(
|
||||
(self.field_type, 'dot product'),
|
||||
self.get_size() * len(list(self.bases(iter(self.args)))))
|
||||
|
||||
class matmul_base(base.DataInstruction):
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
@@ -2718,6 +2733,10 @@ class trunc_pr(base.VarArgsInstruction):
|
||||
code = base.opcodes['TRUNC_PR']
|
||||
arg_format = tools.cycle(['sw','s','int','int'])
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('modp', 'probabilistic truncation'),
|
||||
self.get_size() * len(self.args) // 4)
|
||||
|
||||
class shuffle_base(base.DataInstruction):
|
||||
n_relevant_parties = 2
|
||||
|
||||
@@ -2725,12 +2744,12 @@ class shuffle_base(base.DataInstruction):
|
||||
super(shuffle_base, self).__init__(*args, **kwargs)
|
||||
prog = base.program
|
||||
if re.match('ring|rep-field|sy-rep.*', prog.options.execute or ''):
|
||||
ref = 'AHIK+22'
|
||||
ref = 'AHIK+22', 'Protocol 3.2'
|
||||
elif prog.options.execute:
|
||||
ref = 'KS14'
|
||||
ref = 'KS14', 'Section 4.3'
|
||||
else:
|
||||
ref = ('AHIK+22', 'KS14')
|
||||
base.program.reading('secure shuffling', ref)
|
||||
ref = ('AHIK+22', 'KS14'), None
|
||||
base.program.reading('secure shuffling', *ref)
|
||||
|
||||
@staticmethod
|
||||
def logn(n):
|
||||
@@ -2741,26 +2760,39 @@ class shuffle_base(base.DataInstruction):
|
||||
logn = cls.logn(n)
|
||||
return logn * 2 ** logn - 2 ** logn + 1
|
||||
|
||||
def add_gen_usage(self, req_node, n):
|
||||
@classmethod
|
||||
def add_gen_usage(self, req_node, n, add_shuffles=True, malicious=True,
|
||||
n_relevant_parties=None):
|
||||
# hack for unknown usage
|
||||
req_node.increment(('bit', 'inverse'), float('inf'))
|
||||
# minimal usage with two relevant parties
|
||||
logn = self.logn(n)
|
||||
n_switches = self.n_swaps(n)
|
||||
for i in range(self.n_relevant_parties):
|
||||
n_relevant_parties = n_relevant_parties or self.n_relevant_parties
|
||||
for i in range(n_relevant_parties):
|
||||
req_node.increment((self.field_type, 'input', i), n_switches)
|
||||
# multiplications for bit check
|
||||
req_node.increment((self.field_type, 'triple'),
|
||||
n_switches * self.n_relevant_parties)
|
||||
if malicious:
|
||||
# multiplications for bit check
|
||||
req_node.increment((self.field_type, 'triple'),
|
||||
n_switches * n_relevant_parties)
|
||||
if add_shuffles:
|
||||
req_node.increment((self.field_type, 'shuffle generation', n))
|
||||
|
||||
def add_apply_usage(self, req_node, n, record_size):
|
||||
@classmethod
|
||||
def add_apply_usage(self, req_node, n, record_size, add_shuffles=True,
|
||||
malicious=True, n_relevant_parties=None):
|
||||
req_node.increment(('bit', 'inverse'), float('inf'))
|
||||
logn = self.logn(n)
|
||||
n_switches = self.n_swaps(n) * self.n_relevant_parties
|
||||
if n != 2 ** logn:
|
||||
n_switches = self.n_swaps(n) * \
|
||||
(n_relevant_parties or self.n_relevant_parties)
|
||||
real_record_size = record_size
|
||||
if n != 2 ** logn and malicious:
|
||||
record_size += 1
|
||||
req_node.increment((self.field_type, 'triple'),
|
||||
n_switches * record_size)
|
||||
if add_shuffles:
|
||||
req_node.increment(
|
||||
(self.field_type, 'shuffle application', n, real_record_size))
|
||||
|
||||
@base.gf2n
|
||||
class secshuffle(base.VectorInstruction, shuffle_base):
|
||||
@@ -2824,6 +2856,9 @@ class applyshuffle(shuffle_base, base.Mergeable):
|
||||
for i in range(0, len(self.args), 6):
|
||||
self.add_apply_usage(req_node, self.args[i], self.args[i + 3])
|
||||
|
||||
def handles(self):
|
||||
return self.args[::4]
|
||||
|
||||
class delshuffle(base.Instruction):
|
||||
""" Delete secure shuffle.
|
||||
|
||||
@@ -2874,7 +2909,7 @@ class sqrs(base.CISC):
|
||||
arg_format = ['sw', 's']
|
||||
|
||||
def expand(self):
|
||||
s = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
s = [type(self.args[0])() for i in range(6)]
|
||||
c = [self.args[0].clear_type() for i in range(2)]
|
||||
square(s[0], s[1])
|
||||
subs(s[2], self.args[1], s[0])
|
||||
|
||||
@@ -1112,6 +1112,8 @@ class Instruction(object):
|
||||
new_args.append(arg.copy())
|
||||
subs[arg] = new_args[-1]
|
||||
else:
|
||||
if isinstance(arg, program.curr_tape.Register) and arg.caller:
|
||||
print(util.format_trace(arg.caller), file=sys.stderr)
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def print_str(s, *args, print_secrets=False):
|
||||
if print_secrets:
|
||||
val.output()
|
||||
else:
|
||||
secret_error()
|
||||
secret_error(args[i])
|
||||
elif isinstance(val, cfloat):
|
||||
val.print_float_plain()
|
||||
elif isinstance(val, (list, tuple)):
|
||||
@@ -831,7 +831,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
|
||||
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
|
||||
n_threads=None, key_indices=None):
|
||||
get_program().reading('sorting', 'KSS13')
|
||||
get_program().reading('sorting', 'KSS13', 'Section 6.1')
|
||||
a_in = a
|
||||
if isinstance(a_in, list):
|
||||
a = Array.create_from(a)
|
||||
@@ -1592,8 +1592,13 @@ def _run_and_link(function, g=None, lock_lists=True, allow_return=False):
|
||||
pre = copy.copy(g)
|
||||
res = function()
|
||||
if res is not None and not allow_return:
|
||||
if get_program().options.flow_optimization:
|
||||
suffix = ' and avoid -l/--flow-optimization to keep ' \
|
||||
'compile-time branching'
|
||||
else:
|
||||
suffix = ''
|
||||
raise CompilerError('Conditional blocks cannot return values. '
|
||||
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else')
|
||||
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else' + suffix)
|
||||
_link(pre, g)
|
||||
return res
|
||||
|
||||
@@ -2052,7 +2057,8 @@ def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Goldschmidt method as presented in Catrina10,
|
||||
"""
|
||||
get_program().reading('fixed-point division', 'CdH10-fixed')
|
||||
get_program().reading('fixed-point division', 'CdH10-fixed',
|
||||
'Protocol 3.3')
|
||||
prime = get_program().prime
|
||||
if 2 * k == int(get_program().options.ring) or \
|
||||
(prime and 2 * k <= (prime.bit_length() - 1)):
|
||||
|
||||
@@ -63,7 +63,7 @@ import re
|
||||
|
||||
from Compiler import mpc_math, util
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _unreduced_squant
|
||||
from Compiler.types import _unreduced_squant, _single
|
||||
from Compiler.library import *
|
||||
from Compiler.util import is_zero, tree_reduce
|
||||
from Compiler.comparison import CarryOutRawLE
|
||||
@@ -927,6 +927,11 @@ class Dense(DenseBase):
|
||||
progress('f input')
|
||||
|
||||
def _forward(self, batch=None):
|
||||
if not issubclass(self.W.value_type, _single) \
|
||||
or not issubclass(self.X.value_type, _single):
|
||||
raise CompilerError(
|
||||
'dense inputs have to be sfix in arithmetic circuits')
|
||||
|
||||
if batch is None:
|
||||
batch = regint.Array(self.N)
|
||||
batch.assign(regint.inc(self.N))
|
||||
@@ -2160,6 +2165,11 @@ class Conv2d(ConvBase):
|
||||
return weights_h * weights_w * n_channels_in
|
||||
|
||||
def _forward(self, batch):
|
||||
if not issubclass(self.weights.value_type, _single) \
|
||||
or not issubclass(self.X.value_type, _single):
|
||||
raise CompilerError(
|
||||
'convolution inputs have to be sfix in arithmetic circuits')
|
||||
|
||||
if self.tf_weight_format:
|
||||
assert(self.weight_shape[3] == self.output_shape[-1])
|
||||
weights_h, weights_w, _, _ = self.weight_shape
|
||||
@@ -4058,7 +4068,8 @@ class SGD(Optimizer):
|
||||
# divide by len(batch) by truncation
|
||||
# increased rate if len(batch) is not a power of two
|
||||
diff = red_old - nabla_vector
|
||||
pre_trunc = diff.v * rate.v
|
||||
# assuming rate is already synchronized
|
||||
pre_trunc = diff.v.mul(rate.v, sync=False)
|
||||
momentum_value.assign_vector(diff, base)
|
||||
k = max(nabla_vector.k, rate.k) + rate.f
|
||||
m = rate.f + int(log_batch_size)
|
||||
|
||||
@@ -131,7 +131,7 @@ def p_eval(p_c, x):
|
||||
local_aggregation = 0
|
||||
# Evaluation of the Polynomial
|
||||
for i, pre_mult in zip(p_c[1:], pre_mults):
|
||||
local_aggregation += pre_mult.mul_no_reduce(x.coerce(i))
|
||||
local_aggregation += pre_mult.mul_no_reduce(i)
|
||||
return local_aggregation.reduce_after_mul() + p_c[0]
|
||||
|
||||
|
||||
@@ -148,7 +148,8 @@ def p_eval(p_c, x):
|
||||
# @return b2: \{0,1\} value. Returns one when reduction to
|
||||
# \pi is greater than \pi/2.
|
||||
def sTrigSub(x):
|
||||
library.get_program().reading('trigonometric functions', 'AS19')
|
||||
library.get_program().reading('trigonometric functions', 'AS19',
|
||||
'Section 4')
|
||||
# reduction to 2* \pi
|
||||
f = x * (1.0 / (2 * pi))
|
||||
f = trunc(f)
|
||||
@@ -267,7 +268,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
|
||||
:return: :math:`2^a` if it is within the range. Undefined otherwise
|
||||
"""
|
||||
library.get_program().reading('exponential', 'AS19')
|
||||
library.get_program().reading('exponential', 'AS19', 'Protocol 6')
|
||||
def exp_from_parts(whole_exp, frac):
|
||||
class my_fix(type(a)):
|
||||
pass
|
||||
@@ -316,7 +317,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
s = sint.conv(bits[-1])
|
||||
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
|
||||
else:
|
||||
bits = sbitvec(a.v, a.k)
|
||||
bits = sbitvec(a.v, a.k).v
|
||||
s = sint.conv(bits[-1])
|
||||
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
|
||||
higher_bits = bits[a.f:n_bits]
|
||||
@@ -437,7 +438,7 @@ def log2_fx(x, use_division=True):
|
||||
:return: (sfix) the value of :math:`\log_2(x)`
|
||||
|
||||
"""
|
||||
library.get_program().reading('logarithm', 'AS19')
|
||||
library.get_program().reading('logarithm', 'AS19', 'Section 5')
|
||||
if isinstance(x, types._fix):
|
||||
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
||||
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
||||
@@ -815,7 +816,7 @@ def sqrt(x, k=None, f=None):
|
||||
|
||||
:return: square root of :py:obj:`x` (sfix).
|
||||
"""
|
||||
library.get_program().reading('square root', 'AS19')
|
||||
library.get_program().reading('square root', 'AS19', 'Section 3')
|
||||
if k is None:
|
||||
k = x.k
|
||||
if f is None:
|
||||
@@ -837,7 +838,8 @@ def atan(x):
|
||||
|
||||
:return: arctan of :py:obj:`x` (sfix).
|
||||
"""
|
||||
library.get_program().reading('inverse trigonometric functions', 'AS19')
|
||||
library.get_program().reading('inverse trigonometric functions', 'AS19',
|
||||
'Protocol 5')
|
||||
# obtain absolute value of x
|
||||
s = x < 0
|
||||
x_abs = s.if_else(-x, x)
|
||||
|
||||
@@ -26,7 +26,7 @@ class NonLinear:
|
||||
if prog.use_trunc_pr and m and (
|
||||
not prog.options.ring or \
|
||||
prog.use_trunc_pr <= (int(prog.options.ring) - k)):
|
||||
prog.reading('probabilistic truncation', 'DEK20')
|
||||
prog.reading('probabilistic truncation', 'DEK20', 'Section 3.2.2')
|
||||
if prog.options.ring:
|
||||
comparison.require_ring_size(k, 'truncation')
|
||||
else:
|
||||
@@ -92,8 +92,9 @@ class Prime(Masking):
|
||||
def kor(self, d):
|
||||
return KOR(d)
|
||||
|
||||
def require_bit_length(self, bit_length, op):
|
||||
def require_bit_length(self, bit_length, op, slack=0):
|
||||
prog = program.Program.prog
|
||||
bit_length += slack * (not prog.allow_tight_parameters)
|
||||
if bit_length > 32:
|
||||
prog.curr_tape.require_bit_length(bit_length - 1, reason=op)
|
||||
|
||||
@@ -146,7 +147,7 @@ class KnownPrime(NonLinear):
|
||||
else:
|
||||
return super(KnownPrime, self).ltz(a, k)
|
||||
|
||||
def require_bit_length(self, bit_length, op):
|
||||
def require_bit_length(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class Ring(Masking):
|
||||
@@ -189,5 +190,5 @@ class Ring(Masking):
|
||||
def ltz(self, a, k):
|
||||
return LtzRing(a, k)
|
||||
|
||||
def require_bit_length(self, bit_length, op):
|
||||
comparison.require_ring_size(bit_length, op)
|
||||
def require_bit_length(self, *args, **kwargs):
|
||||
comparison.require_ring_size(*args, **kwargs)
|
||||
|
||||
@@ -7,7 +7,7 @@ papers = {
|
||||
'AN17': 'https://eprint.iacr.org/2017/816',
|
||||
'AS19': 'https://eprint.iacr.org/2019/354',
|
||||
'AHIK+22': 'https://eprint.iacr.org/2022/1595',
|
||||
'CdH10': 'https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)',
|
||||
'CdH10': 'https://www.researchgate.net/publication/225092133, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)',
|
||||
'CdH10-fixed': 'https://www.ifca.ai/pub/fc10/31_47.pdf',
|
||||
'CCD88': 'https://doi.org/10.1145/62212.62214',
|
||||
'DDNNT15': 'https://eprint.iacr.org/2015/1006',
|
||||
|
||||
@@ -25,6 +25,7 @@ from Compiler.instructions_base import RegType
|
||||
from . import allocator as al
|
||||
from . import util
|
||||
from .papers import *
|
||||
from .cost import expected_communication
|
||||
|
||||
data_types = dict(
|
||||
triple=0,
|
||||
@@ -131,7 +132,8 @@ class Program(object):
|
||||
assert self.rabbit_gap()
|
||||
print(", for example, %d." % self.prime)
|
||||
self.prime = bad_prime
|
||||
except ImportError:
|
||||
except (ImportError, AssertionError):
|
||||
self.prime = bad_prime
|
||||
print(".")
|
||||
if options.execute:
|
||||
print("Use '-- --prime <prime>' to specify the prime for "
|
||||
@@ -251,6 +253,15 @@ class Program(object):
|
||||
else:
|
||||
print("Use '--execute <protocol>' to see recommended reading "
|
||||
"on the basic protocol.")
|
||||
if self.options.garbled:
|
||||
if not self.options.binary:
|
||||
raise CompilerError(
|
||||
"You have to specify a default bit length using '--binary' "
|
||||
"for garbled circuits.")
|
||||
self.optimize_for_gc()
|
||||
self.allow_tight_parameters = True
|
||||
self.warned_about_tightness = False
|
||||
self.warned_about_a2b = False
|
||||
|
||||
Program.prog = self
|
||||
from . import comparison, instructions, instructions_base, types
|
||||
@@ -439,6 +450,9 @@ class Program(object):
|
||||
else:
|
||||
self.req_num += tape.req_num
|
||||
|
||||
def required_bit_length(self, t):
|
||||
return max(x.req_bit_length[t] for x in self.tapes)
|
||||
|
||||
def write_bytes(self):
|
||||
|
||||
"""Write all non-empty threads and schedule to files."""
|
||||
@@ -455,7 +469,7 @@ class Program(object):
|
||||
sch_file.write("1 0\n")
|
||||
sch_file.write("0\n")
|
||||
sch_file.write(" ".join(sys.argv) + "\n")
|
||||
req = max(x.req_bit_length["p"] for x in self.tapes)
|
||||
req = self.required_bit_length("p")
|
||||
if self.options.ring:
|
||||
sch_file.write("R:%s" % self.options.ring)
|
||||
elif self.options.prime:
|
||||
@@ -470,6 +484,14 @@ class Program(object):
|
||||
assert len(req2) <= 2
|
||||
if req2:
|
||||
sch_file.write("lg2:%s" % max(req2))
|
||||
sch_file.write("\n")
|
||||
exp = self.expected_communication()
|
||||
if exp:
|
||||
sch_file.write(
|
||||
"online:%d offline:%d n_parties:%d\n" % (
|
||||
exp.sanitize() + (exp.n_parties,)))
|
||||
else:
|
||||
sch_file.write('no expections\n')
|
||||
sch_file.close()
|
||||
h = hashlib.sha256()
|
||||
for tape in self.tapes:
|
||||
@@ -590,6 +612,10 @@ class Program(object):
|
||||
# communicate protocol compability
|
||||
Compiler.instructions.active(self._always_active)
|
||||
|
||||
# communicate mulm usage to VM
|
||||
if self.use_mulm != 1:
|
||||
self.relevant_opts.add("no_mulm")
|
||||
|
||||
self.write_bytes()
|
||||
|
||||
if self.options.asmoutfile:
|
||||
@@ -743,6 +769,15 @@ class Program(object):
|
||||
def used_splits(self):
|
||||
return self._split
|
||||
|
||||
def have_a2b(self):
|
||||
if self.use_split() or self.use_edabit() or self.use_dabit:
|
||||
return True
|
||||
if not self.warned_about_a2b:
|
||||
print(
|
||||
'WARNING: No option selected for A2B conversion, defaulting '
|
||||
'to edaBits. Use -X/-Y/-Z to get rid of this warning.')
|
||||
self.warned_about_a2b = True
|
||||
|
||||
def use_square(self, change=None):
|
||||
"""Setting whether to use preprocessed square tuples
|
||||
(default: false).
|
||||
@@ -869,15 +904,30 @@ class Program(object):
|
||||
bl = inst.args[0]
|
||||
return (abs(bl.i) + 63) // 64 * 8
|
||||
|
||||
def reading(self, concept, reference):
|
||||
key = concept, reference
|
||||
def reading(self, concept, reference, part=None):
|
||||
key = concept, reference, part
|
||||
if self.options.papers and key not in self.recommended:
|
||||
if isinstance(reference, tuple):
|
||||
assert part is None
|
||||
reference = ', '.join(papers.get(x) or x for x in reference)
|
||||
print('Recommended reading on %s: %s' % (
|
||||
concept, papers.get(reference) or reference))
|
||||
suffix = ' (%s)' % part or ''
|
||||
print('Recommended reading on %s: %s%s' % (
|
||||
concept, papers.get(reference) or reference, suffix))
|
||||
self.recommended.add(key)
|
||||
|
||||
def expected_communication(self):
|
||||
if self.options.ring:
|
||||
bit_length = int(self.options.ring)
|
||||
elif self.options.prime:
|
||||
bit_length = self.prime.bit_length()
|
||||
else:
|
||||
# check against OnlineOptions.cpp
|
||||
bit_length = max(self.required_bit_length("p"), 128)
|
||||
bit_length = int(math.ceil(bit_length / 64) * 64)
|
||||
length = int(math.ceil(bit_length / 8))
|
||||
return expected_communication(
|
||||
self.options.execute, self.req_num or Tape.ReqNum(), length)
|
||||
|
||||
class Tape:
|
||||
"""A tape contains a list of basic blocks, onto which instructions are added."""
|
||||
|
||||
@@ -1452,6 +1502,12 @@ class Tape:
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __neg__(self):
|
||||
res = Tape.ReqNum()
|
||||
for i, count in list(self.items()):
|
||||
res[i] = -count
|
||||
return res
|
||||
|
||||
def set_all(self, value):
|
||||
if Program.prog.options.verbose and \
|
||||
value == float("inf") and self["all", "inv"] > 0:
|
||||
@@ -1644,6 +1700,19 @@ class Tape:
|
||||
|
||||
__float__ = __int__
|
||||
|
||||
def __eq__(self, other):
|
||||
raise CompilerError("equality testing not implemented")
|
||||
|
||||
__ne__ = __eq__
|
||||
|
||||
class _no_secret_truth(_no_truth):
|
||||
def __bool__(self):
|
||||
raise CompilerError(
|
||||
"Cannot branch on secret values like %s. "
|
||||
"See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-branch-on-secret-values. " % \
|
||||
type(self).__name__
|
||||
)
|
||||
|
||||
class Register(_no_truth):
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
@@ -1755,6 +1824,17 @@ class Tape:
|
||||
return self.vector or [self]
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
if isinstance(index, slice):
|
||||
for x in index.start, index.stop, index.step:
|
||||
if x is not None:
|
||||
int(x)
|
||||
else:
|
||||
int(index)
|
||||
except:
|
||||
raise CompilerError(
|
||||
'cannot address vectors with run-time indices, '
|
||||
'use (Multi)Array instead')
|
||||
if self.size == 1 and index == 0:
|
||||
return self
|
||||
if not self.vector:
|
||||
|
||||
@@ -591,7 +591,7 @@ class _structure(Tape._no_truth):
|
||||
return cls.int_type.reg_type
|
||||
raise CompilerError('type not supported as argument: %s' % cls)
|
||||
|
||||
class _secret_structure(_structure):
|
||||
class _secret_structure(Tape._no_secret_truth, _structure):
|
||||
@classmethod
|
||||
def input_tensor_from(cls, player, shape):
|
||||
""" Input tensor secretly from player.
|
||||
@@ -1105,10 +1105,9 @@ class cint(_clear, _int):
|
||||
@staticmethod
|
||||
def in_immediate_range(value, regint=False):
|
||||
if value and not regint:
|
||||
# +1 for sign
|
||||
bit_length = 1 + int(math.ceil(math.log(abs(value), 2)))
|
||||
# slack for sign
|
||||
program.non_linear.require_bit_length(
|
||||
bit_length, 'integer conversion')
|
||||
value.bit_length(), 'integer conversion', slack=1)
|
||||
return value < 2**31 and value >= -2**31
|
||||
|
||||
@vectorize_init
|
||||
@@ -1321,7 +1320,7 @@ class cint(_clear, _int):
|
||||
:param other: cint/regint/int """
|
||||
return self >> other
|
||||
|
||||
def round(self, k, m, nearest=None, signed=False):
|
||||
def round(self, k, m, nearest=None, signed=True):
|
||||
if signed:
|
||||
self += 2 ** (k - 1)
|
||||
self += 2 ** (m - 1)
|
||||
@@ -2383,7 +2382,7 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
@set_instruction_type
|
||||
@read_mem_value
|
||||
@vectorize
|
||||
def mul(self, other):
|
||||
def mul(self, other, sync=True):
|
||||
""" Secret multiplication. Either both operands have the same
|
||||
size or one size 1 for a value-vector multiplication.
|
||||
|
||||
@@ -2396,7 +2395,7 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
res = type(self)(size=x.size)
|
||||
mulrs(res, x, y)
|
||||
return res
|
||||
if program.use_mulm == 1:
|
||||
if program.use_mulm == 1 or not sync:
|
||||
mulm = instructions.mulm
|
||||
elif program.use_mulm == -1:
|
||||
mulm = lambda res, x, y: instructions.mulm(res, x, cint(regint(y)))
|
||||
@@ -2530,7 +2529,7 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
writesharestofile(regint.conv(position), *shares)
|
||||
|
||||
class sint(_secret, _int):
|
||||
"""
|
||||
r"""
|
||||
Secret integer in the protocol-specific domain. It supports
|
||||
operations with :py:class:`sint`, :py:class:`cint`,
|
||||
:py:class:`regint`, and Python integers. Operations where one of
|
||||
@@ -2559,6 +2558,12 @@ class sint(_secret, _int):
|
||||
undefined and potentially insecure if the operands are longer than
|
||||
the bit length.
|
||||
|
||||
Instances of sint are understood to be signed. This means that,
|
||||
for modulo :math:`N`, numbers in :math:`[0,N/2)` are understood as
|
||||
positive numbers whereas numbers in :math:`[N/2,N)` are understood
|
||||
to be negative, namely :math:`x-N`. This ensures expected
|
||||
arithmetic such as :math:`-1 + 1 = (N-1) + 1 = N = 0 \mod N`.
|
||||
|
||||
See :ref:`nonlinear` for an overview of how non-linear
|
||||
computation is implemented.
|
||||
|
||||
@@ -2826,7 +2831,7 @@ class sint(_secret, _int):
|
||||
self.load_other(val.v.round(val.k, val.f,
|
||||
nearest=val.round_nearest))
|
||||
elif isinstance(val, sbitvec):
|
||||
super(sint, self).__init__('s', val=val, size=val[0].n)
|
||||
super(sint, self).__init__('s', val=val, size=val.v[0].n)
|
||||
else:
|
||||
super(sint, self).__init__('s', val=val, size=size)
|
||||
|
||||
@@ -3001,13 +3006,19 @@ class sint(_secret, _int):
|
||||
maybe_mixed)
|
||||
|
||||
def TruncMul(self, other, k, m, nearest=False):
|
||||
if not nearest and not program.warned_about_tightness and \
|
||||
program.options.ring and int(program.options.ring) == k:
|
||||
print('WARNING: Using tight parameters. '
|
||||
'Increase ring size or reduce fixed-point precision '
|
||||
'for increased efficiency')
|
||||
program.warned_about_tightness = True
|
||||
return (self * other).round(k, m, nearest, signed=True)
|
||||
|
||||
def TruncPr(self, k, m, signed=True):
|
||||
return floatingpoint.TruncPr(self, k, m, signed=signed)
|
||||
|
||||
@vectorize
|
||||
def round(self, k, m, nearest=False, signed=False):
|
||||
def round(self, k, m, nearest=False, signed=True):
|
||||
""" Truncate and maybe round secret :py:obj:`k`-bit integer
|
||||
by :py:obj:`m` bits. :py:obj:`m` can be secret if
|
||||
:py:obj:`nearest` is false, in which case the truncation will be
|
||||
@@ -3625,7 +3636,7 @@ class _bitint(Tape._no_truth):
|
||||
return s ^ carry, a ^ (s & (carry ^ a))
|
||||
|
||||
@staticmethod
|
||||
def bit_comparator(a, b):
|
||||
def bit_comparator(a, b, m=None):
|
||||
long_one = util.long_one(a + b)
|
||||
op = lambda y,x,*args: (util.if_else(x[1], x[0], y[0]), \
|
||||
util.if_else(x[1], long_one, y[1]))
|
||||
@@ -3794,7 +3805,11 @@ class _bitint(Tape._no_truth):
|
||||
if const_rounds:
|
||||
return self.get_highest_different_bits(a, b, index)
|
||||
else:
|
||||
return self.bit_comparator(a, b)
|
||||
try:
|
||||
return self.maybe_function(
|
||||
self.bit_comparator, a, b, result_length=2)
|
||||
except:
|
||||
return self.bit_comparator(a, b)
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.reverse_type(other):
|
||||
@@ -3826,9 +3841,17 @@ class _bitint(Tape._no_truth):
|
||||
if self.reverse_type(other):
|
||||
return other == self
|
||||
diff = self ^ other
|
||||
diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]]
|
||||
return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y),
|
||||
diff_bits))
|
||||
diff_bits = diff.bit_decompose()[:bit_length]
|
||||
try:
|
||||
res = self.maybe_function(self.eqz, diff_bits, [], 1)
|
||||
except:
|
||||
res = self.eqz(diff_bits)
|
||||
return self.comp_result(res[0])
|
||||
|
||||
@staticmethod
|
||||
def eqz(bits, other_bits=None, m=None):
|
||||
diff_bits = [x.bit_not() for x in bits]
|
||||
return [util.tree_reduce(lambda x, y: x.bit_and(y), diff_bits)]
|
||||
|
||||
def __ne__(self, other):
|
||||
return (self == other).bit_not()
|
||||
@@ -4052,7 +4075,7 @@ class cfix(_number, _structure):
|
||||
:py:class:`cfix` if the other operand is public
|
||||
(cfix/regint/cint/int) or :py:class:`sfix` if the other operand is
|
||||
an sfix. It also support comparisons (``==, !=, <, <=, >, >=``),
|
||||
returning either :py:class:`regint` or :py:class:`sbitint`.
|
||||
returning either :py:class:`regint` or :py:class:`sintbit`.
|
||||
|
||||
Similarly to :py:class:`Compiler.types.cint`, this type is
|
||||
restricted to arithmetic circuits due to the fact that only
|
||||
@@ -4806,7 +4829,7 @@ class _fix(_single):
|
||||
return self._new(self.v[index])
|
||||
|
||||
def __iter__(self):
|
||||
return (self._new(x) for x in self.v)
|
||||
return (self._new(x, k=self.k, f=self.f) for x in self.v)
|
||||
|
||||
@vectorize
|
||||
def add(self, other):
|
||||
@@ -4839,7 +4862,8 @@ class _fix(_single):
|
||||
f -= 1
|
||||
v //= 2
|
||||
k = len(bin(abs(v))) - 1
|
||||
other = self.multipliable(v, k, f, self.size)
|
||||
val = self.v.TruncMul(v, self.k + f, f, nearest=self.round_nearest)
|
||||
return self._new(val, k=self.k, f=self.f)
|
||||
try:
|
||||
other = self.coerce(other, equal_precision=False)
|
||||
except:
|
||||
@@ -4983,19 +5007,29 @@ class sfix(_fix):
|
||||
|
||||
It supports basic arithmetic (``+, -, *, /``), returning
|
||||
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
|
||||
returning :py:class:`sbitint`. The other operand can be any of
|
||||
returning :py:class:`sintbit`. The other operand can be any of
|
||||
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
|
||||
and ``**``.
|
||||
|
||||
Note that the default precision (16 bits after the dot, 31 bits in
|
||||
total) only allows numbers up to :math:`2^{31-16-1} \\approx
|
||||
16000` with the smallest non-zero number being :math:`2^{-16}`.
|
||||
16000` with the smallest non-zero number being :math:`2^{-16}
|
||||
\\approx 0.000015`.
|
||||
You can change this using :py:func:`set_precision`.
|
||||
|
||||
Fixed-point multiplication is not linear in the sense of the
|
||||
computation domain. Therefore, techniques from :ref:`nonlinear`
|
||||
have to be used.
|
||||
|
||||
Many operations (including multiplication and division) use
|
||||
probabilistic trunctation by default. This means that the results
|
||||
are not deterministc but random within a small range around the
|
||||
deterministic result. You can switch to (more expensive)
|
||||
deterministic computation by setting
|
||||
``sfix.round_nearest`` to true. See `Catrina and de Hoogh
|
||||
<https://www.ifca.ai/pub/fc10/31_47.pdf>`_ for an introduction to
|
||||
probabilistic truncation.
|
||||
|
||||
:params _v: int/float/regint/cint/sint/sfloat
|
||||
"""
|
||||
int_type = sint
|
||||
@@ -5079,7 +5113,10 @@ class sfix(_fix):
|
||||
return self.v
|
||||
|
||||
def mul_no_reduce(self, other, res_params=None):
|
||||
if not isinstance(other, type(self)):
|
||||
if util.is_constant_float(other):
|
||||
return self.unreduced(
|
||||
self.v * cfix.int_rep(other, k=self.k, f=self.f))
|
||||
elif not isinstance(other, type(self)):
|
||||
return self * other
|
||||
assert self.f == other.f
|
||||
assert self.k == other.k
|
||||
@@ -6040,7 +6077,11 @@ class Array(_vectorizable):
|
||||
@read_mem_value
|
||||
def get_address(self, index, size=None):
|
||||
if isinstance(index, (_secret, _single)):
|
||||
raise CompilerError('need cleartext index')
|
||||
raise CompilerError(
|
||||
'Need cleartext index to address Array. If you need to address '
|
||||
'using secret numbers, you need to use ORAM: '
|
||||
'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#'
|
||||
'module-Compiler.oram')
|
||||
key = str(index), size or 1
|
||||
index = self.check(index, self.length, self.length)
|
||||
if (program.curr_block, key) not in self.address_cache:
|
||||
@@ -6486,6 +6527,10 @@ class Array(_vectorizable):
|
||||
M = Matrix(1, len(self), self.value_type, address=self.address)
|
||||
return M.dot(other)
|
||||
|
||||
def sum(self):
|
||||
""" Sum of elements. """
|
||||
return self[:].sum()
|
||||
|
||||
def shuffle(self):
|
||||
""" Insecure shuffle in place. """
|
||||
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
|
||||
@@ -7105,11 +7150,6 @@ class SubMultiArray(_vectorizable):
|
||||
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
|
||||
try:
|
||||
try:
|
||||
# force matmuls for smaller sizes
|
||||
a, c = res_matrix.sizes
|
||||
if a * c / (a + c) < 2 and \
|
||||
self.value_type == other.value_type:
|
||||
raise AttributeError()
|
||||
self.value_type.direct_matrix_mul
|
||||
skip_reduce = set((sint, sfix)) == \
|
||||
set((self.value_type, other.value_type))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
from Compiler.exceptions import *
|
||||
|
||||
def format_trace(trace, prefix=' '):
|
||||
if trace is None:
|
||||
@@ -91,8 +92,9 @@ def if_else(cond, a, b):
|
||||
else:
|
||||
return cond.if_else(a, b)
|
||||
except:
|
||||
print(cond, a, b)
|
||||
raise
|
||||
raise CompilerError(
|
||||
'incompatible types for ternary/if-else operator: %s' % '/'.join(
|
||||
type(x).__name__ for x in (cond, a, b)))
|
||||
|
||||
def cond_swap(cond, a, b):
|
||||
if isinstance(cond, (bool, int)):
|
||||
|
||||
@@ -18,7 +18,7 @@ int main()
|
||||
KeySetup<Share<P256Element::Scalar>> key;
|
||||
string prefix = PREP_DIR "ECDSA/";
|
||||
mkdir_p(prefix.c_str());
|
||||
write_online_setup(prefix, P256Element::Scalar::pr());
|
||||
P256Element::Scalar::write_setup(prefix);
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
generate_mac_keys<Share<P256Element::Scalar>>(key, 2, prefix, G);
|
||||
|
||||
@@ -60,7 +60,7 @@ void sub(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1)
|
||||
void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (c0.params!=c1.params) { throw params_mismatch(); }
|
||||
if (ans.params!=c1.params) { throw params_mismatch(); }
|
||||
|
||||
@@ -9,7 +9,7 @@ Diagonalizer::Diagonalizer(const MatrixVector& matrices,
|
||||
const FFT_Data& FTD, const FHE_PK& pk) :
|
||||
FTD(FTD)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
assert(not matrices.empty());
|
||||
for (auto& matrix : matrices)
|
||||
|
||||
@@ -28,7 +28,7 @@ void NaiveFFT(vector<modp>& ans,vector<modp>& a,int N,const modp& theta,const Zp
|
||||
|
||||
void FFT(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (N==1) { return; }
|
||||
|
||||
@@ -141,7 +141,7 @@ void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
|
||||
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
|
||||
const Zp_Data& PrD, bool start_with_one)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
assert(roots.size() > size_t(n));
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class FFT_Data
|
||||
|
||||
const Zp_Data& get_prD() const { return prData; }
|
||||
const bigint& get_prime() const { return prData.pr; }
|
||||
|
||||
int phi_m() const { return R.phi_m(); }
|
||||
int m() const { return R.m(); }
|
||||
int num_slots() const { return R.phi_m(); }
|
||||
@@ -71,6 +72,8 @@ class FFT_Data
|
||||
|
||||
const Ring& get_R() const { return R; }
|
||||
|
||||
void write_setup(const string& dir) const { prData.write_setup(dir); }
|
||||
|
||||
bool operator==(const FFT_Data& other) const { return not (*this != other); }
|
||||
bool operator!=(const FFT_Data& other) const;
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "P2Data.h"
|
||||
#include "FFT_Data.h"
|
||||
#include "Tools/CodeLocations.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
@@ -66,7 +67,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
int noise_boost)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
FHE_PK& PK = *this;
|
||||
|
||||
@@ -154,7 +155,7 @@ void FHE_PK::encrypt(Ciphertext& c,
|
||||
void FHE_PK::quasi_encrypt(Ciphertext& c,
|
||||
const Rq_Element& mess,const Random_Coins& rc) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (&rc.get_params()!=params) { throw params_mismatch(); }
|
||||
@@ -216,7 +217,7 @@ void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
|
||||
|
||||
Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
|
||||
@@ -284,6 +285,14 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
|
||||
PRNG G; G.ReSeed();
|
||||
bigint mask;
|
||||
bigint two_Bd = 2 * Bd;
|
||||
|
||||
bool verbose = OnlineOptions::singleton.has_option("verbose_dd");
|
||||
int max_bits = 0;
|
||||
|
||||
if (verbose)
|
||||
cerr << "Random bits in distributed decryption: " << two_Bd.numBits()
|
||||
<< endl;
|
||||
|
||||
for (int i=0; i<(*params).phi_m(); i++)
|
||||
{
|
||||
G.randomBnd(mask, two_Bd);
|
||||
@@ -292,7 +301,13 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
|
||||
vv[i] += mask;
|
||||
vv[i] %= mod;
|
||||
if (vv[i]<0) { vv[i]+=mod; }
|
||||
|
||||
if (verbose)
|
||||
max_bits = max(max_bits, vv[i].numBits());
|
||||
}
|
||||
|
||||
if (verbose)
|
||||
cerr << "Maximum bits in distributed decryption: " << max_bits << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
int m = 1024;
|
||||
int lgp = plaintext_length;
|
||||
bigint p;
|
||||
@@ -95,7 +95,7 @@ template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, P2Data& P2D, bool round_up, int n)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (params.n_mults() > 0)
|
||||
throw runtime_error("only implemented for 0-level BGV");
|
||||
@@ -113,12 +113,13 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Need ciphertext modulus of length " << lgp0;
|
||||
if (params.n_mults() > 0)
|
||||
cout << "+" << lgp1;
|
||||
cout << " and " << phi_N(m) << " slots" << endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
|
||||
{
|
||||
cout << "Need ciphertext modulus of length " << lgp0;
|
||||
if (params.n_mults() > 0)
|
||||
cout << "+" << lgp1;
|
||||
cout << " and " << phi_N(m) << " slots" << endl;
|
||||
}
|
||||
|
||||
int extra_slack = 0;
|
||||
if (round_up)
|
||||
@@ -160,13 +161,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
|
||||
{
|
||||
(void) lg2pi, (void) n;
|
||||
|
||||
#ifdef VERBOSE
|
||||
if (n >= 2 and n <= 10)
|
||||
cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
|
||||
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
|
||||
cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
|
||||
cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
|
||||
#endif
|
||||
bool verbose = OnlineOptions::singleton.has_option("verbose_he_setup");
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
if (n >= 2 and n <= 10)
|
||||
cerr << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
|
||||
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
|
||||
cerr << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
|
||||
cerr << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
|
||||
}
|
||||
|
||||
int extra_slack = 0;
|
||||
if (round_up)
|
||||
@@ -185,15 +189,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
|
||||
extra_slack = 2 * i;
|
||||
lg2p0 += i;
|
||||
lg2p1 += i;
|
||||
#ifdef VERBOSE
|
||||
cout << "Rounding up to " << lg2p0 << "+" << lg2p1
|
||||
<< ", giving extra slack of " << extra_slack << " bits" << endl;
|
||||
#endif
|
||||
|
||||
if (verbose)
|
||||
cerr << "Rounding up to " << lg2p0 << "+" << lg2p1
|
||||
<< ", giving extra slack of " << extra_slack << " bits"
|
||||
<< endl;
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cout << "Total length: " << lg2p0 + lg2p1 << endl;
|
||||
#endif
|
||||
if (verbose)
|
||||
cerr << "Total length: " << lg2p0 + lg2p1 << " = " << lg2p0 << " + "
|
||||
<< lg2p1 << endl;
|
||||
|
||||
return extra_slack;
|
||||
}
|
||||
@@ -305,7 +310,7 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
|
||||
template <>
|
||||
void Parameters::SPDZ_Data_Setup(FHE_Params& params, FFT_Data& FTD)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
bigint p;
|
||||
int idx, m;
|
||||
@@ -678,7 +683,7 @@ void char_2_dimension(int& m, int& lg2)
|
||||
template <>
|
||||
void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
int n = n_parties;
|
||||
int lg2 = plaintext_length;
|
||||
|
||||
@@ -22,9 +22,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
sigma *= 1.4;
|
||||
params.set_R(params.get_R() * 1.4);
|
||||
}
|
||||
#ifdef VERBOSE
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
#endif
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
|
||||
produce_epsilon_constants();
|
||||
|
||||
@@ -40,24 +40,27 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
|
||||
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
|
||||
int matrix_dim = params.get_matrix_dim();
|
||||
#ifdef NOISY
|
||||
cout << "phi(m): " << phi_m << endl;
|
||||
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
|
||||
cout << "V_s: " << V_s << endl;
|
||||
cout << "c1: " << c1 << endl;
|
||||
cout << "c2: " << c2 << endl;
|
||||
cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
|
||||
cout << "log(slack): " << slack << endl;
|
||||
cout << "B_clean: " << B_clean << endl;
|
||||
cout << "B_scale: " << B_scale << endl;
|
||||
cout << "matrix dimension: " << matrix_dim << endl;
|
||||
cout << "drown sec: " << params.secp() << endl;
|
||||
cout << "sec: " << sec << endl;
|
||||
#endif
|
||||
|
||||
assert(matrix_dim > 0);
|
||||
assert(params.secp() >= 0);
|
||||
drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp());
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
|
||||
{
|
||||
cerr << "phi(m): " << phi_m << endl;
|
||||
cerr << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
|
||||
cerr << "V_s: " << V_s << endl;
|
||||
cerr << "c1: " << c1 << endl;
|
||||
cerr << "c2: " << c2 << endl;
|
||||
cerr << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
|
||||
cerr << "log(slack): " << slack << endl;
|
||||
cerr << "B_clean bits: " << B_clean.numBits() << endl;
|
||||
cerr << "B_scale bits: " << B_scale.numBits() << endl;
|
||||
cerr << "matrix dimension: " << matrix_dim << endl;
|
||||
cerr << "drown sec: " << params.secp() << endl;
|
||||
cerr << "sec: " << sec << endl;
|
||||
cerr << "drown bits: " << drown.numBits() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
|
||||
@@ -118,7 +121,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
|
||||
|
||||
NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
|
||||
const FHE_Params& params) :
|
||||
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, params)
|
||||
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, sec, params)
|
||||
{
|
||||
B_KS = p * c2 * this->sigma * phi_m / sqrt(12);
|
||||
#ifdef NOISY
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
void P2Data::forward(vector<poly_type>& ans,const vector<gf2n_short>& a) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
int n=gf2n_short::degree();
|
||||
|
||||
@@ -32,7 +32,7 @@ void P2Data::forward(vector<poly_type>& ans,const vector<gf2n_short>& a) const
|
||||
|
||||
void P2Data::backward(vector<gf2n_short>& ans,const vector<poly_type>& a) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
int n=gf2n_short::degree();
|
||||
BitVector bv(a.size());
|
||||
|
||||
@@ -32,6 +32,7 @@ class P2Data
|
||||
void backward(vector<gf2n_short>& ans,const vector<poly_type>& a) const;
|
||||
|
||||
int get_prime() const { return 2; }
|
||||
void write_setup(const string&) const {}
|
||||
|
||||
bool operator!=(const P2Data& other) const;
|
||||
|
||||
@@ -47,6 +48,7 @@ class P2Data
|
||||
|
||||
void load_or_generate(const Ring& Rg);
|
||||
|
||||
|
||||
friend void init(P2Data& P2D,const Ring& Rg);
|
||||
};
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
|
||||
void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
assert(a.FFTD);
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
@@ -299,7 +299,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other)
|
||||
|
||||
Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
assert(FFTD);
|
||||
Ring_Element ans;
|
||||
@@ -517,7 +517,9 @@ modp Ring_Element::get_constant() const
|
||||
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
|
||||
{
|
||||
ZpD.pack(o);
|
||||
o.store(v);
|
||||
o.store((int)v.size());
|
||||
for (auto& x : v)
|
||||
x.pack(o, ZpD);
|
||||
}
|
||||
|
||||
|
||||
@@ -529,7 +531,16 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
throw runtime_error(
|
||||
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
|
||||
+ to_string(ZpD.pr_bit_length));
|
||||
o.get(v);
|
||||
unsigned int length;
|
||||
o.get(length);
|
||||
v.clear();
|
||||
v.reserve(length);
|
||||
modp tmp;
|
||||
for (unsigned int i=0; i<length; i++)
|
||||
{
|
||||
tmp.unpack(o,ZpD);
|
||||
v.push_back(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ vector<bigint> Rq_Element::to_vec_bigint() const
|
||||
// result mod p0 = a[0]; result mod p1 = a[1]
|
||||
void Rq_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
{
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
a[0].to_vec_bigint(v);
|
||||
if (n_mults() == 0) {
|
||||
@@ -208,7 +208,7 @@ void Rq_Element::Scale(const bigint& p)
|
||||
{
|
||||
if (lev==0) { return; }
|
||||
|
||||
CODE_LOCATION
|
||||
CODE_LOCATION_NO_SCOPE
|
||||
|
||||
if (n_mults() == 0) {
|
||||
//for some reason we scale but we have just one level
|
||||
@@ -312,7 +312,17 @@ void Rq_Element::pack(octetStream& o, int) const
|
||||
void Rq_Element::unpack(octetStream& o, int)
|
||||
{
|
||||
unsigned int ll; o.get(ll); lev=ll;
|
||||
check_level();
|
||||
|
||||
try
|
||||
{
|
||||
check_level();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
lev = 0;
|
||||
throw;
|
||||
}
|
||||
|
||||
for (int i = 0; i <= lev; ++i)
|
||||
a[i].unpack(o);
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ void PartSetup<FD>::output(Names& N)
|
||||
{
|
||||
// Write outputs to file
|
||||
string dir = get_prep_sub_dir<Share<typename FD::T>>(N.num_players());
|
||||
write_online_setup(dir, FieldD.get_prime());
|
||||
FieldD.write_setup(dir);
|
||||
write_mac_key(dir, N.my_num(), N.num_players(), alphai);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
|
||||
vv.resize(pk.get_params().phi_m());
|
||||
vv1.resize(pk.get_params().phi_m());
|
||||
// extra limb for operations
|
||||
bigint limit = pk.get_params().Q() << 64;
|
||||
bigint limit = pk.get_params().p0() << 64;
|
||||
vv.allocate_slots(limit);
|
||||
vv1.allocate_slots(limit);
|
||||
mf.allocate_slots(pk.p() << 64);
|
||||
@@ -19,6 +19,8 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
|
||||
|
||||
class ModuloTreeSum : public TreeSum<bigint>
|
||||
{
|
||||
typedef TreeSum<bigint> super;
|
||||
|
||||
bigint modulo;
|
||||
|
||||
void post_add_process(vector<bigint>& values)
|
||||
@@ -32,6 +34,12 @@ public:
|
||||
modulo(modulo)
|
||||
{
|
||||
}
|
||||
|
||||
void run(vector<bigint>& values, const Player& P)
|
||||
{
|
||||
lengths.resize(values.size(), numBytes(modulo));
|
||||
super::run(values, P);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FD>
|
||||
@@ -48,13 +56,15 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
|
||||
if ((int)vv.size() != params.phi_m())
|
||||
throw length_error("wrong length of ring element");
|
||||
|
||||
size_t length = numBytes(pk.get_params().p0());
|
||||
|
||||
if (OnlineOptions::singleton.direct)
|
||||
{
|
||||
// Now pack into an octetStream for broadcasting
|
||||
vector<octetStream> os(P.num_players());
|
||||
|
||||
for (int i=0; i<params.phi_m(); i++)
|
||||
{ (os[P.my_num()]).store(vv[i]); }
|
||||
{ (os[P.my_num()]).store(vv[i], length); }
|
||||
|
||||
// Broadcast and Receive the values
|
||||
P.Broadcast_Receive(os);
|
||||
@@ -67,7 +77,7 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
|
||||
{
|
||||
for (int j = 0; j < params.phi_m(); j++)
|
||||
{
|
||||
os[i].get(vv1[j]);
|
||||
os[i].get(vv1[j], length);
|
||||
}
|
||||
share.dist_decrypt_2(vv, vv1);
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ void RealPairwiseMachine::init()
|
||||
gfp::init_field(p);
|
||||
ofstream outf;
|
||||
if (output)
|
||||
write_online_setup(get_prep_dir<FFT_Data>(P), p);
|
||||
gfp::write_setup(get_prep_dir<FFT_Data>(P));
|
||||
}
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
@@ -141,5 +141,10 @@ void PairwiseMachine::check(Player& P) const
|
||||
bundle.compare(P);
|
||||
}
|
||||
|
||||
int PairwiseMachine::comp_sec()
|
||||
{
|
||||
return NonInteractiveProof::comp_sec(sec);
|
||||
}
|
||||
|
||||
template void RealPairwiseMachine::setup_keys<FFT_Data>();
|
||||
template void RealPairwiseMachine::setup_keys<P2Data>();
|
||||
|
||||
@@ -31,6 +31,8 @@ public:
|
||||
void unpack(octetStream& os);
|
||||
|
||||
void check(Player& P) const;
|
||||
|
||||
int comp_sec();
|
||||
};
|
||||
|
||||
class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine
|
||||
|
||||
@@ -71,6 +71,7 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
string filename = PREP_DIR + T::name() + "-"
|
||||
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
|
||||
+ to_string(params.secp()) + "-"
|
||||
+ to_string(machine.comp_sec()) + "-"
|
||||
+ to_string(params.get_matrix_dim()) + "-"
|
||||
+ OnlineOptions::singleton.prime.get_str() + "-"
|
||||
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
|
||||
@@ -121,7 +122,7 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
os.output(file);
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
if (OnlineOptions::singleton.has_option("verbose_he"))
|
||||
{
|
||||
cerr << "Ciphertext length: " << params.p0().numBits();
|
||||
for (size_t i = 1; i < params.FFTD().size(); i++)
|
||||
@@ -131,6 +132,7 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64);
|
||||
cerr << " limbs)";
|
||||
cerr << endl;
|
||||
cerr << "Number of slots: " << params.phi_m() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -184,3 +184,11 @@ void Proof::Preimages::check_sizes()
|
||||
if (m.size() != r.size())
|
||||
throw runtime_error("preimage sizes don't match");
|
||||
}
|
||||
|
||||
int NonInteractiveProof::comp_sec(int sec)
|
||||
{
|
||||
if (sec > 0)
|
||||
return OnlineOptions::singleton.comp_sec();
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ using namespace std;
|
||||
#include "FHE/Ciphertext.h"
|
||||
#include "FHE/AddableVector.h"
|
||||
#include "Protocols/CowGearOptions.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include "config.h"
|
||||
|
||||
@@ -90,9 +91,6 @@ class Proof
|
||||
{
|
||||
V = ceil((sec + 2) / log2(2 * phim + 1));
|
||||
U = 2 * V;
|
||||
#ifdef VERBOSE
|
||||
cerr << "Using " << U << " ciphertexts per proof" << endl;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -151,14 +149,24 @@ class Proof
|
||||
output += input.at(j);
|
||||
}
|
||||
}
|
||||
|
||||
void debugging()
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("verbose_he"))
|
||||
{
|
||||
cerr << "Using " << U << " ciphertexts per proof" << endl;
|
||||
cerr << "Plaintext bound check bit length: " << B_plain_length << endl;
|
||||
cerr << "Randomness bound check bit length: " << B_rand_length << endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NonInteractiveProof : public Proof
|
||||
{
|
||||
// sec = 0 used for protocols without proofs
|
||||
static int comp_sec(int sec) { return sec > 0 ? max(COMP_SEC, sec) : 0; }
|
||||
|
||||
public:
|
||||
// sec = 0 used for protocols without proofs
|
||||
static int comp_sec(int sec);
|
||||
|
||||
bigint static slack(int sec, int phim)
|
||||
{ sec = comp_sec(sec); return bigint(phim * sec * sec) << (sec / 2 + 8); }
|
||||
|
||||
@@ -174,6 +182,7 @@ public:
|
||||
B_rand_length = numBits(B*3*phim*rho);
|
||||
plain_check = (bigint(1) << B_plain_length) - sec * tau;
|
||||
rand_check = (bigint(1) << B_rand_length) - sec * rho;
|
||||
debugging();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -194,6 +203,7 @@ public:
|
||||
// leeway for completeness
|
||||
plain_check = (bigint(2) << B_plain_length);
|
||||
rand_check = (bigint(2) << B_rand_length);
|
||||
debugging();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -161,6 +161,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
|
||||
template<class T,class FD,class S>
|
||||
void SimpleEncCommit<T,FD,S>::create_more()
|
||||
{
|
||||
CODE_LOCATION
|
||||
cout << "Generating more ciphertexts in round " << this->n_rounds << endl;
|
||||
octetStream ciphertexts, cleartexts;
|
||||
size_t prover_memory = this->generate_proof(this->c, this->m, ciphertexts, cleartexts);
|
||||
@@ -181,6 +182,7 @@ template <class FD>
|
||||
size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& ciphertexts,
|
||||
octetStream& cleartexts)
|
||||
{
|
||||
CODE_LOCATION
|
||||
AddableVector<Ciphertext> others_ciphertexts;
|
||||
others_ciphertexts.resize(proof.U, pk.get_params());
|
||||
for (int i = 1; i < P.num_players(); i++)
|
||||
@@ -244,6 +246,7 @@ SummingEncCommit<FD>::SummingEncCommit(const Player& P, const FHE_PK& pk,
|
||||
template<class FD>
|
||||
void SummingEncCommit<FD>::create_more()
|
||||
{
|
||||
CODE_LOCATION
|
||||
octetStream cleartexts;
|
||||
const Player& P = this->P;
|
||||
AddableVector<Ciphertext> commitments;
|
||||
@@ -267,10 +270,11 @@ void SummingEncCommit<FD>::create_more()
|
||||
this->c.unpack(ciphertexts, this->pk);
|
||||
commitments.unpack(ciphertexts, this->pk);
|
||||
|
||||
#ifdef VERBOSE_HE
|
||||
cout << "Tree-wise sum of ciphertexts with "
|
||||
<< 1e-9 * ciphertexts.get_length() << " GB" << endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("verbose_he"))
|
||||
cerr << "Tree-wise sum of " << this->c.size()
|
||||
<< " ciphertexts with " << 1e-9 * ciphertexts.get_length()
|
||||
<< " GB" << endl;
|
||||
|
||||
this->timers["Exchanging ciphertexts"].start();
|
||||
tree_sum.run(this->c, P);
|
||||
tree_sum.run(commitments, P);
|
||||
|
||||
@@ -56,6 +56,9 @@ public:
|
||||
void unpack(octetStream&) {}
|
||||
|
||||
void check(Player&) const {}
|
||||
|
||||
// computational security doesn't matter in global proofs
|
||||
int comp_sec() { return 0; }
|
||||
};
|
||||
|
||||
class MultiplicativeMachineParams : public MachineBase
|
||||
|
||||
@@ -78,6 +78,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
const int CcdShare<T>::default_length;
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_CCDSHARE_H_ */
|
||||
|
||||
@@ -88,6 +88,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
const int MaliciousCcdShare<T>::default_length;
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_MALICIOUSCCDSHARE_H_ */
|
||||
|
||||
@@ -286,7 +286,12 @@ void Processor<T>::notcb(const ::BaseInstruction& instruction)
|
||||
template<class T>
|
||||
void Processor<T>::movsb(const ::BaseInstruction& instruction)
|
||||
{
|
||||
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
|
||||
int n_blocks;
|
||||
if (instruction.get_n() < unsigned(T::default_length))
|
||||
n_blocks = 1;
|
||||
else
|
||||
n_blocks = DIV_CEIL(instruction.get_n(), T::default_length);
|
||||
for (int i = 0; i < n_blocks; i++)
|
||||
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i];
|
||||
}
|
||||
|
||||
@@ -407,12 +412,14 @@ void Processor<T>::convcbitvec(const BaseInstruction& instruction,
|
||||
{
|
||||
auto proto = ShareThread<T>::s().protocol;
|
||||
auto P = ShareThread<T>::s().P;
|
||||
if (proto)
|
||||
// The default use case in the compiler doesn't require synchronization
|
||||
// with function-dependent protocols, but testing does.
|
||||
if (proto and OnlineOptions::singleton.has_option("convcbitvec_sync"))
|
||||
proto->sync(bits, *P);
|
||||
else
|
||||
throw exception();
|
||||
throw no_singleton();
|
||||
}
|
||||
catch (exception&)
|
||||
catch (no_singleton&)
|
||||
{
|
||||
if (P)
|
||||
ProtocolBase<T>::sync(bits, *P);
|
||||
|
||||
@@ -147,6 +147,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T, class V>
|
||||
const int SemiSecretBase<T, V>::default_length;
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SEMISECRET_H_ */
|
||||
|
||||
@@ -16,9 +16,6 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T, class V>
|
||||
const int SemiSecretBase<T, V>::default_length;
|
||||
|
||||
inline
|
||||
SemiSecret::MC* SemiSecret::new_mc(
|
||||
typename super::mac_key_type)
|
||||
|
||||
@@ -67,7 +67,7 @@ inline ShareThread<T>& ShareThread<T>::s()
|
||||
if (singleton and T::is_real)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no ShareThread singleton");
|
||||
throw no_singleton("no ShareThread singleton");
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -120,6 +120,11 @@ public:
|
||||
return {S, left, right, n_full_blocks()};
|
||||
}
|
||||
|
||||
Range<StackedVector<T>> full_block_left_range(StackedVector<T>& S)
|
||||
{
|
||||
return {S, left, n_full_blocks()};
|
||||
}
|
||||
|
||||
DoubleIterator<T> partial_block(StackedVector<T>& S)
|
||||
{
|
||||
assert(n_blocks() != n_full_blocks());
|
||||
@@ -127,6 +132,17 @@ public:
|
||||
S.iterator_for_size(right + n_full_blocks(), 1)};
|
||||
}
|
||||
|
||||
typename CheckVector<T>::iterator partial_left_block(StackedVector<T>& S)
|
||||
{
|
||||
assert(n_blocks() != n_full_blocks());
|
||||
return S.iterator_for_size(left + n_full_blocks(), 1);
|
||||
}
|
||||
|
||||
T& get_right_base(StackedVector<T>& S)
|
||||
{
|
||||
return S[right];
|
||||
}
|
||||
|
||||
Range<StackedVector<T>> full_block_output_range(StackedVector<T>& S)
|
||||
{
|
||||
return {S, dest, n_full_blocks()};
|
||||
@@ -174,16 +190,17 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
for (auto info : infos)
|
||||
{
|
||||
int n = T::default_length;
|
||||
for (auto x : info.full_block_input_range(S))
|
||||
auto& y = info.get_right_base(S);
|
||||
for (auto x : info.full_block_left_range(S))
|
||||
{
|
||||
x.second.extend_bit(y_ext, n);
|
||||
protocol->prepare_mult(x.first, y_ext, n, true);
|
||||
y.extend_bit(y_ext, n);
|
||||
protocol->prepare_mult(x, y_ext, n, true);
|
||||
}
|
||||
n = info.last_length();
|
||||
if (n)
|
||||
{
|
||||
info.partial_block(S).left->mask(x_ext, n);
|
||||
info.partial_block(S).right->extend_bit(y_ext, n);
|
||||
info.partial_left_block(S)->mask(x_ext, n);
|
||||
y.extend_bit(y_ext, n);
|
||||
protocol->prepare_mult(x_ext, y_ext, n, true);
|
||||
}
|
||||
}
|
||||
@@ -193,7 +210,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
if (fast_mode)
|
||||
for (auto x : info.full_block_input_range(S))
|
||||
protocol->prepare_mul_fast(x.first, x.second);
|
||||
else
|
||||
else if (info.n_full_blocks())
|
||||
for (auto x : info.full_block_input_range(S))
|
||||
protocol->prepare_mul(x.first, x.second);
|
||||
int n = info.last_length();
|
||||
@@ -228,7 +245,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
if (fast_mode)
|
||||
for (auto& res : info.full_block_output_range(S))
|
||||
res = protocol->finalize_mul_fast();
|
||||
else
|
||||
else if (info.n_full_blocks())
|
||||
for (auto& res : info.full_block_output_range(S))
|
||||
res = protocol->finalize_mul();
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ public:
|
||||
|
||||
void join_tape();
|
||||
void finish();
|
||||
|
||||
virtual NamedCommStats extra_comm() { return {}; }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -115,6 +115,7 @@ void ThreadMaster<T>::run_with_error()
|
||||
for (auto thread : threads)
|
||||
{
|
||||
stats += thread->P->total_comm();
|
||||
stats += thread->extra_comm();
|
||||
exe_stats += thread->processor.stats;
|
||||
delete thread;
|
||||
}
|
||||
|
||||
@@ -145,6 +145,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
const int TinierShare<T>::default_length;
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINIERSHARE_H_ */
|
||||
|
||||
@@ -33,7 +33,7 @@ void TinierSharePrep<T>::buffer_secret_triples()
|
||||
assert(triple_generator != 0);
|
||||
params.generateBits = false;
|
||||
vector<array<T, 3>> triples;
|
||||
TripleShuffleSacrifice<T> sacrifice(DATA_GF2);
|
||||
TripleShuffleSacrifice<T> sacrifice;
|
||||
size_t required;
|
||||
required = sacrifice.minimum_n_inputs_with_combining(
|
||||
BaseMachine::batch_size<T>(DATA_TRIPLE));
|
||||
|
||||
@@ -38,7 +38,10 @@ int main(int argc, const char** argv)
|
||||
if (online_opts.prime_limbs() == 2)
|
||||
return run<2, 1>(machine);
|
||||
|
||||
cerr << "Not compiled for choice of parameters" << endl;
|
||||
cerr << "Try using '-lgp 128'" << endl;
|
||||
if (online_opts.prime_limbs() > 2)
|
||||
cerr << "Use MASCOT with large primes" << endl;
|
||||
else
|
||||
cerr << "Not compiled for choice of parameters" << endl;
|
||||
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ int main(int argc, const char** argv)
|
||||
{
|
||||
if (s == SPDZ2K_DEFAULT_SECURITY)
|
||||
{
|
||||
ring_domain_error(k);
|
||||
ring_domain_error(k, 72);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
2
Makefile
2
Makefile
@@ -7,7 +7,7 @@ TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp))
|
||||
|
||||
NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
|
||||
|
||||
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o
|
||||
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o Protocols/ShareInterface.o
|
||||
|
||||
FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ public:
|
||||
void randomize(PRNG& G);
|
||||
void almost_randomize(PRNG& G) { randomize(G); }
|
||||
|
||||
void output(ostream& s,bool human) const;
|
||||
void output(ostream& s, bool human, bool signed_ = true) const;
|
||||
void input(istream& s,bool human);
|
||||
|
||||
void pack(octetStream& os) const { os.store_int(a, sizeof(a)); }
|
||||
|
||||
@@ -15,7 +15,7 @@ inline void IntBase<T>::specification(octetStream& os)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void IntBase<T>::output(ostream& s,bool human) const
|
||||
void IntBase<T>::output(ostream& s, bool human, bool) const
|
||||
{
|
||||
if (human)
|
||||
s << a;
|
||||
|
||||
@@ -85,15 +85,17 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree)
|
||||
p = OnlineOptions::singleton.prime;
|
||||
if (!probPrime(p))
|
||||
{
|
||||
cerr << p << " is not a prime" << endl;
|
||||
exit(1);
|
||||
throw runtime_error(to_string(p) + " is not a prime");
|
||||
}
|
||||
else if (m != 1 and p % m != 1)
|
||||
{
|
||||
cerr << p
|
||||
<< " is not compatible with our encryption scheme, must be 1 modulo "
|
||||
<< m << endl;
|
||||
exit(1);
|
||||
throw runtime_error(
|
||||
to_string(p)
|
||||
+ " is not compatible with our encryption scheme, must be "
|
||||
"1 modulo " + to_string(m) + ". This is because "
|
||||
"the implementation relies on number theoretic transform. "
|
||||
"See https://eprint.iacr.org/2024/585.pdf for details, "
|
||||
"in particular Theorem 13.");
|
||||
}
|
||||
else
|
||||
return;
|
||||
@@ -125,8 +127,10 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree)
|
||||
}
|
||||
|
||||
|
||||
void write_online_setup(string dirname, const bigint& p)
|
||||
void Zp_Data::write_setup(const string& dirname) const
|
||||
{
|
||||
auto& p = pr;
|
||||
|
||||
if (p == 0)
|
||||
throw runtime_error("prime cannot be 0");
|
||||
|
||||
@@ -145,19 +149,23 @@ void write_online_setup(string dirname, const bigint& p)
|
||||
ofstream outf;
|
||||
outf.open(ss.str().c_str());
|
||||
outf << p << endl;
|
||||
outf << montgomery << endl;
|
||||
if (!outf.good())
|
||||
throw file_error("cannot write to " + ss.str());
|
||||
}
|
||||
|
||||
void check_setup(string dir, bigint pr)
|
||||
void Zp_Data::check_setup(const string& dir)
|
||||
{
|
||||
bigint p;
|
||||
bool mont = true;
|
||||
string filename = dir + "Params-Data";
|
||||
ifstream(filename) >> p;
|
||||
ifstream(filename) >> p >> mont;
|
||||
if (p == 0)
|
||||
throw setup_error("no modulus in " + filename);
|
||||
if (p != pr)
|
||||
throw setup_error("wrong modulus in " + filename);
|
||||
if (mont != montgomery)
|
||||
throw setup_error("Montgomery different in " + filename);
|
||||
}
|
||||
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
|
||||
|
||||
@@ -26,8 +26,6 @@ template<class T>
|
||||
void generate_prime_setup(string dir, int lgp);
|
||||
template<class T>
|
||||
void generate_online_setup(string dirname, bigint& p, int lgp);
|
||||
void write_online_setup(string dirname, const bigint& p);
|
||||
void check_setup(string dirname, bigint p);
|
||||
|
||||
// Setup primes only
|
||||
// Chooses a p of at least lgp bits
|
||||
|
||||
@@ -13,14 +13,15 @@ void generate_online_setup(string dirname, bigint& p, int lgp)
|
||||
{
|
||||
int idx, m;
|
||||
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
|
||||
write_online_setup(dirname, p);
|
||||
T::init_field(p);
|
||||
T::write_setup(dirname);
|
||||
}
|
||||
|
||||
template<class T = gfp>
|
||||
void read_setup(const string& dir_prefix, int lgp = -1)
|
||||
{
|
||||
bigint p;
|
||||
bool montgomery = true;
|
||||
|
||||
string filename = dir_prefix + "Params-Data";
|
||||
|
||||
@@ -32,6 +33,8 @@ void read_setup(const string& dir_prefix, int lgp = -1)
|
||||
#endif
|
||||
ifstream inpf(filename.c_str());
|
||||
inpf >> p;
|
||||
inpf >> montgomery;
|
||||
|
||||
if (inpf.fail())
|
||||
{
|
||||
if (lgp > 0)
|
||||
@@ -45,9 +48,12 @@ void read_setup(const string& dir_prefix, int lgp = -1)
|
||||
throw file_error(filename.c_str());
|
||||
}
|
||||
else
|
||||
T::init_field(p);
|
||||
T::init_field(p, montgomery);
|
||||
|
||||
inpf.close();
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Using prime modulus " << T::pr() << endl;
|
||||
}
|
||||
|
||||
#endif /* MATH_SETUP_HPP_ */
|
||||
|
||||
@@ -230,3 +230,8 @@ void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const
|
||||
q_half = shanks_q_half;
|
||||
r = shanks_r;
|
||||
}
|
||||
|
||||
string Zp_Data::fake_opts() const
|
||||
{
|
||||
return "-P " + to_string(pr) + (montgomery ? "" : " -n");
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ using namespace std;
|
||||
|
||||
#ifndef MAX_MOD_SZ
|
||||
#if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11
|
||||
#define MAX_MOD_SZ GFP_MOD_SZ
|
||||
#define MAX_MOD_SZ 2 * GFP_MOD_SZ
|
||||
#else
|
||||
#define MAX_MOD_SZ 11
|
||||
#endif
|
||||
@@ -94,6 +94,11 @@ class Zp_Data
|
||||
|
||||
void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const;
|
||||
|
||||
void write_setup(const string& directory) const;
|
||||
void check_setup(const string& directory);
|
||||
|
||||
string fake_opts() const;
|
||||
|
||||
template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD);
|
||||
template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD);
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ int powerMod(int x,int e,int p)
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
size_t bigint::report_size(ReportType type) const
|
||||
{
|
||||
size_t res = 0;
|
||||
@@ -98,6 +97,16 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs)
|
||||
mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data);
|
||||
}
|
||||
|
||||
void bigint::pack(octetStream& os, int length) const
|
||||
{
|
||||
os.store(*this, length);
|
||||
}
|
||||
|
||||
void bigint::unpack(octetStream& os, int length)
|
||||
{
|
||||
os.get(*this, length);
|
||||
}
|
||||
|
||||
string to_string(const bigint& x)
|
||||
{
|
||||
stringstream ss;
|
||||
|
||||
@@ -134,8 +134,8 @@ public:
|
||||
void generateUniform(PRNG& G, int n_bits, bool positive = false)
|
||||
{ G.get(*this, n_bits, positive); }
|
||||
|
||||
void pack(octetStream& os, int = -1) const { os.store(*this); }
|
||||
void unpack(octetStream& os, int = -1) { os.get(*this); };
|
||||
void pack(octetStream& os, int = -1) const;
|
||||
void unpack(octetStream& os, int = -1);
|
||||
|
||||
size_t report_size(ReportType type) const;
|
||||
};
|
||||
|
||||
@@ -31,6 +31,9 @@ mpf_class bigint::get_float(T v, T p, T z, T s)
|
||||
Integer exp = Integer(p, 31).get();
|
||||
bigint tmp;
|
||||
tmp.from_signed(v);
|
||||
if (abs(tmp) == 1)
|
||||
BaseMachine::s().mini_warning = min(BaseMachine::s().mini_warning,
|
||||
int(exp.get()));
|
||||
mpf_class res = tmp;
|
||||
if (exp > 0)
|
||||
mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp.get());
|
||||
|
||||
@@ -15,6 +15,7 @@ class fixint : public SignedZ2<64 * (L + 1)>
|
||||
|
||||
public:
|
||||
typedef SignedZ2<64 * (L + 1)> super;
|
||||
typedef typename conditional<L == 0, super, SignedZ2<64 * L>>::type pack_type;
|
||||
|
||||
fixint()
|
||||
{
|
||||
@@ -56,6 +57,19 @@ public:
|
||||
*this = bigint::tmp;
|
||||
}
|
||||
|
||||
void pack(octetStream& os) const
|
||||
{
|
||||
pack_type tmp = *this;
|
||||
tmp.pack(os);
|
||||
}
|
||||
|
||||
void unpack(octetStream& os)
|
||||
{
|
||||
pack_type tmp;
|
||||
tmp.unpack(os);
|
||||
*this = tmp;
|
||||
}
|
||||
|
||||
int get_min_alloc() const
|
||||
{
|
||||
return this->N_BYTES;
|
||||
|
||||
@@ -36,8 +36,8 @@ template<class T> void generate_prime_setup(string, int, int);
|
||||
#define GFP_MOD_SZ 2
|
||||
#endif
|
||||
|
||||
#if GFP_MOD_SZ > MAX_MOD_SZ
|
||||
#error GFP_MOD_SZ must be at most MAX_MOD_SZ
|
||||
#if 2 * GFP_MOD_SZ > MAX_MOD_SZ
|
||||
#error 2 * GFP_MOD_SZ must be at most MAX_MOD_SZ
|
||||
#endif
|
||||
|
||||
/**
|
||||
@@ -105,9 +105,9 @@ class gfp_ : public ValueInterface
|
||||
static void write_setup(int nplayers)
|
||||
{ write_setup(get_prep_sub_dir<T>(nplayers)); }
|
||||
static void write_setup(string dir)
|
||||
{ write_online_setup(dir, pr()); }
|
||||
{ ZpD.write_setup(dir); }
|
||||
static void check_setup(string dir);
|
||||
static string fake_opts() { return " -P " + to_string(pr()); }
|
||||
static string fake_opts() { return ZpD.fake_opts(); }
|
||||
|
||||
/**
|
||||
* Get the prime modulus
|
||||
|
||||
@@ -28,7 +28,7 @@ inline void gfp_<X, L>::read_or_generate_setup(string dir,
|
||||
template<int X, int L>
|
||||
void gfp_<X, L>::check_setup(string dir)
|
||||
{
|
||||
::check_setup(dir, pr());
|
||||
ZpD.check_setup(dir);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
@@ -201,7 +201,7 @@ bool gfp_<X, L>::allows(Dtype type)
|
||||
template<int X, int L>
|
||||
void gfp_<X, L>::specification(octetStream& os)
|
||||
{
|
||||
os.store(pr());
|
||||
ZpD.pack(os);
|
||||
}
|
||||
|
||||
template <int X, int L>
|
||||
|
||||
@@ -33,7 +33,7 @@ char gfpvar_<X, L>::type_char()
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::specification(octetStream& os)
|
||||
{
|
||||
os.store(pr());
|
||||
ZpD.pack(os);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
@@ -101,13 +101,13 @@ const bigint& gfpvar_<X, L>::pr()
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::check_setup(string dir)
|
||||
{
|
||||
::check_setup(dir, pr());
|
||||
ZpD.check_setup(dir);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::write_setup(string dir)
|
||||
{
|
||||
write_online_setup(dir, pr());
|
||||
ZpD.write_setup(dir);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
|
||||
@@ -332,7 +332,7 @@ void modp_<L>::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_)
|
||||
if (human)
|
||||
{ bigint te;
|
||||
to_bigint(te, ZpD);
|
||||
if (te < ZpD.pr / 2 or not signed_)
|
||||
if (te <= ZpD.pr_half or not signed_)
|
||||
s << te;
|
||||
else
|
||||
s << (te - ZpD.pr);
|
||||
|
||||
@@ -24,16 +24,26 @@ public:
|
||||
delete N;
|
||||
}
|
||||
|
||||
void send_to_no_stats(int player, const octetStream& o) const
|
||||
void send_to(int player, const octetStream& o) const
|
||||
{
|
||||
P.send_to(player, o);
|
||||
}
|
||||
|
||||
void receive_player_no_stats(int i, octetStream& o) const
|
||||
void receive_player(int i, octetStream& o) const
|
||||
{
|
||||
P.receive_player(i, o);
|
||||
}
|
||||
|
||||
void send_to_no_stats(int, const octetStream&) const
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void receive_player_no_stats(int, octetStream&) const
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void send_receive_all_no_stats(const vector<vector<bool>>& channels,
|
||||
const vector<octetStream>& to_send,
|
||||
vector<octetStream>& to_receive) const
|
||||
|
||||
@@ -367,10 +367,7 @@ long MultiPlayer<T>::receive_long(int i) const
|
||||
|
||||
void Player::send_to(int player,const octetStream& o) const
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "sending to " << player << endl;
|
||||
#endif
|
||||
TimeScope ts(comm_stats["Sending directly"].add(o));
|
||||
TimeScope ts(comm_stats["Sending directly"].add(o, player));
|
||||
send_to_no_stats(player, o);
|
||||
sent += o.get_length();
|
||||
}
|
||||
@@ -405,12 +402,9 @@ void Player::receive_all(vector<octetStream>& os) const
|
||||
|
||||
void Player::receive_player(int i,octetStream& o) const
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "receiving from " << i << endl;
|
||||
#endif
|
||||
TimeScope ts(timer);
|
||||
receive_player_no_stats(i, o);
|
||||
comm_stats["Receiving directly"].add(o, ts);
|
||||
comm_stats["Receiving directly"].add(o, ts, i);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -484,10 +478,7 @@ void MultiPlayer<T>::exchange_no_stats(int other, const octetStream& o, octetStr
|
||||
|
||||
void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "Exchanging with " << other << endl;
|
||||
#endif
|
||||
TimeScope ts(comm_stats["Exchanging"].add(o));
|
||||
TimeScope ts(comm_stats["Exchanging"].add(o, other));
|
||||
exchange_no_stats(other, o, to_receive);
|
||||
sent += o.get_length();
|
||||
}
|
||||
@@ -603,9 +594,8 @@ void Player::send_receive_all(const vector<vector<bool>>& channels,
|
||||
if (i != my_num() and channels.at(my_num()).at(i))
|
||||
{
|
||||
data += to_send.at(i).get_length();
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "Send " << to_send.at(i).get_length() << " to " << i << endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("detailed_verbose_comm"))
|
||||
cerr << "Send " << to_send.at(i).get_length() << " bytes to " << i << endl;
|
||||
}
|
||||
TimeScope ts(comm_stats["Sending/receiving"].add(data));
|
||||
sent += data;
|
||||
@@ -879,15 +869,22 @@ Timer& CommStatsWithName::add_length_only(size_t length)
|
||||
return stats.add_length_only(length);
|
||||
}
|
||||
|
||||
Timer& CommStatsWithName::add(const octetStream& os)
|
||||
Timer& CommStatsWithName::add(const octetStream& os, int player)
|
||||
{
|
||||
return add(os.get_length());
|
||||
return add(os.get_length(), player);
|
||||
}
|
||||
|
||||
Timer& CommStatsWithName::add(size_t length)
|
||||
Timer& CommStatsWithName::add(size_t length, int player)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("verbose_comm"))
|
||||
fprintf(stderr, "%s %zu bytes\n", name.c_str(), length);
|
||||
{
|
||||
if (player < 0)
|
||||
fprintf(stderr, "%s %zu bytes\n", name.c_str(), length);
|
||||
else
|
||||
fprintf(stderr, "%s %zu bytes with party %d\n", name.c_str(), length,
|
||||
player);
|
||||
}
|
||||
|
||||
return stats.add(length);
|
||||
}
|
||||
|
||||
|
||||
@@ -160,9 +160,10 @@ public:
|
||||
name(name), stats(stats) {}
|
||||
|
||||
Timer& add_length_only(size_t length);
|
||||
Timer& add(const octetStream& os);
|
||||
Timer& add(size_t length);
|
||||
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
|
||||
Timer& add(const octetStream& os, int player = -1);
|
||||
Timer& add(size_t length, int player = -1);
|
||||
void add(const octetStream& os, const TimeScope& scope, int player = -1)
|
||||
{ add(os, player) += scope; }
|
||||
};
|
||||
|
||||
class NamedCommStats : public map<string, CommStats>
|
||||
@@ -272,7 +273,7 @@ public:
|
||||
/**
|
||||
* Send to a specific player
|
||||
*/
|
||||
void send_to(int player,const octetStream& o) const;
|
||||
virtual void send_to(int player,const octetStream& o) const;
|
||||
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
|
||||
/**
|
||||
* Receive from all other players.
|
||||
@@ -282,7 +283,7 @@ public:
|
||||
/**
|
||||
* Receive from a specific player
|
||||
*/
|
||||
void receive_player(int i,octetStream& o) const;
|
||||
virtual void receive_player(int i,octetStream& o) const;
|
||||
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
|
||||
virtual void receive_player(int i,FlexBuffer& buffer) const;
|
||||
|
||||
@@ -546,6 +547,8 @@ public:
|
||||
|
||||
size_t send(const PlayerBuffer& buffer, bool block) const;
|
||||
size_t recv(const PlayerBuffer& buffer, bool block) const;
|
||||
|
||||
NamedCommStats get_comm_stats() const { return comm_stats; }
|
||||
};
|
||||
|
||||
class RealTwoPartyPlayer : public VirtualTwoPartyPlayer
|
||||
|
||||
@@ -243,16 +243,23 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
|
||||
CODE_LOCATION
|
||||
typedef typename W::input_type::share_type::open_type T;
|
||||
|
||||
auto nTriplesPerLoop = this->nTriplesPerLoop * 10;
|
||||
auto nTriplesPerLoop = this->nTriplesPerLoop;
|
||||
auto& valueBits = this->valueBits;
|
||||
auto& share_prg = this->share_prg;
|
||||
auto& ot_multipliers = this->ot_multipliers;
|
||||
auto& nparties = this->nparties;
|
||||
auto& globalPlayer = this->globalPlayer;
|
||||
|
||||
if (this->thread_num >= 0)
|
||||
nTriplesPerLoop *= 10;
|
||||
|
||||
// extra value for sacrifice
|
||||
int toCheck = nTriplesPerLoop
|
||||
+ DIV_CEIL(W::mac_key_type::size_in_bits(), T::size_in_bits());
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_input"))
|
||||
fprintf(stderr, "generating %d input tuples\n", toCheck);
|
||||
|
||||
valueBits.resize(1);
|
||||
this->signal_multipliers({player, toCheck});
|
||||
bool mine = player == globalPlayer.my_num();
|
||||
|
||||
@@ -77,6 +77,22 @@ void OTExtensionWithMatrix::protocol_agreement()
|
||||
if (OnlineOptions::singleton.has_option("high_softspoken"))
|
||||
softspoken_k = 8;
|
||||
|
||||
if (OnlineOptions::singleton.has_param("softspoken"))
|
||||
softspoken_k = OnlineOptions::singleton.get_param("softspoken");
|
||||
|
||||
int needed = DIV_CEIL(nbaseOTs, softspoken_k) * softspoken_k;
|
||||
|
||||
baseReceiverInput.resize_zero(needed);
|
||||
|
||||
for (int i = nbaseOTs; i < needed; i++)
|
||||
{
|
||||
auto zero = string(SEED_SIZE, '\0');
|
||||
G_receiver.push_back(zero);
|
||||
G_sender.push_back({});
|
||||
for (int j = 0; j < 2; j++)
|
||||
G_sender.back().push_back(zero);
|
||||
}
|
||||
|
||||
bundle.mine.store(softspoken_k);
|
||||
|
||||
player->unchecked_broadcast(bundle);
|
||||
@@ -177,7 +193,8 @@ void OTExtensionWithMatrix::soft_sender(size_t n)
|
||||
return;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_ot"))
|
||||
fprintf(stderr, "%zu OTs as sender\n", n);
|
||||
fprintf(stderr, "%zu OTs as sender (%s)\n", n,
|
||||
passive_only ? "semi-honest" : "malicious");
|
||||
|
||||
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
|
||||
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k);
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
template <class T>
|
||||
void OTVoleBase<T>::evaluate(vector<T>& output, const vector<T>& newReceiverInput) {
|
||||
CODE_LOCATION
|
||||
if (OnlineOptions::singleton.has_option("verbose_vole"))
|
||||
fprintf(stderr, "%d-bit VOLE with %zu elements and S=%d\n", T::N_BITS,
|
||||
newReceiverInput.size(), S);
|
||||
|
||||
const int N1 = newReceiverInput.size() + 1;
|
||||
output.resize(newReceiverInput.size());
|
||||
auto& os = oss;
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sodium.h>
|
||||
#include <regex>
|
||||
using namespace std;
|
||||
|
||||
BaseMachine* BaseMachine::singleton = 0;
|
||||
@@ -66,16 +67,21 @@ int BaseMachine::triple_bucket_size(DataFieldType type)
|
||||
int BaseMachine::bucket_size(size_t usage)
|
||||
{
|
||||
int res = OnlineOptions::singleton.bucket_size;
|
||||
int min = res;
|
||||
|
||||
if (usage)
|
||||
{
|
||||
for (int B = res; B <= 5; B++)
|
||||
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
|
||||
res = 5;
|
||||
for (int B = res; B >= min; B--)
|
||||
if (ShuffleSacrifice(B).minimum_n_outputs() > usage * 1.1)
|
||||
break;
|
||||
else
|
||||
res = B;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
fprintf(stderr, "bucket_size=%d usage=%zu\n", res, usage);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -103,8 +109,13 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool BaseMachine::allow_mulm()
|
||||
{
|
||||
return singleton and singleton->relevant_opts.find("no_mulm") != string::npos;
|
||||
}
|
||||
|
||||
BaseMachine::BaseMachine() :
|
||||
nthreads(0), multithread(false), nan_warning(0)
|
||||
nthreads(0), multithread(false), nan_warning(0), mini_warning(0)
|
||||
{
|
||||
if (sodium_init() == -1)
|
||||
throw runtime_error("couldn't initialize libsodium");
|
||||
@@ -182,6 +193,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
getline(inpf, relevant_opts);
|
||||
getline(inpf, security);
|
||||
getline(inpf, gf2n);
|
||||
getline(inpf, expected_communication);
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
@@ -320,17 +332,47 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(stats.sent);
|
||||
P.Broadcast_Receive_no_stats(bundle);
|
||||
size_t global = 0;
|
||||
long long global = 0;
|
||||
for (auto& os : bundle)
|
||||
global += os.get_int(8);
|
||||
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
|
||||
|
||||
smatch what;
|
||||
regex comm_regexp("online:([0-9]*) offline:([0-9]*) n_parties:([0-9]*)");
|
||||
if (regex_search(expected_communication, what, comm_regexp))
|
||||
{
|
||||
long long expected = stoll(what[1]) + stoll(what[2]);
|
||||
int n_parties = stoi(what[3]);
|
||||
if (expected and n_parties != P.num_players())
|
||||
{
|
||||
cerr << "Wrong number of parties in compiler's expectation: "
|
||||
<< n_parties << endl;
|
||||
}
|
||||
else if (expected)
|
||||
{
|
||||
double over = round(100. * (global - expected) / expected);
|
||||
if (over >= 5)
|
||||
cerr
|
||||
<< "Actual communication exceeds the compiler's expectation by "
|
||||
<< over << " percent." << endl;
|
||||
if (over < 0)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("overestimate"))
|
||||
cerr << "Actual communication is below the compiler's "
|
||||
"expectation by " << -over << " percent." << endl;
|
||||
else
|
||||
cerr << "The compiler overestimated the communication." << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
{
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
if (x.first.find("transmission") == string::npos)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << P.my_num() << " only";
|
||||
if (multithread)
|
||||
@@ -341,3 +383,9 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
|
||||
print_global_comm(P, comm_stats);
|
||||
}
|
||||
|
||||
void BaseMachine::add_one_off(const NamedCommStats& comm)
|
||||
{
|
||||
if (has_singleton())
|
||||
s().one_off_comm += comm;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ void print_usage(ostream& o, const char* name, size_t capacity);
|
||||
class BaseMachine
|
||||
{
|
||||
friend class Program;
|
||||
template<class sint, class sgf2n> friend class thread_info;
|
||||
|
||||
protected:
|
||||
static BaseMachine* singleton;
|
||||
@@ -38,6 +39,9 @@ protected:
|
||||
string relevant_opts;
|
||||
string security;
|
||||
string gf2n;
|
||||
string expected_communication;
|
||||
|
||||
NamedCommStats one_off_comm;
|
||||
|
||||
virtual size_t load_program(const string& threadname,
|
||||
const string& filename);
|
||||
@@ -60,6 +64,7 @@ public:
|
||||
vector<Program> progs;
|
||||
|
||||
bool nan_warning;
|
||||
int mini_warning;
|
||||
|
||||
static BaseMachine& s();
|
||||
static bool has_singleton() { return singleton != 0; }
|
||||
@@ -75,7 +80,8 @@ public:
|
||||
static int security_from_schedule(string progname);
|
||||
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0,
|
||||
int factor = 0);
|
||||
template<class T>
|
||||
static int input_batch_size(int player, int buffer_size = 0);
|
||||
template<class T>
|
||||
@@ -86,6 +92,10 @@ public:
|
||||
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
|
||||
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
|
||||
|
||||
static bool allow_mulm();
|
||||
|
||||
static void add_one_off(const NamedCommStats& comm);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
|
||||
@@ -110,6 +120,8 @@ public:
|
||||
void print_comm(Player& P, const NamedCommStats& stats);
|
||||
|
||||
virtual const Names& get_N() { throw not_implemented(); }
|
||||
|
||||
virtual void gap_warning(int) { throw not_implemented(); }
|
||||
};
|
||||
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
@@ -118,7 +130,8 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback,
|
||||
int factor)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
|
||||
@@ -133,7 +146,8 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
else if (fallback > 0)
|
||||
n_opts = fallback;
|
||||
else
|
||||
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
|
||||
n_opts = OnlineOptions::singleton.batch_size
|
||||
* max(factor, T::default_length);
|
||||
|
||||
if (buffer_size <= 0 and has_program())
|
||||
{
|
||||
@@ -180,9 +194,17 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
res = n_opts;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
{
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
|
||||
<< " buffer_size=" << buffer_size << endl;
|
||||
<< " buffer_size=" << buffer_size << " bits/dabits="
|
||||
<< T::LivePrep::bits_from_dabits() << "/"
|
||||
<< T::LivePrep::dabits_from_bits() << " has_program="
|
||||
<< has_program();
|
||||
if (program)
|
||||
cerr << " program=" << program->get_name();
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
assert(res > 0);
|
||||
return res;
|
||||
|
||||
@@ -34,7 +34,9 @@ public:
|
||||
+ ", have you generated edaBits, "
|
||||
"for example by running "
|
||||
"'./Fake-Offline.x -e "
|
||||
+ to_string(n_bits) + " ...'?");
|
||||
+ to_string(n_bits)
|
||||
+ T::template proto_fake_opts<typename T::clear>()
|
||||
+ " ...'?");
|
||||
}
|
||||
|
||||
assert(BufferBase::file);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "GC/instructions.h"
|
||||
|
||||
#include "Memory.hpp"
|
||||
#include "Instruction.hpp"
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
|
||||
@@ -1501,8 +1501,6 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
|
||||
auto& processor = Proc.Procb;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
BaseMachine::program = this;
|
||||
|
||||
while (Proc.PC<size)
|
||||
{
|
||||
Proc.last_PC = Proc.PC;
|
||||
@@ -1559,7 +1557,9 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
|
||||
template<class T>
|
||||
void Program::mulm_check() const
|
||||
{
|
||||
if (T::function_dependent and not OnlineOptions::singleton.has_option("allow_mulm"))
|
||||
if (T::function_dependent
|
||||
and not (BaseMachine::allow_mulm()
|
||||
or OnlineOptions::singleton.has_option("allow_mulm")))
|
||||
throw runtime_error("Mixed multiplication not implemented for function-dependent preprocessing. "
|
||||
"Use '-E <protocol>' during compilation or state "
|
||||
"'program.use_mulm = False' at the beginning of your high-level program.");
|
||||
|
||||
@@ -54,6 +54,9 @@ class Machine : public BaseMachine
|
||||
|
||||
NamedCommStats max_comm;
|
||||
|
||||
int max_trunc_size;
|
||||
Lock warn_lock;
|
||||
|
||||
size_t load_program(const string& threadname, const string& filename);
|
||||
|
||||
void prepare(const string& progname_str);
|
||||
@@ -126,6 +129,8 @@ class Machine : public BaseMachine
|
||||
Player& get_player() { return *P; }
|
||||
|
||||
void check_program();
|
||||
|
||||
void gap_warning(int k);
|
||||
};
|
||||
|
||||
#endif /* MACHINE_H_ */
|
||||
|
||||
@@ -55,6 +55,7 @@ template<class sint, class sgf2n>
|
||||
Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
const OnlineOptions opts)
|
||||
: my_number(playerNames.my_num()), N(playerNames),
|
||||
max_trunc_size(0),
|
||||
use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts),
|
||||
external_clients(my_number)
|
||||
{
|
||||
@@ -607,6 +608,12 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
if (multithread)
|
||||
cerr << " (overall core time)";
|
||||
cerr << endl;
|
||||
auto& P = *this->P;
|
||||
auto one_off = TreeSum<Z2<64>>().run(
|
||||
this->one_off_comm.sent, P).get_limb(0);
|
||||
if (one_off)
|
||||
cerr << "One-off global communication: " << one_off * 1e-6 << " MB"
|
||||
<< endl;
|
||||
}
|
||||
|
||||
print_timers();
|
||||
@@ -685,12 +692,17 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
<< "have you considered using " << alt << " instead?" << endl;
|
||||
}
|
||||
|
||||
if (nan_warning and sint::real_shares(*P))
|
||||
if ((nan_warning or mini_warning) and sint::real_shares(*P))
|
||||
{
|
||||
cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. See ";
|
||||
cerr << "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix";
|
||||
if (nan_warning)
|
||||
cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. ";
|
||||
if (mini_warning)
|
||||
cerr << pow(2, mini_warning) << " is the smallest non-zero number "
|
||||
<< "in a used fixed-point representation. ";
|
||||
cerr << "See https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix";
|
||||
cerr << " for details" << endl;
|
||||
nan_warning = false;
|
||||
mini_warning = 0;
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
@@ -743,6 +755,10 @@ void Machine<sint, sgf2n>::suggest_optimizations()
|
||||
cerr << "This program might benefit from some protocol options." << endl
|
||||
<< "Consider adding the following at the beginning of your code:"
|
||||
<< endl << optimizations;
|
||||
if (sint::clear::n_bits() < max_trunc_size)
|
||||
cerr << "The computation domain is too small "
|
||||
<< "for low-round truncation; it would need to have at least "
|
||||
<< max_trunc_size << " bits." << endl;
|
||||
#ifndef __clang__
|
||||
cerr << "This virtual machine was compiled with GCC. Recompile with "
|
||||
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;
|
||||
@@ -768,4 +784,15 @@ void Machine<sint, sgf2n>::check_program()
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::gap_warning(int k)
|
||||
{
|
||||
if (k > max_trunc_size)
|
||||
{
|
||||
warn_lock.lock();
|
||||
max_trunc_size = max(k, max_trunc_size);
|
||||
warn_lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -37,6 +37,16 @@ public:
|
||||
#endif
|
||||
}
|
||||
|
||||
const T* begin() const
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
const T* end() const
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
virtual T& operator[](size_t i) = 0;
|
||||
virtual const T& operator[](size_t i) const = 0;
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
printf("\tClient %d about to run %d\n",num,program);
|
||||
#endif
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
online_prep_timer -= Proc.prep_time();
|
||||
Proc.reset(progs[program], job.arg);
|
||||
|
||||
// Bits, Triples, Squares, and Inverses skipping
|
||||
@@ -278,6 +278,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
//printf("\tExecuting program");
|
||||
// Execute the program
|
||||
BaseMachine::program = &progs[program];
|
||||
progs[program].execute(Proc);
|
||||
|
||||
// make sure values used in other threads are safe
|
||||
@@ -298,7 +299,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
"in thread %d\n", program, num);
|
||||
#endif
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
online_prep_timer += Proc.prep_time();
|
||||
wait_timer.start();
|
||||
queues->finished(job, P.total_comm());
|
||||
wait_timer.stop();
|
||||
@@ -307,10 +308,10 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
// final check
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
online_prep_timer -= Proc.prep_time();
|
||||
Proc.check();
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
online_prep_timer += Proc.prep_time();
|
||||
|
||||
if (machine.opts.file_prep_per_thread)
|
||||
Proc.DataF.prune();
|
||||
|
||||
@@ -9,10 +9,12 @@
|
||||
#include "Math/gfpvar.h"
|
||||
#include "Protocols/HemiOptions.h"
|
||||
#include "Protocols/config.h"
|
||||
#include "FHEOffline/config.h"
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
#include <boost/filesystem.hpp>
|
||||
#include <regex>
|
||||
|
||||
using namespace std;
|
||||
|
||||
@@ -40,6 +42,8 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
max_broadcast = 0;
|
||||
receive_threads = false;
|
||||
code_locations = false;
|
||||
have_warned_about_comp_sec = false;
|
||||
semi_honest = false;
|
||||
#ifdef VERBOSE
|
||||
verbose = true;
|
||||
#else
|
||||
@@ -161,6 +165,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
|
||||
opt.get("--options")->getStrings(options);
|
||||
|
||||
for (auto& option : options)
|
||||
if (option.find("verbose") == 0)
|
||||
verbose = true;
|
||||
|
||||
code_locations = opt.isSet("--code-locations");
|
||||
|
||||
#ifdef THROW_EXCEPTIONS
|
||||
@@ -463,6 +471,7 @@ void OnlineOptions::finalize_with_error(ez::ezOptionParser& opt)
|
||||
o->getString(disk_memory);
|
||||
|
||||
receive_threads = opt.isSet("--threads");
|
||||
semi_honest = opt.isSet("--semi-honest");
|
||||
|
||||
if (use_security_parameter)
|
||||
{
|
||||
@@ -505,3 +514,35 @@ int OnlineOptions::prime_limbs()
|
||||
{
|
||||
return DIV_CEIL(prime_length(), 64);
|
||||
}
|
||||
|
||||
bool OnlineOptions::has_param(const string& param)
|
||||
{
|
||||
for (auto& x : options)
|
||||
if (x.find(param + "=") == 0)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
int OnlineOptions::get_param(const string& param)
|
||||
{
|
||||
basic_regex re(param + "=([0-9]+)");
|
||||
smatch match;
|
||||
for (auto& x : options)
|
||||
if (regex_match(x, match, re))
|
||||
return atoi(match[1].str().c_str());
|
||||
throw runtime_error("parameter not found: " + param);
|
||||
}
|
||||
|
||||
int OnlineOptions::comp_sec()
|
||||
{
|
||||
int res = COMP_SEC;
|
||||
if (has_param("comp_sec"))
|
||||
res = get_param("comp_sec");
|
||||
if (res < 128 and not have_warned_about_comp_sec)
|
||||
{
|
||||
cerr << "WARNING: computational security parameter " << res
|
||||
<< " suitable for testing only" << endl;
|
||||
have_warned_about_comp_sec = true;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ class OnlineOptions
|
||||
{
|
||||
void finalize_with_error(ez::ezOptionParser& opt);
|
||||
|
||||
bool have_warned_about_comp_sec;
|
||||
|
||||
public:
|
||||
static OnlineOptions singleton;
|
||||
|
||||
@@ -44,6 +46,7 @@ public:
|
||||
vector<string> options;
|
||||
string executable;
|
||||
bool code_locations;
|
||||
bool semi_honest;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
@@ -79,6 +82,11 @@ public:
|
||||
{
|
||||
return find(options.begin(), options.end(), option) != options.end();
|
||||
}
|
||||
|
||||
bool has_param(const string& param);
|
||||
int get_param(const string& param);
|
||||
|
||||
int comp_sec();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_ONLINEOPTIONS_H_ */
|
||||
|
||||
@@ -97,6 +97,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
|
||||
if (T::semi_honest_option)
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Semi-honest operation (default: malicious security)"
|
||||
// Help description.
|
||||
"-sh", // Flag token.
|
||||
"--semi-honest" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -310,6 +310,8 @@ class Processor : public ArithmeticProcessor
|
||||
|
||||
void call_tape(int tape_number, int arg, const vector<int>& results);
|
||||
|
||||
TimerWithComm prep_time();
|
||||
|
||||
private:
|
||||
|
||||
template<class T> friend class SPDZ;
|
||||
|
||||
@@ -638,9 +638,6 @@ void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
int batchStartI = 0;
|
||||
int batchStartJ = 0;
|
||||
|
||||
size_t sourceSize = source.size();
|
||||
const T* sourceData = source.data();
|
||||
|
||||
protocol.init_dotprod();
|
||||
for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) {
|
||||
auto output = S.begin() + matmulArgs[0];
|
||||
@@ -654,27 +651,54 @@ void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
|
||||
assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end());
|
||||
|
||||
for (int j = 0; j < resultNumberOfColumns; j += 1) {
|
||||
auto actualSecondFactorColumn =
|
||||
Proc->get_Ci().at(matmulArgs[9] + j).get();
|
||||
auto secondBase = source.begin() + secondFactorBase
|
||||
+ actualSecondFactorColumn;
|
||||
for (auto &x : Range(Proc->get_Ci(), matmulArgs[8],
|
||||
usedNumberOfFirstFactorColumns))
|
||||
assert(
|
||||
secondBase + x.get() * secondFactorTotalNumberOfColumns
|
||||
< source.end());
|
||||
}
|
||||
|
||||
vector<long> second_factors;
|
||||
second_factors.reserve(usedNumberOfFirstFactorColumns);
|
||||
|
||||
for (auto& x : Range(Proc->get_Ci(), matmulArgs[8],
|
||||
usedNumberOfFirstFactorColumns))
|
||||
second_factors.push_back(x.get() * secondFactorTotalNumberOfColumns);
|
||||
|
||||
for (int i = 0; i < resultNumberOfRows; i += 1) {
|
||||
auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get();
|
||||
auto firstBase = source.begin() + firstFactorBase
|
||||
+ actualFirstFactorRow * firstFactorTotalNumberOfColumns;
|
||||
|
||||
for (auto& x : Range(Proc->get_Ci(), matmulArgs[7],
|
||||
usedNumberOfFirstFactorColumns))
|
||||
assert(firstBase + x.get() < source.end());
|
||||
|
||||
for (int j = 0; j < resultNumberOfColumns; j += 1) {
|
||||
auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get();
|
||||
auto secondBase = source.begin() + secondFactorBase
|
||||
+ actualSecondFactorColumn;
|
||||
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Preparing " << i << "," << j << "(buffer size: " << protocol.get_buffer_size() << ")" << endl;
|
||||
#endif
|
||||
|
||||
for (int k = 0; k < usedNumberOfFirstFactorColumns; k += 1) {
|
||||
auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get();
|
||||
auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get();
|
||||
auto second_it = second_factors.begin();
|
||||
|
||||
auto firstAddress = firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn;
|
||||
auto secondAddress = secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn;
|
||||
for (auto& x : Range(Proc->get_Ci(), matmulArgs[7],
|
||||
usedNumberOfFirstFactorColumns))
|
||||
{
|
||||
auto actualFirstFactorColumn = x.get();
|
||||
|
||||
assert(firstAddress < sourceSize);
|
||||
assert(secondAddress < sourceSize);
|
||||
auto first = firstBase + actualFirstFactorColumn;
|
||||
auto second = secondBase + *second_it++;
|
||||
|
||||
protocol.prepare_dotprod(sourceData[firstAddress], sourceData[secondAddress]);
|
||||
protocol.prepare_dotprod(*first, *second);
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
|
||||
@@ -905,9 +929,19 @@ void Conv2dTuple::post(StackedVector<T>& S, typename T::Protocol& protocol)
|
||||
template<class T>
|
||||
void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
|
||||
{
|
||||
typename T::Protocol::Shuffler(S, instruction.get_size(),
|
||||
instruction.get_n(), instruction.get_r(0), instruction.get_r(1),
|
||||
*this);
|
||||
size_t n = instruction.get_size();
|
||||
size_t unit_size = instruction.get_n();
|
||||
size_t output_base = instruction.get_r(0);
|
||||
size_t input_base = instruction.get_r(1);
|
||||
|
||||
typename T::Protocol::Shuffler shuffler(*this);
|
||||
|
||||
typename T::Protocol::Shuffler::shuffle_type shuffle;
|
||||
shuffler.generate(n / unit_size, shuffle);
|
||||
|
||||
vector<ShuffleTuple<T>> shuffles{ShuffleTuple<T>(n, output_base,
|
||||
input_base, unit_size, shuffle, true)};
|
||||
shuffler.apply_multiple(S, shuffles);
|
||||
|
||||
maybe_check();
|
||||
}
|
||||
@@ -916,7 +950,10 @@ template<class T>
|
||||
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
return shuffler.generate(instruction.get_n(), shuffle_store);
|
||||
size_t n = instruction.get_n();
|
||||
auto res = shuffle_store.add(n);
|
||||
shuffler.generate(n, shuffle_store.get(res).second);
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -926,21 +963,18 @@ void SubProcessor<T>::apply_shuffle(const Instruction& instruction,
|
||||
const auto& args = instruction.get_start();
|
||||
|
||||
const auto n_shuffles = args.size() / 6;
|
||||
vector<size_t> sizes(n_shuffles, 0);
|
||||
vector<size_t> destinations(n_shuffles, 0);
|
||||
vector<size_t> sources(n_shuffles, 0);
|
||||
vector<size_t> unit_sizes(n_shuffles, 0);
|
||||
vector<size_t> shuffles(n_shuffles, 0);
|
||||
vector<bool> reverse(n_shuffles, false);
|
||||
for (size_t i = 0; i < n_shuffles; i++) {
|
||||
sizes[i] = args[6 * i];
|
||||
destinations[i] = args[6 * i + 1];
|
||||
sources[i] = args[6 * i + 2];
|
||||
unit_sizes[i] = args[6 * i + 3];
|
||||
shuffles[i] = Proc->read_Ci(args[6 * i + 4]);
|
||||
reverse[i] = args[6 * i + 5];
|
||||
vector<ShuffleTuple<T>> shuffles;
|
||||
|
||||
for (size_t i = 0; i < n_shuffles; i++)
|
||||
{
|
||||
shuffles.push_back(
|
||||
ShuffleTuple<T>(args[6 * i], args[6 * i + 1], args[6 * i + 2],
|
||||
args[6 * i + 3],
|
||||
shuffle_store.get(Proc->read_Ci(args[6 * i + 4])),
|
||||
bool(args[6 * i + 5])));
|
||||
}
|
||||
shuffler.apply_multiple(S, sizes, destinations, sources, unit_sizes, shuffles, reverse, shuffle_store);
|
||||
|
||||
shuffler.apply_multiple(S, shuffles);
|
||||
|
||||
maybe_check();
|
||||
}
|
||||
@@ -1184,4 +1218,13 @@ void Processor<sint, sgf2n>::call_tape(int tape_number, int arg,
|
||||
arg_stack.pop_back();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
TimerWithComm Processor<sint, sgf2n>::prep_time()
|
||||
{
|
||||
auto res = DataF.total_time();
|
||||
res += Procp.protocol.prep_time();
|
||||
res += Proc2.protocol.prep_time();
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -42,6 +42,8 @@ class Program
|
||||
|
||||
size_t size() const { return p.size(); }
|
||||
|
||||
string get_name() const { return name; }
|
||||
|
||||
// Read in a program
|
||||
void parse(string filename);
|
||||
void parse_with_error(string filename);
|
||||
|
||||
@@ -31,10 +31,24 @@ HonestMajorityRingMachine<U, V>::HonestMajorityRingMachine(int argc, const char*
|
||||
RingMachine<U, V, HonestMajorityMachine>(argc, argv, opt, online_opts, nplayers);
|
||||
}
|
||||
|
||||
inline void ring_domain_error(int R)
|
||||
inline void ring_domain_error(int R, int max)
|
||||
{
|
||||
cerr << "not compiled for " << R << "-bit computation, " << endl;
|
||||
cerr << "compile with -DRING_SIZE=" << R << endl;
|
||||
cerr << "The virtual machine is not compiled for " << R
|
||||
<< "-bit computation." << endl;
|
||||
cerr << "Compile with 'MY_CFLAGS += -DRING_SIZE=" << R
|
||||
<< "' in 'CONFIG.mine'";
|
||||
(void) max;
|
||||
#ifndef FEWER_RINGS
|
||||
for (int r = 0; r <= max; r += 64)
|
||||
{
|
||||
if (r >= R)
|
||||
{
|
||||
cerr << " or try " << "'-R " << r << "'";
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
cerr << "." << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -60,7 +74,7 @@ RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
|
||||
X(RING_SIZE)
|
||||
#endif
|
||||
#undef X
|
||||
ring_domain_error(R);
|
||||
ring_domain_error(R, 192);
|
||||
}
|
||||
|
||||
template<template<int K, int S> class U, template<class T> class V>
|
||||
@@ -98,7 +112,7 @@ HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecuri
|
||||
X(72) X(128)
|
||||
#endif
|
||||
#undef X
|
||||
ring_domain_error(R);
|
||||
ring_domain_error(R, 128);
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_RINGMACHINE_HPP_ */
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "OnlineOptions.h"
|
||||
#include "BaseMachine.h"
|
||||
#include "GC/ArgTuples.h"
|
||||
|
||||
template<class T> class StackedVector;
|
||||
@@ -106,9 +107,12 @@ public:
|
||||
TruncPrTupleWithGap(vector<int>::const_iterator it) :
|
||||
TruncPrTuple<T>(it)
|
||||
{
|
||||
big_gap_ = this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error;
|
||||
int min_size = this->k + OnlineOptions::singleton.trunc_error;
|
||||
big_gap_ = min_size <= T::n_bits();
|
||||
if (T::prime_field and small_gap())
|
||||
throw runtime_error("domain too small for chosen truncation error");
|
||||
if (small_gap() and BaseMachine::has_singleton())
|
||||
BaseMachine::s().gap_warning(min_size);
|
||||
}
|
||||
|
||||
T upper(T mask)
|
||||
|
||||
@@ -202,7 +202,7 @@
|
||||
*dest++ = *op1++ > *op2++) \
|
||||
X(EQC, auto dest = &Ci[r[0]]; auto op1 = &Ci[r[1]]; auto op2 = &Ci[r[2]], \
|
||||
*dest++ = *op1++ == *op2++) \
|
||||
X(PRINTINT, Proc.out << Proc.read_Ci(r[0]) << flush,) \
|
||||
X(PRINTINT, print(Proc.out, &Proc.get_Ci_ref(r[0])),) \
|
||||
X(PRINTFLOATPREC, Proc.out << setprecision(n),) \
|
||||
X(PRINTSTR, Proc.out << string((char*)&n,4) << flush,) \
|
||||
X(PRINTCHR, Proc.out << string((char*)&n,1) << flush,) \
|
||||
|
||||
5
Programs/Source/and-bench.py
Normal file
5
Programs/Source/and-bench.py
Normal file
@@ -0,0 +1,5 @@
|
||||
a = sbits.get_type(int(program.args[1]))(0)
|
||||
|
||||
@for_range(int(program.args[2]))
|
||||
def _(i):
|
||||
a & a
|
||||
11
Programs/Source/combo-bench.py
Normal file
11
Programs/Source/combo-bench.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import math
|
||||
|
||||
n = int(program.args[1])
|
||||
n_sqrt = int(math.sqrt(n))
|
||||
|
||||
sfix.Matrix(n_sqrt, 10) * sfix.Matrix(10, n_sqrt)
|
||||
(sfix(0, size=n) < 0).store_in_mem(0)
|
||||
|
||||
sint.Array(n).secure_shuffle()
|
||||
|
||||
sint(personal(0, cint(0, size=n)))
|
||||
15
Programs/Source/comp-bench.py
Normal file
15
Programs/Source/comp-bench.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#sfix.set_precision(32, 63)
|
||||
#program.use_trunc_pr = True
|
||||
#program.use_split(3)
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program)
|
||||
try:
|
||||
n_loops = int(program.args[2])
|
||||
except:
|
||||
n_loops = 1
|
||||
|
||||
a = sfix(cint(0, size=int(program.args[1])))
|
||||
|
||||
@for_range(n_loops)
|
||||
def _(i):
|
||||
(a < a)#.store_in_mem(0)
|
||||
10
Programs/Source/fdiv-bench.py
Normal file
10
Programs/Source/fdiv-bench.py
Normal file
@@ -0,0 +1,10 @@
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program)
|
||||
|
||||
n = int(program.args[1])
|
||||
m = int(program.args[2])
|
||||
a = sfix(0, size=n)
|
||||
|
||||
@for_range(m)
|
||||
def _(i):
|
||||
(a / a).store_in_mem(0)
|
||||
14
Programs/Source/fmul-bench.py
Normal file
14
Programs/Source/fmul-bench.py
Normal file
@@ -0,0 +1,14 @@
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program)
|
||||
|
||||
try:
|
||||
n = int(program.args[1])
|
||||
except:
|
||||
n = 10 ** 6
|
||||
|
||||
m = int(program.args[2])
|
||||
a = sfix(0, size=n)
|
||||
|
||||
@for_range(m)
|
||||
def _(i):
|
||||
(a * a).store_in_mem(0)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user