Multinode computation.

This commit is contained in:
Marcel Keller
2023-12-14 12:00:24 +11:00
parent 1f8f784611
commit cf4426fdb3
130 changed files with 2505 additions and 799 deletions

View File

@@ -10,6 +10,7 @@
#include <cstring>
#include <string>
#include <vector>
#include <stdint.h>
using namespace std;
#include "Tools/CheckVector.h"

View File

@@ -1,5 +1,23 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.8 (December 14, 2023)
- Functionality for multiple nodes per party
- Functionality to use disk space for high-level data structures
- True division is always fixed-point division (similar to Python 3)
- Compiler option to optimize for specific protocol
- Cleartext permutation
- Faster compilation and lower bytecode size
- Functionality to output secret shares from high-level code
- Run-time command-line arguments accessible from high-level code
- Client connection setup specifies cleartext domain
- Compile-time parameter for connection timeout
- Prevent connections from timing out (@ParallelogramPal)
- More ECDSA examples
- More flexible multiplication instruction
- Dot product instruction supports several operations at once
- Example-based virtual machine explanation
## 0.3.7 (August 14, 2023)
- Path Oblivious Heap (@tskovlund)

1
CONFIG
View File

@@ -87,6 +87,7 @@ LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
LDLIBS += $(BREW_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto
LDLIBS += -lboost_filesystem -lboost_iostreams
CFLAGS += -I./local/include

View File

@@ -80,17 +80,17 @@ opcodes = dict(
CONVCBITVEC = 0x231,
)
class BinaryVectorInstruction(base.Instruction):
is_vec = lambda self: True
class BinaryCiscable(base.Ciscable):
pass
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))
class BinaryVectorInstruction(BinaryCiscable):
is_vec = lambda self: True
class NonVectorInstruction(base.Instruction):
is_vec = lambda self: False
def __init__(self, *args, **kwargs):
assert(args[0].n <= args[0].unit)
assert(args[0].n is None or args[0].n <= args[0].unit)
super(NonVectorInstruction, self).__init__(*args, **kwargs)
class NonVectorInstruction1(base.Instruction):
@@ -163,7 +163,7 @@ class andrs(BinaryVectorInstruction):
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
class andrsvec(base.VarArgsInstruction, base.Mergeable,
base.DynFormatInstruction):
base.DynFormatInstruction, BinaryCiscable):
""" Constant-vector AND of secret bit registers (vectorized version).
:param: total number of arguments to follow (int)
@@ -206,6 +206,9 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))
class ands(BinaryVectorInstruction):
""" Bitwise AND of secret bit register vector.
@@ -306,7 +309,7 @@ class bitcoms(NonVectorInstruction, base.VarArgsInstruction):
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
class bitdecc(NonVectorInstruction, base.VarArgsInstruction):
""" Secret bit register decomposition.
""" Clear bit register decomposition.
:param: number of arguments to follow / number of bits plus one (int)
:param: source (sbit)
@@ -513,8 +516,8 @@ class convcbitvec(BinaryVectorInstruction):
"""
code = opcodes['CONVCBITVEC']
arg_format = ['int','ciw','cb']
def __init__(self, *args):
super(convcbitvec, self).__init__(*args)
def __init__(self, *args, **kwargs):
super(convcbitvec, self).__init__(*args, **kwargs)
assert(args[2].n == args[0])
args[1].set_size(args[0])
@@ -546,14 +549,14 @@ class split(base.Instruction):
super(split_class, self).__init__(*args, **kwargs)
assert (len(args) - 2) % args[0] == 0
class movsb(NonVectorInstruction):
class movsb(BinaryVectorInstruction):
""" Copy secret bit register.
:param: destination (sbit)
:param: source (sbit)
"""
code = opcodes['MOVSB']
arg_format = ['sbw','sb']
arg_format = ['int', 'sbw','sb']
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
""" Secret bit register vector transpose. The first destination vector
@@ -568,8 +571,6 @@ class trans(base.VarArgsInstruction, base.DynFormatInstruction):
"""
code = opcodes['TRANS']
is_vec = lambda self: True
def __init__(self, *args):
super(trans, self).__init__(*args)
@classmethod
def dynamic_arg_format(cls, args):

View File

@@ -97,8 +97,9 @@ class bits(Tape.Register, _structure, _bit):
cbits.conv_cint_vec(a, *res)
return res
@classmethod
def malloc(cls, size, creator_tape=None):
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
def malloc(cls, size, creator_tape=None, **kwargs):
return Program.prog.malloc(size, cls, creator_tape=creator_tape,
**kwargs)
@staticmethod
def n_elements():
return 1
@@ -254,6 +255,18 @@ class bits(Tape.Register, _structure, _bit):
return self.get_type(length).bit_compose([self] * length)
else:
raise CompilerError('cannot expand from %s to %s' % (self.n, length))
@classmethod
def new_vector(cls, size):
return cls.get_type(size)()
@classmethod
def concat(cls, parts):
return cls.bit_compose(
sum([part.bit_decompose() for part in parts], []))
def copy_from_part(self, source, base, size):
self.mov(self,
self.bit_compose(source.bit_decompose()[base:base + size]))
def vector_size(self):
return self.n
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
@@ -425,7 +438,7 @@ class sbits(bits):
tmp = cbits.get_type(n)()
tmp.conv_regint_by_bit(n, tmp, other)
res.load_other(tmp)
mov = inst.movsb
mov = staticmethod(lambda x, y: inst.movsb(x.n, x, y))
types = {}
def __init__(self, *args, **kwargs):
bits.__init__(self, *args, **kwargs)
@@ -1048,6 +1061,9 @@ def result_conv(x, y):
class sbit(bit, sbits):
""" Single secret bit. """
@classmethod
def get_type(cls, length):
return sbits.get_type(length)
def if_else(self, x, y):
""" Non-vectorized oblivious selection::
@@ -1301,6 +1317,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
bit_extend = staticmethod(_complement_two_extend)
mul_functions = {}
@classmethod
def popcnt_bits(cls, bits):
return sbitvec.from_vec(bits).popcnt()
@@ -1326,20 +1343,42 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
elif isinstance(other, sbitfixvec):
return NotImplemented
my_bits, other_bits = self.expand(other, False)
matrix = []
m = float('inf')
uniform = True
for x in itertools.chain(my_bits, other_bits):
try:
uniform &= type(x) == type(my_bits[0]) and x.n == my_bits[0].n
m = min(m, x.n)
except:
pass
if uniform and Program.prog.options.cisc:
bl = len(my_bits)
key = bl, len(other_bits)
if key not in self.mul_functions:
def instruction(*args):
res = self.binary_mul(args[bl:2 * bl], args[2 * bl:],
args[0].n)
for x, y in zip(res, args):
x.mov(y, x)
instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits))
self.mul_functions[key] = instructions_base.cisc(instruction,
bl)
res = [sbits.get_type(m)() for i in range(bl)]
self.mul_functions[key](*(res + my_bits + other_bits))
return self.from_vec(res)
else:
return self.binary_mul(my_bits, other_bits, m)
@classmethod
def binary_mul(cls, my_bits, other_bits, m):
matrix = []
for i, b in enumerate(other_bits):
if m == 1:
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
matrix.append([x * b for x in my_bits[:len(my_bits)-i]])
else:
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
matrix.append((
sbitvec.from_vec(my_bits[:len(my_bits)-i]) * b).v)
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])
return cls.from_vec(v[:len(my_bits)])
__rmul__ = __mul__
reduce_after_mul = lambda x: x
def TruncMul(self, other, k, m, kappa=None, nearest=False):

View File

@@ -104,9 +104,10 @@ class AllocRange:
break
class AllocPool:
def __init__(self):
def __init__(self, parent=None):
self.ranges = defaultdict(lambda: [AllocRange()])
self.by_base = {}
self.parent = parent
def alloc(self, reg_type, size):
for r in self.ranges[reg_type]:
@@ -116,8 +117,17 @@ class AllocPool:
return res
def free(self, reg):
r = self.by_base.pop((reg.reg_type, reg.i))
r.free(reg.i, reg.size)
try:
r = self.by_base.pop((reg.reg_type, reg.i))
r.free(reg.i, reg.size)
except KeyError:
try:
self.parent.free(reg)
except:
if program.Program.prog.options.debug:
print('Error with freeing register with trace:')
print(util.format_trace(reg.caller))
print()
def new_ranges(self, min_usage):
for t, n in min_usage.items():
@@ -133,7 +143,10 @@ class AllocPool:
rr.consolidate()
def n_fragments(self):
return max(len(r) for r in self.ranges)
if self.ranges:
return max(len(r) for r in self.ranges)
else:
return 0
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
@@ -146,6 +159,7 @@ class StraightlineAllocator:
assert(n == REG_MAX)
self.program = program
self.old_pool = None
self.unused = defaultdict(lambda: 0)
def alloc_reg(self, reg, free):
base = reg.vectorbase
@@ -195,7 +209,8 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)
free.free(base)
if reg not in self.program.base_addresses:
free.free(base)
if inst.is_vec() and base.vector:
self.defined[base] = inst
for i in base.vector:
@@ -220,8 +235,11 @@ class StraightlineAllocator:
if unused_regs and len(unused_regs) == len(list(i.get_def())) and \
self.program.verbose:
# only report if all assigned registers are unused
print("Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller)))
self.unused[type(i).__name__] += 1
if self.program.verbose > 1:
print(
"Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller)))
for j in i.get_used():
self.alloc_reg(j, alloc_pool)
@@ -277,6 +295,7 @@ class StraightlineAllocator:
x = reg.reg_type, reg.size
print('Used registers: ', end='')
p(sizes)
print('Unused instructions:', dict(self.unused))
def determine_scope(block, options):
last_def = defaultdict_by_id(lambda: -1)
@@ -421,6 +440,7 @@ class Merger:
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
last_input = defaultdict(lambda: [None, None])
mem_scopes = defaultdict_by_id(lambda: MemScope())
depths = [0] * len(block.instructions)
self.depths = depths
@@ -429,6 +449,12 @@ class Merger:
self.sources = []
self.real_depths = [0] * len(block.instructions)
round_type = {}
shuffles = defaultdict_by_id(set)
class MemScope:
def __init__(self):
self.read = []
self.write = []
def add_edge(i, j):
if i in (-1, j):
@@ -581,14 +607,20 @@ class Merger:
depths[n] = depth
if isinstance(instr, ReadMemoryInstruction):
if options.preserve_mem_order or instr._protect:
if options.preserve_mem_order:
strict_mem_access(n, last_mem_read, last_mem_write)
elif not options.preserve_mem_order:
elif instr._protect:
scope = mem_scopes[instr._protect]
strict_mem_access(n, scope.read, scope.write)
if not options.preserve_mem_order:
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
elif isinstance(instr, WriteMemoryInstruction):
if options.preserve_mem_order or instr._protect:
if options.preserve_mem_order:
strict_mem_access(n, last_mem_write, last_mem_read)
elif not options.preserve_mem_order:
elif instr._protect:
scope = mem_scopes[instr._protect]
strict_mem_access(n, scope.write, scope.read)
if not options.preserve_mem_order:
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
elif isinstance(instr, matmulsm):
if options.preserve_mem_order:
@@ -608,6 +640,11 @@ class Merger:
keep_order(instr, n, instr.args[0])
elif isinstance(instr, StackInstruction):
keep_order(instr, n, StackInstruction)
elif isinstance(instr, applyshuffle):
shuffles[instr.args[3]].add(n)
elif isinstance(instr, delshuffle):
for i_inst in shuffles[instr.args[0]]:
add_edge(i_inst, n)
if not G.pred[n]:
self.sources.append(n)
@@ -683,6 +720,7 @@ class RegintOptimizer:
self.cache = util.dict_by_id()
self.offset_cache = util.dict_by_id()
self.rev_offset_cache = {}
self.range_cache = util.dict_by_id()
def add_offset(self, res, new_base, new_offset):
self.offset_cache[res] = new_base, new_offset
@@ -693,6 +731,12 @@ class RegintOptimizer:
for i, inst in enumerate(instructions):
if isinstance(inst, ldint_class):
self.cache[inst.args[0]] = inst.args[1]
elif isinstance(inst, incint):
if inst.args[2] == 1 and inst.args[3] == 1 and \
inst.args[4] == len(inst.args[0]) and \
inst.args[1] in self.cache:
self.range_cache[inst.args[0]] = \
len(inst.args[0]), self.cache[inst.args[1]]
elif isinstance(inst, IntegerInstruction):
if inst.args[1] in self.cache and inst.args[2] in self.cache:
res = inst.op(self.cache[inst.args[1]],
@@ -731,6 +775,10 @@ class RegintOptimizer:
base, offset = self.offset_cache[inst.args[1]]
addr = self.rev_offset_cache[base.i, offset]
inst.args[1] = addr
elif inst.args[1] in self.range_cache:
size, base = self.range_cache[inst.args[1]]
if size == len(inst.args[0]):
instructions[i] = inst.get_direct(base)
elif type(inst) == convint_class:
if inst.args[1] in self.cache:
res = self.cache[inst.args[1]]

View File

@@ -65,6 +65,8 @@ def ld2i(c, n):
movc(c, t1)
def require_ring_size(k, op):
if not program.options.ring:
return
if int(program.options.ring) < k:
msg = 'ring size too small for %s, compile ' \
'with \'-R %d\' or more' % (op, k)
@@ -140,7 +142,7 @@ def TruncRing(d, a, k, m, signed):
high = sint.conv(carries[length])
else:
if m == 1:
low = x[1][1]
low = x[0][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
else:
@@ -181,7 +183,7 @@ def TruncLeakyInRing(a, k, m, signed):
if k == m:
return 0
assert k > m
assert int(program.options.ring) >= k
require_ring_size(k, 'leaky truncation')
from .types import sint, intbitint, cint, cgf2n
n_bits = k - m
n_shift = int(program.options.ring) - n_bits
@@ -228,7 +230,7 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
def Mod2mRing(a_prime, a, k, m, signed):
assert(int(program.options.ring) >= k)
require_ring_size(k, 'modulo power of two')
from Compiler.types import sint, intbitint, cint
shift = int(program.options.ring) - m
r_prime, r_bin = MaskingBitsInRing(m, True)
@@ -404,7 +406,7 @@ def carry(b, a, compute_p=True):
return b
if b is None:
return a
t = [program.curr_block.new_reg('s') for i in range(3)]
t = [None] * 3
if compute_p:
t[0] = a[0].bit_and(b[0])
t[2] = a[0].bit_and(b[1]) + a[1]

View File

@@ -13,12 +13,28 @@ from .program import Program, defaults
class Compiler:
def __init__(self, custom_args=None, usage=None, execute=False):
def __init__(self, custom_args=None, usage=None, execute=False,
split_args=False):
if usage:
self.usage = usage
else:
self.usage = "usage: %prog [options] filename [args]"
self.execute = execute
self.runtime_args = []
if split_args:
if custom_args is None:
args = sys.argv
else:
args = custom_args
try:
split = args.index('--')
except ValueError:
split = len(args)
custom_args = args[1:split]
self.runtime_args = args[split + 1:]
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
@@ -148,7 +164,8 @@ class Compiler:
"--prime",
dest="prime",
default=defaults.prime,
help="prime modulus (default: not specified)",
help="use bit decomposition with a specifed prime modulus "
"for non-linear computation (default: use the masking approach)",
)
parser.add_option(
"-I",
@@ -235,13 +252,31 @@ class Compiler:
dest="hostfile",
help="hosts to execute with",
)
else:
parser.add_option(
"-E",
"--execute",
dest="execute",
help="protocol to optimize for",
)
self.parser = parser
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.execute:
if not self.options.execute:
raise CompilerError("must give name of protocol with '-E'")
if len(self.args) > 1:
self.options.execute = self.args.pop(0)
else:
self.parser.error("missing protocol name")
if self.options.hostfile:
try:
open(self.options.hostfile)
except:
print('hostfile %s not found' % self.options.hostfile,
file=sys.stderr)
exit(1)
if self.options.execute:
protocol = self.options.execute
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
protocol.find("brain") >= 0 or protocol == "emulate":
@@ -268,14 +303,14 @@ class Compiler:
def build_program(self, name=None):
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute:
if self.options.execute in \
("emulate", "ring", "rep-field", "rep4-ring"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
self.prog.use_split(3)
if self.options.execute in ("semi2k",):
self.prog.use_split(2)
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
@@ -476,7 +511,9 @@ class Compiler:
else:
return protocol + "-party.x"
def local_execution(self, args=[]):
def local_execution(self, args=None):
if args is None:
args = self.runtime_args
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
@@ -488,9 +525,13 @@ class Compiler:
"Note that compilation requires a few GB of RAM.")
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
sys.stdout.flush()
print("Compilation finished, running program...", file=sys.stderr)
sys.stderr.flush()
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]):
def remote_execution(self, args=None):
if args is None:
args = self.runtime_args
vm = self.executable_from_protocol(self.options.execute)
hosts = list(x.strip()
for x in filter(None, open(self.options.hostfile)))

View File

@@ -286,6 +286,7 @@ def BitDecRingRaw(a, k, m):
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return bits
@instructions_base.bit_cisc
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
@@ -304,6 +305,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.reset_global_vector_size()
return res
@instructions_base.bit_cisc
def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
return [types.sintbit.conv(bit) for bit in res]
@@ -358,7 +360,6 @@ def B2U_from_Pow2(pow2a, l, kappa):
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
""" Oblivious truncation by secret m """
prog = program.Program.prog
kappa = kappa or prog.security
if util.is_constant(m) and not compute_modulo:
# cheaper
res = type(a)(size=a.size)
@@ -371,6 +372,8 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
return a * (1 - m)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, Pow2(m, l, kappa))
else:
kappa = kappa or program.Program.prog.security
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
@@ -409,6 +412,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
b = shifted - d
return b
@instructions_base.ret_cisc
def TruncInRing(to_shift, l, pow2m):
n_shift = int(program.Program.prog.options.ring) - l
bits = BitDecRing(to_shift, l, l)
@@ -433,11 +437,7 @@ def SplitInRing(a, l, m):
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
if program.Program.prog.options.ring:
s = (1 - overflow) * t + \
comparison.TruncLeakyInRing(overflow * t, length, 1, False)
else:
s = (1 - overflow) * t + overflow * t / 2
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
return s, overflow
def Int2FL(a, gamma, l, kappa=None):
@@ -555,7 +555,7 @@ def TruncPrField(a, k, m, kappa=None):
c = (b + r).reveal(False)
c_prime = c % two_to_m
a_prime = c_prime - r_prime
d = (a - a_prime) / two_to_m
d = (a - a_prime).field_div(two_to_m)
return d
@instructions_base.ret_cisc

View File

@@ -345,6 +345,17 @@ class starg(base.Instruction):
code = base.opcodes['STARG']
arg_format = ['ci']
@base.vectorize
class cmdlinearg(base.Instruction):
""" Load command-line argument.
:param: dest (regint)
:param: index (regint)
"""
code = base.opcodes['CMDLINEARG']
arg_format = ['ciw','ci']
@base.gf2n
class reqbl(base.Instruction):
""" Requirement on computation modulus. Minimal bit length of prime if
@@ -654,7 +665,7 @@ class picks(base.VectorInstruction):
def __init__(self, *args):
super(picks, self).__init__(*args)
assert 0 <= args[2] < len(args[1])
assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1])
assert 0 <= args[2] + args[3] * (len(args[0]) - 1) < len(args[1])
class concats(base.VectorInstruction):
""" Concatenate vectors.
@@ -1630,6 +1641,16 @@ class print_reg_plain(base.IOInstruction):
code = base.opcodes['PRINTREGPLAIN']
arg_format = ['c']
class print_reg_plains(base.IOInstruction):
""" Output secret register.
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['PRINTREGPLAINS']
arg_format = ['s']
class cond_print_plain(base.IOInstruction):
""" Conditionally output clear register (with precision).
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
@@ -1860,6 +1881,19 @@ class acceptclientconnection(base.IOInstruction):
code = base.opcodes['ACCEPTCLIENTCONNECTION']
arg_format = ['ciw', 'ci']
class initclientconnection(base.IOInstruction):
""" Initialize connection.
:param: client id destination (regint)
:param: port number (regint)
:param: my client id (regint)
:param: hostname (variable string)
"""
__slots__ = []
code = base.opcodes['INITCLIENTCONNECTION']
arg_format = ['ciw', 'ci', 'ci', 'varstr']
class closeclientconnection(base.IOInstruction):
""" Close connection to client.
@@ -1941,7 +1975,7 @@ class fixinput(base.PublicFileIOInstruction):
:param: player (int)
:param: destination (cint)
:param: exponent (int)
:param: exponent (int, for float/double) / byte length (1/8, for integer)
:param: input type (0: 64-bit integer, 1: float, 2: double)
"""
@@ -2284,30 +2318,30 @@ class asm_open(base.VarArgsInstruction, base.DataInstruction):
self.args += other.args[1:]
@base.gf2n
@base.vectorize
class muls(base.VarArgsInstruction, base.DataInstruction):
class muls(base.VarArgsInstruction, base.DataInstruction, base.Ciscable):
""" (Element-wise) multiplication of secret registers (vectors).
:param: number of arguments to follow (multiple of three)
:param: number of arguments to follow (multiple of four)
:param: vector size (int)
:param: result (sint)
:param: factor (sint)
:param: factor (sint)
:param: (repeat the last three)...
:param: (repeat the last four)...
"""
__slots__ = []
code = base.opcodes['MULS']
arg_format = tools.cycle(['sw','s','s'])
arg_format = tools.cycle(['int','sw','s','s'])
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, *args, **kwargs):
super(muls_class, self).__init__(*args, **kwargs)
for i in range(0, len(args), 4):
for j in range(3):
assert args[i + j + 1].size == args[i]
def get_repeat(self):
return len(self.args) // 3
def merge_id(self):
# can merge different sizes
# but not if large
if self.get_size() is None or self.get_size() > 100:
return type(self), self.get_size()
return type(self)
return sum(self.args[::4])
# def expand(self):
# s = [program.curr_block.new_reg('s') for i in range(9)]
@@ -2324,6 +2358,16 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
# adds(s[8], s[7], s[6])
# addm(self.args[0], s[8], c[2])
# compatibility
try:
vmuls = muls_class
muls_bak = muls
muls = lambda *args: muls_bak(args[0].size, *args)
vgmuls = gmuls_class = gmuls
gmuls = lambda *args: gmuls_class(args[0].size, *args)
except NameError:
pass
@base.gf2n
class mulrs(base.VarArgsInstruction, base.DataInstruction):
""" Constant-vector multiplication of secret registers.
@@ -2403,8 +2447,8 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction,
return self.arg_format()
def get_repeat(self):
return sum(self.args[i] // 2
for i, n in self.bases(iter(self.args))) * self.get_size()
return sum(self.args[i] // 2 - 1
for i, n in self.bases(iter(self.args)))
def get_def(self):
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
@@ -2421,7 +2465,7 @@ class matmul_base(base.DataInstruction):
def get_repeat(self):
return reduce(operator.mul, self.args[3:6])
class matmuls(matmul_base):
class matmuls(matmul_base, base.Mergeable):
""" Secret matrix multiplication from registers. All matrices are
represented as vectors in row-first order.
@@ -2433,7 +2477,11 @@ class matmuls(matmul_base):
:param: number of columns in second factor and result (int)
"""
code = base.opcodes['MATMULS']
arg_format = ['sw','s','s','int','int','int']
arg_format = itertools.cycle(['sw','s','s','int','int','int'])
def get_repeat(self):
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
for i in range(0, len(self.args), 6))
class matmulsm(matmul_base):
""" Secret matrix multiplication reading directly from memory.

View File

@@ -67,6 +67,7 @@ opcodes = dict(
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
CMDLINEARG = 0xEB,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -152,7 +153,7 @@ opcodes = dict(
LISTEN = 0x6c,
ACCEPTCLIENTCONNECTION = 0x6d,
CLOSECLIENTCONNECTION = 0x6e,
READCLIENTPUBLICKEY = 0x6f,
INITCLIENTCONNECTION = 0x6f,
# Bitwise logic
ANDC = 0x70,
XORC = 0x71,
@@ -196,6 +197,7 @@ opcodes = dict(
PRINTREG = 0XB1,
RAND = 0xB2,
PRINTREGPLAIN = 0xB3,
PRINTREGPLAINS = 0xEA,
PRINTCHR = 0xB4,
PRINTSTR = 0xB5,
PUBINPUT = 0xB6,
@@ -422,28 +424,31 @@ def gf2n(instruction):
class Mergeable:
pass
def cisc(function):
def cisc(function, n_outputs=1):
class MergeCISC(Mergeable):
instructions = {}
functions = {}
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.security = program.security
self.security = program._security
self.calls = [(args, kwargs)]
self.params = []
self.used = []
for arg in self.args[1:]:
for arg in self.args[n_outputs:]:
if isinstance(arg, program.curr_tape.Register):
self.used.append(arg)
self.params.append(type(arg))
else:
self.params.append(arg)
self.function = function
self.caller = None
program.curr_block.instructions.append(self)
def get_def(self):
return [call[0][0] for call in self.calls]
return sum(([call[0][i] for call in self.calls]
for i in range(n_outputs)), [])
def get_used(self):
return self.used
@@ -460,7 +465,7 @@ def cisc(function):
self.used += other.used
def get_size(self):
return self.args[0].size
return self.args[0].vector_size()
def new_instructions(self, size, regs):
if self.merge_id() not in self.instructions:
@@ -474,11 +479,11 @@ def cisc(function):
args = []
for arg in self.args:
try:
args.append(type(arg)(size=None))
args.append(arg.new_vector(size=None))
except:
args.append(arg)
program.options.cisc = False
old_security = program.security
old_security = program._security
program.security = self.security
self.function(*args, **self.kwargs)
program.security = old_security
@@ -490,7 +495,8 @@ def cisc(function):
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
args[0].can_eliminate = False
for i in range(n_outputs):
args[i].can_eliminate = False
merger.eliminate_dead_code()
assert int(program.options.max_parallel_open) == 0, \
'merging restriction not compatible with ' \
@@ -501,52 +507,105 @@ def cisc(function):
n_rounds
template, args, self.n_rounds = self.instructions[self.merge_id()]
subs = util.dict_by_id()
from Compiler import types
for arg, reg in zip(args, regs):
subs[arg] = reg
if isinstance(arg, program.curr_tape.Register):
subs[arg] = reg
set_global_vector_size(size)
for inst in template:
inst.copy(size, subs)
reset_global_vector_size()
def expand_to_function(self, size, new_regs):
key = size, program.curr_tape, \
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
tuple(type(reg) for reg in new_regs)
if key not in self.functions:
from Compiler import library, types
from Compiler.GC.types import bits
class Arg:
def __init__(self, reg):
self.type = type(reg)
self.binary = isinstance(reg, bits)
self.reg = reg
# if reg is not None:
# program.base_addresses[reg] = None
def new(self):
if self.binary:
return self.type()
else:
return self.type(size=size)
def load(self):
return self.reg
def store(self, reg):
if self.type != type(None):
self.reg.update(reg)
args = [Arg(x) for x in new_regs]
@library.function_block
def f():
res = [arg.new() for arg in args[:n_outputs]]
self.new_instructions(size,
res + [arg.load() for arg in args[n_outputs:]])
for reg, arg in zip(res, args):
arg.store(reg)
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
[str(x) for x in key[2]])
self.functions[key] = f, args
f, args = self.functions[key]
for i in range(len(new_regs) - n_outputs):
args[n_outputs + i].store(new_regs[n_outputs + i])
f()
for i in range(n_outputs):
new_regs[i].link(args[i].load())
def expand_merged(self, skip):
if function.__name__ in skip:
good = True
for call in self.calls:
if not good:
break
for arg in call[0]:
if isinstance(arg, program.curr_tape.Register) and \
not issubclass(type(self.calls[0][0][0]), type(arg)):
good = False
for i in range(n_outputs):
for arg in call[0]:
if isinstance(arg, program.curr_tape.Register) and \
not issubclass(type(self.calls[0][0][0]),
type(arg)):
good = False
if good:
return [self], 0
return program.curr_block.instructions.append(self)
if program.verbose:
print('expanding', self.function.__name__)
tape = program.curr_tape
block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block
size = sum(call[0][0].size for call in self.calls)
tape.start_new_basicblock()
size = sum(call[0][0].vector_size() for call in self.calls)
new_regs = []
for i, arg in enumerate(self.args):
try:
if i == 0:
new_regs.append(type(arg)(size=size))
if i < n_outputs:
new_regs.append(arg.new_vector(size=size))
else:
new_regs.append(type(arg).concat(
call[0][i] for call in self.calls))
assert len(new_regs[-1]) == size
assert new_regs[-1].vector_size() == size
except (TypeError, AttributeError):
if not isinstance(arg, int):
if not isinstance(arg, (int, type(None))):
raise
break
new_regs.append(None)
except:
print([call[0][0].size for call in self.calls])
print([call[0][0].vector_size() for call in self.calls])
raise
self.new_instructions(size, new_regs)
if program.cisc_to_function and \
(program.curr_tape.singular or program.n_running_threads):
self.expand_to_function(size, new_regs)
else:
self.new_instructions(size, new_regs)
program.curr_block.n_rounds += self.n_rounds - 1
base = 0
for call in self.calls:
reg = call[0][0]
reg.copy_from_part(new_regs[0], base, reg.size)
base += reg.size
return block.instructions, self.n_rounds - 1
for i in range(n_outputs):
reg = call[0][i]
reg.copy_from_part(new_regs[i], base, reg.vector_size())
base += reg.vector_size()
tape.start_new_basicblock()
def add_usage(self, *args):
pass
@@ -605,10 +664,10 @@ def ret_cisc(function):
from Compiler import types
if not (program.options.cisc and isinstance(args[0], types._register)):
return function(*args, **kwargs)
if isinstance(args[0], types._clear):
res_type = type(args[1])
else:
res_type = type(args[0])
for arg in args:
if isinstance(arg, types._secret):
res_type = type(arg)
break
res = res_type(size=args[0].size)
instruction(res, *args, **kwargs)
return res
@@ -642,6 +701,24 @@ def sfix_cisc(function):
copy_doc(wrapper, function)
return wrapper
bit_instructions = {}
def bit_cisc(function):
def wrapper(a, k, m, *args, **kwargs):
key = function, m
if key not in bit_instructions:
def instruction(*args, **kwargs):
res = function(*args[m:], **kwargs)
for x, y in zip(res, args):
x.mov(y, x)
instruction.__name__ = '%s(%d)' % (function.__name__, m)
bit_instructions[key] = cisc(instruction, m)
from Compiler.types import sintbit
res = [sintbit() for i in range(m)]
bit_instructions[function, m](*res, a, k, m, *args, **kwargs)
return res
return wrapper
class RegType(object):
""" enum-like static class for Register types """
ClearModp = 'c'
@@ -793,6 +870,23 @@ class String(ArgFormat):
def __str__(self):
return self.str
class VarString(ArgFormat):
@classmethod
def check(cls, arg):
if not isinstance(arg, str):
raise ArgumentError(arg, 'Argument is not string')
@classmethod
def encode(cls, arg):
return int_to_bytes(len(arg)) + list(bytearray(arg, 'ascii'))
def __init__(self, f):
length = IntArgFormat(f).i
self.str = str(f.read(length), 'ascii')
def __str__(self):
return self.str
ArgFormats = {
'c': ClearModpAF,
's': SecretModpAF,
@@ -810,6 +904,7 @@ ArgFormats = {
'long': LongArgFormat,
'p': PlayerNoAF,
'str': String,
'varstr': VarString,
}
def format_str_is_reg(format_str):
@@ -930,7 +1025,7 @@ class Instruction(object):
self.args += other.args
def expand_vector_args(self):
if self.is_vec():
if self.is_vec() and self.get_size() != 1:
for arg in self.args:
arg.create_vector_elements()
res = sum(list(zip(*self.args)), ())
@@ -939,7 +1034,7 @@ class Instruction(object):
return self.args
def expand_merged(self, skip):
return [self], 0
program.curr_block.instructions.append(self)
def get_new_args(self, size, subs):
new_args = []
@@ -956,6 +1051,10 @@ class Instruction(object):
new_args.append(arg)
return new_args
def copy(self, *args, **kwargs):
raise CompilerError("%s instruction not compatible with CISC-style "
"merging. Compile with '-O'." % type(self))
@staticmethod
def get_usage(args):
return {}
@@ -990,9 +1089,9 @@ class ParsedInstruction:
pass
read = lambda: struct.unpack('>I', f.read(4))[0]
full_code = struct.unpack('>Q', f.read(8))[0]
code = full_code % (1 << Instruction.code_length)
self.code = full_code % (1 << Instruction.code_length)
self.size = full_code >> Instruction.code_length
self.type = cls.reverse_opcodes[code]
self.type = cls.reverse_opcodes[self.code]
t = self.type
name = t.__name__
try:
@@ -1044,6 +1143,10 @@ class VectorInstruction(Instruction):
def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
class Ciscable(Instruction):
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs), copying=True)
class DynFormatInstruction(Instruction):
__slots__ = []

View File

@@ -7,7 +7,8 @@ from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,c
from Compiler.instructions import *
from Compiler.util import tuplify,untuplify,is_zero
from Compiler.allocator import RegintOptimizer, AllocPool
from Compiler import instructions,instructions_base,comparison,program,util
from Compiler.program import Tape
from Compiler import instructions,instructions_base,comparison,util,types
import inspect,math
import random
import collections
@@ -42,7 +43,7 @@ def vectorize(function):
def set_instruction_type(function):
def instruction_typed_function(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
if len(args) > 0 and isinstance(args[0], Tape.Register):
if args[0].is_gf2n:
instructions_base.set_global_instruction_type('gf2n')
else:
@@ -59,9 +60,15 @@ def set_instruction_type(function):
def _expand_to_print(val):
return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val)
def print_str(s, *args):
def print_str(s, *args, print_secrets=False):
""" Print a string, with optional args for adding
variables/registers with ``%s``. """
variables/registers with ``%s``.
:param s: format string
:param args: arguments (any type)
:param print_secrets: whether to output secret shares
"""
def print_plain_str(ss):
""" Print a plain string (no custom formatting options) """
ss = bytearray(ss, 'utf8')
@@ -84,11 +91,15 @@ def print_str(s, *args):
val = args[i].read()
else:
val = args[i]
if isinstance(val, program.Tape.Register):
if isinstance(val, Tape.Register):
if val.is_clear:
val.print_reg_plain()
elif print_secrets and isinstance(val, sint):
val.output()
else:
raise CompilerError('Cannot print secret value:', args[i])
raise CompilerError(
'Cannot print secret value %s, activate printing of shares with '
"'print_secrets=True'" % args[i])
elif isinstance(val, cfix):
val.print_plain()
elif isinstance(val, sfix) or isinstance(val, sfloat):
@@ -100,16 +111,17 @@ def print_str(s, *args):
else:
try:
val.output()
except AttributeError:
except (AttributeError, TypeError):
print_plain_str(str(val))
def print_ln(s='', *args):
def print_ln(s='', *args, **kwargs):
""" Print line, with optional args for adding variables/registers
with ``%s``. By default only player 0 outputs, but the ``-I``
command-line option changes that.
:param s: Python string with same number of ``%s`` as length of :py:obj:`args`
:param args: list of public values (regint/cint/int/cfix/cfloat/localint)
:param print_secrets: whether to output secret shares
Example:
@@ -117,7 +129,7 @@ def print_ln(s='', *args):
print_ln('a is %s.', a.reveal())
"""
print_str(str(s) + '\n', *args)
print_str(str(s) + '\n', *args, **kwargs)
def print_both(s, end='\n'):
""" Print line during compilation and execution. """
@@ -169,7 +181,7 @@ def print_str_if(cond, ss, *args):
def print_ln_to(player, ss, *args):
""" Print line at :py:obj:`player` only. Note that printing is
disabled by default except at player 0. Activate interactive mode
with `-I` to enable it for all players.
with `-I` or use `-OF .` to enable it for all players.
:param player: int
:param ss: Python string
@@ -295,8 +307,14 @@ def get_arg():
ldarg(res)
return res
def get_cmdline_arg(idx):
""" Return run-time command-line argument. """
res = regint()
cmdlinearg(res, regint.conv(idx))
return localint(res)
def make_array(l, t=None):
if isinstance(l, program.Tape.Register):
if isinstance(l, Tape.Register):
res = Array(len(l), t or type(l))
res[:] = l
else:
@@ -337,14 +355,15 @@ class Function:
# first call
type_args = collections.defaultdict(list)
for i,arg in enumerate(args):
type_args[get_reg_type(arg)].append(i)
if not isinstance(arg, types._vectorizable):
type_args[get_reg_type(arg)].append(i)
def wrapped_function(*compile_args):
base = get_arg()
bases = dict((t, regint.load_mem(base + i)) \
for i,t in enumerate(sorted(type_args,
key=lambda x:
x.reg_type)))
runtime_args = [None] * len(args)
runtime_args = list(args)
for t in sorted(type_args, key=lambda x: x.reg_type):
i = 0
for i_arg in type_args[t]:
@@ -407,13 +426,14 @@ def unmemorize(x):
class FunctionBlock(Function):
def on_first_call(self, wrapped_function):
p_return_address = get_tape().program.malloc(1, 'ci')
old_block = get_tape().active_basicblock
parent_node = get_tape().req_node
parent_node = old_block.req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock
block.alloc_pool = AllocPool()
block.alloc_pool = AllocPool(parent=block.alloc_pool)
del parent_node.children[-1]
self.node = get_tape().req_node
self.node = block.req_node
if get_program().verbose:
print('Compiling function', self.name)
result = wrapped_function(*self.compile_args)
@@ -423,7 +443,6 @@ class FunctionBlock(Function):
self.result = None
if get_program().verbose:
print('Done compiling function', self.name)
p_return_address = get_tape().program.malloc(1, 'ci')
get_tape().function_basicblocks[block] = p_return_address
return_address = regint.load_mem(p_return_address)
get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False))
@@ -446,7 +465,7 @@ class FunctionBlock(Function):
return_address.store_in_mem(p_return_address)
get_tape().start_new_basicblock(name='call-' + self.name)
get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
get_tape().req_node.children.append(self.node)
get_block().req_node.children.append(self.node)
if self.result is not None:
return unmemorize(self.result)
@@ -654,7 +673,7 @@ def range_loop(loop_body, start, stop=None, step=None):
and isinstance(step, int):
# known loop count
if condition(start):
get_tape().req_node.children[-1].aggregator = \
get_block().req_node.children[-1].aggregator = \
lambda x: int(ceil(((stop - start) / step))) * x[0]
def for_range(start, stop=None, step=None):
@@ -680,9 +699,6 @@ def for_range(start, stop=None, step=None):
x.update(x + 1)
print_ln('%s', x.reveal())
Note that you cannot overwrite data structures such as
:py:class:`~Compiler.types.Array` in a loop. Use
:py:func:`~Compiler.types.Array.assign` instead.
"""
def decorator(loop_body):
range_loop(loop_body, start, stop, step)
@@ -791,7 +807,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
loop_rounds = n_loops // n_parallel \
if n_parallel < n_loops else 0
else:
loop_rounds = n_loops / n_parallel
loop_rounds = n_loops // n_parallel
def write_state_to_memory(r):
if use_array:
mem_state.assign(r)
@@ -821,6 +837,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
n_opt_loops_reg = regint(0)
n_opt_loops_inst = get_block().instructions[-1]
parent_block = get_block()
prevent_breaks = get_program().prevent_breaks
get_program().prevent_breaks = False
@while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0))
def _(i):
state = tuplify(initializer())
@@ -846,6 +864,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
loop_rounds = n_loops // my_n_parallel
blocks = get_tape().basicblocks
n_to_merge = 5
get_program().prevent_breaks = prevent_breaks
if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]:
# merge blocks started by if and do_while
def exit_elimination(block):
@@ -857,19 +876,22 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
merged.exit_condition = blocks[-1].exit_condition
merged.exit_block = blocks[-1].exit_block
assert parent_block is blocks[-n_to_merge]
assert blocks[-n_to_merge + 1] is \
get_tape().req_node.children[-1].nodes[0].blocks[0]
assert blocks[-n_to_merge + 1].req_node is \
get_block().req_node.children[-1].nodes[0]
for block in blocks[-n_to_merge + 1:]:
merged.instructions += block.instructions
exit_elimination(block)
block.purge(retain_usage=False)
del blocks[-n_to_merge + 1:]
del get_tape().req_node.children[-1]
del get_block().req_node.children[-1]
merged.children = []
RegintOptimizer().run(merged.instructions, get_program())
get_tape().active_basicblock = merged
else:
req_node = get_tape().req_node.children[-1].nodes[0]
if get_program().verbose:
print(n_opt_loops, 'repetitions')
assert not get_program().prevent_breaks
req_node = get_block().req_node.children[-1].nodes[0]
if util.is_constant(loop_rounds):
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
if isinstance(n_loops, int):
@@ -892,7 +914,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
return returner
return decorator
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={},
budget=None):
"""
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
threads, up to :py:obj:`n_parallel` in parallel per thread.
@@ -902,9 +925,10 @@ def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
"""
return map_reduce(n_threads, n_parallel, n_loops, \
lambda *x: [], lambda *x: [], thread_mem_req)
lambda *x: [], lambda *x: [], thread_mem_req,
budget=budget)
def for_range_opt_multithread(n_threads, n_loops):
def for_range_opt_multithread(n_threads, n_loops, budget=None):
"""
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
threads, in parallel up to an optimization budget per thread
@@ -943,7 +967,7 @@ def for_range_opt_multithread(n_threads, n_loops):
b = a + 1
"""
return for_range_multithread(n_threads, None, n_loops)
return for_range_multithread(n_threads, None, n_loops, budget=budget)
def multithread(n_threads, n_items=None, max_size=None):
"""
@@ -983,7 +1007,7 @@ def multithread(n_threads, n_items=None, max_size=None):
return wrapper
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
thread_mem_req={}, looping=True):
thread_mem_req={}, looping=True, budget=None):
assert(n_threads != 0)
if isinstance(n_loops, (list, tuple)):
split = n_loops
@@ -1025,10 +1049,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
state_type = type(state[0])
else:
state_type = type(state)
prevent_breaks = get_program().prevent_breaks
def f(inc):
get_program().prevent_breaks = prevent_breaks
base = args[get_arg()][0]
get_program().base_addresses[base] = None
if not util.is_constant(thread_rounds):
i = base / thread_rounds
i = base // thread_rounds
overhang = n_loops % n_threads
inc = i < overhang
base += inc.if_else(i, overhang)
@@ -1050,6 +1077,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if prog.curr_tape == prog.tapes[0]:
prog.n_running_threads = n_threads
if not util.is_zero(thread_rounds):
prog.prevent_breaks = False
tape = prog.new_tape(f, (0,), 'multithread')
for i in range(n_threads - remainder):
mem_state = make_array(initializer())
@@ -1058,6 +1086,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
args[remainder + i][1] = mem_state.address
thread_args.append((tape, remainder + i))
if remainder:
prog.prevent_breaks = False
tape1 = prog.new_tape(f, (1,), 'multithread1')
for i in range(remainder):
mem_state = make_array(initializer())
@@ -1066,10 +1095,12 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
args[i][1] = mem_state.address
thread_args.append((tape1, i))
prog.n_running_threads = None
prog.prevent_breaks = False
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
prog.free_later()
prog.prevent_breaks = prevent_breaks
if len(state):
if thread_rounds:
for i in range(n_threads - remainder):
@@ -1266,7 +1297,7 @@ def _link(pre, g):
if g:
from .types import _single
for name, var in pre.items():
if isinstance(var, (program.Tape.Register, _single, _vec)):
if isinstance(var, (Tape.Register, _single, _vec)):
new_var = g[name]
if util.is_constant_float(new_var):
raise CompilerError('cannot reassign constants in blocks')
@@ -1285,7 +1316,7 @@ def do_while(loop_fn, g=None):
return regint(0)
"""
scope = instructions.program.curr_block
parent_node = get_tape().req_node
parent_node = get_block().req_node
# possibly unknown loop count
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
name='begin-loop')
@@ -1334,9 +1365,10 @@ def else_then():
raise CompilerError('else block already defined')
# run the else block
state.if_exit_block = instructions.program.curr_block
state.req_child.add_node(get_tape(), 'else-block')
req_node = state.req_child.add_node(get_tape(), 'else-block')
instructions.program.curr_tape.start_new_basicblock(state.start_block, \
name='else-block')
name='else-block',
req_node=req_node)
state.else_block = instructions.program.curr_block
state.has_else = True
@@ -1545,6 +1577,23 @@ def accept_client_connection(port):
instructions.acceptclientconnection(res, regint.conv(port))
return res
def init_client_connection(host, port, my_id, relative_port=True):
""" Initiate connection to another party as client.
:param host: hostname
:param port: port base (int/regint/cint)
:param my_id: client id to use
:param relative_port: whether to add party number to port number
:returns: connection id
"""
if relative_port:
port = (port + get_player_id())._v
res = regint()
instructions.initclientconnection(
res, regint.conv(port), regint.conv(my_id), host)
return res
def break_point(name=''):
"""
Insert break point. This makes sure that all following code
@@ -1643,7 +1692,9 @@ def cint_cint_division(a, b, k, f):
return (sign_a * sign_b) * A
from Compiler.program import Program
def sint_cint_division(a, b, k, f, kappa):
@instructions_base.ret_cisc
def sint_cint_division(a, b, k, f, kappa, nearest=False):
"""
type(a) = sint, type(b) = cint
"""
@@ -1659,12 +1710,11 @@ def sint_cint_division(a, b, k, f, kappa):
B = absolute_b
W = w0
@for_range(1, theta)
def block(i):
A.link(TruncPr(A * W, 2*k, f, kappa))
temp = (B * W) >> f
W.link(two - temp)
B.link(temp)
for i in range(1, theta):
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
temp = (B * W + 2 * (f - 1)) >> f
W = two - temp
B = temp
return (sign_a * sign_b) * A
def IntDiv(a, b, k, kappa=None):
@@ -1691,13 +1741,16 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
assert 2 * f > k - nearest
theta = int(ceil(log(k/3.5) / log(2)))
l_y = k + 3 * f - res_f
comparison.require_ring_size(
l_y, 'division (https://www.ifca.ai/pub/fc10/31_47.pdf)')
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
l_y = k + 3 * f - res_f
y = a.extend(l_y) * w
y = y.round(l_y, f, kappa, nearest, signed=True)

View File

@@ -834,7 +834,7 @@ class Dense(DenseBase):
prod = MultiArray([N, self.d, self.d_out], sfix)
else:
prod = self.f_input
max_size = program.Program.prog.budget // self.d_out
max_size = get_program().budget // self.d_out
@multithread(self.n_threads, N, max_size)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
@@ -1038,16 +1038,16 @@ class ElementWiseLayer(NoVariableLayer):
self.inputs = inputs
def f_part(self, base, size):
return self.f(self.X.get_part_vector(base, size))
return self.f(self.X.get_vector(base, size))
def f_prime_part(self, base, size):
return self.f_prime(self.Y.get_vector(base, size))
def _forward(self, batch=[0]):
n_per_item = reduce(operator.mul, self.X.sizes[1:])
@multithread(self.n_threads, len(batch), max(1, 1000 // n_per_item))
@multithread(self.n_threads, len(batch) * n_per_item)
def _(base, size):
self.Y.assign_part_vector(self.f_part(base, size), base)
self.Y.assign_vector(self.f_part(base, size), base)
if self.debug_output:
name = self
@@ -1095,9 +1095,9 @@ class Relu(ElementWiseLayer):
self.comparisons = MultiArray(shape, sint)
def f_part(self, base, size):
x = self.X.get_part_vector(base, size)
x = self.X.get_vector(base, size)
c = x > 0
self.comparisons.assign_part_vector(c, base)
self.comparisons.assign_vector(c, base)
return c.if_else(x, 0)
def f_prime_part(self, base, size):
@@ -1686,12 +1686,9 @@ class Conv2d(ConvBase):
padding_h, padding_w = self.padding
if self.use_conv2ds:
n_parts = max(1, round((self.n_threads or 1) / n_channels_out))
while len(batch) % n_parts != 0:
n_parts -= 1
print('Convolution in %d parts' % n_parts)
part_size = len(batch) // n_parts
@for_range_multithread(self.n_threads, 1, [n_parts, n_channels_out])
part_size = 1
@for_range_opt_multithread(self.n_threads,
[len(batch), n_channels_out])
def _(i, j):
inputs = self.X.get_slice_vector(
batch.get_part(i * part_size, part_size))
@@ -2507,6 +2504,10 @@ class Optimizer:
loss = self.layers[-1].average_loss(N)
res = (loss < stop_on_loss) * (loss >= -1)
self.stopped_on_loss.write(1 - res)
print_ln_if(
self.stopped_on_loss,
'aborting epoch because loss is outside range: %s',
loss)
return res
if self.print_losses:
print_ln()
@@ -2545,7 +2546,7 @@ class Optimizer:
loss = MemValue(sfix(0))
def f(start, batch_size, batch):
batch.assign_vector(regint.inc(batch_size, start))
self.forward(batch=batch)
self.forward(batch=batch, run_last=False)
part_truth = truth.get_part(start, batch_size)
n_correct.iadd(
self.layers[-1].reveal_correctness(batch_size, part_truth))
@@ -2644,7 +2645,7 @@ class Optimizer:
batch = Array.create_from(regint.inc(batch_size))
self.forward(batch=batch, training=True)
self.backward(batch=batch)
self.update(0, batch=batch)
self.update(0, batch=batch, i_batch=0)
return
@for_range(n_runs)
def _(i):
@@ -2697,6 +2698,8 @@ class Optimizer:
if depreciation:
self.gamma.imul(depreciation)
print_ln('reducing learning rate to %s', self.gamma)
print_ln_if(self.stopped_on_low_loss,
'aborting run because of low loss')
return 1 - self.stopped_on_low_loss
if self.missing_newline:
print_ln('')

View File

@@ -20,7 +20,7 @@ import sys
from functools import reduce
from Compiler.types import *
from Compiler.types import _secret
from Compiler.types import _secret, _register
from Compiler.library import *
from Compiler.program import Program
from Compiler import floatingpoint,comparison,permutation
@@ -77,8 +77,8 @@ class intBlock(Block):
self.lower, self.shift = \
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
Program.prog.security, True)
trunc = (self.value - self.lower) / self.shift
self.slice = trunc.mod2m(length, self.n_bits, False)
trunc = (self.value - self.lower).field_div(self.shift)
self.slice = trunc.mod2m(length, self.n_bits, signed=False)
self.upper = (trunc - self.slice) * self.shift
def get_slice(self):
total_length = sum(self.lengths)
@@ -89,13 +89,11 @@ class intBlock(Block):
res = []
remainder = self.slice
for length,start in zip(self.lengths[:-1],series(self.lengths)):
res.append(remainder.mod2m(length, total_length - start, False))
res.append(remainder.mod2m(length, total_length - start,
signed=False))
remainder -= res[-1]
if Program.prog.options.ring:
remainder = remainder.trunc_zeros(length,
total_length - start, False)
else:
remainder /= floatingpoint.two_power(length)
remainder = remainder.trunc_zeros(length,
total_length - start, False)
res.append(remainder)
return res
def set_slice(self, value):
@@ -208,23 +206,39 @@ def demux_list(x):
return res
def demux_array(x, res=None):
tmp = demux_matrix(x).array
if res:
try:
assert issubclass(x.value_type, _register)
res[:] = tmp[:]
except:
@for_range(len(res))
def _(i):
res[i] = tmp[i]
else:
res = tmp
return res
def demux_matrix(x, n_threads=None):
n = len(x)
if res is None:
res = Array(2**n, type(x[0]))
if n == 0:
return [1]
m = len(x[0])
t = type(x[0])
res = Matrix(2**n, m, type(x[0]))
if n == 1:
res[0] = 1 - x[0]
res[1] = x[0]
else:
a = Array(2**(n//2), type(x[0]))
a = Matrix(2**(n//2), m, type(x[0]))
a.assign(demux(x[:n//2]))
b = Array(2**(n-n//2), type(x[0]))
b = Matrix(2**(n-n//2), m, type(x[0]))
b.assign(demux(x[n//2:]))
@for_range_multithread(get_n_threads(len(res)), \
max(1, n_parallel // len(b)), len(a))
@for_range_opt_multithread(n_threads, len(a))
def f(i):
@for_range_parallel(n_parallel, len(b))
@for_range_opt(len(b))
def f(j):
res[j * len(a) + i] = a[i] * b[j]
res[j * len(a) + i][:] = a[i][:] * b[j][:]
return res
def get_first_one(x):
@@ -1717,7 +1731,8 @@ class BinaryORAM:
def OptimalORAM(size,*args,**kwargs):
""" Create an ORAM instance suitable for the size based on
experiments.
experiments. This uses the approach by `Keller and Scholl
<https://eprint.iacr.org/2014/137>`_.
:param size: number of elements
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /

View File

@@ -11,6 +11,7 @@ import os
import re
import sys
import hashlib
import random
from collections import defaultdict, deque
from functools import reduce
@@ -74,8 +75,8 @@ class Program(object):
"""A program consists of a list of tapes representing the whole
computation.
When compiling an :file:`.mpc` file, the single instances is
available as :py:obj:`program` in order. When compiling directly
When compiling an :file:`.mpc` file, the single instance is
available as :py:obj:`program`. When compiling directly
from Python code, an instance has to be created before running any
instructions.
"""
@@ -89,6 +90,7 @@ class Program(object):
self.name = name
self.init_names(args)
self._security = 40
self.used_security = 0
self.prime = None
self.tapes = []
if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1:
@@ -113,7 +115,8 @@ class Program(object):
if not self.bit_length:
self.bit_length = 64
print("Default bit length for compilation:", self.bit_length)
print("Default security parameter for compilation:", self.security)
if not (options.binary or options.garbled):
print("Default security parameter for compilation:", self._security)
self.galois_length = int(options.galois)
if self.verbose:
print("Galois length:", self.galois_length)
@@ -185,10 +188,13 @@ class Program(object):
self.relevant_opts = set()
self.n_running_threads = None
self.input_files = {}
self.base_addresses = {}
self.base_addresses = util.dict_by_id()
self._protect_memory = False
self.mem_protect_stack = []
self._always_active = True
self.active = True
self.prevent_breaks = False
self.cisc_to_function = True
if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard
@@ -249,7 +255,7 @@ class Program(object):
else:
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in [args[0]] + infiles))
", ".join("'%s'" % x for x in infiles))
"""
self.name is input file name (minus extension) + any optional arguments.
Used to generate output filenames
@@ -352,7 +358,8 @@ class Program(object):
)
self.curr_tape.start_new_basicblock(name="post-run_tape")
for arg in args:
self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree)
self.curr_block.req_node.children.append(
self.tapes[arg[0]].req_tree)
return thread_numbers
def join_tape(self, thread_number):
@@ -400,6 +407,7 @@ class Program(object):
sch_file.write("lgp:%s" % req)
sch_file.write("\n")
sch_file.write("opts: %s\n" % " ".join(self.relevant_opts))
sch_file.write("sec:%d\n" % self.used_security)
sch_file.close()
h = hashlib.sha256()
for tape in self.tapes:
@@ -433,7 +441,7 @@ class Program(object):
"""The basic block that is currently being created."""
return self.curr_tape.active_basicblock
def malloc(self, size, mem_type, reg_type=None, creator_tape=None):
def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=True):
"""Allocate memory from the top"""
if not isinstance(size, int):
raise CompilerError("size must be known at compile time")
@@ -456,7 +464,7 @@ class Program(object):
else:
raise CompilerError("cannot allocate memory " "outside main thread")
blocks = self.free_mem_blocks[mem_type]
addr = blocks.pop(size)
addr = blocks.pop(size) if use_freed else None
if addr is not None:
self.saved += size
else:
@@ -469,11 +477,13 @@ class Program(object):
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
if single_size:
from .library import get_thread_number, runtime_error_if
bak = self.curr_tape.active_basicblock
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
tn = get_thread_number()
runtime_error_if(tn > self.n_running_threads, "malloc")
res = addr + single_size * (tn - 1)
self.base_addresses[str(res)] = addr
self.curr_tape.active_basicblock = bak
self.base_addresses[res] = addr
return res
else:
return addr
@@ -482,7 +492,7 @@ class Program(object):
"""Free memory"""
now = True
if not util.is_constant(addr):
addr = self.base_addresses[str(addr)]
addr = self.base_addresses[addr]
now = self.curr_tape == self.tapes[0]
size, pool = self.allocated_mem_blocks[addr, mem_type]
if self.curr_block.alloc_pool is not pool:
@@ -524,7 +534,8 @@ class Program(object):
self.public_input_file.close()
def finalize_memory(self):
self.curr_tape.start_new_basicblock(None, "memory-usage")
self.curr_tape.start_new_basicblock(None, "memory-usage",
req_node=self.curr_tape.req_tree)
# reset register counter to 0
if not self.options.noreallocate:
self.curr_tape.init_registers()
@@ -575,6 +586,7 @@ class Program(object):
def security(self):
"""The statistical security parameter for non-linear
functions."""
self.used_security = max(self.used_security, self._security)
return self._security
@security.setter
@@ -701,6 +713,13 @@ class Program(object):
""" Enable or disable memory protection. """
self._protect_memory = status
def open_memory_scope(self, key=None):
self.mem_protect_stack.append(self._protect_memory)
self.protect_memory(key or object())
def close_memory_scope(self):
self.protect_memory(self.mem_protect_stack.pop())
def use_cisc(self):
return self.options.cisc and (not self.prime or self.rabbit_gap()) \
and not self.options.max_parallel_open
@@ -725,7 +744,7 @@ class Program(object):
self._always_active = False
@staticmethod
def read_tapes(schedule):
def read_schedule(schedule):
m = re.search(r"([^/]*)\.mpc", schedule)
if m:
schedule = m.group(1)
@@ -733,7 +752,7 @@ class Program(object):
schedule = "Programs/Schedules/%s.sch" % schedule
try:
lines = open(schedule).readlines()
return open(schedule).readlines()
except FileNotFoundError:
print(
"%s not found, have you compiled the program?" % schedule,
@@ -741,9 +760,25 @@ class Program(object):
)
sys.exit(1)
@classmethod
def read_tapes(cls, schedule):
lines = cls.read_schedule(schedule)
for tapename in lines[2].split(" "):
yield tapename.strip().split(":")[0]
@classmethod
def read_n_threads(cls, schedule):
return int(cls.read_schedule(schedule)[0])
@classmethod
def read_domain_size(cls, schedule):
from Compiler.instructions import reqbl_class
tapename = cls.read_schedule(schedule)[2].strip().split(":")[0]
for inst in Tape.read_instructions(tapename):
if inst.code == reqbl_class.code:
bl = inst.args[0]
return (abs(bl.i) + 63) // 64 * 8
class Tape:
"""A tape contains a list of basic blocks, onto which instructions are added."""
@@ -755,13 +790,12 @@ class Tape:
self.init_names(name)
self.init_registers()
self.req_tree = self.ReqNode(name)
self.req_node = self.req_tree
self.basicblocks = []
self.purged = False
self.block_counter = 0
self.active_basicblock = None
self.old_allocated_mem = program.allocated_mem.copy()
self.start_new_basicblock()
self.start_new_basicblock(req_node=self.req_tree)
self._is_empty = False
self.merge_opens = True
self.if_states = []
@@ -774,7 +808,8 @@ class Tape:
self.warned_about_mem = False
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None):
def __init__(self, parent, name, scope, exit_condition=None,
req_node=None):
self.parent = parent
self.instructions = []
self.name = name
@@ -794,6 +829,8 @@ class Tape:
self.n_to_merge = 0
self.rounds = Tape.ReqNum()
self.warn_about_mem = parent.program.warn_about_mem[-1]
self.req_node = req_node
self.used_from_scope = set()
def __len__(self):
return len(self.instructions)
@@ -860,17 +897,25 @@ class Tape:
req_node.num += self.rounds
def expand_cisc(self):
new_instructions = []
if self.parent.program.options.keep_cisc is not None:
skip = ["LTZ", "Trunc"]
skip = ["LTZ", "Trunc", "EQZ"]
skip += self.parent.program.options.keep_cisc.split(",")
else:
skip = []
tape = self.parent
tape.start_new_basicblock(scope=self.scope, req_node=self.req_node,
name="cisc")
start_block = tape.basicblocks[-1]
start_block.alloc_pool = self.alloc_pool
for inst in self.instructions:
new_inst, n_rounds = inst.expand_merged(skip)
new_instructions.extend(new_inst)
self.n_rounds += n_rounds
self.instructions = new_instructions
inst.expand_merged(skip)
self.instructions = tape.active_basicblock.instructions
if start_block == tape.basicblocks[-1]:
res = self
else:
res = start_block
tape.basicblocks[-1] = self
return res
def __str__(self):
return self.name
@@ -885,7 +930,8 @@ class Tape:
self._is_empty = len(self.basicblocks) == 0
return self._is_empty
def start_new_basicblock(self, scope=False, name=""):
def start_new_basicblock(self, scope=False, name="", req_node=None):
assert not self.program.prevent_breaks
if self.program.verbose and self.active_basicblock and \
self.program.allocated_mem != self.old_allocated_mem:
print("New allocated memory in %s " % self.active_basicblock.name,
@@ -900,10 +946,12 @@ class Tape:
scope = self.active_basicblock
suffix = "%s-%d" % (name, self.block_counter)
self.block_counter += 1
sub = self.BasicBlock(self, self.name + "-" + suffix, scope)
if req_node is None:
req_node = self.active_basicblock.req_node
sub = self.BasicBlock(self, self.name + "-" + suffix, scope,
req_node=req_node)
self.basicblocks.append(sub)
self.active_basicblock = sub
self.req_node.add_block(sub)
# print 'Compiling basic block', sub.name
def init_registers(self):
@@ -1054,12 +1102,20 @@ class Tape:
print("Re-allocating...")
allocator = al.StraightlineAllocator(REG_MAX, self.program)
# make addresses available in functions
for addr in self.program.base_addresses:
if addr.program == self and self.basicblocks:
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
seen = set()
def alloc(block):
allocator.update_usage(block.alloc_pool)
for reg in sorted(
block.used_from_scope, key=lambda x: (x.reg_type, x.i)
):
allocator.alloc_reg(reg, block.alloc_pool)
seen.add(block)
def alloc_loop(block):
left = deque([block])
@@ -1067,7 +1123,8 @@ class Tape:
block = left.popleft()
alloc(block)
for child in block.children:
left.append(child)
if child not in seen:
left.append(child)
allocator.old_pool = None
for i, block in enumerate(reversed(self.basicblocks)):
@@ -1101,6 +1158,8 @@ class Tape:
# offline data requirements
if self.program.verbose:
print("Compile offline data requirements...")
for block in self.basicblocks:
block.req_node.add_block(block)
self.req_num = self.req_tree.aggregate()
if self.program.verbose:
print("Tape requires", self.req_num)
@@ -1160,8 +1219,24 @@ class Tape:
@unpurged
def expand_cisc(self):
mapping = {None: None}
blocks = self.basicblocks[:]
self.basicblocks = []
for block in blocks:
expanded = block.expand_cisc()
mapping[block] = expanded
for block in self.basicblocks:
block.expand_cisc()
if block not in mapping:
mapping[block] = block
for block in self.basicblocks:
block.exit_block = mapping[block.exit_block]
if block.exit_block is not None:
assert block.exit_block in self.basicblocks
if block.previous_block and mapping[block] != block:
mapping[block].previous_block = block.previous_block
mapping[block].sub_block = block.sub_block
block.previous_block = None
del block.sub_block
@unpurged
def _get_instructions(self):
@@ -1320,27 +1395,38 @@ class Tape:
return repr(dict(self))
class ReqNode(object):
__slots__ = ["num", "children", "name", "blocks"]
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
def __init__(self, name):
self.children = []
self._children = []
self.name = name
self.blocks = []
self.aggregated = None
@property
def children(self):
self.aggregated = None
return self._children
def aggregate(self, *args):
if self.aggregated is not None:
return self.aggregated
self.num = Tape.ReqNum()
for block in self.blocks:
block.add_usage(self)
res = reduce(
lambda x, y: x + y.aggregate(self.name), self.children, self.num
)
self.aggregated = res
return res
def increment(self, data_type, num=1):
self.num[data_type] += num
self.aggregated = None
def add_block(self, block):
self.blocks.append(block)
self.aggregated = None
class ReqChild(object):
__slots__ = ["aggregator", "nodes", "parent"]
@@ -1369,18 +1455,18 @@ class Tape:
def add_node(self, tape, name):
new_node = Tape.ReqNode(name)
self.nodes.append(new_node)
tape.req_node = new_node
return new_node
def open_scope(self, aggregator, scope=False, name=""):
child = self.ReqChild(aggregator, self.req_node)
self.req_node.children.append(child)
child.add_node(self, "%s-%d" % (name, len(self.basicblocks)))
self.start_new_basicblock(name=name)
req_node = self.active_basicblock.req_node
child = self.ReqChild(aggregator, req_node)
req_node.children.append(child)
node = child.add_node(self, "%s-%d" % (name, len(self.basicblocks)))
self.start_new_basicblock(name=name, req_node=node)
return child
def close_scope(self, outer_scope, parent_req_node, name):
self.req_node = parent_req_node
self.start_new_basicblock(outer_scope, name)
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
def require_bit_length(self, bit_length, t="p"):
if t == "p":
@@ -1553,7 +1639,7 @@ class Tape:
diff_block = isinstance(other, Tape.Register) and self.block != other.block
other = type(self)(other)
if not diff_block:
self.program.start_new_basicblock()
self.program.start_new_basicblock(name="update")
if self.program != other.program:
raise CompilerError(
'cannot update register with one from another thread')
@@ -1575,6 +1661,8 @@ class Tape:
)
def __str__(self):
return self.reg_type + str(self.i)
return self.reg_type + str(self.i) + \
("(%d)" % self.size if self.size is not None and self.size > 1
else "")
__repr__ = __str__

View File

@@ -20,7 +20,7 @@ correct signature.
Basic types
-----------
All basic can be used as vectors, that is one instance representing
All basic types can be used as vectors, that is one instance representing
several values, with all operations being executed element-wise. For
example, the following computes ten multiplications of integers input
by party 0 and 1::
@@ -588,7 +588,7 @@ class _secret_structure(_structure):
@classmethod
def input_tensor_via(cls, player, content=None, shape=None, binary=True,
one_hot=False):
one_hot=False, skip_input=False, n_bytes=None):
"""
Input tensor-like data via a player. This overwrites the input
file for the relevant player. The following returns an
@@ -630,7 +630,10 @@ class _secret_structure(_structure):
else:
t = numpy.single
else:
t = numpy.int64
if n_bytes == 1:
t = numpy.int8
else:
t = numpy.int64
if one_hot:
content = numpy.eye(content.max() + 1)[content]
content = content.astype(t)
@@ -666,9 +669,10 @@ class _secret_structure(_structure):
if requested_shape is not None and \
list(shape) != list(requested_shape):
raise CompilerError('content contradicts shape')
res = cls.Tensor(shape)
res.input_from(player, binary=binary)
return res
if not skip_input:
res = cls.Tensor(shape)
res.input_from(player, binary=binary, n_bytes=n_bytes)
return res
class _vec(Tape._no_truth):
def link(self, other):
@@ -681,6 +685,13 @@ class _register(Tape.Register, _number, _structure):
def n_elements():
return 1
@classmethod
def new_vector(cls, size):
return cls(size=size)
def vector_size(self):
return self.size
@vectorized_classmethod
def conv(cls, val):
if isinstance(val, MemValue):
@@ -707,7 +718,7 @@ class _register(Tape.Register, _number, _structure):
except AttributeError:
try:
return type(val)(cls.hard_conv(v) for v in val)
except TypeError:
except (TypeError, CompilerError):
pass
return cls(val)
@@ -756,11 +767,11 @@ class _register(Tape.Register, _number, _structure):
return sum(cls.conv(b) << i for i,b in enumerate(bits))
@classmethod
def malloc(cls, size, creator_tape=None):
def malloc(cls, size, creator_tape=None, **kwargs):
""" Allocate memory (statically).
:param size: compile-time (int) """
return program.malloc(size, cls, creator_tape=creator_tape)
return program.malloc(size, cls, creator_tape=creator_tape, **kwargs)
@classmethod
def free(cls, addr):
@@ -783,7 +794,11 @@ class _register(Tape.Register, _number, _structure):
else:
self[i].load_other(x)
elif val is not None:
self.load_other(val)
try:
self.load_other(val)
except:
raise CompilerError(
"cannot convert '%s' to '%s'" % (type(val), type(self)))
def _new_by_number(self, i, size=1):
res = type(self)(size=size)
@@ -943,16 +958,15 @@ class _clear(_arithmetic_register):
return self.clear_op(other, subc, subcfi, True)
__rsub__.__doc__ = __sub__.__doc__
def __truediv__(self, other):
def field_div(self, other):
""" Field division of public values. Not available for
computation modulo a power of two.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, divc, divci)
def __rtruediv__(self, other):
return self.coerce_op(other, divc, True)
__rtruediv__.__doc__ = __truediv__.__doc__
try:
return other._rfield_div(self)
except AttributeError:
return self.clear_op(other, divc, divci)
def __and__(self, other):
""" Bit-wise AND of public values.
@@ -1107,6 +1121,20 @@ class cint(_clear, _int):
def __rfloordiv__(self, other):
return self.coerce_op(other, floordivc, True)
def __truediv__(self, other):
""" Clear fixed-point division.
:param other: any compatible type """
if isinstance(other, cint):
return other.__rtruediv__(self)
try:
return cfix._new(self) / cfix._new(cint(other))
except:
return NotImplemented
def __rtruediv__(self, other):
return cfix._new(other) / cfix._new(self)
@vectorize
def less_than(self, other, bit_length):
""" Clear comparison for particular bit length.
@@ -1348,6 +1376,11 @@ class cgf2n(_clear, _gf2n):
""" Identity. """
return self
__truediv__ = _clear.field_div
def __rtruediv__(self, other):
return self.coerce_op(other, divc, True)
@vectorize
def __invert__(self):
""" Clear bit-wise inversion. """
@@ -1594,8 +1627,14 @@ class regint(_register, _int):
return self.int_op(other, divint, True)
__rfloordiv__.__doc__ = __floordiv__.__doc__
__truediv__ = __floordiv__
__rtruediv__ = __rfloordiv__
def __truediv__(self, other):
if isinstance(other, _gf2n):
return NotImplemented
else:
return cint(self) / other
def __rtruediv__(self, other):
return other / cint(self)
def __mod__(self, other):
""" Clear modulo computation.
@@ -1603,7 +1642,7 @@ class regint(_register, _int):
:param other: regint/cint/int """
if util.is_constant(other) and other >= 2 ** 64:
return self
return self - (self / other) * other
return self - (self // other) * other
def __rmod__(self, other):
""" Clear modulo computation.
@@ -1661,7 +1700,7 @@ class regint(_register, _int):
def __rshift__(self, other):
if isinstance(other, int):
return self / 2**other
return self // 2**other
else:
return self.cint_op(other, operator.rshift)
@@ -1793,6 +1832,9 @@ class localint(Tape._no_truth):
__eq__ = lambda self, other: localint(self._v == other)
__ne__ = lambda self, other: localint(self._v != other)
__add__ = lambda self, other: localint(self._v + other)
__radd__ = lambda self, other: localint(self._v + other)
class personal(Tape._no_truth):
""" Value known to one player. Supports operations with public
values and personal values known to the same player. Can be used
@@ -1812,7 +1854,7 @@ class personal(Tape._no_truth):
self._v = value
@classmethod
def read_int(cls, player):
def read_int(cls, player, n_bytes=None):
""" Read integer from
``Player-Data/Input-Binary-P<player>-<threadnum>`` only on
party :py:obj:`player`.
@@ -1822,7 +1864,7 @@ class personal(Tape._no_truth):
"""
tmp = cint()
fixinput(player, tmp, 0, 0)
fixinput(player, tmp, n_bytes or 0, 0)
return cls(player, tmp)
@classmethod
@@ -2229,7 +2271,7 @@ class _secret(_arithmetic_register, _secret_structure):
return self.secret_op(other, subs, submr, subsfi, True)
__rsub__.__doc__ = __sub__.__doc__
def __truediv__(self, other):
def field_div(self, other):
""" Secret field division.
:param other: any compatible type """
@@ -2237,13 +2279,12 @@ class _secret(_arithmetic_register, _secret_structure):
one = self.clear_type(1, size=other.size)
except AttributeError:
one = self.clear_type(1)
return self * (one / other)
return self * one.field_div(other)
@vectorize
def __rtruediv__(self, other):
def _rfield_div(self, other):
a,b = self.get_random_inverse()
return other * a / (a * self).reveal()
__rtruediv__.__doc__ = __truediv__.__doc__
return other * a.field_div((a * self).reveal())
@set_instruction_type
@vectorize
@@ -2311,8 +2352,8 @@ class sint(_secret, _int):
The following operations work as expected in the computation
domain (modulo a prime or a power of two): ``+, -, *``. ``/``
denotes the field division modulo a prime. It will reveal if the
divisor is zero. Comparisons operators (``==, !=, <, <=, >, >=``)
denotes a fixed-point division.
Comparisons operators (``==, !=, <, <=, >, >=``)
assume that the element in the computation domain represents a
signed integer in a restricted range, see below. The same holds
for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and
@@ -2398,14 +2439,14 @@ class sint(_secret, _int):
return res
@vectorized_classmethod
def get_input_from(cls, player, binary=False):
def get_input_from(cls, player, binary=False, n_bytes=None):
""" Secret input.
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
if binary:
return cls(personal.read_int(player))
return cls(personal.read_int(player, n_bytes=n_bytes))
else:
res = cls()
inputmixed('int', res, player)
@@ -2540,6 +2581,12 @@ class sint(_secret, _int):
"""
writesockets(client_id, message_type, values[0].size, *values)
@vectorize
def write_fully_to_socket(self, client_id,
message_type=ClientMessageType.NoType):
""" Send full secret to socket """
writesockets(client_id, message_type, self.size, self)
@vectorize
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send only share to socket """
@@ -2557,7 +2604,9 @@ class sint(_secret, _int):
@classmethod
def read_from_file(cls, start, n_items):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
""" Read shares from
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning (int/regint/cint)
:param n_items: number of items (int)
@@ -2572,7 +2621,8 @@ class sint(_secret, _int):
@staticmethod
def write_to_file(shares, position=None):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
(appending at the end). See :ref:`this section <persistence>`
for details on the data format.
:param shares: (list or iterable of sint)
:param position: start position (int/regint/cint),
@@ -2641,7 +2691,7 @@ class sint(_secret, _int):
res = sintbit()
comparison.LTZ(res, self - other,
(bit_length or program.bit_length) + 1,
security or program.security)
security)
return res
@read_mem_value
@@ -2651,7 +2701,7 @@ class sint(_secret, _int):
res = sintbit()
comparison.LTZ(res, other - self,
(bit_length or program.bit_length) + 1,
security or program.security)
security)
return res
@read_mem_value
@@ -2670,7 +2720,7 @@ class sint(_secret, _int):
def __eq__(self, other, bit_length=None, security=None):
return sintbit.conv(
floatingpoint.EQZ(self - other, bit_length or program.bit_length,
security or program.security))
security))
@read_mem_value
@type_comp
@@ -2709,7 +2759,6 @@ class sint(_secret, _int):
:param bit_length: bit length of input (default: global bit length)
"""
bit_length = bit_length or program.bit_length
security = security or program.security
if isinstance(m, int):
if m == 0:
return 0
@@ -2737,7 +2786,7 @@ class sint(_secret, _int):
:param bit_length: bit length of input (default: global bit length)
"""
return floatingpoint.Pow2(self, bit_length or program.bit_length, \
security or program.security)
security)
def __lshift__(self, other, bit_length=None, security=None):
""" Secret left shift.
@@ -2756,7 +2805,6 @@ class sint(_secret, _int):
:param bit_length: bit length of input (default: global bit length)
"""
bit_length = bit_length or program.bit_length
security = security or program.security
if isinstance(other, int):
if other == 0:
return self
@@ -2783,7 +2831,7 @@ class sint(_secret, _int):
""" Secret right shift.
:param other: secret or public integer (sint/cint/regint/int) of globale bit length if secret """
return floatingpoint.Trunc(other, program.bit_length, self, program.security)
return floatingpoint.Trunc(other, program.bit_length, self)
@vectorize
def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False):
@@ -2791,7 +2839,7 @@ class sint(_secret, _int):
if bit_length == 0:
return []
bit_length = bit_length or program.bit_length
assert program.security == security or program.security
program.non_linear.check_security(security)
return program.non_linear.bit_dec(self, bit_length, bit_length,
maybe_mixed)
@@ -2815,7 +2863,6 @@ class sint(_secret, _int):
:param kappa: statistical security parameter (int)
:param nearest: bool
:param signed: bool """
kappa = kappa or program.security
secret = isinstance(m, sint)
if nearest:
if secret:
@@ -2830,6 +2877,20 @@ class sint(_secret, _int):
def Norm(self, k, f, kappa=None, simplex_flag=False):
return library.Norm(self, k, f, kappa, simplex_flag)
def __truediv__(self, other):
""" Secret fixed-point division.
:param other: any compatible type """
if isinstance(other, sint):
return other.__rtruediv__(self)
try:
return sfix._new(self) / cfix._new(cint(other), f=sfix.f, k=sfix.k)
except:
return NotImplemented
def __rtruediv__(self, other):
return sfix._new(other) / sfix._new(self)
@vectorize
def int_div(self, other, bit_length=None, security=None):
""" Secret integer division. Note that the domain bit length
@@ -2839,7 +2900,7 @@ class sint(_secret, _int):
:param bit_length: bit length of input (default: global bit length)
"""
k = bit_length or program.bit_length
kappa = security or program.security
kappa = security
tmp = library.IntDiv(self, other, k, kappa)
res = type(self)()
comparison.Trunc(res, tmp, 2 * k, k, kappa, True)
@@ -2963,6 +3024,7 @@ class sint(_secret, _int):
gensecshuffle(res, n)
return res
@read_mem_value
def secure_permute(self, shuffle, unit_size=1, reverse=False):
res = sint(size=self.size)
applyshuffle(res, self, unit_size, shuffle, reverse)
@@ -3005,6 +3067,21 @@ class sint(_secret, _int):
def copy_from_part(self, source, base, size):
picks(self, source, base, 1)
def get_reverse_vector(self):
res = type(self)(size=self.size)
picks(res, self, self.size - 1, -1)
return res
def get_vector(self, base=0, size=None):
if size is None:
size = len(self) - base
if base == 0 and size == len(self):
return self
assert base + size <= len(self)
res = type(self)(size=size)
picks(res, self, base, 1)
return res
@classmethod
def concat(cls, parts):
parts = list(parts)
@@ -3013,6 +3090,10 @@ class sint(_secret, _int):
concats(res, *args)
return res
@vectorize
def output(self):
print_reg_plains(self)
class sintbit(sint):
""" :py:class:`sint` holding a bit, supporting binary operations
(``&, |, ^``). """
@@ -3137,6 +3218,7 @@ class sgf2n(_secret, _gf2n):
""" Store in memory by public address. """
self._store_in_mem(address, gstms, gstmsi)
@vectorize_init
def __init__(self, val=None, size=None):
super(sgf2n, self).__init__('sg', val=val, size=size)
@@ -3144,6 +3226,9 @@ class sgf2n(_secret, _gf2n):
""" Identity. """
return self
__truediv__ = _secret.field_div
__rtruediv__ = _secret._rfield_div
@vectorize
def __invert__(self):
""" Secret bit-wise inversion. """
@@ -3637,6 +3722,7 @@ class sgf2nint(_bitint, sgf2n):
raise CompilerError('Invalid signed %d-bit integer: %d' % \
(self.n_bits, other))
@vectorize
def load_other(self, other):
if isinstance(other, sgf2nint):
gmovs(self, self.compose(other.bit_decompose(self.n_bits)))
@@ -4200,6 +4286,15 @@ class _single(_number, _secret_structure):
cls.int_type.write_shares_to_socket(
client_id, [x.v for x in values], message_type)
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
return util.untuplify([cls._new(x) for x in util.tuplify(
cls.int_type.read_from_socket(client_id, n))])
@classmethod
def write_to_socket(cls, client_id, values):
cls.int_type.write_to_socket(client_id, [x.v for x in values])
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
@@ -4273,7 +4368,8 @@ class _single(_number, _secret_structure):
@classmethod
def read_from_file(cls, *args, **kwargs):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
Precision must be the same as when storing.
Precision must be the same as when storing. See :ref:`this
section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning
(int/regint/cint)
@@ -4288,7 +4384,8 @@ class _single(_number, _secret_structure):
@classmethod
def write_to_file(cls, shares, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data``.
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param shares: (list or iterable of sfix)
:param position: start position (int/regint/cint),
@@ -4588,7 +4685,8 @@ class _fix(_single):
nearest=self.round_nearest)
elif isinstance(other, cfix):
v = library.sint_cint_division(self.v, other.v, self.k, self.f,
self.kappa)
self.kappa,
nearest=self.round_nearest)
else:
raise TypeError('Incompatible fixed point types in division')
return self._new(v, k=self.k, f=self.f)
@@ -4656,7 +4754,7 @@ class sfix(_fix):
default_type = sint
@vectorized_classmethod
def get_input_from(cls, player, binary=False):
def get_input_from(cls, player, binary=False, n_bytes=None):
""" Secret fixed-point input.
:param player: public (regint/cint/int)
@@ -4770,6 +4868,13 @@ class sfix(_fix):
def multipliable(v, k, f, size):
return cfix._new(cint.conv(v, size=size), k, f)
def dot(self, other):
""" Dot product with :py:class:`sint:`. """
if isinstance(other, sint):
return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f)
else:
raise NotImplementedError()
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
@@ -4793,6 +4898,12 @@ class sfix(_fix):
def sum(self):
return self._new(self.v.sum())
def get_reverse_vector(self):
return self._new(self.v.get_reverse_vector(), k=self.k, f=self.f)
def get_vector(self, *args, **kwargs):
return self._new(self.v.get_vector(*args, **kwargs), k=self.k, f=self.f)
@classmethod
def concat(cls, parts):
parts = list(parts)
@@ -5486,6 +5597,11 @@ class sfloat(_number, _secret_structure):
self.z.update(other.z)
self.s.update(other.s)
def for_mux(self, other):
other = self.coerce(other)
f = lambda x: type(self)(*x)
return f, sint(list(self)), sint(list(other))
class cfloat(Tape._no_truth):
""" Helper class for printing revealed sfloats. """
__slots__ = ['v', 'p', 'z', 's', 'nan']
@@ -5763,6 +5879,9 @@ class Array(_vectorizable):
tmp.store_in_mem(address)
def __len__(self):
if self.length is None:
raise CompilerError('this functionality is not available '
'for variable-length arrays')
return self.length
def total_size(self):
@@ -5887,6 +6006,19 @@ class Array(_vectorizable):
addresses = self.get_slice_addresses(slice)
vector.store_in_mem(addresses)
def permute(self, permutation, reverse=False, n_threads=None):
""" Public permutation.
:param permutation: cleartext :py:class`Array` containing number
in :math:`[0,n-1]` where :math:`n` is the length of this array
:param reverse: whether to apply the inverse of the permutation
"""
if reverse:
self.assign_slice_vector(permutation, self.get_vector())
else:
self.assign_vector(self.get_slice_vector(permutation))
def expand_to_vector(self, index, size):
""" Create vector from single entry.
@@ -5930,7 +6062,8 @@ class Array(_vectorizable):
def read_from_file(self, start):
""" Read content from ``Persistence/Transactions-P<playerno>.data``.
Precision must be the same as when storing if applicable.
Precision must be the same as when storing if applicable. See
:ref:`this section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning
(int/regint/cint)
@@ -5943,13 +6076,36 @@ class Array(_vectorizable):
def write_to_file(self, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data``.
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param position: start position (int/regint/cint),
defaults to end of file
"""
self.value_type.write_to_file(list(self), position)
def read_from_socket(self, socket, debug=False):
""" Read content from socket. """
if debug:
library.print_str('reading %s...' % self)
# hard-coded budget for interopability
@library.multithread(None, len(self), max_size=10 ** 6)
def _(base, size):
self.assign_vector(
self.value_type.read_from_socket(socket, size=size), base=base)
if debug:
library.print_ln('done')
def write_to_socket(self, socket, debug=False):
""" Write content to socket. """
if debug:
library.print_ln('writing %s' % self)
# hard-coded budget for interopability
@library.multithread(None, len(self), max_size=10 ** 6)
def _(base, size):
self.value_type.write_to_socket(
socket, [self.get_vector(base=base, size=size)])
def __add__(self, other):
""" Vector addition.
@@ -6254,9 +6410,13 @@ class SubMultiArray(_vectorizable):
""" Assign container to content. Not implemented for floating-point.
:param other: container of matching size and type """
if self.value_type.n_elements() > 1:
assert self.sizes == other.sizes
self.assign_vector(other.get_vector())
try:
if self.value_type.n_elements() > 1:
assert self.sizes == other.sizes
self.assign_vector(other.get_vector())
except:
for i, x in enumerate(other):
self[i].assign(x)
def get_part_vector(self, base=0, size=None):
""" Vector from range of the first dimension, including all
@@ -6297,15 +6457,41 @@ class SubMultiArray(_vectorizable):
addresses = self.get_slice_addresses(slice)
vector.store_in_mem(self.address + addresses)
def get_slice_addresses(self, slice):
def get_part_size(self):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
return reduce(operator.mul, self.sizes[1:])
def get_slice_addresses(self, slice, part_size=None):
part_size = part_size or self.get_part_size()
assert len(slice) * part_size <= self.total_size()
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
addresses = slice.value_type.load_mem(base) * part_size + inc
return addresses
def permute(self, permutation, reverse=False, n_threads=None):
""" Public permutation along first dimension.
:param permutation: cleartext :py:class`Array` containing number
in :math:`[0,n-1]` where :math:`n` is the length of this array
:param reverse: whether to apply the inverse of the permutation
"""
@library.multithread(n_threads, self.get_part_size())
def _(base, size):
addresses = self.get_slice_addresses(permutation, part_size=1)
addresses *= self.get_part_size()
@library.for_range_opt(size)
def _(j):
i = base + j
if reverse:
v = self.get_column(i)
v.store_in_mem(self.address + i + addresses)
else:
v = self.value_type.load_mem(
self.address + i + addresses)
self.set_column(i, v)
def get_addresses(self, *indices):
assert self.value_type.n_elements() == 1
assert len(indices) == len(self.sizes)
@@ -6389,7 +6575,8 @@ class SubMultiArray(_vectorizable):
def write_to_file(self, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data``.
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param position: start position (int/regint/cint),
defaults to end of file
@@ -6404,7 +6591,8 @@ class SubMultiArray(_vectorizable):
def read_from_file(self, start):
""" Read content from ``Persistence/Transactions-P<playerno>.data``.
Precision must be the same as when storing if applicable.
Precision must be the same as when storing if applicable. See
:ref:`this section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning
(int/regint/cint)
@@ -6417,6 +6605,14 @@ class SubMultiArray(_vectorizable):
start.write(self[i].read_from_file(start))
return start
def write_to_socket(self, socket, debug=False):
""" Write content to socket. """
self.array.write_to_socket(socket, debug=debug)
def read_from_socket(self, socket, debug=False):
""" Read content from socket. """
self.array.read_from_socket(socket, debug=debug)
def schur(self, other):
""" Element-wise product.
@@ -6758,7 +6954,28 @@ class SubMultiArray(_vectorizable):
res = self.value_type.dot_product(a, b)
return res
def transpose(self):
def get_column(self, index):
""" Get matrix column as vector.
:param index: regint/cint/int
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.get_part_size())
return self.value_type.load_mem(addresses)
def set_column(self, index, vector):
""" Change column.
:param index: regint/cint/int
:param vector: short enought vector of compatible type
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.get_part_size())
self.value_type.conv(vector).store_in_mem(addresses)
def transpose(self, n_threads=None):
""" Matrix transpose.
:param self: two-dimensional """
@@ -6766,13 +6983,24 @@ class SubMultiArray(_vectorizable):
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
library.break_point()
if self.value_type.n_elements() == 1:
nr = self.sizes[1]
nc = self.sizes[0]
a = regint.inc(nr * nc, 0, nr, 1, nc)
b = regint.inc(nr * nc, 0, 1, nc)
res[:] = self.value_type.load_mem(self.address + a + b)
if self.sizes[0] < program.budget:
if self.sizes[1] < program.budget:
nr = self.sizes[1]
nc = self.sizes[0]
a = regint.inc(nr * nc, 0, nr, 1, nc)
b = regint.inc(nr * nc, 0, 1, nc)
res[:] = self.value_type.load_mem(self.address + a + b)
else:
@library.for_range_multithread(n_threads, 1, self.sizes[0])
def _(i):
res.set_column(i, self[i][:])
else:
@library.for_range_multithread(n_threads, 1, self.sizes[1])
def _(i):
res[i][:] = self.get_column(i)
else:
@library.for_range_opt(self.sizes[1], budget=100)
@library.for_range_opt_multithread(n_threads, self.sizes[1],
budget=100)
def _(i):
@library.for_range_opt(self.sizes[0], budget=100)
def _(j):
@@ -6801,7 +7029,7 @@ class SubMultiArray(_vectorizable):
"""
self.assign_vector(self.get_vector().secure_shuffle(self.part_size()))
def secure_permute(self, permutation, reverse=False):
def secure_permute(self, permutation, reverse=False, n_threads=None):
""" Securely permute rows (first index). See
:py:func:`secure_shuffle` for references.
@@ -6809,8 +7037,12 @@ class SubMultiArray(_vectorizable):
:param reverse: whether to apply inverse (default: False)
"""
self.assign_vector(self.get_vector().secure_permute(
permutation, self.part_size(), reverse))
if n_threads is not None:
permutation = MemValue(permutation)
@library.for_range_multithread(n_threads, 1, self.get_part_size())
def _(i):
self.set_column(i, self.get_column(i).secure_permute(
permutation, reverse=reverse))
def sort(self, key_indices=None, n_bits=None):
""" Sort sub-arrays (different first index) in place.
@@ -6829,6 +7061,9 @@ class SubMultiArray(_vectorizable):
return
if key_indices is None:
key_indices = (0,) * (len(self.sizes) - 1)
if len(key_indices) != len(self.sizes) - 1:
raise CompilerError('length of key_indices has to be one less '
'than the dimension')
key_indices = (None,) + util.tuplify(key_indices)
from . import sorting
keys = self.get_vector_by_indices(*key_indices)
@@ -6971,7 +7206,8 @@ class Matrix(MultiArray):
@staticmethod
def create_from(rows):
rows = list(rows)
if not isinstance(rows, _vectorizable):
rows = list(rows)
if isinstance(rows[0], (list, tuple, Array)):
t = type(rows[0][0])
else:
@@ -6983,20 +7219,15 @@ class Matrix(MultiArray):
raise CompilerError(
'accidental shortening by creating matrix')
res = Matrix(len(rows), len(rows[0]), t)
for i in range(len(rows)):
res[i].assign(rows[i])
if isinstance(rows, _vectorizable):
@library.for_range_opt(len(rows))
def _(i):
res[i].assign(rows[i])
else:
for i in range(len(rows)):
res[i].assign(rows[i])
return res
def get_column(self, index):
""" Get column as vector.
:param index: regint/cint/int
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.sizes[1])
return self.value_type.load_mem(addresses)
def get_columns(self):
return (self.get_column(i) for i in range(self.sizes[1]))
@@ -7006,17 +7237,6 @@ class Matrix(MultiArray):
regint.inc(len(rows), self.address + column, 0)
return self.value_type.load_mem(addresses)
def set_column(self, index, vector):
""" Change column.
:param index: regint/cint/int
:param vector: short enought vector of compatible type
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.sizes[1])
self.value_type.conv(vector).store_in_mem(addresses)
def concat_columns(self, other):
""" Concatenate two matrices by columns. """
assert self.sizes[0] == other.sizes[0]
@@ -7217,6 +7437,8 @@ class MemValue(_mem):
bit_and = lambda self,other: self.read().bit_and(other)
bit_not = lambda self: self.read().bit_not()
print_if = lambda self,*args,**kwargs: self.read().print_if(*args, **kwargs)
def expand_to_vector(self, size=None):
if program.curr_block == self.last_write_block:
return self.read().expand_to_vector(size)

View File

@@ -64,14 +64,14 @@ RUN pip install --upgrade pip ipython
COPY . .
ARG arch=native
ARG arch=
ARG cxx=clang++-11
ARG use_ntl=0
ARG prep_dir="Player-Data"
ARG ssl_dir="Player-Data"
RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \
&& echo "CXX = ${cxx}" >> CONFIG.mine \
RUN if test -n "${arch}"; then echo "ARCH = -march=${arch}" >> CONFIG.mine; fi
RUN echo "CXX = ${cxx}" >> CONFIG.mine \
&& echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \
&& echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \
&& echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \

View File

@@ -50,7 +50,6 @@ public:
void assign_zero() { *this = 0; }
bool is_zero() { return *this == 0; }
void add(octetStream& os) { *this += os.get<CurveElement>(); }
void pack(octetStream& os) const;
void unpack(octetStream& os);

View File

@@ -166,13 +166,3 @@ bool P256Element::operator !=(const P256Element& other) const
{
return not (*this == other);
}
octetStream P256Element::hash(size_t n_bytes) const
{
octetStream os;
pack(os);
auto res = os.hash();
assert(n_bytes >= res.get_length());
res.resize_precise(n_bytes);
return res;
}

View File

@@ -56,15 +56,9 @@ public:
bool operator==(const P256Element& other) const;
bool operator!=(const P256Element& other) const;
void assign_zero() { *this = {}; }
bool is_zero() { return *this == P256Element(); }
void add(octetStream& os, int = -1) { *this += os.get<P256Element>(); }
void pack(octetStream& os, int = -1) const;
void unpack(octetStream& os, int = -1);
octetStream hash(size_t n_bytes) const;
friend ostream& operator<<(ostream& s, const P256Element& x);
};

View File

@@ -131,6 +131,12 @@ int main(int argc, char** argv)
case 'R':
{
int R = specification.get<int>();
int R2 = specification.get<int>();
if (R2 != 64)
{
cerr << R2 << "-bit ring not implemented" << endl;
}
switch (R)
{
case 64:

View File

@@ -14,24 +14,17 @@ finish = int(sys.argv[4])
client = Client(['localhost'] * n_parties, 14000, client_id)
type = client.specification.get_int(4)
if type == ord('R'):
domain = Z2(client.specification.get_int(4))
elif type == ord('p'):
domain = Fp(client.specification.get_bigint())
else:
raise Exception('invalid type')
for socket in client.sockets:
os = octetStream()
os.store(finish)
os.Send(socket)
def run(x):
client.send_private_inputs([x])
print('Winning client id is :', client.receive_outputs(1)[0])
# running two rounds
# first for sint, then for sfix
for x in bonus, bonus * 2 ** 16:
client.send_private_inputs([domain(x)])
print('Winning client id is :',
int(client.receive_outputs(domain, 1)[0]))
run(bonus)
run(bonus * 2 ** 16)

View File

@@ -2,6 +2,7 @@ import platform
import socket, ssl
import struct
import time
from domains import *
# The following function is either taken directly or derived from:
# https://stackoverflow.com/questions/12248132/how-to-change-tcp-keepalive-timer-using-python-script
@@ -61,6 +62,15 @@ class Client:
self.specification = octetStream()
self.specification.Receive(self.sockets[0])
type = self.specification.get_int(4)
if type == ord('R'):
self.domain = Z2(self.specification.get_int(4))
self.clear_domain = Z2(self.specification.get_int(4))
elif type == ord('p'):
self.domain = Fp(self.specification.get_bigint())
self.clear_domain = self.domain
else:
raise Exception('invalid type')
def receive_triples(self, T, n):
triples = [[0, 0, 0] for i in range(n)]
@@ -89,18 +99,19 @@ class Client:
return triples
def send_private_inputs(self, values):
T = type(values[0])
T = self.domain
triples = self.receive_triples(T, len(values))
os = octetStream()
assert len(values) == len(triples)
for value, triple in zip(values, triples):
(value + triple[0]).pack(os)
(T(value) + triple[0]).pack(os)
for socket in self.sockets:
os.Send(socket)
def receive_outputs(self, T, n):
def receive_outputs(self, n):
T = self.domain
triples = self.receive_triples(T, n)
return [triple[0] for triple in triples]
return [int(self.clear_domain(triple[0].v)) for triple in triples]
class octetStream:
def __init__(self, value=None):

View File

@@ -67,20 +67,15 @@ void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,
cc0.Scale(pk.p()); cc1.Scale(pk.p());
// Now do the multiply
Rq_Element d0,d1,d2;
mul(d0,cc0.cc0,cc1.cc0);
mul(d1,cc0.cc0,cc1.cc1);
mul(d2,cc0.cc1,cc1.cc0);
add(d1,d1,d2);
mul(d2,cc0.cc1,cc1.cc1);
auto d0 = cc0.cc0 * cc1.cc0;
auto d1 = cc0.cc0 * cc1.cc1 + cc0.cc1 * cc1.cc0;
auto d2 = cc0.cc1 * cc1.cc1;
d2.negate();
// Now do the switch key
d2.raise_level();
Rq_Element t;
d0.mul_by_p1();
mul(t,pk.bs(),d2);
auto t = pk.bs()* d2;
add(d0,d0,t);
d1.mul_by_p1();

View File

@@ -29,7 +29,6 @@ class Ciphertext
word pk_id;
public:
static int size() { return 0; }
const FHE_Params& get_params() const { return *params; }

View File

@@ -94,8 +94,8 @@ void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
add(PK.Sw_b,PK.Sw_b,es);
// bs=bs-p1*s^2
Rq_Element s2;
mul(s2,sk,sk); // Mult at level 0
// Mult at level 0
auto s2 = sk * sk;
s2.mul_by_p1(); // This raises back to level 1
sub(PK.Sw_b,PK.Sw_b,s2);
}
@@ -155,17 +155,12 @@ void FHE_PK::quasi_encrypt(Ciphertext& c,
if (&rc.get_params()!=params) { throw params_mismatch(); }
assert(pr != 0);
Rq_Element ed,edd,c0,c1,aa;
// c1=a0*u+p*v
mul(aa,a0,rc.u());
mul(ed,rc.v(),pr);
add(c1,aa,ed);
auto c1 = a0 * rc.u() + rc.v() * pr;
// c0 = b0 * u + p * w + mess
mul(c0,b0,rc.u());
mul(edd,rc.w(),pr);
add(edd,edd,mess);
auto c0 = b0 * rc.u();
auto edd = rc.w() * pr + mess;
if (params->n_mults() == 0)
edd.change_rep(evaluation);
else
@@ -218,10 +213,7 @@ Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
auto ans = c.c0() - c.c1() * sk;
ans.change_rep(polynomial);
return ans;
}
@@ -267,8 +259,7 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
Ciphertext cc=ctx; cc.Scale(pr);
// First do the basic decryption
Rq_Element dec_sh;
mul(dec_sh,cc.c1(),sk);
auto dec_sh = cc.c1() * sk;
if (player_number==0)
{ sub(dec_sh,cc.c0(),dec_sh); }
else

View File

@@ -89,9 +89,6 @@ class FHE_SK
bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; }
void add(octetStream& os, int = -1)
{ FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const;
template<class FD>
@@ -120,10 +117,12 @@ class FHE_PK
bigint p() const { return pr; }
void assign(const Rq_Element& a,const Rq_Element& b,
const Rq_Element& sa = {},const Rq_Element& sb = {}
const Rq_Element& sa,const Rq_Element& sb
)
{ a0=a; b0=b; Sw_a=sa; Sw_b=sb; }
void assign(const Rq_Element& a,const Rq_Element& b)
{ a0=a; b0=b; }
FHE_PK(const FHE_Params& pms);

View File

@@ -27,6 +27,8 @@ class RingReadIterator;
class Ring_Element
{
friend class Rq_Element;
RepType rep;
/* FFTD is defined as a pointer so each different Ring_Element
@@ -41,6 +43,9 @@ class Ring_Element
vector<modp> element;
/* Careful calling this one, as FFTD will not be defined */
Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; }
public:
// Used to basically make sure *this is able to cope
@@ -63,9 +68,6 @@ class Ring_Element
void assign_zero();
void assign_one();
/* Careful calling this one, as FFTD will not be defined */
Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; }
Ring_Element(const FFT_Data& prd,RepType r=polynomial);
template<class T>

View File

@@ -23,7 +23,7 @@ Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
void Rq_Element::set_data(const vector<FFT_Data>& prd)
{
a.resize(prd.size());
a.resize(prd.size(), {});
for(size_t i = 0; i < a.size(); i++)
a[i].set_data(prd[i]);
lev=n_mults();
@@ -50,7 +50,7 @@ void Rq_Element::assign_one()
void Rq_Element::partial_assign(const Rq_Element& other)
{
lev=other.lev;
a.resize(other.a.size());
a.resize(other.a.size(), {});
}
void Rq_Element::negate()
@@ -112,13 +112,6 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
}
}
void Rq_Element::add(octetStream& os, int)
{
Rq_Element tmp(*this);
tmp.unpack(os);
*this += tmp;
}
void Rq_Element::randomize(PRNG& G,int l)
{
set_level(l);

View File

@@ -33,6 +33,10 @@ protected:
vector<Ring_Element> a;
int lev;
// Must be careful not to call by mistake
Rq_Element(RepType r0=evaluation,RepType r1=polynomial) :
a({r0, r1}), lev(n_mults()) {}
public:
int n_mults() const { return a.size() - 1; }
@@ -46,10 +50,6 @@ protected:
void assign_one();
void partial_assign(const Rq_Element& e);
// Must be careful not to call by mistake
Rq_Element(RepType r0=evaluation,RepType r1=polynomial) :
a({r0, r1}), lev(n_mults()) {}
// Pass in a pair of FFT_Data as a vector
Rq_Element(const vector<FFT_Data>& prd, RepType r0 = evaluation,
RepType r1 = polynomial);
@@ -97,8 +97,6 @@ protected:
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
void add(octetStream& os, int = -1);
template<class S>
Rq_Element& operator+=(const vector<S>& other);

View File

@@ -15,13 +15,10 @@
void Encrypt_Rq_Element(Ciphertext& c,const Rq_Element& mess, const Random_Coins& rc,
const FHE_PK& pk)
{
Rq_Element ed, edd, c0, c1;
mul(c1, pk.a(), rc.u());
mul(ed, rc.v(), pk.p());
add(c1, c1, ed);
auto c1 = pk.a() * rc.u() + rc.v() * pk.p();
auto c0 = pk.b() * rc.u();
auto edd = rc.w() * pk.p();
mul(c0, pk.b(), rc.u());
mul(edd, rc.w(), pk.p());
edd.change_rep(evaluation,evaluation);
add(edd,edd,mess);
add(c0,c0,edd);

View File

@@ -87,6 +87,7 @@ public:
void xorc(const ::BaseInstruction& instruction);
void nots(const ::BaseInstruction& instruction);
void notcb(const ::BaseInstruction& instruction);
void movsb(const ::BaseInstruction& instruction);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);
void andrs(const vector<int>& args) { and_(args, true); }

View File

@@ -283,6 +283,13 @@ void Processor<T>::notcb(const ::BaseInstruction& instruction)
}
}
template<class T>
void Processor<T>::movsb(const ::BaseInstruction& instruction)
{
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i];
}
template<class T>
void Processor<T>::andm(const ::BaseInstruction& instruction)
{

View File

@@ -32,7 +32,7 @@ void TinierSharePrep<T>::buffer_secret_triples()
assert(triple_generator != 0);
params.generateBits = false;
vector<array<T, 3>> triples;
TripleShuffleSacrifice<T> sacrifice;
TripleShuffleSacrifice<T> sacrifice(DATA_GF2);
size_t required;
required = sacrifice.minimum_n_inputs_with_combining(
BaseMachine::batch_size<T>(DATA_TRIPLE));

View File

@@ -65,7 +65,7 @@
X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \
X(LDMCBI, PROC.mem_op(SIZE, PROC.C, MMC, R0, Ci[REG1])) \
X(STMCBI, PROC.mem_op(SIZE, MMC, PROC.C, Ci[REG1], R0)) \
X(MOVSB, S0 = PS1) \
X(MOVSB, PROC.movsb(INST)) \
X(TRANS, T::trans(PROC, IMM, EXTRA)) \
X(BITB, PROC.random_bit(S0)) \
X(REVEAL, T::reveal_inst(PROC, EXTRA)) \
@@ -123,7 +123,7 @@
X(LDMINTI, I0 = MII) \
X(STMINTI, MII = I0) \
X(PUSHINT, PROC.pushi(I0.get())) \
X(POPINT, long x; PROC.popi(x); I0 = x) \
X(POPINT, PROC.popi(I0)) \
X(MOVINT, I0 = PI1) \
X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \
X(LDARG, I0 = PROC.get_arg()) \

View File

@@ -72,6 +72,12 @@ ShamirOptions::ShamirOptions(ez::ezOptionParser& opt, int argc, const char** arg
);
opt.parse(argc, argv);
opt.get("-N")->getInt(nparties);
if (nparties < 3)
{
cerr << "Protocols based on Shamir secret sharing require at least "
<< "three parties." << endl;
exit(1);
}
set_threshold(opt);
opt.resetArgs();
}

View File

@@ -26,23 +26,9 @@ int main(int argc, const char** argv)
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
online_opts = {opt, argc, argv, FakeShare<SignedZ2<64>>()};
opt.parse(argc, argv);
opt.syntax = string(argv[0]) + " <progname>";
string progname;
if (opt.firstArgs.size() > 1)
progname = *opt.firstArgs.at(1);
else if (not opt.lastArgs.empty())
progname = *opt.lastArgs.at(0);
else if (not opt.unknownArgs.empty())
progname = *opt.unknownArgs.at(0);
else
{
string usage;
opt.getUsage(usage);
cerr << usage << endl;
exit(1);
}
online_opts.finalize(opt, argc, argv, false);
string& progname = online_opts.progname;
#ifdef ROUND_NEAREST_IN_EMULATION
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;

View File

@@ -7,6 +7,7 @@
#include "Processor/OnlineMachine.hpp"
#include "Processor/Machine.hpp"
#include "Processor/OnlineOptions.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/ReplicatedPrep.hpp"
@@ -17,7 +18,7 @@
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
OnlineOptions::singleton = {opt, argc, argv};
OnlineOptions::singleton = {opt, argc, argv, NoShare<gf2n>()};
OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton);
OnlineOptions::singleton.finalize(opt, argc, argv);
machine.start_networking();

View File

@@ -7,6 +7,7 @@
#include "Processor/Machine.h"
#include "Processor/RingOptions.h"
#include "Protocols/Spdz2kShare.h"
#include "Protocols/SPDZ2k.h"
#include "Math/gf2n.h"
#include "Networking/Server.h"
@@ -62,8 +63,10 @@ int main(int argc, const char** argv)
cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line "
<< (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+"
<< s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl;
cerr << "Alternatively, compile with -DRING_SIZE=" << k
<< " and -DSPDZ2K_DEFAULT_SECURITY=" << s << endl;
cerr << "Alternatively, put 'MY_CFLAGS += -DRING_SIZE=" << k
<< " -DSPDZ2K_DEFAULT_SECURITY=" << s
<< "' in 'CONFIG.mine' before running 'make spdz2k-party.x'"
<< endl;
}
exit(1);
}

View File

@@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x
ifeq ($(OS), Darwin)
setup: mac-setup
else
setup: boost linux-machine-setup
setup: maybe-boost linux-machine-setup
endif
tldr: setup
@@ -296,13 +296,17 @@ deps/SimplestOT_C/ref10/Makefile:
.PHONY: Programs/Circuits
Programs/Circuits:
git submodule update --init Programs/Circuits
git submodule update --init Programs/Circuits || git clone https://github.com/mkskeller/bristol-fashion Programs/Circuits
deps/libOTe/libOTe:
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
boost: deps/libOTe/libOTe
cd deps/libOTe; \
python3 build.py --setup --boost --install=$(CURDIR)/local
maybe-boost: deps/libOTe/libOTe
cd `mktemp -d`; \
PATH="$(CURDIR)/local/bin:$(PATH)" cmake $(CURDIR)/deps/libOTe || \
{ cd -; make boost; }
OTE_OPTS += -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib
@@ -334,11 +338,12 @@ OT/OTExtensionWithMatrix.o: $(OTE)
endif
local/lib/liblibOTe.a: deps/libOTe/libOTe
make maybe-boost; \
cd deps/libOTe; \
PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS) && \
touch ../../local/lib/liblibOTe.a
$(SHARED_OTE): deps/libOTe/libOTe
$(SHARED_OTE): deps/libOTe/libOTe maybe-boost
cd deps/libOTe; \
python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS)

View File

@@ -51,11 +51,6 @@ public:
return other * *this;
}
void add(octetStream& os, int = -1)
{
*this += os.get<Bit>();
}
void pack(octetStream& os, int = -1) const
{
super::pack(os, 1);

View File

@@ -56,13 +56,6 @@ public:
void extend_bit(BitVec_& res, int) const { res = extend_bit(); }
void add(octetStream& os, int n_bits)
{
BitVec_ tmp;
tmp.unpack(os, n_bits);
*this += tmp;
}
void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; }
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = this->mask(n); }

View File

@@ -138,12 +138,6 @@ public:
v[i] = (x.v[i] * y.v[i]);
}
void add(octetStream& os)
{
for (int i = 0; i < L; i++)
v[i].add(os);
}
void negate()
{
for (auto& x : v)

View File

@@ -8,6 +8,8 @@
#include <sys/stat.h>
const false_type ValueInterface::binary;
void ValueInterface::check_setup(const string& directory)
{
struct stat sb;

View File

@@ -156,8 +156,6 @@ public:
bool operator==(const Z2<K>& other) const;
bool operator!=(const Z2<K>& other) const { return not (*this == other); }
void add(octetStream& os, int = -1) { *this += (os.consume(size())); }
Z2 lazy_add(const Z2& x) const;
Z2 lazy_mul(const Z2& x) const;

View File

@@ -136,7 +136,7 @@ inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y
template<>
inline void Zp_Data::Add<1>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
{
#if defined(__clang__) || !defined(__x86_64__)
#if defined(__clang__) || !defined(__x86_64__) || (__GNUC__ == 10)
Add<0>(ans, x, y);
#else
*ans = *x + *y;

View File

@@ -87,12 +87,6 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs)
mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data);
}
void bigint::add(octetStream& os, int)
{
tmp.unpack(os);
*this += tmp;
}
string to_string(const bigint& x)
{
stringstream ss;

View File

@@ -102,8 +102,6 @@ public:
void mul(const bigint& x, const bigint& y) { *this = x * y; }
void add(octetStream& os, int = -1);
#ifdef REALLOC_POLICE
~bigint() { lottery(); }
void lottery();

View File

@@ -136,10 +136,6 @@ protected:
// x+y
void add(const gf2n_& x,const gf2n_& y)
{ a=x.a^y.a; }
void add(octet* x)
{ a^=*(U*)(x); }
void add(octetStream& os, int = -1)
{ add(os.consume(size())); }
void sub(const gf2n_& x,const gf2n_& y)
{ a=x.a^y.a; }
// = x * y

View File

@@ -189,12 +189,8 @@ class gfp_ : public ValueInterface
bool operator!=(const gfp_& y) const { return !equal(y); }
// x+y
void add(octetStream& os, int = -1)
{ add(os.consume(size())); }
void add(const gfp_& x,const gfp_& y)
{ ZpD.Add<L>(a.x,x.a.x,y.a.x); }
void add(void* x)
{ ZpD.Add<L>(a.x,a.x,(mp_limb_t*)x); }
void sub(const gfp_& x,const gfp_& y)
{ ZpD.Sub<L>(a.x,x.a.x,y.a.x); }
// = x * y

View File

@@ -295,12 +295,6 @@ bool gfpvar_<X, L>::operator !=(const gfpvar_<X, L>& other) const
return not (*this == other);
}
template<int X, int L>
void gfpvar_<X, L>::add(octetStream& other, int)
{
*this += other.get<gfpvar_<X, L>>();
}
template<int X, int L>
void gfpvar_<X, L>::negate()
{

View File

@@ -149,8 +149,6 @@ public:
bool operator==(const gfpvar_& other) const;
bool operator!=(const gfpvar_& other) const;
void add(octetStream& other, int = -1);
void negate();
gfpvar_ invert() const;

View File

@@ -272,14 +272,14 @@ PlayerBase::~PlayerBase()
// Set up nmachines client and server sockets to send data back and fro
// A machine is a server between it and player i if i<=my_number
// A machine is a server between it and player i if i>=my_number
// Can also communicate with myself, but only with send_to and receive_from
void PlainPlayer::setup_sockets(const vector<string>& names,
const vector<int>& ports, const string& id_base, ServerSocket& server)
{
sockets.resize(nplayers);
// Set up the client side
for (int i=player_no; i<nplayers; i++) {
for (int i=0; i<=player_no; i++) {
auto pn=id_base+"P"+to_string(player_no);
if (i==player_no) {
const char* localhost = "127.0.0.1";
@@ -300,7 +300,7 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
}
send_to_self_socket = sockets[player_no];
// Setting up the server side
for (int i=0; i<=player_no; i++) {
for (int i=player_no; i<nplayers; i++) {
auto id=id_base+"P"+to_string(i);
#ifdef DEBUG_NETWORKING
fprintf(stderr,

View File

@@ -140,12 +140,12 @@ void ServerSocket::accept_clients()
fprintf(stderr, "Accepting...\n");
#endif
int consocket;
for (int i = 0; i < 25; i++)
for (int i = 0; i < 1000; i++)
{
consocket = accept(main_socket, (struct sockaddr*) &dest,
(socklen_t*) &socksize);
if (consocket < 0)
usleep(1 << i);
usleep(min(1 << i, 1000));
else
break;
}
@@ -204,8 +204,10 @@ int ServerSocket::get_connection_socket(const string& id)
while (clients.find(id) == clients.end())
{
if (data_signal.wait(60) == ETIMEDOUT)
throw runtime_error("No client after one minute");
if (data_signal.wait(CONNECTION_TIMEOUT) == ETIMEDOUT)
throw runtime_error("Timed out waiting for peer. See "
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
"for details on networking.");
}
int client_socket = clients[id];
@@ -228,7 +230,7 @@ void AnonymousServerSocket::init()
void AnonymousServerSocket::process_client(const string& client_id)
{
if (clients.find(client_id) != clients.end())
close_client_socket(clients[client_id]);
throw runtime_error("client " + client_id + " already connected");
client_connection_queue.push(client_id);
}
@@ -236,9 +238,14 @@ int AnonymousServerSocket::get_connection_socket(string& client_id)
{
data_signal.lock();
//while (clients.find(next_client_id) == clients.end())
while (client_connection_queue.empty())
data_signal.wait();
{
int res = data_signal.wait(CONNECTION_TIMEOUT);
if (res == ETIMEDOUT)
throw runtime_error("timed out while waiting for client");
else if (res)
throw runtime_error("waiting error");
}
client_id = client_connection_queue.front();
client_connection_queue.pop();

View File

@@ -29,7 +29,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
gethostname((char*)my_name,512);
int erp;
for (int i = 0; i < 60; i++)
for (int i = 0; i < CONNECTION_TIMEOUT; i++)
{ erp=getaddrinfo (hostname, NULL, &hints, &ai);
if (erp == 0)
{ break; }
@@ -90,7 +90,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
if (fl != 0)
{
close(mysocket);
usleep(wait *= 2);
usleep(wait < 1000 ? wait *= 2 : wait);
#ifdef DEBUG_NETWORKING
string msg = "Connecting to " + string(hostname) + ":" +
to_string(Portnum) + " failed";
@@ -102,7 +102,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
}
while (fl == -1
&& (errno == ECONNREFUSED || errno == ETIMEDOUT || errno == EINPROGRESS)
&& timer.elapsed() < 60);
&& timer.elapsed() < CONNECTION_TIMEOUT);
if (fl < 0)
{

View File

@@ -23,6 +23,10 @@
#include <iostream>
using namespace std;
// default to one minute
#ifndef CONNECTION_TIMEOUT
#define CONNECTION_TIMEOUT 60
#endif
void error(const char *str);
@@ -38,10 +42,15 @@ void receive(T& socket, size_t& a, size_t len);
inline size_t send_non_blocking(int socket, octet* msg, size_t len)
{
#ifdef __APPLE__
int j = send(socket,msg,min(len,10000lu),MSG_DONTWAIT);
#else
int j = send(socket,msg,len,MSG_DONTWAIT);
#endif
if (j < 0)
{
if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK)
if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK and
errno != ENOBUFS)
{ error("Send error - 1 "); }
else
return 0;

View File

@@ -41,6 +41,8 @@ union square128 {
int16_t doublebytes[128][8];
int32_t words[128][4];
square128() {}
bool get_bit(int x, int y)
{ return (bytes[x][y/8] >> (y % 8)) & 1; }

View File

@@ -8,6 +8,7 @@
#include "Math/Setup.h"
#include "Tools/Bundle.h"
#include "Instruction.hpp"
#include "Protocols/ShuffleSacrifice.hpp"
#include <iostream>
@@ -38,12 +39,27 @@ bool BaseMachine::has_program()
}
int BaseMachine::edabit_bucket_size(int n_bits)
{
size_t usage = 0;
if (has_program())
usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
return bucket_size(usage);
}
int BaseMachine::triple_bucket_size(DataFieldType type)
{
size_t usage = 0;
if (has_program())
usage = s().progs[0].get_offline_data_used().files[type][DATA_TRIPLE];
return bucket_size(usage);
}
int BaseMachine::bucket_size(size_t usage)
{
int res = OnlineOptions::singleton.bucket_size;
if (has_program())
if (usage)
{
auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
for (int B = res; B <= 5; B++)
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
break;
@@ -91,7 +107,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
string threadname;
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
size_t split = threadname.find(":");
size_t split = threadname.find_last_of(":");
long expected = -1;
if (split != string::npos)
{
@@ -125,6 +141,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
getline(inpf, compiler);
getline(inpf, domain);
getline(inpf, relevant_opts);
getline(inpf, security);
inpf.close();
}
@@ -184,17 +201,19 @@ string BaseMachine::memory_filename(const string& type_short, int my_number)
string BaseMachine::get_domain(string progname)
{
if (singleton)
{
assert(s().progname == progname);
return s().domain;
}
return get_basics(progname).domain;
}
assert(not singleton);
BaseMachine BaseMachine::get_basics(string progname)
{
if (singleton and s().progname == progname)
return s();
auto backup = singleton;
BaseMachine machine;
singleton = 0;
singleton = backup;
machine.load_schedule(progname, false);
return machine.domain;
return machine;
}
int BaseMachine::ring_size_from_schedule(string progname)
@@ -226,6 +245,15 @@ bigint BaseMachine::prime_from_schedule(string progname)
return 0;
}
int BaseMachine::security_from_schedule(string progname)
{
string sec = get_basics(progname).security;
if (sec.substr(0, 4).compare("sec:") == 0)
return stoi(sec.substr(4));
else
return 0;
}
NamedCommStats BaseMachine::total_comm()
{
NamedCommStats res;

View File

@@ -32,10 +32,13 @@ protected:
string compiler;
string domain;
string relevant_opts;
string security;
virtual size_t load_program(const string& threadname,
const string& filename);
static BaseMachine get_basics(string progname);
public:
static thread_local int thread_num;
@@ -58,12 +61,15 @@ public:
static int ring_size_from_schedule(string progname);
static int prime_length_from_schedule(string progname);
static bigint prime_from_schedule(string progname);
static int security_from_schedule(string progname);
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
template<class T>
static int edabit_batch_size(int n_bits, int buffer_size = 0);
static int edabit_bucket_size(int n_bits);
static int triple_bucket_size(DataFieldType type);
static int bucket_size(size_t usage);
BaseMachine();
virtual ~BaseMachine() {}

View File

@@ -1,5 +1,7 @@
#include "Processor/ExternalClients.h"
#include "Processor/OnlineOptions.h"
#include "Networking/ServerSocket.h"
#include "Networking/ssl_sockets.h"
#include <netinet/in.h>
#include <arpa/inet.h>
#include <thread>
@@ -25,6 +27,8 @@ ExternalClients::~ExternalClients()
}
if (ctx)
delete ctx;
for (auto it = peer_ctxs.begin(); it != peer_ctxs.end(); it++)
delete it->second;
}
void ExternalClients::start_listening(int portnum_base)
@@ -32,8 +36,9 @@ void ExternalClients::start_listening(int portnum_base)
ScopeLock _(lock);
client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num());
client_connection_servers[portnum_base]->init();
cerr << "Start listening on thread " << this_thread::get_id() << endl;
cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num())
if (OnlineOptions::singleton.verbose)
cerr << "Party " << get_party_num() << " is listening on port "
<< (portnum_base + get_party_num())
<< " for external client connections." << endl;
}
@@ -46,7 +51,6 @@ int ExternalClients::get_client_connection(int portnum_base)
cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl;
throw runtime_error("No connection on port " + to_string(portnum_base));
}
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
int client_id, socket;
string client;
socket = client_connection_servers[portnum_base]->get_connection_socket(
@@ -57,10 +61,38 @@ int ExternalClients::get_client_connection(int portnum_base)
external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket,
"C" + to_string(client_id), "P" + to_string(get_party_num()), false);
client_ports[client_id] = portnum_base;
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
if (OnlineOptions::singleton.verbose)
cerr << "Party " << get_party_num()
<< " received external client connection from client id: " << dec
<< client_id << endl;
return client_id;
}
int ExternalClients::init_client_connection(const string& host, int portnum,
int my_client_id)
{
ScopeLock _(lock);
int plain_socket;
set_up_client_socket(plain_socket, host.c_str(), portnum);
octetStream(to_string(my_client_id)).Send(plain_socket);
string my_client_name = "C" + to_string(my_client_id);
if (peer_ctxs.find(my_client_id) == peer_ctxs.end())
peer_ctxs[my_client_id] = new client_ctx(my_client_name);
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
true);
if (party_num == 0)
{
octetStream specification;
specification.Receive(socket);
}
int id = -1;
if (not external_client_sockets.empty())
id = min(id, external_client_sockets.begin()->first);
external_client_sockets[id] = socket;
return id;
}
void ExternalClients::close_connection(int client_id)
{
ScopeLock _(lock);

View File

@@ -32,6 +32,7 @@ class ExternalClients
ssl_service io_service;
client_ctx* ctx;
map<int, client_ctx*> peer_ctxs;
Lock lock;
@@ -43,6 +44,7 @@ class ExternalClients
void start_listening(int portnum_base);
int get_client_connection(int portnum_base);
int init_client_connection(const string& host, int portnum, int my_client_id);
void close_connection(int client_id);

View File

@@ -12,7 +12,6 @@
#include "OnlineMachine.hpp"
#include "OnlineOptions.hpp"
template<template<class U> class T, class V>
HonestMajorityFieldMachine<T, V>::HonestMajorityFieldMachine(int argc,
const char **argv)
@@ -34,6 +33,7 @@ template<template<class U> class T, template<class U> class V, class W, class X>
FieldMachine<T, V, W, X>::FieldMachine(int argc, const char** argv,
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
{
assert(nplayers or T<gfpvar>::variable_players);
W machine(argc, argv, opt, online_opts, X(), nplayers);
int n_limbs = online_opts.prime_limbs();
switch (n_limbs)

View File

@@ -10,11 +10,13 @@
#include "Math/gf2n.h"
#include "GC/instructions.h"
#include "Memory.hpp"
#include <iomanip>
template<class cgf2n>
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
vector<cgf2n>& memory, ArithmeticProcessor& Proc) const
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
{
auto& C2 = registers;
auto& M2C = memory;
@@ -123,6 +125,6 @@ ostream& operator<<(ostream& s, const Instruction& instr)
}
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
vector<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
vector<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;

View File

@@ -72,6 +72,7 @@ enum
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
CMDLINEARG = 0xEB,
// Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -153,7 +154,7 @@ enum
LISTEN = 0x6c,
ACCEPTCLIENTCONNECTION = 0x6d,
CLOSECLIENTCONNECTION = 0x6e,
READCLIENTPUBLICKEY = 0x6f,
INITCLIENTCONNECTION = 0x6f,
// Bitwise logic
ANDC = 0x70,
XORC = 0x71,
@@ -197,6 +198,7 @@ enum
PRINTREG = 0XB1,
RAND = 0xB2,
PRINTREGPLAIN = 0xB3,
PRINTREGPLAINS = 0xEA,
PRINTCHR = 0xB4,
PRINTSTR = 0xB5,
PUBINPUT = 0xB6,
@@ -345,6 +347,7 @@ protected:
int r[4]; // Fixed parameter registers
size_t n; // Possible immediate value
vector<int> start; // Values for a start/stop open
string str;
public:
virtual ~BaseInstruction() {};
@@ -387,7 +390,7 @@ public:
void execute(Processor<sint, sgf2n>& Proc) const;
template<class cgf2n>
void execute_clear_gf2n(vector<cgf2n>& registers, vector<cgf2n>& memory,
void execute_clear_gf2n(vector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
ArithmeticProcessor& Proc) const;
template<class cgf2n>

View File

@@ -105,7 +105,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case STMCBI:
case MOVC:
case MOVS:
case MOVSB:
case MOVINT:
case LDMINTI:
case STMINTI:
@@ -131,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case SHUFFLE:
case ACCEPTCLIENTCONNECTION:
case PREFIXSUMS:
case CMDLINEARG:
get_ints(r, s, 2);
break;
// instructions with 1 register operand
@@ -139,6 +139,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case RANDOMFULLS:
case PRINTREGPLAIN:
case PRINTREGPLAINB:
case PRINTREGPLAINS:
case LDTN:
case LDARG:
case STARG:
@@ -316,13 +317,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case TRUNC_PR:
case RUN_TAPE:
case CONV2DS:
case MATMULS:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
case MATMULS:
get_ints(r, s, 3);
get_vector(3, start, s);
break;
case MATMULSM:
get_ints(r, s, 3);
get_vector(9, start, s);
@@ -358,7 +356,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
n = get_int(s);
get_vector(num_var_args, start, s);
break;
case READCLIENTPUBLICKEY:
case INITCLIENTCONNECTION:
get_ints(r, s, 3);
get_string(str, s);
break;
case INITSECURESOCKET:
case RESPSECURESOCKET:
throw runtime_error("VM-controlled encryption not supported any more");
@@ -459,6 +460,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case CONVCBIT2S:
case NOTS:
case NOTCB:
case MOVSB:
n = get_int(s);
get_ints(r, s, 2);
break;
@@ -566,7 +568,7 @@ int BaseInstruction::get_reg_type() const
case MOVINT:
case READSOCKETINT:
case WRITESOCKETINT:
case READCLIENTPUBLICKEY:
case INITCLIENTCONNECTION:
case INITSECURESOCKET:
case RESPSECURESOCKET:
case LDARG:
@@ -584,6 +586,7 @@ int BaseInstruction::get_reg_type() const
case INTOUTPUT:
case ACCEPTCLIENTCONNECTION:
case GENSECSHUFFLE:
case CMDLINEARG:
return INT;
case PREP:
case GPREP:
@@ -723,6 +726,15 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
return res;
}
case MATMULS:
{
int res = 0;
for (auto it = start.begin(); it < start.end(); it += 6)
{
int tmp = *it + *(it + 3) * *(it + 5);
res = max(res, tmp);
}
return res;
}
case MATMULSM:
return r[0] + start[0] * start[2];
case CONV2DS:
@@ -817,7 +829,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
while (it < start.end())
{
int n = *it - n_prefix;
int size = DIV_CEIL(*(it + 1), 64);
size = max((long long) size, DIV_CEIL(*(it + 1), 64));
it += n_prefix;
assert(it + n <= start.end());
for (int i = 0; i < n; i++)
@@ -922,16 +934,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n));
n++;
break;
case LDMCI:
Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1])));
break;
case STMC:
Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0]));
n++;
break;
case STMCI:
Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0]));
break;
case MOVC:
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
break;
@@ -1089,10 +1095,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Proc2.POpen(*this);
return;
case MULS:
Proc.Procp.muls(start, size);
Proc.Procp.muls(start);
return;
case GMULS:
Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size);
Proc.Proc2.muls(start);
return;
case MULRS:
Proc.Procp.mulrs(start);
@@ -1107,7 +1113,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Proc2.dotprods(start, size);
return;
case MATMULS:
Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]);
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
return;
case MATMULSM:
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
@@ -1126,13 +1132,15 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Proc2.secure_shuffle(*this);
return;
case GENSECSHUFFLE:
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this));
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this,
Proc.machine.shuffle_store));
return;
case APPLYSHUFFLE:
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)));
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)),
Proc.machine.shuffle_store);
return;
case DELSHUFFLE:
Proc.Procp.delete_shuffle(Proc.read_Ci(r[0]));
Proc.machine.shuffle_store.del(Proc.read_Ci(r[0]));
return;
case INVPERM:
Proc.Procp.inverse_permutation(*this);
@@ -1170,6 +1178,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case PRINTREGPLAIN:
print(Proc.out, &Proc.read_Cp(r[0]));
return;
case PRINTREGPLAINS:
Proc.out << Proc.read_Sp(r[0]);
return;
case CONDPRINTPLAIN:
if (not Proc.read_Cp(r[0]).is_zero())
{
@@ -1237,6 +1248,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case PLAYERID:
Proc.write_Ci(r[0], Proc.P.my_num());
break;
case CMDLINEARG:
{
size_t idx = Proc.read_Ci(r[1]);
auto& args = OnlineOptions::singleton.args;
if (idx < args.size())
Proc.write_Ci(r[0], args[idx]);
else
{
cerr << idx << "-th command-line argument not given" << endl;
exit(1);
}
break;
}
// ***
// TODO: read/write shared GF(2^n) data instructions
// ***
@@ -1255,11 +1279,17 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
octetStream os;
os.store(int(sint::open_type::type_char()));
sint::specification(os);
sint::clear::specification(os);
os.Send(Proc.external_clients.get_socket(client_handle));
}
Proc.write_Ci(r[0], client_handle);
break;
}
case INITCLIENTCONNECTION:
Proc.write_Ci(r[0],
Proc.external_clients.init_client_connection(str,
Proc.read_Ci(r[1]), Proc.read_Ci(r[2])));
break;
case CLOSECLIENTCONNECTION:
Proc.external_clients.close_connection(Proc.read_Ci(r[0]));
break;

View File

@@ -20,6 +20,8 @@
#include "Tools/time-func.h"
#include "Tools/ExecutionStats.h"
#include "Protocols/SecureShuffle.h"
#include <vector>
#include <map>
#include <atomic>
@@ -70,6 +72,8 @@ class Machine : public BaseMachine
ExternalClients external_clients;
typename sint::Protocol::Shuffler::store_type shuffle_store;
static void init_binary_domains(int security_parameter, int lg2);
Machine(Names& playerNames, bool use_encryption = true,

View File

@@ -60,11 +60,21 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
{
OnlineOptions::singleton = opts;
if (N.num_players() == 1 and sint::is_real)
int min_players = 3 - sint::dishonest_majority;
if (sint::is_real)
{
cerr << "Need more than one player to run a protocol." << endl;
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
exit(1);
if (N.num_players() == 1)
{
cerr << "Need more than one player to run a protocol." << endl;
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
exit(1);
}
else if (N.num_players() < min_players)
{
cerr << "Need at least " << min_players << " players for this protocol."
<< endl;
exit(1);
}
}
// Set the prime modulus from command line or program if applicable
@@ -480,8 +490,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
if (opts.verbose)
{
cerr << "Communication details "
"(rounds in parallel threads counted double):" << endl;
cerr << "Communication details";
if (multithread)
cerr << " (rounds in parallel threads counted double)";
cerr << ":" << endl;
comm_stats.print();
cerr << "CPU time = " << proc_timer.elapsed();
if (multithread)
@@ -547,6 +559,14 @@ void Machine<sint, sgf2n>::run(const string& progname)
suggest_optimizations();
if (N.num_players() > 4)
{
string alt = sint::alt();
if (alt.size())
cerr << "This protocol doesn't scale well with the number of parties, "
<< "have you considered using " << alt << " instead?" << endl;
}
#ifdef VERBOSE
cerr << "End of prog" << endl;
#endif

View File

@@ -14,34 +14,90 @@ template<class T> istream& operator>>(istream& s,Memory<T>& M);
#include "Processor/Program.h"
#include "Tools/CheckVector.h"
#include "Tools/DiskVector.h"
template<class T>
class MemoryPart : public CheckVector<T>
class MemoryPart
{
public:
template<class U>
static void check_index(const vector<U>& M, size_t i)
virtual ~MemoryPart() {}
virtual size_t size() const = 0;
virtual void resize(size_t) = 0;
virtual T* data() = 0;
virtual const T* data() const = 0;
void check_index(size_t i) const
{
(void) M, (void) i;
(void) i;
#ifndef NO_CHECK_INDEX
if (i >= M.size())
throw overflow(U::type_string() + " memory", i, M.size());
if (i >= this->size())
throw overflow(T::type_string() + " memory", i, this->size());
#endif
}
virtual T& operator[](size_t i) = 0;
virtual const T& operator[](size_t i) const = 0;
virtual T& at(size_t i) = 0;
virtual const T& at(size_t i) const = 0;
template<class U>
void indirect_read(const Instruction& inst, vector<T>& regs,
const U& indices);
template<class U>
void indirect_write(const Instruction& inst, vector<T>& regs,
const U& indices);
void minimum_size(size_t size);
};
template<class T, template<class> class V>
class MemoryPartImpl : public MemoryPart<T>, public V<T>
{
public:
size_t size() const
{
return V<T>::size();
}
void resize(size_t size)
{
V<T>::resize(size);
}
T* data()
{
return V<T>::data();
}
const T* data() const
{
return V<T>::data();
}
T& operator[](size_t i)
{
check_index(*this, i);
return CheckVector<T>::operator[](i);
this->check_index(i);
return V<T>::operator[](i);
}
const T& operator[](size_t i) const
{
check_index(*this, i);
return CheckVector<T>::operator[](i);
this->check_index(i);
return V<T>::operator[](i);
}
void minimum_size(size_t size);
T& at(size_t i)
{
return V<T>::at(i);
}
const T& at(size_t i) const
{
return V<T>::at(i);
}
};
template<class T>
@@ -49,8 +105,11 @@ class Memory
{
public:
MemoryPart<T> MS;
MemoryPart<typename T::clear> MC;
MemoryPart<T>& MS;
MemoryPartImpl<typename T::clear, CheckVector> MC;
Memory();
~Memory();
void resize_s(size_t sz)
{ MS.resize(sz); }

View File

@@ -3,6 +3,54 @@
#include <fstream>
template<class T>
template<class U>
void MemoryPart<T>::indirect_read(const Instruction& inst,
vector<T>& regs, const U& indices)
{
size_t n = inst.get_size();
auto dest = regs.begin() + inst.get_r(0);
auto start = indices.begin() + inst.get_r(1);
#ifdef CHECK_SIZE
assert(start + n <= indices.end());
assert(dest + n <= regs.end());
#endif
long size = this->size();
const T* data = this->data();
for (auto it = start; it < start + n; it++)
{
#ifndef NO_CHECK_SIZE
if (*it >= size)
throw overflow(T::type_string() + " memory read", it->get(), size);
#endif
*dest++ = data[it->get()];
}
}
template<class T>
template<class U>
void MemoryPart<T>::indirect_write(const Instruction& inst,
vector<T>& regs, const U& indices)
{
size_t n = inst.get_size();
auto source = regs.begin() + inst.get_r(0);
auto start = indices.begin() + inst.get_r(1);
#ifdef CHECK_SIZE
assert(start + n <= indices.end());
assert(source + n <= regs.end());
#endif
long size = this->size();
T* data = this->data();
for (auto it = start; it < start + n; it++)
{
#ifndef NO_CHECK_SIZE
if (*it >= size)
throw overflow(T::type_string() + " memory write", it->get(), size);
#endif
data[it->get()] = *source++;
}
}
template<class T>
void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
const Program &program, const string& threadname)
@@ -29,6 +77,21 @@ void MemoryPart<T>::minimum_size(size_t size)
}
}
template<class T>
Memory<T>::Memory() :
MS(
*(OnlineOptions::singleton.disk_memory.size() ?
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, DiskVector>) :
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, CheckVector>)))
{
}
template<class T>
Memory<T>::~Memory()
{
delete &MS;
}
template<class T>
ostream& operator<<(ostream& s,const Memory<T>& M)
{

View File

@@ -71,18 +71,6 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op
"--ip-file-name" // Flag token.
);
if (nplayers == 0)
opt.add(
"2", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of players (default: 2). "
"Ignored if external server is used.", // Help description.
"-N", // Flag token.
"--nparties" // Flag token.
);
opt.add(
"", // Default.
0, // Required?

View File

@@ -22,12 +22,13 @@ OnlineOptions::OnlineOptions() : playerno(-1)
interactive = false;
lgp = gfp0::MAX_N_BITS;
live_prep = true;
batch_size = 10000;
batch_size = 1000;
memtype = "empty";
bits_from_squares = false;
direct = false;
bucket_size = 4;
security_parameter = DEFAULT_SECURITY;
use_security_parameter = false;
cmd_private_input_file = "Player-Data/Input";
cmd_private_output_file = "";
file_prep_per_thread = false;
@@ -46,6 +47,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
const char** argv, bool security) :
OnlineOptions()
{
use_security_parameter = security;
opt.syntax = std::string(argv[0]) + " [OPTIONS] [<playerno>] <progname>";
opt.add(
@@ -116,7 +119,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
("Security parameter (default: " + to_string(security_parameter)
("Statistical ecurity parameter (default: " + to_string(security_parameter)
+ ")").c_str(), // Help description.
"-S", // Flag token.
"--security" // Flag token.
@@ -138,7 +141,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
if (security)
{
opt.get("-S")->getInt(security_parameter);
cerr << "Using security parameter " << security_parameter << endl;
if (security_parameter <= 0)
{
cerr << "Invalid security parameter: " << security_parameter << endl;
@@ -280,7 +282,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
}
void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
const char** argv)
const char** argv, bool networking)
{
opt.resetArgs();
opt.parse(argc, argv);
@@ -292,17 +294,21 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
vector<string> badOptions;
unsigned int i;
opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html "
"for documentation on the networking setup.\n";
if (networking)
opt.footer += "See also "
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
"for documentation on the networking setup.\n\n";
if (allArgs.size() != 3u - opt.isSet("-p"))
size_t name_index = 1 + networking - opt.isSet("-p");
if (allArgs.size() < name_index + 1)
{
opt.getUsage(usage);
cout << usage;
cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl;
cerr << "Arguments given were:\n";
for (unsigned int j = 1; j < allArgs.size(); j++)
cout << "'" << *allArgs[j] << "'" << endl;
opt.getUsage(usage);
cout << usage;
exit(1);
}
else
@@ -311,25 +317,25 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
opt.get("-p")->getInt(playerno);
else
sscanf((*allArgs[1]).c_str(), "%d", &playerno);
progname = *allArgs[2 - opt.isSet("-p")];
progname = *allArgs.at(name_index);
}
if (!opt.gotRequired(badOptions))
{
for (i = 0; i < badOptions.size(); ++i)
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
for (i = 0; i < badOptions.size(); ++i)
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
exit(1);
}
if (!opt.gotExpected(badOptions))
{
opt.getUsage(usage);
cout << usage;
for (i = 0; i < badOptions.size(); ++i)
cerr << "ERROR: Got unexpected number of arguments for option "
<< badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
exit(1);
}
@@ -347,6 +353,22 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
prime = schedule_prime;
}
for (size_t i = name_index + 1; i < allArgs.size(); i++)
{
try
{
args.push_back(stol(*allArgs[i]));
}
catch (exception& e)
{
opt.getUsage(usage);
cerr << usage;
cerr << "Additional argument has to be integer: " << *allArgs[i]
<< endl;
exit(1);
}
}
// ignore program if length explicitly set from command line
if (opt.get("-lgp") and not opt.isSet("-lgp"))
{
@@ -367,7 +389,29 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
if (o)
o->getInt(max_broadcast);
o = opt.get("--disk-memory");
if (o)
o->getString(disk_memory);
receive_threads = opt.isSet("--threads");
if (use_security_parameter)
{
int program_sec = BaseMachine::security_from_schedule(progname);
if (program_sec > 0)
{
if (not opt.isSet("-S"))
security_parameter = program_sec;
if (program_sec < security_parameter)
{
cerr << "Security parameter used in compilation is insufficient" << endl;
exit(1);
}
}
cerr << "Using statistical security parameter " << security_parameter << endl;
}
}
void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt)

View File

@@ -27,6 +27,7 @@ public:
bool direct;
int bucket_size;
int security_parameter;
bool use_security_parameter;
std::string cmd_private_input_file;
std::string cmd_private_output_file;
bool verbose;
@@ -34,6 +35,8 @@ public:
int trunc_error;
int opening_sum, max_broadcast;
bool receive_threads;
std::string disk_memory;
vector<long> args;
OnlineOptions();
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
@@ -48,7 +51,8 @@ public:
OnlineOptions(T);
~OnlineOptions() {}
void finalize(ez::ezOptionParser& opt, int argc, const char** argv);
void finalize(ez::ezOptionParser& opt, int argc, const char** argv,
bool networking = true);
void set_trunc_error(ez::ezOptionParser& opt);

View File

@@ -11,7 +11,7 @@
template<class T>
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
const char** argv, T, bool default_live_prep) :
OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0,
OnlineOptions(opt, argc, argv, OnlineOptions(T()).batch_size,
default_live_prep, T::clear::prime_field)
{
if (T::has_trunc_pr)
@@ -56,13 +56,39 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"--max-broadcast" // Flag token.
);
}
if (not T::clear::binary)
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use directory on disk for memory (container data structures) "
"instead of RAM", // Help description.
"-D", // Flag token.
"--disk-memory" // Flag token.
);
if (T::variable_players)
opt.add(
T::dishonest_majority ? "2" : "3", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
("Number of players (default: "
+ (T::dishonest_majority ?
to_string("2") : to_string("3")) + "). " +
"Ignored if external server is used.").c_str(), // Help description.
"-N", // Flag token.
"--nparties" // Flag token.
);
}
template<class T>
OnlineOptions::OnlineOptions(T) : OnlineOptions()
{
if (T::dishonest_majority)
batch_size = 1000;
if (not T::dishonest_majority)
batch_size = 10000;
}
#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */

View File

@@ -36,7 +36,7 @@ class SubProcessor
void resize(size_t size) { C.resize(size); S.resize(size); }
void matmulsm_prep(int ii, int j, const CheckVector<T>& source,
void matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
const vector<int>& dim, size_t a, size_t b);
void matmulsm_finalize(int i, int j, const vector<int>& dim,
typename vector<T>::iterator C);
@@ -48,6 +48,8 @@ class SubProcessor
typedef typename T::bit_type::part_type BT;
typedef typename T::Protocol::Shuffler::store_type ShuffleStore;
public:
ArithmeticProcessor* Proc;
typename T::MAC_Check& MC;
@@ -71,19 +73,19 @@ public:
// Access to PO (via calls to POpen start/stop)
void POpen(const Instruction& inst);
void muls(const vector<int>& reg, int size);
void muls(const vector<int>& reg);
void mulrs(const vector<int>& reg);
void dotprods(const vector<int>& reg, int size);
void matmuls(const vector<T>& source, const Instruction& instruction, size_t a,
size_t b);
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, size_t a,
void matmuls(const vector<T>& source, const Instruction& instruction);
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction, size_t a,
size_t b);
void conv2ds(const Instruction& instruction);
void secure_shuffle(const Instruction& instruction);
size_t generate_secure_shuffle(const Instruction& instruction);
void apply_shuffle(const Instruction& instruction, int handle);
void delete_shuffle(int handle);
size_t generate_secure_shuffle(const Instruction& instruction,
ShuffleStore& shuffle_store);
void apply_shuffle(const Instruction& instruction, int handle,
ShuffleStore& shuffle_store);
void inverse_permutation(const Instruction& instruction);
void input_personal(const vector<int>& args);
@@ -116,7 +118,7 @@ public:
class ArithmeticProcessor : public ProcessorBase
{
protected:
CheckVector<long> Ci;
CheckVector<Integer> Ci;
ofstream public_output;
ofstream binary_output;
@@ -162,13 +164,13 @@ public:
return thread_num;
}
const long& read_Ci(size_t i) const
{ return Ci[i]; }
long& get_Ci_ref(size_t i)
long read_Ci(size_t i) const
{ return Ci[i].get(); }
Integer& get_Ci_ref(size_t i)
{ return Ci[i]; }
void write_Ci(size_t i, const long& x)
{ Ci[i]=x; }
CheckVector<long>& get_Ci()
CheckVector<Integer>& get_Ci()
{ return Ci; }
virtual ofstream& get_public_output()

View File

@@ -379,9 +379,20 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id,
client_timer.stop();
client_stats.add(socket_stream.get_length());
for (int j = 0; j < size; j++)
for (int i = 0; i < m; i++)
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
int j, i;
try
{
for (j = 0; j < size; j++)
for (i = 0; i < m; i++)
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
}
catch (exception& e)
{
throw insufficient_shares(m * size, j * m + i, e);
}
if (socket_stream.left())
throw runtime_error("unexpected share data");
}
@@ -468,28 +479,29 @@ void SubProcessor<T>::POpen(const Instruction& inst)
}
template<class T>
void SubProcessor<T>::muls(const vector<int>& reg, int size)
void SubProcessor<T>::muls(const vector<int>& reg)
{
assert(reg.size() % 3 == 0);
int n = reg.size() / 3;
assert(reg.size() % 4 == 0);
int n = reg.size() / 4;
SubProcessor<T>& proc = *this;
protocol.init_mul();
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++)
for (int j = 0; j < reg[4 * i]; j++)
{
auto& x = proc.S[reg[3 * i + 1] + j];
auto& y = proc.S[reg[3 * i + 2] + j];
auto& x = proc.S[reg[4 * i + 2] + j];
auto& y = proc.S[reg[4 * i + 3] + j];
protocol.prepare_mul(x, y);
}
protocol.exchange();
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++)
{
for (int j = 0; j < reg[4 * i]; j++)
{
proc.S[reg[3 * i] + j] = protocol.finalize_mul();
proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul();
}
protocol.counter += n * size;
protocol.counter += n * reg[4 * i];
}
}
template<class T>
@@ -553,33 +565,46 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
template<class T>
void SubProcessor<T>::matmuls(const vector<T>& source,
const Instruction& instruction, size_t a, size_t b)
const Instruction& instruction)
{
auto& dim = instruction.get_start();
auto A = source.begin() + a;
auto B = source.begin() + b;
auto C = S.begin() + (instruction.get_r(0));
assert(A + dim[0] * dim[1] <= source.end());
assert(B + dim[1] * dim[2] <= source.end());
assert(C + dim[0] * dim[2] <= S.end());
protocol.init_dotprod();
for (int i = 0; i < dim[0]; i++)
for (int j = 0; j < dim[2]; j++)
{
for (int k = 0; k < dim[1]; k++)
protocol.prepare_dotprod(*(A + i * dim[1] + k),
*(B + k * dim[2] + j));
protocol.next_dotprod();
}
auto& start = instruction.get_start();
assert(start.size() % 6 == 0);
for(auto it = start.begin(); it < start.end(); it += 6)
{
auto dim = it + 3;
auto A = source.begin() + *(it + 1);
auto B = source.begin() + *(it + 2);
assert(A + dim[0] * dim[1] <= source.end());
assert(B + dim[1] * dim[2] <= source.end());
for (int i = 0; i < dim[0]; i++)
for (int j = 0; j < dim[2]; j++)
{
for (int k = 0; k < dim[1]; k++)
protocol.prepare_dotprod(*(A + i * dim[1] + k),
*(B + k * dim[2] + j));
protocol.next_dotprod();
}
}
protocol.exchange();
for (int i = 0; i < dim[0]; i++)
for (int j = 0; j < dim[2]; j++)
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
for(auto it = start.begin(); it < start.end(); it += 6)
{
auto C = S.begin() + *it;
auto dim = it + 3;
assert(C + dim[0] * dim[2] <= S.end());
for (int i = 0; i < dim[0]; i++)
for (int j = 0; j < dim[2]; j++)
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
}
}
template<class T>
void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
const Instruction& instruction, size_t a, size_t b)
{
auto& dim = instruction.get_start();
@@ -592,7 +617,7 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
protocol.init_dotprod();
for (int i = 0; i < dim[0]; i++)
{
auto ii = Proc->get_Ci().at(dim[3] + i);
auto ii = Proc->get_Ci().at(dim[3] + i).get();
for (int j = 0; j < dim[2]; j++)
{
#ifdef DEBUG_MATMULSM
@@ -628,16 +653,21 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
}
template<class T>
void SubProcessor<T>::matmulsm_prep(int ii, int j, const CheckVector<T>& source,
void SubProcessor<T>::matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
const vector<int>& dim, size_t a, size_t b)
{
auto jj = Proc->get_Ci().at(dim[6] + j);
auto jj = Proc->get_Ci().at(dim[6] + j).get();
const T* base = source.data();
size_t size = source.size();
for (int k = 0; k < dim[1]; k++)
{
auto kk = Proc->get_Ci().at(dim[4] + k);
auto ll = Proc->get_Ci().at(dim[5] + k);
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
source.at(b + ll * dim[8] + jj));
auto kk = Proc->get_Ci().at(dim[4] + k).get();
auto ll = Proc->get_Ci().at(dim[5] + k).get();
auto aa = a + ii * dim[7] + kk;
auto bb = b + ll * dim[8] + jj;
assert(aa < size);
assert(bb < size);
protocol.prepare_dotprod(base[aa], base[bb]);
}
protocol.next_dotprod();
}
@@ -655,16 +685,22 @@ void SubProcessor<T>::matmulsm_finalize(int i, int j, const vector<int>& dim,
template<class T>
void SubProcessor<T>::conv2ds(const Instruction& instruction)
{
protocol.init_dotprod();
auto& args = instruction.get_start();
vector<Conv2dTuple> tuples;
for (size_t i = 0; i < args.size(); i += 15)
tuples.push_back(Conv2dTuple(args, i));
for (auto& tuple : tuples)
tuple.pre(S, protocol);
protocol.exchange();
for (auto& tuple : tuples)
tuple.post(S, protocol);
size_t done = 0;
while (done < tuples.size())
{
protocol.init_dotprod();
size_t i;
for (i = done; i < tuples.size() and protocol.get_buffer_size() <
OnlineOptions::singleton.batch_size; i++)
tuples[i].pre(S, protocol);
protocol.exchange();
for (; done < i; done++)
tuples[done].post(S, protocol);
}
}
inline
@@ -766,25 +802,22 @@ void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
}
template<class T>
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction)
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
ShuffleStore& shuffle_store)
{
return shuffler.generate(instruction.get_n());
return shuffler.generate(instruction.get_n(), shuffle_store);
}
template<class T>
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle)
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle,
ShuffleStore& shuffle_store)
{
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
instruction.get_start()[0], instruction.get_start()[1], handle,
instruction.get_start()[0], instruction.get_start()[1],
shuffle_store.get(handle),
instruction.get_start()[4]);
}
template<class T>
void SubProcessor<T>::delete_shuffle(int handle)
{
shuffler.del(handle);
}
template<class T>
void SubProcessor<T>::inverse_permutation(const Instruction& instruction) {
shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0],
@@ -796,17 +829,26 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
{
input.reset_all(P);
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
if (args[i + 1] == P.my_num())
{
if (args[i + 1] == P.my_num())
input.add_mine(C[args[i + 3] + j]);
else
input.add_other(args[i + 1]);
auto begin = C.begin() + args[i + 3];
auto end = begin + args[i];
assert(end <= C.end());
for (auto it = begin; it < end; it++)
input.add_mine(*it);
}
else
for (int j = 0; j < args[i]; j++)
input.add_other(args[i + 1]);
input.exchange();
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
S[args[i + 2] + j] = input.finalize(args[i + 1]);
{
auto begin = S.begin() + args[i + 2];
auto end = begin + args[i];
assert(end <= S.end());
for (auto it = begin; it < end; it++)
*it = input.finalize(args[i + 1]);
}
}
/**
@@ -858,6 +900,16 @@ typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
return inverses2m[m];
}
template<class T, class U>
void fixinput_int(T& proc, const Instruction& instruction, U)
{
U* x = new U[instruction.get_size()];
proc.binary_input.read((char*) x, sizeof(U) * instruction.get_size());
for (int i = 0; i < instruction.get_size(); i++)
proc.write_Cp(instruction.get_r(0) + i, x[i]);
delete[] x;
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
{
@@ -878,19 +930,24 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
throw runtime_error("unknown format for fixed-point input");
}
for (int i = 0; i < instruction.get_size(); i++)
if (binary_input.fail())
throw IO_Error("failure reading from " + binary_input_filename);
if (binary_input.peek() == EOF)
throw IO_Error("not enough inputs in " + binary_input_filename);
if (instruction.get_r(2) == 0)
{
if (binary_input.peek() == EOF)
throw IO_Error("not enough inputs in " + binary_input_filename);
double buf;
if (instruction.get_r(2) == 0)
{
int64_t x;
binary_input.read((char*) &x, sizeof(x));
tmp = x;
}
if (instruction.get_r(1) == 1)
fixinput_int(*this, instruction, int8_t());
else
fixinput_int(*this, instruction, int64_t());
}
else
{
for (int i = 0; i < instruction.get_size(); i++)
{
double buf;
if (use_double)
binary_input.read((char*) &buf, sizeof(double));
else
@@ -900,11 +957,12 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
buf = x;
}
tmp = bigint::tmp = round(buf * exp2(instruction.get_r(1)));
write_Cp(instruction.get_r(0) + i, tmp);
}
if (binary_input.fail())
throw IO_Error("failure reading from " + binary_input_filename);
write_Cp(instruction.get_r(0) + i, tmp);
}
if (binary_input.fail())
throw IO_Error("failure reading from " + binary_input_filename);
}
}

View File

@@ -14,11 +14,12 @@ using namespace std;
#include "Tools/ExecutionStats.h"
#include "Tools/SwitchableOutput.h"
#include "OnlineOptions.h"
#include "Math/Integer.h"
class ProcessorBase
{
// Stack
stack<long> stacki;
stack<Integer> stacki;
ifstream input_file;
string input_filename;
@@ -26,7 +27,7 @@ class ProcessorBase
protected:
// Optional argument to tape
int arg;
Integer arg;
string get_parameterized_filename(int my_num, int thread_num,
const string& prefix);
@@ -38,15 +39,15 @@ public:
ProcessorBase();
void pushi(long x) { stacki.push(x); }
void popi(long& x) { x = stacki.top(); stacki.pop(); }
void pushi(Integer x) { stacki.push(x); }
void popi(Integer& x) { x = stacki.top(); stacki.pop(); }
int get_arg() const
Integer get_arg() const
{
return arg;
}
void set_arg(int new_arg)
void set_arg(Integer new_arg)
{
arg=new_arg;
}

View File

@@ -41,6 +41,7 @@ template<template<int L> class U, template<class T> class V, class W>
RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
{
assert(nplayers or U<64>::variable_players);
RingOptions opts(opt, argc, argv);
W machine(argc, argv, opt, online_opts, gf2n(), nplayers);
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);
@@ -65,7 +66,7 @@ template<template<int K, int S> class U, template<class T> class V>
HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecurity(
int argc, const char** argv, ez::ezOptionParser& opt)
{
OnlineOptions online_opts(opt, argc, argv);
OnlineOptions online_opts(opt, argc, argv, U<64, 40>());
RingOptions opts(opt, argc, argv);
HonestMajorityMachine machine(argc, argv, opt, online_opts);
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);

View File

@@ -18,10 +18,10 @@
*dest++ = *source++) \
X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \
*dest++ = *source++) \
X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
*dest++ = Proc.machine.Mp.read_S(*source++)) \
X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
Proc.machine.Mp.write_S(*dest++, *source++)) \
X(LDMSI, Proc.machine.Mp.MS.indirect_read(instruction, Procp.get_S(), Proc.get_Ci()),) \
X(STMSI, Proc.machine.Mp.MS.indirect_write(instruction, Procp.get_S(), Proc.get_Ci()),) \
X(LDMCI, Proc.machine.Mp.MC.indirect_read(instruction, Procp.get_C(), Proc.get_Ci()),) \
X(STMCI, Proc.machine.Mp.MC.indirect_write(instruction, Procp.get_C(), Proc.get_Ci()),) \
X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \
*dest++ = *source++) \
X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
@@ -121,10 +121,8 @@
*dest++ = *source++) \
X(GSTMS, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.machine.M2.MS[n], \
*dest++ = *source++) \
X(GLDMSI, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
*dest++ = Proc.machine.M2.read_S(*source++)) \
X(GSTMSI, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
Proc.machine.M2.write_S(*dest++, *source++)) \
X(GLDMSI, Proc.machine.M2.MS.indirect_read(instruction, Proc2.get_S(), Proc.get_Ci()),) \
X(GSTMSI, Proc.machine.M2.MS.indirect_write(instruction, Proc2.get_S(), Proc.get_Ci()),) \
X(GMOVS, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc2.get_S()[r[1]], \
*dest++ = *source++) \
X(GADDS, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \
@@ -171,10 +169,8 @@
*dest++ = (*source).get(); source++) \
X(STMINT, auto dest = &Mi[n]; auto source = &Proc.get_Ci()[r[0]], \
*dest++ = *source++) \
X(LDMINTI, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
*dest++ = Mi[*source].get(); source++) \
X(STMINTI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &Ci[r[0]], \
Mi[*dest] = *source++; dest++) \
X(LDMINTI, Mi.indirect_read(*this, Proc.get_Ci(), Proc.get_Ci()),) \
X(STMINTI, Mi.indirect_write(*this, Proc.get_Ci(), Proc.get_Ci()),) \
X(MOVINT, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
*dest++ = *source++) \
X(PUSHINT, Proc.pushi(Ci[r[0]]),) \
@@ -213,7 +209,7 @@
X(SHUFFLE, shuffle(Proc),) \
X(BITDECINT, bitdecint(Proc),) \
X(RAND, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], \
*dest++ = Proc.shared_prng.get_uint() % (1 << *source++)) \
*dest++ = Proc.shared_prng.get_uint() % (1 << (*source++).get())) \
#define CLEAR_GF2N_INSTRUCTIONS \
X(GLDI, auto dest = &C2[r[0]]; cgf2n tmp = int(n), \
@@ -222,10 +218,8 @@
*dest++ = (*source).get(); source++) \
X(GSTMC, auto dest = &M2C[n]; auto source = &C2[r[0]], \
*dest++ = *source++) \
X(GLDMCI, auto dest = &C2[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
*dest++ = M2C[*source++]) \
X(GSTMCI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &C2[r[0]], \
M2C[*dest++] = *source++) \
X(GLDMCI, M2C.indirect_read(*this, C2, Proc.get_Ci()),) \
X(GSTMCI, M2C.indirect_write(*this, C2, Proc.get_Ci()),) \
X(GMOVC, auto dest = &C2[r[0]]; auto source = &C2[r[1]], \
*dest++ = *source++) \
X(GADDC, auto dest = &C2[r[0]]; auto op1 = &C2[r[1]]; \
@@ -288,9 +282,7 @@
#define REMAINING_INSTRUCTIONS \
X(CONVMODP, throw not_implemented(),) \
X(LDMC, throw not_implemented(),) \
X(LDMCI, throw not_implemented(),) \
X(STMC, throw not_implemented(),) \
X(STMCI, throw not_implemented(),) \
X(MOVC, throw not_implemented(),) \
X(DIVC, throw not_implemented(),) \
X(GDIVC, throw not_implemented(),) \
@@ -390,6 +382,8 @@
X(APPLYSHUFFLE, throw not_implemented(),) \
X(DELSHUFFLE, throw not_implemented(),) \
X(ACTIVE, throw not_implemented(),) \
X(FIXINPUT, throw not_implemented(),) \
X(CONCATS, throw not_implemented(),) \
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS

View File

@@ -51,7 +51,7 @@ except:
pass
layers = [
ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [n_examples, 24, 24, 16],
ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [N, 24, 24, 16],
(1, 1), 'VALID'),
ml.MaxPool([N, 24, 24, 16]),
ml.Relu([N, 12, 12, 16]),

View File

@@ -0,0 +1,34 @@
import random
n_nodes_per_party = int(program.args[1])
n_threads_per_node = int(program.args[2])
n_ops_per_thread = int(program.args[3])
n_ops_per_node = n_threads_per_node * n_ops_per_thread
n_ops = n_nodes_per_party * n_ops_per_node
data = Array.create_from(sint(regint.inc(n_ops)))
listen_for_clients(15000)
ready = regint.Array(n_nodes_per_party)
@for_range(n_nodes_per_party)
def _(i):
ready[accept_client_connection(15000)] = 1
runtime_error_if(sum(ready) != n_nodes_per_party, 'connection problems')
@for_range(n_nodes_per_party)
def _(i):
data.get_vector(base=i * n_ops_per_node,
size=n_ops_per_node).write_fully_to_socket(i)
@for_range(n_nodes_per_party)
def _(i):
data.assign_vector(sint.read_from_socket(i, size=n_ops_per_node),
base=i * n_ops_per_node)
for i in range(10):
index = random.randrange(n_ops)
value = data[index].reveal()
runtime_error_if(value != index ** 2, '%s != %s', value, index ** 2)

View File

@@ -0,0 +1,21 @@
n_threads = int(program.args[1])
n_ops_per_thread = int(program.args[2])
worker_id = int(program.args[3])
if len(program.args) > 4:
host = program.args[4]
else:
host = 'localhost'
n_ops = n_threads * n_ops_per_thread
data = sint.Array(n_ops)
main = init_client_connection(host, 15000, worker_id)
data.read_from_socket(main)
@for_range_opt_multithread(n_threads, n_ops)
def _(i):
data[i] = data[i] ** 2
data.write_to_socket(main)

View File

@@ -1,4 +1,5 @@
# sint: secret integers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint
# you can assign public numbers to sint
@@ -14,6 +15,7 @@ def test(actual, expected):
# private inputs are read from Player-Data/Input-P<i>-0
# or from standard input if using command-line option -I
# see https://mp-spdz.readthedocs.io/en/latest/io.html for more options
for i in 0, 1:
print_ln('got %s from player %s', sint.get_input_from(i).reveal(), i)
@@ -62,6 +64,7 @@ test(a[99], 99 * 98)
# test(a, 99)
# sfix: fixed-point numbers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix
# set the precision after the dot and in total

View File

@@ -35,6 +35,11 @@ public:
typedef GC::AtlasSecret bit_type;
#endif
static string alt()
{
return "";
}
AtlasShare()
{
}

View File

@@ -7,6 +7,7 @@
#define PROTOCOLS_FAKEPROTOCOL_H_
#include "Replicated.h"
#include "SecureShuffle.h"
#include "Math/Z2k.h"
#include "Processor/Instruction.h"
#include "Processor/TruncPrTuple.h"
@@ -17,6 +18,8 @@ template<class T>
class FakeShuffle
{
public:
typedef ShuffleStore<int> store_type;
FakeShuffle(SubProcessor<T>&)
{
}
@@ -27,9 +30,9 @@ public:
apply(a, n, unit_size, output_base, input_base, 0, 0);
}
size_t generate(size_t)
size_t generate(size_t, store_type& store)
{
return 0;
return store.add();
}
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
@@ -49,10 +52,6 @@ public:
}
}
void del(size_t)
{
}
void inverse_permutation(vector<T>&, size_t, size_t, size_t)
{
}
@@ -280,6 +279,19 @@ public:
}
}
}
else if (tag == string("EQZ\0", 4))
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
assert(args[i] == 6);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
res = processor.get_S()[args[i + 3] + j] == 0;
}
}
}
else if (tag == "Trun")
{
for (size_t i = 0; i < args.size(); i += args[i])

View File

@@ -35,6 +35,7 @@ public:
static const bool dishonest_majority = false;
static const bool malicious = false;
static const bool is_real = false;
static const bool variable_players = false;
static string type_short()
{

View File

@@ -33,7 +33,7 @@ public:
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
void matmulsm(SubProcessor<T>& processor, MemoryPart<T>& source,
const Instruction& instruction, int a, int b);
void conv2ds(SubProcessor<T>& processor, const Instruction& instruction);
};

View File

@@ -33,7 +33,7 @@ typename T::MatrixPrep& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
}
template<class T>
void Hemi<T>::matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
void Hemi<T>::matmulsm(SubProcessor<T>& processor, MemoryPart<T>& source,
const Instruction& instruction, int a, int b)
{
if (HemiOptions::singleton.plain_matmul
@@ -61,16 +61,16 @@ void Hemi<T>::matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
for (int i = 0; i < dim[0]; i++)
for (int k = 0; k < dim[1]; k++)
{
auto kk = Proc->get_Ci().at(dim[4] + k);
auto ii = Proc->get_Ci().at(dim[3] + i);
auto kk = Proc->get_Ci().at(dim[4] + k).get();
auto ii = Proc->get_Ci().at(dim[3] + i).get();
A.entries.v.push_back(source.at(a + ii * dim[7] + kk));
}
for (int k = 0; k < dim[1]; k++)
for (int j = 0; j < dim[2]; j++)
{
auto jj = Proc->get_Ci().at(dim[6] + j);
auto ll = Proc->get_Ci().at(dim[5] + k);
auto jj = Proc->get_Ci().at(dim[6] + j).get();
auto ll = Proc->get_Ci().at(dim[5] + k).get();
B.entries.v.push_back(source.at(b + ll * dim[8] + jj));
}

View File

@@ -34,6 +34,11 @@ public:
static const bool local_mul = true;
static true_type triple_matmul;
static string alt()
{
return "Temi";
}
HemiShare()
{
}

View File

@@ -250,8 +250,11 @@ TreeSum<T>::~TreeSum()
template<class T>
void TreeSum<T>::run(vector<T>& values, const Player& P)
{
start(values, P);
finish(values, P);
if (not values.empty())
{
start(values, P);
finish(values, P);
}
}
template<class T>
@@ -300,9 +303,11 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
P.wait_receive(sender, oss[j]);
MC.player_timers[sender].stop();
MC.timers[SUM].start();
T tmp = values.at(0);
for (unsigned int i=0; i<values.size(); i++)
{
values[i].add(oss[j], use_lengths ? lengths[i] : -1);
tmp.unpack(oss[j], use_lengths ? lengths[i] : -1);
values[i] += tmp;
}
post_add_process(values);
MC.timers[SUM].stop();

View File

@@ -176,7 +176,7 @@ void MAC_Check_<U>::Check(const Player& P)
for (auto& os : bundle)
if (&os != &bundle.mine)
delta += os.get<typename U::mac_type>();
if (not delta.is_zero())
if (delta != 0)
throw mac_fail();
}
}
@@ -194,8 +194,6 @@ void MAC_Check_<U>::Check(const Player& P)
typename U::mac_type a,gami,temp;
typename U::mac_type::Scalar h;
vector<typename U::mac_type> tau(P.num_players());
a.assign_zero();
gami.assign_zero();
for (int i=0; i<popen_cnt; i++)
{
h.almost_randomize(G);
@@ -217,10 +215,10 @@ void MAC_Check_<U>::Check(const Player& P)
//cerr << "\tFinal Check" << endl;
typename U::mac_type t;
t.assign_zero();
for (int i=0; i<P.num_players(); i++)
{ t += tau[i]; }
if (!t.is_zero()) { throw mac_fail(); }
if (t != 0)
throw mac_fail();
}
vals.erase(vals.begin(), vals.begin() + popen_cnt);

View File

@@ -136,6 +136,12 @@ TripleShuffleSacrifice<T>::TripleShuffleSacrifice(int B, int C) :
{
}
template<class T>
TripleShuffleSacrifice<T>::TripleShuffleSacrifice(DataFieldType type) :
ShuffleSacrifice(BaseMachine::bucket_size(type))
{
}
template<class T>
void TripleShuffleSacrifice<T>::triple_sacrifice(vector<array<T, 3>>& triples,
vector<array<T, 3>>& check_triples, Player& P,

View File

@@ -44,6 +44,10 @@ public:
// default private output facility (using input tuples)
typedef ::PrivateOutput<NoShare> PrivateOutput;
// indicate whether protocol allows dishonest majority and variable players
static const bool dishonest_majority = true;
static const bool variable_players = true;
// description used for debugging output
static string type_string()
{
@@ -187,4 +191,11 @@ public:
}
};
template<class T>
inline ostream& operator<<(ostream& o, NoShare<T>)
{
throw runtime_error("no output");
return o;
}
#endif /* PROTOCOLS_NOSHARE_H_ */

View File

@@ -6,12 +6,17 @@
#ifndef PROTOCOLS_REP3SHUFFLER_H_
#define PROTOCOLS_REP3SHUFFLER_H_
#include "SecureShuffle.h"
template<class T>
class Rep3Shuffler
{
SubProcessor<T>& proc;
public:
typedef array<vector<int>, 2> shuffle_type;
typedef ShuffleStore<shuffle_type> store_type;
vector<array<vector<int>, 2>> shuffles;
private:
SubProcessor<T>& proc;
public:
Rep3Shuffler(vector<T>& a, size_t n, int unit_size, size_t output_base,
@@ -19,15 +24,13 @@ public:
Rep3Shuffler(SubProcessor<T>& proc);
int generate(int n_shuffle);
int generate(int n_shuffle, store_type& store);
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int handle, bool reverse);
size_t input_base, shuffle_type& shuffle, bool reverse);
void inverse_permutation(vector<T>& stack, size_t n, size_t output_base,
size_t input_base);
void del(int handle);
};
#endif /* PROTOCOLS_REP3SHUFFLER_H_ */

View File

@@ -13,9 +13,10 @@ Rep3Shuffler<T>::Rep3Shuffler(vector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T>& proc) :
proc(proc)
{
apply(a, n, unit_size, output_base, input_base, generate(n / unit_size),
store_type store;
int handle = generate(n / unit_size, store);
apply(a, n, unit_size, output_base, input_base, store.get(handle),
false);
shuffles.pop_back();
}
template<class T>
@@ -25,10 +26,10 @@ Rep3Shuffler<T>::Rep3Shuffler(SubProcessor<T>& proc) :
}
template<class T>
int Rep3Shuffler<T>::generate(int n_shuffle)
int Rep3Shuffler<T>::generate(int n_shuffle, store_type& store)
{
shuffles.push_back({});
auto& shuffle = shuffles.back();
int res = store.add();
auto& shuffle = store.get(res);
for (int i = 0; i < 2; i++)
{
auto& perm = shuffle[i];
@@ -40,19 +41,22 @@ int Rep3Shuffler<T>::generate(int n_shuffle)
swap(perm[k], perm[k + j]);
}
}
return shuffles.size() - 1;
return res;
}
template<class T>
void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, int handle, bool reverse)
size_t output_base, size_t input_base, shuffle_type& shuffle,
bool reverse)
{
assert(proc.P.num_players() == 3);
assert(not T::malicious);
assert(not T::dishonest_majority);
assert(n % unit_size == 0);
auto& shuffle = shuffles.at(handle);
if (shuffle.empty())
throw runtime_error("shuffle has been deleted");
vector<T> to_shuffle;
for (size_t i = 0; i < n; i++)
to_shuffle.push_back(a[input_base + i]);
@@ -115,12 +119,6 @@ void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
a[output_base + i] = to_shuffle[i];
}
template<class T>
void Rep3Shuffler<T>::del(int handle)
{
shuffles.at(handle) = {};
}
template<class T>
void Rep3Shuffler<T>::inverse_permutation(vector<T>&, size_t, size_t, size_t)
{

View File

@@ -41,6 +41,7 @@ public:
typedef GC::Rep4Secret bit_type;
static const bool malicious = true;
static const bool variable_players = false;
static string type_short()
{

View File

@@ -15,6 +15,7 @@ using namespace std;
#include "Tools/random.h"
#include "Tools/PointerVector.h"
#include "Networking/Player.h"
#include "Processor/Memory.h"
template<class T> class SubProcessor;
template<class T> class ReplicatedMC;
@@ -68,8 +69,6 @@ public:
ProtocolBase();
virtual ~ProtocolBase();
void muls(const vector<int>& reg, SubProcessor<T>& proc, typename T::MAC_Check& MC,
int size);
void mulrs(const vector<int>& reg, SubProcessor<T>& proc);
void multiply(vector<T>& products, vector<pair<T, T>>& multiplicands,
@@ -111,7 +110,7 @@ public:
virtual void randoms_inst(vector<T>&, const Instruction&);
template<int = 0>
void matmulsm(SubProcessor<T> & proc, CheckVector<T>& source,
void matmulsm(SubProcessor<T> & proc, MemoryPart<T>& source,
const Instruction& instruction, int a, int b)
{ proc.matmulsm(source, instruction, a, b); }

View File

@@ -78,14 +78,6 @@ ProtocolBase<T>::~ProtocolBase()
#endif
}
template<class T>
void ProtocolBase<T>::muls(const vector<int>& reg,
SubProcessor<T>& proc, typename T::MAC_Check& MC, int size)
{
(void)MC;
proc.muls(reg, size);
}
template<class T>
void ProtocolBase<T>::mulrs(const vector<int>& reg,
SubProcessor<T>& proc)

View File

@@ -9,18 +9,42 @@
#include <vector>
using namespace std;
#include "Tools/Lock.h"
template<class T> class SubProcessor;
template<class T>
class ShuffleStore
{
typedef T shuffle_type;
deque<shuffle_type> shuffles;
Lock store_lock;
void lock();
void unlock();
public:
int add();
shuffle_type& get(int handle);
void del(int handle);
};
template<class T>
class SecureShuffle
{
public:
typedef vector<vector<vector<T>>> shuffle_type;
typedef ShuffleStore<shuffle_type> store_type;
private:
SubProcessor<T>& proc;
vector<T> to_shuffle;
vector<vector<T>> config;
vector<T> tmp;
int unit_size;
vector<vector<vector<vector<T>>>> shuffles;
size_t n_shuffle;
bool exact;
@@ -62,7 +86,7 @@ public:
SecureShuffle(SubProcessor<T>& proc);
int generate(int n_shuffle);
int generate(int n_shuffle, store_type& store);
/**
*
@@ -73,12 +97,12 @@ public:
* would result in [3,4,1,2]
* @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to)
* @param input_base The starting address of the input vector (i.e. the location from which to read the permutation)
* @param handle The integer identifying the preconfigured waksman network (shuffle) to use. Such a handle can be obtained from calling
* @param shuffle The preconfigured waksman network (shuffle) to use
* @param reverse Boolean indicating whether to apply the inverse of the permutation
* @see SecureShuffle::generate for obtaining a shuffle handle
*/
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int handle, bool reverse);
size_t input_base, shuffle_type& shuffle, bool reverse);
/**
* Calculate the secret inverse permutation of stack given secret permutation.
@@ -94,8 +118,6 @@ public:
* @param input_base The starting address of the input vector (i.e. the location from which to read the permutation)
*/
void inverse_permutation(vector<T>& stack, size_t n, size_t output_base, size_t input_base);
void del(int handle);
};
#endif /* PROTOCOLS_SECURESHUFFLE_H_ */

View File

@@ -12,6 +12,45 @@
#include <math.h>
#include <algorithm>
template<class T>
void ShuffleStore<T>::lock()
{
store_lock.lock();
}
template<class T>
void ShuffleStore<T>::unlock()
{
store_lock.unlock();
}
template<class T>
int ShuffleStore<T>::add()
{
lock();
int res = shuffles.size();
shuffles.push_back({});
unlock();
return res;
}
template<class T>
typename ShuffleStore<T>::shuffle_type& ShuffleStore<T>::get(int handle)
{
lock();
auto& res = shuffles.at(handle);
unlock();
return res;
}
template<class T>
void ShuffleStore<T>::del(int handle)
{
lock();
shuffles.at(handle) = {};
unlock();
}
template<class T>
SecureShuffle<T>::SecureShuffle(SubProcessor<T>& proc) :
proc(proc), unit_size(0), n_shuffle(0), exact(false)
@@ -33,13 +72,12 @@ SecureShuffle<T>::SecureShuffle(vector<T>& a, size_t n, int unit_size,
template<class T>
void SecureShuffle<T>::apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int handle, bool reverse)
size_t input_base, shuffle_type& shuffle, bool reverse)
{
this->unit_size = unit_size;
pre(a, n, input_base);
auto& shuffle = shuffles.at(handle);
assert(shuffle.size() == proc.protocol.get_relevant_players().size());
if (reverse)
@@ -134,12 +172,6 @@ void SecureShuffle<T>::inverse_permutation(vector<T> &stack, size_t n, size_t ou
post(stack, n, output_base);
}
template<class T>
void SecureShuffle<T>::del(int handle)
{
shuffles.at(handle).clear();
}
template<class T>
void SecureShuffle<T>::pre(vector<T>& a, size_t n, size_t input_base)
{
@@ -230,11 +262,10 @@ void SecureShuffle<T>::player_round(int config_player) {
}
template<class T>
int SecureShuffle<T>::generate(int n_shuffle)
int SecureShuffle<T>::generate(int n_shuffle, store_type& store)
{
int res = shuffles.size();
shuffles.push_back({});
auto& shuffle = shuffles.back();
int res = store.add();
auto& shuffle = store.get(res);
for (auto i: proc.protocol.get_relevant_players()) {
vector<int> perm;

View File

@@ -65,6 +65,11 @@ public:
return "Shamir " + T::type_string();
}
static string alt()
{
return "ATLAS";
}
static int threshold(int)
{
return ShamirMachine::s().threshold;

Some files were not shown because too many files have changed in this diff Show More