Expected communication cost in compiler.

This commit is contained in:
Marcel Keller
2025-12-24 13:46:43 +11:00
parent f10864f85e
commit bf7f8f4b65
194 changed files with 2768 additions and 884 deletions

View File

@@ -1,5 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.4.2 (Dec 24, 2025)
- Expected communication cost in compiler
- Semi-honest option of Rep4
- Reduced communication for preprocessing in Dealer protocol
- Option of choosing SoftSpoken parameter at run-time
- BERT functionality (@hiddely)
- Recommended reading list in documentation
## 0.4.1 (May 30, 2025)
- Add protocols with function-dependent preprocessing (https://eprint.iacr.org/2025/919)

View File

@@ -618,6 +618,10 @@ class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable):
code = opcodes['REVEAL']
arg_format = tools.cycle(['int','cbw','sb'])
def add_usage(self, req_node):
req_node.increment(('bit', 'open'), sum(
int(math.ceil(x / 64)) * 8 for x in self.args[0::3]))
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
""" Copy private input to secret bit register vectors. The input is
read as floating-point number, multiplied by a power of two, and then

View File

@@ -24,9 +24,14 @@ from functools import reduce
class _binary:
def __or__(self, other):
return self ^ other ^ (self & other)
__ror__ = __or__
def reveal_to(self, *args, **kwargs):
raise CompilerError(
'%s does not support revealing to individual players' % type(self))
@staticmethod
def direct_matrix_mul(*args, **kwargs):
raise AttributeError('direct matrix multiplication only supported '
'in arithmetic circuits')
class bits(Tape.Register, _structure, _bit, _binary):
n = 40
@@ -432,7 +437,7 @@ class cbits(bits):
inst.convcbitvec(self.n, res, self)
return res
class sbits(bits):
class sbits(bits, Tape._no_secret_truth):
"""
Secret bits register. This type supports basic bit-wise operations::
@@ -697,7 +702,7 @@ class sbits(bits):
def output(self):
inst.print_reg_plainsb(self)
class sbitvec(_vec, _bit, _binary):
class sbitvec(Tape._no_secret_truth, _vec, _bit, _binary):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
This facilitates parallel arithmetic operations in binary circuits.
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -907,35 +912,27 @@ class sbitvec(_vec, _bit, _binary):
def __init__(self, elements=None, length=None, input_length=None):
if length:
assert isinstance(elements, sint)
if Program.prog.use_split():
x = elements.split_to_two_summands(length)
v = sbitint.bit_adder(x[0], x[1])
else:
prog = Program.prog
if not prog.options.ring:
# force the use of edaBits
backup = prog.use_edabit()
prog.use_edabit(True)
self.v = prog.non_linear.bit_dec(
elements, max(length, input_length or prog.bit_length),
length, maybe_mixed=True)
assert isinstance(self.v[0], sbits)
prog.use_edabit(backup)
return
comparison.require_ring_size(length, 'A2B conversion')
l = int(Program.prog.options.ring)
r, r_bits = sint.get_edabit(length, size=elements.size)
c = ((elements - r) << (l - length)).reveal()
c >>= l - length
cb = [(c >> i) for i in range(length)]
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
prog = Program.prog
backup = prog.use_edabit()
if not prog.have_a2b():
# force the use of edaBits
prog.use_edabit(True)
self.v = prog.non_linear.bit_dec(
elements, max(length, input_length or prog.bit_length),
length, maybe_mixed=True)
assert isinstance(self.v[0], sbits)
prog.use_edabit(backup)
elif isinstance(elements, sbitvec):
self.v = elements.v
elif isinstance(elements, (list, tuple)) and \
isinstance(elements[0], sbitvec):
self.v = sbitvec(sum((x.elements() for x in elements), [])).v
elif elements is not None and not (util.is_constant(elements) and \
elements == 0):
self.v = sbits.trans(elements)
def __str__(self):
return 'sbitvec(%s/%s)' % (len(self.v), self.size)
__repr__ = __str__
def popcnt(self):
""" Population count / Hamming weight.
@@ -961,7 +958,7 @@ class sbitvec(_vec, _bit, _binary):
return self.from_vec(x ^ y for x, y in zip(*self.expand(other)))
def __and__(self, other):
return self.from_vec(x & y for x, y in zip(*self.expand(other)))
__rxor__ = __xor__
__add__ = __radd__ = __sub__ = __rsub__ =__rxor__ = __xor__
__rand__ = __and__
def __invert__(self):
return self.from_vec(~x for x in self.v)
@@ -969,10 +966,6 @@ class sbitvec(_vec, _bit, _binary):
return util.if_else(self.v[0], x, y)
def __iter__(self):
return iter(self.elements())
def __len__(self):
return len(self.v)
def __getitem__(self, index):
return self.v[index]
@classmethod
def conv(cls, other):
if isinstance(other, cls):
@@ -999,8 +992,6 @@ class sbitvec(_vec, _bit, _binary):
return util.untuplify([x.reveal() for x in self.elements()])
def long_one(self):
return [x.long_one() for x in self.v]
def __rsub__(self, other):
return self.from_vec(y - x for x, y in zip(self.v, other))
def half_adder(self, other):
other = self.coerce(other)
res = zip(*(x.half_adder(y) for x, y in zip(self.v, other)))
@@ -1014,7 +1005,7 @@ class sbitvec(_vec, _bit, _binary):
elif len(self.v) == 1:
self, other = other, self.v[0]
else:
raise CompilerError('no operand of lenght 1: %d/%d',
raise CompilerError('no operand of length 1: %d/%d',
(len(self.v), len(other.v)))
if not isinstance(other, sbits):
return NotImplemented
@@ -1036,8 +1027,6 @@ class sbitvec(_vec, _bit, _binary):
i += 1
return sbitvec.from_vec(res)
__rmul__ = __mul__
def __add__(self, other):
return self.from_vec(x + y for x, y in zip(self.v, other))
def bit_and(self, other):
return self & other
def bit_xor(self, other):
@@ -1060,7 +1049,12 @@ class sbitvec(_vec, _bit, _binary):
@classmethod
def comp_result(cls, x):
return cls.get_type(1).from_vec([x])
def expand(self, other, expand=True):
@staticmethod
def reverse_type(other):
return isinstance(other, sbitfixvec)
equal = __eq__ = _bitint.__eq__
eqz = staticmethod(_bitint.eqz)
def expand(self, other, expand=True, copy=False):
assert not isinstance(other, sbitfixvec)
m = 1
for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []):
@@ -1076,7 +1070,10 @@ class sbitvec(_vec, _bit, _binary):
res.append([x * sbits.get_type(m)().long_one()
for x in util.bit_decompose(y, len(self.v))])
else:
v = [type(x)(x) if isinstance(x, bits) else x for x in y.v]
if copy:
v = [type(x)(x) if isinstance(x, bits) else x for x in y.v]
else:
v = y.v
res.append([x.expand(m) if (expand and isinstance(x, bits))
else x for x in v])
return res
@@ -1364,7 +1361,21 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
Vector of signed integers for parallel binary computation.
Values and vectors of signed integers for parallel binary computation::
si32 = sbitintvec.get_type(32)
print_ln('add: %s', (si32(5) + si32(3)).reveal())
print_ln('sub: %s', (si32(5) - si32(3)).reveal())
print_ln('mul: %s', (si32(5) * si32(3)).reveal())
print_ln('lt: %s', (si32(5) < si32(3)).reveal())
This should output::
add: 8
sub: 2
mul: 15
lt: 0
The following example uses vectors of size two::
sb32 = sbits.get_type(32)
@@ -1389,7 +1400,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
bit_extend = staticmethod(_complement_two_extend)
mul_functions = {}
functions = {}
@classmethod
def popcnt_bits(cls, bits):
return sbitvec.from_vec(bits).popcnt()
@@ -1406,8 +1417,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
if len(a) == 1:
res = _bitint.bit_adder(a, b, get_carry=True)
return self.get_type(32).from_vec(res, signed=False)
v = sbitint.bit_adder(a, b)
return self.get_type(len(v)).from_vec(v)
return self.maybe_function(self.binary_add, a, b)
__radd__ = __add__
__sub__ = _bitint.__sub__
def __rsub__(self, other):
@@ -1424,9 +1434,12 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
elif isinstance(other, sbitfixvec):
return NotImplemented
try:
my_bits, other_bits = self.expand(other, False)
my_bits, other_bits = self.expand(other, False, copy=True)
except:
return NotImplemented
return self.maybe_function(self.binary_mul, my_bits, other_bits)
@classmethod
def maybe_function(cls, call, my_bits, other_bits, result_length=None):
m = float('inf')
uniform = True
for x in itertools.chain(my_bits, other_bits):
@@ -1437,21 +1450,26 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
pass
if uniform and Program.prog.options.cisc:
bl = len(my_bits)
key = bl, len(other_bits)
if key not in self.mul_functions:
ol = result_length or bl
key = call.__name__, ol, bl, len(other_bits)
if key not in cls.functions:
def instruction(*args):
res = self.binary_mul(args[bl:2 * bl], args[2 * bl:],
args[0].n)
res = call(args[ol:ol + bl], args[ol + bl:], args[0].n)
for x, y in zip(sbitvec.from_vec(res).v, args):
x.mov(y, x)
instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits))
self.mul_functions[key] = instructions_base.cisc(instruction,
bl)
res = [sbits.get_type(m)() for i in range(bl)]
self.mul_functions[key](*(res + my_bits + other_bits))
return self.from_vec(res)
instruction.__name__ = '%s%sx%s' % (call.__name__, bl, len(other_bits))
cls.functions[key] = instructions_base.cisc(instruction, ol)
res = [sbits.get_type(m)() for i in range(ol)]
cls.functions[key](*(res + my_bits + other_bits))
if result_length:
return res
else:
return cls.from_vec(res)
else:
return self.binary_mul(my_bits, other_bits, m)
return call(my_bits, other_bits, m)
@classmethod
def binary_add(cls, a, b, m):
return cls.from_vec(sbitint.bit_adder(a, b))
@classmethod
def binary_mul(cls, my_bits, other_bits, m):
matrix = []
@@ -1468,21 +1486,21 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
def TruncMul(self, other, k, m, kappa=None, nearest=False):
if nearest:
raise CompilerError('round to nearest not implemented')
if not isinstance(other, sbitintvec):
other = sbitintvec(other)
if isinstance(other, int):
b = other
else:
if not isinstance(other, sbitintvec):
other = sbitintvec(other)
b = self.get_type(k).from_vec(_complement_two_extend(other.v, k))
a = self.get_type(k).from_vec(_complement_two_extend(self.v, k))
b = self.get_type(k).from_vec(_complement_two_extend(other.v, k))
tmp = a * b
assert len(tmp.v) == k
return self.get_type(k - m).from_vec(tmp[m:])
return self.get_type(k - m).from_vec(tmp.v[m:])
def pow2(self, k):
""" Computer integer power of two.
:param k: bit length of input """
return _sbitintbase.pow2(self, k)
@staticmethod
def reverse_type(other):
return isinstance(other, sbitfixvec)
sbits.vec = sbitvec
sbitint.vec = sbitintvec
@@ -1511,6 +1529,14 @@ class cbitfix(object):
v = self.v
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
def __iter__(self):
return iter([self])
def error(*args, **kwargs):
raise CompilerError(
'Support for revealed fixed-point values in binary circuits '
'is currently limited to simple outputs. '
'Please file a feature request if you need this for an application.')
__add__ = __mul__ = __sub__ = error
class sbitfix(_fix, _binary):
""" Secret signed fixed-point number in one binary register.
@@ -1581,14 +1607,29 @@ class sbitfix(_fix, _binary):
return cls._new(cls.int_type(other), k, f)
class sbitfixvec(_fix, _vec, _binary):
""" Vector of fixed-point numbers for parallel binary computation.
"""
Values and vectors of fixed-point numbers for parallel binary computation::
Use :py:obj:`set_precision()` to change the precision.
print_ln('add: %s', (sbitfixvec(0.5) + sbitfixvec(0.3)).reveal())
print_ln('mul: %s', (sbitfixvec(0.5) * sbitfixvec(0.3)).reveal())
print_ln('sub: %s', (sbitfixvec(0.5) - sbitfixvec(0.3)).reveal())
print_ln('lt: %s', (sbitfixvec(0.5) < sbitfixvec(0.3)).reveal())
Example::
will output roughly::
a = sbitfixvec([sbitfix(0.3), sbitfix(0.5)])
b = sbitfixvec([sbitfix(0.4), sbitfix(0.6)])
add: 0.800003
mul: 0.149994
sub: 0.199997
lt: 0
Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx
16000`. You can increase this using :py:func:`set_precision`.
Refer to the following example for the vector functionality::
a = sbitfixvec([sbitfixvec(0.3), sbitfixvec(0.5)])
b = sbitfixvec([sbitfixvec(0.4), sbitfixvec(0.6)])
c = (a + b).elements()
print_ln('add: %s, %s', c[0].reveal(), c[1].reveal())
c = (a * b).elements()
@@ -1606,13 +1647,12 @@ class sbitfixvec(_fix, _vec, _binary):
lt: 1, 1
"""
int_type = sbitintvec.get_type(sbitfix.k)
float_type = type(None)
clear_type = cbitfix
rep_type = staticmethod(lambda x: x)
@property
def bit_type(self):
return type(self.v[0])
return type(self.v.v[0])
@classmethod
def set_precision(cls, f, k=None):
super(sbitfixvec, cls).set_precision(f=f, k=k)
@@ -1637,7 +1677,8 @@ class sbitfixvec(_fix, _vec, _binary):
value = self.int_type(value)
super(sbitfixvec, self).__init__(value, *args, **kwargs)
def elements(self):
return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()]
return [sbitfixvec._new(x, f=self.f, k=self.k)
for x in self.v.elements()]
def mul(self, other):
if isinstance(other, sbits):
return self._new(self.v * other)

View File

@@ -712,7 +712,8 @@ class Merger:
elif isinstance(instr, StackInstruction):
keep_order(instr, n, StackInstruction)
elif isinstance(instr, applyshuffle):
shuffles[instr.args[3]].add(n)
for handle in instr.handles():
shuffles[handle].add(n)
elif isinstance(instr, delshuffle):
for i_inst in shuffles[instr.args[0]]:
add_edge(i_inst, n)

View File

@@ -61,14 +61,15 @@ class Circuit:
return self.run(*inputs)
def run(self, *inputs):
n = inputs[0][0].n, get_tape()
inputs = [sbitvec.from_vec(x) for x in inputs]
n = inputs[0].v[0].n, get_tape()
if n not in self.functions:
if get_program().force_cisc_tape:
f = function_call_tape
else:
f = function_block
self.functions[n] = f(lambda *args: self.compile(*args))
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
self.functions[n].name = '%s(%d)' % (self.name, inputs[0].v[0].n)
flat_res = self.functions[n](*itertools.chain(*(
sbitvec.from_vec(x).v for x in inputs)))
res = []
@@ -208,7 +209,7 @@ def sha3_256(x):
for x in range(5):
for i in range(w):
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
res[x][y][i] = S_flat[1600 - 1 -j]
res[x][y][i] = S_flat.v[1600 - 1 -j]
return res
w = 64
@@ -313,7 +314,7 @@ class ieee_float:
for i in range(2):
for j in range(10):
values.append(sbitint.get_type(64).get_input_from(i))
values.append(sbitintvec.get_type(64).get_input_from(i))
fvalues = [ieee_float(x) for x in values]

View File

@@ -67,16 +67,21 @@ def ld2i(c, n):
def maybe_mulm(res, x, y):
# overwrite instruction for function-dependent preprocessing protocols
from Compiler import types
res.link(x * y)
program.curr_block.replace_last_reg(res, x * y)
def require_ring_size(k, op, suffix=''):
def require_ring_size(k, op, suffix='', slack=0):
if not program.options.ring:
return
diff = slack * (not program.allow_tight_parameters)
k += diff
if int(program.options.ring) < k:
msg = 'ring size too small for %s, compile ' \
'with \'-R %d\' or more' % (op, k)
if k > 64 and k < 128:
msg += ' (maybe \'-R 128\' as it is supported by default)'
if int(program.options.ring) >= k - diff:
msg += ", alternatively set " \
"'program.allow_tight_parameters=True' in the program"
raise CompilerError(msg + suffix)
program.curr_tape.require_bit_length(k)
@@ -97,13 +102,13 @@ def LtzRingRaw(a, k):
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
program.reading('comparison', 'ABY3')
program.reading('comparison', 'Keller25', 'Section 6')
summands = a.split_to_two_summands(k)
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
return msb
else:
program.reading('comparison', 'DEK20-pre')
program.reading('comparison', 'DEK20-pre', 'Paragraph III.D.8')
from . import floatingpoint
require_ring_size(k, 'comparison')
m = k - 1
@@ -195,7 +200,7 @@ def TruncLeakyInRing(a, k, m, signed):
if k == m:
return 0
assert k > m
program.reading('truncation', 'DEK20-pre')
program.reading('truncation', 'DEK20-pre', 'Paragraph III.D.4')
require_ring_size(k, 'leaky truncation')
from .types import sint, intbitint, cint, cgf2n
n_bits = k - m
@@ -239,7 +244,7 @@ def Mod2m(a_prime, a, k, m, signed):
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
def Mod2mRing(a_prime, a, k, m, signed):
program.reading('modulo', 'DEK20-pre')
program.reading('modulo', 'DEK20-pre', 'Paragraph III.D.3')
require_ring_size(k, 'modulo power of two')
from Compiler.types import sint, intbitint, cint
shift = int(program.options.ring) - m
@@ -254,7 +259,7 @@ def Mod2mRing(a_prime, a, k, m, signed):
return res
def Mod2mField(a_prime, a, k, m, signed):
program.reading('modulo', 'CdH10')
program.reading('modulo', 'CdH10', 'Protocol 3.2')
from .types import sint
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
@@ -349,6 +354,8 @@ def BitLTC1(u, a, b):
a: array of clear bits
b: array of secret bits (same length as a)
"""
program.reading('constant-round bit-wise public-private comparison',
'CdH10', 'Protocol 4.5')
k = len(b)
p = [program.curr_block.new_reg('s') for i in range(k)]
from . import floatingpoint
@@ -489,6 +496,8 @@ def BitLTL(res, a, b):
a: clear integer register
b: array of secret bits (same length as a)
"""
program.reading('logarithmic-round bit-wise public-private comparison',
'CdH10', 'Protocol 4.1')
k = len(b)
a_bits = b[0].bit_decompose_clear(a, k)
from .types import sint
@@ -655,7 +664,7 @@ def Mod2(a_0, a, k, signed):
if k <= 1:
movs(a_0, a)
return
program.reading('modulo', 'CdH10')
program.reading('modulo', 'CdH10', 'Protocol 3.4')
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
r_0 = program.curr_block.new_reg('s')

View File

@@ -587,6 +587,12 @@ class Compiler:
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
print("Memory size:", dict(self.prog.allocated_mem))
comm = self.prog.expected_communication()
if sum(comm):
print(
"Expected communication is %g MB online and %g MB offline." % \
(comm[0] / 1e6, comm[1] / 1e6))
return self.prog
match = {
@@ -608,6 +614,13 @@ class Compiler:
else:
return protocol + "-party.x"
@classmethod
def short_protocol_name(cls, protocol):
for x in cls.match.items():
if protocol == x[1]:
return x[0]
return re.sub('^malicious-', 'mal-', protocol)
def local_execution(self, args=None):
if args is None:
args = self.runtime_args
@@ -651,6 +664,7 @@ class Compiler:
destinations.append('.')
connections = [Connection(hostname) for hostname in hostnames]
print("Setting up players...")
lockfile = ".transfer.lock"
def run(i):
dest = destinations[i]
@@ -658,6 +672,16 @@ class Compiler:
connection.run(
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
dest_lockfile = "%s/%s" % (dest, lockfile)
try:
connection.run("test -e %s && exit 1; touch %s" % (
(dest_lockfile,) * 2))
except:
raise Exception(
"Problem with %s on %s. You cannot use the same directory "
"for several instances (including the control instance). "
"Remove %s on %s if this has been left behind from an "
"aborted exection." % ((dest_lockfile, hostnames[i]) * 2))
# executable
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
@@ -676,10 +700,12 @@ class Compiler:
dest + "Player-Data")
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
connection.run("rm %s" % dest_lockfile)
def run_with_error(i):
try:
run(i)
copied[i] = True
except IOError:
print('IO error when copying files, does %s have enough space?' %
hostnames[i])
@@ -693,13 +719,19 @@ class Compiler:
out = fn(i)
outputs[i] = out
open(lockfile, "w")
threads = []
copied = [False] * len(hosts)
for i in range(len(hosts)):
threads.append(threading.Thread(target=run_with_error, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
os.remove(lockfile)
if False in copied:
print("Error in remote copying, see above")
sys.exit(1)
# execution
threads = []

485
Compiler/cost.py Normal file
View File

@@ -0,0 +1,485 @@
import re
import math
import os
import itertools
class Comm:
def __init__(self, comm=0, offline=0):
try:
comm = comm()
except:
pass
try:
self.online, self.offline = comm
assert not offline
except:
self.online = comm or 0
assert isinstance(self.online, (int, float))
self.offline = offline
def __getitem__(self, index):
return self.offline if index else self.online
def __iter__(self):
return iter((self.online, self.offline))
def __add__(self, other):
return Comm(x + y for x, y in zip(self, other))
def __sub__(self, other):
return self + -1 * other
def __mul__(self, other):
return Comm(x * other for x in self)
__rmul__ = __mul__
def __repr__(self):
return 'Comm(%d, %d)' % tuple(self)
def __bool__(self):
return bool(sum(self))
def sanitize(self):
try:
return tuple(int(x) for x in self)
except:
return (0, 0)
dishonest_majority = {
'emi',
'mascot',
'spdz',
'soho',
'gear',
}
semihonest = {
'emi|soho',
'atlas|^shamir',
'dealer',
}
ring = {
'ring',
'2k',
}
fixed = {
'^(ring|rep-field)': 3,
'rep4': 6,
'mal-rep-field': (6, 9),
'mal-rep-ring': (lambda l: (6 * l, (l + 5) * 9)),
'sy-rep-field': 6,
'sy-rep-ring': lambda l: (6 * (l + 5), 0),
'ps-rep-field': 9,
'ps-rep-ring': lambda l: 9 * (l + 5),
'brain': lambda l: (3 * 2 * l, 3 * (2 * (l + 5) + 3 * (2 * l + 15))),
}
ot_cost = 64
spdz2k_sec = 64
def lowgear_cipher_length(l):
res = (30 + 2 * l) // 8
return res
def highgear_cipher_lengths(l):
res = 71 + 16 * l, 57 + 8 * l
return res
def highgear_cipher_limbs(l):
res = sum(int(math.ceil(x / 64)) for x in highgear_cipher_lengths(l))
return res
def highgear_decrypt_length(l):
return highgear_cipher_lengths(l)[0] / 8 + 1
def hemi_cipher_length(l):
res = 16 * l + 77
return res
def hemi_cipher_limbs(l):
res = int(math.ceil(hemi_cipher_length(l) / 64))
return res
variable = {
'^shamir': lambda N: N * (N - 1) // 2,
'atlas': lambda N: N // 2 * 4,
'dealer': lambda N: (2 * (N - 1), 1),
'semi': lambda N: lambda l: (
4 * (N - 1) * l, N * (N - 1) * (l * (ot_cost + 8 * l))),
'mascot': lambda N: lambda l: (
4 * (N - 1) * l, N * (N - 1) * (l * (3 * ot_cost + 64 * l))),
'spdz2k': lambda N: lambda l: (
4 * (N - 1) * l,
N * (N - 1) * (ot_cost * (2 * l + 4 * spdz2k_sec // 8) + \
(l + spdz2k_sec // 8) * (4 * spdz2k_sec + 2 * l * 8) + \
(5 * (l + 2 * spdz2k_sec // 8) * spdz2k_sec))),
'hemi': lambda N: lambda l: (
4 * (N - 1) * l, N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * 2),
'temi': lambda N: lambda l: (
4 * (N - 1) * l, (N - 1) * (hemi_cipher_limbs(l) * 8 * 2 * 2 +
hemi_cipher_length(l) / 8 + 1) * 2),
'soho': lambda N: lambda l: (
4 * (N - 1) * l,
(N - 1) * (N * highgear_cipher_limbs(l) * 8 * 2 +
highgear_decrypt_length(l)) * 2),
'owgear': lambda N: lambda l: (
4 * (N - 1) * l,
N * ((N - 1) * (lowgear_cipher_length(l) * (128 + 48) + 64) + 2 * l)),
'.*i.*gear': lambda N: lambda l: (
4 * (N - 1) * l,
(N - 1) * (highgear_cipher_limbs(l) * 96 * 3 +
highgear_decrypt_length(l) * 16 + N * 192 + 6 * l)),
'sy-shamir': lambda N: 2 * variable['^shamir'](N) + variable_random['^shamir|atlas'](N)
}
variable_square = {
'soho': lambda N: lambda l: (
0, (N - 1) * (N * highgear_cipher_limbs(l) * 8 + 46) * 2),
'i.*gear': lambda N: lambda l: (
0, (N - 1) * (highgear_cipher_limbs(l) * 64 * 3 +
highgear_decrypt_length(l) * 12 + N * 128 + 4 * l)),
'ps-rep-ring': lambda N: lambda l: fixed['ps-rep-ring'](l),
'sy-shamir': lambda N: (
0, variable['sy-shamir'](N) + variable_random['sy-shamir'](N))
}
matrix_triples = {
'dealer': lambda N: (N - 1, 1),
}
diag_matrix = {
'hemi': lambda N, l, dims: N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * \
(dims[0] * dims[1] + dims[0] * dims[2]),
'temi': lambda N, l, dims: (N - 1) * (
hemi_cipher_limbs(l) * 8 * 2 * 2 * (
dims[0] * dims[1] + dims[0] * dims[2]) +
(hemi_cipher_length(l) / 8 + 1) * 2 * (dims[0] * dims[2])),
}
fixed_bit = {
'mal-rep-field': (0, 11),
'rep4': (0, 8),
}
fixed_square = {
'mal-rep-ring': lambda l: fixed['mal-rep-ring'](l)[1],
}
variable_bit = {
'dealer': lambda N: (0, 1),
# missing OT cost
'emi': lambda N: lambda l: (0, l + ot_cost / 8) if N == 2 else None,
'mal-shamir': lambda N: (
0, variable_random['^shamir|atlas'](N) + \
math.ceil(N / 2) * variable_input['^shamir|atlas'](N) + \
(math.ceil(N / 2) - 0) * variable['^shamir'](N) + \
2 * reveal_variable['(mal|sy)-shamir'](N)),
}
fixed_and = {
'(mal|sy|ps)-rep': lambda bucket_size=4: (6, 3 * (3 * bucket_size - 2)),
}
variable_and = {
'emi': lambda N: (4 * (N - 1), N * (N - 1) * ot_cost)
}
trunc_pr = {
'^ring': 4,
'rep-field': 1,
'rep4': 12,
}
bit2a = {
'^(ring|rep-field)': 3,
}
dabit_from_bit = {
'ring',
'-rep-ring',
'semi2k',
}
bits_from_squares = {
'atlas': lambda N: N > 4,
'sy-shamir': lambda N: True,
'soho': lambda N: True,
'gear': lambda N: True,
'ps-rep-ring': lambda N: True,
'spdz2k': lambda N: True,
'mascot': lambda N: True,
'mal-rep-ring': lambda N: True,
'emi$': lambda N: True,
}
reveal = {
'((^|rep.*)ring|rep-field|brain)': 3,
'rep4': 4,
}
reveal_variable = {
'^shamir|atlas': lambda N: 3 * (N - 1) // 2,
'(mal|sy)-shamir': lambda N: (N - 1) // 2 * 2 * N,
'dealer': lambda N: 2 * (N - 2),
'spdz2k': lambda N: N * variable_input['mascot|spdz2k'](N),
}
fixed_input = {
'(^|ps-|mal-)(ring|rep-)': 1,
'sy-rep-ring': lambda l: 4 * (l + 5),
'sy-rep-field': 4,
'rep4': 2,
}
variable_input = {
'^shamir|atlas': lambda N: N // 2,
'mal-shamir': lambda N: N // 2,
'sy-shamir': lambda N: \
N // 2 + variable['^shamir'](N) + variable_random['^shamir|atlas'](N),
'mascot|spdz2k': lambda N: (N - 1) * Comm(1, ot_cost * 2),
'owgear': lambda N: lambda l: (
(N - 1) * l, (N - 1) * lowgear_cipher_length(l) * 16),
'i.*gear': lambda N: lambda l: (
(N - 1) * l, (N - 1) * (highgear_cipher_limbs(l) * 24 + 32 +
highgear_decrypt_length(l) * 4)),
}
variable_random = {
'^shamir|atlas': lambda N: N * (N // 2) / ((N + 2) // 2),
'mal-shamir': lambda N: N // 2 * N,
'sy-shamir': lambda N: \
2 * variable_random['^shamir|atlas'](N) + variable['^shamir'](N),
}
# cut random values
fixed_randoms = {
'sy-rep-ring': lambda l: 3 * (l + 5),
}
cheap_dot_product = {
'^(ring|rep-field)',
'sy-*',
'^shamir',
'rep4',
'atlas',
}
shuffle_application = {
'^(ring|rep-field)': 6,
'sy-rep-field': 12,
'sy-rep-ring': lambda l: 12 * (l + 5)
}
variable_edabit = {
'dealer': lambda N: lambda n_bits: lambda l: l + n_bits / 8
}
def find_match(data, protocol):
for x in data:
if re.search(x, protocol):
return x
def get_match(data, protocol):
x = find_match(data, protocol)
try:
return data.get(x)
except:
return bool(x)
def get_match_variable(data, protocol, n_parties):
f = get_match(data, protocol)
if f:
return f(n_parties)
def apply_length(unit, length):
try:
return Comm(unit(length))
except:
return Comm(unit) * length
def get_cost(fixed, variable, protocol, n_parties):
return get_match(fixed, protocol) or \
get_match_variable(variable, protocol, n_parties)
def get_mul_cost(protocol, n_parties):
return get_cost(fixed, variable, protocol, n_parties)
def get_and_cost(protocol, n_parties):
return get_cost(fixed_and, variable_and, protocol, n_parties)
def expected_communication(protocol, req_num, length, n_parties=None,
force_triple_use=False):
from Compiler.instructions import shuffle_base
from Compiler.program import Tape
get_int = lambda x: req_num.get(('modp', x), 0)
get_bit = lambda x: req_num.get(('bit', x), 0)
res = Comm()
if not protocol:
return res
if not n_parties:
try:
if get_match(fixed, protocol):
raise TypeError()
n_parties = int(os.getenv('PLAYERS'))
except TypeError:
if find_match(dishonest_majority, protocol):
n_parties = 2
else:
n_parties = 3
if find_match(dishonest_majority, protocol):
threshold = n_parties - 1
elif re.match('rep4', protocol):
n_parties = 4
threshold = 1
elif re.match('dealer', protocol):
threshold = 0
else:
threshold = n_parties // 2
malicious = not find_match(semihonest, protocol)
x = find_match(fixed, protocol)
y = get_mul_cost(protocol, n_parties)
unit = apply_length(y, length)
n_mults = get_int('simple multiplication')
matrix_cost = apply_length(
get_match_variable(matrix_triples, protocol, n_parties), length)
use_diag_matrix = get_match(diag_matrix, protocol)
use_triple_number = False
if find_match(cheap_dot_product, protocol):
n_mults += get_int('dot product')
elif (not matrix_cost and not use_diag_matrix) or force_triple_use:
use_triple_number = True
n_mults = get_int('triple')
and_cost = get_and_cost(protocol, n_parties)
if and_cost:
res += Comm(and_cost) * math.ceil(get_bit('triple') / 8)
else:
n_mults += get_bit('triple') / (length * 8)
bit_cost = Comm(apply_length(
bit2a.get(x) or get_match(fixed_bit, protocol) or
get_match_variable(variable_bit, protocol, n_parties),
length))
input_cost = apply_length(
get_match(fixed_input, protocol) or \
get_match_variable(variable_input, protocol, n_parties), length)
output_cost = Comm(
get_match(reveal, protocol) or \
get_match_variable(reveal_variable, protocol, n_parties) or \
(n_parties - 1) * 2)
random_cost = apply_length(
get_match_variable(variable_random, protocol, n_parties), length)
if not random_cost:
random_cost = n_parties * input_cost
square_unit = get_cost(fixed_square, variable_square, protocol, n_parties)
if not square_unit:
def square_unit(l):
unit = apply_length(y, l)
return Comm(0, (unit[1] or unit[0]) + sum(
apply_length(output_cost, l)))
square_cost = apply_length(square_unit, length)
res += square_cost * get_int('square')
if bit_cost:
res += bit_cost * get_int('bit')
elif get_match_variable(bits_from_squares, protocol, n_parties):
if square_cost:
if get_match(ring, protocol):
sb_cost = apply_length(square_unit, length + 1)
else:
sb_cost = square_cost
bit_cost = Comm(0, offline=sum(sb_cost + output_cost * length))
else:
bit_cost = Comm(0, offline=sum(
unit + random_cost + length * output_cost))
res += bit_cost * get_int('bit')
else:
bit_cost = Comm(0, offline=sum(
threshold * unit + (threshold + 1) * input_cost))
res += bit_cost * get_int('bit')
res += unit * n_mults
if not unit:
sh_protocol = re.sub('mal-', '', protocol)
sh_unit = get_mul_cost(sh_protocol, n_parties)
sh_random_unit = get_match_variable(
variable_random, sh_protocol, n_parties)
if sh_unit:
res += length * Comm(
sum(2 * output_cost * n_mults),
int(n_mults * (3 * sh_random_unit + 2 * sh_unit + \
2 * sum(output_cost))))
res += Comm(get_match(trunc_pr, protocol)) * length * \
get_int('probabilistic truncation')
res += Comm(bit2a.get(x)) * length * get_int('bit2A')
res += output_cost * length * get_int('open')
res += output_cost * get_bit('open')
res += get_match(dabit_from_bit, protocol) * bit_cost * get_int('dabit')
res += random_cost * get_int('random')
res += get_int('cut random') * apply_length(
get_match(fixed_randoms, protocol), length)
shuffle_correction = not find_match(shuffle_application, protocol)
def get_node():
req_node = Tape.ReqNode("")
req_node.aggregate()
return req_node
for x in req_num:
if len(x) >= 3 and x[0] == 'modp':
if x[1] == 'input':
res += input_cost * req_num[x]
elif x[1] == 'shuffle application':
shuffle_cost = apply_length(
get_match(shuffle_application, protocol), length)
if shuffle_cost:
res += Comm(shuffle_cost) * req_num[x] * x[2]
elif find_match(cheap_dot_product, protocol) or \
'dealer' in protocol:
res += shuffle_base.n_swaps(x[2]) * (threshold + 1) * \
req_num[x] * unit * (x[3] + malicious)
elif shuffle_correction:
node = get_node()
shuffle_base.add_apply_usage(
node, x[2], x[3], add_shuffles=False)
node.num = -node.num
if not use_triple_number:
node.num['modp', 'triple'] = 0
shuffle_base.add_apply_usage(
node, x[2], x[3], add_shuffles=False,
malicious=malicious, n_relevant_parties=threshold + 1)
res += req_num[x] * \
expected_communication(protocol, node.num, length,
force_triple_use=True)
elif x[1] == 'shuffle generation':
if 'dealer' in protocol:
res += Comm(
req_num[x] * shuffle_base.n_swaps(x[2]) * length)
else:
req_node = get_node()
shuffle_base.add_gen_usage(
req_node, x[2], add_shuffles=False)
req_node.num = -req_node.num
if shuffle_correction:
shuffle_base.add_gen_usage(
req_node, x[2], add_shuffles=False,
malicious=malicious,
n_relevant_parties=threshold + 1)
res += req_num[x] * \
expected_communication(protocol, req_node.num, length)
elif x[0] == 'matmul':
mm_unit = Comm()
if use_diag_matrix:
dims = list(x[1])
if dims[0] > dims[2]:
dims[0::2] = dims[2::-2]
mm_unit += Comm(
offline=use_diag_matrix(n_parties, length, dims))
matrix_cost = Comm(unit.online / 2)
for idx in ((0, 1), (1, 2)):
mm_unit += Comm(matrix_cost.online) * \
x[1][idx[0]] * x[1][idx[1]]
mm_unit += Comm(offline=matrix_cost.offline) * x[1][0] * x[1][2]
res += mm_unit * req_num[x]
elif re.search('edabit', x[0]):
edabit = get_match_variable(variable_edabit, protocol, n_parties)
if edabit:
res += Comm(offline=edabit(x[1])(length)) * req_num[x]
res.n_parties = n_parties
return res

View File

@@ -636,7 +636,8 @@ def preprocess_pandas(data):
elif pandas.api.types.is_object_dtype(t):
values = list(filter(lambda x: isinstance(x, str),
list(data.iloc[:,i].unique())))
print('converting the following to unary:', values)
print('converting the following to unary from %d: %s' %
(len(res), values))
if len(values) == 2:
res.append(data.iloc[:,i].to_numpy() == values[1])
types.append('b')

View File

@@ -98,7 +98,7 @@ class HeapQ(object):
self.size = MemValue(int_type(0))
self.int_type = int_type
self.basic_type = basic_type
prog.reading('heap queue', 'KS14')
prog.reading('heap queue', 'KS14', 'Section 5.1')
print('heap: %d levels, depth %d, size %d, index size %d' % \
(self.levels, self.depth, self.heap.oram.size, self.value_index.size))
def update(self, value, prio, for_real=True):
@@ -243,7 +243,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
:param int_type: secret integer type (default: sint)
"""
prog.reading("Dijkstra's algorithm", "KS14")
prog.reading("Dijkstra's algorithm", "KS14", "Section 5.2")
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \

View File

@@ -59,7 +59,7 @@ def EQZ(a, k):
v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sintbit.conv(bit)
prog.reading('equality', 'ABZS13')
prog.reading('equality', 'CdH10', 'Protocol 3.7')
return prog.non_linear.eqz(a, k)
def bits(a,m):
@@ -313,9 +313,10 @@ def BitDecRingRaw(a, k, m):
return bits[:m]
else:
if program.Program.prog.use_edabit():
r, r_bits = types.sint.get_edabit(m, strict=False)
r, r_bits = types.sint.get_edabit(m, strict=False, size=a.size)
elif program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r, r_bits = zip(*(types.sint.get_dabit(size=a.size)
for i in range(m)))
r = types.sint.bit_compose(r)
else:
r_bits = [types.sint.get_random_bit() for i in range(m)]
@@ -334,7 +335,8 @@ def BitDecRing(a, k, m):
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
comparison.program.reading('bit decomposition', 'ABZS13')
comparison.program.reading('bit decomposition', 'CdH10-fixed',
'Protocol 3.7')
instructions_base.set_global_vector_size(a.size)
r_dprime = types.sint()
r_prime = types.sint()
@@ -362,7 +364,7 @@ def Pow2(a, l):
return Pow2_from_bits(t)
def Pow2_from_bits(bits):
comparison.program.reading('power of two', 'ABZS13')
comparison.program.reading('power of two', 'ABZS13', 'Section 3')
m = len(bits)
t = list(bits)
pow2k = [None for i in range(m)]
@@ -419,7 +421,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False):
return TruncInRing(a, l, Pow2(m, l))
else:
kappa = program.Program.prog.security
prog.reading('secret truncation', 'ABZS13')
prog.reading('secret truncation', 'ABZS13', 'Section 3')
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
@@ -460,7 +462,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False):
@instructions_base.ret_cisc
def TruncInRing(to_shift, l, pow2m):
comparison.program.reading('secret truncation', 'DEK20')
comparison.program.reading('secret truncation', 'DEK20', 'Section 3.2.3')
n_shift = int(program.Program.prog.options.ring) - l
bits = util.bit_decompose(to_shift, l)
rev = types.sint.bit_compose(reversed(bits))
@@ -564,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True):
res = sint()
trunc_pr(res, a, k, m)
else:
prog.reading('probabilistic truncation', 'CdH10-fixed')
prog.reading('probabilistic truncation', 'CdH10-fixed',
'Protocol 3.1')
# extra bit to mask overflow
prog.curr_tape.require_bit_length(1)
if prog.use_edabit() or prog.use_split() > 2:
@@ -594,7 +597,7 @@ def TruncPrField(a, k, m):
program.Program.prog.trunc_pr_warning()
prog = program.Program.prog
prog.reading('probabilistic truncation', 'CdH10-fixed')
prog.reading('probabilistic truncation', 'CdH10-fixed', 'Protocol 3.1')
b = two_power(k-1) + a
r_prime, r_dprime = types.sint(), types.sint()
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
@@ -632,7 +635,7 @@ def SDiv(a, b, l, round_nearest=False):
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
signed=False)
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
y = y.round(2 * l + 1, l + 1, nearest=round_nearest, signed=False)
return y
def SDiv_mono(a, b, l):
@@ -684,7 +687,7 @@ def BITLT(a, b, bit_length):
def BitDecFull(a, n_bits=None, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint, cint
get_program().reading('full bit decomposition', 'NO07')
get_program().reading('full bit decomposition', 'NO07', 'Figure 2')
p = get_program().prime
assert p
bit_length = p.bit_length()
@@ -731,6 +734,7 @@ def BitDecFull(a, n_bits=None, maybe_mixed=False):
r = sint.get_edabit(bit_length, True)
bs[j].link(r[0])
tbits[j].link(sbitvec.from_vec(r[1]))
tbits[j] = tbits[j].v
else:
for i in range(bit_length):
tbits[j][i].link(sint.get_random_bit())

View File

@@ -1348,6 +1348,9 @@ class randoms(base.Instruction):
arg_format = ['sw','int']
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'cut random'), self.get_size())
@base.vectorize
class randomfulls(base.DataInstruction):
""" Store share(s) of a fresh secret random element in secret
@@ -1365,6 +1368,12 @@ class randomfulls(base.DataInstruction):
return len(self.args)
class unsplit(base.VectorInstruction, base.Ciscable):
""" Bit injection (conversion from binary to arithmetic).
:param: destination (sint)
:param: source (sbits)
"""
__slots__ = []
code = base.opcodes['UNSPLIT']
arg_format = tools.chain(['sb'], itertools.repeat('sw'))
@@ -2568,6 +2577,12 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction,
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
def add_usage(self, req_num):
base.DataInstruction.add_usage(self, req_num)
req_num.increment(
(self.field_type, 'dot product'),
self.get_size() * len(list(self.bases(iter(self.args)))))
class matmul_base(base.DataInstruction):
data_type = 'triple'
is_vec = lambda self: True
@@ -2718,6 +2733,10 @@ class trunc_pr(base.VarArgsInstruction):
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
def add_usage(self, req_node):
req_node.increment(('modp', 'probabilistic truncation'),
self.get_size() * len(self.args) // 4)
class shuffle_base(base.DataInstruction):
n_relevant_parties = 2
@@ -2725,12 +2744,12 @@ class shuffle_base(base.DataInstruction):
super(shuffle_base, self).__init__(*args, **kwargs)
prog = base.program
if re.match('ring|rep-field|sy-rep.*', prog.options.execute or ''):
ref = 'AHIK+22'
ref = 'AHIK+22', 'Protocol 3.2'
elif prog.options.execute:
ref = 'KS14'
ref = 'KS14', 'Section 4.3'
else:
ref = ('AHIK+22', 'KS14')
base.program.reading('secure shuffling', ref)
ref = ('AHIK+22', 'KS14'), None
base.program.reading('secure shuffling', *ref)
@staticmethod
def logn(n):
@@ -2741,26 +2760,39 @@ class shuffle_base(base.DataInstruction):
logn = cls.logn(n)
return logn * 2 ** logn - 2 ** logn + 1
def add_gen_usage(self, req_node, n):
@classmethod
def add_gen_usage(self, req_node, n, add_shuffles=True, malicious=True,
n_relevant_parties=None):
# hack for unknown usage
req_node.increment(('bit', 'inverse'), float('inf'))
# minimal usage with two relevant parties
logn = self.logn(n)
n_switches = self.n_swaps(n)
for i in range(self.n_relevant_parties):
n_relevant_parties = n_relevant_parties or self.n_relevant_parties
for i in range(n_relevant_parties):
req_node.increment((self.field_type, 'input', i), n_switches)
# multiplications for bit check
req_node.increment((self.field_type, 'triple'),
n_switches * self.n_relevant_parties)
if malicious:
# multiplications for bit check
req_node.increment((self.field_type, 'triple'),
n_switches * n_relevant_parties)
if add_shuffles:
req_node.increment((self.field_type, 'shuffle generation', n))
def add_apply_usage(self, req_node, n, record_size):
@classmethod
def add_apply_usage(self, req_node, n, record_size, add_shuffles=True,
malicious=True, n_relevant_parties=None):
req_node.increment(('bit', 'inverse'), float('inf'))
logn = self.logn(n)
n_switches = self.n_swaps(n) * self.n_relevant_parties
if n != 2 ** logn:
n_switches = self.n_swaps(n) * \
(n_relevant_parties or self.n_relevant_parties)
real_record_size = record_size
if n != 2 ** logn and malicious:
record_size += 1
req_node.increment((self.field_type, 'triple'),
n_switches * record_size)
if add_shuffles:
req_node.increment(
(self.field_type, 'shuffle application', n, real_record_size))
@base.gf2n
class secshuffle(base.VectorInstruction, shuffle_base):
@@ -2824,6 +2856,9 @@ class applyshuffle(shuffle_base, base.Mergeable):
for i in range(0, len(self.args), 6):
self.add_apply_usage(req_node, self.args[i], self.args[i + 3])
def handles(self):
return self.args[::4]
class delshuffle(base.Instruction):
""" Delete secure shuffle.
@@ -2874,7 +2909,7 @@ class sqrs(base.CISC):
arg_format = ['sw', 's']
def expand(self):
s = [program.curr_block.new_reg('s') for i in range(6)]
s = [type(self.args[0])() for i in range(6)]
c = [self.args[0].clear_type() for i in range(2)]
square(s[0], s[1])
subs(s[2], self.args[1], s[0])

View File

@@ -1112,6 +1112,8 @@ class Instruction(object):
new_args.append(arg.copy())
subs[arg] = new_args[-1]
else:
if isinstance(arg, program.curr_tape.Register) and arg.caller:
print(util.format_trace(arg.caller), file=sys.stderr)
new_args.append(arg)
return new_args

View File

@@ -109,7 +109,7 @@ def print_str(s, *args, print_secrets=False):
if print_secrets:
val.output()
else:
secret_error()
secret_error(args[i])
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, (list, tuple)):
@@ -831,7 +831,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
n_threads=None, key_indices=None):
get_program().reading('sorting', 'KSS13')
get_program().reading('sorting', 'KSS13', 'Section 6.1')
a_in = a
if isinstance(a_in, list):
a = Array.create_from(a)
@@ -1592,8 +1592,13 @@ def _run_and_link(function, g=None, lock_lists=True, allow_return=False):
pre = copy.copy(g)
res = function()
if res is not None and not allow_return:
if get_program().options.flow_optimization:
suffix = ' and avoid -l/--flow-optimization to keep ' \
'compile-time branching'
else:
suffix = ''
raise CompilerError('Conditional blocks cannot return values. '
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else')
'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else' + suffix)
_link(pre, g)
return res
@@ -2052,7 +2057,8 @@ def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
"""
Goldschmidt method as presented in Catrina10,
"""
get_program().reading('fixed-point division', 'CdH10-fixed')
get_program().reading('fixed-point division', 'CdH10-fixed',
'Protocol 3.3')
prime = get_program().prime
if 2 * k == int(get_program().options.ring) or \
(prime and 2 * k <= (prime.bit_length() - 1)):

View File

@@ -63,7 +63,7 @@ import re
from Compiler import mpc_math, util
from Compiler.types import *
from Compiler.types import _unreduced_squant
from Compiler.types import _unreduced_squant, _single
from Compiler.library import *
from Compiler.util import is_zero, tree_reduce
from Compiler.comparison import CarryOutRawLE
@@ -927,6 +927,11 @@ class Dense(DenseBase):
progress('f input')
def _forward(self, batch=None):
if not issubclass(self.W.value_type, _single) \
or not issubclass(self.X.value_type, _single):
raise CompilerError(
'dense inputs have to be sfix in arithmetic circuits')
if batch is None:
batch = regint.Array(self.N)
batch.assign(regint.inc(self.N))
@@ -2160,6 +2165,11 @@ class Conv2d(ConvBase):
return weights_h * weights_w * n_channels_in
def _forward(self, batch):
if not issubclass(self.weights.value_type, _single) \
or not issubclass(self.X.value_type, _single):
raise CompilerError(
'convolution inputs have to be sfix in arithmetic circuits')
if self.tf_weight_format:
assert(self.weight_shape[3] == self.output_shape[-1])
weights_h, weights_w, _, _ = self.weight_shape
@@ -4058,7 +4068,8 @@ class SGD(Optimizer):
# divide by len(batch) by truncation
# increased rate if len(batch) is not a power of two
diff = red_old - nabla_vector
pre_trunc = diff.v * rate.v
# assuming rate is already synchronized
pre_trunc = diff.v.mul(rate.v, sync=False)
momentum_value.assign_vector(diff, base)
k = max(nabla_vector.k, rate.k) + rate.f
m = rate.f + int(log_batch_size)

View File

@@ -131,7 +131,7 @@ def p_eval(p_c, x):
local_aggregation = 0
# Evaluation of the Polynomial
for i, pre_mult in zip(p_c[1:], pre_mults):
local_aggregation += pre_mult.mul_no_reduce(x.coerce(i))
local_aggregation += pre_mult.mul_no_reduce(i)
return local_aggregation.reduce_after_mul() + p_c[0]
@@ -148,7 +148,8 @@ def p_eval(p_c, x):
# @return b2: \{0,1\} value. Returns one when reduction to
# \pi is greater than \pi/2.
def sTrigSub(x):
library.get_program().reading('trigonometric functions', 'AS19')
library.get_program().reading('trigonometric functions', 'AS19',
'Section 4')
# reduction to 2* \pi
f = x * (1.0 / (2 * pi))
f = trunc(f)
@@ -267,7 +268,7 @@ def exp2_fx(a, zero_output=False, as19=False):
:return: :math:`2^a` if it is within the range. Undefined otherwise
"""
library.get_program().reading('exponential', 'AS19')
library.get_program().reading('exponential', 'AS19', 'Protocol 6')
def exp_from_parts(whole_exp, frac):
class my_fix(type(a)):
pass
@@ -316,7 +317,7 @@ def exp2_fx(a, zero_output=False, as19=False):
s = sint.conv(bits[-1])
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
else:
bits = sbitvec(a.v, a.k)
bits = sbitvec(a.v, a.k).v
s = sint.conv(bits[-1])
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
higher_bits = bits[a.f:n_bits]
@@ -437,7 +438,7 @@ def log2_fx(x, use_division=True):
:return: (sfix) the value of :math:`\log_2(x)`
"""
library.get_program().reading('logarithm', 'AS19')
library.get_program().reading('logarithm', 'AS19', 'Section 5')
if isinstance(x, types._fix):
# transforms sfix to f*2^n, where f is [o.5,1] bounded
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
@@ -815,7 +816,7 @@ def sqrt(x, k=None, f=None):
:return: square root of :py:obj:`x` (sfix).
"""
library.get_program().reading('square root', 'AS19')
library.get_program().reading('square root', 'AS19', 'Section 3')
if k is None:
k = x.k
if f is None:
@@ -837,7 +838,8 @@ def atan(x):
:return: arctan of :py:obj:`x` (sfix).
"""
library.get_program().reading('inverse trigonometric functions', 'AS19')
library.get_program().reading('inverse trigonometric functions', 'AS19',
'Protocol 5')
# obtain absolute value of x
s = x < 0
x_abs = s.if_else(-x, x)

View File

@@ -26,7 +26,7 @@ class NonLinear:
if prog.use_trunc_pr and m and (
not prog.options.ring or \
prog.use_trunc_pr <= (int(prog.options.ring) - k)):
prog.reading('probabilistic truncation', 'DEK20')
prog.reading('probabilistic truncation', 'DEK20', 'Section 3.2.2')
if prog.options.ring:
comparison.require_ring_size(k, 'truncation')
else:
@@ -92,8 +92,9 @@ class Prime(Masking):
def kor(self, d):
return KOR(d)
def require_bit_length(self, bit_length, op):
def require_bit_length(self, bit_length, op, slack=0):
prog = program.Program.prog
bit_length += slack * (not prog.allow_tight_parameters)
if bit_length > 32:
prog.curr_tape.require_bit_length(bit_length - 1, reason=op)
@@ -146,7 +147,7 @@ class KnownPrime(NonLinear):
else:
return super(KnownPrime, self).ltz(a, k)
def require_bit_length(self, bit_length, op):
def require_bit_length(self, *args, **kwargs):
pass
class Ring(Masking):
@@ -189,5 +190,5 @@ class Ring(Masking):
def ltz(self, a, k):
return LtzRing(a, k)
def require_bit_length(self, bit_length, op):
comparison.require_ring_size(bit_length, op)
def require_bit_length(self, *args, **kwargs):
comparison.require_ring_size(*args, **kwargs)

View File

@@ -7,7 +7,7 @@ papers = {
'AN17': 'https://eprint.iacr.org/2017/816',
'AS19': 'https://eprint.iacr.org/2019/354',
'AHIK+22': 'https://eprint.iacr.org/2022/1595',
'CdH10': 'https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)',
'CdH10': 'https://www.researchgate.net/publication/225092133, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)',
'CdH10-fixed': 'https://www.ifca.ai/pub/fc10/31_47.pdf',
'CCD88': 'https://doi.org/10.1145/62212.62214',
'DDNNT15': 'https://eprint.iacr.org/2015/1006',

View File

@@ -25,6 +25,7 @@ from Compiler.instructions_base import RegType
from . import allocator as al
from . import util
from .papers import *
from .cost import expected_communication
data_types = dict(
triple=0,
@@ -131,7 +132,8 @@ class Program(object):
assert self.rabbit_gap()
print(", for example, %d." % self.prime)
self.prime = bad_prime
except ImportError:
except (ImportError, AssertionError):
self.prime = bad_prime
print(".")
if options.execute:
print("Use '-- --prime <prime>' to specify the prime for "
@@ -251,6 +253,15 @@ class Program(object):
else:
print("Use '--execute <protocol>' to see recommended reading "
"on the basic protocol.")
if self.options.garbled:
if not self.options.binary:
raise CompilerError(
"You have to specify a default bit length using '--binary' "
"for garbled circuits.")
self.optimize_for_gc()
self.allow_tight_parameters = True
self.warned_about_tightness = False
self.warned_about_a2b = False
Program.prog = self
from . import comparison, instructions, instructions_base, types
@@ -439,6 +450,9 @@ class Program(object):
else:
self.req_num += tape.req_num
def required_bit_length(self, t):
return max(x.req_bit_length[t] for x in self.tapes)
def write_bytes(self):
"""Write all non-empty threads and schedule to files."""
@@ -455,7 +469,7 @@ class Program(object):
sch_file.write("1 0\n")
sch_file.write("0\n")
sch_file.write(" ".join(sys.argv) + "\n")
req = max(x.req_bit_length["p"] for x in self.tapes)
req = self.required_bit_length("p")
if self.options.ring:
sch_file.write("R:%s" % self.options.ring)
elif self.options.prime:
@@ -470,6 +484,14 @@ class Program(object):
assert len(req2) <= 2
if req2:
sch_file.write("lg2:%s" % max(req2))
sch_file.write("\n")
exp = self.expected_communication()
if exp:
sch_file.write(
"online:%d offline:%d n_parties:%d\n" % (
exp.sanitize() + (exp.n_parties,)))
else:
sch_file.write('no expections\n')
sch_file.close()
h = hashlib.sha256()
for tape in self.tapes:
@@ -590,6 +612,10 @@ class Program(object):
# communicate protocol compability
Compiler.instructions.active(self._always_active)
# communicate mulm usage to VM
if self.use_mulm != 1:
self.relevant_opts.add("no_mulm")
self.write_bytes()
if self.options.asmoutfile:
@@ -743,6 +769,15 @@ class Program(object):
def used_splits(self):
return self._split
def have_a2b(self):
if self.use_split() or self.use_edabit() or self.use_dabit:
return True
if not self.warned_about_a2b:
print(
'WARNING: No option selected for A2B conversion, defaulting '
'to edaBits. Use -X/-Y/-Z to get rid of this warning.')
self.warned_about_a2b = True
def use_square(self, change=None):
"""Setting whether to use preprocessed square tuples
(default: false).
@@ -869,15 +904,30 @@ class Program(object):
bl = inst.args[0]
return (abs(bl.i) + 63) // 64 * 8
def reading(self, concept, reference):
key = concept, reference
def reading(self, concept, reference, part=None):
key = concept, reference, part
if self.options.papers and key not in self.recommended:
if isinstance(reference, tuple):
assert part is None
reference = ', '.join(papers.get(x) or x for x in reference)
print('Recommended reading on %s: %s' % (
concept, papers.get(reference) or reference))
suffix = ' (%s)' % part or ''
print('Recommended reading on %s: %s%s' % (
concept, papers.get(reference) or reference, suffix))
self.recommended.add(key)
def expected_communication(self):
if self.options.ring:
bit_length = int(self.options.ring)
elif self.options.prime:
bit_length = self.prime.bit_length()
else:
# check against OnlineOptions.cpp
bit_length = max(self.required_bit_length("p"), 128)
bit_length = int(math.ceil(bit_length / 64) * 64)
length = int(math.ceil(bit_length / 8))
return expected_communication(
self.options.execute, self.req_num or Tape.ReqNum(), length)
class Tape:
"""A tape contains a list of basic blocks, onto which instructions are added."""
@@ -1452,6 +1502,12 @@ class Tape:
__rmul__ = __mul__
def __neg__(self):
res = Tape.ReqNum()
for i, count in list(self.items()):
res[i] = -count
return res
def set_all(self, value):
if Program.prog.options.verbose and \
value == float("inf") and self["all", "inv"] > 0:
@@ -1644,6 +1700,19 @@ class Tape:
__float__ = __int__
def __eq__(self, other):
raise CompilerError("equality testing not implemented")
__ne__ = __eq__
class _no_secret_truth(_no_truth):
def __bool__(self):
raise CompilerError(
"Cannot branch on secret values like %s. "
"See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-branch-on-secret-values. " % \
type(self).__name__
)
class Register(_no_truth):
"""
Class for creating new registers. The register's index is automatically assigned
@@ -1755,6 +1824,17 @@ class Tape:
return self.vector or [self]
def __getitem__(self, index):
try:
if isinstance(index, slice):
for x in index.start, index.stop, index.step:
if x is not None:
int(x)
else:
int(index)
except:
raise CompilerError(
'cannot address vectors with run-time indices, '
'use (Multi)Array instead')
if self.size == 1 and index == 0:
return self
if not self.vector:

View File

@@ -591,7 +591,7 @@ class _structure(Tape._no_truth):
return cls.int_type.reg_type
raise CompilerError('type not supported as argument: %s' % cls)
class _secret_structure(_structure):
class _secret_structure(Tape._no_secret_truth, _structure):
@classmethod
def input_tensor_from(cls, player, shape):
""" Input tensor secretly from player.
@@ -1105,10 +1105,9 @@ class cint(_clear, _int):
@staticmethod
def in_immediate_range(value, regint=False):
if value and not regint:
# +1 for sign
bit_length = 1 + int(math.ceil(math.log(abs(value), 2)))
# slack for sign
program.non_linear.require_bit_length(
bit_length, 'integer conversion')
value.bit_length(), 'integer conversion', slack=1)
return value < 2**31 and value >= -2**31
@vectorize_init
@@ -1321,7 +1320,7 @@ class cint(_clear, _int):
:param other: cint/regint/int """
return self >> other
def round(self, k, m, nearest=None, signed=False):
def round(self, k, m, nearest=None, signed=True):
if signed:
self += 2 ** (k - 1)
self += 2 ** (m - 1)
@@ -2383,7 +2382,7 @@ class _secret(_arithmetic_register, _secret_structure):
@set_instruction_type
@read_mem_value
@vectorize
def mul(self, other):
def mul(self, other, sync=True):
""" Secret multiplication. Either both operands have the same
size or one size 1 for a value-vector multiplication.
@@ -2396,7 +2395,7 @@ class _secret(_arithmetic_register, _secret_structure):
res = type(self)(size=x.size)
mulrs(res, x, y)
return res
if program.use_mulm == 1:
if program.use_mulm == 1 or not sync:
mulm = instructions.mulm
elif program.use_mulm == -1:
mulm = lambda res, x, y: instructions.mulm(res, x, cint(regint(y)))
@@ -2530,7 +2529,7 @@ class _secret(_arithmetic_register, _secret_structure):
writesharestofile(regint.conv(position), *shares)
class sint(_secret, _int):
"""
r"""
Secret integer in the protocol-specific domain. It supports
operations with :py:class:`sint`, :py:class:`cint`,
:py:class:`regint`, and Python integers. Operations where one of
@@ -2559,6 +2558,12 @@ class sint(_secret, _int):
undefined and potentially insecure if the operands are longer than
the bit length.
Instances of sint are understood to be signed. This means that,
for modulo :math:`N`, numbers in :math:`[0,N/2)` are understood as
positive numbers whereas numbers in :math:`[N/2,N)` are understood
to be negative, namely :math:`x-N`. This ensures expected
arithmetic such as :math:`-1 + 1 = (N-1) + 1 = N = 0 \mod N`.
See :ref:`nonlinear` for an overview of how non-linear
computation is implemented.
@@ -2826,7 +2831,7 @@ class sint(_secret, _int):
self.load_other(val.v.round(val.k, val.f,
nearest=val.round_nearest))
elif isinstance(val, sbitvec):
super(sint, self).__init__('s', val=val, size=val[0].n)
super(sint, self).__init__('s', val=val, size=val.v[0].n)
else:
super(sint, self).__init__('s', val=val, size=size)
@@ -3001,13 +3006,19 @@ class sint(_secret, _int):
maybe_mixed)
def TruncMul(self, other, k, m, nearest=False):
if not nearest and not program.warned_about_tightness and \
program.options.ring and int(program.options.ring) == k:
print('WARNING: Using tight parameters. '
'Increase ring size or reduce fixed-point precision '
'for increased efficiency')
program.warned_about_tightness = True
return (self * other).round(k, m, nearest, signed=True)
def TruncPr(self, k, m, signed=True):
return floatingpoint.TruncPr(self, k, m, signed=signed)
@vectorize
def round(self, k, m, nearest=False, signed=False):
def round(self, k, m, nearest=False, signed=True):
""" Truncate and maybe round secret :py:obj:`k`-bit integer
by :py:obj:`m` bits. :py:obj:`m` can be secret if
:py:obj:`nearest` is false, in which case the truncation will be
@@ -3625,7 +3636,7 @@ class _bitint(Tape._no_truth):
return s ^ carry, a ^ (s & (carry ^ a))
@staticmethod
def bit_comparator(a, b):
def bit_comparator(a, b, m=None):
long_one = util.long_one(a + b)
op = lambda y,x,*args: (util.if_else(x[1], x[0], y[0]), \
util.if_else(x[1], long_one, y[1]))
@@ -3794,7 +3805,11 @@ class _bitint(Tape._no_truth):
if const_rounds:
return self.get_highest_different_bits(a, b, index)
else:
return self.bit_comparator(a, b)
try:
return self.maybe_function(
self.bit_comparator, a, b, result_length=2)
except:
return self.bit_comparator(a, b)
def __lt__(self, other):
if self.reverse_type(other):
@@ -3826,9 +3841,17 @@ class _bitint(Tape._no_truth):
if self.reverse_type(other):
return other == self
diff = self ^ other
diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]]
return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y),
diff_bits))
diff_bits = diff.bit_decompose()[:bit_length]
try:
res = self.maybe_function(self.eqz, diff_bits, [], 1)
except:
res = self.eqz(diff_bits)
return self.comp_result(res[0])
@staticmethod
def eqz(bits, other_bits=None, m=None):
diff_bits = [x.bit_not() for x in bits]
return [util.tree_reduce(lambda x, y: x.bit_and(y), diff_bits)]
def __ne__(self, other):
return (self == other).bit_not()
@@ -4052,7 +4075,7 @@ class cfix(_number, _structure):
:py:class:`cfix` if the other operand is public
(cfix/regint/cint/int) or :py:class:`sfix` if the other operand is
an sfix. It also support comparisons (``==, !=, <, <=, >, >=``),
returning either :py:class:`regint` or :py:class:`sbitint`.
returning either :py:class:`regint` or :py:class:`sintbit`.
Similarly to :py:class:`Compiler.types.cint`, this type is
restricted to arithmetic circuits due to the fact that only
@@ -4806,7 +4829,7 @@ class _fix(_single):
return self._new(self.v[index])
def __iter__(self):
return (self._new(x) for x in self.v)
return (self._new(x, k=self.k, f=self.f) for x in self.v)
@vectorize
def add(self, other):
@@ -4839,7 +4862,8 @@ class _fix(_single):
f -= 1
v //= 2
k = len(bin(abs(v))) - 1
other = self.multipliable(v, k, f, self.size)
val = self.v.TruncMul(v, self.k + f, f, nearest=self.round_nearest)
return self._new(val, k=self.k, f=self.f)
try:
other = self.coerce(other, equal_precision=False)
except:
@@ -4983,19 +5007,29 @@ class sfix(_fix):
It supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
returning :py:class:`sbitint`. The other operand can be any of
returning :py:class:`sintbit`. The other operand can be any of
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
and ``**``.
Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx
16000` with the smallest non-zero number being :math:`2^{-16}`.
16000` with the smallest non-zero number being :math:`2^{-16}
\\approx 0.000015`.
You can change this using :py:func:`set_precision`.
Fixed-point multiplication is not linear in the sense of the
computation domain. Therefore, techniques from :ref:`nonlinear`
have to be used.
Many operations (including multiplication and division) use
probabilistic trunctation by default. This means that the results
are not deterministc but random within a small range around the
deterministic result. You can switch to (more expensive)
deterministic computation by setting
``sfix.round_nearest`` to true. See `Catrina and de Hoogh
<https://www.ifca.ai/pub/fc10/31_47.pdf>`_ for an introduction to
probabilistic truncation.
:params _v: int/float/regint/cint/sint/sfloat
"""
int_type = sint
@@ -5079,7 +5113,10 @@ class sfix(_fix):
return self.v
def mul_no_reduce(self, other, res_params=None):
if not isinstance(other, type(self)):
if util.is_constant_float(other):
return self.unreduced(
self.v * cfix.int_rep(other, k=self.k, f=self.f))
elif not isinstance(other, type(self)):
return self * other
assert self.f == other.f
assert self.k == other.k
@@ -6040,7 +6077,11 @@ class Array(_vectorizable):
@read_mem_value
def get_address(self, index, size=None):
if isinstance(index, (_secret, _single)):
raise CompilerError('need cleartext index')
raise CompilerError(
'Need cleartext index to address Array. If you need to address '
'using secret numbers, you need to use ORAM: '
'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#'
'module-Compiler.oram')
key = str(index), size or 1
index = self.check(index, self.length, self.length)
if (program.curr_block, key) not in self.address_cache:
@@ -6486,6 +6527,10 @@ class Array(_vectorizable):
M = Matrix(1, len(self), self.value_type, address=self.address)
return M.dot(other)
def sum(self):
""" Sum of elements. """
return self[:].sum()
def shuffle(self):
""" Insecure shuffle in place. """
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
@@ -7105,11 +7150,6 @@ class SubMultiArray(_vectorizable):
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
try:
try:
# force matmuls for smaller sizes
a, c = res_matrix.sizes
if a * c / (a + c) < 2 and \
self.value_type == other.value_type:
raise AttributeError()
self.value_type.direct_matrix_mul
skip_reduce = set((sint, sfix)) == \
set((self.value_type, other.value_type))

View File

@@ -1,6 +1,7 @@
import math
import operator
from functools import reduce
from Compiler.exceptions import *
def format_trace(trace, prefix=' '):
if trace is None:
@@ -91,8 +92,9 @@ def if_else(cond, a, b):
else:
return cond.if_else(a, b)
except:
print(cond, a, b)
raise
raise CompilerError(
'incompatible types for ternary/if-else operator: %s' % '/'.join(
type(x).__name__ for x in (cond, a, b)))
def cond_swap(cond, a, b):
if isinstance(cond, (bool, int)):

View File

@@ -18,7 +18,7 @@ int main()
KeySetup<Share<P256Element::Scalar>> key;
string prefix = PREP_DIR "ECDSA/";
mkdir_p(prefix.c_str());
write_online_setup(prefix, P256Element::Scalar::pr());
P256Element::Scalar::write_setup(prefix);
PRNG G;
G.ReSeed();
generate_mac_keys<Share<P256Element::Scalar>>(key, 2, prefix, G);

View File

@@ -60,7 +60,7 @@ void sub(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1)
void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,
const FHE_PK& pk)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (c0.params!=c1.params) { throw params_mismatch(); }
if (ans.params!=c1.params) { throw params_mismatch(); }

View File

@@ -9,7 +9,7 @@ Diagonalizer::Diagonalizer(const MatrixVector& matrices,
const FFT_Data& FTD, const FHE_PK& pk) :
FTD(FTD)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
assert(not matrices.empty());
for (auto& matrix : matrices)

View File

@@ -28,7 +28,7 @@ void NaiveFFT(vector<modp>& ans,vector<modp>& a,int N,const modp& theta,const Zp
void FFT(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (N==1) { return; }
@@ -141,7 +141,7 @@ void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD, bool start_with_one)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
assert(roots.size() > size_t(n));

View File

@@ -56,6 +56,7 @@ class FFT_Data
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }
int phi_m() const { return R.phi_m(); }
int m() const { return R.m(); }
int num_slots() const { return R.phi_m(); }
@@ -71,6 +72,8 @@ class FFT_Data
const Ring& get_R() const { return R; }
void write_setup(const string& dir) const { prData.write_setup(dir); }
bool operator==(const FFT_Data& other) const { return not (*this != other); }
bool operator!=(const FFT_Data& other) const;

View File

@@ -4,6 +4,7 @@
#include "P2Data.h"
#include "FFT_Data.h"
#include "Tools/CodeLocations.h"
#include "Processor/OnlineOptions.h"
#include "Math/modp.hpp"
@@ -66,7 +67,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
FHE_PK& PK = *this;
@@ -154,7 +155,7 @@ void FHE_PK::encrypt(Ciphertext& c,
void FHE_PK::quasi_encrypt(Ciphertext& c,
const Rq_Element& mess,const Random_Coins& rc) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
@@ -216,7 +217,7 @@ void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (&c.get_params()!=params) { throw params_mismatch(); }
@@ -284,6 +285,14 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
PRNG G; G.ReSeed();
bigint mask;
bigint two_Bd = 2 * Bd;
bool verbose = OnlineOptions::singleton.has_option("verbose_dd");
int max_bits = 0;
if (verbose)
cerr << "Random bits in distributed decryption: " << two_Bd.numBits()
<< endl;
for (int i=0; i<(*params).phi_m(); i++)
{
G.randomBnd(mask, two_Bd);
@@ -292,7 +301,13 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
vv[i] += mask;
vv[i] %= mod;
if (vv[i]<0) { vv[i]+=mod; }
if (verbose)
max_bits = max(max_bits, vv[i].numBits());
}
if (verbose)
cerr << "Maximum bits in distributed decryption: " << max_bits << endl;
}

View File

@@ -52,7 +52,7 @@ template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
int m = 1024;
int lgp = plaintext_length;
bigint p;
@@ -95,7 +95,7 @@ template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, P2Data& P2D, bool round_up, int n)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (params.n_mults() > 0)
throw runtime_error("only implemented for 0-level BGV");
@@ -113,12 +113,13 @@ int generate_semi_setup(int plaintext_length, int sec,
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up)
{
#ifdef VERBOSE
cout << "Need ciphertext modulus of length " << lgp0;
if (params.n_mults() > 0)
cout << "+" << lgp1;
cout << " and " << phi_N(m) << " slots" << endl;
#endif
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
{
cout << "Need ciphertext modulus of length " << lgp0;
if (params.n_mults() > 0)
cout << "+" << lgp1;
cout << " and " << phi_N(m) << " slots" << endl;
}
int extra_slack = 0;
if (round_up)
@@ -160,13 +161,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
{
(void) lg2pi, (void) n;
#ifdef VERBOSE
if (n >= 2 and n <= 10)
cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
#endif
bool verbose = OnlineOptions::singleton.has_option("verbose_he_setup");
if (verbose)
{
if (n >= 2 and n <= 10)
cerr << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
cerr << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
cerr << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
}
int extra_slack = 0;
if (round_up)
@@ -185,15 +189,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
extra_slack = 2 * i;
lg2p0 += i;
lg2p1 += i;
#ifdef VERBOSE
cout << "Rounding up to " << lg2p0 << "+" << lg2p1
<< ", giving extra slack of " << extra_slack << " bits" << endl;
#endif
if (verbose)
cerr << "Rounding up to " << lg2p0 << "+" << lg2p1
<< ", giving extra slack of " << extra_slack << " bits"
<< endl;
}
#ifdef VERBOSE
cout << "Total length: " << lg2p0 + lg2p1 << endl;
#endif
if (verbose)
cerr << "Total length: " << lg2p0 + lg2p1 << " = " << lg2p0 << " + "
<< lg2p1 << endl;
return extra_slack;
}
@@ -305,7 +310,7 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
template <>
void Parameters::SPDZ_Data_Setup(FHE_Params& params, FFT_Data& FTD)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
bigint p;
int idx, m;
@@ -678,7 +683,7 @@ void char_2_dimension(int& m, int& lg2)
template <>
void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
int n = n_parties;
int lg2 = plaintext_length;

View File

@@ -22,9 +22,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
sigma *= 1.4;
params.set_R(params.get_R() * 1.4);
}
#ifdef VERBOSE
cerr << "Standard deviation: " << this->sigma << endl;
#endif
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
cerr << "Standard deviation: " << this->sigma << endl;
produce_epsilon_constants();
@@ -40,24 +40,27 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
int matrix_dim = params.get_matrix_dim();
#ifdef NOISY
cout << "phi(m): " << phi_m << endl;
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cout << "V_s: " << V_s << endl;
cout << "c1: " << c1 << endl;
cout << "c2: " << c2 << endl;
cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
cout << "log(slack): " << slack << endl;
cout << "B_clean: " << B_clean << endl;
cout << "B_scale: " << B_scale << endl;
cout << "matrix dimension: " << matrix_dim << endl;
cout << "drown sec: " << params.secp() << endl;
cout << "sec: " << sec << endl;
#endif
assert(matrix_dim > 0);
assert(params.secp() >= 0);
drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp());
if (OnlineOptions::singleton.has_option("verbose_he_setup"))
{
cerr << "phi(m): " << phi_m << endl;
cerr << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cerr << "V_s: " << V_s << endl;
cerr << "c1: " << c1 << endl;
cerr << "c2: " << c2 << endl;
cerr << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
cerr << "log(slack): " << slack << endl;
cerr << "B_clean bits: " << B_clean.numBits() << endl;
cerr << "B_scale bits: " << B_scale.numBits() << endl;
cerr << "matrix dimension: " << matrix_dim << endl;
cerr << "drown sec: " << params.secp() << endl;
cerr << "sec: " << sec << endl;
cerr << "drown bits: " << drown.numBits() << endl;
}
}
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
@@ -118,7 +121,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
const FHE_Params& params) :
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, params)
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, sec, params)
{
B_KS = p * c2 * this->sigma * phi_m / sqrt(12);
#ifdef NOISY

View File

@@ -9,7 +9,7 @@
void P2Data::forward(vector<poly_type>& ans,const vector<gf2n_short>& a) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
int n=gf2n_short::degree();
@@ -32,7 +32,7 @@ void P2Data::forward(vector<poly_type>& ans,const vector<gf2n_short>& a) const
void P2Data::backward(vector<gf2n_short>& ans,const vector<poly_type>& a) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
int n=gf2n_short::degree();
BitVector bv(a.size());

View File

@@ -32,6 +32,7 @@ class P2Data
void backward(vector<gf2n_short>& ans,const vector<poly_type>& a) const;
int get_prime() const { return 2; }
void write_setup(const string&) const {}
bool operator!=(const P2Data& other) const;
@@ -47,6 +48,7 @@ class P2Data
void load_or_generate(const Ring& Rg);
friend void init(P2Data& P2D,const Ring& Rg);
};

View File

@@ -154,7 +154,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
assert(a.FFTD);
if (a.rep!=b.rep) { throw rep_mismatch(); }
@@ -299,7 +299,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other)
Ring_Element Ring_Element::mul_by_X_i(int j) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
assert(FFTD);
Ring_Element ans;
@@ -517,7 +517,9 @@ modp Ring_Element::get_constant() const
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
{
ZpD.pack(o);
o.store(v);
o.store((int)v.size());
for (auto& x : v)
x.pack(o, ZpD);
}
@@ -529,7 +531,16 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
throw runtime_error(
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
+ to_string(ZpD.pr_bit_length));
o.get(v);
unsigned int length;
o.get(length);
v.clear();
v.reserve(length);
modp tmp;
for (unsigned int i=0; i<length; i++)
{
tmp.unpack(o,ZpD);
v.push_back(tmp);
}
}

View File

@@ -140,7 +140,7 @@ vector<bigint> Rq_Element::to_vec_bigint() const
// result mod p0 = a[0]; result mod p1 = a[1]
void Rq_Element::to_vec_bigint(vector<bigint>& v) const
{
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
a[0].to_vec_bigint(v);
if (n_mults() == 0) {
@@ -208,7 +208,7 @@ void Rq_Element::Scale(const bigint& p)
{
if (lev==0) { return; }
CODE_LOCATION
CODE_LOCATION_NO_SCOPE
if (n_mults() == 0) {
//for some reason we scale but we have just one level
@@ -312,7 +312,17 @@ void Rq_Element::pack(octetStream& o, int) const
void Rq_Element::unpack(octetStream& o, int)
{
unsigned int ll; o.get(ll); lev=ll;
check_level();
try
{
check_level();
}
catch (...)
{
lev = 0;
throw;
}
for (int i = 0; i <= lev; ++i)
a[i].unpack(o);
}

View File

@@ -132,7 +132,7 @@ void PartSetup<FD>::output(Names& N)
{
// Write outputs to file
string dir = get_prep_sub_dir<Share<typename FD::T>>(N.num_players());
write_online_setup(dir, FieldD.get_prime());
FieldD.write_setup(dir);
write_mac_key(dir, N.my_num(), N.num_players(), alphai);
}

View File

@@ -11,7 +11,7 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
vv.resize(pk.get_params().phi_m());
vv1.resize(pk.get_params().phi_m());
// extra limb for operations
bigint limit = pk.get_params().Q() << 64;
bigint limit = pk.get_params().p0() << 64;
vv.allocate_slots(limit);
vv1.allocate_slots(limit);
mf.allocate_slots(pk.p() << 64);
@@ -19,6 +19,8 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
class ModuloTreeSum : public TreeSum<bigint>
{
typedef TreeSum<bigint> super;
bigint modulo;
void post_add_process(vector<bigint>& values)
@@ -32,6 +34,12 @@ public:
modulo(modulo)
{
}
void run(vector<bigint>& values, const Player& P)
{
lengths.resize(values.size(), numBytes(modulo));
super::run(values, P);
}
};
template<class FD>
@@ -48,13 +56,15 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
if ((int)vv.size() != params.phi_m())
throw length_error("wrong length of ring element");
size_t length = numBytes(pk.get_params().p0());
if (OnlineOptions::singleton.direct)
{
// Now pack into an octetStream for broadcasting
vector<octetStream> os(P.num_players());
for (int i=0; i<params.phi_m(); i++)
{ (os[P.my_num()]).store(vv[i]); }
{ (os[P.my_num()]).store(vv[i], length); }
// Broadcast and Receive the values
P.Broadcast_Receive(os);
@@ -67,7 +77,7 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
{
for (int j = 0; j < params.phi_m(); j++)
{
os[i].get(vv1[j]);
os[i].get(vv1[j], length);
}
share.dist_decrypt_2(vv, vv1);
}

View File

@@ -38,7 +38,7 @@ void RealPairwiseMachine::init()
gfp::init_field(p);
ofstream outf;
if (output)
write_online_setup(get_prep_dir<FFT_Data>(P), p);
gfp::write_setup(get_prep_dir<FFT_Data>(P));
}
for (int i = 0; i < nthreads; i++)
@@ -141,5 +141,10 @@ void PairwiseMachine::check(Player& P) const
bundle.compare(P);
}
int PairwiseMachine::comp_sec()
{
return NonInteractiveProof::comp_sec(sec);
}
template void RealPairwiseMachine::setup_keys<FFT_Data>();
template void RealPairwiseMachine::setup_keys<P2Data>();

View File

@@ -31,6 +31,8 @@ public:
void unpack(octetStream& os);
void check(Player& P) const;
int comp_sec();
};
class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine

View File

@@ -71,6 +71,7 @@ void secure_init(T& setup, Player& P, U& machine,
string filename = PREP_DIR + T::name() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
+ to_string(params.secp()) + "-"
+ to_string(machine.comp_sec()) + "-"
+ to_string(params.get_matrix_dim()) + "-"
+ OnlineOptions::singleton.prime.get_str() + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
@@ -121,7 +122,7 @@ void secure_init(T& setup, Player& P, U& machine,
os.output(file);
}
if (OnlineOptions::singleton.verbose)
if (OnlineOptions::singleton.has_option("verbose_he"))
{
cerr << "Ciphertext length: " << params.p0().numBits();
for (size_t i = 1; i < params.FFTD().size(); i++)
@@ -131,6 +132,7 @@ void secure_init(T& setup, Player& P, U& machine,
cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64);
cerr << " limbs)";
cerr << endl;
cerr << "Number of slots: " << params.phi_m() << endl;
}
}

View File

@@ -184,3 +184,11 @@ void Proof::Preimages::check_sizes()
if (m.size() != r.size())
throw runtime_error("preimage sizes don't match");
}
int NonInteractiveProof::comp_sec(int sec)
{
if (sec > 0)
return OnlineOptions::singleton.comp_sec();
else
return 0;
}

View File

@@ -9,6 +9,7 @@ using namespace std;
#include "FHE/Ciphertext.h"
#include "FHE/AddableVector.h"
#include "Protocols/CowGearOptions.h"
#include "Processor/OnlineOptions.h"
#include "config.h"
@@ -90,9 +91,6 @@ class Proof
{
V = ceil((sec + 2) / log2(2 * phim + 1));
U = 2 * V;
#ifdef VERBOSE
cerr << "Using " << U << " ciphertexts per proof" << endl;
#endif
}
else
{
@@ -151,14 +149,24 @@ class Proof
output += input.at(j);
}
}
void debugging()
{
if (OnlineOptions::singleton.has_option("verbose_he"))
{
cerr << "Using " << U << " ciphertexts per proof" << endl;
cerr << "Plaintext bound check bit length: " << B_plain_length << endl;
cerr << "Randomness bound check bit length: " << B_rand_length << endl;
}
}
};
class NonInteractiveProof : public Proof
{
// sec = 0 used for protocols without proofs
static int comp_sec(int sec) { return sec > 0 ? max(COMP_SEC, sec) : 0; }
public:
// sec = 0 used for protocols without proofs
static int comp_sec(int sec);
bigint static slack(int sec, int phim)
{ sec = comp_sec(sec); return bigint(phim * sec * sec) << (sec / 2 + 8); }
@@ -174,6 +182,7 @@ public:
B_rand_length = numBits(B*3*phim*rho);
plain_check = (bigint(1) << B_plain_length) - sec * tau;
rand_check = (bigint(1) << B_rand_length) - sec * rho;
debugging();
}
};
@@ -194,6 +203,7 @@ public:
// leeway for completeness
plain_check = (bigint(2) << B_plain_length);
rand_check = (bigint(2) << B_rand_length);
debugging();
}
};

View File

@@ -161,6 +161,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
template<class T,class FD,class S>
void SimpleEncCommit<T,FD,S>::create_more()
{
CODE_LOCATION
cout << "Generating more ciphertexts in round " << this->n_rounds << endl;
octetStream ciphertexts, cleartexts;
size_t prover_memory = this->generate_proof(this->c, this->m, ciphertexts, cleartexts);
@@ -181,6 +182,7 @@ template <class FD>
size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& ciphertexts,
octetStream& cleartexts)
{
CODE_LOCATION
AddableVector<Ciphertext> others_ciphertexts;
others_ciphertexts.resize(proof.U, pk.get_params());
for (int i = 1; i < P.num_players(); i++)
@@ -244,6 +246,7 @@ SummingEncCommit<FD>::SummingEncCommit(const Player& P, const FHE_PK& pk,
template<class FD>
void SummingEncCommit<FD>::create_more()
{
CODE_LOCATION
octetStream cleartexts;
const Player& P = this->P;
AddableVector<Ciphertext> commitments;
@@ -267,10 +270,11 @@ void SummingEncCommit<FD>::create_more()
this->c.unpack(ciphertexts, this->pk);
commitments.unpack(ciphertexts, this->pk);
#ifdef VERBOSE_HE
cout << "Tree-wise sum of ciphertexts with "
<< 1e-9 * ciphertexts.get_length() << " GB" << endl;
#endif
if (OnlineOptions::singleton.has_option("verbose_he"))
cerr << "Tree-wise sum of " << this->c.size()
<< " ciphertexts with " << 1e-9 * ciphertexts.get_length()
<< " GB" << endl;
this->timers["Exchanging ciphertexts"].start();
tree_sum.run(this->c, P);
tree_sum.run(commitments, P);

View File

@@ -56,6 +56,9 @@ public:
void unpack(octetStream&) {}
void check(Player&) const {}
// computational security doesn't matter in global proofs
int comp_sec() { return 0; }
};
class MultiplicativeMachineParams : public MachineBase

View File

@@ -78,6 +78,9 @@ public:
}
};
template<class T>
const int CcdShare<T>::default_length;
}
#endif /* GC_CCDSHARE_H_ */

View File

@@ -88,6 +88,9 @@ public:
}
};
template<class T>
const int MaliciousCcdShare<T>::default_length;
} /* namespace GC */
#endif /* GC_MALICIOUSCCDSHARE_H_ */

View File

@@ -286,7 +286,12 @@ void Processor<T>::notcb(const ::BaseInstruction& instruction)
template<class T>
void Processor<T>::movsb(const ::BaseInstruction& instruction)
{
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
int n_blocks;
if (instruction.get_n() < unsigned(T::default_length))
n_blocks = 1;
else
n_blocks = DIV_CEIL(instruction.get_n(), T::default_length);
for (int i = 0; i < n_blocks; i++)
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i];
}
@@ -407,12 +412,14 @@ void Processor<T>::convcbitvec(const BaseInstruction& instruction,
{
auto proto = ShareThread<T>::s().protocol;
auto P = ShareThread<T>::s().P;
if (proto)
// The default use case in the compiler doesn't require synchronization
// with function-dependent protocols, but testing does.
if (proto and OnlineOptions::singleton.has_option("convcbitvec_sync"))
proto->sync(bits, *P);
else
throw exception();
throw no_singleton();
}
catch (exception&)
catch (no_singleton&)
{
if (P)
ProtocolBase<T>::sync(bits, *P);

View File

@@ -147,6 +147,9 @@ public:
}
};
template<class T, class V>
const int SemiSecretBase<T, V>::default_length;
} /* namespace GC */
#endif /* GC_SEMISECRET_H_ */

View File

@@ -16,9 +16,6 @@
namespace GC
{
template<class T, class V>
const int SemiSecretBase<T, V>::default_length;
inline
SemiSecret::MC* SemiSecret::new_mc(
typename super::mac_key_type)

View File

@@ -67,7 +67,7 @@ inline ShareThread<T>& ShareThread<T>::s()
if (singleton and T::is_real)
return *singleton;
else
throw runtime_error("no ShareThread singleton");
throw no_singleton("no ShareThread singleton");
}
} /* namespace GC */

View File

@@ -120,6 +120,11 @@ public:
return {S, left, right, n_full_blocks()};
}
Range<StackedVector<T>> full_block_left_range(StackedVector<T>& S)
{
return {S, left, n_full_blocks()};
}
DoubleIterator<T> partial_block(StackedVector<T>& S)
{
assert(n_blocks() != n_full_blocks());
@@ -127,6 +132,17 @@ public:
S.iterator_for_size(right + n_full_blocks(), 1)};
}
typename CheckVector<T>::iterator partial_left_block(StackedVector<T>& S)
{
assert(n_blocks() != n_full_blocks());
return S.iterator_for_size(left + n_full_blocks(), 1);
}
T& get_right_base(StackedVector<T>& S)
{
return S[right];
}
Range<StackedVector<T>> full_block_output_range(StackedVector<T>& S)
{
return {S, dest, n_full_blocks()};
@@ -174,16 +190,17 @@ void ShareThread<T>::and_(Processor<T>& processor,
for (auto info : infos)
{
int n = T::default_length;
for (auto x : info.full_block_input_range(S))
auto& y = info.get_right_base(S);
for (auto x : info.full_block_left_range(S))
{
x.second.extend_bit(y_ext, n);
protocol->prepare_mult(x.first, y_ext, n, true);
y.extend_bit(y_ext, n);
protocol->prepare_mult(x, y_ext, n, true);
}
n = info.last_length();
if (n)
{
info.partial_block(S).left->mask(x_ext, n);
info.partial_block(S).right->extend_bit(y_ext, n);
info.partial_left_block(S)->mask(x_ext, n);
y.extend_bit(y_ext, n);
protocol->prepare_mult(x_ext, y_ext, n, true);
}
}
@@ -193,7 +210,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
if (fast_mode)
for (auto x : info.full_block_input_range(S))
protocol->prepare_mul_fast(x.first, x.second);
else
else if (info.n_full_blocks())
for (auto x : info.full_block_input_range(S))
protocol->prepare_mul(x.first, x.second);
int n = info.last_length();
@@ -228,7 +245,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
if (fast_mode)
for (auto& res : info.full_block_output_range(S))
res = protocol->finalize_mul_fast();
else
else if (info.n_full_blocks())
for (auto& res : info.full_block_output_range(S))
res = protocol->finalize_mul();

View File

@@ -56,6 +56,8 @@ public:
void join_tape();
void finish();
virtual NamedCommStats extra_comm() { return {}; }
};
template<class T>

View File

@@ -115,6 +115,7 @@ void ThreadMaster<T>::run_with_error()
for (auto thread : threads)
{
stats += thread->P->total_comm();
stats += thread->extra_comm();
exe_stats += thread->processor.stats;
delete thread;
}

View File

@@ -145,6 +145,9 @@ public:
}
};
template<class T>
const int TinierShare<T>::default_length;
} /* namespace GC */
#endif /* GC_TINIERSHARE_H_ */

View File

@@ -33,7 +33,7 @@ void TinierSharePrep<T>::buffer_secret_triples()
assert(triple_generator != 0);
params.generateBits = false;
vector<array<T, 3>> triples;
TripleShuffleSacrifice<T> sacrifice(DATA_GF2);
TripleShuffleSacrifice<T> sacrifice;
size_t required;
required = sacrifice.minimum_n_inputs_with_combining(
BaseMachine::batch_size<T>(DATA_TRIPLE));

View File

@@ -38,7 +38,10 @@ int main(int argc, const char** argv)
if (online_opts.prime_limbs() == 2)
return run<2, 1>(machine);
cerr << "Not compiled for choice of parameters" << endl;
cerr << "Try using '-lgp 128'" << endl;
if (online_opts.prime_limbs() > 2)
cerr << "Use MASCOT with large primes" << endl;
else
cerr << "Not compiled for choice of parameters" << endl;
exit(1);
}

View File

@@ -55,7 +55,7 @@ int main(int argc, const char** argv)
{
if (s == SPDZ2K_DEFAULT_SECURITY)
{
ring_domain_error(k);
ring_domain_error(k, 72);
}
else
{

View File

@@ -7,7 +7,7 @@ TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp))
NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o Protocols/ShareInterface.o
FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o

View File

@@ -103,7 +103,7 @@ public:
void randomize(PRNG& G);
void almost_randomize(PRNG& G) { randomize(G); }
void output(ostream& s,bool human) const;
void output(ostream& s, bool human, bool signed_ = true) const;
void input(istream& s,bool human);
void pack(octetStream& os) const { os.store_int(a, sizeof(a)); }

View File

@@ -15,7 +15,7 @@ inline void IntBase<T>::specification(octetStream& os)
}
template<class T>
void IntBase<T>::output(ostream& s,bool human) const
void IntBase<T>::output(ostream& s, bool human, bool) const
{
if (human)
s << a;

View File

@@ -85,15 +85,17 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree)
p = OnlineOptions::singleton.prime;
if (!probPrime(p))
{
cerr << p << " is not a prime" << endl;
exit(1);
throw runtime_error(to_string(p) + " is not a prime");
}
else if (m != 1 and p % m != 1)
{
cerr << p
<< " is not compatible with our encryption scheme, must be 1 modulo "
<< m << endl;
exit(1);
throw runtime_error(
to_string(p)
+ " is not compatible with our encryption scheme, must be "
"1 modulo " + to_string(m) + ". This is because "
"the implementation relies on number theoretic transform. "
"See https://eprint.iacr.org/2024/585.pdf for details, "
"in particular Theorem 13.");
}
else
return;
@@ -125,8 +127,10 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree)
}
void write_online_setup(string dirname, const bigint& p)
void Zp_Data::write_setup(const string& dirname) const
{
auto& p = pr;
if (p == 0)
throw runtime_error("prime cannot be 0");
@@ -145,19 +149,23 @@ void write_online_setup(string dirname, const bigint& p)
ofstream outf;
outf.open(ss.str().c_str());
outf << p << endl;
outf << montgomery << endl;
if (!outf.good())
throw file_error("cannot write to " + ss.str());
}
void check_setup(string dir, bigint pr)
void Zp_Data::check_setup(const string& dir)
{
bigint p;
bool mont = true;
string filename = dir + "Params-Data";
ifstream(filename) >> p;
ifstream(filename) >> p >> mont;
if (p == 0)
throw setup_error("no modulus in " + filename);
if (p != pr)
throw setup_error("wrong modulus in " + filename);
if (mont != montgomery)
throw setup_error("Montgomery different in " + filename);
}
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,

View File

@@ -26,8 +26,6 @@ template<class T>
void generate_prime_setup(string dir, int lgp);
template<class T>
void generate_online_setup(string dirname, bigint& p, int lgp);
void write_online_setup(string dirname, const bigint& p);
void check_setup(string dirname, bigint p);
// Setup primes only
// Chooses a p of at least lgp bits

View File

@@ -13,14 +13,15 @@ void generate_online_setup(string dirname, bigint& p, int lgp)
{
int idx, m;
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
write_online_setup(dirname, p);
T::init_field(p);
T::write_setup(dirname);
}
template<class T = gfp>
void read_setup(const string& dir_prefix, int lgp = -1)
{
bigint p;
bool montgomery = true;
string filename = dir_prefix + "Params-Data";
@@ -32,6 +33,8 @@ void read_setup(const string& dir_prefix, int lgp = -1)
#endif
ifstream inpf(filename.c_str());
inpf >> p;
inpf >> montgomery;
if (inpf.fail())
{
if (lgp > 0)
@@ -45,9 +48,12 @@ void read_setup(const string& dir_prefix, int lgp = -1)
throw file_error(filename.c_str());
}
else
T::init_field(p);
T::init_field(p, montgomery);
inpf.close();
if (OnlineOptions::singleton.verbose)
cerr << "Using prime modulus " << T::pr() << endl;
}
#endif /* MATH_SETUP_HPP_ */

View File

@@ -230,3 +230,8 @@ void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const
q_half = shanks_q_half;
r = shanks_r;
}
string Zp_Data::fake_opts() const
{
return "-P " + to_string(pr) + (montgomery ? "" : " -n");
}

View File

@@ -21,7 +21,7 @@ using namespace std;
#ifndef MAX_MOD_SZ
#if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11
#define MAX_MOD_SZ GFP_MOD_SZ
#define MAX_MOD_SZ 2 * GFP_MOD_SZ
#else
#define MAX_MOD_SZ 11
#endif
@@ -94,6 +94,11 @@ class Zp_Data
void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const;
void write_setup(const string& directory) const;
void check_setup(const string& directory);
string fake_opts() const;
template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD);
template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD);

View File

@@ -45,7 +45,6 @@ int powerMod(int x,int e,int p)
return ans;
}
size_t bigint::report_size(ReportType type) const
{
size_t res = 0;
@@ -98,6 +97,16 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs)
mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data);
}
void bigint::pack(octetStream& os, int length) const
{
os.store(*this, length);
}
void bigint::unpack(octetStream& os, int length)
{
os.get(*this, length);
}
string to_string(const bigint& x)
{
stringstream ss;

View File

@@ -134,8 +134,8 @@ public:
void generateUniform(PRNG& G, int n_bits, bool positive = false)
{ G.get(*this, n_bits, positive); }
void pack(octetStream& os, int = -1) const { os.store(*this); }
void unpack(octetStream& os, int = -1) { os.get(*this); };
void pack(octetStream& os, int = -1) const;
void unpack(octetStream& os, int = -1);
size_t report_size(ReportType type) const;
};

View File

@@ -31,6 +31,9 @@ mpf_class bigint::get_float(T v, T p, T z, T s)
Integer exp = Integer(p, 31).get();
bigint tmp;
tmp.from_signed(v);
if (abs(tmp) == 1)
BaseMachine::s().mini_warning = min(BaseMachine::s().mini_warning,
int(exp.get()));
mpf_class res = tmp;
if (exp > 0)
mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp.get());

View File

@@ -15,6 +15,7 @@ class fixint : public SignedZ2<64 * (L + 1)>
public:
typedef SignedZ2<64 * (L + 1)> super;
typedef typename conditional<L == 0, super, SignedZ2<64 * L>>::type pack_type;
fixint()
{
@@ -56,6 +57,19 @@ public:
*this = bigint::tmp;
}
void pack(octetStream& os) const
{
pack_type tmp = *this;
tmp.pack(os);
}
void unpack(octetStream& os)
{
pack_type tmp;
tmp.unpack(os);
*this = tmp;
}
int get_min_alloc() const
{
return this->N_BYTES;

View File

@@ -36,8 +36,8 @@ template<class T> void generate_prime_setup(string, int, int);
#define GFP_MOD_SZ 2
#endif
#if GFP_MOD_SZ > MAX_MOD_SZ
#error GFP_MOD_SZ must be at most MAX_MOD_SZ
#if 2 * GFP_MOD_SZ > MAX_MOD_SZ
#error 2 * GFP_MOD_SZ must be at most MAX_MOD_SZ
#endif
/**
@@ -105,9 +105,9 @@ class gfp_ : public ValueInterface
static void write_setup(int nplayers)
{ write_setup(get_prep_sub_dir<T>(nplayers)); }
static void write_setup(string dir)
{ write_online_setup(dir, pr()); }
{ ZpD.write_setup(dir); }
static void check_setup(string dir);
static string fake_opts() { return " -P " + to_string(pr()); }
static string fake_opts() { return ZpD.fake_opts(); }
/**
* Get the prime modulus

View File

@@ -28,7 +28,7 @@ inline void gfp_<X, L>::read_or_generate_setup(string dir,
template<int X, int L>
void gfp_<X, L>::check_setup(string dir)
{
::check_setup(dir, pr());
ZpD.check_setup(dir);
}
template<int X, int L>
@@ -201,7 +201,7 @@ bool gfp_<X, L>::allows(Dtype type)
template<int X, int L>
void gfp_<X, L>::specification(octetStream& os)
{
os.store(pr());
ZpD.pack(os);
}
template <int X, int L>

View File

@@ -33,7 +33,7 @@ char gfpvar_<X, L>::type_char()
template<int X, int L>
void gfpvar_<X, L>::specification(octetStream& os)
{
os.store(pr());
ZpD.pack(os);
}
template<int X, int L>
@@ -101,13 +101,13 @@ const bigint& gfpvar_<X, L>::pr()
template<int X, int L>
void gfpvar_<X, L>::check_setup(string dir)
{
::check_setup(dir, pr());
ZpD.check_setup(dir);
}
template<int X, int L>
void gfpvar_<X, L>::write_setup(string dir)
{
write_online_setup(dir, pr());
ZpD.write_setup(dir);
}
template<int X, int L>

View File

@@ -332,7 +332,7 @@ void modp_<L>::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_)
if (human)
{ bigint te;
to_bigint(te, ZpD);
if (te < ZpD.pr / 2 or not signed_)
if (te <= ZpD.pr_half or not signed_)
s << te;
else
s << (te - ZpD.pr);

View File

@@ -24,16 +24,26 @@ public:
delete N;
}
void send_to_no_stats(int player, const octetStream& o) const
void send_to(int player, const octetStream& o) const
{
P.send_to(player, o);
}
void receive_player_no_stats(int i, octetStream& o) const
void receive_player(int i, octetStream& o) const
{
P.receive_player(i, o);
}
void send_to_no_stats(int, const octetStream&) const
{
throw not_implemented();
}
void receive_player_no_stats(int, octetStream&) const
{
throw not_implemented();
}
void send_receive_all_no_stats(const vector<vector<bool>>& channels,
const vector<octetStream>& to_send,
vector<octetStream>& to_receive) const

View File

@@ -367,10 +367,7 @@ long MultiPlayer<T>::receive_long(int i) const
void Player::send_to(int player,const octetStream& o) const
{
#ifdef VERBOSE_COMM
cerr << "sending to " << player << endl;
#endif
TimeScope ts(comm_stats["Sending directly"].add(o));
TimeScope ts(comm_stats["Sending directly"].add(o, player));
send_to_no_stats(player, o);
sent += o.get_length();
}
@@ -405,12 +402,9 @@ void Player::receive_all(vector<octetStream>& os) const
void Player::receive_player(int i,octetStream& o) const
{
#ifdef VERBOSE_COMM
cerr << "receiving from " << i << endl;
#endif
TimeScope ts(timer);
receive_player_no_stats(i, o);
comm_stats["Receiving directly"].add(o, ts);
comm_stats["Receiving directly"].add(o, ts, i);
}
template<class T>
@@ -484,10 +478,7 @@ void MultiPlayer<T>::exchange_no_stats(int other, const octetStream& o, octetStr
void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const
{
#ifdef VERBOSE_COMM
cerr << "Exchanging with " << other << endl;
#endif
TimeScope ts(comm_stats["Exchanging"].add(o));
TimeScope ts(comm_stats["Exchanging"].add(o, other));
exchange_no_stats(other, o, to_receive);
sent += o.get_length();
}
@@ -603,9 +594,8 @@ void Player::send_receive_all(const vector<vector<bool>>& channels,
if (i != my_num() and channels.at(my_num()).at(i))
{
data += to_send.at(i).get_length();
#ifdef VERBOSE_COMM
cerr << "Send " << to_send.at(i).get_length() << " to " << i << endl;
#endif
if (OnlineOptions::singleton.has_option("detailed_verbose_comm"))
cerr << "Send " << to_send.at(i).get_length() << " bytes to " << i << endl;
}
TimeScope ts(comm_stats["Sending/receiving"].add(data));
sent += data;
@@ -879,15 +869,22 @@ Timer& CommStatsWithName::add_length_only(size_t length)
return stats.add_length_only(length);
}
Timer& CommStatsWithName::add(const octetStream& os)
Timer& CommStatsWithName::add(const octetStream& os, int player)
{
return add(os.get_length());
return add(os.get_length(), player);
}
Timer& CommStatsWithName::add(size_t length)
Timer& CommStatsWithName::add(size_t length, int player)
{
if (OnlineOptions::singleton.has_option("verbose_comm"))
fprintf(stderr, "%s %zu bytes\n", name.c_str(), length);
{
if (player < 0)
fprintf(stderr, "%s %zu bytes\n", name.c_str(), length);
else
fprintf(stderr, "%s %zu bytes with party %d\n", name.c_str(), length,
player);
}
return stats.add(length);
}

View File

@@ -160,9 +160,10 @@ public:
name(name), stats(stats) {}
Timer& add_length_only(size_t length);
Timer& add(const octetStream& os);
Timer& add(size_t length);
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
Timer& add(const octetStream& os, int player = -1);
Timer& add(size_t length, int player = -1);
void add(const octetStream& os, const TimeScope& scope, int player = -1)
{ add(os, player) += scope; }
};
class NamedCommStats : public map<string, CommStats>
@@ -272,7 +273,7 @@ public:
/**
* Send to a specific player
*/
void send_to(int player,const octetStream& o) const;
virtual void send_to(int player,const octetStream& o) const;
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
/**
* Receive from all other players.
@@ -282,7 +283,7 @@ public:
/**
* Receive from a specific player
*/
void receive_player(int i,octetStream& o) const;
virtual void receive_player(int i,octetStream& o) const;
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
virtual void receive_player(int i,FlexBuffer& buffer) const;
@@ -546,6 +547,8 @@ public:
size_t send(const PlayerBuffer& buffer, bool block) const;
size_t recv(const PlayerBuffer& buffer, bool block) const;
NamedCommStats get_comm_stats() const { return comm_stats; }
};
class RealTwoPartyPlayer : public VirtualTwoPartyPlayer

View File

@@ -243,16 +243,23 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
CODE_LOCATION
typedef typename W::input_type::share_type::open_type T;
auto nTriplesPerLoop = this->nTriplesPerLoop * 10;
auto nTriplesPerLoop = this->nTriplesPerLoop;
auto& valueBits = this->valueBits;
auto& share_prg = this->share_prg;
auto& ot_multipliers = this->ot_multipliers;
auto& nparties = this->nparties;
auto& globalPlayer = this->globalPlayer;
if (this->thread_num >= 0)
nTriplesPerLoop *= 10;
// extra value for sacrifice
int toCheck = nTriplesPerLoop
+ DIV_CEIL(W::mac_key_type::size_in_bits(), T::size_in_bits());
if (OnlineOptions::singleton.has_option("verbose_input"))
fprintf(stderr, "generating %d input tuples\n", toCheck);
valueBits.resize(1);
this->signal_multipliers({player, toCheck});
bool mine = player == globalPlayer.my_num();

View File

@@ -77,6 +77,22 @@ void OTExtensionWithMatrix::protocol_agreement()
if (OnlineOptions::singleton.has_option("high_softspoken"))
softspoken_k = 8;
if (OnlineOptions::singleton.has_param("softspoken"))
softspoken_k = OnlineOptions::singleton.get_param("softspoken");
int needed = DIV_CEIL(nbaseOTs, softspoken_k) * softspoken_k;
baseReceiverInput.resize_zero(needed);
for (int i = nbaseOTs; i < needed; i++)
{
auto zero = string(SEED_SIZE, '\0');
G_receiver.push_back(zero);
G_sender.push_back({});
for (int j = 0; j < 2; j++)
G_sender.back().push_back(zero);
}
bundle.mine.store(softspoken_k);
player->unchecked_broadcast(bundle);
@@ -177,7 +193,8 @@ void OTExtensionWithMatrix::soft_sender(size_t n)
return;
if (OnlineOptions::singleton.has_option("verbose_ot"))
fprintf(stderr, "%zu OTs as sender\n", n);
fprintf(stderr, "%zu OTs as sender (%s)\n", n,
passive_only ? "semi-honest" : "malicious");
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k);

View File

@@ -9,6 +9,10 @@
template <class T>
void OTVoleBase<T>::evaluate(vector<T>& output, const vector<T>& newReceiverInput) {
CODE_LOCATION
if (OnlineOptions::singleton.has_option("verbose_vole"))
fprintf(stderr, "%d-bit VOLE with %zu elements and S=%d\n", T::N_BITS,
newReceiverInput.size(), S);
const int N1 = newReceiverInput.size() + 1;
output.resize(newReceiverInput.size());
auto& os = oss;

View File

@@ -13,6 +13,7 @@
#include <iostream>
#include <sodium.h>
#include <regex>
using namespace std;
BaseMachine* BaseMachine::singleton = 0;
@@ -66,16 +67,21 @@ int BaseMachine::triple_bucket_size(DataFieldType type)
int BaseMachine::bucket_size(size_t usage)
{
int res = OnlineOptions::singleton.bucket_size;
int min = res;
if (usage)
{
for (int B = res; B <= 5; B++)
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
res = 5;
for (int B = res; B >= min; B--)
if (ShuffleSacrifice(B).minimum_n_outputs() > usage * 1.1)
break;
else
res = B;
}
if (OnlineOptions::singleton.has_option("debug_batch_size"))
fprintf(stderr, "bucket_size=%d usage=%zu\n", res, usage);
return res;
}
@@ -103,8 +109,13 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
return -1;
}
bool BaseMachine::allow_mulm()
{
return singleton and singleton->relevant_opts.find("no_mulm") != string::npos;
}
BaseMachine::BaseMachine() :
nthreads(0), multithread(false), nan_warning(0)
nthreads(0), multithread(false), nan_warning(0), mini_warning(0)
{
if (sodium_init() == -1)
throw runtime_error("couldn't initialize libsodium");
@@ -182,6 +193,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
getline(inpf, relevant_opts);
getline(inpf, security);
getline(inpf, gf2n);
getline(inpf, expected_communication);
inpf.close();
}
@@ -320,17 +332,47 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
Bundle<octetStream> bundle(P);
bundle.mine.store(stats.sent);
P.Broadcast_Receive_no_stats(bundle);
size_t global = 0;
long long global = 0;
for (auto& os : bundle)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
smatch what;
regex comm_regexp("online:([0-9]*) offline:([0-9]*) n_parties:([0-9]*)");
if (regex_search(expected_communication, what, comm_regexp))
{
long long expected = stoll(what[1]) + stoll(what[2]);
int n_parties = stoi(what[3]);
if (expected and n_parties != P.num_players())
{
cerr << "Wrong number of parties in compiler's expectation: "
<< n_parties << endl;
}
else if (expected)
{
double over = round(100. * (global - expected) / expected);
if (over >= 5)
cerr
<< "Actual communication exceeds the compiler's expectation by "
<< over << " percent." << endl;
if (over < 0)
{
if (OnlineOptions::singleton.has_option("overestimate"))
cerr << "Actual communication is below the compiler's "
"expectation by " << -over << " percent." << endl;
else
cerr << "The compiler overestimated the communication." << endl;
}
}
}
}
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
if (x.first.find("transmission") == string::npos)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (multithread)
@@ -341,3 +383,9 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
print_global_comm(P, comm_stats);
}
void BaseMachine::add_one_off(const NamedCommStats& comm)
{
if (has_singleton())
s().one_off_comm += comm;
}

View File

@@ -23,6 +23,7 @@ void print_usage(ostream& o, const char* name, size_t capacity);
class BaseMachine
{
friend class Program;
template<class sint, class sgf2n> friend class thread_info;
protected:
static BaseMachine* singleton;
@@ -38,6 +39,9 @@ protected:
string relevant_opts;
string security;
string gf2n;
string expected_communication;
NamedCommStats one_off_comm;
virtual size_t load_program(const string& threadname,
const string& filename);
@@ -60,6 +64,7 @@ public:
vector<Program> progs;
bool nan_warning;
int mini_warning;
static BaseMachine& s();
static bool has_singleton() { return singleton != 0; }
@@ -75,7 +80,8 @@ public:
static int security_from_schedule(string progname);
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0,
int factor = 0);
template<class T>
static int input_batch_size(int player, int buffer_size = 0);
template<class T>
@@ -86,6 +92,10 @@ public:
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
static bool allow_mulm();
static void add_one_off(const NamedCommStats& comm);
BaseMachine();
virtual ~BaseMachine() {}
@@ -110,6 +120,8 @@ public:
void print_comm(Player& P, const NamedCommStats& stats);
virtual const Names& get_N() { throw not_implemented(); }
virtual void gap_warning(int) { throw not_implemented(); }
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
@@ -118,7 +130,8 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
}
template<class T>
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback,
int factor)
{
if (OnlineOptions::singleton.has_option("debug_batch_size"))
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
@@ -133,7 +146,8 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
else if (fallback > 0)
n_opts = fallback;
else
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
n_opts = OnlineOptions::singleton.batch_size
* max(factor, T::default_length);
if (buffer_size <= 0 and has_program())
{
@@ -180,9 +194,17 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
res = n_opts;
if (OnlineOptions::singleton.has_option("debug_batch_size"))
{
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
<< " buffer_size=" << buffer_size << endl;
<< " buffer_size=" << buffer_size << " bits/dabits="
<< T::LivePrep::bits_from_dabits() << "/"
<< T::LivePrep::dabits_from_bits() << " has_program="
<< has_program();
if (program)
cerr << " program=" << program->get_name();
cerr << endl;
}
assert(res > 0);
return res;

View File

@@ -34,7 +34,9 @@ public:
+ ", have you generated edaBits, "
"for example by running "
"'./Fake-Offline.x -e "
+ to_string(n_bits) + " ...'?");
+ to_string(n_bits)
+ T::template proto_fake_opts<typename T::clear>()
+ " ...'?");
}
assert(BufferBase::file);

View File

@@ -11,6 +11,7 @@
#include "GC/instructions.h"
#include "Memory.hpp"
#include "Instruction.hpp"
#include <iomanip>

View File

@@ -1501,8 +1501,6 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
auto& processor = Proc.Procb;
auto& Ci = Proc.get_Ci();
BaseMachine::program = this;
while (Proc.PC<size)
{
Proc.last_PC = Proc.PC;
@@ -1559,7 +1557,9 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
template<class T>
void Program::mulm_check() const
{
if (T::function_dependent and not OnlineOptions::singleton.has_option("allow_mulm"))
if (T::function_dependent
and not (BaseMachine::allow_mulm()
or OnlineOptions::singleton.has_option("allow_mulm")))
throw runtime_error("Mixed multiplication not implemented for function-dependent preprocessing. "
"Use '-E <protocol>' during compilation or state "
"'program.use_mulm = False' at the beginning of your high-level program.");

View File

@@ -54,6 +54,9 @@ class Machine : public BaseMachine
NamedCommStats max_comm;
int max_trunc_size;
Lock warn_lock;
size_t load_program(const string& threadname, const string& filename);
void prepare(const string& progname_str);
@@ -126,6 +129,8 @@ class Machine : public BaseMachine
Player& get_player() { return *P; }
void check_program();
void gap_warning(int k);
};
#endif /* MACHINE_H_ */

View File

@@ -55,6 +55,7 @@ template<class sint, class sgf2n>
Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
const OnlineOptions opts)
: my_number(playerNames.my_num()), N(playerNames),
max_trunc_size(0),
use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts),
external_clients(my_number)
{
@@ -607,6 +608,12 @@ void Machine<sint, sgf2n>::run(const string& progname)
if (multithread)
cerr << " (overall core time)";
cerr << endl;
auto& P = *this->P;
auto one_off = TreeSum<Z2<64>>().run(
this->one_off_comm.sent, P).get_limb(0);
if (one_off)
cerr << "One-off global communication: " << one_off * 1e-6 << " MB"
<< endl;
}
print_timers();
@@ -685,12 +692,17 @@ void Machine<sint, sgf2n>::run(const string& progname)
<< "have you considered using " << alt << " instead?" << endl;
}
if (nan_warning and sint::real_shares(*P))
if ((nan_warning or mini_warning) and sint::real_shares(*P))
{
cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. See ";
cerr << "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix";
if (nan_warning)
cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. ";
if (mini_warning)
cerr << pow(2, mini_warning) << " is the smallest non-zero number "
<< "in a used fixed-point representation. ";
cerr << "See https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix";
cerr << " for details" << endl;
nan_warning = false;
mini_warning = 0;
}
#ifdef VERBOSE
@@ -743,6 +755,10 @@ void Machine<sint, sgf2n>::suggest_optimizations()
cerr << "This program might benefit from some protocol options." << endl
<< "Consider adding the following at the beginning of your code:"
<< endl << optimizations;
if (sint::clear::n_bits() < max_trunc_size)
cerr << "The computation domain is too small "
<< "for low-round truncation; it would need to have at least "
<< max_trunc_size << " bits." << endl;
#ifndef __clang__
cerr << "This virtual machine was compiled with GCC. Recompile with "
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;
@@ -768,4 +784,15 @@ void Machine<sint, sgf2n>::check_program()
}
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::gap_warning(int k)
{
if (k > max_trunc_size)
{
warn_lock.lock();
max_trunc_size = max(k, max_trunc_size);
warn_lock.unlock();
}
}
#endif

View File

@@ -37,6 +37,16 @@ public:
#endif
}
const T* begin() const
{
return data();
}
const T* end() const
{
return data() + size();
}
virtual T& operator[](size_t i) = 0;
virtual const T& operator[](size_t i) const = 0;

View File

@@ -268,7 +268,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
printf("\tClient %d about to run %d\n",num,program);
#endif
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
online_prep_timer -= Proc.prep_time();
Proc.reset(progs[program], job.arg);
// Bits, Triples, Squares, and Inverses skipping
@@ -278,6 +278,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
//printf("\tExecuting program");
// Execute the program
BaseMachine::program = &progs[program];
progs[program].execute(Proc);
// make sure values used in other threads are safe
@@ -298,7 +299,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
"in thread %d\n", program, num);
#endif
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
online_prep_timer += Proc.prep_time();
wait_timer.start();
queues->finished(job, P.total_comm());
wait_timer.stop();
@@ -307,10 +308,10 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// final check
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
online_prep_timer -= Proc.prep_time();
Proc.check();
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
online_prep_timer += Proc.prep_time();
if (machine.opts.file_prep_per_thread)
Proc.DataF.prune();

View File

@@ -9,10 +9,12 @@
#include "Math/gfpvar.h"
#include "Protocols/HemiOptions.h"
#include "Protocols/config.h"
#include "FHEOffline/config.h"
#include "Math/gfp.hpp"
#include <boost/filesystem.hpp>
#include <regex>
using namespace std;
@@ -40,6 +42,8 @@ OnlineOptions::OnlineOptions() : playerno(-1)
max_broadcast = 0;
receive_threads = false;
code_locations = false;
have_warned_about_comp_sec = false;
semi_honest = false;
#ifdef VERBOSE
verbose = true;
#else
@@ -161,6 +165,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
opt.get("--options")->getStrings(options);
for (auto& option : options)
if (option.find("verbose") == 0)
verbose = true;
code_locations = opt.isSet("--code-locations");
#ifdef THROW_EXCEPTIONS
@@ -463,6 +471,7 @@ void OnlineOptions::finalize_with_error(ez::ezOptionParser& opt)
o->getString(disk_memory);
receive_threads = opt.isSet("--threads");
semi_honest = opt.isSet("--semi-honest");
if (use_security_parameter)
{
@@ -505,3 +514,35 @@ int OnlineOptions::prime_limbs()
{
return DIV_CEIL(prime_length(), 64);
}
bool OnlineOptions::has_param(const string& param)
{
for (auto& x : options)
if (x.find(param + "=") == 0)
return true;
return false;
}
int OnlineOptions::get_param(const string& param)
{
basic_regex re(param + "=([0-9]+)");
smatch match;
for (auto& x : options)
if (regex_match(x, match, re))
return atoi(match[1].str().c_str());
throw runtime_error("parameter not found: " + param);
}
int OnlineOptions::comp_sec()
{
int res = COMP_SEC;
if (has_param("comp_sec"))
res = get_param("comp_sec");
if (res < 128 and not have_warned_about_comp_sec)
{
cerr << "WARNING: computational security parameter " << res
<< " suitable for testing only" << endl;
have_warned_about_comp_sec = true;
}
return res;
}

View File

@@ -15,6 +15,8 @@ class OnlineOptions
{
void finalize_with_error(ez::ezOptionParser& opt);
bool have_warned_about_comp_sec;
public:
static OnlineOptions singleton;
@@ -44,6 +46,7 @@ public:
vector<string> options;
string executable;
bool code_locations;
bool semi_honest;
OnlineOptions();
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
@@ -79,6 +82,11 @@ public:
{
return find(options.begin(), options.end(), option) != options.end();
}
bool has_param(const string& param);
int get_param(const string& param);
int comp_sec();
};
#endif /* PROCESSOR_ONLINEOPTIONS_H_ */

View File

@@ -97,6 +97,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-N", // Flag token.
"--nparties" // Flag token.
);
if (T::semi_honest_option)
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Semi-honest operation (default: malicious security)"
// Help description.
"-sh", // Flag token.
"--semi-honest" // Flag token.
);
}
template<class T>

View File

@@ -310,6 +310,8 @@ class Processor : public ArithmeticProcessor
void call_tape(int tape_number, int arg, const vector<int>& results);
TimerWithComm prep_time();
private:
template<class T> friend class SPDZ;

View File

@@ -638,9 +638,6 @@ void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
int batchStartI = 0;
int batchStartJ = 0;
size_t sourceSize = source.size();
const T* sourceData = source.data();
protocol.init_dotprod();
for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) {
auto output = S.begin() + matmulArgs[0];
@@ -654,27 +651,54 @@ void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end());
for (int j = 0; j < resultNumberOfColumns; j += 1) {
auto actualSecondFactorColumn =
Proc->get_Ci().at(matmulArgs[9] + j).get();
auto secondBase = source.begin() + secondFactorBase
+ actualSecondFactorColumn;
for (auto &x : Range(Proc->get_Ci(), matmulArgs[8],
usedNumberOfFirstFactorColumns))
assert(
secondBase + x.get() * secondFactorTotalNumberOfColumns
< source.end());
}
vector<long> second_factors;
second_factors.reserve(usedNumberOfFirstFactorColumns);
for (auto& x : Range(Proc->get_Ci(), matmulArgs[8],
usedNumberOfFirstFactorColumns))
second_factors.push_back(x.get() * secondFactorTotalNumberOfColumns);
for (int i = 0; i < resultNumberOfRows; i += 1) {
auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get();
auto firstBase = source.begin() + firstFactorBase
+ actualFirstFactorRow * firstFactorTotalNumberOfColumns;
for (auto& x : Range(Proc->get_Ci(), matmulArgs[7],
usedNumberOfFirstFactorColumns))
assert(firstBase + x.get() < source.end());
for (int j = 0; j < resultNumberOfColumns; j += 1) {
auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get();
auto secondBase = source.begin() + secondFactorBase
+ actualSecondFactorColumn;
#ifdef MATMULSM_DEBUG
cout << "Preparing " << i << "," << j << "(buffer size: " << protocol.get_buffer_size() << ")" << endl;
#endif
for (int k = 0; k < usedNumberOfFirstFactorColumns; k += 1) {
auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get();
auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get();
auto second_it = second_factors.begin();
auto firstAddress = firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn;
auto secondAddress = secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn;
for (auto& x : Range(Proc->get_Ci(), matmulArgs[7],
usedNumberOfFirstFactorColumns))
{
auto actualFirstFactorColumn = x.get();
assert(firstAddress < sourceSize);
assert(secondAddress < sourceSize);
auto first = firstBase + actualFirstFactorColumn;
auto second = secondBase + *second_it++;
protocol.prepare_dotprod(sourceData[firstAddress], sourceData[secondAddress]);
protocol.prepare_dotprod(*first, *second);
}
protocol.next_dotprod();
@@ -905,9 +929,19 @@ void Conv2dTuple::post(StackedVector<T>& S, typename T::Protocol& protocol)
template<class T>
void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
{
typename T::Protocol::Shuffler(S, instruction.get_size(),
instruction.get_n(), instruction.get_r(0), instruction.get_r(1),
*this);
size_t n = instruction.get_size();
size_t unit_size = instruction.get_n();
size_t output_base = instruction.get_r(0);
size_t input_base = instruction.get_r(1);
typename T::Protocol::Shuffler shuffler(*this);
typename T::Protocol::Shuffler::shuffle_type shuffle;
shuffler.generate(n / unit_size, shuffle);
vector<ShuffleTuple<T>> shuffles{ShuffleTuple<T>(n, output_base,
input_base, unit_size, shuffle, true)};
shuffler.apply_multiple(S, shuffles);
maybe_check();
}
@@ -916,7 +950,10 @@ template<class T>
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
ShuffleStore& shuffle_store)
{
return shuffler.generate(instruction.get_n(), shuffle_store);
size_t n = instruction.get_n();
auto res = shuffle_store.add(n);
shuffler.generate(n, shuffle_store.get(res).second);
return res;
}
template<class T>
@@ -926,21 +963,18 @@ void SubProcessor<T>::apply_shuffle(const Instruction& instruction,
const auto& args = instruction.get_start();
const auto n_shuffles = args.size() / 6;
vector<size_t> sizes(n_shuffles, 0);
vector<size_t> destinations(n_shuffles, 0);
vector<size_t> sources(n_shuffles, 0);
vector<size_t> unit_sizes(n_shuffles, 0);
vector<size_t> shuffles(n_shuffles, 0);
vector<bool> reverse(n_shuffles, false);
for (size_t i = 0; i < n_shuffles; i++) {
sizes[i] = args[6 * i];
destinations[i] = args[6 * i + 1];
sources[i] = args[6 * i + 2];
unit_sizes[i] = args[6 * i + 3];
shuffles[i] = Proc->read_Ci(args[6 * i + 4]);
reverse[i] = args[6 * i + 5];
vector<ShuffleTuple<T>> shuffles;
for (size_t i = 0; i < n_shuffles; i++)
{
shuffles.push_back(
ShuffleTuple<T>(args[6 * i], args[6 * i + 1], args[6 * i + 2],
args[6 * i + 3],
shuffle_store.get(Proc->read_Ci(args[6 * i + 4])),
bool(args[6 * i + 5])));
}
shuffler.apply_multiple(S, sizes, destinations, sources, unit_sizes, shuffles, reverse, shuffle_store);
shuffler.apply_multiple(S, shuffles);
maybe_check();
}
@@ -1184,4 +1218,13 @@ void Processor<sint, sgf2n>::call_tape(int tape_number, int arg,
arg_stack.pop_back();
}
template<class sint, class sgf2n>
TimerWithComm Processor<sint, sgf2n>::prep_time()
{
auto res = DataF.total_time();
res += Procp.protocol.prep_time();
res += Proc2.protocol.prep_time();
return res;
}
#endif

View File

@@ -42,6 +42,8 @@ class Program
size_t size() const { return p.size(); }
string get_name() const { return name; }
// Read in a program
void parse(string filename);
void parse_with_error(string filename);

View File

@@ -31,10 +31,24 @@ HonestMajorityRingMachine<U, V>::HonestMajorityRingMachine(int argc, const char*
RingMachine<U, V, HonestMajorityMachine>(argc, argv, opt, online_opts, nplayers);
}
inline void ring_domain_error(int R)
inline void ring_domain_error(int R, int max)
{
cerr << "not compiled for " << R << "-bit computation, " << endl;
cerr << "compile with -DRING_SIZE=" << R << endl;
cerr << "The virtual machine is not compiled for " << R
<< "-bit computation." << endl;
cerr << "Compile with 'MY_CFLAGS += -DRING_SIZE=" << R
<< "' in 'CONFIG.mine'";
(void) max;
#ifndef FEWER_RINGS
for (int r = 0; r <= max; r += 64)
{
if (r >= R)
{
cerr << " or try " << "'-R " << r << "'";
break;
}
}
#endif
cerr << "." << endl;
exit(1);
}
@@ -60,7 +74,7 @@ RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
X(RING_SIZE)
#endif
#undef X
ring_domain_error(R);
ring_domain_error(R, 192);
}
template<template<int K, int S> class U, template<class T> class V>
@@ -98,7 +112,7 @@ HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecuri
X(72) X(128)
#endif
#undef X
ring_domain_error(R);
ring_domain_error(R, 128);
}
#endif /* PROCESSOR_RINGMACHINE_HPP_ */

View File

@@ -11,6 +11,7 @@
using namespace std;
#include "OnlineOptions.h"
#include "BaseMachine.h"
#include "GC/ArgTuples.h"
template<class T> class StackedVector;
@@ -106,9 +107,12 @@ public:
TruncPrTupleWithGap(vector<int>::const_iterator it) :
TruncPrTuple<T>(it)
{
big_gap_ = this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error;
int min_size = this->k + OnlineOptions::singleton.trunc_error;
big_gap_ = min_size <= T::n_bits();
if (T::prime_field and small_gap())
throw runtime_error("domain too small for chosen truncation error");
if (small_gap() and BaseMachine::has_singleton())
BaseMachine::s().gap_warning(min_size);
}
T upper(T mask)

View File

@@ -202,7 +202,7 @@
*dest++ = *op1++ > *op2++) \
X(EQC, auto dest = &Ci[r[0]]; auto op1 = &Ci[r[1]]; auto op2 = &Ci[r[2]], \
*dest++ = *op1++ == *op2++) \
X(PRINTINT, Proc.out << Proc.read_Ci(r[0]) << flush,) \
X(PRINTINT, print(Proc.out, &Proc.get_Ci_ref(r[0])),) \
X(PRINTFLOATPREC, Proc.out << setprecision(n),) \
X(PRINTSTR, Proc.out << string((char*)&n,4) << flush,) \
X(PRINTCHR, Proc.out << string((char*)&n,1) << flush,) \

View File

@@ -0,0 +1,5 @@
a = sbits.get_type(int(program.args[1]))(0)
@for_range(int(program.args[2]))
def _(i):
a & a

View File

@@ -0,0 +1,11 @@
import math
n = int(program.args[1])
n_sqrt = int(math.sqrt(n))
sfix.Matrix(n_sqrt, 10) * sfix.Matrix(10, n_sqrt)
(sfix(0, size=n) < 0).store_in_mem(0)
sint.Array(n).secure_shuffle()
sint(personal(0, cint(0, size=n)))

View File

@@ -0,0 +1,15 @@
#sfix.set_precision(32, 63)
#program.use_trunc_pr = True
#program.use_split(3)
program.options_from_args()
sfix.set_precision_from_args(program)
try:
n_loops = int(program.args[2])
except:
n_loops = 1
a = sfix(cint(0, size=int(program.args[1])))
@for_range(n_loops)
def _(i):
(a < a)#.store_in_mem(0)

View File

@@ -0,0 +1,10 @@
program.options_from_args()
sfix.set_precision_from_args(program)
n = int(program.args[1])
m = int(program.args[2])
a = sfix(0, size=n)
@for_range(m)
def _(i):
(a / a).store_in_mem(0)

View File

@@ -0,0 +1,14 @@
program.options_from_args()
sfix.set_precision_from_args(program)
try:
n = int(program.args[1])
except:
n = 10 ** 6
m = int(program.args[2])
a = sfix(0, size=n)
@for_range(m)
def _(i):
(a * a).store_in_mem(0)

Some files were not shown because too many files have changed in this diff Show More