Maintenance.

This commit is contained in:
Marcel Keller
2023-08-14 18:29:08 +10:00
parent 7bc156e581
commit 2813c0ef0f
140 changed files with 1598 additions and 388 deletions

View File

@@ -1,5 +1,19 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.7 (August 14, 2023)
- Path Oblivious Heap (@tskovlund)
- Adjust batch and bucket size to program
- Direct communication available in more protocols
- Option for seed in fake preprocessing (@strieflin)
- Lower memory usage due to improved register allocation
- New instructions to speed up CISC compilation
- Protocol implementation example
- Fixed security bug: missing MAC checks in multi-threaded programs
- Fixed security bug: race condition in MAC check
- Fixed security bug: missing shuffling check in PS mod 2^k and Brain
- Fixed security bug: insufficient drowning in pairwise protocols
## 0.3.6 (May 9, 2023)
- More extensive benchmarking outputs

View File

@@ -781,8 +781,6 @@ class sbitvec(_vec, _bit):
for i in range(n):
for j, x in enumerate(v[i].bit_decompose()):
x.store_in_mem(address + i + j * n)
def reveal(self):
return util.untuplify([x.reveal() for x in self.elements()])
@classmethod
def two_power(cls, nn, size=1):
return cls.from_vec(
@@ -919,8 +917,7 @@ class sbitvec(_vec, _bit):
return self.v[:n_bits]
bit_compose = from_vec
def reveal(self):
assert len(self) == 1
return self.v[0].reveal()
return util.untuplify([x.reveal() for x in self.elements()])
def long_one(self):
return [x.long_one() for x in self.v]
def __rsub__(self, other):
@@ -1279,7 +1276,8 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
Vector of signed integers for parallel binary computation::
Vector of signed integers for parallel binary computation.
The following example uses vectors of size two::
sb32 = sbits.get_type(32)
siv32 = sbitintvec.get_type(32)
@@ -1291,7 +1289,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal())
c = (a - b).elements()
print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal())
c = (a < b).bit_decompose()
c = (a < b).elements()
print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal())
This should output::
@@ -1467,7 +1465,7 @@ class sbitfixvec(_fix, _vec):
print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal())
c = (a - b).elements()
print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal())
c = (a < b).bit_decompose()
c = (a < b).elements()
print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal())
This should output roughly::

View File

@@ -43,7 +43,7 @@ class BlockAllocator:
else:
done = False
for x in self.by_logsize[logsize + 1:]:
for block_size, addresses in x.items():
for block_size, addresses in sorted(x.items()):
if len(addresses) > 0:
done = True
break
@@ -60,16 +60,92 @@ class BlockAllocator:
self.by_address[addr + size] = diff
return addr
class AllocRange:
def __init__(self, base=0):
self.base = base
self.top = base
self.limit = base
self.grow = True
self.pool = defaultdict(set)
def alloc(self, size):
if self.pool[size]:
return self.pool[size].pop()
elif self.grow or self.top + size <= self.limit:
res = self.top
self.top += size
self.limit = max(self.limit, self.top)
if res >= REG_MAX:
raise RegisterOverflowError()
return res
def free(self, base, size):
assert self.base <= base < self.top
self.pool[size].add(base)
def stop_growing(self):
self.grow = False
def consolidate(self):
regs = []
for size, pool in self.pool.items():
for base in pool:
regs.append((base, size))
for base, size in reversed(sorted(regs)):
if base + size == self.top:
self.top -= size
self.pool[size].remove(base)
regs.pop()
else:
if program.Program.prog.verbose:
print('cannot free %d register blocks '
'by a gap of %d at %d' %
(len(regs), self.top - size - base, base))
break
class AllocPool:
def __init__(self):
self.ranges = defaultdict(lambda: [AllocRange()])
self.by_base = {}
def alloc(self, reg_type, size):
for r in self.ranges[reg_type]:
res = r.alloc(size)
if res is not None:
self.by_base[reg_type, res] = r
return res
def free(self, reg):
r = self.by_base.pop((reg.reg_type, reg.i))
r.free(reg.i, reg.size)
def new_ranges(self, min_usage):
for t, n in min_usage.items():
r = self.ranges[t][-1]
assert (n >= r.limit)
if r.limit < n:
r.stop_growing()
self.ranges[t].append(AllocRange(n))
def consolidate(self):
for r in self.ranges.values():
for rr in r:
rr.consolidate()
def n_fragments(self):
return max(len(r) for r in self.ranges)
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
It is based on the precondition that every register is only defined once."""
def __init__(self, n, program):
self.alloc = dict_by_id()
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
self.max_usage = defaultdict(lambda: 0)
self.defined = dict_by_id()
self.dealloc = set_by_id()
self.n = n
assert(n == REG_MAX)
self.program = program
self.old_pool = None
def alloc_reg(self, reg, free):
base = reg.vectorbase
@@ -79,14 +155,7 @@ class StraightlineAllocator:
reg_type = reg.reg_type
size = base.size
if free[reg_type, size]:
res = free[reg_type, size].pop()
else:
if self.usage[reg_type] < self.n:
res = self.usage[reg_type]
self.usage[reg_type] += size
else:
raise RegisterOverflowError()
res = free.alloc(reg_type, size)
self.alloc[base] = res
base.i = self.alloc[base]
@@ -126,7 +195,7 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)
free[reg.reg_type, base.size].append(self.alloc[base])
free.free(base)
if inst.is_vec() and base.vector:
self.defined[base] = inst
for i in base.vector:
@@ -135,6 +204,7 @@ class StraightlineAllocator:
self.defined[reg] = inst
def process(self, program, alloc_pool):
self.update_usage(alloc_pool)
for k,i in enumerate(reversed(program)):
unused_regs = []
for j in i.get_def():
@@ -161,12 +231,26 @@ class StraightlineAllocator:
if k % 1000000 == 0 and k > 0:
print("Allocated registers for %d instructions at" % k, time.asctime())
self.update_max_usage(alloc_pool)
alloc_pool.consolidate()
# print "Successfully allocated registers"
# print "modp usage: %d clear, %d secret" % \
# (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp])
# print "GF2N usage: %d clear, %d secret" % \
# (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N])
return self.usage
return self.max_usage
def update_max_usage(self, alloc_pool):
for t, r in alloc_pool.ranges.items():
self.max_usage[t] = max(self.max_usage[t], r[-1].limit)
def update_usage(self, alloc_pool):
if self.old_pool:
self.update_max_usage(self.old_pool)
if id(self.old_pool) != id(alloc_pool):
alloc_pool.new_ranges(self.max_usage)
self.old_pool = alloc_pool
def finalize(self, options):
for reg in self.alloc:
@@ -178,6 +262,21 @@ class StraightlineAllocator:
'\t\t'))
if options.stop:
sys.exit(1)
if self.program.verbose:
def p(sizes):
total = defaultdict(lambda: 0)
for (t, size) in sorted(sizes):
n = sizes[t, size]
total[t] += size * n
print('%s:%d*%d' % (t, size, n), end=' ')
print()
print('Total:', dict(total))
sizes = defaultdict(lambda: 0)
for reg in self.alloc:
x = reg.reg_type, reg.size
print('Used registers: ', end='')
p(sizes)
def determine_scope(block, options):
last_def = defaultdict_by_id(lambda: -1)

View File

@@ -10,7 +10,7 @@ the ones used below into ``Programs/Circuits`` as follows::
"""
from Compiler.GC.types import *
from Compiler.library import function_block
from Compiler.library import function_block, get_tape
from Compiler import util
import itertools
import struct
@@ -54,7 +54,7 @@ class Circuit:
return self.run(*inputs)
def run(self, *inputs):
n = inputs[0][0].n
n = inputs[0][0].n, get_tape()
if n not in self.functions:
self.functions[n] = function_block(lambda *args:
self.compile(*args))

View File

@@ -270,9 +270,9 @@ class Compiler:
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute in \
("emulate", "ring", "rep-field"):
("emulate", "ring", "rep-field", "rep4-ring"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring",):
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
self.prog.use_split(3)
if self.options.execute in ("semi2k",):
self.prog.use_split(2)
@@ -487,6 +487,7 @@ class Compiler:
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
sys.stdout.flush()
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]):

View File

@@ -633,7 +633,8 @@ def preprocess_pandas(data):
res.append(data.iloc[:,i].to_numpy())
types.append('c')
elif pandas.api.types.is_object_dtype(t):
values = data.iloc[:,i].unique()
values = list(filter(lambda x: isinstance(x, str),
list(data.iloc[:,i].unique())))
print('converting the following to unary:', values)
if len(values) == 2:
res.append(data.iloc[:,i].to_numpy() == values[1])

View File

@@ -638,6 +638,44 @@ class prefixsums(base.Instruction):
code = base.opcodes['PREFIXSUMS']
arg_format = ['sw','s']
class picks(base.VectorInstruction):
""" Extract part of vector.
:param: result (sint)
:param: input (sint)
:param: start offset (int)
:param: step
"""
__slots__ = []
code = base.opcodes['PICKS']
arg_format = ['sw','s','int','int']
def __init__(self, *args):
super(picks, self).__init__(*args)
assert 0 <= args[2] < len(args[1])
assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1])
class concats(base.VectorInstruction):
""" Concatenate vectors.
:param: result (sint)
:param: start offset (int)
:param: input (sint)
:param: (repeat from offset)...
"""
__slots__ = []
code = base.opcodes['CONCATS']
arg_format = tools.chain(['sw'], tools.cycle(['int','s']))
def __init__(self, *args):
super(concats, self).__init__(*args)
assert len(args) % 2 == 1
assert len(args[0]) == sum(args[1::2])
for i in range(1, len(args), 2):
assert args[i] == len(args[i + 1])
@base.gf2n
@base.vectorize
class mulc(base.MulBase):

View File

@@ -82,6 +82,8 @@ opcodes = dict(
SUBCFI = 0x2B,
SUBSFI = 0x2C,
PREFIXSUMS = 0x2D,
PICKS = 0x2E,
CONCATS = 0x2F,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
@@ -523,29 +525,26 @@ def cisc(function):
tape.active_basicblock = block
size = sum(call[0][0].size for call in self.calls)
new_regs = []
for arg in self.args:
for i, arg in enumerate(self.args):
try:
new_regs.append(type(arg)(size=size))
except TypeError:
if i == 0:
new_regs.append(type(arg)(size=size))
else:
new_regs.append(type(arg).concat(
call[0][i] for call in self.calls))
assert len(new_regs[-1]) == size
except (TypeError, AttributeError):
if not isinstance(arg, int):
raise
break
except:
print([call[0][0].size for call in self.calls])
raise
assert len(new_regs) > 1
base = 0
for call in self.calls:
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
set_global_vector_size(reg.size)
reg.mov(new_reg.get_vector(base, reg.size), reg)
reset_global_vector_size()
base += reg.size
self.new_instructions(size, new_regs)
base = 0
for call in self.calls:
reg = call[0][0]
set_global_vector_size(reg.size)
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
reset_global_vector_size()
reg.copy_from_part(new_regs[0], base, reg.size)
base += reg.size
return block.instructions, self.n_rounds - 1
@@ -628,7 +627,7 @@ def sfix_cisc(function):
instruction = cisc(instruction)
def wrapper(*args, **kwargs):
if isinstance(args[0], sfix):
if isinstance(args[0], sfix) and program.options.cisc:
for arg in args[1:]:
assert util.is_constant(arg)
assert not kwargs
@@ -844,6 +843,12 @@ class Instruction(object):
Instruction.count += 1
if Instruction.count % 100000 == 0:
print("Compiled %d lines at" % self.__class__.count, time.asctime())
if Instruction.count > 10 ** 7:
print("Compilation produced more that 10 million instructions. "
"Consider using './compile.py -l' or replacing for loops "
"with @for_range_opt: "
"https://mp-spdz.readthedocs.io/en/latest/Compiler.html#"
"Compiler.library.for_range_opt")
def get_code(self, prefix=0):
return (prefix << self.code_length) + self.code

View File

@@ -6,7 +6,7 @@ in particularly providing flow control and output.
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc, _vec
from Compiler.instructions import *
from Compiler.util import tuplify,untuplify,is_zero
from Compiler.allocator import RegintOptimizer
from Compiler.allocator import RegintOptimizer, AllocPool
from Compiler import instructions,instructions_base,comparison,program,util
import inspect,math
import random
@@ -411,7 +411,7 @@ class FunctionBlock(Function):
parent_node = get_tape().req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock
block.alloc_pool = defaultdict(list)
block.alloc_pool = AllocPool()
del parent_node.children[-1]
self.node = get_tape().req_node
if get_program().verbose:
@@ -935,7 +935,7 @@ def for_range_opt_multithread(n_threads, n_loops):
...
Note that you cannot use registers across threads. Use
:py:class:`MemValue` instead::
:py:class:`~Compiler.types.MemValue` instead::
a = MemValue(sint(0))
@for_range_opt_multithread(8, 80)
@@ -1069,6 +1069,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
prog.free_later()
if len(state):
if thread_rounds:
for i in range(n_threads - remainder):
@@ -1320,6 +1321,7 @@ def if_then(condition):
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
name='if-block')
state.has_else = False
state.closed_if = False
state.caller = [frame[1:] for frame in inspect.stack()[1:]]
instructions.program.curr_tape.if_states.append(state)
@@ -1434,6 +1436,7 @@ def if_e(condition):
else:
if_then(condition)
_run_and_link(body)
get_tape().if_states[-1].closed_if = True
return decorator
def else_(body):
@@ -1443,6 +1446,8 @@ def else_(body):
_run_and_link(body)
if_states.pop()
else:
if not if_states[-1].closed_if:
raise CompilerError('@if_e not closed before else block')
else_then()
_run_and_link(body)
end_if()

View File

@@ -822,10 +822,10 @@ class Dense(DenseBase):
self.W.randomize(-r, r, n_threads=self.n_threads)
self.b.assign_all(0)
def input_from(self, player, raw=False):
self.W.input_from(player, raw=raw)
def input_from(self, player, **kwargs):
self.W.input_from(player, **kwargs)
if self.input_bias:
self.b.input_from(player, raw=raw)
self.b.input_from(player, **kwargs)
def compute_f_input(self, batch):
N = len(batch)
@@ -1088,10 +1088,7 @@ class Relu(ElementWiseLayer):
:param shape: input/output shape (tuple/list of int)
"""
f = staticmethod(relu)
f_prime = staticmethod(relu_prime)
prime_type = sint
comparisons = None
def __init__(self, shape, inputs=None):
super(Relu, self).__init__(shape)
@@ -1310,12 +1307,12 @@ class FusedBatchNorm(Layer):
self.bias = sfix.Array(shape[3])
self.inputs = inputs
def input_from(self, player, raw=False):
self.weights.input_from(player, raw=raw)
self.bias.input_from(player, raw=raw)
def input_from(self, player, **kwargs):
self.weights.input_from(player, **kwargs)
self.bias.input_from(player, **kwargs)
tmp = sfix.Array(len(self.bias))
tmp.input_from(player, raw=raw)
tmp.input_from(player, raw=raw)
tmp.input_from(player, **kwargs)
tmp.input_from(player, **kwargs)
def _forward(self, batch=[0]):
assert len(batch) == 1
@@ -1611,11 +1608,11 @@ class ConvBase(BaseLayer):
self.bias_shape, self.Y.sizes, self.stride, repr(self.padding),
self.tf_weight_format)
def input_from(self, player, raw=False):
def input_from(self, player, **kwargs):
self.input_params_from(player)
self.weights.input_from(player, budget=100000, raw=raw)
self.weights.input_from(player, budget=100000, **kwargs)
if self.input_bias:
self.bias.input_from(player, raw=raw)
self.bias.input_from(player, **kwargs)
def output_weights(self):
self.weights.print_reveal_nested()

View File

@@ -112,8 +112,8 @@ class Program(object):
self.non_linear = Prime(self.security)
if not self.bit_length:
self.bit_length = 64
print("Default bit length:", self.bit_length)
print("Default security parameter:", self.security)
print("Default bit length for compilation:", self.bit_length)
print("Default security parameter for compilation:", self.security)
self.galois_length = int(options.galois)
if self.verbose:
print("Galois length:", self.galois_length)
@@ -122,6 +122,7 @@ class Program(object):
self.DEBUG = options.debug
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
self.free_mem_blocks = defaultdict(al.BlockAllocator)
self.later_mem_blocks = defaultdict(list)
self.allocated_mem_blocks = {}
self.saved = 0
self.req_num = None
@@ -229,24 +230,26 @@ class Program(object):
if self.name.endswith(ext):
self.name = self.name[:-len(ext)]
if os.path.exists(args[0]):
self.infile = args[0]
infiles = [args[0]]
for x in (self.programs_dir, sys.path[0] + "/Programs"):
for ext in exts:
filename = args[0]
if not filename.endswith(ext):
filename += ext
filename = x + "/Source/" + filename
if os.path.abspath(filename) not in \
[os.path.abspath(f) for f in infiles]:
infiles += [filename]
existing = [f for f in infiles if os.path.exists(f)]
if len(existing) == 1:
self.infile = existing[0]
elif len(existing) > 1:
raise CompilerError("ambiguous input files: " +
", ".join(existing))
else:
infiles = []
for x in (self.programs_dir, sys.path[0] + "/Programs"):
for ext in exts:
filename = args[0]
if not filename.endswith(ext):
filename += ext
infiles += [x + "/Source/" + filename]
for f in infiles:
if os.path.exists(f):
self.infile = f
break
else:
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in [args[0]] + infiles))
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in [args[0]] + infiles))
"""
self.name is input file name (minus extension) + any optional arguments.
Used to generate output filenames
@@ -463,7 +466,7 @@ class Program(object):
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
if addr + size >= 2**64:
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
self.allocated_mem_blocks[addr, mem_type] = size
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
if single_size:
from .library import get_thread_number, runtime_error_if
@@ -477,12 +480,24 @@ class Program(object):
def free(self, addr, mem_type):
"""Free memory"""
if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool:
raise CompilerError("Cannot free memory within function block")
now = True
if not util.is_constant(addr):
addr = self.base_addresses[str(addr)]
size = self.allocated_mem_blocks.pop((addr, mem_type))
self.free_mem_blocks[mem_type].push(addr, size)
now = self.curr_tape == self.tapes[0]
size, pool = self.allocated_mem_blocks[addr, mem_type]
if self.curr_block.alloc_pool is not pool:
raise CompilerError("Cannot free memory across function blocks")
self.allocated_mem_blocks.pop((addr, mem_type))
if now:
self.free_mem_blocks[mem_type].push(addr, size)
else:
self.later_mem_blocks[mem_type].append((addr, size))
def free_later(self):
for mem_type in self.later_mem_blocks:
for block in self.later_mem_blocks[mem_type]:
self.free_mem_blocks[mem_type].push(*block)
self.later_mem_blocks.clear()
def finalize(self):
# optimize the tapes
@@ -744,6 +759,7 @@ class Tape:
self.purged = False
self.block_counter = 0
self.active_basicblock = None
self.old_allocated_mem = program.allocated_mem.copy()
self.start_new_basicblock()
self._is_empty = False
self.merge_opens = True
@@ -771,7 +787,7 @@ class Tape:
scope.children.append(self)
self.alloc_pool = scope.alloc_pool
else:
self.alloc_pool = defaultdict(list)
self.alloc_pool = al.AllocPool()
self.purged = False
self.n_rounds = 0
self.n_to_merge = 0
@@ -869,6 +885,15 @@ class Tape:
return self._is_empty
def start_new_basicblock(self, scope=False, name=""):
if self.program.verbose and self.active_basicblock and \
self.program.allocated_mem != self.old_allocated_mem:
print("New allocated memory in %s " % self.active_basicblock.name,
end="")
for t, n in self.program.allocated_mem.items():
if n != self.old_allocated_mem[t]:
print("%s:%d " % (t, n - self.old_allocated_mem[t]), end="")
print()
self.old_allocated_mem = self.program.allocated_mem.copy()
# use False because None means no scope
if scope is False:
scope = self.active_basicblock
@@ -1029,6 +1054,7 @@ class Tape:
allocator = al.StraightlineAllocator(REG_MAX, self.program)
def alloc(block):
allocator.update_usage(block.alloc_pool)
for reg in sorted(
block.used_from_scope, key=lambda x: (x.reg_type, x.i)
):
@@ -1042,6 +1068,7 @@ class Tape:
for child in block.children:
left.append(child)
allocator.old_pool = None
for i, block in enumerate(reversed(self.basicblocks)):
if len(block.instructions) > 1000000:
print(
@@ -1055,10 +1082,20 @@ class Tape:
and block.exit_block.scope is not None
):
alloc_loop(block.exit_block.scope)
usage = allocator.max_usage.copy()
allocator.process(block.instructions, block.alloc_pool)
if self.program.verbose and usage != allocator.max_usage:
print("Allocated registers in %s " % block.name, end="")
for t, n in allocator.max_usage.items():
if n > usage[t]:
print("%s:%d " % (t, n - usage[t]), end="")
print()
allocator.finalize(options)
if self.program.verbose:
print("Tape register usage:", dict(allocator.usage))
print("Tape register usage:", dict(allocator.max_usage))
scopes = set(block.alloc_pool for block in self.basicblocks)
n_fragments = sum(scope.n_fragments() for scope in scopes)
print("%d register fragments in %d scopes" % (n_fragments, len(scopes)))
# offline data requirements
if self.program.verbose:
@@ -1499,6 +1536,7 @@ class Tape:
if Program.prog.options.noreallocate:
raise CompilerError("reallocation necessary for linking, "
"remove option -u")
assert self.reg_type == other.reg_type
self.duplicates |= other.duplicates
for dup in self.duplicates:
dup.duplicates = self.duplicates

View File

@@ -38,6 +38,7 @@ by party 0 and 1::
sfloat
sgf2n
cgf2n
personal
Container types
---------------
@@ -392,7 +393,7 @@ class _int(Tape._no_truth):
return a - prod, b + prod
def bit_xor(self, other):
""" XOR in arithmetic circuits.
""" Single-bit XOR in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:return: type depends on inputs (secret if any of them is) """
@@ -404,7 +405,7 @@ class _int(Tape._no_truth):
return self + other - 2 * self * other
def bit_or(self, other):
""" OR in arithmetic circuits.
""" Single-bit OR in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:return: type depends on inputs (secret if any of them is) """
@@ -416,14 +417,14 @@ class _int(Tape._no_truth):
return self + other - self * other
def bit_and(self, other):
""" AND in arithmetic circuits.
""" Single-bit AND in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:rtype: depending on inputs (secret if any of them is) """
return self * other
def bit_not(self):
""" NOT in arithmetic circuits. """
""" Single-bit NOT in arithmetic circuits. """
return 1 - self
def half_adder(self, other):
@@ -611,6 +612,8 @@ class _secret_structure(_structure):
if program.curr_tape != program.tapes[0]:
raise CompilerError('only available in main thread')
if content is not None:
if isinstance(content, (_vectorizable, Tape.Register)):
raise CompilerError('cannot input data already in the VM')
requested_shape = shape
if binary:
import numpy
@@ -800,11 +803,31 @@ class _register(Tape.Register, _number, _structure):
if self.size == size:
return self
assert self.size == 1
return self._expand_to_vector(size)
def _expand_to_vector(self, size):
res = type(self)(size=size)
for i in range(size):
self.mov(res[i], self)
return res
def copy_from_part(self, source, base, size):
set_global_vector_size(size)
self.mov(self, source.get_vector(base, size))
reset_global_vector_size()
@classmethod
def concat(cls, parts):
parts = list(parts)
res = cls(size=sum(len(part) for part in parts))
base = 0
for reg in parts:
set_global_vector_size(reg.size)
reg.mov(res.get_vector(base, reg.size), reg)
reset_global_vector_size()
base += reg.size
return res
class _arithmetic_register(_register):
""" Arithmetic circuit type. """
def __init__(self, *args, **kwargs):
@@ -1508,6 +1531,14 @@ class regint(_register, _int):
raise CompilerError("Cannot convert '%s' to integer" % \
type(val))
def expand_to_vector(self, size=None):
if size is None:
size = get_global_vector_size()
if self.size == size:
return self
assert self.size == 1
return self.inc(size, self, 0)
@vectorize
@read_mem_value
def int_op(self, other, inst, reverse=False):
@@ -1762,7 +1793,8 @@ class localint(Tape._no_truth):
class personal(Tape._no_truth):
""" Value known to one player. Supports operations with public
values and personal values known to the same player. Can be used
with :py:func:`~Compiler.library.print_ln_to`.
with :py:func:`~Compiler.library.print_ln_to`. It is possible to
convert to secret types like :py:class:`sint`.
:param player: player (int)
:param value: cleartext value (cint, cfix, cfloat) or array thereof
@@ -2023,11 +2055,16 @@ class _secret(_arithmetic_register, _secret_structure):
:rtype: same as inputs
"""
x = list(x)
set_global_vector_size(x[0].size)
res = cls()
dotprods(res, x, y)
reset_global_vector_size()
if isinstance(x, cls) and isinstance(y, cls):
assert len(x) == len(y)
res = cls()
matmuls(res, x, y, 1, len(x), 1)
else:
x = list(x)
set_global_vector_size(x[0].size)
res = cls()
dotprods(res, x, y)
reset_global_vector_size()
return res
@classmethod
@@ -2952,6 +2989,27 @@ class sint(_secret, _int):
prefixsums(res, self)
return res
def sum(self):
res = type(self)(size=1)
picks(res, self.prefix_sum(), len(self) - 1, 0)
return res
def _expand_to_vector(self, size):
res = type(self)(size=size)
picks(res, self, 0, 0)
return res
def copy_from_part(self, source, base, size):
picks(self, source, base, 1)
@classmethod
def concat(cls, parts):
parts = list(parts)
res = cls(size=sum(len(part) for part in parts))
args = sum(([len(part), part] for part in parts), [])
concats(res, *args)
return res
class sintbit(sint):
""" :py:class:`sint` holding a bit, supporting binary operations
(``&, |, ^``). """
@@ -4726,6 +4784,24 @@ class sfix(_fix):
def prefix_sum(self):
return self._new(self.v.prefix_sum(), k=self.k, f=self.f)
def sum(self):
return self._new(self.v.sum())
@classmethod
def concat(cls, parts):
parts = list(parts)
int_parts = []
f = parts[0].f
k = parts[0].k
for part in parts:
assert part.f == f
assert part.k == k
int_parts.append(part.v)
return cls._new(cls.int_type.concat(int_parts), k=k, f=f)
def __repr__(self):
return '<sfix{f=%d,k=%d} at %s>' % (self.f, self.k, self.v)
class unreduced_sfix(_single):
int_type = sint
@@ -5739,9 +5815,11 @@ class Array(_vectorizable):
if value.size != 1:
raise CompilerError('cannot assign vector to all elements')
mem_value = MemValue(value)
self.address = MemValue.if_necessary(self.address)
n_threads = 8 if use_threads and util.is_constant(self.length) and \
len(self) > 2**20 and not program.options.garbled else None
len(self) > 2**20 and not program.options.garbled and \
program.curr_tape.singular else None
if n_threads is not None:
self.address = MemValue.if_necessary(self.address)
@library.multithread(n_threads, self.length)
def _(base, size):
if use_vector:
@@ -6444,12 +6522,20 @@ class SubMultiArray(_vectorizable):
try:
try:
self.value_type.direct_matrix_mul
assert self.value_type == other.value_type
skip_reduce = set((sint, sfix)) == \
set((self.value_type, other.value_type))
assert self.value_type == other.value_type or skip_reduce
max_size = _register.maximum_size // res_matrix.sizes[1]
@library.multithread(n_threads, self.sizes[0], max_size)
def _(base, size):
res_matrix.assign_part_vector(
self.get_part(base, size).direct_mul(other), base)
tmp = self.get_part(base, size).direct_mul(
other, reduce=not skip_reduce,
res_type=sfix if skip_reduce else None)
if skip_reduce:
tmp = t._new(tmp.v)
else:
tmp = tmp.reduce_after_mul()
res_matrix.assign_part_vector(tmp, base)
except AttributeError:
assert n_threads is None
if max(res_matrix.sizes) > 1000:
@@ -6483,7 +6569,7 @@ class SubMultiArray(_vectorizable):
else:
raise NotImplementedError
def direct_mul(self, other, reduce=True, indices=None):
def direct_mul(self, other, reduce=True, indices=None, res_type=None):
""" Matrix multiplication in the virtual machine.
Unlike :py:func:`dot`, this only works for sint and sfix, and it
returns a vector instead of a data structure.
@@ -6511,10 +6597,15 @@ class SubMultiArray(_vectorizable):
other_sizes = other.sizes
assert len(other.sizes) == 2
assert self.sizes[1] == other_sizes[0]
assert self.value_type == other.value_type
return self.value_type.direct_matrix_mul(self.address, other.address,
self.sizes[0], *other_sizes,
reduce=reduce, indices=indices)
if self.value_type == other.value_type:
assert res_type in (self.value_type, None)
res_type = self.value_type
else:
assert not reduce
assert res_type
return res_type.direct_matrix_mul(self.address, other.address,
self.sizes[0], *other_sizes,
reduce=reduce, indices=indices)
def direct_mul_trans(self, other, reduce=True, indices=None):
"""
@@ -6988,7 +7079,10 @@ class _mem(_number):
__ilshift__ = lambda self,other: self.write(self.read() << other)
__irshift__ = lambda self,other: self.write(self.read() >> other)
iadd = __iadd__
def iadd(self, other):
""" Addition assignment. """
return self.__iadd__(other)
isub = __isub__
imul = __imul__
itruediv = __itruediv__
@@ -7016,7 +7110,7 @@ class MemValue(_mem):
@classmethod
def if_necessary(cls, value):
if util.is_constant_float(value):
if util.is_constant_float(value) or isinstance(value, MemValue):
return value
else:
return cls(value)
@@ -7121,7 +7215,7 @@ class MemValue(_mem):
return self.value_type.load_mem(addresses)
def __repr__(self):
return 'MemValue(%s,%d)' % (self.value_type, self.address)
return 'MemValue(%s,%s)' % (self.value_type, self.address)
class MemFloat(MemValue):

View File

@@ -28,6 +28,8 @@ for socket in client.sockets:
os.store(finish)
os.Send(socket)
# running two rounds
# first for sint, then for sfix
for x in bonus, bonus * 2 ** 16:
client.send_private_inputs([domain(x)])

View File

@@ -124,7 +124,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
{
Rq_Element tmp(*params);
SeededPRNG G;
vector<FFT_Data::S> r(params->FFTD()[0].m());
vector<FFT_Data::S> r(params->FFTD()[0].phi_m());
bigint p = pk.p();
assert(p != 0);
for (auto& x : r)

View File

@@ -2,6 +2,13 @@
#include "Ring.h"
#include "Tools/Exceptions.h"
Ring::Ring(int m) :
mm(0), phim(0)
{
if (m != 0)
init(*this, m);
}
void Ring::pack(octetStream& o) const
{
o.store(mm);

View File

@@ -22,7 +22,7 @@ class Ring
public:
Ring() : mm(0), phim(0) { ; }
Ring(int m = 0);
~Ring() { ; }
// Rely on default copy assignment/constructor
@@ -40,6 +40,7 @@ class Ring
void unpack(octetStream& o);
bool operator!=(const Ring& other) const;
bool operator==(const Ring& other) const { return not (*this != other); }
};
void init(Ring& Rg, int m, bool generate_poly = false);

View File

@@ -44,6 +44,7 @@ void Ring_Element::prepare(const Ring_Element& other)
void Ring_Element::prepare_push()
{
element.clear();
assert(FFTD);
element.reserve(FFTD->phi_m());
}
@@ -63,6 +64,7 @@ void Ring_Element::assign_zero()
void Ring_Element::assign_one()
{
assert(FFTD);
allocate();
modp fill;
if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); }
@@ -79,6 +81,7 @@ void Ring_Element::negate()
if (element.empty())
return;
assert(FFTD);
for (int i=0; i<(*FFTD).phi_m(); i++)
{ Negate(element[i],element[i],(*FFTD).get_prD()); }
}
@@ -87,6 +90,7 @@ void Ring_Element::negate()
void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
assert(a.FFTD);
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
if (a.element.empty())
{
@@ -119,6 +123,7 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
assert(a.FFTD);
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
if (a.element.empty())
@@ -148,6 +153,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
assert(a.FFTD);
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
if (a.element.empty() or b.element.empty())
@@ -200,9 +206,11 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
}
else if ((*a.FFTD).get_twop()==0)
{ // m a power of two case
ans.partial_assign(a);
Ring_Element aa(*ans.FFTD,ans.rep);
aa.partial_assign(a);
modp temp;
cerr << "slow polynomial multiplication "
"(change representation to change this)..." << endl;
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ for (int j=0; j<(*ans.FFTD).phi_m(); j++)
{ Mul(temp,a.element[i],b.element[j],(*a.FFTD).get_prD());
@@ -213,7 +221,9 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
}
Add(aa.element[k],aa.element[k],temp,(*a.FFTD).get_prD());
}
cerr << "\r" << i << "/" << ans.FFTD->phi_m();
}
cerr << endl;
ans=aa;
}
else
@@ -241,6 +251,7 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD);
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
@@ -252,6 +263,7 @@ Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD);
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
@@ -263,6 +275,7 @@ Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD);
assert(FFTD == other.FFTD);
assert(rep == other.rep);
assert(rep == evaluation);
@@ -274,6 +287,7 @@ Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
Ring_Element& Ring_Element::operator *=(const modp& other)
{
assert(FFTD);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].mul(other, FFTD->get_prD());
return *this;
@@ -282,6 +296,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other)
Ring_Element Ring_Element::mul_by_X_i(int j) const
{
assert(FFTD);
Ring_Element ans;
ans.prepare(*this);
if (element.empty())
@@ -331,6 +346,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const
void Ring_Element::randomize(PRNG& G,bool Diag)
{
assert(FFTD);
allocate();
if (Diag==false)
{ for (int i=0; i<(*FFTD).phi_m(); i++)
@@ -352,6 +368,7 @@ void Ring_Element::randomize(PRNG& G,bool Diag)
void Ring_Element::change_rep(RepType r)
{
assert(FFTD);
if (element.empty())
{
rep = r;
@@ -403,6 +420,7 @@ void Ring_Element::change_rep(RepType r)
bool Ring_Element::equals(const Ring_Element& a) const
{
assert(FFTD);
if (rep!=a.rep) { throw rep_mismatch(); }
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
@@ -417,6 +435,7 @@ bool Ring_Element::equals(const Ring_Element& a) const
bool Ring_Element::is_zero() const
{
assert(FFTD);
if (element.empty())
return true;
for (auto& x : element)
@@ -428,6 +447,7 @@ bool Ring_Element::is_zero() const
ConversionIterator Ring_Element::get_iterator() const
{
assert(FFTD);
if (rep != polynomial)
throw runtime_error("simple iterator only available in polynomial represention");
assert(not element.empty());
@@ -436,16 +456,19 @@ ConversionIterator Ring_Element::get_iterator() const
RingReadIterator Ring_Element::get_copy_iterator() const
{
assert(FFTD);
return *this;
}
RingWriteIterator Ring_Element::get_write_iterator()
{
assert(FFTD);
return *this;
}
vector<bigint> Ring_Element::to_vec_bigint() const
{
assert(FFTD);
vector<bigint> v;
to_vec_bigint(v);
return v;
@@ -454,6 +477,7 @@ vector<bigint> Ring_Element::to_vec_bigint() const
void Ring_Element::to_vec_bigint(vector<bigint>& v) const
{
assert(FFTD);
v.resize(FFTD->phi_m());
if (element.empty())
return;
@@ -476,6 +500,7 @@ void Ring_Element::to_vec_bigint(vector<bigint>& v) const
modp Ring_Element::get_constant() const
{
assert(FFTD);
if (element.empty())
return {};
else
@@ -516,6 +541,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
void Ring_Element::pack(octetStream& o) const
{
assert(FFTD);
check_size();
o.store(unsigned(rep));
store(o,element,(*FFTD).get_prD());
@@ -524,6 +550,7 @@ void Ring_Element::pack(octetStream& o) const
void Ring_Element::unpack(octetStream& o)
{
assert(FFTD);
unsigned int a;
o.get(a);
rep=(RepType) a;
@@ -542,12 +569,14 @@ void Ring_Element::check_rep()
void Ring_Element::check_size() const
{
assert(FFTD);
if (not element.empty() and (int)element.size() != FFTD->phi_m())
throw runtime_error("invalid element size");
}
void Ring_Element::output(ostream& s) const
{
assert(FFTD);
s.write((char*)&rep, sizeof(rep));
auto size = element.size();
s.write((char*)&size, sizeof(size));
@@ -558,6 +587,7 @@ void Ring_Element::output(ostream& s) const
void Ring_Element::input(istream& s)
{
assert(FFTD);
s.read((char*)&rep, sizeof(rep));
check_rep();
auto size = element.size();
@@ -579,6 +609,7 @@ void Ring_Element::check(const FFT_Data& FFTD) const
size_t Ring_Element::report_size(ReportType type) const
{
assert(FFTD);
if (type == CAPACITY)
return sizeof(modp) * element.capacity();
else

View File

@@ -56,9 +56,9 @@ class Ring_Element
void allocate();
void set_data(const FFT_Data& prd) { FFTD=&prd; }
const FFT_Data& get_FFTD() const { return *FFTD; }
const Zp_Data& get_prD() const { return (*FFTD).get_prD(); }
const bigint& get_prime() const { return (*FFTD).get_prime(); }
const FFT_Data& get_FFTD() const { assert(FFTD); return *FFTD; }
const Zp_Data& get_prD() const { return get_FFTD().get_prD(); }
const bigint& get_prime() const { return get_FFTD().get_prime(); }
void assign_zero();
void assign_one();
@@ -120,6 +120,7 @@ class Ring_Element
template <class T>
void from(const vector<T>& source)
{
assert(source.size() == (size_t) get_FFTD().phi_m());
from(Iterator<T>(source));
}
@@ -162,7 +163,7 @@ class RingWriteIterator : public WriteConversionIterator
RepType rep;
public:
RingWriteIterator(Ring_Element& element) :
WriteConversionIterator(element.element, element.FFTD->get_prD()),
WriteConversionIterator(element.element, element.get_FFTD().get_prD()),
element(element), rep(element.rep)
{
element.rep = polynomial;
@@ -177,7 +178,7 @@ class RingReadIterator : public ConversionIterator
Ring_Element element;
public:
RingReadIterator(const Ring_Element& element) :
ConversionIterator(this->element.element, element.FFTD->get_prD()),
ConversionIterator(this->element.element, element.get_FFTD().get_prD()),
element(element)
{
this->element.change_rep(polynomial);
@@ -198,6 +199,7 @@ void Ring_Element::from(const Generator<T>& generator)
T tmp;
modp tmp2;
prepare_push();
assert(FFTD);
for (int i=0; i<(*FFTD).phi_m(); i++)
{
generator.get(tmp);

View File

@@ -14,7 +14,10 @@ Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
if (prd.size() > 0)
a.push_back({prd[0], r0});
if (prd.size() > 1)
{
assert(prd[0].get_R() == prd[1].get_R());
a.push_back({prd[1], r1});
}
lev = n_mults();
}
@@ -155,6 +158,7 @@ void Rq_Element::to_vec_bigint(vector<bigint>& v) const
if (lev==1)
{ vector<bigint> v1;
a[1].to_vec_bigint(v1);
assert(v.size() == v1.size());
bigint p0=a[0].get_prime();
bigint p1=a[1].get_prime();
bigint p0i,lambda,Q=p0*p1;

View File

@@ -63,7 +63,10 @@ protected:
Rq_Element(const FHE_PK& pk);
Rq_Element(const Ring_Element& b0,const Ring_Element& b1) :
a({b0, b1}), lev(n_mults()) {}
a({b0, b1}), lev(n_mults())
{
assert(b0.get_FFTD().get_R() == b1.get_FFTD().get_R());
}
Rq_Element(const Ring_Element& b0) :
a({b0}), lev(n_mults()) {}
@@ -139,6 +142,8 @@ protected:
template <class T>
void from(const vector<T>& source, int level=-1)
{
for (auto& x : a)
assert(source.size() == (size_t) x.get_FFTD().phi_m());
from(Iterator<T>(source), level);
}

View File

@@ -76,11 +76,14 @@ modp Find_Primitive_Root_2m(int m,const vector<int>& poly,const Zp_Data& ZpD)
*/
modp Find_Primitive_Root_2power(int m,const Zp_Data& ZpD)
{
assert((m & (m - 1)) == 0);
assert(m > 1);
modp ans,e,one,base;
assignOne(one,ZpD);
assignOne(base,ZpD);
bigint exp;
exp=(ZpD.pr-1)/m;
assert(exp * m == ZpD.pr - 1);
bool flag=true;
while (flag)
{ Add(base,base,one,ZpD); // Keep incrementing base until we hit the answer

View File

@@ -17,6 +17,23 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
mf.allocate_slots(pk.p() << 64);
}
class ModuloTreeSum : public TreeSum<bigint>
{
bigint modulo;
void post_add_process(vector<bigint>& values)
{
for (auto& v : values)
v %= modulo;
}
public:
ModuloTreeSum(bigint modulo) :
modulo(modulo)
{
}
};
template<class FD>
Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
{
@@ -57,10 +74,7 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
}
else
{
TreeSum<bigint>().run(vv, P);
bigint mod=params.p0();
for (auto& v : vv)
v %= mod;
ModuloTreeSum(params.p0()).run(vv, P);
}
// Now get the final message

View File

@@ -76,6 +76,7 @@ void secure_init(T& setup, Player& P, U& machine,
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
string reason;
auto base_setup = setup;
try
{
@@ -107,7 +108,7 @@ void secure_init(T& setup, Player& P, U& machine,
<< " because no suitable material "
"from a previous run was found (" << reason << ")"
<< endl;
setup = {};
setup = base_setup;
setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine);
octetStream os;
@@ -122,6 +123,10 @@ void secure_init(T& setup, Player& P, U& machine,
cerr << "Ciphertext length: " << params.p0().numBits();
for (size_t i = 1; i < params.FFTD().size(); i++)
cerr << "+" << params.FFTD()[i].get_prime().numBits();
cerr << " (" << DIV_CEIL(params.p0().numBits(), 64);
for (size_t i = 1; i < params.FFTD().size(); i++)
cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64);
cerr << " limbs)";
cerr << endl;
}
}

View File

@@ -33,7 +33,11 @@ public:
{
ProtocolSetup<DealerShare<BitVec>> setup(*P);
ProtocolSet<DealerShare<BitVec>> set(*P, setup);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
int buffer_size = DIV_CEIL(
BaseMachine::batch_size<DealerSecret>(DATA_TRIPLE),
DealerSecret::default_length);
set.preprocessing.buffer_extra(DATA_TRIPLE, buffer_size);
for (int i = 0; i < buffer_size; i++)
{
auto triple = set.preprocessing.get_triple(
DealerSecret::default_length);

View File

@@ -70,6 +70,7 @@ public:
typedef U whole_type;
static const bool expensive_triples = true;
static const bool malicious = true;
static MC* new_mc(typename super::mac_key_type)
{

View File

@@ -26,6 +26,7 @@ public:
typedef Rep4Input<This> Input;
static const bool expensive_triples = false;
static const bool malicious = true;
static MC* new_mc(typename super::mac_key_type) { return new MC; }

View File

@@ -135,6 +135,9 @@ public:
static void run_tapes(const vector<int>& args) { T::run_tapes(args); }
template<class U>
static string proto_fake_opts() { return U::fake_opts(); }
Secret();
Secret(const Integer& x) { *this = x; }

View File

@@ -43,6 +43,9 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
void SemiPrep::buffer_triples()
{
assert(this->triple_generator);
this->triple_generator->set_batch_size(
DIV_CEIL(BaseMachine::batch_size<SemiSecret>(DATA_TRIPLE,
this->buffer_size), 64));
this->triple_generator->generatePlainTriples();
for (auto& x : this->triple_generator->plainTriples)
{

View File

@@ -151,6 +151,12 @@ public:
{
}
template<class T>
static string proto_fake_opts()
{
return T::fake_opts();
}
RepSecretBase()
{
}
@@ -258,6 +264,7 @@ public:
typedef SemiHonestRepSecret whole_type;
static const bool expensive_triples = false;
static const bool malicious = false;
static MC* new_mc(mac_key_type) { return new MC; }

View File

@@ -184,12 +184,15 @@ void ShareThread<T>::xors(Processor<T>& processor, const vector<int>& args)
int out = args[i + 1];
int left = args[i + 2];
int right = args[i + 3];
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
{
int n = min(T::default_length, n_bits - j * T::default_length);
processor.S[out + j].xor_(n, processor.S[left + j],
processor.S[right + j]);
}
if (n_bits == 1)
processor.S[out].xor_(1, processor.S[left], processor.S[right]);
else
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
{
int n = min(T::default_length, n_bits - j * T::default_length);
processor.S[out + j].xor_(n, processor.S[left + j],
processor.S[right + j]);
}
}
}

View File

@@ -62,6 +62,7 @@ public:
typedef TinierSecret<T> whole_type;
static const int default_length = 1;
static const bool expensive_triples = true;
static string name()
{

View File

@@ -34,7 +34,9 @@ void TinierSharePrep<T>::buffer_secret_triples()
vector<array<T, 3>> triples;
TripleShuffleSacrifice<T> sacrifice;
size_t required;
required = sacrifice.minimum_n_inputs_with_combining();
required = sacrifice.minimum_n_inputs_with_combining(
BaseMachine::batch_size<T>(DATA_TRIPLE));
triple_generator->set_batch_size(DIV_CEIL(required, 64));
while (triples.size() < required)
{
triple_generator->generatePlainTriples();

View File

@@ -50,7 +50,8 @@ public:
static const bool variable_players = T::variable_players;
static const bool needs_ot = T::needs_ot;
static const bool has_mac = T::has_mac;
static const bool expensive_triples = false;
static const bool malicious = T::malicious;
static const bool expensive_triples = T::expensive_triples;
static const bool randoms_for_opens = false;
static const int default_length = 64;

View File

@@ -82,7 +82,7 @@ CONFIG.mine:
%.o: %.cpp
$(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x mascot-party.x
offline: $(OT_EXE) Check-Offline.x mascot-offline.x cowgear-offline.x mal-shamir-offline.x

View File

@@ -154,9 +154,9 @@ void check_setup(string dir, bigint pr)
string filename = dir + "Params-Data";
ifstream(filename) >> p;
if (p == 0)
throw runtime_error("no modulus in " + filename);
throw setup_error("no modulus in " + filename);
if (p != pr)
throw runtime_error("wrong modulus in " + filename);
throw setup_error("wrong modulus in " + filename);
}
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,

View File

@@ -12,7 +12,7 @@ void ValueInterface::check_setup(const string& directory)
{
struct stat sb;
if (stat(directory.c_str(), &sb) != 0)
throw runtime_error(directory + " does not exist");
throw setup_error(directory + " does not exist");
if (not (sb.st_mode & S_IFDIR))
throw runtime_error(directory + " is not a directory");
throw setup_error(directory + " is not a directory");
}

View File

@@ -109,6 +109,11 @@ public:
void assign(const void* buffer) { avx_memcpy(a, buffer, N_BYTES); normalize(); }
void assign(int x) { *this = x; }
/**
* Get 64-bit part.
*
* @param i return word containing 64*i- to 64*i+63-least significant bits
*/
mp_limb_t get_limb(int i) const { return a[i]; }
bool get_bit(int i) const;

View File

@@ -8,15 +8,17 @@ void Zp_Data::init(const bigint& p,bool mont)
{
lock.lock();
#ifdef VERBOSE
if (pr != 0)
{
#ifdef VERBOSE
if (pr != p)
cerr << "Changing prime from " << pr << " to " << p << endl;
if (mont != montgomery)
cerr << "Changing Montgomery" << endl;
}
#endif
if (pr != p or mont != montgomery)
throw runtime_error("Zp_Data instance already initialized");
}
if (not probPrime(p))
throw runtime_error(p.get_str() + " is not a prime");

View File

@@ -47,7 +47,7 @@ class Zp_Data
void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const;
void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
{ Mont_Mult(z, x, y, t); }
{ Mont_Mult(z, x, y, get_t()); }
void Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, const mp_limb_t* y,
int max_t) const;
@@ -61,7 +61,7 @@ class Zp_Data
void assign(const Zp_Data& Zp);
void init(const bigint& p,bool mont=true);
int get_t() const { return t; }
int get_t() const { assert(t > 0); return t; }
const mp_limb_t* get_prA() const { return prA; }
bool get_mont() const { return montgomery; }
mp_limb_t overhang_mask() const;
@@ -73,8 +73,9 @@ class Zp_Data
Zp_Data() :
montgomery(0), pi(0), mask(0), pr_byte_length(0), pr_bit_length(0)
{
t = MAX_MOD_SZ;
t = -1;
overhang = 0;
shanks_r = 0;
}
// The main init funciton

View File

@@ -91,6 +91,7 @@ public:
template<int K>
bigint& operator=(const SignedZ2<K>& x);
/// Convert to signed representation in :math:`[-p/2,p/2]`.
template<int X, int L>
bigint& from_signed(const gfp_<X, L>& other);
template<class T>

View File

@@ -61,6 +61,8 @@ class gfp_ : public ValueInterface
static thread_local vector<gfp_> powers;
static gfp_ two;
public:
typedef gfp_ value_type;
@@ -317,6 +319,8 @@ gfp_<X, L>::gfp_(long x)
assign_zero();
else if (x == 1)
assign_one();
else if (x == 2)
*this = two;
else
*this = bigint::tmp = x;
}

View File

@@ -16,6 +16,8 @@ template<int X, int L>
const true_type gfp_<X, L>::prime_field;
template<int X, int L>
const int gfp_<X, L>::MAX_N_BITS;
template<int X, int L>
gfp_<X, L> gfp_<X, L>::two;
template<int X, int L>
inline void gfp_<X, L>::read_or_generate_setup(string dir,
@@ -50,6 +52,7 @@ void gfp_<X, L>::init_field(const bigint& p, bool mont)
else
cerr << name << " larger than necessary for modulus " << p << endl;
}
two = bigint::tmp = 2;
}
template <int X, int L>

View File

@@ -80,6 +80,12 @@ void gfpvar_<X, L>::init_default(int lgp, bool montgomery)
init_field(SPDZ_Data_Setup_Primes(lgp), montgomery);
}
template<int X, int L>
inline void gfpvar_<X, L>::reset()
{
ZpD = {};
}
template<int X, int L>
const Zp_Data& gfpvar_<X, L>::get_ZpD()
{

View File

@@ -68,6 +68,7 @@ public:
{
init_field(T::pr(), montgomery);
}
static void reset();
static const Zp_Data& get_ZpD();
static const bigint& pr();

View File

@@ -102,7 +102,7 @@ bool isZero(const modp_<L>& ans,const Zp_Data& ZpD)
template<int L>
void assignOne(modp_<L>& x,const Zp_Data& ZpD)
{ if (ZpD.montgomery)
{ mpn_copyi(x.x,ZpD.R,ZpD.t); }
{ mpn_copyi(x.x,ZpD.R,ZpD.get_t()); }
else
{ assignZero(x,ZpD);
x.x[0]=1;
@@ -177,7 +177,7 @@ void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
template<int L>
void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD)
{
inline_mpn_zero(ans.x,ZpD.t);
inline_mpn_zero(ans.x,ZpD.get_t());
if (x>=0)
{ ans.x[0]=x;
if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; }
@@ -232,13 +232,13 @@ void modp_<L>::convert_destroy(const fixint<M>& xx,
SignedZ2<64 * L> tmp = xx;
if (xx.negative())
tmp += ZpD.pr;
convert(tmp.get(), ZpD.t, ZpD, false);
convert(tmp.get(), ZpD.get_t(), ZpD, false);
}
template<int L>
void modp_<L>::convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& ZpD, bool negative)
{
assert(size <= ZpD.t);
assert(size <= ZpD.get_t());
if (negative)
mpn_sub(x, ZpD.prA, ZpD.t, source, size);
else

View File

@@ -110,6 +110,8 @@ public:
Player* parentPlayer = 0);
~OTTripleGenerator();
void set_batch_size(int nTriples);
void generate() { throw not_implemented(); }
void generatePlainTriples();

View File

@@ -69,6 +69,14 @@ Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
{
}
template<class T>
void OTTripleGenerator<T>::set_batch_size(int batch_size)
{
nTriplesPerLoop = DIV_CEIL(batch_size, nloops);
nTriples = nTriplesPerLoop * nloops;
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
}
template<class T>
OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
@@ -84,11 +92,9 @@ OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
machine(machine),
MC(0)
{
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
nTriples = nTriplesPerLoop * nloops;
field_size = T::open_type::size() * 8;
nAmplify = machine.amplify ? N_AMPLIFY : 1;
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
set_batch_size(_nTriples);
int n = nparties;
//baseReceiverInput = machines[0]->baseReceiverInput;

View File

@@ -13,6 +13,7 @@ using namespace std;
#include "OT/OTVole.h"
#include "OT/Rectangle.h"
#include "Tools/random.h"
#include "Tools/CheckVector.h"
template<class T>
class NPartyTripleGenerator;
@@ -187,7 +188,7 @@ class SemiMultiplier : public OTMultiplier<T>
}
public:
vector<typename T::open_type> c_output;
CheckVector<typename T::open_type> c_output;
SemiMultiplier(OTTripleGenerator<T>& generator, int i) :
OTMultiplier<T>(generator, i)

View File

@@ -163,10 +163,10 @@ void SemiMultiplier<T>::multiplyForBits()
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
int n_squares = otCorrelator.receiverOutputMatrix.squares.size();
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
baseReceiverOutput);
otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(),
this->generator.valueBits[0], false, -1);
otCorrelator.correlate(0, n_squares, aBits, false, -1);
c_output.clear();

View File

@@ -114,7 +114,7 @@ void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
__m128i buffer[T::size()];
size_t next = 0;
while (next + 16 < row.size())
while (next + 16 <= row.size())
{
for (int j = 0; j < 16; j++)
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
@@ -124,6 +124,8 @@ void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
for (int j = 0; j < 16; j++)
if (next < row.size())
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
else
memset((char*) buffer + j * T::size(), 0, T::size());
for (int j = 0; j < num_blocks % T::size(); j++)
add_mul(res, buffer[j], *coefficients++);
assert(coefficients == coeff_base + num_blocks);

View File

@@ -8,6 +8,8 @@
#include "Math/Setup.h"
#include "Tools/Bundle.h"
#include "Protocols/ShuffleSacrifice.hpp"
#include <iostream>
#include <sodium.h>
using namespace std;
@@ -30,6 +32,28 @@ BaseMachine& BaseMachine::s()
throw runtime_error("no singleton");
}
bool BaseMachine::has_program()
{
return has_singleton() and not s().progs.empty();
}
int BaseMachine::edabit_bucket_size(int n_bits)
{
int res = OnlineOptions::singleton.bucket_size;
if (has_program())
{
auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
for (int B = res; B <= 5; B++)
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
break;
else
res = B;
}
return res;
}
BaseMachine::BaseMachine() : nthreads(0)
{
if (sodium_init() == -1)

View File

@@ -11,6 +11,8 @@
#include "OT/OTTripleSetup.h"
#include "ThreadJob.h"
#include "ThreadQueues.h"
#include "Program.h"
#include "OnlineOptions.h"
#include <map>
#include <fstream>
@@ -44,8 +46,11 @@ public:
vector<string> bc_filenames;
vector<Program> progs;
static BaseMachine& s();
static bool has_singleton() { return singleton != 0; }
static bool has_program();
static string memory_filename(const string& type_short, int my_number);
@@ -54,6 +59,12 @@ public:
static int prime_length_from_schedule(string progname);
static bigint prime_from_schedule(string progname);
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
template<class T>
static int edabit_batch_size(int n_bits, int buffer_size = 0);
static int edabit_bucket_size(int n_bits);
BaseMachine();
virtual ~BaseMachine() {}
@@ -76,6 +87,8 @@ public:
void print_global_comm(Player& P, const NamedCommStats& stats);
void print_comm(Player& P, const NamedCommStats& stats);
virtual const Names& get_N() { throw not_implemented(); }
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
@@ -83,4 +96,105 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
return ot_setup.get_fresh(P);
}
template<class T>
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
{
int n_opts;
int n = 0;
int res = 0;
if (buffer_size > 0)
n_opts = buffer_size;
else if (fallback > 0)
n_opts = fallback;
else
n_opts = OnlineOptions::singleton.batch_size;
if (buffer_size <= 0 and has_program())
{
auto files = s().progs[0].get_offline_data_used().files;
auto usage = files[T::clear::field_type()];
if (type == DATA_DABIT and T::LivePrep::bits_from_dabits())
n = usage[DATA_BIT] + usage[DATA_DABIT];
else if (type == DATA_BIT and T::LivePrep::dabits_from_bits())
n = usage[DATA_BIT] + usage[DATA_DABIT];
else
n = usage[type];
}
else if (type != DATA_DABIT)
{
n = buffer_size;
buffer_size = 0;
n_opts = OnlineOptions::singleton.batch_size;
}
if (n > 0 and not (buffer_size > 0))
{
bool used_frac = false;
if (n > n_opts)
{
// finding the right fraction
for (int i = 1; i <= 10; i++)
{
int frac = DIV_CEIL(n, i);
if (frac <= n_opts)
{
res = frac;
used_frac = true;
#ifdef DEBUG_BATCH_SIZE
cerr << "found fraction " << frac << endl;
#endif
break;
}
}
}
if (not used_frac)
res = min(n, n_opts);
}
else
res = n_opts;
#ifdef DEBUG_BATCH_SIZE
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n="
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
#endif
assert(res > 0);
return res;
}
template<class T>
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
{
int n_opts;
int n = 0;
int res;
if (buffer_size > 0)
n_opts = buffer_size;
else
n_opts = OnlineOptions::singleton.batch_size;
if (has_program())
{
n = s().progs[0].get_offline_data_used().total_edabits(n_bits);
}
if (n > 0 and not (buffer_size > 0))
res = min(n, n_opts);
else
res = n_opts;
#ifdef DEBUG_BATCH_SIZE
cerr << "edaBits " << T::type_string() << " (" << n_bits
<< ") res=" << res << " n="
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
#endif
assert(res > 0);
return res;
}
#endif /* PROCESSOR_BASEMACHINE_H_ */

View File

@@ -229,3 +229,9 @@ bool DataPositions::any_more(const DataPositions& other) const
return false;
}
long long DataPositions::total_edabits(int n_bits) const
{
auto usage = edabits;
return usage[{false, n_bits}] + usage[{true, n_bits}];
}

View File

@@ -85,6 +85,8 @@ public:
void print_cost() const;
bool empty() const;
bool any_more(const DataPositions& other) const;
long long total_edabits(int n_bits) const;
};
template<class sint, class sgf2n> class Processor;
@@ -229,6 +231,10 @@ public:
static long additional_inputs(const DataPositions& usage);
static string get_prep_dir(const Names& N);
static void check_setup(const Names& N);
static void check_setup(int num_players, const string& prep_dir);
Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir,
DataPositions& usage, int thread_num = -1);
Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num = -1);
@@ -299,7 +305,7 @@ class Data_Files
Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<sint>* procp = 0,
SubProcessor<sgf2n>* proc2 = 0);
Data_Files(const Names& N);
Data_Files(const Names& N, int thread_num = -1);
~Data_Files();
DataPositions tellg() { return usage; }

View File

@@ -60,8 +60,7 @@ T Preprocessing<T>::get_random_from_inputs(int nplayers)
template<class T>
Sub_Data_Files<T>::Sub_Data_Files(const Names& N, DataPositions& usage,
int thread_num) :
Sub_Data_Files(N,
OnlineOptions::singleton.prep_dir_prefix<T>(N.num_players()), usage,
Sub_Data_Files(N, get_prep_dir(N), usage,
thread_num)
{
}
@@ -98,6 +97,32 @@ string Sub_Data_Files<T>::get_edabit_filename(const Names& N, int n_bits,
get_prep_sub_dir<T>(N.num_players()), n_bits, N.my_num(), thread_num);
}
template<class T>
string Sub_Data_Files<T>::get_prep_dir(const Names& N)
{
return OnlineOptions::singleton.prep_dir_prefix<T>(N.num_players());
}
template<class T>
void Sub_Data_Files<T>::check_setup(const Names& N)
{
return check_setup(N.num_players(), get_prep_dir(N));
}
template<class T>
void Sub_Data_Files<T>::check_setup(int num_players, const string& prep_dir)
{
try
{
T::clear::check_setup(prep_dir);
}
catch (exception& e)
{
throw prep_setup_error(e.what(), num_players,
T::template proto_fake_opts<typename T::clear>());
}
}
template<class T>
Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
const string& prep_data_dir, DataPositions& usage, int thread_num) :
@@ -109,19 +134,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
#endif
try
{
T::clear::check_setup(prep_data_dir);
}
catch (...)
{
cerr << "Something is wrong with the preprocessing data on disk." << endl;
cerr
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
<< num_players
<< T::clear::fake_opts() << "'?" << endl;
throw;
}
check_setup(num_players, prep_data_dir);
string type_short = T::type_short();
string type_string = T::type_string();
@@ -173,11 +186,11 @@ Data_Files<sint, sgf2n>::Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<
}
template<class sint, class sgf2n>
Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
Data_Files<sint, sgf2n>::Data_Files(const Names& N, int thread_num) :
usage(N.num_players()),
DataFp(*new Sub_Data_Files<sint>(N, usage)),
DataF2(*new Sub_Data_Files<sgf2n>(N, usage)),
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage))
DataFp(*new Sub_Data_Files<sint>(N, usage, thread_num)),
DataF2(*new Sub_Data_Files<sgf2n>(N, usage, thread_num)),
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage, thread_num))
{
}

View File

@@ -112,6 +112,8 @@ template<class T>
class DummyLivePrep : public Preprocessing<T>
{
public:
static const bool homomorphic = true;
static void basic_setup(Player&)
{
}
@@ -125,6 +127,11 @@ public:
"live preprocessing not implemented for " + T::type_string());
}
static bool bits_from_dabits()
{
return false;
}
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
Preprocessing<T>(usage)
{

View File

@@ -29,7 +29,12 @@ public:
if (not BufferBase::file)
{
if (this->open()->fail())
throw runtime_error("error opening " + this->filename);
throw runtime_error(
"error opening " + this->filename
+ ", have you generated edaBits, "
"for example by running "
"'./Fake-Offline.x -e "
+ to_string(n_bits) + " ...'?");
}
assert(BufferBase::file);

View File

@@ -6,6 +6,7 @@
#include "Instruction.h"
#include "instructions.h"
#include "Processor.h"
#include "Memory.h"
#include "Math/gf2n.h"
#include "GC/instructions.h"
@@ -54,7 +55,7 @@ void Instruction::gbitcom(vector<cgf2n>& registers) const
}
}
void Instruction::execute_regint(ArithmeticProcessor& Proc, vector<Integer>& Mi) const
void Instruction::execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const
{
(void) Mi;
auto& Ci = Proc.get_Ci();

View File

@@ -14,6 +14,7 @@ using namespace std;
template<class sint, class sgf2n> class Machine;
template<class sint, class sgf2n> class Processor;
template<class T> class SubProcessor;
template<class T> class MemoryPart;
class ArithmeticProcessor;
class SwitchableOutput;
@@ -86,6 +87,8 @@ enum
SUBCFI = 0x2B,
SUBSFI = 0x2C,
PREFIXSUMS = 0x2D,
PICKS = 0x2E,
CONCATS = 0x2F,
// Multiplication/division/other arithmetic
MULC = 0x30,
MULM = 0x31,
@@ -392,7 +395,7 @@ public:
template<class cgf2n>
void gbitcom(vector<cgf2n>& registers) const;
void execute_regint(ArithmeticProcessor& Proc, vector<Integer>& Mi) const;
void execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const;
void shuffle(ArithmeticProcessor& Proc) const;
void bitdecint(ArithmeticProcessor& Proc) const;

View File

@@ -208,6 +208,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_ints(r, s, 2);
n = get_int(s);
break;
case PICKS:
get_ints(r, s, 3);
n = get_int(s);
break;
case USE:
case USE_INP:
case USE_EDABIT:
@@ -392,6 +396,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case EDABIT:
case SEDABIT:
case WRITEFILESHARE:
case CONCATS:
num_var_args = get_int(s) - 1;
r[0] = get_int(s);
get_vector(num_var_args, start, s);
@@ -930,6 +935,20 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case MOVC:
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
break;
case CONCATS:
{
auto& S = Proc.Procp.get_S();
auto dest = S.begin() + r[0];
for (auto j = start.begin(); j < start.end(); j += 2)
{
auto source = S.begin() + *(j + 1);
assert(dest + *j <= S.end());
assert(source + *j <= S.end());
for (int k = 0; k < *j; k++)
*dest++ = *source++;
}
return;
}
case DIVC:
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2]));
break;

View File

@@ -53,8 +53,6 @@ class Machine : public BaseMachine
public:
vector<Program> progs;
Memory<sgf2n> M2;
Memory<sint> Mp;
Memory<Integer> Mi;
@@ -63,10 +61,6 @@ class Machine : public BaseMachine
vector<Timer> join_timer;
Timer finish_timer;
bool direct;
int opening_sum;
bool receive_threads;
int max_broadcast;
bool use_encryption;
bool live_prep;

View File

@@ -55,8 +55,6 @@ template<class sint, class sgf2n>
Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
const OnlineOptions opts, int lg2)
: my_number(playerNames.my_num()), N(playerNames),
direct(opts.direct), opening_sum(opts.opening_sum),
receive_threads(opts.receive_threads), max_broadcast(opts.max_broadcast),
use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts),
external_clients(my_number)
{
@@ -69,11 +67,6 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
exit(1);
}
if (opening_sum < 2)
this->opening_sum = N.num_players();
if (max_broadcast < 2)
this->max_broadcast = N.num_players();
// Set the prime modulus from command line or program if applicable
if (opts.prime)
sint::clear::init_field(opts.prime);
@@ -102,7 +95,17 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
sint::bit_type::MAC_Check::setup(*P);
sgf2n::MAC_Check::setup(*P);
alphapi = read_generate_write_mac_key<sint>(*P);
if (opts.live_prep)
alphapi = read_generate_write_mac_key<sint>(*P);
else
{
// check for directory
Sub_Data_Files<sint>::check_setup(N);
// require existing MAC key
if (sint::has_mac)
read_mac_key<sint>(N, alphapi);
}
alpha2i = read_generate_write_mac_key<sgf2n>(*P);
alphabi = read_generate_write_mac_key<typename
sint::bit_type::part_type>(*P);
@@ -451,6 +454,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
finish_timer.start();
// actual usage
bool multithread = nthreads > 1;
auto res = stop_threads();
DataPositions& pos = res.first;
@@ -479,7 +483,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
cerr << "Communication details "
"(rounds in parallel threads counted double):" << endl;
comm_stats.print();
cerr << "CPU time = " << proc_timer.elapsed() << endl;
cerr << "CPU time = " << proc_timer.elapsed();
if (multithread)
cerr << " (overall core time)";
cerr << endl;
}
print_timers();

View File

@@ -19,6 +19,28 @@ template<class T>
class MemoryPart : public CheckVector<T>
{
public:
template<class U>
static void check_index(const vector<U>& M, size_t i)
{
(void) M, (void) i;
#ifndef NO_CHECK_INDEX
if (i >= M.size())
throw overflow(U::type_string() + " memory", i, M.size());
#endif
}
T& operator[](size_t i)
{
check_index(*this, i);
return CheckVector<T>::operator[](i);
}
const T& operator[](size_t i) const
{
check_index(*this, i);
return CheckVector<T>::operator[](i);
}
void minimum_size(size_t size);
};
@@ -40,35 +62,21 @@ class Memory
size_t size_c()
{ return MC.size(); }
template<class U>
static void check_index(const vector<U>& M, size_t i)
{
(void) M, (void) i;
#ifndef NO_CHECK_INDEX
if (i >= M.size())
throw overflow(U::type_string() + " memory", i, M.size());
#endif
}
const typename T::clear& read_C(size_t i) const
{
check_index(MC, i);
return MC[i];
}
const T& read_S(size_t i) const
{
check_index(MS, i);
return MS[i];
}
void write_C(size_t i,const typename T::clear& x)
{
check_index(MC, i);
MC[i]=x;
}
void write_S(size_t i,const T& x)
{
check_index(MS, i);
MS[i]=x;
}

View File

@@ -12,13 +12,11 @@
#include "Networking/CryptoPlayer.h"
template<class W>
class OfflineMachine : public W
class OfflineMachine : public W, BaseMachine
{
DataPositions usage;
BaseMachine machine;
Names& playerNames;
Player& P;
int n_threads;
template<class T>
void generate();
@@ -34,6 +32,8 @@ public:
template<class T, class U>
int run();
const Names& get_N();
};
#endif /* PROCESSOR_OFFLINEMACHINE_H_ */

View File

@@ -18,18 +18,19 @@ OfflineMachine<W>::OfflineMachine(int argc, const char** argv,
W(argc, argv, opt, online_opts, V(), nplayers), playerNames(
W::playerNames), P(*this->new_player("machine"))
{
machine.load_schedule(online_opts.progname, false);
load_schedule(online_opts.progname, false);
Program program(playerNames.num_players());
program.parse(machine.bc_filenames[0]);
program.parse(bc_filenames[0]);
progs.push_back(program);
if (program.usage_unknown())
{
cerr << "Preprocessing might be insufficient "
cerr << "Preprocessing will be insufficient "
<< "due to unknown requirements" << endl;
exit(1);
}
usage = program.get_offline_data_used();
n_threads = machine.nthreads;
}
template<class W>
@@ -73,7 +74,7 @@ int OfflineMachine<W>::run()
template<class W>
int OfflineMachine<W>::buffered_total(size_t required, size_t batch)
{
return DIV_CEIL(required, batch) * batch + (n_threads - 1) * batch;
return DIV_CEIL(required, batch) * batch + (nthreads - 1) * batch;
}
template<class W>
@@ -183,4 +184,10 @@ void OfflineMachine<W>::generate()
output.Check(P);
}
template<class W>
const Names& OfflineMachine<W>::get_N()
{
return playerNames;
}
#endif /* PROCESSOR_OFFLINEMACHINE_HPP_ */

View File

@@ -43,6 +43,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
BaseMachine::s().thread_num = num;
auto& queues = machine.queues[num];
auto& opts = machine.opts;
queues->next();
ThreadQueue::thread_queue = queues;
@@ -58,7 +59,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#endif
player = new CryptoPlayer(*(tinfo->Nms), id);
}
else if (!machine.receive_threads or machine.direct)
else if (!opts.receive_threads or opts.direct)
{
#ifdef VERBOSE_OPTIONS
cerr << "Using single-threaded receiving" << endl;
@@ -80,7 +81,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
typename sgf2n::MAC_Check* MC2;
typename sint::MAC_Check* MCp;
if (machine.direct)
if (opts.direct)
{
#ifdef VERBOSE_OPTIONS
cerr << "Using direct communication." << endl;
@@ -93,8 +94,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#ifdef VERBOSE_OPTIONS
cerr << "Using indirect communication." << endl;
#endif
MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), machine.opening_sum, machine.max_broadcast);
MCp = new typename sint::MAC_Check(*(tinfo->alphapi), machine.opening_sum, machine.max_broadcast);
MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), opts.opening_sum, opts.max_broadcast);
MCp = new typename sint::MAC_Check(*(tinfo->alphapi), opts.opening_sum, opts.max_broadcast);
}
// Allocate memory for first program before starting the clock
@@ -376,6 +377,10 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
{
ti.Sub_Main_Func();
}
catch (setup_error&)
{
throw;
}
catch (...)
{
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
@@ -393,16 +398,20 @@ void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N, int thread_nu
cerr << "Purging preprocessed data because something is wrong" << endl;
try
{
Data_Files<sint, sgf2n> df(N);
Data_Files<sint, sgf2n> df(N, thread_num);
df.purge();
DataPositions pos;
Sub_Data_Files<typename sint::bit_type> bit_df(N, pos, thread_num);
bit_df.get_part();
bit_df.purge();
}
catch(...)
catch(setup_error&)
{
}
catch(exception& e)
{
cerr << "Purging failed. This might be because preprocessed data is incomplete." << endl
<< "SECURITY FAILURE; YOU ARE ON YOUR OWN NOW!" << endl;
cerr << "Reason: " << e.what() << endl;
}
}

View File

@@ -41,25 +41,28 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir,
}
void PrepBase::print_left(const char* name, size_t n, const string& type_string,
size_t used)
size_t used, bool large)
{
if (n > 0 and OnlineOptions::singleton.verbose)
cerr << "\t" << n << " " << name << " of " << type_string << " left"
<< endl;
if (n > used / 10)
if (n > used / 10 and n >= 64)
{
cerr << "Significant amount of unused " << name << " of " << type_string
<< ". For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size." << endl;
cerr
<< "Note that some protocols have larger minimum batch sizes."
<< endl;
<< " distorting the benchmark. ";
if (large)
cerr << "This protocol has a large minimum batch size, "
<< "which makes this unavoidable for small programs.";
else
cerr << "For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size.";
cerr << endl;
}
}
void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used)
int n_bits, size_t used, bool malicious)
{
if (n > 0 and OnlineOptions::singleton.verbose)
{
@@ -70,8 +73,15 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
}
if (n * n_batch > used / 10)
{
cerr << "Significant amount of unused edaBits of size " << n_bits
<< ". For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size "
<< "or increasing the bucket size with --bucket-size." << endl;
<< ". ";
if (malicious)
cerr << "This protocol has a large minimum batch size, "
<< "which makes this unavoidable for small programs.";
else
cerr << "For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size.";
cerr << endl;
}
}

View File

@@ -26,9 +26,9 @@ public:
int my_num, int thread_num = 0);
static void print_left(const char* name, size_t n,
const string& type_string, size_t used);
const string& type_string, size_t used, bool large = false);
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used);
int n_bits, size_t used, bool malicious);
TimerWithComm prep_timer;
};

View File

@@ -65,6 +65,8 @@
X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
sint s, \
s += *op1++; *dest++ = s) \
X(PICKS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1] + r[2]], \
*dest++ = *op1; op1 += int(n)) \
X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
auto op2 = &Procp.get_C()[r[2]], \
*dest++ = *op1++ * *op2++) \

View File

@@ -38,7 +38,7 @@ except:
try:
batch_size = int(program.args[2])
except:
batch_size = N
batch_size = min(N, 128)
if 'savemem' in program.args:
N = batch_size

View File

@@ -38,7 +38,7 @@ except:
try:
batch_size = int(program.args[2])
except:
batch_size = N
batch_size = min(N, 128)
if 'savemem' in program.args:
N = batch_size

View File

@@ -38,7 +38,7 @@ except:
try:
batch_size = int(program.args[2])
except:
batch_size = N
batch_size = min(N, 128)
assert batch_size <= N
ml.Layer.back_batch_size = batch_size

View File

@@ -26,12 +26,14 @@ exec(subprocess.check_output(['Scripts/process-tf.py', program.args[1]]))
opt = ml.Optimizer()
opt.set_layers_with_inputs(layers)
layers[0].X.input_from(0)
layers[0].X.input_from(0, binary=True)
for layer in layers:
layer.input_from(0, raw='raw' in program.args)
layer.input_from(0, binary=True)
sint(0).reveal().store_in_mem(0)
opt.time_layers = 'time_layers' in program.args
start_timer(1)
opt.forward(1, keep_intermediate=False)
stop_timer(1)

View File

@@ -115,7 +115,7 @@ void BrainPrep<T>::buffer_triples()
+ to_string(ZProtocol<T>::share_type::clear::N_BITS)
+ "-bit integer computation");
typedef Rep3Share<gfp2> pShare;
auto buffer_size = OnlineOptions::singleton.batch_size;
auto buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE);
Player& P = this->protocol->P;
vector<array<ZShare<T>, 3>> triples;
vector<array<Rep3Share<gfp2>, 3>> check_triples;

32
Protocols/BufferScope.h Normal file
View File

@@ -0,0 +1,32 @@
/*
* BufferScope.h
*
*/
#ifndef PROTOCOLS_BUFFERSCOPE_H_
#define PROTOCOLS_BUFFERSCOPE_H_
template<class T> class BufferPrep;
template<class T> class Preprocessing;
template<class T>
class BufferScope
{
BufferPrep<T>& prep;
int bak;
public:
BufferScope(Preprocessing<T> & prep, int buffer_size) :
prep(dynamic_cast<BufferPrep<T>&>(prep))
{
bak = this->prep.buffer_size;
this->prep.buffer_size = buffer_size;
}
~BufferScope()
{
prep.buffer_size = bak;
}
};
#endif /* PROTOCOLS_BUFFERSCOPE_H_ */

View File

@@ -33,6 +33,8 @@ class ChaiGearPrep : public MaliciousRingPrep<T>
void buffer_bits(false_type);
public:
static const bool homomorphic = true;
static void basic_setup(Player& P);
static void key_setup(Player& P, mac_key_type alphai);
static void teardown();

View File

@@ -33,6 +33,8 @@ class CowGearPrep : public MaliciousRingPrep<T>
void buffer_bits(false_type);
public:
static const bool homomorphic = true;
static void basic_setup(Player& P);
static void key_setup(Player& P, mac_key_type alphai);
static void setup(Player& P, mac_key_type alphai);

View File

@@ -112,9 +112,6 @@ PairwiseGenerator<typename T::clear::FD>& CowGearPrep<T>::get_generator()
{
auto& machine = *pairwise_machine;
typedef typename T::open_type::FD FD;
// generate minimal number of items
this->buffer_size = min(machine.setup<FD>().alpha.num_slots(),
(unsigned)OnlineOptions::singleton.batch_size);
pairwise_generator = new PairwiseGenerator<FD>(0, machine, &proc->P);
}
return *pairwise_generator;

View File

@@ -6,21 +6,28 @@
#ifndef PROTOCOLS_DABITSACRIFICE_H_
#define PROTOCOLS_DABITSACRIFICE_H_
#include "Processor/BaseMachine.h"
template<class T>
class DabitSacrifice
{
const int S;
size_t n_masks, n_produced;
public:
DabitSacrifice();
~DabitSacrifice();
int minimum_n_inputs(int n_outputs = 0)
{
if (n_outputs < 1)
n_outputs = OnlineOptions::singleton.batch_size;
if (T::clear::N_BITS < 0)
// sacrifice uses S^2 random bits
n_outputs = max(n_outputs, 10 * S * S);
n_outputs = BaseMachine::batch_size<T>(DATA_DABIT,
n_outputs, max(n_outputs, 10 * S * S));
else
n_outputs = BaseMachine::batch_size<T>(DATA_DABIT, n_outputs);
assert(n_outputs > 0);
return n_outputs + S;
}

View File

@@ -7,13 +7,15 @@
#define PROTOCOLS_DABITSACRIFICE_HPP_
#include "DabitSacrifice.h"
#include "BufferScope.h"
#include "Tools/PointerVector.h"
#include <math.h>
template<class T>
DabitSacrifice<T>::DabitSacrifice() :
S(OnlineOptions::singleton.security_parameter)
S(OnlineOptions::singleton.security_parameter),
n_masks(0), n_produced()
{
}
@@ -36,10 +38,15 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
timer.start();
#endif
int n = check_dabits.size() - S;
n_masks += S;
assert(n > 0);
GlobalPRNG G(proc.P);
typedef typename T::bit_type::part_type BT;
vector<T> shares;
vector<BT> bit_shares;
if (T::clear::N_BITS <= 0)
dynamic_cast<BufferPrep<T>&>(proc.DataF).buffer_extra(DATA_BIT,
S * (ceil(log2(n)) + S));
for (int i = 0; i < S; i++)
{
dabit<T> to_check;
@@ -58,6 +65,7 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
T tmp;
proc.DataF.get_one(DATA_BIT, tmp);
masked += tmp << (1 + j);
n_masks++;
}
shares.push_back(masked);
bit_shares.push_back(to_check.second);
@@ -84,6 +92,7 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
}
}
dabits.insert(dabits.end(), check_dabits.begin(), check_dabits.begin() + n);
n_produced += n;
MCBB.Check(proc.P);
delete &MCBB;
#ifdef VERBOSE_DABIT
@@ -92,6 +101,17 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
#endif
}
template<class T>
DabitSacrifice<T>::~DabitSacrifice()
{
#ifdef DABIT_WASTAGE
if (n_produced > 0)
{
cerr << "daBit wastage: " << float(n_masks) / n_produced << endl;
}
#endif
}
template<class T>
void DabitSacrifice<T>::sacrifice_and_check_bits(vector<dabit<T> >& dabits,
vector<dabit<T> >& check_dabits, SubProcessor<T>& proc,
@@ -113,7 +133,10 @@ void DabitSacrifice<T>::sacrifice_and_check_bits(vector<dabit<T> >& dabits,
queues->wrap_up(job);
}
else
{
BufferScope<T> scope(proc.DataF, multiplicands.size());
protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc);
}
vector<T> check_for_zero;
for (auto& x : to_check)
check_for_zero.push_back(x.first - products.next());

View File

@@ -17,11 +17,13 @@ void DealerPrep<T>::buffer_triples()
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE,
this->buffer_size);
if (this->proc->input.is_dealer())
{
SeededPRNG G;
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
{
T triples[3];
for (int i = 0; i < 2; i++)
@@ -41,7 +43,7 @@ void DealerPrep<T>::buffer_triples()
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
this->triples.push_back(to_receive.back().get<FixedVec<T, 3>>().get());
}
}
@@ -68,11 +70,12 @@ void DealerPrep<T>::buffer_inverses(true_type)
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int buffer_size = BaseMachine::batch_size<T>(DATA_INVERSE);
if (this->proc->input.is_dealer())
{
SeededPRNG G;
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
{
T tuple[2];
while (tuple[0] == 0)
@@ -92,7 +95,7 @@ void DealerPrep<T>::buffer_inverses(true_type)
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
this->inverses.push_back(to_receive.back().get<FixedVec<T, 2>>().get());
}
}
@@ -105,11 +108,12 @@ void DealerPrep<T>::buffer_bits()
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int buffer_size = BaseMachine::batch_size<T>(DATA_BIT);
if (this->proc->input.is_dealer())
{
SeededPRNG G;
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
{
T bit = G.get_bit();
make_share(shares.data(), typename T::clear(bit),
@@ -123,7 +127,7 @@ void DealerPrep<T>::buffer_bits()
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
this->bits.push_back(to_receive.back().get<T>());
}
}
@@ -136,12 +140,13 @@ void DealerPrep<T>::buffer_dabits(ThreadQueues*)
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int buffer_size = BaseMachine::batch_size<T>(DATA_DABIT);
if (this->proc->input.is_dealer())
{
SeededPRNG G;
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
vector<GC::SemiSecret> bit_shares(P.num_players() - 1);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
{
auto bit = G.get_bit();
make_share(shares.data(), typename T::clear(bit),
@@ -160,7 +165,7 @@ void DealerPrep<T>::buffer_dabits(ThreadQueues*)
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
for (int i = 0; i < buffer_size; i++)
{
this->dabits.push_back({to_receive.back().get<T>(),
to_receive.back().get<typename T::bit_type>()});
@@ -200,7 +205,8 @@ void DealerPrep<T>::buffer_edabits(int length, false_type)
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int n_vecs = OnlineOptions::singleton.batch_size / edabitvec<T>::MAX_SIZE;
int n_vecs = DIV_CEIL(BaseMachine::edabit_batch_size<T>(length),
edabitvec<T>::MAX_SIZE);
auto& buffer = this->edabits[{false, length}];
if (this->proc->input.is_dealer())
{

View File

@@ -28,6 +28,8 @@ class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>>
HemiMatrixPrep(const HemiMatrixPrep&) = delete;
public:
static const bool homomorphic = true;
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep,
DataPositions& usage) :
super(usage), n_rows(n_rows), n_inner(n_inner),

View File

@@ -35,6 +35,8 @@ class HemiPrep : public SemiHonestRingPrep<T>
SemiPrep<T>& get_two_party_prep();
public:
static const bool homomorphic = true;
static void basic_setup(Player& P);
static void teardown();

View File

@@ -142,6 +142,8 @@ void HemiPrep<T>::buffer_bits()
if (this->proc->P.num_players() == 2)
{
auto& prep = get_two_party_prep();
prep.buffer_size = BaseMachine::batch_size<T>(DATA_BIT,
this->buffer_size);
prep.buffer_dabits(0);
for (auto& x : prep.dabits)
this->bits.push_back(x.first);
@@ -158,6 +160,8 @@ void HemiPrep<T>::buffer_dabits(ThreadQueues* queues)
if (this->proc->P.num_players() == 2)
{
auto& prep = get_two_party_prep();
prep.buffer_size = BaseMachine::batch_size<T>(DATA_DABIT,
this->buffer_size);
prep.buffer_dabits(queues);
this->dabits = prep.dabits;
prep.dabits.clear();

View File

@@ -24,7 +24,7 @@ KeyGenProtocol<X, L>::KeyGenProtocol(Player& P, const FHE_Params& params,
int level) :
P(P), params(params), fftd(params.FFTD().at(level)), usage(P)
{
open_type::init_field(params.FFTD().at(level).get_prD().pr);
open_type::init_field(params.FFTD().at(level).get_prD().pr, false);
typename share_type::mac_key_type alphai;
auto& batch_size = OnlineOptions::singleton.batch_size;
@@ -54,7 +54,8 @@ KeyGenProtocol<X, L>::~KeyGenProtocol()
{
MC->Check(P);
usage.print_cost();
if (OnlineOptions::singleton.verbose)
usage.print_cost();
delete proc;
delete prep;
@@ -63,6 +64,7 @@ KeyGenProtocol<X, L>::~KeyGenProtocol()
MC->teardown();
OnlineOptions::singleton.batch_size = backup_batch_size;
open_type::reset();
}
template<int X, int L>

View File

@@ -12,6 +12,7 @@ using namespace std;
#include "Protocols/MAC_Check_Base.h"
#include "Tools/time-func.h"
#include "Tools/Coordinator.h"
#include "Processor/OnlineOptions.h"
/* The MAX number of things we will partially open before running
@@ -38,6 +39,8 @@ class TreeSum
void add_openings(vector<T>& values, const Player& P, int sum_players,
int last_sum_players, int send_player);
virtual void post_add_process(vector<T>&) {}
protected:
int base_player;
int opening_sum;
@@ -54,7 +57,9 @@ public:
vector<Timer> timers;
vector<Timer> player_timers;
TreeSum(int opening_sum = 10, int max_broadcast = 10, int base_player = 0);
TreeSum(int opening_sum = OnlineOptions::singleton.opening_sum,
int max_broadcast = OnlineOptions::singleton.max_broadcast,
int base_player = 0);
virtual ~TreeSum();
void run(vector<T>& values, const Player& P);
@@ -114,7 +119,7 @@ Coordinator* Tree_MAC_Check<U>::coordinator = 0;
* SPDZ opening protocol with MAC check (indirect communication)
*/
template<class U>
class MAC_Check_ : public Tree_MAC_Check<U>
class MAC_Check_ : public virtual Tree_MAC_Check<U>
{
public:
MAC_Check_(const typename U::mac_key_type::Scalar& ai, int opening_sum = 10,
@@ -135,7 +140,7 @@ template<class T> class MascotPrep;
* SPDZ2k opening protocol with MAC check
*/
template<class T, class U, class V, class W>
class MAC_Check_Z2k : public Tree_MAC_Check<W>
class MAC_Check_Z2k : public virtual Tree_MAC_Check<W>
{
protected:
Preprocessing<W>* prep;
@@ -161,12 +166,11 @@ template<class W>
using MAC_Check_Z2k_ = MAC_Check_Z2k<typename W::open_type,
typename W::mac_key_type, typename W::open_type, W>;
/**
* SPDZ opening protocol with MAC check (pairwise communication)
*/
template<class T>
class Direct_MAC_Check: public MAC_Check_<T>
class Direct_MAC_Check: public virtual MAC_Check_<T>
{
typedef MAC_Check_<T> super;
@@ -186,7 +190,35 @@ public:
void init_open(const Player& P, int n = 0);
void prepare_open(const T& secret, int = -1);
void exchange(const Player& P);
virtual void exchange(const Player& P);
};
template<class T>
class Direct_MAC_Check_Z2k: virtual public MAC_Check_Z2k_<T>,
virtual public Direct_MAC_Check<T>
{
public:
Direct_MAC_Check_Z2k(const typename T::mac_key_type& ai) :
Tree_MAC_Check<T>(ai), MAC_Check_Z2k_<T>(ai), MAC_Check_<T>(ai),
Direct_MAC_Check<T>(ai)
{
}
void prepare_open(const T& secret, int = -1)
{
MAC_Check_Z2k_<T>::prepare_open(secret);
}
void exchange(const Player& P)
{
Direct_MAC_Check<T>::exchange(P);
assert(this->WaitingForCheck() > 0);
}
void Check(const Player& P)
{
MAC_Check_Z2k_<T>::Check(P);
}
};
@@ -272,6 +304,7 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
{
values[i].add(oss[j], use_lengths ? lengths[i] : -1);
}
post_add_process(values);
MC.timers[SUM].stop();
}
}
@@ -279,6 +312,11 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
template<class T>
void TreeSum<T>::start(vector<T>& values, const Player& P)
{
if (opening_sum < 2)
opening_sum = P.num_players();
if (max_broadcast < 2)
max_broadcast = P.num_players();
os.reset_write_head();
int sum_players = P.num_players();
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());

View File

@@ -98,6 +98,7 @@ void Tree_MAC_Check<U>::init_open(const Player&, int n)
template<class U>
void Tree_MAC_Check<U>::prepare_open(const U& secret, int)
{
assert(U::mac_type::invertible);
this->values.push_back(secret.get_share());
macs.push_back(secret.get_mac());
}
@@ -344,7 +345,7 @@ Direct_MAC_Check<T>::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai
template<class T>
Direct_MAC_Check<T>::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai) :
MAC_Check_<T>(ai)
Tree_MAC_Check<T>(ai), MAC_Check_<T>(ai)
{
open_counter = 0;
}
@@ -405,6 +406,7 @@ void Direct_MAC_Check<T>::init_open(const Player& P, int n)
template<class T>
void Direct_MAC_Check<T>::prepare_open(const T& secret, int)
{
assert(T::mac_type::invertible);
this->values.push_back(secret.get_share());
this->macs.push_back(secret.get_mac());
}

View File

@@ -77,6 +77,11 @@ class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
public virtual SimplerMalRepRingPrep<T>
{
public:
static bool dabits_from_bits()
{
return true;
}
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
void set_protocol(typename T::Protocol& protocol)

View File

@@ -64,7 +64,8 @@ void MalRepRingPrep<T>::buffer_squares()
MaliciousRepPrep<prep_type> prep(_);
assert(this->proc != 0);
prep.init_honest(this->proc->P);
prep.buffer_size = this->buffer_size;
prep.buffer_size = BaseMachine::batch_size<T>(DATA_SQUARE,
this->buffer_size);
prep.buffer_squares();
for (auto& x : prep.squares)
this->squares.push_back({{x[0], x[1]}});

View File

@@ -35,11 +35,6 @@ public:
typedef MalRepRingShare<K + 2, S> SquareToBitShare;
typedef MalRepRingPrep<MalRepRingShare> SquarePrep;
static string type_short()
{
return "RR";
}
MalRepRingShare()
{
}

View File

@@ -53,6 +53,10 @@ public:
{
return "M" + string(1, T::type_char());
}
static string type_string()
{
return "malicious " + super::type_string();
}
MaliciousRep3Share()
{

View File

@@ -76,7 +76,8 @@ void MaliciousRepPrep<T>::buffer_triples()
{
check_field_size<typename T::open_type>();
auto& triples = this->triples;
auto buffer_size = this->buffer_size;
auto buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE,
this->buffer_size);
auto& honest_proc = this->honest_proc;
assert(honest_proc != 0);
Player& P = honest_proc->P;
@@ -140,7 +141,7 @@ void MaliciousRepPrep<T>::buffer_squares()
vector<typename T::open_type> opened;
vector<array<T, 2>> check_squares;
auto& squares = this->squares;
auto buffer_size = this->buffer_size;
auto buffer_size = BaseMachine::batch_size<T>(DATA_SQUARE, this->buffer_size);
auto& honest_prep = this->honest_prep;
auto& honest_proc = this->honest_proc;
auto& MC = this->MC;
@@ -186,7 +187,8 @@ void MaliciousBitOnlyRepPrep<T>::buffer_bits()
vector<typename T::open_type> opened;
vector<array<T, 2>> check_squares;
auto& bits = this->bits;
auto buffer_size = this->buffer_size;
auto buffer_size = BaseMachine::batch_size<T>(DATA_BIT,
this->buffer_size);
assert(honest_proc);
Player& P = honest_proc->P;
honest_prep.buffer_size = buffer_size;

View File

@@ -32,9 +32,8 @@ void MaliciousDabitOnlyPrep<T>::buffer_dabits(ThreadQueues* queues, false_type,
{
assert(this->proc != 0);
vector<dabit<T>> check_dabits;
DabitSacrifice<T> dabit_sacrifice;
this->buffer_dabits_without_check(check_dabits,
dabit_sacrifice.minimum_n_inputs(), queues);
dabit_sacrifice.minimum_n_inputs(this->buffer_size), queues);
dabit_sacrifice.sacrifice_and_check_bits(this->dabits, check_dabits,
*this->proc, queues);
}

View File

@@ -90,6 +90,8 @@ class MascotPrep : public virtual MaliciousRingPrep<T>,
public virtual MascotDabitOnlyPrep<T>
{
public:
static bool bits_from_triples() { return true; }
MascotPrep(SubProcessor<T>* proc, DataPositions& usage) :
BufferPrep<T>(usage), BitPrep<T>(proc, usage),
RingPrep<T>(proc, usage),

View File

@@ -62,8 +62,11 @@ void MascotTriplePrep<T>::buffer_triples()
auto& params = this->params;
auto& triple_generator = this->triple_generator;
params.generateBits = false;
triple_generator->set_batch_size(
BaseMachine::batch_size<T>(DATA_TRIPLE, this->buffer_size));
triple_generator->generate();
triple_generator->unlock();
triple_generator->set_batch_size(OnlineOptions::singleton.batch_size);
assert(triple_generator->uncheckedTriples.size() != 0);
for (auto& triple : triple_generator->uncheckedTriples)
this->triples.push_back(

View File

@@ -118,8 +118,7 @@ void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
template<class T>
void Rep3Shuffler<T>::del(int handle)
{
for (int i = 0; i < 2; i++)
shuffles.at(handle)[i].clear();
shuffles.at(handle) = {};
}
template<class T>

View File

@@ -42,6 +42,11 @@ class Rep4RingOnlyPrep : public virtual Rep4RingPrep<T>,
}
public:
static bool dabits_from_bits()
{
return true;
}
static void edabit_sacrifice_buckets(vector<edabit<T>>&, size_t, bool, int,
SubProcessor<T>&, int, int, const void* = 0)
{

View File

@@ -46,7 +46,7 @@ void Rep4RingPrep<T>::buffer_inputs(int player)
template<class T>
void Rep4RingPrep<T>::buffer_triples()
{
generate_triples(this->triples, OnlineOptions::singleton.batch_size,
generate_triples(this->triples, BaseMachine::batch_size<T>(DATA_TRIPLE),
this->protocol);
}
@@ -78,7 +78,7 @@ void Rep4RingPrep<T>::buffer_bits()
auto& protocol = this->proc->protocol;
vector<typename T::open_type> bits;
int batch_size = OnlineOptions::singleton.batch_size;
int batch_size = BaseMachine::batch_size<T>(DATA_BIT);
bits.reserve(batch_size);
for (int i = 0; i < batch_size; i++)
bits.push_back(G.get_bit());

View File

@@ -12,7 +12,7 @@ void RepRingOnlyEdabitPrep<T>::buffer_edabits(int n_bits, ThreadQueues*)
{
assert(this->proc);
int dl = T::bit_type::default_length;
int buffer_size = DIV_CEIL(this->buffer_size, dl) * dl;
int buffer_size = DIV_CEIL(BaseMachine::edabit_batch_size<T>(n_bits, this->buffer_size), dl) * dl;
vector<T> wholes;
wholes.resize(buffer_size);
Instruction inst;
@@ -49,5 +49,5 @@ void RepRingOnlyEdabitPrep<T>::buffer_edabits(int n_bits, ThreadQueues*)
SubProcessor<bt> bit_proc(party.MC->get_part_MC(), this->proc->bit_prep, P);
bit_adder.multi_add(sums, summands, 0, buffer_size / dl, bit_proc, dl, 0);
this->push_edabits(this->edabits[{false, n_bits}], wholes, sums, buffer_size);
this->push_edabits(this->edabits[{false, n_bits}], wholes, sums);
}

View File

@@ -15,6 +15,7 @@
#include "Protocols/ShuffleSacrifice.h"
#include "Tools/TimerWithComm.h"
#include "edabit.h"
#include "DabitSacrifice.h"
#include <array>
@@ -36,13 +37,13 @@ class BufferPrep : public Preprocessing<T>
friend class InScope;
static const bool homomorphic = false;
template<int>
void buffer_inverses(true_type);
template<int>
void buffer_inverses(false_type) { throw runtime_error("no inverses"); }
virtual bool bits_from_dabits() { return false; }
protected:
vector<array<T, 3>> triples;
vector<array<T, 2>> squares;
@@ -83,8 +84,9 @@ protected:
{ throw runtime_error("no personal daBits"); }
void push_edabits(vector<edabitvec<T>>& edabits,
const vector<T>& sums, const vector<vector<typename T::bit_type::part_type>>& bits,
int buffer_size);
const vector<T>& sums,
const vector<vector<typename T::bit_type::part_type>>& bits);
public:
typedef T share_type;
@@ -103,6 +105,10 @@ public:
throw runtime_error("sacrifice not available");
}
static bool bits_from_dabits() { return false; }
static bool bits_from_triples() { return false; }
static bool dabits_from_bits() { return false; }
BufferPrep(DataPositions& usage);
virtual ~BufferPrep();
@@ -135,6 +141,8 @@ public:
SubProcessor<T>* get_proc() { return proc; }
void set_proc(SubProcessor<T>* proc) { this->proc = proc; }
void buffer_extra(Dtype type, int n_items);
};
/**
@@ -272,7 +280,7 @@ public:
void buffer_edabits(int n_bits, false_type)
{ this->template buffer_edabits_without_check<0>(n_bits,
this->edabits[{false, n_bits}],
OnlineOptions::singleton.batch_size); }
BaseMachine::edabit_batch_size<T>(n_bits, this->buffer_size)); }
template<int>
void buffer_edabits(int, true_type)
{ throw not_implemented(); }
@@ -286,6 +294,8 @@ public:
template<class T>
class MaliciousDabitOnlyPrep : public virtual RingPrep<T>
{
DabitSacrifice<T> dabit_sacrifice;
template<int>
void buffer_dabits(ThreadQueues* queues, true_type, false_type);
template<int>
@@ -312,6 +322,8 @@ class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep<T>
{
typedef typename T::bit_type::part_type BT;
DabitSacrifice<T> dabit_sacrifice;
protected:
void buffer_personal_edabits(int n_bits, vector<T>& sums,
vector<vector<BT>>& bits, SubProcessor<BT>& proc, int input_player,
@@ -343,7 +355,7 @@ public:
bool strict, int player, SubProcessor<T>& proc, int begin, int end,
const void* supply = 0)
{
EdabitShuffleSacrifice<T>().edabit_sacrifice_buckets(to_check, n_bits, strict,
EdabitShuffleSacrifice<T>(n_bits).edabit_sacrifice_buckets(to_check, strict,
player, proc, begin, end, supply);
}

Some files were not shown because too many files have changed in this diff Show More