mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Multinode computation.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdint.h>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
18
CHANGELOG.md
18
CHANGELOG.md
@@ -1,5 +1,23 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.3.8 (December 14, 2023)
|
||||
|
||||
- Functionality for multiple nodes per party
|
||||
- Functionality to use disk space for high-level data structures
|
||||
- True division is always fixed-point division (similar to Python 3)
|
||||
- Compiler option to optimize for specific protocol
|
||||
- Cleartext permutation
|
||||
- Faster compilation and lower bytecode size
|
||||
- Functionality to output secret shares from high-level code
|
||||
- Run-time command-line arguments accessible from high-level code
|
||||
- Client connection setup specifies cleartext domain
|
||||
- Compile-time parameter for connection timeout
|
||||
- Prevent connections from timing out (@ParallelogramPal)
|
||||
- More ECDSA examples
|
||||
- More flexible multiplication instruction
|
||||
- Dot product instruction supports several operations at once
|
||||
- Example-based virtual machine explanation
|
||||
|
||||
## 0.3.7 (August 14, 2023)
|
||||
|
||||
- Path Oblivious Heap (@tskovlund)
|
||||
|
||||
1
CONFIG
1
CONFIG
@@ -87,6 +87,7 @@ LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
|
||||
LDLIBS += $(BREW_LDLIBS)
|
||||
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
|
||||
LDLIBS += -lboost_system -lssl -lcrypto
|
||||
LDLIBS += -lboost_filesystem -lboost_iostreams
|
||||
|
||||
CFLAGS += -I./local/include
|
||||
|
||||
|
||||
@@ -80,17 +80,17 @@ opcodes = dict(
|
||||
CONVCBITVEC = 0x231,
|
||||
)
|
||||
|
||||
class BinaryVectorInstruction(base.Instruction):
|
||||
is_vec = lambda self: True
|
||||
class BinaryCiscable(base.Ciscable):
|
||||
pass
|
||||
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
class BinaryVectorInstruction(BinaryCiscable):
|
||||
is_vec = lambda self: True
|
||||
|
||||
class NonVectorInstruction(base.Instruction):
|
||||
is_vec = lambda self: False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert(args[0].n <= args[0].unit)
|
||||
assert(args[0].n is None or args[0].n <= args[0].unit)
|
||||
super(NonVectorInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
class NonVectorInstruction1(base.Instruction):
|
||||
@@ -163,7 +163,7 @@ class andrs(BinaryVectorInstruction):
|
||||
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
|
||||
|
||||
class andrsvec(base.VarArgsInstruction, base.Mergeable,
|
||||
base.DynFormatInstruction):
|
||||
base.DynFormatInstruction, BinaryCiscable):
|
||||
""" Constant-vector AND of secret bit registers (vectorized version).
|
||||
|
||||
:param: total number of arguments to follow (int)
|
||||
@@ -206,6 +206,9 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
|
||||
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
|
||||
req_node.increment(('bit', 'mixed'), size)
|
||||
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
|
||||
class ands(BinaryVectorInstruction):
|
||||
""" Bitwise AND of secret bit register vector.
|
||||
|
||||
@@ -306,7 +309,7 @@ class bitcoms(NonVectorInstruction, base.VarArgsInstruction):
|
||||
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
|
||||
|
||||
class bitdecc(NonVectorInstruction, base.VarArgsInstruction):
|
||||
""" Secret bit register decomposition.
|
||||
""" Clear bit register decomposition.
|
||||
|
||||
:param: number of arguments to follow / number of bits plus one (int)
|
||||
:param: source (sbit)
|
||||
@@ -513,8 +516,8 @@ class convcbitvec(BinaryVectorInstruction):
|
||||
"""
|
||||
code = opcodes['CONVCBITVEC']
|
||||
arg_format = ['int','ciw','cb']
|
||||
def __init__(self, *args):
|
||||
super(convcbitvec, self).__init__(*args)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(convcbitvec, self).__init__(*args, **kwargs)
|
||||
assert(args[2].n == args[0])
|
||||
args[1].set_size(args[0])
|
||||
|
||||
@@ -546,14 +549,14 @@ class split(base.Instruction):
|
||||
super(split_class, self).__init__(*args, **kwargs)
|
||||
assert (len(args) - 2) % args[0] == 0
|
||||
|
||||
class movsb(NonVectorInstruction):
|
||||
class movsb(BinaryVectorInstruction):
|
||||
""" Copy secret bit register.
|
||||
|
||||
:param: destination (sbit)
|
||||
:param: source (sbit)
|
||||
"""
|
||||
code = opcodes['MOVSB']
|
||||
arg_format = ['sbw','sb']
|
||||
arg_format = ['int', 'sbw','sb']
|
||||
|
||||
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
|
||||
""" Secret bit register vector transpose. The first destination vector
|
||||
@@ -568,8 +571,6 @@ class trans(base.VarArgsInstruction, base.DynFormatInstruction):
|
||||
"""
|
||||
code = opcodes['TRANS']
|
||||
is_vec = lambda self: True
|
||||
def __init__(self, *args):
|
||||
super(trans, self).__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
|
||||
@@ -97,8 +97,9 @@ class bits(Tape.Register, _structure, _bit):
|
||||
cbits.conv_cint_vec(a, *res)
|
||||
return res
|
||||
@classmethod
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
|
||||
def malloc(cls, size, creator_tape=None, **kwargs):
|
||||
return Program.prog.malloc(size, cls, creator_tape=creator_tape,
|
||||
**kwargs)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@@ -254,6 +255,18 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return self.get_type(length).bit_compose([self] * length)
|
||||
else:
|
||||
raise CompilerError('cannot expand from %s to %s' % (self.n, length))
|
||||
@classmethod
|
||||
def new_vector(cls, size):
|
||||
return cls.get_type(size)()
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
return cls.bit_compose(
|
||||
sum([part.bit_decompose() for part in parts], []))
|
||||
def copy_from_part(self, source, base, size):
|
||||
self.mov(self,
|
||||
self.bit_compose(source.bit_decompose()[base:base + size]))
|
||||
def vector_size(self):
|
||||
return self.n
|
||||
|
||||
class cbits(bits):
|
||||
""" Clear bits register. Helper type with limited functionality. """
|
||||
@@ -425,7 +438,7 @@ class sbits(bits):
|
||||
tmp = cbits.get_type(n)()
|
||||
tmp.conv_regint_by_bit(n, tmp, other)
|
||||
res.load_other(tmp)
|
||||
mov = inst.movsb
|
||||
mov = staticmethod(lambda x, y: inst.movsb(x.n, x, y))
|
||||
types = {}
|
||||
def __init__(self, *args, **kwargs):
|
||||
bits.__init__(self, *args, **kwargs)
|
||||
@@ -1048,6 +1061,9 @@ def result_conv(x, y):
|
||||
|
||||
class sbit(bit, sbits):
|
||||
""" Single secret bit. """
|
||||
@classmethod
|
||||
def get_type(cls, length):
|
||||
return sbits.get_type(length)
|
||||
def if_else(self, x, y):
|
||||
""" Non-vectorized oblivious selection::
|
||||
|
||||
@@ -1301,6 +1317,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
|
||||
"""
|
||||
bit_extend = staticmethod(_complement_two_extend)
|
||||
mul_functions = {}
|
||||
@classmethod
|
||||
def popcnt_bits(cls, bits):
|
||||
return sbitvec.from_vec(bits).popcnt()
|
||||
@@ -1326,20 +1343,42 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return NotImplemented
|
||||
my_bits, other_bits = self.expand(other, False)
|
||||
matrix = []
|
||||
m = float('inf')
|
||||
uniform = True
|
||||
for x in itertools.chain(my_bits, other_bits):
|
||||
try:
|
||||
uniform &= type(x) == type(my_bits[0]) and x.n == my_bits[0].n
|
||||
m = min(m, x.n)
|
||||
except:
|
||||
pass
|
||||
if uniform and Program.prog.options.cisc:
|
||||
bl = len(my_bits)
|
||||
key = bl, len(other_bits)
|
||||
if key not in self.mul_functions:
|
||||
def instruction(*args):
|
||||
res = self.binary_mul(args[bl:2 * bl], args[2 * bl:],
|
||||
args[0].n)
|
||||
for x, y in zip(res, args):
|
||||
x.mov(y, x)
|
||||
instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits))
|
||||
self.mul_functions[key] = instructions_base.cisc(instruction,
|
||||
bl)
|
||||
res = [sbits.get_type(m)() for i in range(bl)]
|
||||
self.mul_functions[key](*(res + my_bits + other_bits))
|
||||
return self.from_vec(res)
|
||||
else:
|
||||
return self.binary_mul(my_bits, other_bits, m)
|
||||
@classmethod
|
||||
def binary_mul(cls, my_bits, other_bits, m):
|
||||
matrix = []
|
||||
for i, b in enumerate(other_bits):
|
||||
if m == 1:
|
||||
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
|
||||
matrix.append([x * b for x in my_bits[:len(my_bits)-i]])
|
||||
else:
|
||||
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
|
||||
matrix.append((
|
||||
sbitvec.from_vec(my_bits[:len(my_bits)-i]) * b).v)
|
||||
v = sbitint.wallace_tree_from_matrix(matrix)
|
||||
return self.from_vec(v[:len(self.v)])
|
||||
return cls.from_vec(v[:len(my_bits)])
|
||||
__rmul__ = __mul__
|
||||
reduce_after_mul = lambda x: x
|
||||
def TruncMul(self, other, k, m, kappa=None, nearest=False):
|
||||
|
||||
@@ -104,9 +104,10 @@ class AllocRange:
|
||||
break
|
||||
|
||||
class AllocPool:
|
||||
def __init__(self):
|
||||
def __init__(self, parent=None):
|
||||
self.ranges = defaultdict(lambda: [AllocRange()])
|
||||
self.by_base = {}
|
||||
self.parent = parent
|
||||
|
||||
def alloc(self, reg_type, size):
|
||||
for r in self.ranges[reg_type]:
|
||||
@@ -116,8 +117,17 @@ class AllocPool:
|
||||
return res
|
||||
|
||||
def free(self, reg):
|
||||
r = self.by_base.pop((reg.reg_type, reg.i))
|
||||
r.free(reg.i, reg.size)
|
||||
try:
|
||||
r = self.by_base.pop((reg.reg_type, reg.i))
|
||||
r.free(reg.i, reg.size)
|
||||
except KeyError:
|
||||
try:
|
||||
self.parent.free(reg)
|
||||
except:
|
||||
if program.Program.prog.options.debug:
|
||||
print('Error with freeing register with trace:')
|
||||
print(util.format_trace(reg.caller))
|
||||
print()
|
||||
|
||||
def new_ranges(self, min_usage):
|
||||
for t, n in min_usage.items():
|
||||
@@ -133,7 +143,10 @@ class AllocPool:
|
||||
rr.consolidate()
|
||||
|
||||
def n_fragments(self):
|
||||
return max(len(r) for r in self.ranges)
|
||||
if self.ranges:
|
||||
return max(len(r) for r in self.ranges)
|
||||
else:
|
||||
return 0
|
||||
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
@@ -146,6 +159,7 @@ class StraightlineAllocator:
|
||||
assert(n == REG_MAX)
|
||||
self.program = program
|
||||
self.old_pool = None
|
||||
self.unused = defaultdict(lambda: 0)
|
||||
|
||||
def alloc_reg(self, reg, free):
|
||||
base = reg.vectorbase
|
||||
@@ -195,7 +209,8 @@ class StraightlineAllocator:
|
||||
for x in itertools.chain(dup.duplicates, base.duplicates):
|
||||
to_check.add(x)
|
||||
|
||||
free.free(base)
|
||||
if reg not in self.program.base_addresses:
|
||||
free.free(base)
|
||||
if inst.is_vec() and base.vector:
|
||||
self.defined[base] = inst
|
||||
for i in base.vector:
|
||||
@@ -220,8 +235,11 @@ class StraightlineAllocator:
|
||||
if unused_regs and len(unused_regs) == len(list(i.get_def())) and \
|
||||
self.program.verbose:
|
||||
# only report if all assigned registers are unused
|
||||
print("Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller)))
|
||||
self.unused[type(i).__name__] += 1
|
||||
if self.program.verbose > 1:
|
||||
print(
|
||||
"Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller)))
|
||||
|
||||
for j in i.get_used():
|
||||
self.alloc_reg(j, alloc_pool)
|
||||
@@ -277,6 +295,7 @@ class StraightlineAllocator:
|
||||
x = reg.reg_type, reg.size
|
||||
print('Used registers: ', end='')
|
||||
p(sizes)
|
||||
print('Unused instructions:', dict(self.unused))
|
||||
|
||||
def determine_scope(block, options):
|
||||
last_def = defaultdict_by_id(lambda: -1)
|
||||
@@ -421,6 +440,7 @@ class Merger:
|
||||
last = defaultdict(lambda: defaultdict(lambda: None))
|
||||
last_open = deque()
|
||||
last_input = defaultdict(lambda: [None, None])
|
||||
mem_scopes = defaultdict_by_id(lambda: MemScope())
|
||||
|
||||
depths = [0] * len(block.instructions)
|
||||
self.depths = depths
|
||||
@@ -429,6 +449,12 @@ class Merger:
|
||||
self.sources = []
|
||||
self.real_depths = [0] * len(block.instructions)
|
||||
round_type = {}
|
||||
shuffles = defaultdict_by_id(set)
|
||||
|
||||
class MemScope:
|
||||
def __init__(self):
|
||||
self.read = []
|
||||
self.write = []
|
||||
|
||||
def add_edge(i, j):
|
||||
if i in (-1, j):
|
||||
@@ -581,14 +607,20 @@ class Merger:
|
||||
depths[n] = depth
|
||||
|
||||
if isinstance(instr, ReadMemoryInstruction):
|
||||
if options.preserve_mem_order or instr._protect:
|
||||
if options.preserve_mem_order:
|
||||
strict_mem_access(n, last_mem_read, last_mem_write)
|
||||
elif not options.preserve_mem_order:
|
||||
elif instr._protect:
|
||||
scope = mem_scopes[instr._protect]
|
||||
strict_mem_access(n, scope.read, scope.write)
|
||||
if not options.preserve_mem_order:
|
||||
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
|
||||
elif isinstance(instr, WriteMemoryInstruction):
|
||||
if options.preserve_mem_order or instr._protect:
|
||||
if options.preserve_mem_order:
|
||||
strict_mem_access(n, last_mem_write, last_mem_read)
|
||||
elif not options.preserve_mem_order:
|
||||
elif instr._protect:
|
||||
scope = mem_scopes[instr._protect]
|
||||
strict_mem_access(n, scope.write, scope.read)
|
||||
if not options.preserve_mem_order:
|
||||
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
|
||||
elif isinstance(instr, matmulsm):
|
||||
if options.preserve_mem_order:
|
||||
@@ -608,6 +640,11 @@ class Merger:
|
||||
keep_order(instr, n, instr.args[0])
|
||||
elif isinstance(instr, StackInstruction):
|
||||
keep_order(instr, n, StackInstruction)
|
||||
elif isinstance(instr, applyshuffle):
|
||||
shuffles[instr.args[3]].add(n)
|
||||
elif isinstance(instr, delshuffle):
|
||||
for i_inst in shuffles[instr.args[0]]:
|
||||
add_edge(i_inst, n)
|
||||
|
||||
if not G.pred[n]:
|
||||
self.sources.append(n)
|
||||
@@ -683,6 +720,7 @@ class RegintOptimizer:
|
||||
self.cache = util.dict_by_id()
|
||||
self.offset_cache = util.dict_by_id()
|
||||
self.rev_offset_cache = {}
|
||||
self.range_cache = util.dict_by_id()
|
||||
|
||||
def add_offset(self, res, new_base, new_offset):
|
||||
self.offset_cache[res] = new_base, new_offset
|
||||
@@ -693,6 +731,12 @@ class RegintOptimizer:
|
||||
for i, inst in enumerate(instructions):
|
||||
if isinstance(inst, ldint_class):
|
||||
self.cache[inst.args[0]] = inst.args[1]
|
||||
elif isinstance(inst, incint):
|
||||
if inst.args[2] == 1 and inst.args[3] == 1 and \
|
||||
inst.args[4] == len(inst.args[0]) and \
|
||||
inst.args[1] in self.cache:
|
||||
self.range_cache[inst.args[0]] = \
|
||||
len(inst.args[0]), self.cache[inst.args[1]]
|
||||
elif isinstance(inst, IntegerInstruction):
|
||||
if inst.args[1] in self.cache and inst.args[2] in self.cache:
|
||||
res = inst.op(self.cache[inst.args[1]],
|
||||
@@ -731,6 +775,10 @@ class RegintOptimizer:
|
||||
base, offset = self.offset_cache[inst.args[1]]
|
||||
addr = self.rev_offset_cache[base.i, offset]
|
||||
inst.args[1] = addr
|
||||
elif inst.args[1] in self.range_cache:
|
||||
size, base = self.range_cache[inst.args[1]]
|
||||
if size == len(inst.args[0]):
|
||||
instructions[i] = inst.get_direct(base)
|
||||
elif type(inst) == convint_class:
|
||||
if inst.args[1] in self.cache:
|
||||
res = self.cache[inst.args[1]]
|
||||
|
||||
@@ -65,6 +65,8 @@ def ld2i(c, n):
|
||||
movc(c, t1)
|
||||
|
||||
def require_ring_size(k, op):
|
||||
if not program.options.ring:
|
||||
return
|
||||
if int(program.options.ring) < k:
|
||||
msg = 'ring size too small for %s, compile ' \
|
||||
'with \'-R %d\' or more' % (op, k)
|
||||
@@ -140,7 +142,7 @@ def TruncRing(d, a, k, m, signed):
|
||||
high = sint.conv(carries[length])
|
||||
else:
|
||||
if m == 1:
|
||||
low = x[1][1]
|
||||
low = x[0][1]
|
||||
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
|
||||
sint.conv(x[0][-1])
|
||||
else:
|
||||
@@ -181,7 +183,7 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
if k == m:
|
||||
return 0
|
||||
assert k > m
|
||||
assert int(program.options.ring) >= k
|
||||
require_ring_size(k, 'leaky truncation')
|
||||
from .types import sint, intbitint, cint, cgf2n
|
||||
n_bits = k - m
|
||||
n_shift = int(program.options.ring) - n_bits
|
||||
@@ -228,7 +230,7 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
|
||||
|
||||
def Mod2mRing(a_prime, a, k, m, signed):
|
||||
assert(int(program.options.ring) >= k)
|
||||
require_ring_size(k, 'modulo power of two')
|
||||
from Compiler.types import sint, intbitint, cint
|
||||
shift = int(program.options.ring) - m
|
||||
r_prime, r_bin = MaskingBitsInRing(m, True)
|
||||
@@ -404,7 +406,7 @@ def carry(b, a, compute_p=True):
|
||||
return b
|
||||
if b is None:
|
||||
return a
|
||||
t = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
t = [None] * 3
|
||||
if compute_p:
|
||||
t[0] = a[0].bit_and(b[0])
|
||||
t[2] = a[0].bit_and(b[1]) + a[1]
|
||||
|
||||
@@ -13,12 +13,28 @@ from .program import Program, defaults
|
||||
|
||||
|
||||
class Compiler:
|
||||
def __init__(self, custom_args=None, usage=None, execute=False):
|
||||
def __init__(self, custom_args=None, usage=None, execute=False,
|
||||
split_args=False):
|
||||
if usage:
|
||||
self.usage = usage
|
||||
else:
|
||||
self.usage = "usage: %prog [options] filename [args]"
|
||||
self.execute = execute
|
||||
self.runtime_args = []
|
||||
|
||||
if split_args:
|
||||
if custom_args is None:
|
||||
args = sys.argv
|
||||
else:
|
||||
args = custom_args
|
||||
try:
|
||||
split = args.index('--')
|
||||
except ValueError:
|
||||
split = len(args)
|
||||
|
||||
custom_args = args[1:split]
|
||||
self.runtime_args = args[split + 1:]
|
||||
|
||||
self.custom_args = custom_args
|
||||
self.build_option_parser()
|
||||
self.VARS = {}
|
||||
@@ -148,7 +164,8 @@ class Compiler:
|
||||
"--prime",
|
||||
dest="prime",
|
||||
default=defaults.prime,
|
||||
help="prime modulus (default: not specified)",
|
||||
help="use bit decomposition with a specifed prime modulus "
|
||||
"for non-linear computation (default: use the masking approach)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-I",
|
||||
@@ -235,13 +252,31 @@ class Compiler:
|
||||
dest="hostfile",
|
||||
help="hosts to execute with",
|
||||
)
|
||||
else:
|
||||
parser.add_option(
|
||||
"-E",
|
||||
"--execute",
|
||||
dest="execute",
|
||||
help="protocol to optimize for",
|
||||
)
|
||||
self.parser = parser
|
||||
|
||||
def parse_args(self):
|
||||
self.options, self.args = self.parser.parse_args(self.custom_args)
|
||||
if self.execute:
|
||||
if not self.options.execute:
|
||||
raise CompilerError("must give name of protocol with '-E'")
|
||||
if len(self.args) > 1:
|
||||
self.options.execute = self.args.pop(0)
|
||||
else:
|
||||
self.parser.error("missing protocol name")
|
||||
if self.options.hostfile:
|
||||
try:
|
||||
open(self.options.hostfile)
|
||||
except:
|
||||
print('hostfile %s not found' % self.options.hostfile,
|
||||
file=sys.stderr)
|
||||
exit(1)
|
||||
if self.options.execute:
|
||||
protocol = self.options.execute
|
||||
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
|
||||
protocol.find("brain") >= 0 or protocol == "emulate":
|
||||
@@ -268,14 +303,14 @@ class Compiler:
|
||||
|
||||
def build_program(self, name=None):
|
||||
self.prog = Program(self.args, self.options, name=name)
|
||||
if self.execute:
|
||||
if self.options.execute:
|
||||
if self.options.execute in \
|
||||
("emulate", "ring", "rep-field", "rep4-ring"):
|
||||
self.prog.use_trunc_pr = True
|
||||
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
|
||||
self.prog.use_split(3)
|
||||
if self.options.execute in ("semi2k",):
|
||||
self.prog.use_split(2)
|
||||
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
|
||||
if self.options.execute in ("rep4-ring",):
|
||||
self.prog.use_split(4)
|
||||
|
||||
@@ -476,7 +511,9 @@ class Compiler:
|
||||
else:
|
||||
return protocol + "-party.x"
|
||||
|
||||
def local_execution(self, args=[]):
|
||||
def local_execution(self, args=None):
|
||||
if args is None:
|
||||
args = self.runtime_args
|
||||
executable = self.executable_from_protocol(self.options.execute)
|
||||
if not os.path.exists("%s/%s" % (self.root, executable)):
|
||||
print("Creating binary for virtual machine...")
|
||||
@@ -488,9 +525,13 @@ class Compiler:
|
||||
"Note that compilation requires a few GB of RAM.")
|
||||
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
|
||||
sys.stdout.flush()
|
||||
print("Compilation finished, running program...", file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
os.execl(vm, vm, self.prog.name, *args)
|
||||
|
||||
def remote_execution(self, args=[]):
|
||||
def remote_execution(self, args=None):
|
||||
if args is None:
|
||||
args = self.runtime_args
|
||||
vm = self.executable_from_protocol(self.options.execute)
|
||||
hosts = list(x.strip()
|
||||
for x in filter(None, open(self.options.hostfile)))
|
||||
|
||||
@@ -286,6 +286,7 @@ def BitDecRingRaw(a, k, m):
|
||||
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
|
||||
return bits
|
||||
|
||||
@instructions_base.bit_cisc
|
||||
def BitDecRing(a, k, m):
|
||||
bits = BitDecRingRaw(a, k, m)
|
||||
# reversing to reduce number of rounds
|
||||
@@ -304,6 +305,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
instructions_base.reset_global_vector_size()
|
||||
return res
|
||||
|
||||
@instructions_base.bit_cisc
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
|
||||
return [types.sintbit.conv(bit) for bit in res]
|
||||
@@ -358,7 +360,6 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
""" Oblivious truncation by secret m """
|
||||
prog = program.Program.prog
|
||||
kappa = kappa or prog.security
|
||||
if util.is_constant(m) and not compute_modulo:
|
||||
# cheaper
|
||||
res = type(a)(size=a.size)
|
||||
@@ -371,6 +372,8 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
return a * (1 - m)
|
||||
if program.Program.prog.options.ring and not compute_modulo:
|
||||
return TruncInRing(a, l, Pow2(m, l, kappa))
|
||||
else:
|
||||
kappa = kappa or program.Program.prog.security
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
@@ -409,6 +412,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
b = shifted - d
|
||||
return b
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def TruncInRing(to_shift, l, pow2m):
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
bits = BitDecRing(to_shift, l, l)
|
||||
@@ -433,11 +437,7 @@ def SplitInRing(a, l, m):
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
|
||||
if program.Program.prog.options.ring:
|
||||
s = (1 - overflow) * t + \
|
||||
comparison.TruncLeakyInRing(overflow * t, length, 1, False)
|
||||
else:
|
||||
s = (1 - overflow) * t + overflow * t / 2
|
||||
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
|
||||
return s, overflow
|
||||
|
||||
def Int2FL(a, gamma, l, kappa=None):
|
||||
@@ -555,7 +555,7 @@ def TruncPrField(a, k, m, kappa=None):
|
||||
c = (b + r).reveal(False)
|
||||
c_prime = c % two_to_m
|
||||
a_prime = c_prime - r_prime
|
||||
d = (a - a_prime) / two_to_m
|
||||
d = (a - a_prime).field_div(two_to_m)
|
||||
return d
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
|
||||
@@ -345,6 +345,17 @@ class starg(base.Instruction):
|
||||
code = base.opcodes['STARG']
|
||||
arg_format = ['ci']
|
||||
|
||||
@base.vectorize
|
||||
class cmdlinearg(base.Instruction):
|
||||
""" Load command-line argument.
|
||||
|
||||
:param: dest (regint)
|
||||
:param: index (regint)
|
||||
|
||||
"""
|
||||
code = base.opcodes['CMDLINEARG']
|
||||
arg_format = ['ciw','ci']
|
||||
|
||||
@base.gf2n
|
||||
class reqbl(base.Instruction):
|
||||
""" Requirement on computation modulus. Minimal bit length of prime if
|
||||
@@ -654,7 +665,7 @@ class picks(base.VectorInstruction):
|
||||
def __init__(self, *args):
|
||||
super(picks, self).__init__(*args)
|
||||
assert 0 <= args[2] < len(args[1])
|
||||
assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1])
|
||||
assert 0 <= args[2] + args[3] * (len(args[0]) - 1) < len(args[1])
|
||||
|
||||
class concats(base.VectorInstruction):
|
||||
""" Concatenate vectors.
|
||||
@@ -1630,6 +1641,16 @@ class print_reg_plain(base.IOInstruction):
|
||||
code = base.opcodes['PRINTREGPLAIN']
|
||||
arg_format = ['c']
|
||||
|
||||
class print_reg_plains(base.IOInstruction):
|
||||
""" Output secret register.
|
||||
|
||||
:param: source (sint)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['PRINTREGPLAINS']
|
||||
arg_format = ['s']
|
||||
|
||||
class cond_print_plain(base.IOInstruction):
|
||||
""" Conditionally output clear register (with precision).
|
||||
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
|
||||
@@ -1860,6 +1881,19 @@ class acceptclientconnection(base.IOInstruction):
|
||||
code = base.opcodes['ACCEPTCLIENTCONNECTION']
|
||||
arg_format = ['ciw', 'ci']
|
||||
|
||||
class initclientconnection(base.IOInstruction):
|
||||
""" Initialize connection.
|
||||
|
||||
:param: client id destination (regint)
|
||||
:param: port number (regint)
|
||||
:param: my client id (regint)
|
||||
:param: hostname (variable string)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['INITCLIENTCONNECTION']
|
||||
arg_format = ['ciw', 'ci', 'ci', 'varstr']
|
||||
|
||||
class closeclientconnection(base.IOInstruction):
|
||||
""" Close connection to client.
|
||||
|
||||
@@ -1941,7 +1975,7 @@ class fixinput(base.PublicFileIOInstruction):
|
||||
|
||||
:param: player (int)
|
||||
:param: destination (cint)
|
||||
:param: exponent (int)
|
||||
:param: exponent (int, for float/double) / byte length (1/8, for integer)
|
||||
:param: input type (0: 64-bit integer, 1: float, 2: double)
|
||||
|
||||
"""
|
||||
@@ -2284,30 +2318,30 @@ class asm_open(base.VarArgsInstruction, base.DataInstruction):
|
||||
self.args += other.args[1:]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class muls(base.VarArgsInstruction, base.DataInstruction):
|
||||
class muls(base.VarArgsInstruction, base.DataInstruction, base.Ciscable):
|
||||
""" (Element-wise) multiplication of secret registers (vectors).
|
||||
|
||||
:param: number of arguments to follow (multiple of three)
|
||||
:param: number of arguments to follow (multiple of four)
|
||||
:param: vector size (int)
|
||||
:param: result (sint)
|
||||
:param: factor (sint)
|
||||
:param: factor (sint)
|
||||
:param: (repeat the last three)...
|
||||
:param: (repeat the last four)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['MULS']
|
||||
arg_format = tools.cycle(['sw','s','s'])
|
||||
arg_format = tools.cycle(['int','sw','s','s'])
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(muls_class, self).__init__(*args, **kwargs)
|
||||
for i in range(0, len(args), 4):
|
||||
for j in range(3):
|
||||
assert args[i + j + 1].size == args[i]
|
||||
|
||||
def get_repeat(self):
|
||||
return len(self.args) // 3
|
||||
|
||||
def merge_id(self):
|
||||
# can merge different sizes
|
||||
# but not if large
|
||||
if self.get_size() is None or self.get_size() > 100:
|
||||
return type(self), self.get_size()
|
||||
return type(self)
|
||||
return sum(self.args[::4])
|
||||
|
||||
# def expand(self):
|
||||
# s = [program.curr_block.new_reg('s') for i in range(9)]
|
||||
@@ -2324,6 +2358,16 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
|
||||
# adds(s[8], s[7], s[6])
|
||||
# addm(self.args[0], s[8], c[2])
|
||||
|
||||
# compatibility
|
||||
try:
|
||||
vmuls = muls_class
|
||||
muls_bak = muls
|
||||
muls = lambda *args: muls_bak(args[0].size, *args)
|
||||
vgmuls = gmuls_class = gmuls
|
||||
gmuls = lambda *args: gmuls_class(args[0].size, *args)
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
@base.gf2n
|
||||
class mulrs(base.VarArgsInstruction, base.DataInstruction):
|
||||
""" Constant-vector multiplication of secret registers.
|
||||
@@ -2403,8 +2447,8 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction,
|
||||
return self.arg_format()
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(self.args[i] // 2
|
||||
for i, n in self.bases(iter(self.args))) * self.get_size()
|
||||
return sum(self.args[i] // 2 - 1
|
||||
for i, n in self.bases(iter(self.args)))
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
|
||||
@@ -2421,7 +2465,7 @@ class matmul_base(base.DataInstruction):
|
||||
def get_repeat(self):
|
||||
return reduce(operator.mul, self.args[3:6])
|
||||
|
||||
class matmuls(matmul_base):
|
||||
class matmuls(matmul_base, base.Mergeable):
|
||||
""" Secret matrix multiplication from registers. All matrices are
|
||||
represented as vectors in row-first order.
|
||||
|
||||
@@ -2433,7 +2477,11 @@ class matmuls(matmul_base):
|
||||
:param: number of columns in second factor and result (int)
|
||||
"""
|
||||
code = base.opcodes['MATMULS']
|
||||
arg_format = ['sw','s','s','int','int','int']
|
||||
arg_format = itertools.cycle(['sw','s','s','int','int','int'])
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
|
||||
for i in range(0, len(self.args), 6))
|
||||
|
||||
class matmulsm(matmul_base):
|
||||
""" Secret matrix multiplication reading directly from memory.
|
||||
|
||||
@@ -67,6 +67,7 @@ opcodes = dict(
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -152,7 +153,7 @@ opcodes = dict(
|
||||
LISTEN = 0x6c,
|
||||
ACCEPTCLIENTCONNECTION = 0x6d,
|
||||
CLOSECLIENTCONNECTION = 0x6e,
|
||||
READCLIENTPUBLICKEY = 0x6f,
|
||||
INITCLIENTCONNECTION = 0x6f,
|
||||
# Bitwise logic
|
||||
ANDC = 0x70,
|
||||
XORC = 0x71,
|
||||
@@ -196,6 +197,7 @@ opcodes = dict(
|
||||
PRINTREG = 0XB1,
|
||||
RAND = 0xB2,
|
||||
PRINTREGPLAIN = 0xB3,
|
||||
PRINTREGPLAINS = 0xEA,
|
||||
PRINTCHR = 0xB4,
|
||||
PRINTSTR = 0xB5,
|
||||
PUBINPUT = 0xB6,
|
||||
@@ -422,28 +424,31 @@ def gf2n(instruction):
|
||||
class Mergeable:
|
||||
pass
|
||||
|
||||
def cisc(function):
|
||||
def cisc(function, n_outputs=1):
|
||||
class MergeCISC(Mergeable):
|
||||
instructions = {}
|
||||
functions = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.security = program.security
|
||||
self.security = program._security
|
||||
self.calls = [(args, kwargs)]
|
||||
self.params = []
|
||||
self.used = []
|
||||
for arg in self.args[1:]:
|
||||
for arg in self.args[n_outputs:]:
|
||||
if isinstance(arg, program.curr_tape.Register):
|
||||
self.used.append(arg)
|
||||
self.params.append(type(arg))
|
||||
else:
|
||||
self.params.append(arg)
|
||||
self.function = function
|
||||
self.caller = None
|
||||
program.curr_block.instructions.append(self)
|
||||
|
||||
def get_def(self):
|
||||
return [call[0][0] for call in self.calls]
|
||||
return sum(([call[0][i] for call in self.calls]
|
||||
for i in range(n_outputs)), [])
|
||||
|
||||
def get_used(self):
|
||||
return self.used
|
||||
@@ -460,7 +465,7 @@ def cisc(function):
|
||||
self.used += other.used
|
||||
|
||||
def get_size(self):
|
||||
return self.args[0].size
|
||||
return self.args[0].vector_size()
|
||||
|
||||
def new_instructions(self, size, regs):
|
||||
if self.merge_id() not in self.instructions:
|
||||
@@ -474,11 +479,11 @@ def cisc(function):
|
||||
args = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
args.append(type(arg)(size=None))
|
||||
args.append(arg.new_vector(size=None))
|
||||
except:
|
||||
args.append(arg)
|
||||
program.options.cisc = False
|
||||
old_security = program.security
|
||||
old_security = program._security
|
||||
program.security = self.security
|
||||
self.function(*args, **self.kwargs)
|
||||
program.security = old_security
|
||||
@@ -490,7 +495,8 @@ def cisc(function):
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
args[0].can_eliminate = False
|
||||
for i in range(n_outputs):
|
||||
args[i].can_eliminate = False
|
||||
merger.eliminate_dead_code()
|
||||
assert int(program.options.max_parallel_open) == 0, \
|
||||
'merging restriction not compatible with ' \
|
||||
@@ -501,52 +507,105 @@ def cisc(function):
|
||||
n_rounds
|
||||
template, args, self.n_rounds = self.instructions[self.merge_id()]
|
||||
subs = util.dict_by_id()
|
||||
from Compiler import types
|
||||
for arg, reg in zip(args, regs):
|
||||
subs[arg] = reg
|
||||
if isinstance(arg, program.curr_tape.Register):
|
||||
subs[arg] = reg
|
||||
set_global_vector_size(size)
|
||||
for inst in template:
|
||||
inst.copy(size, subs)
|
||||
reset_global_vector_size()
|
||||
|
||||
def expand_to_function(self, size, new_regs):
|
||||
key = size, program.curr_tape, \
|
||||
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
|
||||
tuple(type(reg) for reg in new_regs)
|
||||
if key not in self.functions:
|
||||
from Compiler import library, types
|
||||
from Compiler.GC.types import bits
|
||||
class Arg:
|
||||
def __init__(self, reg):
|
||||
self.type = type(reg)
|
||||
self.binary = isinstance(reg, bits)
|
||||
self.reg = reg
|
||||
# if reg is not None:
|
||||
# program.base_addresses[reg] = None
|
||||
def new(self):
|
||||
if self.binary:
|
||||
return self.type()
|
||||
else:
|
||||
return self.type(size=size)
|
||||
def load(self):
|
||||
return self.reg
|
||||
def store(self, reg):
|
||||
if self.type != type(None):
|
||||
self.reg.update(reg)
|
||||
args = [Arg(x) for x in new_regs]
|
||||
@library.function_block
|
||||
def f():
|
||||
res = [arg.new() for arg in args[:n_outputs]]
|
||||
self.new_instructions(size,
|
||||
res + [arg.load() for arg in args[n_outputs:]])
|
||||
for reg, arg in zip(res, args):
|
||||
arg.store(reg)
|
||||
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
|
||||
[str(x) for x in key[2]])
|
||||
self.functions[key] = f, args
|
||||
f, args = self.functions[key]
|
||||
for i in range(len(new_regs) - n_outputs):
|
||||
args[n_outputs + i].store(new_regs[n_outputs + i])
|
||||
f()
|
||||
for i in range(n_outputs):
|
||||
new_regs[i].link(args[i].load())
|
||||
|
||||
def expand_merged(self, skip):
|
||||
if function.__name__ in skip:
|
||||
good = True
|
||||
for call in self.calls:
|
||||
if not good:
|
||||
break
|
||||
for arg in call[0]:
|
||||
if isinstance(arg, program.curr_tape.Register) and \
|
||||
not issubclass(type(self.calls[0][0][0]), type(arg)):
|
||||
good = False
|
||||
for i in range(n_outputs):
|
||||
for arg in call[0]:
|
||||
if isinstance(arg, program.curr_tape.Register) and \
|
||||
not issubclass(type(self.calls[0][0][0]),
|
||||
type(arg)):
|
||||
good = False
|
||||
if good:
|
||||
return [self], 0
|
||||
return program.curr_block.instructions.append(self)
|
||||
if program.verbose:
|
||||
print('expanding', self.function.__name__)
|
||||
tape = program.curr_tape
|
||||
block = tape.BasicBlock(tape, None, None)
|
||||
tape.active_basicblock = block
|
||||
size = sum(call[0][0].size for call in self.calls)
|
||||
tape.start_new_basicblock()
|
||||
size = sum(call[0][0].vector_size() for call in self.calls)
|
||||
new_regs = []
|
||||
for i, arg in enumerate(self.args):
|
||||
try:
|
||||
if i == 0:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
if i < n_outputs:
|
||||
new_regs.append(arg.new_vector(size=size))
|
||||
else:
|
||||
new_regs.append(type(arg).concat(
|
||||
call[0][i] for call in self.calls))
|
||||
assert len(new_regs[-1]) == size
|
||||
assert new_regs[-1].vector_size() == size
|
||||
except (TypeError, AttributeError):
|
||||
if not isinstance(arg, int):
|
||||
if not isinstance(arg, (int, type(None))):
|
||||
raise
|
||||
break
|
||||
new_regs.append(None)
|
||||
except:
|
||||
print([call[0][0].size for call in self.calls])
|
||||
print([call[0][0].vector_size() for call in self.calls])
|
||||
raise
|
||||
self.new_instructions(size, new_regs)
|
||||
if program.cisc_to_function and \
|
||||
(program.curr_tape.singular or program.n_running_threads):
|
||||
self.expand_to_function(size, new_regs)
|
||||
else:
|
||||
self.new_instructions(size, new_regs)
|
||||
program.curr_block.n_rounds += self.n_rounds - 1
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
reg = call[0][0]
|
||||
reg.copy_from_part(new_regs[0], base, reg.size)
|
||||
base += reg.size
|
||||
return block.instructions, self.n_rounds - 1
|
||||
for i in range(n_outputs):
|
||||
reg = call[0][i]
|
||||
reg.copy_from_part(new_regs[i], base, reg.vector_size())
|
||||
base += reg.vector_size()
|
||||
tape.start_new_basicblock()
|
||||
|
||||
def add_usage(self, *args):
|
||||
pass
|
||||
@@ -605,10 +664,10 @@ def ret_cisc(function):
|
||||
from Compiler import types
|
||||
if not (program.options.cisc and isinstance(args[0], types._register)):
|
||||
return function(*args, **kwargs)
|
||||
if isinstance(args[0], types._clear):
|
||||
res_type = type(args[1])
|
||||
else:
|
||||
res_type = type(args[0])
|
||||
for arg in args:
|
||||
if isinstance(arg, types._secret):
|
||||
res_type = type(arg)
|
||||
break
|
||||
res = res_type(size=args[0].size)
|
||||
instruction(res, *args, **kwargs)
|
||||
return res
|
||||
@@ -642,6 +701,24 @@ def sfix_cisc(function):
|
||||
copy_doc(wrapper, function)
|
||||
return wrapper
|
||||
|
||||
bit_instructions = {}
|
||||
|
||||
def bit_cisc(function):
|
||||
def wrapper(a, k, m, *args, **kwargs):
|
||||
key = function, m
|
||||
if key not in bit_instructions:
|
||||
def instruction(*args, **kwargs):
|
||||
res = function(*args[m:], **kwargs)
|
||||
for x, y in zip(res, args):
|
||||
x.mov(y, x)
|
||||
instruction.__name__ = '%s(%d)' % (function.__name__, m)
|
||||
bit_instructions[key] = cisc(instruction, m)
|
||||
from Compiler.types import sintbit
|
||||
res = [sintbit() for i in range(m)]
|
||||
bit_instructions[function, m](*res, a, k, m, *args, **kwargs)
|
||||
return res
|
||||
return wrapper
|
||||
|
||||
class RegType(object):
|
||||
""" enum-like static class for Register types """
|
||||
ClearModp = 'c'
|
||||
@@ -793,6 +870,23 @@ class String(ArgFormat):
|
||||
def __str__(self):
|
||||
return self.str
|
||||
|
||||
class VarString(ArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, str):
|
||||
raise ArgumentError(arg, 'Argument is not string')
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return int_to_bytes(len(arg)) + list(bytearray(arg, 'ascii'))
|
||||
|
||||
def __init__(self, f):
|
||||
length = IntArgFormat(f).i
|
||||
self.str = str(f.read(length), 'ascii')
|
||||
|
||||
def __str__(self):
|
||||
return self.str
|
||||
|
||||
ArgFormats = {
|
||||
'c': ClearModpAF,
|
||||
's': SecretModpAF,
|
||||
@@ -810,6 +904,7 @@ ArgFormats = {
|
||||
'long': LongArgFormat,
|
||||
'p': PlayerNoAF,
|
||||
'str': String,
|
||||
'varstr': VarString,
|
||||
}
|
||||
|
||||
def format_str_is_reg(format_str):
|
||||
@@ -930,7 +1025,7 @@ class Instruction(object):
|
||||
self.args += other.args
|
||||
|
||||
def expand_vector_args(self):
|
||||
if self.is_vec():
|
||||
if self.is_vec() and self.get_size() != 1:
|
||||
for arg in self.args:
|
||||
arg.create_vector_elements()
|
||||
res = sum(list(zip(*self.args)), ())
|
||||
@@ -939,7 +1034,7 @@ class Instruction(object):
|
||||
return self.args
|
||||
|
||||
def expand_merged(self, skip):
|
||||
return [self], 0
|
||||
program.curr_block.instructions.append(self)
|
||||
|
||||
def get_new_args(self, size, subs):
|
||||
new_args = []
|
||||
@@ -956,6 +1051,10 @@ class Instruction(object):
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
def copy(self, *args, **kwargs):
|
||||
raise CompilerError("%s instruction not compatible with CISC-style "
|
||||
"merging. Compile with '-O'." % type(self))
|
||||
|
||||
@staticmethod
|
||||
def get_usage(args):
|
||||
return {}
|
||||
@@ -990,9 +1089,9 @@ class ParsedInstruction:
|
||||
pass
|
||||
read = lambda: struct.unpack('>I', f.read(4))[0]
|
||||
full_code = struct.unpack('>Q', f.read(8))[0]
|
||||
code = full_code % (1 << Instruction.code_length)
|
||||
self.code = full_code % (1 << Instruction.code_length)
|
||||
self.size = full_code >> Instruction.code_length
|
||||
self.type = cls.reverse_opcodes[code]
|
||||
self.type = cls.reverse_opcodes[self.code]
|
||||
t = self.type
|
||||
name = t.__name__
|
||||
try:
|
||||
@@ -1044,6 +1143,10 @@ class VectorInstruction(Instruction):
|
||||
def get_code(self):
|
||||
return super(VectorInstruction, self).get_code(len(self.args[0]))
|
||||
|
||||
class Ciscable(Instruction):
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs), copying=True)
|
||||
|
||||
class DynFormatInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,c
|
||||
from Compiler.instructions import *
|
||||
from Compiler.util import tuplify,untuplify,is_zero
|
||||
from Compiler.allocator import RegintOptimizer, AllocPool
|
||||
from Compiler import instructions,instructions_base,comparison,program,util
|
||||
from Compiler.program import Tape
|
||||
from Compiler import instructions,instructions_base,comparison,util,types
|
||||
import inspect,math
|
||||
import random
|
||||
import collections
|
||||
@@ -42,7 +43,7 @@ def vectorize(function):
|
||||
|
||||
def set_instruction_type(function):
|
||||
def instruction_typed_function(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
|
||||
if len(args) > 0 and isinstance(args[0], Tape.Register):
|
||||
if args[0].is_gf2n:
|
||||
instructions_base.set_global_instruction_type('gf2n')
|
||||
else:
|
||||
@@ -59,9 +60,15 @@ def set_instruction_type(function):
|
||||
def _expand_to_print(val):
|
||||
return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val)
|
||||
|
||||
def print_str(s, *args):
|
||||
def print_str(s, *args, print_secrets=False):
|
||||
""" Print a string, with optional args for adding
|
||||
variables/registers with ``%s``. """
|
||||
variables/registers with ``%s``.
|
||||
|
||||
:param s: format string
|
||||
:param args: arguments (any type)
|
||||
:param print_secrets: whether to output secret shares
|
||||
|
||||
"""
|
||||
def print_plain_str(ss):
|
||||
""" Print a plain string (no custom formatting options) """
|
||||
ss = bytearray(ss, 'utf8')
|
||||
@@ -84,11 +91,15 @@ def print_str(s, *args):
|
||||
val = args[i].read()
|
||||
else:
|
||||
val = args[i]
|
||||
if isinstance(val, program.Tape.Register):
|
||||
if isinstance(val, Tape.Register):
|
||||
if val.is_clear:
|
||||
val.print_reg_plain()
|
||||
elif print_secrets and isinstance(val, sint):
|
||||
val.output()
|
||||
else:
|
||||
raise CompilerError('Cannot print secret value:', args[i])
|
||||
raise CompilerError(
|
||||
'Cannot print secret value %s, activate printing of shares with '
|
||||
"'print_secrets=True'" % args[i])
|
||||
elif isinstance(val, cfix):
|
||||
val.print_plain()
|
||||
elif isinstance(val, sfix) or isinstance(val, sfloat):
|
||||
@@ -100,16 +111,17 @@ def print_str(s, *args):
|
||||
else:
|
||||
try:
|
||||
val.output()
|
||||
except AttributeError:
|
||||
except (AttributeError, TypeError):
|
||||
print_plain_str(str(val))
|
||||
|
||||
def print_ln(s='', *args):
|
||||
def print_ln(s='', *args, **kwargs):
|
||||
""" Print line, with optional args for adding variables/registers
|
||||
with ``%s``. By default only player 0 outputs, but the ``-I``
|
||||
command-line option changes that.
|
||||
|
||||
:param s: Python string with same number of ``%s`` as length of :py:obj:`args`
|
||||
:param args: list of public values (regint/cint/int/cfix/cfloat/localint)
|
||||
:param print_secrets: whether to output secret shares
|
||||
|
||||
Example:
|
||||
|
||||
@@ -117,7 +129,7 @@ def print_ln(s='', *args):
|
||||
|
||||
print_ln('a is %s.', a.reveal())
|
||||
"""
|
||||
print_str(str(s) + '\n', *args)
|
||||
print_str(str(s) + '\n', *args, **kwargs)
|
||||
|
||||
def print_both(s, end='\n'):
|
||||
""" Print line during compilation and execution. """
|
||||
@@ -169,7 +181,7 @@ def print_str_if(cond, ss, *args):
|
||||
def print_ln_to(player, ss, *args):
|
||||
""" Print line at :py:obj:`player` only. Note that printing is
|
||||
disabled by default except at player 0. Activate interactive mode
|
||||
with `-I` to enable it for all players.
|
||||
with `-I` or use `-OF .` to enable it for all players.
|
||||
|
||||
:param player: int
|
||||
:param ss: Python string
|
||||
@@ -295,8 +307,14 @@ def get_arg():
|
||||
ldarg(res)
|
||||
return res
|
||||
|
||||
def get_cmdline_arg(idx):
|
||||
""" Return run-time command-line argument. """
|
||||
res = regint()
|
||||
cmdlinearg(res, regint.conv(idx))
|
||||
return localint(res)
|
||||
|
||||
def make_array(l, t=None):
|
||||
if isinstance(l, program.Tape.Register):
|
||||
if isinstance(l, Tape.Register):
|
||||
res = Array(len(l), t or type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
@@ -337,14 +355,15 @@ class Function:
|
||||
# first call
|
||||
type_args = collections.defaultdict(list)
|
||||
for i,arg in enumerate(args):
|
||||
type_args[get_reg_type(arg)].append(i)
|
||||
if not isinstance(arg, types._vectorizable):
|
||||
type_args[get_reg_type(arg)].append(i)
|
||||
def wrapped_function(*compile_args):
|
||||
base = get_arg()
|
||||
bases = dict((t, regint.load_mem(base + i)) \
|
||||
for i,t in enumerate(sorted(type_args,
|
||||
key=lambda x:
|
||||
x.reg_type)))
|
||||
runtime_args = [None] * len(args)
|
||||
runtime_args = list(args)
|
||||
for t in sorted(type_args, key=lambda x: x.reg_type):
|
||||
i = 0
|
||||
for i_arg in type_args[t]:
|
||||
@@ -407,13 +426,14 @@ def unmemorize(x):
|
||||
|
||||
class FunctionBlock(Function):
|
||||
def on_first_call(self, wrapped_function):
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
old_block = get_tape().active_basicblock
|
||||
parent_node = get_tape().req_node
|
||||
parent_node = old_block.req_node
|
||||
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
|
||||
block = get_tape().active_basicblock
|
||||
block.alloc_pool = AllocPool()
|
||||
block.alloc_pool = AllocPool(parent=block.alloc_pool)
|
||||
del parent_node.children[-1]
|
||||
self.node = get_tape().req_node
|
||||
self.node = block.req_node
|
||||
if get_program().verbose:
|
||||
print('Compiling function', self.name)
|
||||
result = wrapped_function(*self.compile_args)
|
||||
@@ -423,7 +443,6 @@ class FunctionBlock(Function):
|
||||
self.result = None
|
||||
if get_program().verbose:
|
||||
print('Done compiling function', self.name)
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
get_tape().function_basicblocks[block] = p_return_address
|
||||
return_address = regint.load_mem(p_return_address)
|
||||
get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False))
|
||||
@@ -446,7 +465,7 @@ class FunctionBlock(Function):
|
||||
return_address.store_in_mem(p_return_address)
|
||||
get_tape().start_new_basicblock(name='call-' + self.name)
|
||||
get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
|
||||
get_tape().req_node.children.append(self.node)
|
||||
get_block().req_node.children.append(self.node)
|
||||
if self.result is not None:
|
||||
return unmemorize(self.result)
|
||||
|
||||
@@ -654,7 +673,7 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
and isinstance(step, int):
|
||||
# known loop count
|
||||
if condition(start):
|
||||
get_tape().req_node.children[-1].aggregator = \
|
||||
get_block().req_node.children[-1].aggregator = \
|
||||
lambda x: int(ceil(((stop - start) / step))) * x[0]
|
||||
|
||||
def for_range(start, stop=None, step=None):
|
||||
@@ -680,9 +699,6 @@ def for_range(start, stop=None, step=None):
|
||||
x.update(x + 1)
|
||||
print_ln('%s', x.reveal())
|
||||
|
||||
Note that you cannot overwrite data structures such as
|
||||
:py:class:`~Compiler.types.Array` in a loop. Use
|
||||
:py:func:`~Compiler.types.Array.assign` instead.
|
||||
"""
|
||||
def decorator(loop_body):
|
||||
range_loop(loop_body, start, stop, step)
|
||||
@@ -791,7 +807,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
loop_rounds = n_loops // n_parallel \
|
||||
if n_parallel < n_loops else 0
|
||||
else:
|
||||
loop_rounds = n_loops / n_parallel
|
||||
loop_rounds = n_loops // n_parallel
|
||||
def write_state_to_memory(r):
|
||||
if use_array:
|
||||
mem_state.assign(r)
|
||||
@@ -821,6 +837,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
n_opt_loops_reg = regint(0)
|
||||
n_opt_loops_inst = get_block().instructions[-1]
|
||||
parent_block = get_block()
|
||||
prevent_breaks = get_program().prevent_breaks
|
||||
get_program().prevent_breaks = False
|
||||
@while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0))
|
||||
def _(i):
|
||||
state = tuplify(initializer())
|
||||
@@ -846,6 +864,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
loop_rounds = n_loops // my_n_parallel
|
||||
blocks = get_tape().basicblocks
|
||||
n_to_merge = 5
|
||||
get_program().prevent_breaks = prevent_breaks
|
||||
if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]:
|
||||
# merge blocks started by if and do_while
|
||||
def exit_elimination(block):
|
||||
@@ -857,19 +876,22 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
merged.exit_condition = blocks[-1].exit_condition
|
||||
merged.exit_block = blocks[-1].exit_block
|
||||
assert parent_block is blocks[-n_to_merge]
|
||||
assert blocks[-n_to_merge + 1] is \
|
||||
get_tape().req_node.children[-1].nodes[0].blocks[0]
|
||||
assert blocks[-n_to_merge + 1].req_node is \
|
||||
get_block().req_node.children[-1].nodes[0]
|
||||
for block in blocks[-n_to_merge + 1:]:
|
||||
merged.instructions += block.instructions
|
||||
exit_elimination(block)
|
||||
block.purge(retain_usage=False)
|
||||
del blocks[-n_to_merge + 1:]
|
||||
del get_tape().req_node.children[-1]
|
||||
del get_block().req_node.children[-1]
|
||||
merged.children = []
|
||||
RegintOptimizer().run(merged.instructions, get_program())
|
||||
get_tape().active_basicblock = merged
|
||||
else:
|
||||
req_node = get_tape().req_node.children[-1].nodes[0]
|
||||
if get_program().verbose:
|
||||
print(n_opt_loops, 'repetitions')
|
||||
assert not get_program().prevent_breaks
|
||||
req_node = get_block().req_node.children[-1].nodes[0]
|
||||
if util.is_constant(loop_rounds):
|
||||
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
|
||||
if isinstance(n_loops, int):
|
||||
@@ -892,7 +914,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
return returner
|
||||
return decorator
|
||||
|
||||
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
|
||||
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={},
|
||||
budget=None):
|
||||
"""
|
||||
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
|
||||
threads, up to :py:obj:`n_parallel` in parallel per thread.
|
||||
@@ -902,9 +925,10 @@ def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
|
||||
|
||||
"""
|
||||
return map_reduce(n_threads, n_parallel, n_loops, \
|
||||
lambda *x: [], lambda *x: [], thread_mem_req)
|
||||
lambda *x: [], lambda *x: [], thread_mem_req,
|
||||
budget=budget)
|
||||
|
||||
def for_range_opt_multithread(n_threads, n_loops):
|
||||
def for_range_opt_multithread(n_threads, n_loops, budget=None):
|
||||
"""
|
||||
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
|
||||
threads, in parallel up to an optimization budget per thread
|
||||
@@ -943,7 +967,7 @@ def for_range_opt_multithread(n_threads, n_loops):
|
||||
b = a + 1
|
||||
|
||||
"""
|
||||
return for_range_multithread(n_threads, None, n_loops)
|
||||
return for_range_multithread(n_threads, None, n_loops, budget=budget)
|
||||
|
||||
def multithread(n_threads, n_items=None, max_size=None):
|
||||
"""
|
||||
@@ -983,7 +1007,7 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
return wrapper
|
||||
|
||||
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
thread_mem_req={}, looping=True):
|
||||
thread_mem_req={}, looping=True, budget=None):
|
||||
assert(n_threads != 0)
|
||||
if isinstance(n_loops, (list, tuple)):
|
||||
split = n_loops
|
||||
@@ -1025,10 +1049,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
state_type = type(state[0])
|
||||
else:
|
||||
state_type = type(state)
|
||||
prevent_breaks = get_program().prevent_breaks
|
||||
def f(inc):
|
||||
get_program().prevent_breaks = prevent_breaks
|
||||
base = args[get_arg()][0]
|
||||
get_program().base_addresses[base] = None
|
||||
if not util.is_constant(thread_rounds):
|
||||
i = base / thread_rounds
|
||||
i = base // thread_rounds
|
||||
overhang = n_loops % n_threads
|
||||
inc = i < overhang
|
||||
base += inc.if_else(i, overhang)
|
||||
@@ -1050,6 +1077,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if prog.curr_tape == prog.tapes[0]:
|
||||
prog.n_running_threads = n_threads
|
||||
if not util.is_zero(thread_rounds):
|
||||
prog.prevent_breaks = False
|
||||
tape = prog.new_tape(f, (0,), 'multithread')
|
||||
for i in range(n_threads - remainder):
|
||||
mem_state = make_array(initializer())
|
||||
@@ -1058,6 +1086,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
args[remainder + i][1] = mem_state.address
|
||||
thread_args.append((tape, remainder + i))
|
||||
if remainder:
|
||||
prog.prevent_breaks = False
|
||||
tape1 = prog.new_tape(f, (1,), 'multithread1')
|
||||
for i in range(remainder):
|
||||
mem_state = make_array(initializer())
|
||||
@@ -1066,10 +1095,12 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
args[i][1] = mem_state.address
|
||||
thread_args.append((tape1, i))
|
||||
prog.n_running_threads = None
|
||||
prog.prevent_breaks = False
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
prog.free_later()
|
||||
prog.prevent_breaks = prevent_breaks
|
||||
if len(state):
|
||||
if thread_rounds:
|
||||
for i in range(n_threads - remainder):
|
||||
@@ -1266,7 +1297,7 @@ def _link(pre, g):
|
||||
if g:
|
||||
from .types import _single
|
||||
for name, var in pre.items():
|
||||
if isinstance(var, (program.Tape.Register, _single, _vec)):
|
||||
if isinstance(var, (Tape.Register, _single, _vec)):
|
||||
new_var = g[name]
|
||||
if util.is_constant_float(new_var):
|
||||
raise CompilerError('cannot reassign constants in blocks')
|
||||
@@ -1285,7 +1316,7 @@ def do_while(loop_fn, g=None):
|
||||
return regint(0)
|
||||
"""
|
||||
scope = instructions.program.curr_block
|
||||
parent_node = get_tape().req_node
|
||||
parent_node = get_block().req_node
|
||||
# possibly unknown loop count
|
||||
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
|
||||
name='begin-loop')
|
||||
@@ -1334,9 +1365,10 @@ def else_then():
|
||||
raise CompilerError('else block already defined')
|
||||
# run the else block
|
||||
state.if_exit_block = instructions.program.curr_block
|
||||
state.req_child.add_node(get_tape(), 'else-block')
|
||||
req_node = state.req_child.add_node(get_tape(), 'else-block')
|
||||
instructions.program.curr_tape.start_new_basicblock(state.start_block, \
|
||||
name='else-block')
|
||||
name='else-block',
|
||||
req_node=req_node)
|
||||
state.else_block = instructions.program.curr_block
|
||||
state.has_else = True
|
||||
|
||||
@@ -1545,6 +1577,23 @@ def accept_client_connection(port):
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
return res
|
||||
|
||||
def init_client_connection(host, port, my_id, relative_port=True):
|
||||
""" Initiate connection to another party as client.
|
||||
|
||||
:param host: hostname
|
||||
:param port: port base (int/regint/cint)
|
||||
:param my_id: client id to use
|
||||
:param relative_port: whether to add party number to port number
|
||||
:returns: connection id
|
||||
|
||||
"""
|
||||
if relative_port:
|
||||
port = (port + get_player_id())._v
|
||||
res = regint()
|
||||
instructions.initclientconnection(
|
||||
res, regint.conv(port), regint.conv(my_id), host)
|
||||
return res
|
||||
|
||||
def break_point(name=''):
|
||||
"""
|
||||
Insert break point. This makes sure that all following code
|
||||
@@ -1643,7 +1692,9 @@ def cint_cint_division(a, b, k, f):
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
from Compiler.program import Program
|
||||
def sint_cint_division(a, b, k, f, kappa):
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def sint_cint_division(a, b, k, f, kappa, nearest=False):
|
||||
"""
|
||||
type(a) = sint, type(b) = cint
|
||||
"""
|
||||
@@ -1659,12 +1710,11 @@ def sint_cint_division(a, b, k, f, kappa):
|
||||
B = absolute_b
|
||||
W = w0
|
||||
|
||||
@for_range(1, theta)
|
||||
def block(i):
|
||||
A.link(TruncPr(A * W, 2*k, f, kappa))
|
||||
temp = (B * W) >> f
|
||||
W.link(two - temp)
|
||||
B.link(temp)
|
||||
for i in range(1, theta):
|
||||
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
|
||||
temp = (B * W + 2 * (f - 1)) >> f
|
||||
W = two - temp
|
||||
B = temp
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
def IntDiv(a, b, k, kappa=None):
|
||||
@@ -1691,13 +1741,16 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
assert 2 * f > k - nearest
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
|
||||
l_y = k + 3 * f - res_f
|
||||
comparison.require_ring_size(
|
||||
l_y, 'division (https://www.ifca.ai/pub/fc10/31_47.pdf)')
|
||||
|
||||
base.set_global_vector_size(b.size)
|
||||
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
|
||||
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
|
||||
x = alpha - b.extend(2 * k) * w
|
||||
base.reset_global_vector_size()
|
||||
|
||||
l_y = k + 3 * f - res_f
|
||||
y = a.extend(l_y) * w
|
||||
y = y.round(l_y, f, kappa, nearest, signed=True)
|
||||
|
||||
|
||||
@@ -834,7 +834,7 @@ class Dense(DenseBase):
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
else:
|
||||
prod = self.f_input
|
||||
max_size = program.Program.prog.budget // self.d_out
|
||||
max_size = get_program().budget // self.d_out
|
||||
@multithread(self.n_threads, N, max_size)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
@@ -1038,16 +1038,16 @@ class ElementWiseLayer(NoVariableLayer):
|
||||
self.inputs = inputs
|
||||
|
||||
def f_part(self, base, size):
|
||||
return self.f(self.X.get_part_vector(base, size))
|
||||
return self.f(self.X.get_vector(base, size))
|
||||
|
||||
def f_prime_part(self, base, size):
|
||||
return self.f_prime(self.Y.get_vector(base, size))
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
n_per_item = reduce(operator.mul, self.X.sizes[1:])
|
||||
@multithread(self.n_threads, len(batch), max(1, 1000 // n_per_item))
|
||||
@multithread(self.n_threads, len(batch) * n_per_item)
|
||||
def _(base, size):
|
||||
self.Y.assign_part_vector(self.f_part(base, size), base)
|
||||
self.Y.assign_vector(self.f_part(base, size), base)
|
||||
|
||||
if self.debug_output:
|
||||
name = self
|
||||
@@ -1095,9 +1095,9 @@ class Relu(ElementWiseLayer):
|
||||
self.comparisons = MultiArray(shape, sint)
|
||||
|
||||
def f_part(self, base, size):
|
||||
x = self.X.get_part_vector(base, size)
|
||||
x = self.X.get_vector(base, size)
|
||||
c = x > 0
|
||||
self.comparisons.assign_part_vector(c, base)
|
||||
self.comparisons.assign_vector(c, base)
|
||||
return c.if_else(x, 0)
|
||||
|
||||
def f_prime_part(self, base, size):
|
||||
@@ -1686,12 +1686,9 @@ class Conv2d(ConvBase):
|
||||
padding_h, padding_w = self.padding
|
||||
|
||||
if self.use_conv2ds:
|
||||
n_parts = max(1, round((self.n_threads or 1) / n_channels_out))
|
||||
while len(batch) % n_parts != 0:
|
||||
n_parts -= 1
|
||||
print('Convolution in %d parts' % n_parts)
|
||||
part_size = len(batch) // n_parts
|
||||
@for_range_multithread(self.n_threads, 1, [n_parts, n_channels_out])
|
||||
part_size = 1
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[len(batch), n_channels_out])
|
||||
def _(i, j):
|
||||
inputs = self.X.get_slice_vector(
|
||||
batch.get_part(i * part_size, part_size))
|
||||
@@ -2507,6 +2504,10 @@ class Optimizer:
|
||||
loss = self.layers[-1].average_loss(N)
|
||||
res = (loss < stop_on_loss) * (loss >= -1)
|
||||
self.stopped_on_loss.write(1 - res)
|
||||
print_ln_if(
|
||||
self.stopped_on_loss,
|
||||
'aborting epoch because loss is outside range: %s',
|
||||
loss)
|
||||
return res
|
||||
if self.print_losses:
|
||||
print_ln()
|
||||
@@ -2545,7 +2546,7 @@ class Optimizer:
|
||||
loss = MemValue(sfix(0))
|
||||
def f(start, batch_size, batch):
|
||||
batch.assign_vector(regint.inc(batch_size, start))
|
||||
self.forward(batch=batch)
|
||||
self.forward(batch=batch, run_last=False)
|
||||
part_truth = truth.get_part(start, batch_size)
|
||||
n_correct.iadd(
|
||||
self.layers[-1].reveal_correctness(batch_size, part_truth))
|
||||
@@ -2644,7 +2645,7 @@ class Optimizer:
|
||||
batch = Array.create_from(regint.inc(batch_size))
|
||||
self.forward(batch=batch, training=True)
|
||||
self.backward(batch=batch)
|
||||
self.update(0, batch=batch)
|
||||
self.update(0, batch=batch, i_batch=0)
|
||||
return
|
||||
@for_range(n_runs)
|
||||
def _(i):
|
||||
@@ -2697,6 +2698,8 @@ class Optimizer:
|
||||
if depreciation:
|
||||
self.gamma.imul(depreciation)
|
||||
print_ln('reducing learning rate to %s', self.gamma)
|
||||
print_ln_if(self.stopped_on_low_loss,
|
||||
'aborting run because of low loss')
|
||||
return 1 - self.stopped_on_low_loss
|
||||
if self.missing_newline:
|
||||
print_ln('')
|
||||
|
||||
@@ -20,7 +20,7 @@ import sys
|
||||
from functools import reduce
|
||||
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _secret
|
||||
from Compiler.types import _secret, _register
|
||||
from Compiler.library import *
|
||||
from Compiler.program import Program
|
||||
from Compiler import floatingpoint,comparison,permutation
|
||||
@@ -77,8 +77,8 @@ class intBlock(Block):
|
||||
self.lower, self.shift = \
|
||||
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
|
||||
Program.prog.security, True)
|
||||
trunc = (self.value - self.lower) / self.shift
|
||||
self.slice = trunc.mod2m(length, self.n_bits, False)
|
||||
trunc = (self.value - self.lower).field_div(self.shift)
|
||||
self.slice = trunc.mod2m(length, self.n_bits, signed=False)
|
||||
self.upper = (trunc - self.slice) * self.shift
|
||||
def get_slice(self):
|
||||
total_length = sum(self.lengths)
|
||||
@@ -89,13 +89,11 @@ class intBlock(Block):
|
||||
res = []
|
||||
remainder = self.slice
|
||||
for length,start in zip(self.lengths[:-1],series(self.lengths)):
|
||||
res.append(remainder.mod2m(length, total_length - start, False))
|
||||
res.append(remainder.mod2m(length, total_length - start,
|
||||
signed=False))
|
||||
remainder -= res[-1]
|
||||
if Program.prog.options.ring:
|
||||
remainder = remainder.trunc_zeros(length,
|
||||
total_length - start, False)
|
||||
else:
|
||||
remainder /= floatingpoint.two_power(length)
|
||||
remainder = remainder.trunc_zeros(length,
|
||||
total_length - start, False)
|
||||
res.append(remainder)
|
||||
return res
|
||||
def set_slice(self, value):
|
||||
@@ -208,23 +206,39 @@ def demux_list(x):
|
||||
return res
|
||||
|
||||
def demux_array(x, res=None):
|
||||
tmp = demux_matrix(x).array
|
||||
if res:
|
||||
try:
|
||||
assert issubclass(x.value_type, _register)
|
||||
res[:] = tmp[:]
|
||||
except:
|
||||
@for_range(len(res))
|
||||
def _(i):
|
||||
res[i] = tmp[i]
|
||||
else:
|
||||
res = tmp
|
||||
return res
|
||||
|
||||
def demux_matrix(x, n_threads=None):
|
||||
n = len(x)
|
||||
if res is None:
|
||||
res = Array(2**n, type(x[0]))
|
||||
if n == 0:
|
||||
return [1]
|
||||
m = len(x[0])
|
||||
t = type(x[0])
|
||||
res = Matrix(2**n, m, type(x[0]))
|
||||
if n == 1:
|
||||
res[0] = 1 - x[0]
|
||||
res[1] = x[0]
|
||||
else:
|
||||
a = Array(2**(n//2), type(x[0]))
|
||||
a = Matrix(2**(n//2), m, type(x[0]))
|
||||
a.assign(demux(x[:n//2]))
|
||||
b = Array(2**(n-n//2), type(x[0]))
|
||||
b = Matrix(2**(n-n//2), m, type(x[0]))
|
||||
b.assign(demux(x[n//2:]))
|
||||
@for_range_multithread(get_n_threads(len(res)), \
|
||||
max(1, n_parallel // len(b)), len(a))
|
||||
@for_range_opt_multithread(n_threads, len(a))
|
||||
def f(i):
|
||||
@for_range_parallel(n_parallel, len(b))
|
||||
@for_range_opt(len(b))
|
||||
def f(j):
|
||||
res[j * len(a) + i] = a[i] * b[j]
|
||||
res[j * len(a) + i][:] = a[i][:] * b[j][:]
|
||||
return res
|
||||
|
||||
def get_first_one(x):
|
||||
@@ -1717,7 +1731,8 @@ class BinaryORAM:
|
||||
|
||||
def OptimalORAM(size,*args,**kwargs):
|
||||
""" Create an ORAM instance suitable for the size based on
|
||||
experiments.
|
||||
experiments. This uses the approach by `Keller and Scholl
|
||||
<https://eprint.iacr.org/2014/137>`_.
|
||||
|
||||
:param size: number of elements
|
||||
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
|
||||
|
||||
@@ -11,6 +11,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import hashlib
|
||||
import random
|
||||
from collections import defaultdict, deque
|
||||
from functools import reduce
|
||||
|
||||
@@ -74,8 +75,8 @@ class Program(object):
|
||||
"""A program consists of a list of tapes representing the whole
|
||||
computation.
|
||||
|
||||
When compiling an :file:`.mpc` file, the single instances is
|
||||
available as :py:obj:`program` in order. When compiling directly
|
||||
When compiling an :file:`.mpc` file, the single instance is
|
||||
available as :py:obj:`program`. When compiling directly
|
||||
from Python code, an instance has to be created before running any
|
||||
instructions.
|
||||
"""
|
||||
@@ -89,6 +90,7 @@ class Program(object):
|
||||
self.name = name
|
||||
self.init_names(args)
|
||||
self._security = 40
|
||||
self.used_security = 0
|
||||
self.prime = None
|
||||
self.tapes = []
|
||||
if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1:
|
||||
@@ -113,7 +115,8 @@ class Program(object):
|
||||
if not self.bit_length:
|
||||
self.bit_length = 64
|
||||
print("Default bit length for compilation:", self.bit_length)
|
||||
print("Default security parameter for compilation:", self.security)
|
||||
if not (options.binary or options.garbled):
|
||||
print("Default security parameter for compilation:", self._security)
|
||||
self.galois_length = int(options.galois)
|
||||
if self.verbose:
|
||||
print("Galois length:", self.galois_length)
|
||||
@@ -185,10 +188,13 @@ class Program(object):
|
||||
self.relevant_opts = set()
|
||||
self.n_running_threads = None
|
||||
self.input_files = {}
|
||||
self.base_addresses = {}
|
||||
self.base_addresses = util.dict_by_id()
|
||||
self._protect_memory = False
|
||||
self.mem_protect_stack = []
|
||||
self._always_active = True
|
||||
self.active = True
|
||||
self.prevent_breaks = False
|
||||
self.cisc_to_function = True
|
||||
if not self.options.cisc:
|
||||
self.options.cisc = not self.options.optimize_hard
|
||||
|
||||
@@ -249,7 +255,7 @@ class Program(object):
|
||||
else:
|
||||
raise CompilerError(
|
||||
"found none of the potential input files: " +
|
||||
", ".join("'%s'" % x for x in [args[0]] + infiles))
|
||||
", ".join("'%s'" % x for x in infiles))
|
||||
"""
|
||||
self.name is input file name (minus extension) + any optional arguments.
|
||||
Used to generate output filenames
|
||||
@@ -352,7 +358,8 @@ class Program(object):
|
||||
)
|
||||
self.curr_tape.start_new_basicblock(name="post-run_tape")
|
||||
for arg in args:
|
||||
self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree)
|
||||
self.curr_block.req_node.children.append(
|
||||
self.tapes[arg[0]].req_tree)
|
||||
return thread_numbers
|
||||
|
||||
def join_tape(self, thread_number):
|
||||
@@ -400,6 +407,7 @@ class Program(object):
|
||||
sch_file.write("lgp:%s" % req)
|
||||
sch_file.write("\n")
|
||||
sch_file.write("opts: %s\n" % " ".join(self.relevant_opts))
|
||||
sch_file.write("sec:%d\n" % self.used_security)
|
||||
sch_file.close()
|
||||
h = hashlib.sha256()
|
||||
for tape in self.tapes:
|
||||
@@ -433,7 +441,7 @@ class Program(object):
|
||||
"""The basic block that is currently being created."""
|
||||
return self.curr_tape.active_basicblock
|
||||
|
||||
def malloc(self, size, mem_type, reg_type=None, creator_tape=None):
|
||||
def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=True):
|
||||
"""Allocate memory from the top"""
|
||||
if not isinstance(size, int):
|
||||
raise CompilerError("size must be known at compile time")
|
||||
@@ -456,7 +464,7 @@ class Program(object):
|
||||
else:
|
||||
raise CompilerError("cannot allocate memory " "outside main thread")
|
||||
blocks = self.free_mem_blocks[mem_type]
|
||||
addr = blocks.pop(size)
|
||||
addr = blocks.pop(size) if use_freed else None
|
||||
if addr is not None:
|
||||
self.saved += size
|
||||
else:
|
||||
@@ -469,11 +477,13 @@ class Program(object):
|
||||
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
|
||||
if single_size:
|
||||
from .library import get_thread_number, runtime_error_if
|
||||
|
||||
bak = self.curr_tape.active_basicblock
|
||||
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
|
||||
tn = get_thread_number()
|
||||
runtime_error_if(tn > self.n_running_threads, "malloc")
|
||||
res = addr + single_size * (tn - 1)
|
||||
self.base_addresses[str(res)] = addr
|
||||
self.curr_tape.active_basicblock = bak
|
||||
self.base_addresses[res] = addr
|
||||
return res
|
||||
else:
|
||||
return addr
|
||||
@@ -482,7 +492,7 @@ class Program(object):
|
||||
"""Free memory"""
|
||||
now = True
|
||||
if not util.is_constant(addr):
|
||||
addr = self.base_addresses[str(addr)]
|
||||
addr = self.base_addresses[addr]
|
||||
now = self.curr_tape == self.tapes[0]
|
||||
size, pool = self.allocated_mem_blocks[addr, mem_type]
|
||||
if self.curr_block.alloc_pool is not pool:
|
||||
@@ -524,7 +534,8 @@ class Program(object):
|
||||
self.public_input_file.close()
|
||||
|
||||
def finalize_memory(self):
|
||||
self.curr_tape.start_new_basicblock(None, "memory-usage")
|
||||
self.curr_tape.start_new_basicblock(None, "memory-usage",
|
||||
req_node=self.curr_tape.req_tree)
|
||||
# reset register counter to 0
|
||||
if not self.options.noreallocate:
|
||||
self.curr_tape.init_registers()
|
||||
@@ -575,6 +586,7 @@ class Program(object):
|
||||
def security(self):
|
||||
"""The statistical security parameter for non-linear
|
||||
functions."""
|
||||
self.used_security = max(self.used_security, self._security)
|
||||
return self._security
|
||||
|
||||
@security.setter
|
||||
@@ -701,6 +713,13 @@ class Program(object):
|
||||
""" Enable or disable memory protection. """
|
||||
self._protect_memory = status
|
||||
|
||||
def open_memory_scope(self, key=None):
|
||||
self.mem_protect_stack.append(self._protect_memory)
|
||||
self.protect_memory(key or object())
|
||||
|
||||
def close_memory_scope(self):
|
||||
self.protect_memory(self.mem_protect_stack.pop())
|
||||
|
||||
def use_cisc(self):
|
||||
return self.options.cisc and (not self.prime or self.rabbit_gap()) \
|
||||
and not self.options.max_parallel_open
|
||||
@@ -725,7 +744,7 @@ class Program(object):
|
||||
self._always_active = False
|
||||
|
||||
@staticmethod
|
||||
def read_tapes(schedule):
|
||||
def read_schedule(schedule):
|
||||
m = re.search(r"([^/]*)\.mpc", schedule)
|
||||
if m:
|
||||
schedule = m.group(1)
|
||||
@@ -733,7 +752,7 @@ class Program(object):
|
||||
schedule = "Programs/Schedules/%s.sch" % schedule
|
||||
|
||||
try:
|
||||
lines = open(schedule).readlines()
|
||||
return open(schedule).readlines()
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"%s not found, have you compiled the program?" % schedule,
|
||||
@@ -741,9 +760,25 @@ class Program(object):
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
@classmethod
|
||||
def read_tapes(cls, schedule):
|
||||
lines = cls.read_schedule(schedule)
|
||||
for tapename in lines[2].split(" "):
|
||||
yield tapename.strip().split(":")[0]
|
||||
|
||||
@classmethod
|
||||
def read_n_threads(cls, schedule):
|
||||
return int(cls.read_schedule(schedule)[0])
|
||||
|
||||
@classmethod
|
||||
def read_domain_size(cls, schedule):
|
||||
from Compiler.instructions import reqbl_class
|
||||
tapename = cls.read_schedule(schedule)[2].strip().split(":")[0]
|
||||
for inst in Tape.read_instructions(tapename):
|
||||
if inst.code == reqbl_class.code:
|
||||
bl = inst.args[0]
|
||||
return (abs(bl.i) + 63) // 64 * 8
|
||||
|
||||
|
||||
class Tape:
|
||||
"""A tape contains a list of basic blocks, onto which instructions are added."""
|
||||
@@ -755,13 +790,12 @@ class Tape:
|
||||
self.init_names(name)
|
||||
self.init_registers()
|
||||
self.req_tree = self.ReqNode(name)
|
||||
self.req_node = self.req_tree
|
||||
self.basicblocks = []
|
||||
self.purged = False
|
||||
self.block_counter = 0
|
||||
self.active_basicblock = None
|
||||
self.old_allocated_mem = program.allocated_mem.copy()
|
||||
self.start_new_basicblock()
|
||||
self.start_new_basicblock(req_node=self.req_tree)
|
||||
self._is_empty = False
|
||||
self.merge_opens = True
|
||||
self.if_states = []
|
||||
@@ -774,7 +808,8 @@ class Tape:
|
||||
self.warned_about_mem = False
|
||||
|
||||
class BasicBlock(object):
|
||||
def __init__(self, parent, name, scope, exit_condition=None):
|
||||
def __init__(self, parent, name, scope, exit_condition=None,
|
||||
req_node=None):
|
||||
self.parent = parent
|
||||
self.instructions = []
|
||||
self.name = name
|
||||
@@ -794,6 +829,8 @@ class Tape:
|
||||
self.n_to_merge = 0
|
||||
self.rounds = Tape.ReqNum()
|
||||
self.warn_about_mem = parent.program.warn_about_mem[-1]
|
||||
self.req_node = req_node
|
||||
self.used_from_scope = set()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instructions)
|
||||
@@ -860,17 +897,25 @@ class Tape:
|
||||
req_node.num += self.rounds
|
||||
|
||||
def expand_cisc(self):
|
||||
new_instructions = []
|
||||
if self.parent.program.options.keep_cisc is not None:
|
||||
skip = ["LTZ", "Trunc"]
|
||||
skip = ["LTZ", "Trunc", "EQZ"]
|
||||
skip += self.parent.program.options.keep_cisc.split(",")
|
||||
else:
|
||||
skip = []
|
||||
tape = self.parent
|
||||
tape.start_new_basicblock(scope=self.scope, req_node=self.req_node,
|
||||
name="cisc")
|
||||
start_block = tape.basicblocks[-1]
|
||||
start_block.alloc_pool = self.alloc_pool
|
||||
for inst in self.instructions:
|
||||
new_inst, n_rounds = inst.expand_merged(skip)
|
||||
new_instructions.extend(new_inst)
|
||||
self.n_rounds += n_rounds
|
||||
self.instructions = new_instructions
|
||||
inst.expand_merged(skip)
|
||||
self.instructions = tape.active_basicblock.instructions
|
||||
if start_block == tape.basicblocks[-1]:
|
||||
res = self
|
||||
else:
|
||||
res = start_block
|
||||
tape.basicblocks[-1] = self
|
||||
return res
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@@ -885,7 +930,8 @@ class Tape:
|
||||
self._is_empty = len(self.basicblocks) == 0
|
||||
return self._is_empty
|
||||
|
||||
def start_new_basicblock(self, scope=False, name=""):
|
||||
def start_new_basicblock(self, scope=False, name="", req_node=None):
|
||||
assert not self.program.prevent_breaks
|
||||
if self.program.verbose and self.active_basicblock and \
|
||||
self.program.allocated_mem != self.old_allocated_mem:
|
||||
print("New allocated memory in %s " % self.active_basicblock.name,
|
||||
@@ -900,10 +946,12 @@ class Tape:
|
||||
scope = self.active_basicblock
|
||||
suffix = "%s-%d" % (name, self.block_counter)
|
||||
self.block_counter += 1
|
||||
sub = self.BasicBlock(self, self.name + "-" + suffix, scope)
|
||||
if req_node is None:
|
||||
req_node = self.active_basicblock.req_node
|
||||
sub = self.BasicBlock(self, self.name + "-" + suffix, scope,
|
||||
req_node=req_node)
|
||||
self.basicblocks.append(sub)
|
||||
self.active_basicblock = sub
|
||||
self.req_node.add_block(sub)
|
||||
# print 'Compiling basic block', sub.name
|
||||
|
||||
def init_registers(self):
|
||||
@@ -1054,12 +1102,20 @@ class Tape:
|
||||
print("Re-allocating...")
|
||||
allocator = al.StraightlineAllocator(REG_MAX, self.program)
|
||||
|
||||
# make addresses available in functions
|
||||
for addr in self.program.base_addresses:
|
||||
if addr.program == self and self.basicblocks:
|
||||
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
|
||||
|
||||
seen = set()
|
||||
|
||||
def alloc(block):
|
||||
allocator.update_usage(block.alloc_pool)
|
||||
for reg in sorted(
|
||||
block.used_from_scope, key=lambda x: (x.reg_type, x.i)
|
||||
):
|
||||
allocator.alloc_reg(reg, block.alloc_pool)
|
||||
seen.add(block)
|
||||
|
||||
def alloc_loop(block):
|
||||
left = deque([block])
|
||||
@@ -1067,7 +1123,8 @@ class Tape:
|
||||
block = left.popleft()
|
||||
alloc(block)
|
||||
for child in block.children:
|
||||
left.append(child)
|
||||
if child not in seen:
|
||||
left.append(child)
|
||||
|
||||
allocator.old_pool = None
|
||||
for i, block in enumerate(reversed(self.basicblocks)):
|
||||
@@ -1101,6 +1158,8 @@ class Tape:
|
||||
# offline data requirements
|
||||
if self.program.verbose:
|
||||
print("Compile offline data requirements...")
|
||||
for block in self.basicblocks:
|
||||
block.req_node.add_block(block)
|
||||
self.req_num = self.req_tree.aggregate()
|
||||
if self.program.verbose:
|
||||
print("Tape requires", self.req_num)
|
||||
@@ -1160,8 +1219,24 @@ class Tape:
|
||||
|
||||
@unpurged
|
||||
def expand_cisc(self):
|
||||
mapping = {None: None}
|
||||
blocks = self.basicblocks[:]
|
||||
self.basicblocks = []
|
||||
for block in blocks:
|
||||
expanded = block.expand_cisc()
|
||||
mapping[block] = expanded
|
||||
for block in self.basicblocks:
|
||||
block.expand_cisc()
|
||||
if block not in mapping:
|
||||
mapping[block] = block
|
||||
for block in self.basicblocks:
|
||||
block.exit_block = mapping[block.exit_block]
|
||||
if block.exit_block is not None:
|
||||
assert block.exit_block in self.basicblocks
|
||||
if block.previous_block and mapping[block] != block:
|
||||
mapping[block].previous_block = block.previous_block
|
||||
mapping[block].sub_block = block.sub_block
|
||||
block.previous_block = None
|
||||
del block.sub_block
|
||||
|
||||
@unpurged
|
||||
def _get_instructions(self):
|
||||
@@ -1320,27 +1395,38 @@ class Tape:
|
||||
return repr(dict(self))
|
||||
|
||||
class ReqNode(object):
|
||||
__slots__ = ["num", "children", "name", "blocks"]
|
||||
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
|
||||
|
||||
def __init__(self, name):
|
||||
self.children = []
|
||||
self._children = []
|
||||
self.name = name
|
||||
self.blocks = []
|
||||
self.aggregated = None
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
self.aggregated = None
|
||||
return self._children
|
||||
|
||||
def aggregate(self, *args):
|
||||
if self.aggregated is not None:
|
||||
return self.aggregated
|
||||
self.num = Tape.ReqNum()
|
||||
for block in self.blocks:
|
||||
block.add_usage(self)
|
||||
res = reduce(
|
||||
lambda x, y: x + y.aggregate(self.name), self.children, self.num
|
||||
)
|
||||
self.aggregated = res
|
||||
return res
|
||||
|
||||
def increment(self, data_type, num=1):
|
||||
self.num[data_type] += num
|
||||
self.aggregated = None
|
||||
|
||||
def add_block(self, block):
|
||||
self.blocks.append(block)
|
||||
self.aggregated = None
|
||||
|
||||
class ReqChild(object):
|
||||
__slots__ = ["aggregator", "nodes", "parent"]
|
||||
@@ -1369,18 +1455,18 @@ class Tape:
|
||||
def add_node(self, tape, name):
|
||||
new_node = Tape.ReqNode(name)
|
||||
self.nodes.append(new_node)
|
||||
tape.req_node = new_node
|
||||
return new_node
|
||||
|
||||
def open_scope(self, aggregator, scope=False, name=""):
|
||||
child = self.ReqChild(aggregator, self.req_node)
|
||||
self.req_node.children.append(child)
|
||||
child.add_node(self, "%s-%d" % (name, len(self.basicblocks)))
|
||||
self.start_new_basicblock(name=name)
|
||||
req_node = self.active_basicblock.req_node
|
||||
child = self.ReqChild(aggregator, req_node)
|
||||
req_node.children.append(child)
|
||||
node = child.add_node(self, "%s-%d" % (name, len(self.basicblocks)))
|
||||
self.start_new_basicblock(name=name, req_node=node)
|
||||
return child
|
||||
|
||||
def close_scope(self, outer_scope, parent_req_node, name):
|
||||
self.req_node = parent_req_node
|
||||
self.start_new_basicblock(outer_scope, name)
|
||||
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
|
||||
|
||||
def require_bit_length(self, bit_length, t="p"):
|
||||
if t == "p":
|
||||
@@ -1553,7 +1639,7 @@ class Tape:
|
||||
diff_block = isinstance(other, Tape.Register) and self.block != other.block
|
||||
other = type(self)(other)
|
||||
if not diff_block:
|
||||
self.program.start_new_basicblock()
|
||||
self.program.start_new_basicblock(name="update")
|
||||
if self.program != other.program:
|
||||
raise CompilerError(
|
||||
'cannot update register with one from another thread')
|
||||
@@ -1575,6 +1661,8 @@ class Tape:
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
return self.reg_type + str(self.i) + \
|
||||
("(%d)" % self.size if self.size is not None and self.size > 1
|
||||
else "")
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
@@ -20,7 +20,7 @@ correct signature.
|
||||
Basic types
|
||||
-----------
|
||||
|
||||
All basic can be used as vectors, that is one instance representing
|
||||
All basic types can be used as vectors, that is one instance representing
|
||||
several values, with all operations being executed element-wise. For
|
||||
example, the following computes ten multiplications of integers input
|
||||
by party 0 and 1::
|
||||
@@ -588,7 +588,7 @@ class _secret_structure(_structure):
|
||||
|
||||
@classmethod
|
||||
def input_tensor_via(cls, player, content=None, shape=None, binary=True,
|
||||
one_hot=False):
|
||||
one_hot=False, skip_input=False, n_bytes=None):
|
||||
"""
|
||||
Input tensor-like data via a player. This overwrites the input
|
||||
file for the relevant player. The following returns an
|
||||
@@ -630,7 +630,10 @@ class _secret_structure(_structure):
|
||||
else:
|
||||
t = numpy.single
|
||||
else:
|
||||
t = numpy.int64
|
||||
if n_bytes == 1:
|
||||
t = numpy.int8
|
||||
else:
|
||||
t = numpy.int64
|
||||
if one_hot:
|
||||
content = numpy.eye(content.max() + 1)[content]
|
||||
content = content.astype(t)
|
||||
@@ -666,9 +669,10 @@ class _secret_structure(_structure):
|
||||
if requested_shape is not None and \
|
||||
list(shape) != list(requested_shape):
|
||||
raise CompilerError('content contradicts shape')
|
||||
res = cls.Tensor(shape)
|
||||
res.input_from(player, binary=binary)
|
||||
return res
|
||||
if not skip_input:
|
||||
res = cls.Tensor(shape)
|
||||
res.input_from(player, binary=binary, n_bytes=n_bytes)
|
||||
return res
|
||||
|
||||
class _vec(Tape._no_truth):
|
||||
def link(self, other):
|
||||
@@ -681,6 +685,13 @@ class _register(Tape.Register, _number, _structure):
|
||||
def n_elements():
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def new_vector(cls, size):
|
||||
return cls(size=size)
|
||||
|
||||
def vector_size(self):
|
||||
return self.size
|
||||
|
||||
@vectorized_classmethod
|
||||
def conv(cls, val):
|
||||
if isinstance(val, MemValue):
|
||||
@@ -707,7 +718,7 @@ class _register(Tape.Register, _number, _structure):
|
||||
except AttributeError:
|
||||
try:
|
||||
return type(val)(cls.hard_conv(v) for v in val)
|
||||
except TypeError:
|
||||
except (TypeError, CompilerError):
|
||||
pass
|
||||
return cls(val)
|
||||
|
||||
@@ -756,11 +767,11 @@ class _register(Tape.Register, _number, _structure):
|
||||
return sum(cls.conv(b) << i for i,b in enumerate(bits))
|
||||
|
||||
@classmethod
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
def malloc(cls, size, creator_tape=None, **kwargs):
|
||||
""" Allocate memory (statically).
|
||||
|
||||
:param size: compile-time (int) """
|
||||
return program.malloc(size, cls, creator_tape=creator_tape)
|
||||
return program.malloc(size, cls, creator_tape=creator_tape, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def free(cls, addr):
|
||||
@@ -783,7 +794,11 @@ class _register(Tape.Register, _number, _structure):
|
||||
else:
|
||||
self[i].load_other(x)
|
||||
elif val is not None:
|
||||
self.load_other(val)
|
||||
try:
|
||||
self.load_other(val)
|
||||
except:
|
||||
raise CompilerError(
|
||||
"cannot convert '%s' to '%s'" % (type(val), type(self)))
|
||||
|
||||
def _new_by_number(self, i, size=1):
|
||||
res = type(self)(size=size)
|
||||
@@ -943,16 +958,15 @@ class _clear(_arithmetic_register):
|
||||
return self.clear_op(other, subc, subcfi, True)
|
||||
__rsub__.__doc__ = __sub__.__doc__
|
||||
|
||||
def __truediv__(self, other):
|
||||
def field_div(self, other):
|
||||
""" Field division of public values. Not available for
|
||||
computation modulo a power of two.
|
||||
|
||||
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
|
||||
return self.clear_op(other, divc, divci)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return self.coerce_op(other, divc, True)
|
||||
__rtruediv__.__doc__ = __truediv__.__doc__
|
||||
try:
|
||||
return other._rfield_div(self)
|
||||
except AttributeError:
|
||||
return self.clear_op(other, divc, divci)
|
||||
|
||||
def __and__(self, other):
|
||||
""" Bit-wise AND of public values.
|
||||
@@ -1107,6 +1121,20 @@ class cint(_clear, _int):
|
||||
def __rfloordiv__(self, other):
|
||||
return self.coerce_op(other, floordivc, True)
|
||||
|
||||
def __truediv__(self, other):
|
||||
""" Clear fixed-point division.
|
||||
|
||||
:param other: any compatible type """
|
||||
if isinstance(other, cint):
|
||||
return other.__rtruediv__(self)
|
||||
try:
|
||||
return cfix._new(self) / cfix._new(cint(other))
|
||||
except:
|
||||
return NotImplemented
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return cfix._new(other) / cfix._new(self)
|
||||
|
||||
@vectorize
|
||||
def less_than(self, other, bit_length):
|
||||
""" Clear comparison for particular bit length.
|
||||
@@ -1348,6 +1376,11 @@ class cgf2n(_clear, _gf2n):
|
||||
""" Identity. """
|
||||
return self
|
||||
|
||||
__truediv__ = _clear.field_div
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return self.coerce_op(other, divc, True)
|
||||
|
||||
@vectorize
|
||||
def __invert__(self):
|
||||
""" Clear bit-wise inversion. """
|
||||
@@ -1594,8 +1627,14 @@ class regint(_register, _int):
|
||||
return self.int_op(other, divint, True)
|
||||
__rfloordiv__.__doc__ = __floordiv__.__doc__
|
||||
|
||||
__truediv__ = __floordiv__
|
||||
__rtruediv__ = __rfloordiv__
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, _gf2n):
|
||||
return NotImplemented
|
||||
else:
|
||||
return cint(self) / other
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other / cint(self)
|
||||
|
||||
def __mod__(self, other):
|
||||
""" Clear modulo computation.
|
||||
@@ -1603,7 +1642,7 @@ class regint(_register, _int):
|
||||
:param other: regint/cint/int """
|
||||
if util.is_constant(other) and other >= 2 ** 64:
|
||||
return self
|
||||
return self - (self / other) * other
|
||||
return self - (self // other) * other
|
||||
|
||||
def __rmod__(self, other):
|
||||
""" Clear modulo computation.
|
||||
@@ -1661,7 +1700,7 @@ class regint(_register, _int):
|
||||
|
||||
def __rshift__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self / 2**other
|
||||
return self // 2**other
|
||||
else:
|
||||
return self.cint_op(other, operator.rshift)
|
||||
|
||||
@@ -1793,6 +1832,9 @@ class localint(Tape._no_truth):
|
||||
__eq__ = lambda self, other: localint(self._v == other)
|
||||
__ne__ = lambda self, other: localint(self._v != other)
|
||||
|
||||
__add__ = lambda self, other: localint(self._v + other)
|
||||
__radd__ = lambda self, other: localint(self._v + other)
|
||||
|
||||
class personal(Tape._no_truth):
|
||||
""" Value known to one player. Supports operations with public
|
||||
values and personal values known to the same player. Can be used
|
||||
@@ -1812,7 +1854,7 @@ class personal(Tape._no_truth):
|
||||
self._v = value
|
||||
|
||||
@classmethod
|
||||
def read_int(cls, player):
|
||||
def read_int(cls, player, n_bytes=None):
|
||||
""" Read integer from
|
||||
``Player-Data/Input-Binary-P<player>-<threadnum>`` only on
|
||||
party :py:obj:`player`.
|
||||
@@ -1822,7 +1864,7 @@ class personal(Tape._no_truth):
|
||||
|
||||
"""
|
||||
tmp = cint()
|
||||
fixinput(player, tmp, 0, 0)
|
||||
fixinput(player, tmp, n_bytes or 0, 0)
|
||||
return cls(player, tmp)
|
||||
|
||||
@classmethod
|
||||
@@ -2229,7 +2271,7 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
return self.secret_op(other, subs, submr, subsfi, True)
|
||||
__rsub__.__doc__ = __sub__.__doc__
|
||||
|
||||
def __truediv__(self, other):
|
||||
def field_div(self, other):
|
||||
""" Secret field division.
|
||||
|
||||
:param other: any compatible type """
|
||||
@@ -2237,13 +2279,12 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
one = self.clear_type(1, size=other.size)
|
||||
except AttributeError:
|
||||
one = self.clear_type(1)
|
||||
return self * (one / other)
|
||||
return self * one.field_div(other)
|
||||
|
||||
@vectorize
|
||||
def __rtruediv__(self, other):
|
||||
def _rfield_div(self, other):
|
||||
a,b = self.get_random_inverse()
|
||||
return other * a / (a * self).reveal()
|
||||
__rtruediv__.__doc__ = __truediv__.__doc__
|
||||
return other * a.field_div((a * self).reveal())
|
||||
|
||||
@set_instruction_type
|
||||
@vectorize
|
||||
@@ -2311,8 +2352,8 @@ class sint(_secret, _int):
|
||||
|
||||
The following operations work as expected in the computation
|
||||
domain (modulo a prime or a power of two): ``+, -, *``. ``/``
|
||||
denotes the field division modulo a prime. It will reveal if the
|
||||
divisor is zero. Comparisons operators (``==, !=, <, <=, >, >=``)
|
||||
denotes a fixed-point division.
|
||||
Comparisons operators (``==, !=, <, <=, >, >=``)
|
||||
assume that the element in the computation domain represents a
|
||||
signed integer in a restricted range, see below. The same holds
|
||||
for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and
|
||||
@@ -2398,14 +2439,14 @@ class sint(_secret, _int):
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_input_from(cls, player, binary=False):
|
||||
def get_input_from(cls, player, binary=False, n_bytes=None):
|
||||
""" Secret input.
|
||||
|
||||
:param player: public (regint/cint/int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
if binary:
|
||||
return cls(personal.read_int(player))
|
||||
return cls(personal.read_int(player, n_bytes=n_bytes))
|
||||
else:
|
||||
res = cls()
|
||||
inputmixed('int', res, player)
|
||||
@@ -2540,6 +2581,12 @@ class sint(_secret, _int):
|
||||
"""
|
||||
writesockets(client_id, message_type, values[0].size, *values)
|
||||
|
||||
@vectorize
|
||||
def write_fully_to_socket(self, client_id,
|
||||
message_type=ClientMessageType.NoType):
|
||||
""" Send full secret to socket """
|
||||
writesockets(client_id, message_type, self.size, self)
|
||||
|
||||
@vectorize
|
||||
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Send only share to socket """
|
||||
@@ -2557,7 +2604,9 @@ class sint(_secret, _int):
|
||||
|
||||
@classmethod
|
||||
def read_from_file(cls, start, n_items):
|
||||
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
|
||||
""" Read shares from
|
||||
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
|
||||
section <persistence>` for details on the data format.
|
||||
|
||||
:param start: starting position in number of shares from beginning (int/regint/cint)
|
||||
:param n_items: number of items (int)
|
||||
@@ -2572,7 +2621,8 @@ class sint(_secret, _int):
|
||||
@staticmethod
|
||||
def write_to_file(shares, position=None):
|
||||
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
|
||||
(appending at the end).
|
||||
(appending at the end). See :ref:`this section <persistence>`
|
||||
for details on the data format.
|
||||
|
||||
:param shares: (list or iterable of sint)
|
||||
:param position: start position (int/regint/cint),
|
||||
@@ -2641,7 +2691,7 @@ class sint(_secret, _int):
|
||||
res = sintbit()
|
||||
comparison.LTZ(res, self - other,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
security)
|
||||
return res
|
||||
|
||||
@read_mem_value
|
||||
@@ -2651,7 +2701,7 @@ class sint(_secret, _int):
|
||||
res = sintbit()
|
||||
comparison.LTZ(res, other - self,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
security)
|
||||
return res
|
||||
|
||||
@read_mem_value
|
||||
@@ -2670,7 +2720,7 @@ class sint(_secret, _int):
|
||||
def __eq__(self, other, bit_length=None, security=None):
|
||||
return sintbit.conv(
|
||||
floatingpoint.EQZ(self - other, bit_length or program.bit_length,
|
||||
security or program.security))
|
||||
security))
|
||||
|
||||
@read_mem_value
|
||||
@type_comp
|
||||
@@ -2709,7 +2759,6 @@ class sint(_secret, _int):
|
||||
:param bit_length: bit length of input (default: global bit length)
|
||||
"""
|
||||
bit_length = bit_length or program.bit_length
|
||||
security = security or program.security
|
||||
if isinstance(m, int):
|
||||
if m == 0:
|
||||
return 0
|
||||
@@ -2737,7 +2786,7 @@ class sint(_secret, _int):
|
||||
:param bit_length: bit length of input (default: global bit length)
|
||||
"""
|
||||
return floatingpoint.Pow2(self, bit_length or program.bit_length, \
|
||||
security or program.security)
|
||||
security)
|
||||
|
||||
def __lshift__(self, other, bit_length=None, security=None):
|
||||
""" Secret left shift.
|
||||
@@ -2756,7 +2805,6 @@ class sint(_secret, _int):
|
||||
:param bit_length: bit length of input (default: global bit length)
|
||||
"""
|
||||
bit_length = bit_length or program.bit_length
|
||||
security = security or program.security
|
||||
if isinstance(other, int):
|
||||
if other == 0:
|
||||
return self
|
||||
@@ -2783,7 +2831,7 @@ class sint(_secret, _int):
|
||||
""" Secret right shift.
|
||||
|
||||
:param other: secret or public integer (sint/cint/regint/int) of globale bit length if secret """
|
||||
return floatingpoint.Trunc(other, program.bit_length, self, program.security)
|
||||
return floatingpoint.Trunc(other, program.bit_length, self)
|
||||
|
||||
@vectorize
|
||||
def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False):
|
||||
@@ -2791,7 +2839,7 @@ class sint(_secret, _int):
|
||||
if bit_length == 0:
|
||||
return []
|
||||
bit_length = bit_length or program.bit_length
|
||||
assert program.security == security or program.security
|
||||
program.non_linear.check_security(security)
|
||||
return program.non_linear.bit_dec(self, bit_length, bit_length,
|
||||
maybe_mixed)
|
||||
|
||||
@@ -2815,7 +2863,6 @@ class sint(_secret, _int):
|
||||
:param kappa: statistical security parameter (int)
|
||||
:param nearest: bool
|
||||
:param signed: bool """
|
||||
kappa = kappa or program.security
|
||||
secret = isinstance(m, sint)
|
||||
if nearest:
|
||||
if secret:
|
||||
@@ -2830,6 +2877,20 @@ class sint(_secret, _int):
|
||||
def Norm(self, k, f, kappa=None, simplex_flag=False):
|
||||
return library.Norm(self, k, f, kappa, simplex_flag)
|
||||
|
||||
def __truediv__(self, other):
|
||||
""" Secret fixed-point division.
|
||||
|
||||
:param other: any compatible type """
|
||||
if isinstance(other, sint):
|
||||
return other.__rtruediv__(self)
|
||||
try:
|
||||
return sfix._new(self) / cfix._new(cint(other), f=sfix.f, k=sfix.k)
|
||||
except:
|
||||
return NotImplemented
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return sfix._new(other) / sfix._new(self)
|
||||
|
||||
@vectorize
|
||||
def int_div(self, other, bit_length=None, security=None):
|
||||
""" Secret integer division. Note that the domain bit length
|
||||
@@ -2839,7 +2900,7 @@ class sint(_secret, _int):
|
||||
:param bit_length: bit length of input (default: global bit length)
|
||||
"""
|
||||
k = bit_length or program.bit_length
|
||||
kappa = security or program.security
|
||||
kappa = security
|
||||
tmp = library.IntDiv(self, other, k, kappa)
|
||||
res = type(self)()
|
||||
comparison.Trunc(res, tmp, 2 * k, k, kappa, True)
|
||||
@@ -2963,6 +3024,7 @@ class sint(_secret, _int):
|
||||
gensecshuffle(res, n)
|
||||
return res
|
||||
|
||||
@read_mem_value
|
||||
def secure_permute(self, shuffle, unit_size=1, reverse=False):
|
||||
res = sint(size=self.size)
|
||||
applyshuffle(res, self, unit_size, shuffle, reverse)
|
||||
@@ -3005,6 +3067,21 @@ class sint(_secret, _int):
|
||||
def copy_from_part(self, source, base, size):
|
||||
picks(self, source, base, 1)
|
||||
|
||||
def get_reverse_vector(self):
|
||||
res = type(self)(size=self.size)
|
||||
picks(res, self, self.size - 1, -1)
|
||||
return res
|
||||
|
||||
def get_vector(self, base=0, size=None):
|
||||
if size is None:
|
||||
size = len(self) - base
|
||||
if base == 0 and size == len(self):
|
||||
return self
|
||||
assert base + size <= len(self)
|
||||
res = type(self)(size=size)
|
||||
picks(res, self, base, 1)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
parts = list(parts)
|
||||
@@ -3013,6 +3090,10 @@ class sint(_secret, _int):
|
||||
concats(res, *args)
|
||||
return res
|
||||
|
||||
@vectorize
|
||||
def output(self):
|
||||
print_reg_plains(self)
|
||||
|
||||
class sintbit(sint):
|
||||
""" :py:class:`sint` holding a bit, supporting binary operations
|
||||
(``&, |, ^``). """
|
||||
@@ -3137,6 +3218,7 @@ class sgf2n(_secret, _gf2n):
|
||||
""" Store in memory by public address. """
|
||||
self._store_in_mem(address, gstms, gstmsi)
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, val=None, size=None):
|
||||
super(sgf2n, self).__init__('sg', val=val, size=size)
|
||||
|
||||
@@ -3144,6 +3226,9 @@ class sgf2n(_secret, _gf2n):
|
||||
""" Identity. """
|
||||
return self
|
||||
|
||||
__truediv__ = _secret.field_div
|
||||
__rtruediv__ = _secret._rfield_div
|
||||
|
||||
@vectorize
|
||||
def __invert__(self):
|
||||
""" Secret bit-wise inversion. """
|
||||
@@ -3637,6 +3722,7 @@ class sgf2nint(_bitint, sgf2n):
|
||||
raise CompilerError('Invalid signed %d-bit integer: %d' % \
|
||||
(self.n_bits, other))
|
||||
|
||||
@vectorize
|
||||
def load_other(self, other):
|
||||
if isinstance(other, sgf2nint):
|
||||
gmovs(self, self.compose(other.bit_decompose(self.n_bits)))
|
||||
@@ -4200,6 +4286,15 @@ class _single(_number, _secret_structure):
|
||||
cls.int_type.write_shares_to_socket(
|
||||
client_id, [x.v for x in values], message_type)
|
||||
|
||||
@vectorized_classmethod
|
||||
def read_from_socket(cls, client_id, n=1):
|
||||
return util.untuplify([cls._new(x) for x in util.tuplify(
|
||||
cls.int_type.read_from_socket(client_id, n))])
|
||||
|
||||
@classmethod
|
||||
def write_to_socket(cls, client_id, values):
|
||||
cls.int_type.write_to_socket(client_id, [x.v for x in values])
|
||||
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address, mem_type=None):
|
||||
""" Load from memory by public address. """
|
||||
@@ -4273,7 +4368,8 @@ class _single(_number, _secret_structure):
|
||||
@classmethod
|
||||
def read_from_file(cls, *args, **kwargs):
|
||||
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
|
||||
Precision must be the same as when storing.
|
||||
Precision must be the same as when storing. See :ref:`this
|
||||
section <persistence>` for details on the data format.
|
||||
|
||||
:param start: starting position in number of shares from beginning
|
||||
(int/regint/cint)
|
||||
@@ -4288,7 +4384,8 @@ class _single(_number, _secret_structure):
|
||||
@classmethod
|
||||
def write_to_file(cls, shares, position=None):
|
||||
""" Write shares of integer representation to
|
||||
``Persistence/Transactions-P<playerno>.data``.
|
||||
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
|
||||
section <persistence>` for details on the data format.
|
||||
|
||||
:param shares: (list or iterable of sfix)
|
||||
:param position: start position (int/regint/cint),
|
||||
@@ -4588,7 +4685,8 @@ class _fix(_single):
|
||||
nearest=self.round_nearest)
|
||||
elif isinstance(other, cfix):
|
||||
v = library.sint_cint_division(self.v, other.v, self.k, self.f,
|
||||
self.kappa)
|
||||
self.kappa,
|
||||
nearest=self.round_nearest)
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
return self._new(v, k=self.k, f=self.f)
|
||||
@@ -4656,7 +4754,7 @@ class sfix(_fix):
|
||||
default_type = sint
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_input_from(cls, player, binary=False):
|
||||
def get_input_from(cls, player, binary=False, n_bytes=None):
|
||||
""" Secret fixed-point input.
|
||||
|
||||
:param player: public (regint/cint/int)
|
||||
@@ -4770,6 +4868,13 @@ class sfix(_fix):
|
||||
def multipliable(v, k, f, size):
|
||||
return cfix._new(cint.conv(v, size=size), k, f)
|
||||
|
||||
def dot(self, other):
|
||||
""" Dot product with :py:class:`sint:`. """
|
||||
if isinstance(other, sint):
|
||||
return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
|
||||
@@ -4793,6 +4898,12 @@ class sfix(_fix):
|
||||
def sum(self):
|
||||
return self._new(self.v.sum())
|
||||
|
||||
def get_reverse_vector(self):
|
||||
return self._new(self.v.get_reverse_vector(), k=self.k, f=self.f)
|
||||
|
||||
def get_vector(self, *args, **kwargs):
|
||||
return self._new(self.v.get_vector(*args, **kwargs), k=self.k, f=self.f)
|
||||
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
parts = list(parts)
|
||||
@@ -5486,6 +5597,11 @@ class sfloat(_number, _secret_structure):
|
||||
self.z.update(other.z)
|
||||
self.s.update(other.s)
|
||||
|
||||
def for_mux(self, other):
|
||||
other = self.coerce(other)
|
||||
f = lambda x: type(self)(*x)
|
||||
return f, sint(list(self)), sint(list(other))
|
||||
|
||||
class cfloat(Tape._no_truth):
|
||||
""" Helper class for printing revealed sfloats. """
|
||||
__slots__ = ['v', 'p', 'z', 's', 'nan']
|
||||
@@ -5763,6 +5879,9 @@ class Array(_vectorizable):
|
||||
tmp.store_in_mem(address)
|
||||
|
||||
def __len__(self):
|
||||
if self.length is None:
|
||||
raise CompilerError('this functionality is not available '
|
||||
'for variable-length arrays')
|
||||
return self.length
|
||||
|
||||
def total_size(self):
|
||||
@@ -5887,6 +6006,19 @@ class Array(_vectorizable):
|
||||
addresses = self.get_slice_addresses(slice)
|
||||
vector.store_in_mem(addresses)
|
||||
|
||||
def permute(self, permutation, reverse=False, n_threads=None):
|
||||
""" Public permutation.
|
||||
|
||||
:param permutation: cleartext :py:class`Array` containing number
|
||||
in :math:`[0,n-1]` where :math:`n` is the length of this array
|
||||
:param reverse: whether to apply the inverse of the permutation
|
||||
|
||||
"""
|
||||
if reverse:
|
||||
self.assign_slice_vector(permutation, self.get_vector())
|
||||
else:
|
||||
self.assign_vector(self.get_slice_vector(permutation))
|
||||
|
||||
def expand_to_vector(self, index, size):
|
||||
""" Create vector from single entry.
|
||||
|
||||
@@ -5930,7 +6062,8 @@ class Array(_vectorizable):
|
||||
|
||||
def read_from_file(self, start):
|
||||
""" Read content from ``Persistence/Transactions-P<playerno>.data``.
|
||||
Precision must be the same as when storing if applicable.
|
||||
Precision must be the same as when storing if applicable. See
|
||||
:ref:`this section <persistence>` for details on the data format.
|
||||
|
||||
:param start: starting position in number of shares from beginning
|
||||
(int/regint/cint)
|
||||
@@ -5943,13 +6076,36 @@ class Array(_vectorizable):
|
||||
|
||||
def write_to_file(self, position=None):
|
||||
""" Write shares of integer representation to
|
||||
``Persistence/Transactions-P<playerno>.data``.
|
||||
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
|
||||
section <persistence>` for details on the data format.
|
||||
|
||||
:param position: start position (int/regint/cint),
|
||||
defaults to end of file
|
||||
"""
|
||||
self.value_type.write_to_file(list(self), position)
|
||||
|
||||
def read_from_socket(self, socket, debug=False):
|
||||
""" Read content from socket. """
|
||||
if debug:
|
||||
library.print_str('reading %s...' % self)
|
||||
# hard-coded budget for interopability
|
||||
@library.multithread(None, len(self), max_size=10 ** 6)
|
||||
def _(base, size):
|
||||
self.assign_vector(
|
||||
self.value_type.read_from_socket(socket, size=size), base=base)
|
||||
if debug:
|
||||
library.print_ln('done')
|
||||
|
||||
def write_to_socket(self, socket, debug=False):
|
||||
""" Write content to socket. """
|
||||
if debug:
|
||||
library.print_ln('writing %s' % self)
|
||||
# hard-coded budget for interopability
|
||||
@library.multithread(None, len(self), max_size=10 ** 6)
|
||||
def _(base, size):
|
||||
self.value_type.write_to_socket(
|
||||
socket, [self.get_vector(base=base, size=size)])
|
||||
|
||||
def __add__(self, other):
|
||||
""" Vector addition.
|
||||
|
||||
@@ -6254,9 +6410,13 @@ class SubMultiArray(_vectorizable):
|
||||
""" Assign container to content. Not implemented for floating-point.
|
||||
|
||||
:param other: container of matching size and type """
|
||||
if self.value_type.n_elements() > 1:
|
||||
assert self.sizes == other.sizes
|
||||
self.assign_vector(other.get_vector())
|
||||
try:
|
||||
if self.value_type.n_elements() > 1:
|
||||
assert self.sizes == other.sizes
|
||||
self.assign_vector(other.get_vector())
|
||||
except:
|
||||
for i, x in enumerate(other):
|
||||
self[i].assign(x)
|
||||
|
||||
def get_part_vector(self, base=0, size=None):
|
||||
""" Vector from range of the first dimension, including all
|
||||
@@ -6297,15 +6457,41 @@ class SubMultiArray(_vectorizable):
|
||||
addresses = self.get_slice_addresses(slice)
|
||||
vector.store_in_mem(self.address + addresses)
|
||||
|
||||
def get_slice_addresses(self, slice):
|
||||
def get_part_size(self):
|
||||
assert self.value_type.n_elements() == 1
|
||||
part_size = reduce(operator.mul, self.sizes[1:])
|
||||
return reduce(operator.mul, self.sizes[1:])
|
||||
|
||||
def get_slice_addresses(self, slice, part_size=None):
|
||||
part_size = part_size or self.get_part_size()
|
||||
assert len(slice) * part_size <= self.total_size()
|
||||
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
|
||||
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
|
||||
addresses = slice.value_type.load_mem(base) * part_size + inc
|
||||
return addresses
|
||||
|
||||
def permute(self, permutation, reverse=False, n_threads=None):
|
||||
""" Public permutation along first dimension.
|
||||
|
||||
:param permutation: cleartext :py:class`Array` containing number
|
||||
in :math:`[0,n-1]` where :math:`n` is the length of this array
|
||||
:param reverse: whether to apply the inverse of the permutation
|
||||
|
||||
"""
|
||||
@library.multithread(n_threads, self.get_part_size())
|
||||
def _(base, size):
|
||||
addresses = self.get_slice_addresses(permutation, part_size=1)
|
||||
addresses *= self.get_part_size()
|
||||
@library.for_range_opt(size)
|
||||
def _(j):
|
||||
i = base + j
|
||||
if reverse:
|
||||
v = self.get_column(i)
|
||||
v.store_in_mem(self.address + i + addresses)
|
||||
else:
|
||||
v = self.value_type.load_mem(
|
||||
self.address + i + addresses)
|
||||
self.set_column(i, v)
|
||||
|
||||
def get_addresses(self, *indices):
|
||||
assert self.value_type.n_elements() == 1
|
||||
assert len(indices) == len(self.sizes)
|
||||
@@ -6389,7 +6575,8 @@ class SubMultiArray(_vectorizable):
|
||||
|
||||
def write_to_file(self, position=None):
|
||||
""" Write shares of integer representation to
|
||||
``Persistence/Transactions-P<playerno>.data``.
|
||||
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
|
||||
section <persistence>` for details on the data format.
|
||||
|
||||
:param position: start position (int/regint/cint),
|
||||
defaults to end of file
|
||||
@@ -6404,7 +6591,8 @@ class SubMultiArray(_vectorizable):
|
||||
|
||||
def read_from_file(self, start):
|
||||
""" Read content from ``Persistence/Transactions-P<playerno>.data``.
|
||||
Precision must be the same as when storing if applicable.
|
||||
Precision must be the same as when storing if applicable. See
|
||||
:ref:`this section <persistence>` for details on the data format.
|
||||
|
||||
:param start: starting position in number of shares from beginning
|
||||
(int/regint/cint)
|
||||
@@ -6417,6 +6605,14 @@ class SubMultiArray(_vectorizable):
|
||||
start.write(self[i].read_from_file(start))
|
||||
return start
|
||||
|
||||
def write_to_socket(self, socket, debug=False):
|
||||
""" Write content to socket. """
|
||||
self.array.write_to_socket(socket, debug=debug)
|
||||
|
||||
def read_from_socket(self, socket, debug=False):
|
||||
""" Read content from socket. """
|
||||
self.array.read_from_socket(socket, debug=debug)
|
||||
|
||||
def schur(self, other):
|
||||
""" Element-wise product.
|
||||
|
||||
@@ -6758,7 +6954,28 @@ class SubMultiArray(_vectorizable):
|
||||
res = self.value_type.dot_product(a, b)
|
||||
return res
|
||||
|
||||
def transpose(self):
|
||||
def get_column(self, index):
|
||||
""" Get matrix column as vector.
|
||||
|
||||
:param index: regint/cint/int
|
||||
"""
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.get_part_size())
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def set_column(self, index, vector):
|
||||
""" Change column.
|
||||
|
||||
:param index: regint/cint/int
|
||||
:param vector: short enought vector of compatible type
|
||||
"""
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.get_part_size())
|
||||
self.value_type.conv(vector).store_in_mem(addresses)
|
||||
|
||||
def transpose(self, n_threads=None):
|
||||
""" Matrix transpose.
|
||||
|
||||
:param self: two-dimensional """
|
||||
@@ -6766,13 +6983,24 @@ class SubMultiArray(_vectorizable):
|
||||
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
|
||||
library.break_point()
|
||||
if self.value_type.n_elements() == 1:
|
||||
nr = self.sizes[1]
|
||||
nc = self.sizes[0]
|
||||
a = regint.inc(nr * nc, 0, nr, 1, nc)
|
||||
b = regint.inc(nr * nc, 0, 1, nc)
|
||||
res[:] = self.value_type.load_mem(self.address + a + b)
|
||||
if self.sizes[0] < program.budget:
|
||||
if self.sizes[1] < program.budget:
|
||||
nr = self.sizes[1]
|
||||
nc = self.sizes[0]
|
||||
a = regint.inc(nr * nc, 0, nr, 1, nc)
|
||||
b = regint.inc(nr * nc, 0, 1, nc)
|
||||
res[:] = self.value_type.load_mem(self.address + a + b)
|
||||
else:
|
||||
@library.for_range_multithread(n_threads, 1, self.sizes[0])
|
||||
def _(i):
|
||||
res.set_column(i, self[i][:])
|
||||
else:
|
||||
@library.for_range_multithread(n_threads, 1, self.sizes[1])
|
||||
def _(i):
|
||||
res[i][:] = self.get_column(i)
|
||||
else:
|
||||
@library.for_range_opt(self.sizes[1], budget=100)
|
||||
@library.for_range_opt_multithread(n_threads, self.sizes[1],
|
||||
budget=100)
|
||||
def _(i):
|
||||
@library.for_range_opt(self.sizes[0], budget=100)
|
||||
def _(j):
|
||||
@@ -6801,7 +7029,7 @@ class SubMultiArray(_vectorizable):
|
||||
"""
|
||||
self.assign_vector(self.get_vector().secure_shuffle(self.part_size()))
|
||||
|
||||
def secure_permute(self, permutation, reverse=False):
|
||||
def secure_permute(self, permutation, reverse=False, n_threads=None):
|
||||
""" Securely permute rows (first index). See
|
||||
:py:func:`secure_shuffle` for references.
|
||||
|
||||
@@ -6809,8 +7037,12 @@ class SubMultiArray(_vectorizable):
|
||||
:param reverse: whether to apply inverse (default: False)
|
||||
|
||||
"""
|
||||
self.assign_vector(self.get_vector().secure_permute(
|
||||
permutation, self.part_size(), reverse))
|
||||
if n_threads is not None:
|
||||
permutation = MemValue(permutation)
|
||||
@library.for_range_multithread(n_threads, 1, self.get_part_size())
|
||||
def _(i):
|
||||
self.set_column(i, self.get_column(i).secure_permute(
|
||||
permutation, reverse=reverse))
|
||||
|
||||
def sort(self, key_indices=None, n_bits=None):
|
||||
""" Sort sub-arrays (different first index) in place.
|
||||
@@ -6829,6 +7061,9 @@ class SubMultiArray(_vectorizable):
|
||||
return
|
||||
if key_indices is None:
|
||||
key_indices = (0,) * (len(self.sizes) - 1)
|
||||
if len(key_indices) != len(self.sizes) - 1:
|
||||
raise CompilerError('length of key_indices has to be one less '
|
||||
'than the dimension')
|
||||
key_indices = (None,) + util.tuplify(key_indices)
|
||||
from . import sorting
|
||||
keys = self.get_vector_by_indices(*key_indices)
|
||||
@@ -6971,7 +7206,8 @@ class Matrix(MultiArray):
|
||||
|
||||
@staticmethod
|
||||
def create_from(rows):
|
||||
rows = list(rows)
|
||||
if not isinstance(rows, _vectorizable):
|
||||
rows = list(rows)
|
||||
if isinstance(rows[0], (list, tuple, Array)):
|
||||
t = type(rows[0][0])
|
||||
else:
|
||||
@@ -6983,20 +7219,15 @@ class Matrix(MultiArray):
|
||||
raise CompilerError(
|
||||
'accidental shortening by creating matrix')
|
||||
res = Matrix(len(rows), len(rows[0]), t)
|
||||
for i in range(len(rows)):
|
||||
res[i].assign(rows[i])
|
||||
if isinstance(rows, _vectorizable):
|
||||
@library.for_range_opt(len(rows))
|
||||
def _(i):
|
||||
res[i].assign(rows[i])
|
||||
else:
|
||||
for i in range(len(rows)):
|
||||
res[i].assign(rows[i])
|
||||
return res
|
||||
|
||||
def get_column(self, index):
|
||||
""" Get column as vector.
|
||||
|
||||
:param index: regint/cint/int
|
||||
"""
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.sizes[1])
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def get_columns(self):
|
||||
return (self.get_column(i) for i in range(self.sizes[1]))
|
||||
|
||||
@@ -7006,17 +7237,6 @@ class Matrix(MultiArray):
|
||||
regint.inc(len(rows), self.address + column, 0)
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def set_column(self, index, vector):
|
||||
""" Change column.
|
||||
|
||||
:param index: regint/cint/int
|
||||
:param vector: short enought vector of compatible type
|
||||
"""
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.sizes[1])
|
||||
self.value_type.conv(vector).store_in_mem(addresses)
|
||||
|
||||
def concat_columns(self, other):
|
||||
""" Concatenate two matrices by columns. """
|
||||
assert self.sizes[0] == other.sizes[0]
|
||||
@@ -7217,6 +7437,8 @@ class MemValue(_mem):
|
||||
bit_and = lambda self,other: self.read().bit_and(other)
|
||||
bit_not = lambda self: self.read().bit_not()
|
||||
|
||||
print_if = lambda self,*args,**kwargs: self.read().print_if(*args, **kwargs)
|
||||
|
||||
def expand_to_vector(self, size=None):
|
||||
if program.curr_block == self.last_write_block:
|
||||
return self.read().expand_to_vector(size)
|
||||
|
||||
@@ -64,14 +64,14 @@ RUN pip install --upgrade pip ipython
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG arch=native
|
||||
ARG arch=
|
||||
ARG cxx=clang++-11
|
||||
ARG use_ntl=0
|
||||
ARG prep_dir="Player-Data"
|
||||
ARG ssl_dir="Player-Data"
|
||||
|
||||
RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \
|
||||
&& echo "CXX = ${cxx}" >> CONFIG.mine \
|
||||
RUN if test -n "${arch}"; then echo "ARCH = -march=${arch}" >> CONFIG.mine; fi
|
||||
RUN echo "CXX = ${cxx}" >> CONFIG.mine \
|
||||
&& echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \
|
||||
&& echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \
|
||||
&& echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \
|
||||
|
||||
@@ -50,7 +50,6 @@ public:
|
||||
|
||||
void assign_zero() { *this = 0; }
|
||||
bool is_zero() { return *this == 0; }
|
||||
void add(octetStream& os) { *this += os.get<CurveElement>(); }
|
||||
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os);
|
||||
|
||||
@@ -166,13 +166,3 @@ bool P256Element::operator !=(const P256Element& other) const
|
||||
{
|
||||
return not (*this == other);
|
||||
}
|
||||
|
||||
octetStream P256Element::hash(size_t n_bytes) const
|
||||
{
|
||||
octetStream os;
|
||||
pack(os);
|
||||
auto res = os.hash();
|
||||
assert(n_bytes >= res.get_length());
|
||||
res.resize_precise(n_bytes);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -56,15 +56,9 @@ public:
|
||||
bool operator==(const P256Element& other) const;
|
||||
bool operator!=(const P256Element& other) const;
|
||||
|
||||
void assign_zero() { *this = {}; }
|
||||
bool is_zero() { return *this == P256Element(); }
|
||||
void add(octetStream& os, int = -1) { *this += os.get<P256Element>(); }
|
||||
|
||||
void pack(octetStream& os, int = -1) const;
|
||||
void unpack(octetStream& os, int = -1);
|
||||
|
||||
octetStream hash(size_t n_bytes) const;
|
||||
|
||||
friend ostream& operator<<(ostream& s, const P256Element& x);
|
||||
};
|
||||
|
||||
|
||||
@@ -131,6 +131,12 @@ int main(int argc, char** argv)
|
||||
case 'R':
|
||||
{
|
||||
int R = specification.get<int>();
|
||||
int R2 = specification.get<int>();
|
||||
if (R2 != 64)
|
||||
{
|
||||
cerr << R2 << "-bit ring not implemented" << endl;
|
||||
}
|
||||
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
|
||||
@@ -14,24 +14,17 @@ finish = int(sys.argv[4])
|
||||
|
||||
client = Client(['localhost'] * n_parties, 14000, client_id)
|
||||
|
||||
type = client.specification.get_int(4)
|
||||
|
||||
if type == ord('R'):
|
||||
domain = Z2(client.specification.get_int(4))
|
||||
elif type == ord('p'):
|
||||
domain = Fp(client.specification.get_bigint())
|
||||
else:
|
||||
raise Exception('invalid type')
|
||||
|
||||
for socket in client.sockets:
|
||||
os = octetStream()
|
||||
os.store(finish)
|
||||
os.Send(socket)
|
||||
|
||||
def run(x):
|
||||
client.send_private_inputs([x])
|
||||
|
||||
print('Winning client id is :', client.receive_outputs(1)[0])
|
||||
|
||||
# running two rounds
|
||||
# first for sint, then for sfix
|
||||
for x in bonus, bonus * 2 ** 16:
|
||||
client.send_private_inputs([domain(x)])
|
||||
|
||||
print('Winning client id is :',
|
||||
int(client.receive_outputs(domain, 1)[0]))
|
||||
run(bonus)
|
||||
run(bonus * 2 ** 16)
|
||||
|
||||
@@ -2,6 +2,7 @@ import platform
|
||||
import socket, ssl
|
||||
import struct
|
||||
import time
|
||||
from domains import *
|
||||
|
||||
# The following function is either taken directly or derived from:
|
||||
# https://stackoverflow.com/questions/12248132/how-to-change-tcp-keepalive-timer-using-python-script
|
||||
@@ -61,6 +62,15 @@ class Client:
|
||||
|
||||
self.specification = octetStream()
|
||||
self.specification.Receive(self.sockets[0])
|
||||
type = self.specification.get_int(4)
|
||||
if type == ord('R'):
|
||||
self.domain = Z2(self.specification.get_int(4))
|
||||
self.clear_domain = Z2(self.specification.get_int(4))
|
||||
elif type == ord('p'):
|
||||
self.domain = Fp(self.specification.get_bigint())
|
||||
self.clear_domain = self.domain
|
||||
else:
|
||||
raise Exception('invalid type')
|
||||
|
||||
def receive_triples(self, T, n):
|
||||
triples = [[0, 0, 0] for i in range(n)]
|
||||
@@ -89,18 +99,19 @@ class Client:
|
||||
return triples
|
||||
|
||||
def send_private_inputs(self, values):
|
||||
T = type(values[0])
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, len(values))
|
||||
os = octetStream()
|
||||
assert len(values) == len(triples)
|
||||
for value, triple in zip(values, triples):
|
||||
(value + triple[0]).pack(os)
|
||||
(T(value) + triple[0]).pack(os)
|
||||
for socket in self.sockets:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_outputs(self, T, n):
|
||||
def receive_outputs(self, n):
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, n)
|
||||
return [triple[0] for triple in triples]
|
||||
return [int(self.clear_domain(triple[0].v)) for triple in triples]
|
||||
|
||||
class octetStream:
|
||||
def __init__(self, value=None):
|
||||
|
||||
@@ -67,20 +67,15 @@ void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,
|
||||
cc0.Scale(pk.p()); cc1.Scale(pk.p());
|
||||
|
||||
// Now do the multiply
|
||||
Rq_Element d0,d1,d2;
|
||||
|
||||
mul(d0,cc0.cc0,cc1.cc0);
|
||||
mul(d1,cc0.cc0,cc1.cc1);
|
||||
mul(d2,cc0.cc1,cc1.cc0);
|
||||
add(d1,d1,d2);
|
||||
mul(d2,cc0.cc1,cc1.cc1);
|
||||
auto d0 = cc0.cc0 * cc1.cc0;
|
||||
auto d1 = cc0.cc0 * cc1.cc1 + cc0.cc1 * cc1.cc0;
|
||||
auto d2 = cc0.cc1 * cc1.cc1;
|
||||
d2.negate();
|
||||
|
||||
// Now do the switch key
|
||||
d2.raise_level();
|
||||
Rq_Element t;
|
||||
d0.mul_by_p1();
|
||||
mul(t,pk.bs(),d2);
|
||||
auto t = pk.bs()* d2;
|
||||
add(d0,d0,t);
|
||||
|
||||
d1.mul_by_p1();
|
||||
|
||||
@@ -29,7 +29,6 @@ class Ciphertext
|
||||
word pk_id;
|
||||
|
||||
public:
|
||||
static int size() { return 0; }
|
||||
|
||||
const FHE_Params& get_params() const { return *params; }
|
||||
|
||||
|
||||
@@ -94,8 +94,8 @@ void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
add(PK.Sw_b,PK.Sw_b,es);
|
||||
|
||||
// bs=bs-p1*s^2
|
||||
Rq_Element s2;
|
||||
mul(s2,sk,sk); // Mult at level 0
|
||||
// Mult at level 0
|
||||
auto s2 = sk * sk;
|
||||
s2.mul_by_p1(); // This raises back to level 1
|
||||
sub(PK.Sw_b,PK.Sw_b,s2);
|
||||
}
|
||||
@@ -155,17 +155,12 @@ void FHE_PK::quasi_encrypt(Ciphertext& c,
|
||||
if (&rc.get_params()!=params) { throw params_mismatch(); }
|
||||
assert(pr != 0);
|
||||
|
||||
Rq_Element ed,edd,c0,c1,aa;
|
||||
|
||||
// c1=a0*u+p*v
|
||||
mul(aa,a0,rc.u());
|
||||
mul(ed,rc.v(),pr);
|
||||
add(c1,aa,ed);
|
||||
auto c1 = a0 * rc.u() + rc.v() * pr;
|
||||
|
||||
// c0 = b0 * u + p * w + mess
|
||||
mul(c0,b0,rc.u());
|
||||
mul(edd,rc.w(),pr);
|
||||
add(edd,edd,mess);
|
||||
auto c0 = b0 * rc.u();
|
||||
auto edd = rc.w() * pr + mess;
|
||||
if (params->n_mults() == 0)
|
||||
edd.change_rep(evaluation);
|
||||
else
|
||||
@@ -218,10 +213,7 @@ Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
|
||||
Rq_Element ans;
|
||||
|
||||
mul(ans,c.c1(),sk);
|
||||
sub(ans,c.c0(),ans);
|
||||
auto ans = c.c0() - c.c1() * sk;
|
||||
ans.change_rep(polynomial);
|
||||
return ans;
|
||||
}
|
||||
@@ -267,8 +259,7 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
|
||||
Ciphertext cc=ctx; cc.Scale(pr);
|
||||
|
||||
// First do the basic decryption
|
||||
Rq_Element dec_sh;
|
||||
mul(dec_sh,cc.c1(),sk);
|
||||
auto dec_sh = cc.c1() * sk;
|
||||
if (player_number==0)
|
||||
{ sub(dec_sh,cc.c0(),dec_sh); }
|
||||
else
|
||||
|
||||
@@ -89,9 +89,6 @@ class FHE_SK
|
||||
|
||||
bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; }
|
||||
|
||||
void add(octetStream& os, int = -1)
|
||||
{ FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; }
|
||||
|
||||
void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const;
|
||||
|
||||
template<class FD>
|
||||
@@ -120,10 +117,12 @@ class FHE_PK
|
||||
bigint p() const { return pr; }
|
||||
|
||||
void assign(const Rq_Element& a,const Rq_Element& b,
|
||||
const Rq_Element& sa = {},const Rq_Element& sb = {}
|
||||
const Rq_Element& sa,const Rq_Element& sb
|
||||
)
|
||||
{ a0=a; b0=b; Sw_a=sa; Sw_b=sb; }
|
||||
|
||||
void assign(const Rq_Element& a,const Rq_Element& b)
|
||||
{ a0=a; b0=b; }
|
||||
|
||||
FHE_PK(const FHE_Params& pms);
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ class RingReadIterator;
|
||||
|
||||
class Ring_Element
|
||||
{
|
||||
friend class Rq_Element;
|
||||
|
||||
RepType rep;
|
||||
|
||||
/* FFTD is defined as a pointer so each different Ring_Element
|
||||
@@ -41,6 +43,9 @@ class Ring_Element
|
||||
|
||||
vector<modp> element;
|
||||
|
||||
/* Careful calling this one, as FFTD will not be defined */
|
||||
Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; }
|
||||
|
||||
public:
|
||||
|
||||
// Used to basically make sure *this is able to cope
|
||||
@@ -63,9 +68,6 @@ class Ring_Element
|
||||
void assign_zero();
|
||||
void assign_one();
|
||||
|
||||
/* Careful calling this one, as FFTD will not be defined */
|
||||
Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; }
|
||||
|
||||
Ring_Element(const FFT_Data& prd,RepType r=polynomial);
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -23,7 +23,7 @@ Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
|
||||
|
||||
void Rq_Element::set_data(const vector<FFT_Data>& prd)
|
||||
{
|
||||
a.resize(prd.size());
|
||||
a.resize(prd.size(), {});
|
||||
for(size_t i = 0; i < a.size(); i++)
|
||||
a[i].set_data(prd[i]);
|
||||
lev=n_mults();
|
||||
@@ -50,7 +50,7 @@ void Rq_Element::assign_one()
|
||||
void Rq_Element::partial_assign(const Rq_Element& other)
|
||||
{
|
||||
lev=other.lev;
|
||||
a.resize(other.a.size());
|
||||
a.resize(other.a.size(), {});
|
||||
}
|
||||
|
||||
void Rq_Element::negate()
|
||||
@@ -112,13 +112,6 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
|
||||
}
|
||||
}
|
||||
|
||||
void Rq_Element::add(octetStream& os, int)
|
||||
{
|
||||
Rq_Element tmp(*this);
|
||||
tmp.unpack(os);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
void Rq_Element::randomize(PRNG& G,int l)
|
||||
{
|
||||
set_level(l);
|
||||
|
||||
@@ -33,6 +33,10 @@ protected:
|
||||
vector<Ring_Element> a;
|
||||
int lev;
|
||||
|
||||
// Must be careful not to call by mistake
|
||||
Rq_Element(RepType r0=evaluation,RepType r1=polynomial) :
|
||||
a({r0, r1}), lev(n_mults()) {}
|
||||
|
||||
public:
|
||||
|
||||
int n_mults() const { return a.size() - 1; }
|
||||
@@ -46,10 +50,6 @@ protected:
|
||||
void assign_one();
|
||||
void partial_assign(const Rq_Element& e);
|
||||
|
||||
// Must be careful not to call by mistake
|
||||
Rq_Element(RepType r0=evaluation,RepType r1=polynomial) :
|
||||
a({r0, r1}), lev(n_mults()) {}
|
||||
|
||||
// Pass in a pair of FFT_Data as a vector
|
||||
Rq_Element(const vector<FFT_Data>& prd, RepType r0 = evaluation,
|
||||
RepType r1 = polynomial);
|
||||
@@ -97,8 +97,6 @@ protected:
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
|
||||
|
||||
void add(octetStream& os, int = -1);
|
||||
|
||||
template<class S>
|
||||
Rq_Element& operator+=(const vector<S>& other);
|
||||
|
||||
|
||||
@@ -15,13 +15,10 @@
|
||||
void Encrypt_Rq_Element(Ciphertext& c,const Rq_Element& mess, const Random_Coins& rc,
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
Rq_Element ed, edd, c0, c1;
|
||||
mul(c1, pk.a(), rc.u());
|
||||
mul(ed, rc.v(), pk.p());
|
||||
add(c1, c1, ed);
|
||||
auto c1 = pk.a() * rc.u() + rc.v() * pk.p();
|
||||
auto c0 = pk.b() * rc.u();
|
||||
auto edd = rc.w() * pk.p();
|
||||
|
||||
mul(c0, pk.b(), rc.u());
|
||||
mul(edd, rc.w(), pk.p());
|
||||
edd.change_rep(evaluation,evaluation);
|
||||
add(edd,edd,mess);
|
||||
add(c0,c0,edd);
|
||||
|
||||
@@ -87,6 +87,7 @@ public:
|
||||
void xorc(const ::BaseInstruction& instruction);
|
||||
void nots(const ::BaseInstruction& instruction);
|
||||
void notcb(const ::BaseInstruction& instruction);
|
||||
void movsb(const ::BaseInstruction& instruction);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
void andrs(const vector<int>& args) { and_(args, true); }
|
||||
|
||||
@@ -283,6 +283,13 @@ void Processor<T>::notcb(const ::BaseInstruction& instruction)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::movsb(const ::BaseInstruction& instruction)
|
||||
{
|
||||
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
|
||||
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::andm(const ::BaseInstruction& instruction)
|
||||
{
|
||||
|
||||
@@ -32,7 +32,7 @@ void TinierSharePrep<T>::buffer_secret_triples()
|
||||
assert(triple_generator != 0);
|
||||
params.generateBits = false;
|
||||
vector<array<T, 3>> triples;
|
||||
TripleShuffleSacrifice<T> sacrifice;
|
||||
TripleShuffleSacrifice<T> sacrifice(DATA_GF2);
|
||||
size_t required;
|
||||
required = sacrifice.minimum_n_inputs_with_combining(
|
||||
BaseMachine::batch_size<T>(DATA_TRIPLE));
|
||||
|
||||
@@ -65,7 +65,7 @@
|
||||
X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \
|
||||
X(LDMCBI, PROC.mem_op(SIZE, PROC.C, MMC, R0, Ci[REG1])) \
|
||||
X(STMCBI, PROC.mem_op(SIZE, MMC, PROC.C, Ci[REG1], R0)) \
|
||||
X(MOVSB, S0 = PS1) \
|
||||
X(MOVSB, PROC.movsb(INST)) \
|
||||
X(TRANS, T::trans(PROC, IMM, EXTRA)) \
|
||||
X(BITB, PROC.random_bit(S0)) \
|
||||
X(REVEAL, T::reveal_inst(PROC, EXTRA)) \
|
||||
@@ -123,7 +123,7 @@
|
||||
X(LDMINTI, I0 = MII) \
|
||||
X(STMINTI, MII = I0) \
|
||||
X(PUSHINT, PROC.pushi(I0.get())) \
|
||||
X(POPINT, long x; PROC.popi(x); I0 = x) \
|
||||
X(POPINT, PROC.popi(I0)) \
|
||||
X(MOVINT, I0 = PI1) \
|
||||
X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \
|
||||
X(LDARG, I0 = PROC.get_arg()) \
|
||||
|
||||
@@ -72,6 +72,12 @@ ShamirOptions::ShamirOptions(ez::ezOptionParser& opt, int argc, const char** arg
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
opt.get("-N")->getInt(nparties);
|
||||
if (nparties < 3)
|
||||
{
|
||||
cerr << "Protocols based on Shamir secret sharing require at least "
|
||||
<< "three parties." << endl;
|
||||
exit(1);
|
||||
}
|
||||
set_threshold(opt);
|
||||
opt.resetArgs();
|
||||
}
|
||||
|
||||
@@ -26,23 +26,9 @@ int main(int argc, const char** argv)
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
online_opts = {opt, argc, argv, FakeShare<SignedZ2<64>>()};
|
||||
opt.parse(argc, argv);
|
||||
opt.syntax = string(argv[0]) + " <progname>";
|
||||
|
||||
string progname;
|
||||
if (opt.firstArgs.size() > 1)
|
||||
progname = *opt.firstArgs.at(1);
|
||||
else if (not opt.lastArgs.empty())
|
||||
progname = *opt.lastArgs.at(0);
|
||||
else if (not opt.unknownArgs.empty())
|
||||
progname = *opt.unknownArgs.at(0);
|
||||
else
|
||||
{
|
||||
string usage;
|
||||
opt.getUsage(usage);
|
||||
cerr << usage << endl;
|
||||
exit(1);
|
||||
}
|
||||
online_opts.finalize(opt, argc, argv, false);
|
||||
string& progname = online_opts.progname;
|
||||
|
||||
#ifdef ROUND_NEAREST_IN_EMULATION
|
||||
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "Processor/OnlineMachine.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/OnlineOptions.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
@@ -17,7 +18,7 @@
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
OnlineOptions::singleton = {opt, argc, argv};
|
||||
OnlineOptions::singleton = {opt, argc, argv, NoShare<gf2n>()};
|
||||
OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton);
|
||||
OnlineOptions::singleton.finalize(opt, argc, argv);
|
||||
machine.start_networking();
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "Processor/Machine.h"
|
||||
#include "Processor/RingOptions.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Protocols/SPDZ2k.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Networking/Server.h"
|
||||
|
||||
@@ -62,8 +63,10 @@ int main(int argc, const char** argv)
|
||||
cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line "
|
||||
<< (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+"
|
||||
<< s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl;
|
||||
cerr << "Alternatively, compile with -DRING_SIZE=" << k
|
||||
<< " and -DSPDZ2K_DEFAULT_SECURITY=" << s << endl;
|
||||
cerr << "Alternatively, put 'MY_CFLAGS += -DRING_SIZE=" << k
|
||||
<< " -DSPDZ2K_DEFAULT_SECURITY=" << s
|
||||
<< "' in 'CONFIG.mine' before running 'make spdz2k-party.x'"
|
||||
<< endl;
|
||||
}
|
||||
exit(1);
|
||||
}
|
||||
|
||||
11
Makefile
11
Makefile
@@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x
|
||||
ifeq ($(OS), Darwin)
|
||||
setup: mac-setup
|
||||
else
|
||||
setup: boost linux-machine-setup
|
||||
setup: maybe-boost linux-machine-setup
|
||||
endif
|
||||
|
||||
tldr: setup
|
||||
@@ -296,13 +296,17 @@ deps/SimplestOT_C/ref10/Makefile:
|
||||
|
||||
.PHONY: Programs/Circuits
|
||||
Programs/Circuits:
|
||||
git submodule update --init Programs/Circuits
|
||||
git submodule update --init Programs/Circuits || git clone https://github.com/mkskeller/bristol-fashion Programs/Circuits
|
||||
|
||||
deps/libOTe/libOTe:
|
||||
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
|
||||
boost: deps/libOTe/libOTe
|
||||
cd deps/libOTe; \
|
||||
python3 build.py --setup --boost --install=$(CURDIR)/local
|
||||
maybe-boost: deps/libOTe/libOTe
|
||||
cd `mktemp -d`; \
|
||||
PATH="$(CURDIR)/local/bin:$(PATH)" cmake $(CURDIR)/deps/libOTe || \
|
||||
{ cd -; make boost; }
|
||||
|
||||
OTE_OPTS += -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib
|
||||
|
||||
@@ -334,11 +338,12 @@ OT/OTExtensionWithMatrix.o: $(OTE)
|
||||
endif
|
||||
|
||||
local/lib/liblibOTe.a: deps/libOTe/libOTe
|
||||
make maybe-boost; \
|
||||
cd deps/libOTe; \
|
||||
PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS) && \
|
||||
touch ../../local/lib/liblibOTe.a
|
||||
|
||||
$(SHARED_OTE): deps/libOTe/libOTe
|
||||
$(SHARED_OTE): deps/libOTe/libOTe maybe-boost
|
||||
cd deps/libOTe; \
|
||||
python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS)
|
||||
|
||||
|
||||
@@ -51,11 +51,6 @@ public:
|
||||
return other * *this;
|
||||
}
|
||||
|
||||
void add(octetStream& os, int = -1)
|
||||
{
|
||||
*this += os.get<Bit>();
|
||||
}
|
||||
|
||||
void pack(octetStream& os, int = -1) const
|
||||
{
|
||||
super::pack(os, 1);
|
||||
|
||||
@@ -56,13 +56,6 @@ public:
|
||||
|
||||
void extend_bit(BitVec_& res, int) const { res = extend_bit(); }
|
||||
|
||||
void add(octetStream& os, int n_bits)
|
||||
{
|
||||
BitVec_ tmp;
|
||||
tmp.unpack(os, n_bits);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; }
|
||||
|
||||
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = this->mask(n); }
|
||||
|
||||
@@ -138,12 +138,6 @@ public:
|
||||
v[i] = (x.v[i] * y.v[i]);
|
||||
}
|
||||
|
||||
void add(octetStream& os)
|
||||
{
|
||||
for (int i = 0; i < L; i++)
|
||||
v[i].add(os);
|
||||
}
|
||||
|
||||
void negate()
|
||||
{
|
||||
for (auto& x : v)
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
#include <sys/stat.h>
|
||||
|
||||
const false_type ValueInterface::binary;
|
||||
|
||||
void ValueInterface::check_setup(const string& directory)
|
||||
{
|
||||
struct stat sb;
|
||||
|
||||
@@ -156,8 +156,6 @@ public:
|
||||
bool operator==(const Z2<K>& other) const;
|
||||
bool operator!=(const Z2<K>& other) const { return not (*this == other); }
|
||||
|
||||
void add(octetStream& os, int = -1) { *this += (os.consume(size())); }
|
||||
|
||||
Z2 lazy_add(const Z2& x) const;
|
||||
Z2 lazy_mul(const Z2& x) const;
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y
|
||||
template<>
|
||||
inline void Zp_Data::Add<1>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
#if defined(__clang__) || !defined(__x86_64__)
|
||||
#if defined(__clang__) || !defined(__x86_64__) || (__GNUC__ == 10)
|
||||
Add<0>(ans, x, y);
|
||||
#else
|
||||
*ans = *x + *y;
|
||||
|
||||
@@ -87,12 +87,6 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs)
|
||||
mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data);
|
||||
}
|
||||
|
||||
void bigint::add(octetStream& os, int)
|
||||
{
|
||||
tmp.unpack(os);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
string to_string(const bigint& x)
|
||||
{
|
||||
stringstream ss;
|
||||
|
||||
@@ -102,8 +102,6 @@ public:
|
||||
|
||||
void mul(const bigint& x, const bigint& y) { *this = x * y; }
|
||||
|
||||
void add(octetStream& os, int = -1);
|
||||
|
||||
#ifdef REALLOC_POLICE
|
||||
~bigint() { lottery(); }
|
||||
void lottery();
|
||||
|
||||
@@ -136,10 +136,6 @@ protected:
|
||||
// x+y
|
||||
void add(const gf2n_& x,const gf2n_& y)
|
||||
{ a=x.a^y.a; }
|
||||
void add(octet* x)
|
||||
{ a^=*(U*)(x); }
|
||||
void add(octetStream& os, int = -1)
|
||||
{ add(os.consume(size())); }
|
||||
void sub(const gf2n_& x,const gf2n_& y)
|
||||
{ a=x.a^y.a; }
|
||||
// = x * y
|
||||
|
||||
@@ -189,12 +189,8 @@ class gfp_ : public ValueInterface
|
||||
bool operator!=(const gfp_& y) const { return !equal(y); }
|
||||
|
||||
// x+y
|
||||
void add(octetStream& os, int = -1)
|
||||
{ add(os.consume(size())); }
|
||||
void add(const gfp_& x,const gfp_& y)
|
||||
{ ZpD.Add<L>(a.x,x.a.x,y.a.x); }
|
||||
void add(void* x)
|
||||
{ ZpD.Add<L>(a.x,a.x,(mp_limb_t*)x); }
|
||||
void sub(const gfp_& x,const gfp_& y)
|
||||
{ ZpD.Sub<L>(a.x,x.a.x,y.a.x); }
|
||||
// = x * y
|
||||
|
||||
@@ -295,12 +295,6 @@ bool gfpvar_<X, L>::operator !=(const gfpvar_<X, L>& other) const
|
||||
return not (*this == other);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::add(octetStream& other, int)
|
||||
{
|
||||
*this += other.get<gfpvar_<X, L>>();
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::negate()
|
||||
{
|
||||
|
||||
@@ -149,8 +149,6 @@ public:
|
||||
bool operator==(const gfpvar_& other) const;
|
||||
bool operator!=(const gfpvar_& other) const;
|
||||
|
||||
void add(octetStream& other, int = -1);
|
||||
|
||||
void negate();
|
||||
|
||||
gfpvar_ invert() const;
|
||||
|
||||
@@ -272,14 +272,14 @@ PlayerBase::~PlayerBase()
|
||||
|
||||
|
||||
// Set up nmachines client and server sockets to send data back and fro
|
||||
// A machine is a server between it and player i if i<=my_number
|
||||
// A machine is a server between it and player i if i>=my_number
|
||||
// Can also communicate with myself, but only with send_to and receive_from
|
||||
void PlainPlayer::setup_sockets(const vector<string>& names,
|
||||
const vector<int>& ports, const string& id_base, ServerSocket& server)
|
||||
{
|
||||
sockets.resize(nplayers);
|
||||
// Set up the client side
|
||||
for (int i=player_no; i<nplayers; i++) {
|
||||
for (int i=0; i<=player_no; i++) {
|
||||
auto pn=id_base+"P"+to_string(player_no);
|
||||
if (i==player_no) {
|
||||
const char* localhost = "127.0.0.1";
|
||||
@@ -300,7 +300,7 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
|
||||
}
|
||||
send_to_self_socket = sockets[player_no];
|
||||
// Setting up the server side
|
||||
for (int i=0; i<=player_no; i++) {
|
||||
for (int i=player_no; i<nplayers; i++) {
|
||||
auto id=id_base+"P"+to_string(i);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr,
|
||||
|
||||
@@ -140,12 +140,12 @@ void ServerSocket::accept_clients()
|
||||
fprintf(stderr, "Accepting...\n");
|
||||
#endif
|
||||
int consocket;
|
||||
for (int i = 0; i < 25; i++)
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
consocket = accept(main_socket, (struct sockaddr*) &dest,
|
||||
(socklen_t*) &socksize);
|
||||
if (consocket < 0)
|
||||
usleep(1 << i);
|
||||
usleep(min(1 << i, 1000));
|
||||
else
|
||||
break;
|
||||
}
|
||||
@@ -204,8 +204,10 @@ int ServerSocket::get_connection_socket(const string& id)
|
||||
|
||||
while (clients.find(id) == clients.end())
|
||||
{
|
||||
if (data_signal.wait(60) == ETIMEDOUT)
|
||||
throw runtime_error("No client after one minute");
|
||||
if (data_signal.wait(CONNECTION_TIMEOUT) == ETIMEDOUT)
|
||||
throw runtime_error("Timed out waiting for peer. See "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for details on networking.");
|
||||
}
|
||||
|
||||
int client_socket = clients[id];
|
||||
@@ -228,7 +230,7 @@ void AnonymousServerSocket::init()
|
||||
void AnonymousServerSocket::process_client(const string& client_id)
|
||||
{
|
||||
if (clients.find(client_id) != clients.end())
|
||||
close_client_socket(clients[client_id]);
|
||||
throw runtime_error("client " + client_id + " already connected");
|
||||
client_connection_queue.push(client_id);
|
||||
}
|
||||
|
||||
@@ -236,9 +238,14 @@ int AnonymousServerSocket::get_connection_socket(string& client_id)
|
||||
{
|
||||
data_signal.lock();
|
||||
|
||||
//while (clients.find(next_client_id) == clients.end())
|
||||
while (client_connection_queue.empty())
|
||||
data_signal.wait();
|
||||
{
|
||||
int res = data_signal.wait(CONNECTION_TIMEOUT);
|
||||
if (res == ETIMEDOUT)
|
||||
throw runtime_error("timed out while waiting for client");
|
||||
else if (res)
|
||||
throw runtime_error("waiting error");
|
||||
}
|
||||
|
||||
client_id = client_connection_queue.front();
|
||||
client_connection_queue.pop();
|
||||
|
||||
@@ -29,7 +29,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
gethostname((char*)my_name,512);
|
||||
|
||||
int erp;
|
||||
for (int i = 0; i < 60; i++)
|
||||
for (int i = 0; i < CONNECTION_TIMEOUT; i++)
|
||||
{ erp=getaddrinfo (hostname, NULL, &hints, &ai);
|
||||
if (erp == 0)
|
||||
{ break; }
|
||||
@@ -90,7 +90,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
if (fl != 0)
|
||||
{
|
||||
close(mysocket);
|
||||
usleep(wait *= 2);
|
||||
usleep(wait < 1000 ? wait *= 2 : wait);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
string msg = "Connecting to " + string(hostname) + ":" +
|
||||
to_string(Portnum) + " failed";
|
||||
@@ -102,7 +102,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
}
|
||||
while (fl == -1
|
||||
&& (errno == ECONNREFUSED || errno == ETIMEDOUT || errno == EINPROGRESS)
|
||||
&& timer.elapsed() < 60);
|
||||
&& timer.elapsed() < CONNECTION_TIMEOUT);
|
||||
|
||||
if (fl < 0)
|
||||
{
|
||||
|
||||
@@ -23,6 +23,10 @@
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
// default to one minute
|
||||
#ifndef CONNECTION_TIMEOUT
|
||||
#define CONNECTION_TIMEOUT 60
|
||||
#endif
|
||||
|
||||
void error(const char *str);
|
||||
|
||||
@@ -38,10 +42,15 @@ void receive(T& socket, size_t& a, size_t len);
|
||||
|
||||
inline size_t send_non_blocking(int socket, octet* msg, size_t len)
|
||||
{
|
||||
#ifdef __APPLE__
|
||||
int j = send(socket,msg,min(len,10000lu),MSG_DONTWAIT);
|
||||
#else
|
||||
int j = send(socket,msg,len,MSG_DONTWAIT);
|
||||
#endif
|
||||
if (j < 0)
|
||||
{
|
||||
if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK)
|
||||
if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK and
|
||||
errno != ENOBUFS)
|
||||
{ error("Send error - 1 "); }
|
||||
else
|
||||
return 0;
|
||||
|
||||
@@ -41,6 +41,8 @@ union square128 {
|
||||
int16_t doublebytes[128][8];
|
||||
int32_t words[128][4];
|
||||
|
||||
square128() {}
|
||||
|
||||
bool get_bit(int x, int y)
|
||||
{ return (bytes[x][y/8] >> (y % 8)) & 1; }
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include "Instruction.hpp"
|
||||
#include "Protocols/ShuffleSacrifice.hpp"
|
||||
|
||||
#include <iostream>
|
||||
@@ -38,12 +39,27 @@ bool BaseMachine::has_program()
|
||||
}
|
||||
|
||||
int BaseMachine::edabit_bucket_size(int n_bits)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
int BaseMachine::triple_bucket_size(DataFieldType type)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().files[type][DATA_TRIPLE];
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
int BaseMachine::bucket_size(size_t usage)
|
||||
{
|
||||
int res = OnlineOptions::singleton.bucket_size;
|
||||
|
||||
if (has_program())
|
||||
if (usage)
|
||||
{
|
||||
auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
for (int B = res; B <= 5; B++)
|
||||
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
|
||||
break;
|
||||
@@ -91,7 +107,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
string threadname;
|
||||
for (int i=0; i<nprogs; i++)
|
||||
{ inpf >> threadname;
|
||||
size_t split = threadname.find(":");
|
||||
size_t split = threadname.find_last_of(":");
|
||||
long expected = -1;
|
||||
if (split != string::npos)
|
||||
{
|
||||
@@ -125,6 +141,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
getline(inpf, compiler);
|
||||
getline(inpf, domain);
|
||||
getline(inpf, relevant_opts);
|
||||
getline(inpf, security);
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
@@ -184,17 +201,19 @@ string BaseMachine::memory_filename(const string& type_short, int my_number)
|
||||
|
||||
string BaseMachine::get_domain(string progname)
|
||||
{
|
||||
if (singleton)
|
||||
{
|
||||
assert(s().progname == progname);
|
||||
return s().domain;
|
||||
}
|
||||
return get_basics(progname).domain;
|
||||
}
|
||||
|
||||
assert(not singleton);
|
||||
BaseMachine BaseMachine::get_basics(string progname)
|
||||
{
|
||||
if (singleton and s().progname == progname)
|
||||
return s();
|
||||
|
||||
auto backup = singleton;
|
||||
BaseMachine machine;
|
||||
singleton = 0;
|
||||
singleton = backup;
|
||||
machine.load_schedule(progname, false);
|
||||
return machine.domain;
|
||||
return machine;
|
||||
}
|
||||
|
||||
int BaseMachine::ring_size_from_schedule(string progname)
|
||||
@@ -226,6 +245,15 @@ bigint BaseMachine::prime_from_schedule(string progname)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int BaseMachine::security_from_schedule(string progname)
|
||||
{
|
||||
string sec = get_basics(progname).security;
|
||||
if (sec.substr(0, 4).compare("sec:") == 0)
|
||||
return stoi(sec.substr(4));
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
NamedCommStats BaseMachine::total_comm()
|
||||
{
|
||||
NamedCommStats res;
|
||||
|
||||
@@ -32,10 +32,13 @@ protected:
|
||||
string compiler;
|
||||
string domain;
|
||||
string relevant_opts;
|
||||
string security;
|
||||
|
||||
virtual size_t load_program(const string& threadname,
|
||||
const string& filename);
|
||||
|
||||
static BaseMachine get_basics(string progname);
|
||||
|
||||
public:
|
||||
static thread_local int thread_num;
|
||||
|
||||
@@ -58,12 +61,15 @@ public:
|
||||
static int ring_size_from_schedule(string progname);
|
||||
static int prime_length_from_schedule(string progname);
|
||||
static bigint prime_from_schedule(string progname);
|
||||
static int security_from_schedule(string progname);
|
||||
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
template<class T>
|
||||
static int edabit_batch_size(int n_bits, int buffer_size = 0);
|
||||
static int edabit_bucket_size(int n_bits);
|
||||
static int triple_bucket_size(DataFieldType type);
|
||||
static int bucket_size(size_t usage);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "Processor/ExternalClients.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
#include "Networking/ServerSocket.h"
|
||||
#include "Networking/ssl_sockets.h"
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <thread>
|
||||
@@ -25,6 +27,8 @@ ExternalClients::~ExternalClients()
|
||||
}
|
||||
if (ctx)
|
||||
delete ctx;
|
||||
for (auto it = peer_ctxs.begin(); it != peer_ctxs.end(); it++)
|
||||
delete it->second;
|
||||
}
|
||||
|
||||
void ExternalClients::start_listening(int portnum_base)
|
||||
@@ -32,8 +36,9 @@ void ExternalClients::start_listening(int portnum_base)
|
||||
ScopeLock _(lock);
|
||||
client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num());
|
||||
client_connection_servers[portnum_base]->init();
|
||||
cerr << "Start listening on thread " << this_thread::get_id() << endl;
|
||||
cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num())
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Party " << get_party_num() << " is listening on port "
|
||||
<< (portnum_base + get_party_num())
|
||||
<< " for external client connections." << endl;
|
||||
}
|
||||
|
||||
@@ -46,7 +51,6 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl;
|
||||
throw runtime_error("No connection on port " + to_string(portnum_base));
|
||||
}
|
||||
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
|
||||
int client_id, socket;
|
||||
string client;
|
||||
socket = client_connection_servers[portnum_base]->get_connection_socket(
|
||||
@@ -57,10 +61,38 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket,
|
||||
"C" + to_string(client_id), "P" + to_string(get_party_num()), false);
|
||||
client_ports[client_id] = portnum_base;
|
||||
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Party " << get_party_num()
|
||||
<< " received external client connection from client id: " << dec
|
||||
<< client_id << endl;
|
||||
return client_id;
|
||||
}
|
||||
|
||||
int ExternalClients::init_client_connection(const string& host, int portnum,
|
||||
int my_client_id)
|
||||
{
|
||||
ScopeLock _(lock);
|
||||
int plain_socket;
|
||||
set_up_client_socket(plain_socket, host.c_str(), portnum);
|
||||
octetStream(to_string(my_client_id)).Send(plain_socket);
|
||||
string my_client_name = "C" + to_string(my_client_id);
|
||||
if (peer_ctxs.find(my_client_id) == peer_ctxs.end())
|
||||
peer_ctxs[my_client_id] = new client_ctx(my_client_name);
|
||||
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
|
||||
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
|
||||
true);
|
||||
if (party_num == 0)
|
||||
{
|
||||
octetStream specification;
|
||||
specification.Receive(socket);
|
||||
}
|
||||
int id = -1;
|
||||
if (not external_client_sockets.empty())
|
||||
id = min(id, external_client_sockets.begin()->first);
|
||||
external_client_sockets[id] = socket;
|
||||
return id;
|
||||
}
|
||||
|
||||
void ExternalClients::close_connection(int client_id)
|
||||
{
|
||||
ScopeLock _(lock);
|
||||
|
||||
@@ -32,6 +32,7 @@ class ExternalClients
|
||||
|
||||
ssl_service io_service;
|
||||
client_ctx* ctx;
|
||||
map<int, client_ctx*> peer_ctxs;
|
||||
|
||||
Lock lock;
|
||||
|
||||
@@ -43,6 +44,7 @@ class ExternalClients
|
||||
void start_listening(int portnum_base);
|
||||
|
||||
int get_client_connection(int portnum_base);
|
||||
int init_client_connection(const string& host, int portnum, int my_client_id);
|
||||
|
||||
void close_connection(int client_id);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "OnlineMachine.hpp"
|
||||
#include "OnlineOptions.hpp"
|
||||
|
||||
|
||||
template<template<class U> class T, class V>
|
||||
HonestMajorityFieldMachine<T, V>::HonestMajorityFieldMachine(int argc,
|
||||
const char **argv)
|
||||
@@ -34,6 +33,7 @@ template<template<class U> class T, template<class U> class V, class W, class X>
|
||||
FieldMachine<T, V, W, X>::FieldMachine(int argc, const char** argv,
|
||||
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
|
||||
{
|
||||
assert(nplayers or T<gfpvar>::variable_players);
|
||||
W machine(argc, argv, opt, online_opts, X(), nplayers);
|
||||
int n_limbs = online_opts.prime_limbs();
|
||||
switch (n_limbs)
|
||||
|
||||
@@ -10,11 +10,13 @@
|
||||
#include "Math/gf2n.h"
|
||||
#include "GC/instructions.h"
|
||||
|
||||
#include "Memory.hpp"
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
vector<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
{
|
||||
auto& C2 = registers;
|
||||
auto& M2C = memory;
|
||||
@@ -123,6 +125,6 @@ ostream& operator<<(ostream& s, const Instruction& instr)
|
||||
}
|
||||
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
|
||||
vector<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
|
||||
vector<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
|
||||
@@ -72,6 +72,7 @@ enum
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -153,7 +154,7 @@ enum
|
||||
LISTEN = 0x6c,
|
||||
ACCEPTCLIENTCONNECTION = 0x6d,
|
||||
CLOSECLIENTCONNECTION = 0x6e,
|
||||
READCLIENTPUBLICKEY = 0x6f,
|
||||
INITCLIENTCONNECTION = 0x6f,
|
||||
// Bitwise logic
|
||||
ANDC = 0x70,
|
||||
XORC = 0x71,
|
||||
@@ -197,6 +198,7 @@ enum
|
||||
PRINTREG = 0XB1,
|
||||
RAND = 0xB2,
|
||||
PRINTREGPLAIN = 0xB3,
|
||||
PRINTREGPLAINS = 0xEA,
|
||||
PRINTCHR = 0xB4,
|
||||
PRINTSTR = 0xB5,
|
||||
PUBINPUT = 0xB6,
|
||||
@@ -345,6 +347,7 @@ protected:
|
||||
int r[4]; // Fixed parameter registers
|
||||
size_t n; // Possible immediate value
|
||||
vector<int> start; // Values for a start/stop open
|
||||
string str;
|
||||
|
||||
public:
|
||||
virtual ~BaseInstruction() {};
|
||||
@@ -387,7 +390,7 @@ public:
|
||||
void execute(Processor<sint, sgf2n>& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
void execute_clear_gf2n(vector<cgf2n>& registers, vector<cgf2n>& memory,
|
||||
void execute_clear_gf2n(vector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
|
||||
ArithmeticProcessor& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
|
||||
@@ -105,7 +105,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case STMCBI:
|
||||
case MOVC:
|
||||
case MOVS:
|
||||
case MOVSB:
|
||||
case MOVINT:
|
||||
case LDMINTI:
|
||||
case STMINTI:
|
||||
@@ -131,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case SHUFFLE:
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
case PREFIXSUMS:
|
||||
case CMDLINEARG:
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
// instructions with 1 register operand
|
||||
@@ -139,6 +139,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case RANDOMFULLS:
|
||||
case PRINTREGPLAIN:
|
||||
case PRINTREGPLAINB:
|
||||
case PRINTREGPLAINS:
|
||||
case LDTN:
|
||||
case LDARG:
|
||||
case STARG:
|
||||
@@ -316,13 +317,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
case CONV2DS:
|
||||
case MATMULS:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case MATMULS:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(3, start, s);
|
||||
break;
|
||||
case MATMULSM:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(9, start, s);
|
||||
@@ -358,7 +356,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITCLIENTCONNECTION:
|
||||
get_ints(r, s, 3);
|
||||
get_string(str, s);
|
||||
break;
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
throw runtime_error("VM-controlled encryption not supported any more");
|
||||
@@ -459,6 +460,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case CONVCBIT2S:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
case MOVSB:
|
||||
n = get_int(s);
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
@@ -566,7 +568,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case MOVINT:
|
||||
case READSOCKETINT:
|
||||
case WRITESOCKETINT:
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITCLIENTCONNECTION:
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
case LDARG:
|
||||
@@ -584,6 +586,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case INTOUTPUT:
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
case GENSECSHUFFLE:
|
||||
case CMDLINEARG:
|
||||
return INT;
|
||||
case PREP:
|
||||
case GPREP:
|
||||
@@ -723,6 +726,15 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return res;
|
||||
}
|
||||
case MATMULS:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
int tmp = *it + *(it + 3) * *(it + 5);
|
||||
res = max(res, tmp);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case MATMULSM:
|
||||
return r[0] + start[0] * start[2];
|
||||
case CONV2DS:
|
||||
@@ -817,7 +829,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
while (it < start.end())
|
||||
{
|
||||
int n = *it - n_prefix;
|
||||
int size = DIV_CEIL(*(it + 1), 64);
|
||||
size = max((long long) size, DIV_CEIL(*(it + 1), 64));
|
||||
it += n_prefix;
|
||||
assert(it + n <= start.end());
|
||||
for (int i = 0; i < n; i++)
|
||||
@@ -922,16 +934,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n));
|
||||
n++;
|
||||
break;
|
||||
case LDMCI:
|
||||
Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1])));
|
||||
break;
|
||||
case STMC:
|
||||
Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0]));
|
||||
n++;
|
||||
break;
|
||||
case STMCI:
|
||||
Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0]));
|
||||
break;
|
||||
case MOVC:
|
||||
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
|
||||
break;
|
||||
@@ -1089,10 +1095,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.POpen(*this);
|
||||
return;
|
||||
case MULS:
|
||||
Proc.Procp.muls(start, size);
|
||||
Proc.Procp.muls(start);
|
||||
return;
|
||||
case GMULS:
|
||||
Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size);
|
||||
Proc.Proc2.muls(start);
|
||||
return;
|
||||
case MULRS:
|
||||
Proc.Procp.mulrs(start);
|
||||
@@ -1107,7 +1113,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.dotprods(start, size);
|
||||
return;
|
||||
case MATMULS:
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]);
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
|
||||
return;
|
||||
case MATMULSM:
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
|
||||
@@ -1126,13 +1132,15 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.secure_shuffle(*this);
|
||||
return;
|
||||
case GENSECSHUFFLE:
|
||||
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this));
|
||||
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this,
|
||||
Proc.machine.shuffle_store));
|
||||
return;
|
||||
case APPLYSHUFFLE:
|
||||
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)));
|
||||
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)),
|
||||
Proc.machine.shuffle_store);
|
||||
return;
|
||||
case DELSHUFFLE:
|
||||
Proc.Procp.delete_shuffle(Proc.read_Ci(r[0]));
|
||||
Proc.machine.shuffle_store.del(Proc.read_Ci(r[0]));
|
||||
return;
|
||||
case INVPERM:
|
||||
Proc.Procp.inverse_permutation(*this);
|
||||
@@ -1170,6 +1178,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case PRINTREGPLAIN:
|
||||
print(Proc.out, &Proc.read_Cp(r[0]));
|
||||
return;
|
||||
case PRINTREGPLAINS:
|
||||
Proc.out << Proc.read_Sp(r[0]);
|
||||
return;
|
||||
case CONDPRINTPLAIN:
|
||||
if (not Proc.read_Cp(r[0]).is_zero())
|
||||
{
|
||||
@@ -1237,6 +1248,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case PLAYERID:
|
||||
Proc.write_Ci(r[0], Proc.P.my_num());
|
||||
break;
|
||||
case CMDLINEARG:
|
||||
{
|
||||
size_t idx = Proc.read_Ci(r[1]);
|
||||
auto& args = OnlineOptions::singleton.args;
|
||||
if (idx < args.size())
|
||||
Proc.write_Ci(r[0], args[idx]);
|
||||
else
|
||||
{
|
||||
cerr << idx << "-th command-line argument not given" << endl;
|
||||
exit(1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// ***
|
||||
// TODO: read/write shared GF(2^n) data instructions
|
||||
// ***
|
||||
@@ -1255,11 +1279,17 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
octetStream os;
|
||||
os.store(int(sint::open_type::type_char()));
|
||||
sint::specification(os);
|
||||
sint::clear::specification(os);
|
||||
os.Send(Proc.external_clients.get_socket(client_handle));
|
||||
}
|
||||
Proc.write_Ci(r[0], client_handle);
|
||||
break;
|
||||
}
|
||||
case INITCLIENTCONNECTION:
|
||||
Proc.write_Ci(r[0],
|
||||
Proc.external_clients.init_client_connection(str,
|
||||
Proc.read_Ci(r[1]), Proc.read_Ci(r[2])));
|
||||
break;
|
||||
case CLOSECLIENTCONNECTION:
|
||||
Proc.external_clients.close_connection(Proc.read_Ci(r[0]));
|
||||
break;
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/ExecutionStats.h"
|
||||
|
||||
#include "Protocols/SecureShuffle.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <atomic>
|
||||
@@ -70,6 +72,8 @@ class Machine : public BaseMachine
|
||||
|
||||
ExternalClients external_clients;
|
||||
|
||||
typename sint::Protocol::Shuffler::store_type shuffle_store;
|
||||
|
||||
static void init_binary_domains(int security_parameter, int lg2);
|
||||
|
||||
Machine(Names& playerNames, bool use_encryption = true,
|
||||
|
||||
@@ -60,11 +60,21 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
{
|
||||
OnlineOptions::singleton = opts;
|
||||
|
||||
if (N.num_players() == 1 and sint::is_real)
|
||||
int min_players = 3 - sint::dishonest_majority;
|
||||
if (sint::is_real)
|
||||
{
|
||||
cerr << "Need more than one player to run a protocol." << endl;
|
||||
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
|
||||
exit(1);
|
||||
if (N.num_players() == 1)
|
||||
{
|
||||
cerr << "Need more than one player to run a protocol." << endl;
|
||||
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
|
||||
exit(1);
|
||||
}
|
||||
else if (N.num_players() < min_players)
|
||||
{
|
||||
cerr << "Need at least " << min_players << " players for this protocol."
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Set the prime modulus from command line or program if applicable
|
||||
@@ -480,8 +490,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Communication details "
|
||||
"(rounds in parallel threads counted double):" << endl;
|
||||
cerr << "Communication details";
|
||||
if (multithread)
|
||||
cerr << " (rounds in parallel threads counted double)";
|
||||
cerr << ":" << endl;
|
||||
comm_stats.print();
|
||||
cerr << "CPU time = " << proc_timer.elapsed();
|
||||
if (multithread)
|
||||
@@ -547,6 +559,14 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
suggest_optimizations();
|
||||
|
||||
if (N.num_players() > 4)
|
||||
{
|
||||
string alt = sint::alt();
|
||||
if (alt.size())
|
||||
cerr << "This protocol doesn't scale well with the number of parties, "
|
||||
<< "have you considered using " << alt << " instead?" << endl;
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "End of prog" << endl;
|
||||
#endif
|
||||
|
||||
@@ -14,34 +14,90 @@ template<class T> istream& operator>>(istream& s,Memory<T>& M);
|
||||
|
||||
#include "Processor/Program.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
#include "Tools/DiskVector.h"
|
||||
|
||||
template<class T>
|
||||
class MemoryPart : public CheckVector<T>
|
||||
class MemoryPart
|
||||
{
|
||||
public:
|
||||
template<class U>
|
||||
static void check_index(const vector<U>& M, size_t i)
|
||||
virtual ~MemoryPart() {}
|
||||
|
||||
virtual size_t size() const = 0;
|
||||
virtual void resize(size_t) = 0;
|
||||
|
||||
virtual T* data() = 0;
|
||||
virtual const T* data() const = 0;
|
||||
|
||||
void check_index(size_t i) const
|
||||
{
|
||||
(void) M, (void) i;
|
||||
(void) i;
|
||||
#ifndef NO_CHECK_INDEX
|
||||
if (i >= M.size())
|
||||
throw overflow(U::type_string() + " memory", i, M.size());
|
||||
if (i >= this->size())
|
||||
throw overflow(T::type_string() + " memory", i, this->size());
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual T& operator[](size_t i) = 0;
|
||||
virtual const T& operator[](size_t i) const = 0;
|
||||
|
||||
virtual T& at(size_t i) = 0;
|
||||
virtual const T& at(size_t i) const = 0;
|
||||
|
||||
template<class U>
|
||||
void indirect_read(const Instruction& inst, vector<T>& regs,
|
||||
const U& indices);
|
||||
template<class U>
|
||||
void indirect_write(const Instruction& inst, vector<T>& regs,
|
||||
const U& indices);
|
||||
|
||||
void minimum_size(size_t size);
|
||||
};
|
||||
|
||||
template<class T, template<class> class V>
|
||||
class MemoryPartImpl : public MemoryPart<T>, public V<T>
|
||||
{
|
||||
public:
|
||||
size_t size() const
|
||||
{
|
||||
return V<T>::size();
|
||||
}
|
||||
|
||||
void resize(size_t size)
|
||||
{
|
||||
V<T>::resize(size);
|
||||
}
|
||||
|
||||
T* data()
|
||||
{
|
||||
return V<T>::data();
|
||||
}
|
||||
|
||||
const T* data() const
|
||||
{
|
||||
return V<T>::data();
|
||||
}
|
||||
|
||||
T& operator[](size_t i)
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
this->check_index(i);
|
||||
return V<T>::operator[](i);
|
||||
}
|
||||
|
||||
const T& operator[](size_t i) const
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
this->check_index(i);
|
||||
return V<T>::operator[](i);
|
||||
}
|
||||
|
||||
void minimum_size(size_t size);
|
||||
T& at(size_t i)
|
||||
{
|
||||
return V<T>::at(i);
|
||||
}
|
||||
|
||||
const T& at(size_t i) const
|
||||
{
|
||||
return V<T>::at(i);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -49,8 +105,11 @@ class Memory
|
||||
{
|
||||
public:
|
||||
|
||||
MemoryPart<T> MS;
|
||||
MemoryPart<typename T::clear> MC;
|
||||
MemoryPart<T>& MS;
|
||||
MemoryPartImpl<typename T::clear, CheckVector> MC;
|
||||
|
||||
Memory();
|
||||
~Memory();
|
||||
|
||||
void resize_s(size_t sz)
|
||||
{ MS.resize(sz); }
|
||||
|
||||
@@ -3,6 +3,54 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_read(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto dest = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(dest + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
const T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
throw overflow(T::type_string() + " memory read", it->get(), size);
|
||||
#endif
|
||||
*dest++ = data[it->get()];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_write(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto source = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(source + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
throw overflow(T::type_string() + " memory write", it->get(), size);
|
||||
#endif
|
||||
data[it->get()] = *source++;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
|
||||
const Program &program, const string& threadname)
|
||||
@@ -29,6 +77,21 @@ void MemoryPart<T>::minimum_size(size_t size)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Memory<T>::Memory() :
|
||||
MS(
|
||||
*(OnlineOptions::singleton.disk_memory.size() ?
|
||||
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, DiskVector>) :
|
||||
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, CheckVector>)))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Memory<T>::~Memory()
|
||||
{
|
||||
delete &MS;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ostream& operator<<(ostream& s,const Memory<T>& M)
|
||||
{
|
||||
|
||||
@@ -71,18 +71,6 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op
|
||||
"--ip-file-name" // Flag token.
|
||||
);
|
||||
|
||||
if (nplayers == 0)
|
||||
opt.add(
|
||||
"2", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of players (default: 2). "
|
||||
"Ignored if external server is used.", // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
|
||||
@@ -22,12 +22,13 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
interactive = false;
|
||||
lgp = gfp0::MAX_N_BITS;
|
||||
live_prep = true;
|
||||
batch_size = 10000;
|
||||
batch_size = 1000;
|
||||
memtype = "empty";
|
||||
bits_from_squares = false;
|
||||
direct = false;
|
||||
bucket_size = 4;
|
||||
security_parameter = DEFAULT_SECURITY;
|
||||
use_security_parameter = false;
|
||||
cmd_private_input_file = "Player-Data/Input";
|
||||
cmd_private_output_file = "";
|
||||
file_prep_per_thread = false;
|
||||
@@ -46,6 +47,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, bool security) :
|
||||
OnlineOptions()
|
||||
{
|
||||
use_security_parameter = security;
|
||||
|
||||
opt.syntax = std::string(argv[0]) + " [OPTIONS] [<playerno>] <progname>";
|
||||
|
||||
opt.add(
|
||||
@@ -116,7 +119,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Security parameter (default: " + to_string(security_parameter)
|
||||
("Statistical ecurity parameter (default: " + to_string(security_parameter)
|
||||
+ ")").c_str(), // Help description.
|
||||
"-S", // Flag token.
|
||||
"--security" // Flag token.
|
||||
@@ -138,7 +141,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
if (security)
|
||||
{
|
||||
opt.get("-S")->getInt(security_parameter);
|
||||
cerr << "Using security parameter " << security_parameter << endl;
|
||||
if (security_parameter <= 0)
|
||||
{
|
||||
cerr << "Invalid security parameter: " << security_parameter << endl;
|
||||
@@ -280,7 +282,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
}
|
||||
|
||||
void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv)
|
||||
const char** argv, bool networking)
|
||||
{
|
||||
opt.resetArgs();
|
||||
opt.parse(argc, argv);
|
||||
@@ -292,17 +294,21 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
vector<string> badOptions;
|
||||
unsigned int i;
|
||||
|
||||
opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for documentation on the networking setup.\n";
|
||||
if (networking)
|
||||
opt.footer += "See also "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for documentation on the networking setup.\n\n";
|
||||
|
||||
if (allArgs.size() != 3u - opt.isSet("-p"))
|
||||
size_t name_index = 1 + networking - opt.isSet("-p");
|
||||
|
||||
if (allArgs.size() < name_index + 1)
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl;
|
||||
cerr << "Arguments given were:\n";
|
||||
for (unsigned int j = 1; j < allArgs.size(); j++)
|
||||
cout << "'" << *allArgs[j] << "'" << endl;
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(1);
|
||||
}
|
||||
else
|
||||
@@ -311,25 +317,25 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
opt.get("-p")->getInt(playerno);
|
||||
else
|
||||
sscanf((*allArgs[1]).c_str(), "%d", &playerno);
|
||||
progname = *allArgs[2 - opt.isSet("-p")];
|
||||
progname = *allArgs.at(name_index);
|
||||
}
|
||||
|
||||
if (!opt.gotRequired(badOptions))
|
||||
{
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (!opt.gotExpected(badOptions))
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Got unexpected number of arguments for option "
|
||||
<< badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -347,6 +353,22 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
prime = schedule_prime;
|
||||
}
|
||||
|
||||
for (size_t i = name_index + 1; i < allArgs.size(); i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
args.push_back(stol(*allArgs[i]));
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cerr << usage;
|
||||
cerr << "Additional argument has to be integer: " << *allArgs[i]
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// ignore program if length explicitly set from command line
|
||||
if (opt.get("-lgp") and not opt.isSet("-lgp"))
|
||||
{
|
||||
@@ -367,7 +389,29 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
if (o)
|
||||
o->getInt(max_broadcast);
|
||||
|
||||
o = opt.get("--disk-memory");
|
||||
if (o)
|
||||
o->getString(disk_memory);
|
||||
|
||||
receive_threads = opt.isSet("--threads");
|
||||
|
||||
if (use_security_parameter)
|
||||
{
|
||||
int program_sec = BaseMachine::security_from_schedule(progname);
|
||||
|
||||
if (program_sec > 0)
|
||||
{
|
||||
if (not opt.isSet("-S"))
|
||||
security_parameter = program_sec;
|
||||
if (program_sec < security_parameter)
|
||||
{
|
||||
cerr << "Security parameter used in compilation is insufficient" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "Using statistical security parameter " << security_parameter << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt)
|
||||
|
||||
@@ -27,6 +27,7 @@ public:
|
||||
bool direct;
|
||||
int bucket_size;
|
||||
int security_parameter;
|
||||
bool use_security_parameter;
|
||||
std::string cmd_private_input_file;
|
||||
std::string cmd_private_output_file;
|
||||
bool verbose;
|
||||
@@ -34,6 +35,8 @@ public:
|
||||
int trunc_error;
|
||||
int opening_sum, max_broadcast;
|
||||
bool receive_threads;
|
||||
std::string disk_memory;
|
||||
vector<long> args;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
@@ -48,7 +51,8 @@ public:
|
||||
OnlineOptions(T);
|
||||
~OnlineOptions() {}
|
||||
|
||||
void finalize(ez::ezOptionParser& opt, int argc, const char** argv);
|
||||
void finalize(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
bool networking = true);
|
||||
|
||||
void set_trunc_error(ez::ezOptionParser& opt);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
template<class T>
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, T, bool default_live_prep) :
|
||||
OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0,
|
||||
OnlineOptions(opt, argc, argv, OnlineOptions(T()).batch_size,
|
||||
default_live_prep, T::clear::prime_field)
|
||||
{
|
||||
if (T::has_trunc_pr)
|
||||
@@ -56,13 +56,39 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"--max-broadcast" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
if (not T::clear::binary)
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Use directory on disk for memory (container data structures) "
|
||||
"instead of RAM", // Help description.
|
||||
"-D", // Flag token.
|
||||
"--disk-memory" // Flag token.
|
||||
);
|
||||
|
||||
if (T::variable_players)
|
||||
opt.add(
|
||||
T::dishonest_majority ? "2" : "3", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Number of players (default: "
|
||||
+ (T::dishonest_majority ?
|
||||
to_string("2") : to_string("3")) + "). " +
|
||||
"Ignored if external server is used.").c_str(), // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OnlineOptions::OnlineOptions(T) : OnlineOptions()
|
||||
{
|
||||
if (T::dishonest_majority)
|
||||
batch_size = 1000;
|
||||
if (not T::dishonest_majority)
|
||||
batch_size = 10000;
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */
|
||||
|
||||
@@ -36,7 +36,7 @@ class SubProcessor
|
||||
|
||||
void resize(size_t size) { C.resize(size); S.resize(size); }
|
||||
|
||||
void matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
void matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b);
|
||||
void matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
typename vector<T>::iterator C);
|
||||
@@ -48,6 +48,8 @@ class SubProcessor
|
||||
|
||||
typedef typename T::bit_type::part_type BT;
|
||||
|
||||
typedef typename T::Protocol::Shuffler::store_type ShuffleStore;
|
||||
|
||||
public:
|
||||
ArithmeticProcessor* Proc;
|
||||
typename T::MAC_Check& MC;
|
||||
@@ -71,19 +73,19 @@ public:
|
||||
// Access to PO (via calls to POpen start/stop)
|
||||
void POpen(const Instruction& inst);
|
||||
|
||||
void muls(const vector<int>& reg, int size);
|
||||
void muls(const vector<int>& reg);
|
||||
void mulrs(const vector<int>& reg);
|
||||
void dotprods(const vector<int>& reg, int size);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, size_t a,
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction);
|
||||
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
void secure_shuffle(const Instruction& instruction);
|
||||
size_t generate_secure_shuffle(const Instruction& instruction);
|
||||
void apply_shuffle(const Instruction& instruction, int handle);
|
||||
void delete_shuffle(int handle);
|
||||
size_t generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store);
|
||||
void apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store);
|
||||
void inverse_permutation(const Instruction& instruction);
|
||||
|
||||
void input_personal(const vector<int>& args);
|
||||
@@ -116,7 +118,7 @@ public:
|
||||
class ArithmeticProcessor : public ProcessorBase
|
||||
{
|
||||
protected:
|
||||
CheckVector<long> Ci;
|
||||
CheckVector<Integer> Ci;
|
||||
|
||||
ofstream public_output;
|
||||
ofstream binary_output;
|
||||
@@ -162,13 +164,13 @@ public:
|
||||
return thread_num;
|
||||
}
|
||||
|
||||
const long& read_Ci(size_t i) const
|
||||
{ return Ci[i]; }
|
||||
long& get_Ci_ref(size_t i)
|
||||
long read_Ci(size_t i) const
|
||||
{ return Ci[i].get(); }
|
||||
Integer& get_Ci_ref(size_t i)
|
||||
{ return Ci[i]; }
|
||||
void write_Ci(size_t i, const long& x)
|
||||
{ Ci[i]=x; }
|
||||
CheckVector<long>& get_Ci()
|
||||
CheckVector<Integer>& get_Ci()
|
||||
{ return Ci; }
|
||||
|
||||
virtual ofstream& get_public_output()
|
||||
|
||||
@@ -379,9 +379,20 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id,
|
||||
client_timer.stop();
|
||||
client_stats.add(socket_stream.get_length());
|
||||
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
|
||||
int j, i;
|
||||
try
|
||||
{
|
||||
for (j = 0; j < size; j++)
|
||||
for (i = 0; i < m; i++)
|
||||
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
throw insufficient_shares(m * size, j * m + i, e);
|
||||
}
|
||||
|
||||
if (socket_stream.left())
|
||||
throw runtime_error("unexpected share data");
|
||||
}
|
||||
|
||||
|
||||
@@ -468,28 +479,29 @@ void SubProcessor<T>::POpen(const Instruction& inst)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::muls(const vector<int>& reg, int size)
|
||||
void SubProcessor<T>::muls(const vector<int>& reg)
|
||||
{
|
||||
assert(reg.size() % 3 == 0);
|
||||
int n = reg.size() / 3;
|
||||
assert(reg.size() % 4 == 0);
|
||||
int n = reg.size() / 4;
|
||||
|
||||
SubProcessor<T>& proc = *this;
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
auto& x = proc.S[reg[3 * i + 1] + j];
|
||||
auto& y = proc.S[reg[3 * i + 2] + j];
|
||||
auto& x = proc.S[reg[4 * i + 2] + j];
|
||||
auto& y = proc.S[reg[4 * i + 3] + j];
|
||||
protocol.prepare_mul(x, y);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
proc.S[reg[3 * i] + j] = protocol.finalize_mul();
|
||||
proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul();
|
||||
}
|
||||
|
||||
protocol.counter += n * size;
|
||||
protocol.counter += n * reg[4 * i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -553,33 +565,46 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
const Instruction& instruction)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
auto A = source.begin() + a;
|
||||
auto B = source.begin() + b;
|
||||
auto C = S.begin() + (instruction.get_r(0));
|
||||
assert(A + dim[0] * dim[1] <= source.end());
|
||||
assert(B + dim[1] * dim[2] <= source.end());
|
||||
assert(C + dim[0] * dim[2] <= S.end());
|
||||
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
protocol.prepare_dotprod(*(A + i * dim[1] + k),
|
||||
*(B + k * dim[2] + j));
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
|
||||
auto& start = instruction.get_start();
|
||||
assert(start.size() % 6 == 0);
|
||||
|
||||
for(auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
auto dim = it + 3;
|
||||
auto A = source.begin() + *(it + 1);
|
||||
auto B = source.begin() + *(it + 2);
|
||||
assert(A + dim[0] * dim[1] <= source.end());
|
||||
assert(B + dim[1] * dim[2] <= source.end());
|
||||
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
protocol.prepare_dotprod(*(A + i * dim[1] + k),
|
||||
*(B + k * dim[2] + j));
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
}
|
||||
|
||||
protocol.exchange();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
|
||||
for(auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
auto C = S.begin() + *it;
|
||||
auto dim = it + 3;
|
||||
assert(C + dim[0] * dim[2] <= S.end());
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
@@ -592,7 +617,7 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
{
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i);
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i).get();
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
#ifdef DEBUG_MATMULSM
|
||||
@@ -628,16 +653,21 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b)
|
||||
{
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j);
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j).get();
|
||||
const T* base = source.data();
|
||||
size_t size = source.size();
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k);
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k);
|
||||
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
|
||||
source.at(b + ll * dim[8] + jj));
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k).get();
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k).get();
|
||||
auto aa = a + ii * dim[7] + kk;
|
||||
auto bb = b + ll * dim[8] + jj;
|
||||
assert(aa < size);
|
||||
assert(bb < size);
|
||||
protocol.prepare_dotprod(base[aa], base[bb]);
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
@@ -655,16 +685,22 @@ void SubProcessor<T>::matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
template<class T>
|
||||
void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
auto& args = instruction.get_start();
|
||||
vector<Conv2dTuple> tuples;
|
||||
for (size_t i = 0; i < args.size(); i += 15)
|
||||
tuples.push_back(Conv2dTuple(args, i));
|
||||
for (auto& tuple : tuples)
|
||||
tuple.pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (auto& tuple : tuples)
|
||||
tuple.post(S, protocol);
|
||||
size_t done = 0;
|
||||
while (done < tuples.size())
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
size_t i;
|
||||
for (i = done; i < tuples.size() and protocol.get_buffer_size() <
|
||||
OnlineOptions::singleton.batch_size; i++)
|
||||
tuples[i].pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (; done < i; done++)
|
||||
tuples[done].post(S, protocol);
|
||||
}
|
||||
}
|
||||
|
||||
inline
|
||||
@@ -766,25 +802,22 @@ void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction)
|
||||
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
return shuffler.generate(instruction.get_n());
|
||||
return shuffler.generate(instruction.get_n(), shuffle_store);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle)
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
|
||||
instruction.get_start()[0], instruction.get_start()[1], handle,
|
||||
instruction.get_start()[0], instruction.get_start()[1],
|
||||
shuffle_store.get(handle),
|
||||
instruction.get_start()[4]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::delete_shuffle(int handle)
|
||||
{
|
||||
shuffler.del(handle);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::inverse_permutation(const Instruction& instruction) {
|
||||
shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0],
|
||||
@@ -796,17 +829,26 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
|
||||
{
|
||||
input.reset_all(P);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
if (args[i + 1] == P.my_num())
|
||||
{
|
||||
if (args[i + 1] == P.my_num())
|
||||
input.add_mine(C[args[i + 3] + j]);
|
||||
else
|
||||
input.add_other(args[i + 1]);
|
||||
auto begin = C.begin() + args[i + 3];
|
||||
auto end = begin + args[i];
|
||||
assert(end <= C.end());
|
||||
for (auto it = begin; it < end; it++)
|
||||
input.add_mine(*it);
|
||||
}
|
||||
else
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
input.add_other(args[i + 1]);
|
||||
input.exchange();
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
S[args[i + 2] + j] = input.finalize(args[i + 1]);
|
||||
{
|
||||
auto begin = S.begin() + args[i + 2];
|
||||
auto end = begin + args[i];
|
||||
assert(end <= S.end());
|
||||
for (auto it = begin; it < end; it++)
|
||||
*it = input.finalize(args[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -858,6 +900,16 @@ typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
|
||||
return inverses2m[m];
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
void fixinput_int(T& proc, const Instruction& instruction, U)
|
||||
{
|
||||
U* x = new U[instruction.get_size()];
|
||||
proc.binary_input.read((char*) x, sizeof(U) * instruction.get_size());
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
proc.write_Cp(instruction.get_r(0) + i, x[i]);
|
||||
delete[] x;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
{
|
||||
@@ -878,19 +930,24 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
throw runtime_error("unknown format for fixed-point input");
|
||||
}
|
||||
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
|
||||
if (binary_input.peek() == EOF)
|
||||
throw IO_Error("not enough inputs in " + binary_input_filename);
|
||||
|
||||
if (instruction.get_r(2) == 0)
|
||||
{
|
||||
if (binary_input.peek() == EOF)
|
||||
throw IO_Error("not enough inputs in " + binary_input_filename);
|
||||
double buf;
|
||||
if (instruction.get_r(2) == 0)
|
||||
{
|
||||
int64_t x;
|
||||
binary_input.read((char*) &x, sizeof(x));
|
||||
tmp = x;
|
||||
}
|
||||
if (instruction.get_r(1) == 1)
|
||||
fixinput_int(*this, instruction, int8_t());
|
||||
else
|
||||
fixinput_int(*this, instruction, int64_t());
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
{
|
||||
double buf;
|
||||
if (use_double)
|
||||
binary_input.read((char*) &buf, sizeof(double));
|
||||
else
|
||||
@@ -900,11 +957,12 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
buf = x;
|
||||
}
|
||||
tmp = bigint::tmp = round(buf * exp2(instruction.get_r(1)));
|
||||
write_Cp(instruction.get_r(0) + i, tmp);
|
||||
}
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
write_Cp(instruction.get_r(0) + i, tmp);
|
||||
}
|
||||
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@ using namespace std;
|
||||
#include "Tools/ExecutionStats.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "OnlineOptions.h"
|
||||
#include "Math/Integer.h"
|
||||
|
||||
class ProcessorBase
|
||||
{
|
||||
// Stack
|
||||
stack<long> stacki;
|
||||
stack<Integer> stacki;
|
||||
|
||||
ifstream input_file;
|
||||
string input_filename;
|
||||
@@ -26,7 +27,7 @@ class ProcessorBase
|
||||
|
||||
protected:
|
||||
// Optional argument to tape
|
||||
int arg;
|
||||
Integer arg;
|
||||
|
||||
string get_parameterized_filename(int my_num, int thread_num,
|
||||
const string& prefix);
|
||||
@@ -38,15 +39,15 @@ public:
|
||||
|
||||
ProcessorBase();
|
||||
|
||||
void pushi(long x) { stacki.push(x); }
|
||||
void popi(long& x) { x = stacki.top(); stacki.pop(); }
|
||||
void pushi(Integer x) { stacki.push(x); }
|
||||
void popi(Integer& x) { x = stacki.top(); stacki.pop(); }
|
||||
|
||||
int get_arg() const
|
||||
Integer get_arg() const
|
||||
{
|
||||
return arg;
|
||||
}
|
||||
|
||||
void set_arg(int new_arg)
|
||||
void set_arg(Integer new_arg)
|
||||
{
|
||||
arg=new_arg;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ template<template<int L> class U, template<class T> class V, class W>
|
||||
RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
|
||||
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
|
||||
{
|
||||
assert(nplayers or U<64>::variable_players);
|
||||
RingOptions opts(opt, argc, argv);
|
||||
W machine(argc, argv, opt, online_opts, gf2n(), nplayers);
|
||||
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);
|
||||
@@ -65,7 +66,7 @@ template<template<int K, int S> class U, template<class T> class V>
|
||||
HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecurity(
|
||||
int argc, const char** argv, ez::ezOptionParser& opt)
|
||||
{
|
||||
OnlineOptions online_opts(opt, argc, argv);
|
||||
OnlineOptions online_opts(opt, argc, argv, U<64, 40>());
|
||||
RingOptions opts(opt, argc, argv);
|
||||
HonestMajorityMachine machine(argc, argv, opt, online_opts);
|
||||
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);
|
||||
|
||||
@@ -18,10 +18,10 @@
|
||||
*dest++ = *source++) \
|
||||
X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = Proc.machine.Mp.read_S(*source++)) \
|
||||
X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
|
||||
Proc.machine.Mp.write_S(*dest++, *source++)) \
|
||||
X(LDMSI, Proc.machine.Mp.MS.indirect_read(instruction, Procp.get_S(), Proc.get_Ci()),) \
|
||||
X(STMSI, Proc.machine.Mp.MS.indirect_write(instruction, Procp.get_S(), Proc.get_Ci()),) \
|
||||
X(LDMCI, Proc.machine.Mp.MC.indirect_read(instruction, Procp.get_C(), Proc.get_Ci()),) \
|
||||
X(STMCI, Proc.machine.Mp.MC.indirect_write(instruction, Procp.get_C(), Proc.get_Ci()),) \
|
||||
X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
@@ -121,10 +121,8 @@
|
||||
*dest++ = *source++) \
|
||||
X(GSTMS, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.machine.M2.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(GLDMSI, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = Proc.machine.M2.read_S(*source++)) \
|
||||
X(GSTMSI, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
|
||||
Proc.machine.M2.write_S(*dest++, *source++)) \
|
||||
X(GLDMSI, Proc.machine.M2.MS.indirect_read(instruction, Proc2.get_S(), Proc.get_Ci()),) \
|
||||
X(GSTMSI, Proc.machine.M2.MS.indirect_write(instruction, Proc2.get_S(), Proc.get_Ci()),) \
|
||||
X(GMOVS, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc2.get_S()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(GADDS, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \
|
||||
@@ -171,10 +169,8 @@
|
||||
*dest++ = (*source).get(); source++) \
|
||||
X(STMINT, auto dest = &Mi[n]; auto source = &Proc.get_Ci()[r[0]], \
|
||||
*dest++ = *source++) \
|
||||
X(LDMINTI, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = Mi[*source].get(); source++) \
|
||||
X(STMINTI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &Ci[r[0]], \
|
||||
Mi[*dest] = *source++; dest++) \
|
||||
X(LDMINTI, Mi.indirect_read(*this, Proc.get_Ci(), Proc.get_Ci()),) \
|
||||
X(STMINTI, Mi.indirect_write(*this, Proc.get_Ci(), Proc.get_Ci()),) \
|
||||
X(MOVINT, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(PUSHINT, Proc.pushi(Ci[r[0]]),) \
|
||||
@@ -213,7 +209,7 @@
|
||||
X(SHUFFLE, shuffle(Proc),) \
|
||||
X(BITDECINT, bitdecint(Proc),) \
|
||||
X(RAND, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = Proc.shared_prng.get_uint() % (1 << *source++)) \
|
||||
*dest++ = Proc.shared_prng.get_uint() % (1 << (*source++).get())) \
|
||||
|
||||
#define CLEAR_GF2N_INSTRUCTIONS \
|
||||
X(GLDI, auto dest = &C2[r[0]]; cgf2n tmp = int(n), \
|
||||
@@ -222,10 +218,8 @@
|
||||
*dest++ = (*source).get(); source++) \
|
||||
X(GSTMC, auto dest = &M2C[n]; auto source = &C2[r[0]], \
|
||||
*dest++ = *source++) \
|
||||
X(GLDMCI, auto dest = &C2[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = M2C[*source++]) \
|
||||
X(GSTMCI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &C2[r[0]], \
|
||||
M2C[*dest++] = *source++) \
|
||||
X(GLDMCI, M2C.indirect_read(*this, C2, Proc.get_Ci()),) \
|
||||
X(GSTMCI, M2C.indirect_write(*this, C2, Proc.get_Ci()),) \
|
||||
X(GMOVC, auto dest = &C2[r[0]]; auto source = &C2[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(GADDC, auto dest = &C2[r[0]]; auto op1 = &C2[r[1]]; \
|
||||
@@ -288,9 +282,7 @@
|
||||
#define REMAINING_INSTRUCTIONS \
|
||||
X(CONVMODP, throw not_implemented(),) \
|
||||
X(LDMC, throw not_implemented(),) \
|
||||
X(LDMCI, throw not_implemented(),) \
|
||||
X(STMC, throw not_implemented(),) \
|
||||
X(STMCI, throw not_implemented(),) \
|
||||
X(MOVC, throw not_implemented(),) \
|
||||
X(DIVC, throw not_implemented(),) \
|
||||
X(GDIVC, throw not_implemented(),) \
|
||||
@@ -390,6 +382,8 @@
|
||||
X(APPLYSHUFFLE, throw not_implemented(),) \
|
||||
X(DELSHUFFLE, throw not_implemented(),) \
|
||||
X(ACTIVE, throw not_implemented(),) \
|
||||
X(FIXINPUT, throw not_implemented(),) \
|
||||
X(CONCATS, throw not_implemented(),) \
|
||||
|
||||
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
|
||||
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS
|
||||
|
||||
@@ -51,7 +51,7 @@ except:
|
||||
pass
|
||||
|
||||
layers = [
|
||||
ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [n_examples, 24, 24, 16],
|
||||
ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [N, 24, 24, 16],
|
||||
(1, 1), 'VALID'),
|
||||
ml.MaxPool([N, 24, 24, 16]),
|
||||
ml.Relu([N, 12, 12, 16]),
|
||||
|
||||
34
Programs/Source/multinode_example_main.py
Normal file
34
Programs/Source/multinode_example_main.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import random
|
||||
|
||||
n_nodes_per_party = int(program.args[1])
|
||||
n_threads_per_node = int(program.args[2])
|
||||
n_ops_per_thread = int(program.args[3])
|
||||
|
||||
n_ops_per_node = n_threads_per_node * n_ops_per_thread
|
||||
n_ops = n_nodes_per_party * n_ops_per_node
|
||||
data = Array.create_from(sint(regint.inc(n_ops)))
|
||||
|
||||
listen_for_clients(15000)
|
||||
|
||||
ready = regint.Array(n_nodes_per_party)
|
||||
|
||||
@for_range(n_nodes_per_party)
|
||||
def _(i):
|
||||
ready[accept_client_connection(15000)] = 1
|
||||
|
||||
runtime_error_if(sum(ready) != n_nodes_per_party, 'connection problems')
|
||||
|
||||
@for_range(n_nodes_per_party)
|
||||
def _(i):
|
||||
data.get_vector(base=i * n_ops_per_node,
|
||||
size=n_ops_per_node).write_fully_to_socket(i)
|
||||
|
||||
@for_range(n_nodes_per_party)
|
||||
def _(i):
|
||||
data.assign_vector(sint.read_from_socket(i, size=n_ops_per_node),
|
||||
base=i * n_ops_per_node)
|
||||
|
||||
for i in range(10):
|
||||
index = random.randrange(n_ops)
|
||||
value = data[index].reveal()
|
||||
runtime_error_if(value != index ** 2, '%s != %s', value, index ** 2)
|
||||
21
Programs/Source/multinode_example_worker.py
Normal file
21
Programs/Source/multinode_example_worker.py
Normal file
@@ -0,0 +1,21 @@
|
||||
n_threads = int(program.args[1])
|
||||
n_ops_per_thread = int(program.args[2])
|
||||
worker_id = int(program.args[3])
|
||||
|
||||
if len(program.args) > 4:
|
||||
host = program.args[4]
|
||||
else:
|
||||
host = 'localhost'
|
||||
|
||||
n_ops = n_threads * n_ops_per_thread
|
||||
data = sint.Array(n_ops)
|
||||
|
||||
main = init_client_connection(host, 15000, worker_id)
|
||||
|
||||
data.read_from_socket(main)
|
||||
|
||||
@for_range_opt_multithread(n_threads, n_ops)
|
||||
def _(i):
|
||||
data[i] = data[i] ** 2
|
||||
|
||||
data.write_to_socket(main)
|
||||
@@ -1,4 +1,5 @@
|
||||
# sint: secret integers
|
||||
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint
|
||||
|
||||
# you can assign public numbers to sint
|
||||
|
||||
@@ -14,6 +15,7 @@ def test(actual, expected):
|
||||
|
||||
# private inputs are read from Player-Data/Input-P<i>-0
|
||||
# or from standard input if using command-line option -I
|
||||
# see https://mp-spdz.readthedocs.io/en/latest/io.html for more options
|
||||
|
||||
for i in 0, 1:
|
||||
print_ln('got %s from player %s', sint.get_input_from(i).reveal(), i)
|
||||
@@ -62,6 +64,7 @@ test(a[99], 99 * 98)
|
||||
# test(a, 99)
|
||||
|
||||
# sfix: fixed-point numbers
|
||||
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix
|
||||
|
||||
# set the precision after the dot and in total
|
||||
|
||||
|
||||
@@ -35,6 +35,11 @@ public:
|
||||
typedef GC::AtlasSecret bit_type;
|
||||
#endif
|
||||
|
||||
static string alt()
|
||||
{
|
||||
return "";
|
||||
}
|
||||
|
||||
AtlasShare()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define PROTOCOLS_FAKEPROTOCOL_H_
|
||||
|
||||
#include "Replicated.h"
|
||||
#include "SecureShuffle.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Processor/Instruction.h"
|
||||
#include "Processor/TruncPrTuple.h"
|
||||
@@ -17,6 +18,8 @@ template<class T>
|
||||
class FakeShuffle
|
||||
{
|
||||
public:
|
||||
typedef ShuffleStore<int> store_type;
|
||||
|
||||
FakeShuffle(SubProcessor<T>&)
|
||||
{
|
||||
}
|
||||
@@ -27,9 +30,9 @@ public:
|
||||
apply(a, n, unit_size, output_base, input_base, 0, 0);
|
||||
}
|
||||
|
||||
size_t generate(size_t)
|
||||
size_t generate(size_t, store_type& store)
|
||||
{
|
||||
return 0;
|
||||
return store.add();
|
||||
}
|
||||
|
||||
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
@@ -49,10 +52,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void del(size_t)
|
||||
{
|
||||
}
|
||||
|
||||
void inverse_permutation(vector<T>&, size_t, size_t, size_t)
|
||||
{
|
||||
}
|
||||
@@ -280,6 +279,19 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (tag == string("EQZ\0", 4))
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i += args[i])
|
||||
{
|
||||
assert(i + args[i] <= args.size());
|
||||
assert(args[i] == 6);
|
||||
for (int j = 0; j < args[i + 1]; j++)
|
||||
{
|
||||
auto& res = processor.get_S()[args[i + 2] + j];
|
||||
res = processor.get_S()[args[i + 3] + j] == 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (tag == "Trun")
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i += args[i])
|
||||
|
||||
@@ -35,6 +35,7 @@ public:
|
||||
static const bool dishonest_majority = false;
|
||||
static const bool malicious = false;
|
||||
static const bool is_real = false;
|
||||
static const bool variable_players = false;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
|
||||
@@ -33,7 +33,7 @@ public:
|
||||
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
|
||||
SubProcessor<T>& processor);
|
||||
|
||||
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
|
||||
void matmulsm(SubProcessor<T>& processor, MemoryPart<T>& source,
|
||||
const Instruction& instruction, int a, int b);
|
||||
void conv2ds(SubProcessor<T>& processor, const Instruction& instruction);
|
||||
};
|
||||
|
||||
@@ -33,7 +33,7 @@ typename T::MatrixPrep& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Hemi<T>::matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
|
||||
void Hemi<T>::matmulsm(SubProcessor<T>& processor, MemoryPart<T>& source,
|
||||
const Instruction& instruction, int a, int b)
|
||||
{
|
||||
if (HemiOptions::singleton.plain_matmul
|
||||
@@ -61,16 +61,16 @@ void Hemi<T>::matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k);
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i);
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k).get();
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i).get();
|
||||
A.entries.v.push_back(source.at(a + ii * dim[7] + kk));
|
||||
}
|
||||
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j);
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k);
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j).get();
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k).get();
|
||||
B.entries.v.push_back(source.at(b + ll * dim[8] + jj));
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,11 @@ public:
|
||||
static const bool local_mul = true;
|
||||
static true_type triple_matmul;
|
||||
|
||||
static string alt()
|
||||
{
|
||||
return "Temi";
|
||||
}
|
||||
|
||||
HemiShare()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -250,8 +250,11 @@ TreeSum<T>::~TreeSum()
|
||||
template<class T>
|
||||
void TreeSum<T>::run(vector<T>& values, const Player& P)
|
||||
{
|
||||
start(values, P);
|
||||
finish(values, P);
|
||||
if (not values.empty())
|
||||
{
|
||||
start(values, P);
|
||||
finish(values, P);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -300,9 +303,11 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
|
||||
P.wait_receive(sender, oss[j]);
|
||||
MC.player_timers[sender].stop();
|
||||
MC.timers[SUM].start();
|
||||
T tmp = values.at(0);
|
||||
for (unsigned int i=0; i<values.size(); i++)
|
||||
{
|
||||
values[i].add(oss[j], use_lengths ? lengths[i] : -1);
|
||||
tmp.unpack(oss[j], use_lengths ? lengths[i] : -1);
|
||||
values[i] += tmp;
|
||||
}
|
||||
post_add_process(values);
|
||||
MC.timers[SUM].stop();
|
||||
|
||||
@@ -176,7 +176,7 @@ void MAC_Check_<U>::Check(const Player& P)
|
||||
for (auto& os : bundle)
|
||||
if (&os != &bundle.mine)
|
||||
delta += os.get<typename U::mac_type>();
|
||||
if (not delta.is_zero())
|
||||
if (delta != 0)
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
@@ -194,8 +194,6 @@ void MAC_Check_<U>::Check(const Player& P)
|
||||
typename U::mac_type a,gami,temp;
|
||||
typename U::mac_type::Scalar h;
|
||||
vector<typename U::mac_type> tau(P.num_players());
|
||||
a.assign_zero();
|
||||
gami.assign_zero();
|
||||
for (int i=0; i<popen_cnt; i++)
|
||||
{
|
||||
h.almost_randomize(G);
|
||||
@@ -217,10 +215,10 @@ void MAC_Check_<U>::Check(const Player& P)
|
||||
//cerr << "\tFinal Check" << endl;
|
||||
|
||||
typename U::mac_type t;
|
||||
t.assign_zero();
|
||||
for (int i=0; i<P.num_players(); i++)
|
||||
{ t += tau[i]; }
|
||||
if (!t.is_zero()) { throw mac_fail(); }
|
||||
if (t != 0)
|
||||
throw mac_fail();
|
||||
}
|
||||
|
||||
vals.erase(vals.begin(), vals.begin() + popen_cnt);
|
||||
|
||||
@@ -136,6 +136,12 @@ TripleShuffleSacrifice<T>::TripleShuffleSacrifice(int B, int C) :
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TripleShuffleSacrifice<T>::TripleShuffleSacrifice(DataFieldType type) :
|
||||
ShuffleSacrifice(BaseMachine::bucket_size(type))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TripleShuffleSacrifice<T>::triple_sacrifice(vector<array<T, 3>>& triples,
|
||||
vector<array<T, 3>>& check_triples, Player& P,
|
||||
|
||||
@@ -44,6 +44,10 @@ public:
|
||||
// default private output facility (using input tuples)
|
||||
typedef ::PrivateOutput<NoShare> PrivateOutput;
|
||||
|
||||
// indicate whether protocol allows dishonest majority and variable players
|
||||
static const bool dishonest_majority = true;
|
||||
static const bool variable_players = true;
|
||||
|
||||
// description used for debugging output
|
||||
static string type_string()
|
||||
{
|
||||
@@ -187,4 +191,11 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
inline ostream& operator<<(ostream& o, NoShare<T>)
|
||||
{
|
||||
throw runtime_error("no output");
|
||||
return o;
|
||||
}
|
||||
|
||||
#endif /* PROTOCOLS_NOSHARE_H_ */
|
||||
|
||||
@@ -6,12 +6,17 @@
|
||||
#ifndef PROTOCOLS_REP3SHUFFLER_H_
|
||||
#define PROTOCOLS_REP3SHUFFLER_H_
|
||||
|
||||
#include "SecureShuffle.h"
|
||||
|
||||
template<class T>
|
||||
class Rep3Shuffler
|
||||
{
|
||||
SubProcessor<T>& proc;
|
||||
public:
|
||||
typedef array<vector<int>, 2> shuffle_type;
|
||||
typedef ShuffleStore<shuffle_type> store_type;
|
||||
|
||||
vector<array<vector<int>, 2>> shuffles;
|
||||
private:
|
||||
SubProcessor<T>& proc;
|
||||
|
||||
public:
|
||||
Rep3Shuffler(vector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
@@ -19,15 +24,13 @@ public:
|
||||
|
||||
Rep3Shuffler(SubProcessor<T>& proc);
|
||||
|
||||
int generate(int n_shuffle);
|
||||
int generate(int n_shuffle, store_type& store);
|
||||
|
||||
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
size_t input_base, int handle, bool reverse);
|
||||
size_t input_base, shuffle_type& shuffle, bool reverse);
|
||||
|
||||
void inverse_permutation(vector<T>& stack, size_t n, size_t output_base,
|
||||
size_t input_base);
|
||||
|
||||
void del(int handle);
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_REP3SHUFFLER_H_ */
|
||||
|
||||
@@ -13,9 +13,10 @@ Rep3Shuffler<T>::Rep3Shuffler(vector<T>& a, size_t n, int unit_size,
|
||||
size_t output_base, size_t input_base, SubProcessor<T>& proc) :
|
||||
proc(proc)
|
||||
{
|
||||
apply(a, n, unit_size, output_base, input_base, generate(n / unit_size),
|
||||
store_type store;
|
||||
int handle = generate(n / unit_size, store);
|
||||
apply(a, n, unit_size, output_base, input_base, store.get(handle),
|
||||
false);
|
||||
shuffles.pop_back();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -25,10 +26,10 @@ Rep3Shuffler<T>::Rep3Shuffler(SubProcessor<T>& proc) :
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int Rep3Shuffler<T>::generate(int n_shuffle)
|
||||
int Rep3Shuffler<T>::generate(int n_shuffle, store_type& store)
|
||||
{
|
||||
shuffles.push_back({});
|
||||
auto& shuffle = shuffles.back();
|
||||
int res = store.add();
|
||||
auto& shuffle = store.get(res);
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
auto& perm = shuffle[i];
|
||||
@@ -40,19 +41,22 @@ int Rep3Shuffler<T>::generate(int n_shuffle)
|
||||
swap(perm[k], perm[k + j]);
|
||||
}
|
||||
}
|
||||
return shuffles.size() - 1;
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
|
||||
size_t output_base, size_t input_base, int handle, bool reverse)
|
||||
size_t output_base, size_t input_base, shuffle_type& shuffle,
|
||||
bool reverse)
|
||||
{
|
||||
assert(proc.P.num_players() == 3);
|
||||
assert(not T::malicious);
|
||||
assert(not T::dishonest_majority);
|
||||
assert(n % unit_size == 0);
|
||||
|
||||
auto& shuffle = shuffles.at(handle);
|
||||
if (shuffle.empty())
|
||||
throw runtime_error("shuffle has been deleted");
|
||||
|
||||
vector<T> to_shuffle;
|
||||
for (size_t i = 0; i < n; i++)
|
||||
to_shuffle.push_back(a[input_base + i]);
|
||||
@@ -115,12 +119,6 @@ void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
|
||||
a[output_base + i] = to_shuffle[i];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Rep3Shuffler<T>::del(int handle)
|
||||
{
|
||||
shuffles.at(handle) = {};
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Rep3Shuffler<T>::inverse_permutation(vector<T>&, size_t, size_t, size_t)
|
||||
{
|
||||
|
||||
@@ -41,6 +41,7 @@ public:
|
||||
typedef GC::Rep4Secret bit_type;
|
||||
|
||||
static const bool malicious = true;
|
||||
static const bool variable_players = false;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
|
||||
@@ -15,6 +15,7 @@ using namespace std;
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/PointerVector.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Processor/Memory.h"
|
||||
|
||||
template<class T> class SubProcessor;
|
||||
template<class T> class ReplicatedMC;
|
||||
@@ -68,8 +69,6 @@ public:
|
||||
ProtocolBase();
|
||||
virtual ~ProtocolBase();
|
||||
|
||||
void muls(const vector<int>& reg, SubProcessor<T>& proc, typename T::MAC_Check& MC,
|
||||
int size);
|
||||
void mulrs(const vector<int>& reg, SubProcessor<T>& proc);
|
||||
|
||||
void multiply(vector<T>& products, vector<pair<T, T>>& multiplicands,
|
||||
@@ -111,7 +110,7 @@ public:
|
||||
virtual void randoms_inst(vector<T>&, const Instruction&);
|
||||
|
||||
template<int = 0>
|
||||
void matmulsm(SubProcessor<T> & proc, CheckVector<T>& source,
|
||||
void matmulsm(SubProcessor<T> & proc, MemoryPart<T>& source,
|
||||
const Instruction& instruction, int a, int b)
|
||||
{ proc.matmulsm(source, instruction, a, b); }
|
||||
|
||||
|
||||
@@ -78,14 +78,6 @@ ProtocolBase<T>::~ProtocolBase()
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ProtocolBase<T>::muls(const vector<int>& reg,
|
||||
SubProcessor<T>& proc, typename T::MAC_Check& MC, int size)
|
||||
{
|
||||
(void)MC;
|
||||
proc.muls(reg, size);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ProtocolBase<T>::mulrs(const vector<int>& reg,
|
||||
SubProcessor<T>& proc)
|
||||
|
||||
@@ -9,18 +9,42 @@
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/Lock.h"
|
||||
|
||||
template<class T> class SubProcessor;
|
||||
|
||||
template<class T>
|
||||
class ShuffleStore
|
||||
{
|
||||
typedef T shuffle_type;
|
||||
|
||||
deque<shuffle_type> shuffles;
|
||||
|
||||
Lock store_lock;
|
||||
|
||||
void lock();
|
||||
void unlock();
|
||||
|
||||
public:
|
||||
int add();
|
||||
shuffle_type& get(int handle);
|
||||
void del(int handle);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class SecureShuffle
|
||||
{
|
||||
public:
|
||||
typedef vector<vector<vector<T>>> shuffle_type;
|
||||
typedef ShuffleStore<shuffle_type> store_type;
|
||||
|
||||
private:
|
||||
SubProcessor<T>& proc;
|
||||
vector<T> to_shuffle;
|
||||
vector<vector<T>> config;
|
||||
vector<T> tmp;
|
||||
int unit_size;
|
||||
|
||||
vector<vector<vector<vector<T>>>> shuffles;
|
||||
size_t n_shuffle;
|
||||
bool exact;
|
||||
|
||||
@@ -62,7 +86,7 @@ public:
|
||||
|
||||
SecureShuffle(SubProcessor<T>& proc);
|
||||
|
||||
int generate(int n_shuffle);
|
||||
int generate(int n_shuffle, store_type& store);
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -73,12 +97,12 @@ public:
|
||||
* would result in [3,4,1,2]
|
||||
* @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to)
|
||||
* @param input_base The starting address of the input vector (i.e. the location from which to read the permutation)
|
||||
* @param handle The integer identifying the preconfigured waksman network (shuffle) to use. Such a handle can be obtained from calling
|
||||
* @param shuffle The preconfigured waksman network (shuffle) to use
|
||||
* @param reverse Boolean indicating whether to apply the inverse of the permutation
|
||||
* @see SecureShuffle::generate for obtaining a shuffle handle
|
||||
*/
|
||||
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
size_t input_base, int handle, bool reverse);
|
||||
size_t input_base, shuffle_type& shuffle, bool reverse);
|
||||
|
||||
/**
|
||||
* Calculate the secret inverse permutation of stack given secret permutation.
|
||||
@@ -94,8 +118,6 @@ public:
|
||||
* @param input_base The starting address of the input vector (i.e. the location from which to read the permutation)
|
||||
*/
|
||||
void inverse_permutation(vector<T>& stack, size_t n, size_t output_base, size_t input_base);
|
||||
|
||||
void del(int handle);
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_SECURESHUFFLE_H_ */
|
||||
|
||||
@@ -12,6 +12,45 @@
|
||||
#include <math.h>
|
||||
#include <algorithm>
|
||||
|
||||
template<class T>
|
||||
void ShuffleStore<T>::lock()
|
||||
{
|
||||
store_lock.lock();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShuffleStore<T>::unlock()
|
||||
{
|
||||
store_lock.unlock();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int ShuffleStore<T>::add()
|
||||
{
|
||||
lock();
|
||||
int res = shuffles.size();
|
||||
shuffles.push_back({});
|
||||
unlock();
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
typename ShuffleStore<T>::shuffle_type& ShuffleStore<T>::get(int handle)
|
||||
{
|
||||
lock();
|
||||
auto& res = shuffles.at(handle);
|
||||
unlock();
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShuffleStore<T>::del(int handle)
|
||||
{
|
||||
lock();
|
||||
shuffles.at(handle) = {};
|
||||
unlock();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
SecureShuffle<T>::SecureShuffle(SubProcessor<T>& proc) :
|
||||
proc(proc), unit_size(0), n_shuffle(0), exact(false)
|
||||
@@ -33,13 +72,12 @@ SecureShuffle<T>::SecureShuffle(vector<T>& a, size_t n, int unit_size,
|
||||
|
||||
template<class T>
|
||||
void SecureShuffle<T>::apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
size_t input_base, int handle, bool reverse)
|
||||
size_t input_base, shuffle_type& shuffle, bool reverse)
|
||||
{
|
||||
this->unit_size = unit_size;
|
||||
|
||||
pre(a, n, input_base);
|
||||
|
||||
auto& shuffle = shuffles.at(handle);
|
||||
assert(shuffle.size() == proc.protocol.get_relevant_players().size());
|
||||
|
||||
if (reverse)
|
||||
@@ -134,12 +172,6 @@ void SecureShuffle<T>::inverse_permutation(vector<T> &stack, size_t n, size_t ou
|
||||
post(stack, n, output_base);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SecureShuffle<T>::del(int handle)
|
||||
{
|
||||
shuffles.at(handle).clear();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SecureShuffle<T>::pre(vector<T>& a, size_t n, size_t input_base)
|
||||
{
|
||||
@@ -230,11 +262,10 @@ void SecureShuffle<T>::player_round(int config_player) {
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int SecureShuffle<T>::generate(int n_shuffle)
|
||||
int SecureShuffle<T>::generate(int n_shuffle, store_type& store)
|
||||
{
|
||||
int res = shuffles.size();
|
||||
shuffles.push_back({});
|
||||
auto& shuffle = shuffles.back();
|
||||
int res = store.add();
|
||||
auto& shuffle = store.get(res);
|
||||
|
||||
for (auto i: proc.protocol.get_relevant_players()) {
|
||||
vector<int> perm;
|
||||
|
||||
@@ -65,6 +65,11 @@ public:
|
||||
return "Shamir " + T::type_string();
|
||||
}
|
||||
|
||||
static string alt()
|
||||
{
|
||||
return "ATLAS";
|
||||
}
|
||||
|
||||
static int threshold(int)
|
||||
{
|
||||
return ShamirMachine::s().threshold;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user