mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Maintenance.
This commit is contained in:
14
CHANGELOG.md
14
CHANGELOG.md
@@ -1,5 +1,19 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.3.7 (August 14, 2023)
|
||||
|
||||
- Path Oblivious Heap (@tskovlund)
|
||||
- Adjust batch and bucket size to program
|
||||
- Direct communication available in more protocols
|
||||
- Option for seed in fake preprocessing (@strieflin)
|
||||
- Lower memory usage due to improved register allocation
|
||||
- New instructions to speed up CISC compilation
|
||||
- Protocol implementation example
|
||||
- Fixed security bug: missing MAC checks in multi-threaded programs
|
||||
- Fixed security bug: race condition in MAC check
|
||||
- Fixed security bug: missing shuffling check in PS mod 2^k and Brain
|
||||
- Fixed security bug: insufficient drowning in pairwise protocols
|
||||
|
||||
## 0.3.6 (May 9, 2023)
|
||||
|
||||
- More extensive benchmarking outputs
|
||||
|
||||
@@ -781,8 +781,6 @@ class sbitvec(_vec, _bit):
|
||||
for i in range(n):
|
||||
for j, x in enumerate(v[i].bit_decompose()):
|
||||
x.store_in_mem(address + i + j * n)
|
||||
def reveal(self):
|
||||
return util.untuplify([x.reveal() for x in self.elements()])
|
||||
@classmethod
|
||||
def two_power(cls, nn, size=1):
|
||||
return cls.from_vec(
|
||||
@@ -919,8 +917,7 @@ class sbitvec(_vec, _bit):
|
||||
return self.v[:n_bits]
|
||||
bit_compose = from_vec
|
||||
def reveal(self):
|
||||
assert len(self) == 1
|
||||
return self.v[0].reveal()
|
||||
return util.untuplify([x.reveal() for x in self.elements()])
|
||||
def long_one(self):
|
||||
return [x.long_one() for x in self.v]
|
||||
def __rsub__(self, other):
|
||||
@@ -1279,7 +1276,8 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
|
||||
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
"""
|
||||
Vector of signed integers for parallel binary computation::
|
||||
Vector of signed integers for parallel binary computation.
|
||||
The following example uses vectors of size two::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
siv32 = sbitintvec.get_type(32)
|
||||
@@ -1291,7 +1289,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
c = (a - b).elements()
|
||||
print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
c = (a < b).bit_decompose()
|
||||
c = (a < b).elements()
|
||||
print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
|
||||
This should output::
|
||||
@@ -1467,7 +1465,7 @@ class sbitfixvec(_fix, _vec):
|
||||
print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
c = (a - b).elements()
|
||||
print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
c = (a < b).bit_decompose()
|
||||
c = (a < b).elements()
|
||||
print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal())
|
||||
|
||||
This should output roughly::
|
||||
|
||||
@@ -43,7 +43,7 @@ class BlockAllocator:
|
||||
else:
|
||||
done = False
|
||||
for x in self.by_logsize[logsize + 1:]:
|
||||
for block_size, addresses in x.items():
|
||||
for block_size, addresses in sorted(x.items()):
|
||||
if len(addresses) > 0:
|
||||
done = True
|
||||
break
|
||||
@@ -60,16 +60,92 @@ class BlockAllocator:
|
||||
self.by_address[addr + size] = diff
|
||||
return addr
|
||||
|
||||
class AllocRange:
|
||||
def __init__(self, base=0):
|
||||
self.base = base
|
||||
self.top = base
|
||||
self.limit = base
|
||||
self.grow = True
|
||||
self.pool = defaultdict(set)
|
||||
|
||||
def alloc(self, size):
|
||||
if self.pool[size]:
|
||||
return self.pool[size].pop()
|
||||
elif self.grow or self.top + size <= self.limit:
|
||||
res = self.top
|
||||
self.top += size
|
||||
self.limit = max(self.limit, self.top)
|
||||
if res >= REG_MAX:
|
||||
raise RegisterOverflowError()
|
||||
return res
|
||||
|
||||
def free(self, base, size):
|
||||
assert self.base <= base < self.top
|
||||
self.pool[size].add(base)
|
||||
|
||||
def stop_growing(self):
|
||||
self.grow = False
|
||||
|
||||
def consolidate(self):
|
||||
regs = []
|
||||
for size, pool in self.pool.items():
|
||||
for base in pool:
|
||||
regs.append((base, size))
|
||||
for base, size in reversed(sorted(regs)):
|
||||
if base + size == self.top:
|
||||
self.top -= size
|
||||
self.pool[size].remove(base)
|
||||
regs.pop()
|
||||
else:
|
||||
if program.Program.prog.verbose:
|
||||
print('cannot free %d register blocks '
|
||||
'by a gap of %d at %d' %
|
||||
(len(regs), self.top - size - base, base))
|
||||
break
|
||||
|
||||
class AllocPool:
|
||||
def __init__(self):
|
||||
self.ranges = defaultdict(lambda: [AllocRange()])
|
||||
self.by_base = {}
|
||||
|
||||
def alloc(self, reg_type, size):
|
||||
for r in self.ranges[reg_type]:
|
||||
res = r.alloc(size)
|
||||
if res is not None:
|
||||
self.by_base[reg_type, res] = r
|
||||
return res
|
||||
|
||||
def free(self, reg):
|
||||
r = self.by_base.pop((reg.reg_type, reg.i))
|
||||
r.free(reg.i, reg.size)
|
||||
|
||||
def new_ranges(self, min_usage):
|
||||
for t, n in min_usage.items():
|
||||
r = self.ranges[t][-1]
|
||||
assert (n >= r.limit)
|
||||
if r.limit < n:
|
||||
r.stop_growing()
|
||||
self.ranges[t].append(AllocRange(n))
|
||||
|
||||
def consolidate(self):
|
||||
for r in self.ranges.values():
|
||||
for rr in r:
|
||||
rr.consolidate()
|
||||
|
||||
def n_fragments(self):
|
||||
return max(len(r) for r in self.ranges)
|
||||
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
It is based on the precondition that every register is only defined once."""
|
||||
def __init__(self, n, program):
|
||||
self.alloc = dict_by_id()
|
||||
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
|
||||
self.max_usage = defaultdict(lambda: 0)
|
||||
self.defined = dict_by_id()
|
||||
self.dealloc = set_by_id()
|
||||
self.n = n
|
||||
assert(n == REG_MAX)
|
||||
self.program = program
|
||||
self.old_pool = None
|
||||
|
||||
def alloc_reg(self, reg, free):
|
||||
base = reg.vectorbase
|
||||
@@ -79,14 +155,7 @@ class StraightlineAllocator:
|
||||
|
||||
reg_type = reg.reg_type
|
||||
size = base.size
|
||||
if free[reg_type, size]:
|
||||
res = free[reg_type, size].pop()
|
||||
else:
|
||||
if self.usage[reg_type] < self.n:
|
||||
res = self.usage[reg_type]
|
||||
self.usage[reg_type] += size
|
||||
else:
|
||||
raise RegisterOverflowError()
|
||||
res = free.alloc(reg_type, size)
|
||||
self.alloc[base] = res
|
||||
|
||||
base.i = self.alloc[base]
|
||||
@@ -126,7 +195,7 @@ class StraightlineAllocator:
|
||||
for x in itertools.chain(dup.duplicates, base.duplicates):
|
||||
to_check.add(x)
|
||||
|
||||
free[reg.reg_type, base.size].append(self.alloc[base])
|
||||
free.free(base)
|
||||
if inst.is_vec() and base.vector:
|
||||
self.defined[base] = inst
|
||||
for i in base.vector:
|
||||
@@ -135,6 +204,7 @@ class StraightlineAllocator:
|
||||
self.defined[reg] = inst
|
||||
|
||||
def process(self, program, alloc_pool):
|
||||
self.update_usage(alloc_pool)
|
||||
for k,i in enumerate(reversed(program)):
|
||||
unused_regs = []
|
||||
for j in i.get_def():
|
||||
@@ -161,12 +231,26 @@ class StraightlineAllocator:
|
||||
if k % 1000000 == 0 and k > 0:
|
||||
print("Allocated registers for %d instructions at" % k, time.asctime())
|
||||
|
||||
self.update_max_usage(alloc_pool)
|
||||
alloc_pool.consolidate()
|
||||
|
||||
# print "Successfully allocated registers"
|
||||
# print "modp usage: %d clear, %d secret" % \
|
||||
# (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp])
|
||||
# print "GF2N usage: %d clear, %d secret" % \
|
||||
# (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N])
|
||||
return self.usage
|
||||
return self.max_usage
|
||||
|
||||
def update_max_usage(self, alloc_pool):
|
||||
for t, r in alloc_pool.ranges.items():
|
||||
self.max_usage[t] = max(self.max_usage[t], r[-1].limit)
|
||||
|
||||
def update_usage(self, alloc_pool):
|
||||
if self.old_pool:
|
||||
self.update_max_usage(self.old_pool)
|
||||
if id(self.old_pool) != id(alloc_pool):
|
||||
alloc_pool.new_ranges(self.max_usage)
|
||||
self.old_pool = alloc_pool
|
||||
|
||||
def finalize(self, options):
|
||||
for reg in self.alloc:
|
||||
@@ -178,6 +262,21 @@ class StraightlineAllocator:
|
||||
'\t\t'))
|
||||
if options.stop:
|
||||
sys.exit(1)
|
||||
if self.program.verbose:
|
||||
def p(sizes):
|
||||
total = defaultdict(lambda: 0)
|
||||
for (t, size) in sorted(sizes):
|
||||
n = sizes[t, size]
|
||||
total[t] += size * n
|
||||
print('%s:%d*%d' % (t, size, n), end=' ')
|
||||
print()
|
||||
print('Total:', dict(total))
|
||||
|
||||
sizes = defaultdict(lambda: 0)
|
||||
for reg in self.alloc:
|
||||
x = reg.reg_type, reg.size
|
||||
print('Used registers: ', end='')
|
||||
p(sizes)
|
||||
|
||||
def determine_scope(block, options):
|
||||
last_def = defaultdict_by_id(lambda: -1)
|
||||
|
||||
@@ -10,7 +10,7 @@ the ones used below into ``Programs/Circuits`` as follows::
|
||||
"""
|
||||
|
||||
from Compiler.GC.types import *
|
||||
from Compiler.library import function_block
|
||||
from Compiler.library import function_block, get_tape
|
||||
from Compiler import util
|
||||
import itertools
|
||||
import struct
|
||||
@@ -54,7 +54,7 @@ class Circuit:
|
||||
return self.run(*inputs)
|
||||
|
||||
def run(self, *inputs):
|
||||
n = inputs[0][0].n
|
||||
n = inputs[0][0].n, get_tape()
|
||||
if n not in self.functions:
|
||||
self.functions[n] = function_block(lambda *args:
|
||||
self.compile(*args))
|
||||
|
||||
@@ -270,9 +270,9 @@ class Compiler:
|
||||
self.prog = Program(self.args, self.options, name=name)
|
||||
if self.execute:
|
||||
if self.options.execute in \
|
||||
("emulate", "ring", "rep-field"):
|
||||
("emulate", "ring", "rep-field", "rep4-ring"):
|
||||
self.prog.use_trunc_pr = True
|
||||
if self.options.execute in ("ring",):
|
||||
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
|
||||
self.prog.use_split(3)
|
||||
if self.options.execute in ("semi2k",):
|
||||
self.prog.use_split(2)
|
||||
@@ -487,6 +487,7 @@ class Compiler:
|
||||
"Cannot produce %s. " % executable + \
|
||||
"Note that compilation requires a few GB of RAM.")
|
||||
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
|
||||
sys.stdout.flush()
|
||||
os.execl(vm, vm, self.prog.name, *args)
|
||||
|
||||
def remote_execution(self, args=[]):
|
||||
|
||||
@@ -633,7 +633,8 @@ def preprocess_pandas(data):
|
||||
res.append(data.iloc[:,i].to_numpy())
|
||||
types.append('c')
|
||||
elif pandas.api.types.is_object_dtype(t):
|
||||
values = data.iloc[:,i].unique()
|
||||
values = list(filter(lambda x: isinstance(x, str),
|
||||
list(data.iloc[:,i].unique())))
|
||||
print('converting the following to unary:', values)
|
||||
if len(values) == 2:
|
||||
res.append(data.iloc[:,i].to_numpy() == values[1])
|
||||
|
||||
@@ -638,6 +638,44 @@ class prefixsums(base.Instruction):
|
||||
code = base.opcodes['PREFIXSUMS']
|
||||
arg_format = ['sw','s']
|
||||
|
||||
class picks(base.VectorInstruction):
|
||||
""" Extract part of vector.
|
||||
|
||||
:param: result (sint)
|
||||
:param: input (sint)
|
||||
:param: start offset (int)
|
||||
:param: step
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['PICKS']
|
||||
arg_format = ['sw','s','int','int']
|
||||
|
||||
def __init__(self, *args):
|
||||
super(picks, self).__init__(*args)
|
||||
assert 0 <= args[2] < len(args[1])
|
||||
assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1])
|
||||
|
||||
class concats(base.VectorInstruction):
|
||||
""" Concatenate vectors.
|
||||
|
||||
:param: result (sint)
|
||||
:param: start offset (int)
|
||||
:param: input (sint)
|
||||
:param: (repeat from offset)...
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['CONCATS']
|
||||
arg_format = tools.chain(['sw'], tools.cycle(['int','s']))
|
||||
|
||||
def __init__(self, *args):
|
||||
super(concats, self).__init__(*args)
|
||||
assert len(args) % 2 == 1
|
||||
assert len(args[0]) == sum(args[1::2])
|
||||
for i in range(1, len(args), 2):
|
||||
assert args[i] == len(args[i + 1])
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class mulc(base.MulBase):
|
||||
|
||||
@@ -82,6 +82,8 @@ opcodes = dict(
|
||||
SUBCFI = 0x2B,
|
||||
SUBSFI = 0x2C,
|
||||
PREFIXSUMS = 0x2D,
|
||||
PICKS = 0x2E,
|
||||
CONCATS = 0x2F,
|
||||
# Multiplication/division
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -523,29 +525,26 @@ def cisc(function):
|
||||
tape.active_basicblock = block
|
||||
size = sum(call[0][0].size for call in self.calls)
|
||||
new_regs = []
|
||||
for arg in self.args:
|
||||
for i, arg in enumerate(self.args):
|
||||
try:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
except TypeError:
|
||||
if i == 0:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
else:
|
||||
new_regs.append(type(arg).concat(
|
||||
call[0][i] for call in self.calls))
|
||||
assert len(new_regs[-1]) == size
|
||||
except (TypeError, AttributeError):
|
||||
if not isinstance(arg, int):
|
||||
raise
|
||||
break
|
||||
except:
|
||||
print([call[0][0].size for call in self.calls])
|
||||
raise
|
||||
assert len(new_regs) > 1
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
|
||||
set_global_vector_size(reg.size)
|
||||
reg.mov(new_reg.get_vector(base, reg.size), reg)
|
||||
reset_global_vector_size()
|
||||
base += reg.size
|
||||
self.new_instructions(size, new_regs)
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
reg = call[0][0]
|
||||
set_global_vector_size(reg.size)
|
||||
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
|
||||
reset_global_vector_size()
|
||||
reg.copy_from_part(new_regs[0], base, reg.size)
|
||||
base += reg.size
|
||||
return block.instructions, self.n_rounds - 1
|
||||
|
||||
@@ -628,7 +627,7 @@ def sfix_cisc(function):
|
||||
instruction = cisc(instruction)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if isinstance(args[0], sfix):
|
||||
if isinstance(args[0], sfix) and program.options.cisc:
|
||||
for arg in args[1:]:
|
||||
assert util.is_constant(arg)
|
||||
assert not kwargs
|
||||
@@ -844,6 +843,12 @@ class Instruction(object):
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
if Instruction.count > 10 ** 7:
|
||||
print("Compilation produced more that 10 million instructions. "
|
||||
"Consider using './compile.py -l' or replacing for loops "
|
||||
"with @for_range_opt: "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/Compiler.html#"
|
||||
"Compiler.library.for_range_opt")
|
||||
|
||||
def get_code(self, prefix=0):
|
||||
return (prefix << self.code_length) + self.code
|
||||
|
||||
@@ -6,7 +6,7 @@ in particularly providing flow control and output.
|
||||
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc, _vec
|
||||
from Compiler.instructions import *
|
||||
from Compiler.util import tuplify,untuplify,is_zero
|
||||
from Compiler.allocator import RegintOptimizer
|
||||
from Compiler.allocator import RegintOptimizer, AllocPool
|
||||
from Compiler import instructions,instructions_base,comparison,program,util
|
||||
import inspect,math
|
||||
import random
|
||||
@@ -411,7 +411,7 @@ class FunctionBlock(Function):
|
||||
parent_node = get_tape().req_node
|
||||
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
|
||||
block = get_tape().active_basicblock
|
||||
block.alloc_pool = defaultdict(list)
|
||||
block.alloc_pool = AllocPool()
|
||||
del parent_node.children[-1]
|
||||
self.node = get_tape().req_node
|
||||
if get_program().verbose:
|
||||
@@ -935,7 +935,7 @@ def for_range_opt_multithread(n_threads, n_loops):
|
||||
...
|
||||
|
||||
Note that you cannot use registers across threads. Use
|
||||
:py:class:`MemValue` instead::
|
||||
:py:class:`~Compiler.types.MemValue` instead::
|
||||
|
||||
a = MemValue(sint(0))
|
||||
@for_range_opt_multithread(8, 80)
|
||||
@@ -1069,6 +1069,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
prog.free_later()
|
||||
if len(state):
|
||||
if thread_rounds:
|
||||
for i in range(n_threads - remainder):
|
||||
@@ -1320,6 +1321,7 @@ def if_then(condition):
|
||||
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
|
||||
name='if-block')
|
||||
state.has_else = False
|
||||
state.closed_if = False
|
||||
state.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
instructions.program.curr_tape.if_states.append(state)
|
||||
|
||||
@@ -1434,6 +1436,7 @@ def if_e(condition):
|
||||
else:
|
||||
if_then(condition)
|
||||
_run_and_link(body)
|
||||
get_tape().if_states[-1].closed_if = True
|
||||
return decorator
|
||||
|
||||
def else_(body):
|
||||
@@ -1443,6 +1446,8 @@ def else_(body):
|
||||
_run_and_link(body)
|
||||
if_states.pop()
|
||||
else:
|
||||
if not if_states[-1].closed_if:
|
||||
raise CompilerError('@if_e not closed before else block')
|
||||
else_then()
|
||||
_run_and_link(body)
|
||||
end_if()
|
||||
|
||||
@@ -822,10 +822,10 @@ class Dense(DenseBase):
|
||||
self.W.randomize(-r, r, n_threads=self.n_threads)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
self.W.input_from(player, raw=raw)
|
||||
def input_from(self, player, **kwargs):
|
||||
self.W.input_from(player, **kwargs)
|
||||
if self.input_bias:
|
||||
self.b.input_from(player, raw=raw)
|
||||
self.b.input_from(player, **kwargs)
|
||||
|
||||
def compute_f_input(self, batch):
|
||||
N = len(batch)
|
||||
@@ -1088,10 +1088,7 @@ class Relu(ElementWiseLayer):
|
||||
|
||||
:param shape: input/output shape (tuple/list of int)
|
||||
"""
|
||||
f = staticmethod(relu)
|
||||
f_prime = staticmethod(relu_prime)
|
||||
prime_type = sint
|
||||
comparisons = None
|
||||
|
||||
def __init__(self, shape, inputs=None):
|
||||
super(Relu, self).__init__(shape)
|
||||
@@ -1310,12 +1307,12 @@ class FusedBatchNorm(Layer):
|
||||
self.bias = sfix.Array(shape[3])
|
||||
self.inputs = inputs
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
self.weights.input_from(player, raw=raw)
|
||||
self.bias.input_from(player, raw=raw)
|
||||
def input_from(self, player, **kwargs):
|
||||
self.weights.input_from(player, **kwargs)
|
||||
self.bias.input_from(player, **kwargs)
|
||||
tmp = sfix.Array(len(self.bias))
|
||||
tmp.input_from(player, raw=raw)
|
||||
tmp.input_from(player, raw=raw)
|
||||
tmp.input_from(player, **kwargs)
|
||||
tmp.input_from(player, **kwargs)
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
assert len(batch) == 1
|
||||
@@ -1611,11 +1608,11 @@ class ConvBase(BaseLayer):
|
||||
self.bias_shape, self.Y.sizes, self.stride, repr(self.padding),
|
||||
self.tf_weight_format)
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
def input_from(self, player, **kwargs):
|
||||
self.input_params_from(player)
|
||||
self.weights.input_from(player, budget=100000, raw=raw)
|
||||
self.weights.input_from(player, budget=100000, **kwargs)
|
||||
if self.input_bias:
|
||||
self.bias.input_from(player, raw=raw)
|
||||
self.bias.input_from(player, **kwargs)
|
||||
|
||||
def output_weights(self):
|
||||
self.weights.print_reveal_nested()
|
||||
|
||||
@@ -112,8 +112,8 @@ class Program(object):
|
||||
self.non_linear = Prime(self.security)
|
||||
if not self.bit_length:
|
||||
self.bit_length = 64
|
||||
print("Default bit length:", self.bit_length)
|
||||
print("Default security parameter:", self.security)
|
||||
print("Default bit length for compilation:", self.bit_length)
|
||||
print("Default security parameter for compilation:", self.security)
|
||||
self.galois_length = int(options.galois)
|
||||
if self.verbose:
|
||||
print("Galois length:", self.galois_length)
|
||||
@@ -122,6 +122,7 @@ class Program(object):
|
||||
self.DEBUG = options.debug
|
||||
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
|
||||
self.free_mem_blocks = defaultdict(al.BlockAllocator)
|
||||
self.later_mem_blocks = defaultdict(list)
|
||||
self.allocated_mem_blocks = {}
|
||||
self.saved = 0
|
||||
self.req_num = None
|
||||
@@ -229,24 +230,26 @@ class Program(object):
|
||||
if self.name.endswith(ext):
|
||||
self.name = self.name[:-len(ext)]
|
||||
|
||||
if os.path.exists(args[0]):
|
||||
self.infile = args[0]
|
||||
infiles = [args[0]]
|
||||
for x in (self.programs_dir, sys.path[0] + "/Programs"):
|
||||
for ext in exts:
|
||||
filename = args[0]
|
||||
if not filename.endswith(ext):
|
||||
filename += ext
|
||||
filename = x + "/Source/" + filename
|
||||
if os.path.abspath(filename) not in \
|
||||
[os.path.abspath(f) for f in infiles]:
|
||||
infiles += [filename]
|
||||
existing = [f for f in infiles if os.path.exists(f)]
|
||||
if len(existing) == 1:
|
||||
self.infile = existing[0]
|
||||
elif len(existing) > 1:
|
||||
raise CompilerError("ambiguous input files: " +
|
||||
", ".join(existing))
|
||||
else:
|
||||
infiles = []
|
||||
for x in (self.programs_dir, sys.path[0] + "/Programs"):
|
||||
for ext in exts:
|
||||
filename = args[0]
|
||||
if not filename.endswith(ext):
|
||||
filename += ext
|
||||
infiles += [x + "/Source/" + filename]
|
||||
for f in infiles:
|
||||
if os.path.exists(f):
|
||||
self.infile = f
|
||||
break
|
||||
else:
|
||||
raise CompilerError(
|
||||
"found none of the potential input files: " +
|
||||
", ".join("'%s'" % x for x in [args[0]] + infiles))
|
||||
raise CompilerError(
|
||||
"found none of the potential input files: " +
|
||||
", ".join("'%s'" % x for x in [args[0]] + infiles))
|
||||
"""
|
||||
self.name is input file name (minus extension) + any optional arguments.
|
||||
Used to generate output filenames
|
||||
@@ -463,7 +466,7 @@ class Program(object):
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
if addr + size >= 2**64:
|
||||
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
|
||||
self.allocated_mem_blocks[addr, mem_type] = size
|
||||
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
|
||||
if single_size:
|
||||
from .library import get_thread_number, runtime_error_if
|
||||
|
||||
@@ -477,12 +480,24 @@ class Program(object):
|
||||
|
||||
def free(self, addr, mem_type):
|
||||
"""Free memory"""
|
||||
if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool:
|
||||
raise CompilerError("Cannot free memory within function block")
|
||||
now = True
|
||||
if not util.is_constant(addr):
|
||||
addr = self.base_addresses[str(addr)]
|
||||
size = self.allocated_mem_blocks.pop((addr, mem_type))
|
||||
self.free_mem_blocks[mem_type].push(addr, size)
|
||||
now = self.curr_tape == self.tapes[0]
|
||||
size, pool = self.allocated_mem_blocks[addr, mem_type]
|
||||
if self.curr_block.alloc_pool is not pool:
|
||||
raise CompilerError("Cannot free memory across function blocks")
|
||||
self.allocated_mem_blocks.pop((addr, mem_type))
|
||||
if now:
|
||||
self.free_mem_blocks[mem_type].push(addr, size)
|
||||
else:
|
||||
self.later_mem_blocks[mem_type].append((addr, size))
|
||||
|
||||
def free_later(self):
|
||||
for mem_type in self.later_mem_blocks:
|
||||
for block in self.later_mem_blocks[mem_type]:
|
||||
self.free_mem_blocks[mem_type].push(*block)
|
||||
self.later_mem_blocks.clear()
|
||||
|
||||
def finalize(self):
|
||||
# optimize the tapes
|
||||
@@ -744,6 +759,7 @@ class Tape:
|
||||
self.purged = False
|
||||
self.block_counter = 0
|
||||
self.active_basicblock = None
|
||||
self.old_allocated_mem = program.allocated_mem.copy()
|
||||
self.start_new_basicblock()
|
||||
self._is_empty = False
|
||||
self.merge_opens = True
|
||||
@@ -771,7 +787,7 @@ class Tape:
|
||||
scope.children.append(self)
|
||||
self.alloc_pool = scope.alloc_pool
|
||||
else:
|
||||
self.alloc_pool = defaultdict(list)
|
||||
self.alloc_pool = al.AllocPool()
|
||||
self.purged = False
|
||||
self.n_rounds = 0
|
||||
self.n_to_merge = 0
|
||||
@@ -869,6 +885,15 @@ class Tape:
|
||||
return self._is_empty
|
||||
|
||||
def start_new_basicblock(self, scope=False, name=""):
|
||||
if self.program.verbose and self.active_basicblock and \
|
||||
self.program.allocated_mem != self.old_allocated_mem:
|
||||
print("New allocated memory in %s " % self.active_basicblock.name,
|
||||
end="")
|
||||
for t, n in self.program.allocated_mem.items():
|
||||
if n != self.old_allocated_mem[t]:
|
||||
print("%s:%d " % (t, n - self.old_allocated_mem[t]), end="")
|
||||
print()
|
||||
self.old_allocated_mem = self.program.allocated_mem.copy()
|
||||
# use False because None means no scope
|
||||
if scope is False:
|
||||
scope = self.active_basicblock
|
||||
@@ -1029,6 +1054,7 @@ class Tape:
|
||||
allocator = al.StraightlineAllocator(REG_MAX, self.program)
|
||||
|
||||
def alloc(block):
|
||||
allocator.update_usage(block.alloc_pool)
|
||||
for reg in sorted(
|
||||
block.used_from_scope, key=lambda x: (x.reg_type, x.i)
|
||||
):
|
||||
@@ -1042,6 +1068,7 @@ class Tape:
|
||||
for child in block.children:
|
||||
left.append(child)
|
||||
|
||||
allocator.old_pool = None
|
||||
for i, block in enumerate(reversed(self.basicblocks)):
|
||||
if len(block.instructions) > 1000000:
|
||||
print(
|
||||
@@ -1055,10 +1082,20 @@ class Tape:
|
||||
and block.exit_block.scope is not None
|
||||
):
|
||||
alloc_loop(block.exit_block.scope)
|
||||
usage = allocator.max_usage.copy()
|
||||
allocator.process(block.instructions, block.alloc_pool)
|
||||
if self.program.verbose and usage != allocator.max_usage:
|
||||
print("Allocated registers in %s " % block.name, end="")
|
||||
for t, n in allocator.max_usage.items():
|
||||
if n > usage[t]:
|
||||
print("%s:%d " % (t, n - usage[t]), end="")
|
||||
print()
|
||||
allocator.finalize(options)
|
||||
if self.program.verbose:
|
||||
print("Tape register usage:", dict(allocator.usage))
|
||||
print("Tape register usage:", dict(allocator.max_usage))
|
||||
scopes = set(block.alloc_pool for block in self.basicblocks)
|
||||
n_fragments = sum(scope.n_fragments() for scope in scopes)
|
||||
print("%d register fragments in %d scopes" % (n_fragments, len(scopes)))
|
||||
|
||||
# offline data requirements
|
||||
if self.program.verbose:
|
||||
@@ -1499,6 +1536,7 @@ class Tape:
|
||||
if Program.prog.options.noreallocate:
|
||||
raise CompilerError("reallocation necessary for linking, "
|
||||
"remove option -u")
|
||||
assert self.reg_type == other.reg_type
|
||||
self.duplicates |= other.duplicates
|
||||
for dup in self.duplicates:
|
||||
dup.duplicates = self.duplicates
|
||||
|
||||
@@ -38,6 +38,7 @@ by party 0 and 1::
|
||||
sfloat
|
||||
sgf2n
|
||||
cgf2n
|
||||
personal
|
||||
|
||||
Container types
|
||||
---------------
|
||||
@@ -392,7 +393,7 @@ class _int(Tape._no_truth):
|
||||
return a - prod, b + prod
|
||||
|
||||
def bit_xor(self, other):
|
||||
""" XOR in arithmetic circuits.
|
||||
""" Single-bit XOR in arithmetic circuits.
|
||||
|
||||
:param self/other: 0 or 1 (any compatible type)
|
||||
:return: type depends on inputs (secret if any of them is) """
|
||||
@@ -404,7 +405,7 @@ class _int(Tape._no_truth):
|
||||
return self + other - 2 * self * other
|
||||
|
||||
def bit_or(self, other):
|
||||
""" OR in arithmetic circuits.
|
||||
""" Single-bit OR in arithmetic circuits.
|
||||
|
||||
:param self/other: 0 or 1 (any compatible type)
|
||||
:return: type depends on inputs (secret if any of them is) """
|
||||
@@ -416,14 +417,14 @@ class _int(Tape._no_truth):
|
||||
return self + other - self * other
|
||||
|
||||
def bit_and(self, other):
|
||||
""" AND in arithmetic circuits.
|
||||
""" Single-bit AND in arithmetic circuits.
|
||||
|
||||
:param self/other: 0 or 1 (any compatible type)
|
||||
:rtype: depending on inputs (secret if any of them is) """
|
||||
return self * other
|
||||
|
||||
def bit_not(self):
|
||||
""" NOT in arithmetic circuits. """
|
||||
""" Single-bit NOT in arithmetic circuits. """
|
||||
return 1 - self
|
||||
|
||||
def half_adder(self, other):
|
||||
@@ -611,6 +612,8 @@ class _secret_structure(_structure):
|
||||
if program.curr_tape != program.tapes[0]:
|
||||
raise CompilerError('only available in main thread')
|
||||
if content is not None:
|
||||
if isinstance(content, (_vectorizable, Tape.Register)):
|
||||
raise CompilerError('cannot input data already in the VM')
|
||||
requested_shape = shape
|
||||
if binary:
|
||||
import numpy
|
||||
@@ -800,11 +803,31 @@ class _register(Tape.Register, _number, _structure):
|
||||
if self.size == size:
|
||||
return self
|
||||
assert self.size == 1
|
||||
return self._expand_to_vector(size)
|
||||
|
||||
def _expand_to_vector(self, size):
|
||||
res = type(self)(size=size)
|
||||
for i in range(size):
|
||||
self.mov(res[i], self)
|
||||
return res
|
||||
|
||||
def copy_from_part(self, source, base, size):
|
||||
set_global_vector_size(size)
|
||||
self.mov(self, source.get_vector(base, size))
|
||||
reset_global_vector_size()
|
||||
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
parts = list(parts)
|
||||
res = cls(size=sum(len(part) for part in parts))
|
||||
base = 0
|
||||
for reg in parts:
|
||||
set_global_vector_size(reg.size)
|
||||
reg.mov(res.get_vector(base, reg.size), reg)
|
||||
reset_global_vector_size()
|
||||
base += reg.size
|
||||
return res
|
||||
|
||||
class _arithmetic_register(_register):
|
||||
""" Arithmetic circuit type. """
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -1508,6 +1531,14 @@ class regint(_register, _int):
|
||||
raise CompilerError("Cannot convert '%s' to integer" % \
|
||||
type(val))
|
||||
|
||||
def expand_to_vector(self, size=None):
|
||||
if size is None:
|
||||
size = get_global_vector_size()
|
||||
if self.size == size:
|
||||
return self
|
||||
assert self.size == 1
|
||||
return self.inc(size, self, 0)
|
||||
|
||||
@vectorize
|
||||
@read_mem_value
|
||||
def int_op(self, other, inst, reverse=False):
|
||||
@@ -1762,7 +1793,8 @@ class localint(Tape._no_truth):
|
||||
class personal(Tape._no_truth):
|
||||
""" Value known to one player. Supports operations with public
|
||||
values and personal values known to the same player. Can be used
|
||||
with :py:func:`~Compiler.library.print_ln_to`.
|
||||
with :py:func:`~Compiler.library.print_ln_to`. It is possible to
|
||||
convert to secret types like :py:class:`sint`.
|
||||
|
||||
:param player: player (int)
|
||||
:param value: cleartext value (cint, cfix, cfloat) or array thereof
|
||||
@@ -2023,11 +2055,16 @@ class _secret(_arithmetic_register, _secret_structure):
|
||||
|
||||
:rtype: same as inputs
|
||||
"""
|
||||
x = list(x)
|
||||
set_global_vector_size(x[0].size)
|
||||
res = cls()
|
||||
dotprods(res, x, y)
|
||||
reset_global_vector_size()
|
||||
if isinstance(x, cls) and isinstance(y, cls):
|
||||
assert len(x) == len(y)
|
||||
res = cls()
|
||||
matmuls(res, x, y, 1, len(x), 1)
|
||||
else:
|
||||
x = list(x)
|
||||
set_global_vector_size(x[0].size)
|
||||
res = cls()
|
||||
dotprods(res, x, y)
|
||||
reset_global_vector_size()
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@@ -2952,6 +2989,27 @@ class sint(_secret, _int):
|
||||
prefixsums(res, self)
|
||||
return res
|
||||
|
||||
def sum(self):
|
||||
res = type(self)(size=1)
|
||||
picks(res, self.prefix_sum(), len(self) - 1, 0)
|
||||
return res
|
||||
|
||||
def _expand_to_vector(self, size):
|
||||
res = type(self)(size=size)
|
||||
picks(res, self, 0, 0)
|
||||
return res
|
||||
|
||||
def copy_from_part(self, source, base, size):
|
||||
picks(self, source, base, 1)
|
||||
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
parts = list(parts)
|
||||
res = cls(size=sum(len(part) for part in parts))
|
||||
args = sum(([len(part), part] for part in parts), [])
|
||||
concats(res, *args)
|
||||
return res
|
||||
|
||||
class sintbit(sint):
|
||||
""" :py:class:`sint` holding a bit, supporting binary operations
|
||||
(``&, |, ^``). """
|
||||
@@ -4726,6 +4784,24 @@ class sfix(_fix):
|
||||
def prefix_sum(self):
|
||||
return self._new(self.v.prefix_sum(), k=self.k, f=self.f)
|
||||
|
||||
def sum(self):
|
||||
return self._new(self.v.sum())
|
||||
|
||||
@classmethod
|
||||
def concat(cls, parts):
|
||||
parts = list(parts)
|
||||
int_parts = []
|
||||
f = parts[0].f
|
||||
k = parts[0].k
|
||||
for part in parts:
|
||||
assert part.f == f
|
||||
assert part.k == k
|
||||
int_parts.append(part.v)
|
||||
return cls._new(cls.int_type.concat(int_parts), k=k, f=f)
|
||||
|
||||
def __repr__(self):
|
||||
return '<sfix{f=%d,k=%d} at %s>' % (self.f, self.k, self.v)
|
||||
|
||||
class unreduced_sfix(_single):
|
||||
int_type = sint
|
||||
|
||||
@@ -5739,9 +5815,11 @@ class Array(_vectorizable):
|
||||
if value.size != 1:
|
||||
raise CompilerError('cannot assign vector to all elements')
|
||||
mem_value = MemValue(value)
|
||||
self.address = MemValue.if_necessary(self.address)
|
||||
n_threads = 8 if use_threads and util.is_constant(self.length) and \
|
||||
len(self) > 2**20 and not program.options.garbled else None
|
||||
len(self) > 2**20 and not program.options.garbled and \
|
||||
program.curr_tape.singular else None
|
||||
if n_threads is not None:
|
||||
self.address = MemValue.if_necessary(self.address)
|
||||
@library.multithread(n_threads, self.length)
|
||||
def _(base, size):
|
||||
if use_vector:
|
||||
@@ -6444,12 +6522,20 @@ class SubMultiArray(_vectorizable):
|
||||
try:
|
||||
try:
|
||||
self.value_type.direct_matrix_mul
|
||||
assert self.value_type == other.value_type
|
||||
skip_reduce = set((sint, sfix)) == \
|
||||
set((self.value_type, other.value_type))
|
||||
assert self.value_type == other.value_type or skip_reduce
|
||||
max_size = _register.maximum_size // res_matrix.sizes[1]
|
||||
@library.multithread(n_threads, self.sizes[0], max_size)
|
||||
def _(base, size):
|
||||
res_matrix.assign_part_vector(
|
||||
self.get_part(base, size).direct_mul(other), base)
|
||||
tmp = self.get_part(base, size).direct_mul(
|
||||
other, reduce=not skip_reduce,
|
||||
res_type=sfix if skip_reduce else None)
|
||||
if skip_reduce:
|
||||
tmp = t._new(tmp.v)
|
||||
else:
|
||||
tmp = tmp.reduce_after_mul()
|
||||
res_matrix.assign_part_vector(tmp, base)
|
||||
except AttributeError:
|
||||
assert n_threads is None
|
||||
if max(res_matrix.sizes) > 1000:
|
||||
@@ -6483,7 +6569,7 @@ class SubMultiArray(_vectorizable):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def direct_mul(self, other, reduce=True, indices=None):
|
||||
def direct_mul(self, other, reduce=True, indices=None, res_type=None):
|
||||
""" Matrix multiplication in the virtual machine.
|
||||
Unlike :py:func:`dot`, this only works for sint and sfix, and it
|
||||
returns a vector instead of a data structure.
|
||||
@@ -6511,10 +6597,15 @@ class SubMultiArray(_vectorizable):
|
||||
other_sizes = other.sizes
|
||||
assert len(other.sizes) == 2
|
||||
assert self.sizes[1] == other_sizes[0]
|
||||
assert self.value_type == other.value_type
|
||||
return self.value_type.direct_matrix_mul(self.address, other.address,
|
||||
self.sizes[0], *other_sizes,
|
||||
reduce=reduce, indices=indices)
|
||||
if self.value_type == other.value_type:
|
||||
assert res_type in (self.value_type, None)
|
||||
res_type = self.value_type
|
||||
else:
|
||||
assert not reduce
|
||||
assert res_type
|
||||
return res_type.direct_matrix_mul(self.address, other.address,
|
||||
self.sizes[0], *other_sizes,
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def direct_mul_trans(self, other, reduce=True, indices=None):
|
||||
"""
|
||||
@@ -6988,7 +7079,10 @@ class _mem(_number):
|
||||
__ilshift__ = lambda self,other: self.write(self.read() << other)
|
||||
__irshift__ = lambda self,other: self.write(self.read() >> other)
|
||||
|
||||
iadd = __iadd__
|
||||
def iadd(self, other):
|
||||
""" Addition assignment. """
|
||||
return self.__iadd__(other)
|
||||
|
||||
isub = __isub__
|
||||
imul = __imul__
|
||||
itruediv = __itruediv__
|
||||
@@ -7016,7 +7110,7 @@ class MemValue(_mem):
|
||||
|
||||
@classmethod
|
||||
def if_necessary(cls, value):
|
||||
if util.is_constant_float(value):
|
||||
if util.is_constant_float(value) or isinstance(value, MemValue):
|
||||
return value
|
||||
else:
|
||||
return cls(value)
|
||||
@@ -7121,7 +7215,7 @@ class MemValue(_mem):
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def __repr__(self):
|
||||
return 'MemValue(%s,%d)' % (self.value_type, self.address)
|
||||
return 'MemValue(%s,%s)' % (self.value_type, self.address)
|
||||
|
||||
|
||||
class MemFloat(MemValue):
|
||||
|
||||
@@ -28,6 +28,8 @@ for socket in client.sockets:
|
||||
os.store(finish)
|
||||
os.Send(socket)
|
||||
|
||||
# running two rounds
|
||||
# first for sint, then for sfix
|
||||
for x in bonus, bonus * 2 ** 16:
|
||||
client.send_private_inputs([domain(x)])
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
|
||||
{
|
||||
Rq_Element tmp(*params);
|
||||
SeededPRNG G;
|
||||
vector<FFT_Data::S> r(params->FFTD()[0].m());
|
||||
vector<FFT_Data::S> r(params->FFTD()[0].phi_m());
|
||||
bigint p = pk.p();
|
||||
assert(p != 0);
|
||||
for (auto& x : r)
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
#include "Ring.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
Ring::Ring(int m) :
|
||||
mm(0), phim(0)
|
||||
{
|
||||
if (m != 0)
|
||||
init(*this, m);
|
||||
}
|
||||
|
||||
void Ring::pack(octetStream& o) const
|
||||
{
|
||||
o.store(mm);
|
||||
|
||||
@@ -22,7 +22,7 @@ class Ring
|
||||
public:
|
||||
|
||||
|
||||
Ring() : mm(0), phim(0) { ; }
|
||||
Ring(int m = 0);
|
||||
~Ring() { ; }
|
||||
|
||||
// Rely on default copy assignment/constructor
|
||||
@@ -40,6 +40,7 @@ class Ring
|
||||
void unpack(octetStream& o);
|
||||
|
||||
bool operator!=(const Ring& other) const;
|
||||
bool operator==(const Ring& other) const { return not (*this != other); }
|
||||
};
|
||||
|
||||
void init(Ring& Rg, int m, bool generate_poly = false);
|
||||
|
||||
@@ -44,6 +44,7 @@ void Ring_Element::prepare(const Ring_Element& other)
|
||||
void Ring_Element::prepare_push()
|
||||
{
|
||||
element.clear();
|
||||
assert(FFTD);
|
||||
element.reserve(FFTD->phi_m());
|
||||
}
|
||||
|
||||
@@ -63,6 +64,7 @@ void Ring_Element::assign_zero()
|
||||
|
||||
void Ring_Element::assign_one()
|
||||
{
|
||||
assert(FFTD);
|
||||
allocate();
|
||||
modp fill;
|
||||
if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); }
|
||||
@@ -79,6 +81,7 @@ void Ring_Element::negate()
|
||||
if (element.empty())
|
||||
return;
|
||||
|
||||
assert(FFTD);
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ Negate(element[i],element[i],(*FFTD).get_prD()); }
|
||||
}
|
||||
@@ -87,6 +90,7 @@ void Ring_Element::negate()
|
||||
|
||||
void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
assert(a.FFTD);
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
if (a.element.empty())
|
||||
{
|
||||
@@ -119,6 +123,7 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
|
||||
void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
assert(a.FFTD);
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
if (a.element.empty())
|
||||
@@ -148,6 +153,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
|
||||
void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
assert(a.FFTD);
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
if (a.element.empty() or b.element.empty())
|
||||
@@ -200,9 +206,11 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
}
|
||||
else if ((*a.FFTD).get_twop()==0)
|
||||
{ // m a power of two case
|
||||
ans.partial_assign(a);
|
||||
Ring_Element aa(*ans.FFTD,ans.rep);
|
||||
aa.partial_assign(a);
|
||||
modp temp;
|
||||
cerr << "slow polynomial multiplication "
|
||||
"(change representation to change this)..." << endl;
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ for (int j=0; j<(*ans.FFTD).phi_m(); j++)
|
||||
{ Mul(temp,a.element[i],b.element[j],(*a.FFTD).get_prD());
|
||||
@@ -213,7 +221,9 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
}
|
||||
Add(aa.element[k],aa.element[k],temp,(*a.FFTD).get_prD());
|
||||
}
|
||||
cerr << "\r" << i << "/" << ans.FFTD->phi_m();
|
||||
}
|
||||
cerr << endl;
|
||||
ans=aa;
|
||||
}
|
||||
else
|
||||
@@ -241,6 +251,7 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
|
||||
Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD);
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
@@ -252,6 +263,7 @@ Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
|
||||
Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD);
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
@@ -263,6 +275,7 @@ Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
|
||||
Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD);
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
assert(rep == evaluation);
|
||||
@@ -274,6 +287,7 @@ Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
|
||||
|
||||
Ring_Element& Ring_Element::operator *=(const modp& other)
|
||||
{
|
||||
assert(FFTD);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
element[i] = element[i].mul(other, FFTD->get_prD());
|
||||
return *this;
|
||||
@@ -282,6 +296,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other)
|
||||
|
||||
Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
{
|
||||
assert(FFTD);
|
||||
Ring_Element ans;
|
||||
ans.prepare(*this);
|
||||
if (element.empty())
|
||||
@@ -331,6 +346,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
|
||||
void Ring_Element::randomize(PRNG& G,bool Diag)
|
||||
{
|
||||
assert(FFTD);
|
||||
allocate();
|
||||
if (Diag==false)
|
||||
{ for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
@@ -352,6 +368,7 @@ void Ring_Element::randomize(PRNG& G,bool Diag)
|
||||
|
||||
void Ring_Element::change_rep(RepType r)
|
||||
{
|
||||
assert(FFTD);
|
||||
if (element.empty())
|
||||
{
|
||||
rep = r;
|
||||
@@ -403,6 +420,7 @@ void Ring_Element::change_rep(RepType r)
|
||||
|
||||
bool Ring_Element::equals(const Ring_Element& a) const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (rep!=a.rep) { throw rep_mismatch(); }
|
||||
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
|
||||
|
||||
@@ -417,6 +435,7 @@ bool Ring_Element::equals(const Ring_Element& a) const
|
||||
|
||||
bool Ring_Element::is_zero() const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (element.empty())
|
||||
return true;
|
||||
for (auto& x : element)
|
||||
@@ -428,6 +447,7 @@ bool Ring_Element::is_zero() const
|
||||
|
||||
ConversionIterator Ring_Element::get_iterator() const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (rep != polynomial)
|
||||
throw runtime_error("simple iterator only available in polynomial represention");
|
||||
assert(not element.empty());
|
||||
@@ -436,16 +456,19 @@ ConversionIterator Ring_Element::get_iterator() const
|
||||
|
||||
RingReadIterator Ring_Element::get_copy_iterator() const
|
||||
{
|
||||
assert(FFTD);
|
||||
return *this;
|
||||
}
|
||||
|
||||
RingWriteIterator Ring_Element::get_write_iterator()
|
||||
{
|
||||
assert(FFTD);
|
||||
return *this;
|
||||
}
|
||||
|
||||
vector<bigint> Ring_Element::to_vec_bigint() const
|
||||
{
|
||||
assert(FFTD);
|
||||
vector<bigint> v;
|
||||
to_vec_bigint(v);
|
||||
return v;
|
||||
@@ -454,6 +477,7 @@ vector<bigint> Ring_Element::to_vec_bigint() const
|
||||
|
||||
void Ring_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
{
|
||||
assert(FFTD);
|
||||
v.resize(FFTD->phi_m());
|
||||
if (element.empty())
|
||||
return;
|
||||
@@ -476,6 +500,7 @@ void Ring_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
|
||||
modp Ring_Element::get_constant() const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (element.empty())
|
||||
return {};
|
||||
else
|
||||
@@ -516,6 +541,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
|
||||
void Ring_Element::pack(octetStream& o) const
|
||||
{
|
||||
assert(FFTD);
|
||||
check_size();
|
||||
o.store(unsigned(rep));
|
||||
store(o,element,(*FFTD).get_prD());
|
||||
@@ -524,6 +550,7 @@ void Ring_Element::pack(octetStream& o) const
|
||||
|
||||
void Ring_Element::unpack(octetStream& o)
|
||||
{
|
||||
assert(FFTD);
|
||||
unsigned int a;
|
||||
o.get(a);
|
||||
rep=(RepType) a;
|
||||
@@ -542,12 +569,14 @@ void Ring_Element::check_rep()
|
||||
|
||||
void Ring_Element::check_size() const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (not element.empty() and (int)element.size() != FFTD->phi_m())
|
||||
throw runtime_error("invalid element size");
|
||||
}
|
||||
|
||||
void Ring_Element::output(ostream& s) const
|
||||
{
|
||||
assert(FFTD);
|
||||
s.write((char*)&rep, sizeof(rep));
|
||||
auto size = element.size();
|
||||
s.write((char*)&size, sizeof(size));
|
||||
@@ -558,6 +587,7 @@ void Ring_Element::output(ostream& s) const
|
||||
|
||||
void Ring_Element::input(istream& s)
|
||||
{
|
||||
assert(FFTD);
|
||||
s.read((char*)&rep, sizeof(rep));
|
||||
check_rep();
|
||||
auto size = element.size();
|
||||
@@ -579,6 +609,7 @@ void Ring_Element::check(const FFT_Data& FFTD) const
|
||||
|
||||
size_t Ring_Element::report_size(ReportType type) const
|
||||
{
|
||||
assert(FFTD);
|
||||
if (type == CAPACITY)
|
||||
return sizeof(modp) * element.capacity();
|
||||
else
|
||||
|
||||
@@ -56,9 +56,9 @@ class Ring_Element
|
||||
void allocate();
|
||||
|
||||
void set_data(const FFT_Data& prd) { FFTD=&prd; }
|
||||
const FFT_Data& get_FFTD() const { return *FFTD; }
|
||||
const Zp_Data& get_prD() const { return (*FFTD).get_prD(); }
|
||||
const bigint& get_prime() const { return (*FFTD).get_prime(); }
|
||||
const FFT_Data& get_FFTD() const { assert(FFTD); return *FFTD; }
|
||||
const Zp_Data& get_prD() const { return get_FFTD().get_prD(); }
|
||||
const bigint& get_prime() const { return get_FFTD().get_prime(); }
|
||||
|
||||
void assign_zero();
|
||||
void assign_one();
|
||||
@@ -120,6 +120,7 @@ class Ring_Element
|
||||
template <class T>
|
||||
void from(const vector<T>& source)
|
||||
{
|
||||
assert(source.size() == (size_t) get_FFTD().phi_m());
|
||||
from(Iterator<T>(source));
|
||||
}
|
||||
|
||||
@@ -162,7 +163,7 @@ class RingWriteIterator : public WriteConversionIterator
|
||||
RepType rep;
|
||||
public:
|
||||
RingWriteIterator(Ring_Element& element) :
|
||||
WriteConversionIterator(element.element, element.FFTD->get_prD()),
|
||||
WriteConversionIterator(element.element, element.get_FFTD().get_prD()),
|
||||
element(element), rep(element.rep)
|
||||
{
|
||||
element.rep = polynomial;
|
||||
@@ -177,7 +178,7 @@ class RingReadIterator : public ConversionIterator
|
||||
Ring_Element element;
|
||||
public:
|
||||
RingReadIterator(const Ring_Element& element) :
|
||||
ConversionIterator(this->element.element, element.FFTD->get_prD()),
|
||||
ConversionIterator(this->element.element, element.get_FFTD().get_prD()),
|
||||
element(element)
|
||||
{
|
||||
this->element.change_rep(polynomial);
|
||||
@@ -198,6 +199,7 @@ void Ring_Element::from(const Generator<T>& generator)
|
||||
T tmp;
|
||||
modp tmp2;
|
||||
prepare_push();
|
||||
assert(FFTD);
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{
|
||||
generator.get(tmp);
|
||||
|
||||
@@ -14,7 +14,10 @@ Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
|
||||
if (prd.size() > 0)
|
||||
a.push_back({prd[0], r0});
|
||||
if (prd.size() > 1)
|
||||
{
|
||||
assert(prd[0].get_R() == prd[1].get_R());
|
||||
a.push_back({prd[1], r1});
|
||||
}
|
||||
lev = n_mults();
|
||||
}
|
||||
|
||||
@@ -155,6 +158,7 @@ void Rq_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
if (lev==1)
|
||||
{ vector<bigint> v1;
|
||||
a[1].to_vec_bigint(v1);
|
||||
assert(v.size() == v1.size());
|
||||
bigint p0=a[0].get_prime();
|
||||
bigint p1=a[1].get_prime();
|
||||
bigint p0i,lambda,Q=p0*p1;
|
||||
|
||||
@@ -63,7 +63,10 @@ protected:
|
||||
Rq_Element(const FHE_PK& pk);
|
||||
|
||||
Rq_Element(const Ring_Element& b0,const Ring_Element& b1) :
|
||||
a({b0, b1}), lev(n_mults()) {}
|
||||
a({b0, b1}), lev(n_mults())
|
||||
{
|
||||
assert(b0.get_FFTD().get_R() == b1.get_FFTD().get_R());
|
||||
}
|
||||
|
||||
Rq_Element(const Ring_Element& b0) :
|
||||
a({b0}), lev(n_mults()) {}
|
||||
@@ -139,6 +142,8 @@ protected:
|
||||
template <class T>
|
||||
void from(const vector<T>& source, int level=-1)
|
||||
{
|
||||
for (auto& x : a)
|
||||
assert(source.size() == (size_t) x.get_FFTD().phi_m());
|
||||
from(Iterator<T>(source), level);
|
||||
}
|
||||
|
||||
|
||||
@@ -76,11 +76,14 @@ modp Find_Primitive_Root_2m(int m,const vector<int>& poly,const Zp_Data& ZpD)
|
||||
*/
|
||||
modp Find_Primitive_Root_2power(int m,const Zp_Data& ZpD)
|
||||
{
|
||||
assert((m & (m - 1)) == 0);
|
||||
assert(m > 1);
|
||||
modp ans,e,one,base;
|
||||
assignOne(one,ZpD);
|
||||
assignOne(base,ZpD);
|
||||
bigint exp;
|
||||
exp=(ZpD.pr-1)/m;
|
||||
assert(exp * m == ZpD.pr - 1);
|
||||
bool flag=true;
|
||||
while (flag)
|
||||
{ Add(base,base,one,ZpD); // Keep incrementing base until we hit the answer
|
||||
|
||||
@@ -17,6 +17,23 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
|
||||
mf.allocate_slots(pk.p() << 64);
|
||||
}
|
||||
|
||||
class ModuloTreeSum : public TreeSum<bigint>
|
||||
{
|
||||
bigint modulo;
|
||||
|
||||
void post_add_process(vector<bigint>& values)
|
||||
{
|
||||
for (auto& v : values)
|
||||
v %= modulo;
|
||||
}
|
||||
|
||||
public:
|
||||
ModuloTreeSum(bigint modulo) :
|
||||
modulo(modulo)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template<class FD>
|
||||
Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
|
||||
{
|
||||
@@ -57,10 +74,7 @@ Plaintext_<FD>& DistDecrypt<FD>::run(const Ciphertext& ctx, bool NewCiphertext)
|
||||
}
|
||||
else
|
||||
{
|
||||
TreeSum<bigint>().run(vv, P);
|
||||
bigint mod=params.p0();
|
||||
for (auto& v : vv)
|
||||
v %= mod;
|
||||
ModuloTreeSum(params.p0()).run(vv, P);
|
||||
}
|
||||
|
||||
// Now get the final message
|
||||
|
||||
@@ -76,6 +76,7 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
|
||||
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
|
||||
string reason;
|
||||
auto base_setup = setup;
|
||||
|
||||
try
|
||||
{
|
||||
@@ -107,7 +108,7 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
<< " because no suitable material "
|
||||
"from a previous run was found (" << reason << ")"
|
||||
<< endl;
|
||||
setup = {};
|
||||
setup = base_setup;
|
||||
setup.generate(P, machine, plaintext_length, sec);
|
||||
setup.check(P, machine);
|
||||
octetStream os;
|
||||
@@ -122,6 +123,10 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
cerr << "Ciphertext length: " << params.p0().numBits();
|
||||
for (size_t i = 1; i < params.FFTD().size(); i++)
|
||||
cerr << "+" << params.FFTD()[i].get_prime().numBits();
|
||||
cerr << " (" << DIV_CEIL(params.p0().numBits(), 64);
|
||||
for (size_t i = 1; i < params.FFTD().size(); i++)
|
||||
cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64);
|
||||
cerr << " limbs)";
|
||||
cerr << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,11 @@ public:
|
||||
{
|
||||
ProtocolSetup<DealerShare<BitVec>> setup(*P);
|
||||
ProtocolSet<DealerShare<BitVec>> set(*P, setup);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
int buffer_size = DIV_CEIL(
|
||||
BaseMachine::batch_size<DealerSecret>(DATA_TRIPLE),
|
||||
DealerSecret::default_length);
|
||||
set.preprocessing.buffer_extra(DATA_TRIPLE, buffer_size);
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
auto triple = set.preprocessing.get_triple(
|
||||
DealerSecret::default_length);
|
||||
|
||||
@@ -70,6 +70,7 @@ public:
|
||||
typedef U whole_type;
|
||||
|
||||
static const bool expensive_triples = true;
|
||||
static const bool malicious = true;
|
||||
|
||||
static MC* new_mc(typename super::mac_key_type)
|
||||
{
|
||||
|
||||
@@ -26,6 +26,7 @@ public:
|
||||
typedef Rep4Input<This> Input;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
static const bool malicious = true;
|
||||
|
||||
static MC* new_mc(typename super::mac_key_type) { return new MC; }
|
||||
|
||||
|
||||
@@ -135,6 +135,9 @@ public:
|
||||
|
||||
static void run_tapes(const vector<int>& args) { T::run_tapes(args); }
|
||||
|
||||
template<class U>
|
||||
static string proto_fake_opts() { return U::fake_opts(); }
|
||||
|
||||
Secret();
|
||||
Secret(const Integer& x) { *this = x; }
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
|
||||
void SemiPrep::buffer_triples()
|
||||
{
|
||||
assert(this->triple_generator);
|
||||
this->triple_generator->set_batch_size(
|
||||
DIV_CEIL(BaseMachine::batch_size<SemiSecret>(DATA_TRIPLE,
|
||||
this->buffer_size), 64));
|
||||
this->triple_generator->generatePlainTriples();
|
||||
for (auto& x : this->triple_generator->plainTriples)
|
||||
{
|
||||
|
||||
@@ -151,6 +151,12 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static string proto_fake_opts()
|
||||
{
|
||||
return T::fake_opts();
|
||||
}
|
||||
|
||||
RepSecretBase()
|
||||
{
|
||||
}
|
||||
@@ -258,6 +264,7 @@ public:
|
||||
typedef SemiHonestRepSecret whole_type;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
static const bool malicious = false;
|
||||
|
||||
static MC* new_mc(mac_key_type) { return new MC; }
|
||||
|
||||
|
||||
@@ -184,12 +184,15 @@ void ShareThread<T>::xors(Processor<T>& processor, const vector<int>& args)
|
||||
int out = args[i + 1];
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
processor.S[out + j].xor_(n, processor.S[left + j],
|
||||
processor.S[right + j]);
|
||||
}
|
||||
if (n_bits == 1)
|
||||
processor.S[out].xor_(1, processor.S[left], processor.S[right]);
|
||||
else
|
||||
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
processor.S[out + j].xor_(n, processor.S[left + j],
|
||||
processor.S[right + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ public:
|
||||
typedef TinierSecret<T> whole_type;
|
||||
|
||||
static const int default_length = 1;
|
||||
static const bool expensive_triples = true;
|
||||
|
||||
static string name()
|
||||
{
|
||||
|
||||
@@ -34,7 +34,9 @@ void TinierSharePrep<T>::buffer_secret_triples()
|
||||
vector<array<T, 3>> triples;
|
||||
TripleShuffleSacrifice<T> sacrifice;
|
||||
size_t required;
|
||||
required = sacrifice.minimum_n_inputs_with_combining();
|
||||
required = sacrifice.minimum_n_inputs_with_combining(
|
||||
BaseMachine::batch_size<T>(DATA_TRIPLE));
|
||||
triple_generator->set_batch_size(DIV_CEIL(required, 64));
|
||||
while (triples.size() < required)
|
||||
{
|
||||
triple_generator->generatePlainTriples();
|
||||
|
||||
@@ -50,7 +50,8 @@ public:
|
||||
static const bool variable_players = T::variable_players;
|
||||
static const bool needs_ot = T::needs_ot;
|
||||
static const bool has_mac = T::has_mac;
|
||||
static const bool expensive_triples = false;
|
||||
static const bool malicious = T::malicious;
|
||||
static const bool expensive_triples = T::expensive_triples;
|
||||
static const bool randoms_for_opens = false;
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
2
Makefile
2
Makefile
@@ -82,7 +82,7 @@ CONFIG.mine:
|
||||
%.o: %.cpp
|
||||
$(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c
|
||||
|
||||
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x
|
||||
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x mascot-party.x
|
||||
|
||||
offline: $(OT_EXE) Check-Offline.x mascot-offline.x cowgear-offline.x mal-shamir-offline.x
|
||||
|
||||
|
||||
@@ -154,9 +154,9 @@ void check_setup(string dir, bigint pr)
|
||||
string filename = dir + "Params-Data";
|
||||
ifstream(filename) >> p;
|
||||
if (p == 0)
|
||||
throw runtime_error("no modulus in " + filename);
|
||||
throw setup_error("no modulus in " + filename);
|
||||
if (p != pr)
|
||||
throw runtime_error("wrong modulus in " + filename);
|
||||
throw setup_error("wrong modulus in " + filename);
|
||||
}
|
||||
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
|
||||
|
||||
@@ -12,7 +12,7 @@ void ValueInterface::check_setup(const string& directory)
|
||||
{
|
||||
struct stat sb;
|
||||
if (stat(directory.c_str(), &sb) != 0)
|
||||
throw runtime_error(directory + " does not exist");
|
||||
throw setup_error(directory + " does not exist");
|
||||
if (not (sb.st_mode & S_IFDIR))
|
||||
throw runtime_error(directory + " is not a directory");
|
||||
throw setup_error(directory + " is not a directory");
|
||||
}
|
||||
|
||||
@@ -109,6 +109,11 @@ public:
|
||||
void assign(const void* buffer) { avx_memcpy(a, buffer, N_BYTES); normalize(); }
|
||||
void assign(int x) { *this = x; }
|
||||
|
||||
/**
|
||||
* Get 64-bit part.
|
||||
*
|
||||
* @param i return word containing 64*i- to 64*i+63-least significant bits
|
||||
*/
|
||||
mp_limb_t get_limb(int i) const { return a[i]; }
|
||||
bool get_bit(int i) const;
|
||||
|
||||
|
||||
@@ -8,15 +8,17 @@ void Zp_Data::init(const bigint& p,bool mont)
|
||||
{
|
||||
lock.lock();
|
||||
|
||||
#ifdef VERBOSE
|
||||
if (pr != 0)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
if (pr != p)
|
||||
cerr << "Changing prime from " << pr << " to " << p << endl;
|
||||
if (mont != montgomery)
|
||||
cerr << "Changing Montgomery" << endl;
|
||||
}
|
||||
#endif
|
||||
if (pr != p or mont != montgomery)
|
||||
throw runtime_error("Zp_Data instance already initialized");
|
||||
}
|
||||
|
||||
if (not probPrime(p))
|
||||
throw runtime_error(p.get_str() + " is not a prime");
|
||||
|
||||
@@ -47,7 +47,7 @@ class Zp_Data
|
||||
void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const;
|
||||
void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{ Mont_Mult(z, x, y, t); }
|
||||
{ Mont_Mult(z, x, y, get_t()); }
|
||||
void Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, const mp_limb_t* y,
|
||||
int max_t) const;
|
||||
|
||||
@@ -61,7 +61,7 @@ class Zp_Data
|
||||
|
||||
void assign(const Zp_Data& Zp);
|
||||
void init(const bigint& p,bool mont=true);
|
||||
int get_t() const { return t; }
|
||||
int get_t() const { assert(t > 0); return t; }
|
||||
const mp_limb_t* get_prA() const { return prA; }
|
||||
bool get_mont() const { return montgomery; }
|
||||
mp_limb_t overhang_mask() const;
|
||||
@@ -73,8 +73,9 @@ class Zp_Data
|
||||
Zp_Data() :
|
||||
montgomery(0), pi(0), mask(0), pr_byte_length(0), pr_bit_length(0)
|
||||
{
|
||||
t = MAX_MOD_SZ;
|
||||
t = -1;
|
||||
overhang = 0;
|
||||
shanks_r = 0;
|
||||
}
|
||||
|
||||
// The main init funciton
|
||||
|
||||
@@ -91,6 +91,7 @@ public:
|
||||
template<int K>
|
||||
bigint& operator=(const SignedZ2<K>& x);
|
||||
|
||||
/// Convert to signed representation in :math:`[-p/2,p/2]`.
|
||||
template<int X, int L>
|
||||
bigint& from_signed(const gfp_<X, L>& other);
|
||||
template<class T>
|
||||
|
||||
@@ -61,6 +61,8 @@ class gfp_ : public ValueInterface
|
||||
|
||||
static thread_local vector<gfp_> powers;
|
||||
|
||||
static gfp_ two;
|
||||
|
||||
public:
|
||||
|
||||
typedef gfp_ value_type;
|
||||
@@ -317,6 +319,8 @@ gfp_<X, L>::gfp_(long x)
|
||||
assign_zero();
|
||||
else if (x == 1)
|
||||
assign_one();
|
||||
else if (x == 2)
|
||||
*this = two;
|
||||
else
|
||||
*this = bigint::tmp = x;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ template<int X, int L>
|
||||
const true_type gfp_<X, L>::prime_field;
|
||||
template<int X, int L>
|
||||
const int gfp_<X, L>::MAX_N_BITS;
|
||||
template<int X, int L>
|
||||
gfp_<X, L> gfp_<X, L>::two;
|
||||
|
||||
template<int X, int L>
|
||||
inline void gfp_<X, L>::read_or_generate_setup(string dir,
|
||||
@@ -50,6 +52,7 @@ void gfp_<X, L>::init_field(const bigint& p, bool mont)
|
||||
else
|
||||
cerr << name << " larger than necessary for modulus " << p << endl;
|
||||
}
|
||||
two = bigint::tmp = 2;
|
||||
}
|
||||
|
||||
template <int X, int L>
|
||||
|
||||
@@ -80,6 +80,12 @@ void gfpvar_<X, L>::init_default(int lgp, bool montgomery)
|
||||
init_field(SPDZ_Data_Setup_Primes(lgp), montgomery);
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
inline void gfpvar_<X, L>::reset()
|
||||
{
|
||||
ZpD = {};
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
const Zp_Data& gfpvar_<X, L>::get_ZpD()
|
||||
{
|
||||
|
||||
@@ -68,6 +68,7 @@ public:
|
||||
{
|
||||
init_field(T::pr(), montgomery);
|
||||
}
|
||||
static void reset();
|
||||
|
||||
static const Zp_Data& get_ZpD();
|
||||
static const bigint& pr();
|
||||
|
||||
@@ -102,7 +102,7 @@ bool isZero(const modp_<L>& ans,const Zp_Data& ZpD)
|
||||
template<int L>
|
||||
void assignOne(modp_<L>& x,const Zp_Data& ZpD)
|
||||
{ if (ZpD.montgomery)
|
||||
{ mpn_copyi(x.x,ZpD.R,ZpD.t); }
|
||||
{ mpn_copyi(x.x,ZpD.R,ZpD.get_t()); }
|
||||
else
|
||||
{ assignZero(x,ZpD);
|
||||
x.x[0]=1;
|
||||
@@ -177,7 +177,7 @@ void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
|
||||
template<int L>
|
||||
void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD)
|
||||
{
|
||||
inline_mpn_zero(ans.x,ZpD.t);
|
||||
inline_mpn_zero(ans.x,ZpD.get_t());
|
||||
if (x>=0)
|
||||
{ ans.x[0]=x;
|
||||
if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; }
|
||||
@@ -232,13 +232,13 @@ void modp_<L>::convert_destroy(const fixint<M>& xx,
|
||||
SignedZ2<64 * L> tmp = xx;
|
||||
if (xx.negative())
|
||||
tmp += ZpD.pr;
|
||||
convert(tmp.get(), ZpD.t, ZpD, false);
|
||||
convert(tmp.get(), ZpD.get_t(), ZpD, false);
|
||||
}
|
||||
|
||||
template<int L>
|
||||
void modp_<L>::convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& ZpD, bool negative)
|
||||
{
|
||||
assert(size <= ZpD.t);
|
||||
assert(size <= ZpD.get_t());
|
||||
if (negative)
|
||||
mpn_sub(x, ZpD.prA, ZpD.t, source, size);
|
||||
else
|
||||
|
||||
@@ -110,6 +110,8 @@ public:
|
||||
Player* parentPlayer = 0);
|
||||
~OTTripleGenerator();
|
||||
|
||||
void set_batch_size(int nTriples);
|
||||
|
||||
void generate() { throw not_implemented(); }
|
||||
|
||||
void generatePlainTriples();
|
||||
|
||||
@@ -69,6 +69,14 @@ Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTTripleGenerator<T>::set_batch_size(int batch_size)
|
||||
{
|
||||
nTriplesPerLoop = DIV_CEIL(batch_size, nloops);
|
||||
nTriples = nTriplesPerLoop * nloops;
|
||||
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
@@ -84,11 +92,9 @@ OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
|
||||
machine(machine),
|
||||
MC(0)
|
||||
{
|
||||
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
|
||||
nTriples = nTriplesPerLoop * nloops;
|
||||
field_size = T::open_type::size() * 8;
|
||||
nAmplify = machine.amplify ? N_AMPLIFY : 1;
|
||||
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
|
||||
set_batch_size(_nTriples);
|
||||
|
||||
int n = nparties;
|
||||
//baseReceiverInput = machines[0]->baseReceiverInput;
|
||||
|
||||
@@ -13,6 +13,7 @@ using namespace std;
|
||||
#include "OT/OTVole.h"
|
||||
#include "OT/Rectangle.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
template<class T>
|
||||
class NPartyTripleGenerator;
|
||||
@@ -187,7 +188,7 @@ class SemiMultiplier : public OTMultiplier<T>
|
||||
}
|
||||
|
||||
public:
|
||||
vector<typename T::open_type> c_output;
|
||||
CheckVector<typename T::open_type> c_output;
|
||||
|
||||
SemiMultiplier(OTTripleGenerator<T>& generator, int i) :
|
||||
OTMultiplier<T>(generator, i)
|
||||
|
||||
@@ -163,10 +163,10 @@ void SemiMultiplier<T>::multiplyForBits()
|
||||
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
|
||||
int n_squares = otCorrelator.receiverOutputMatrix.squares.size();
|
||||
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
|
||||
baseReceiverOutput);
|
||||
otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(),
|
||||
this->generator.valueBits[0], false, -1);
|
||||
otCorrelator.correlate(0, n_squares, aBits, false, -1);
|
||||
|
||||
c_output.clear();
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
|
||||
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
|
||||
__m128i buffer[T::size()];
|
||||
size_t next = 0;
|
||||
while (next + 16 < row.size())
|
||||
while (next + 16 <= row.size())
|
||||
{
|
||||
for (int j = 0; j < 16; j++)
|
||||
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
|
||||
@@ -124,6 +124,8 @@ void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
|
||||
for (int j = 0; j < 16; j++)
|
||||
if (next < row.size())
|
||||
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
|
||||
else
|
||||
memset((char*) buffer + j * T::size(), 0, T::size());
|
||||
for (int j = 0; j < num_blocks % T::size(); j++)
|
||||
add_mul(res, buffer[j], *coefficients++);
|
||||
assert(coefficients == coeff_base + num_blocks);
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include "Protocols/ShuffleSacrifice.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <sodium.h>
|
||||
using namespace std;
|
||||
@@ -30,6 +32,28 @@ BaseMachine& BaseMachine::s()
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
bool BaseMachine::has_program()
|
||||
{
|
||||
return has_singleton() and not s().progs.empty();
|
||||
}
|
||||
|
||||
int BaseMachine::edabit_bucket_size(int n_bits)
|
||||
{
|
||||
int res = OnlineOptions::singleton.bucket_size;
|
||||
|
||||
if (has_program())
|
||||
{
|
||||
auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
for (int B = res; B <= 5; B++)
|
||||
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
|
||||
break;
|
||||
else
|
||||
res = B;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
BaseMachine::BaseMachine() : nthreads(0)
|
||||
{
|
||||
if (sodium_init() == -1)
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "ThreadJob.h"
|
||||
#include "ThreadQueues.h"
|
||||
#include "Program.h"
|
||||
#include "OnlineOptions.h"
|
||||
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
@@ -44,8 +46,11 @@ public:
|
||||
|
||||
vector<string> bc_filenames;
|
||||
|
||||
vector<Program> progs;
|
||||
|
||||
static BaseMachine& s();
|
||||
static bool has_singleton() { return singleton != 0; }
|
||||
static bool has_program();
|
||||
|
||||
static string memory_filename(const string& type_short, int my_number);
|
||||
|
||||
@@ -54,6 +59,12 @@ public:
|
||||
static int prime_length_from_schedule(string progname);
|
||||
static bigint prime_from_schedule(string progname);
|
||||
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
template<class T>
|
||||
static int edabit_batch_size(int n_bits, int buffer_size = 0);
|
||||
static int edabit_bucket_size(int n_bits);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
|
||||
@@ -76,6 +87,8 @@ public:
|
||||
|
||||
void print_global_comm(Player& P, const NamedCommStats& stats);
|
||||
void print_comm(Player& P, const NamedCommStats& stats);
|
||||
|
||||
virtual const Names& get_N() { throw not_implemented(); }
|
||||
};
|
||||
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
@@ -83,4 +96,105 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
return ot_setup.get_fresh(P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
{
|
||||
int n_opts;
|
||||
int n = 0;
|
||||
int res = 0;
|
||||
|
||||
if (buffer_size > 0)
|
||||
n_opts = buffer_size;
|
||||
else if (fallback > 0)
|
||||
n_opts = fallback;
|
||||
else
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
|
||||
if (buffer_size <= 0 and has_program())
|
||||
{
|
||||
auto files = s().progs[0].get_offline_data_used().files;
|
||||
auto usage = files[T::clear::field_type()];
|
||||
|
||||
if (type == DATA_DABIT and T::LivePrep::bits_from_dabits())
|
||||
n = usage[DATA_BIT] + usage[DATA_DABIT];
|
||||
else if (type == DATA_BIT and T::LivePrep::dabits_from_bits())
|
||||
n = usage[DATA_BIT] + usage[DATA_DABIT];
|
||||
else
|
||||
n = usage[type];
|
||||
}
|
||||
else if (type != DATA_DABIT)
|
||||
{
|
||||
n = buffer_size;
|
||||
buffer_size = 0;
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
}
|
||||
|
||||
if (n > 0 and not (buffer_size > 0))
|
||||
{
|
||||
bool used_frac = false;
|
||||
if (n > n_opts)
|
||||
{
|
||||
// finding the right fraction
|
||||
for (int i = 1; i <= 10; i++)
|
||||
{
|
||||
int frac = DIV_CEIL(n, i);
|
||||
if (frac <= n_opts)
|
||||
{
|
||||
res = frac;
|
||||
used_frac = true;
|
||||
#ifdef DEBUG_BATCH_SIZE
|
||||
cerr << "found fraction " << frac << endl;
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (not used_frac)
|
||||
res = min(n, n_opts);
|
||||
}
|
||||
else
|
||||
res = n_opts;
|
||||
|
||||
#ifdef DEBUG_BATCH_SIZE
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n="
|
||||
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
|
||||
#endif
|
||||
|
||||
assert(res > 0);
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
|
||||
{
|
||||
int n_opts;
|
||||
int n = 0;
|
||||
int res;
|
||||
|
||||
if (buffer_size > 0)
|
||||
n_opts = buffer_size;
|
||||
else
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
|
||||
if (has_program())
|
||||
{
|
||||
n = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
}
|
||||
|
||||
if (n > 0 and not (buffer_size > 0))
|
||||
res = min(n, n_opts);
|
||||
else
|
||||
res = n_opts;
|
||||
|
||||
#ifdef DEBUG_BATCH_SIZE
|
||||
cerr << "edaBits " << T::type_string() << " (" << n_bits
|
||||
<< ") res=" << res << " n="
|
||||
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
|
||||
#endif
|
||||
|
||||
assert(res > 0);
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_BASEMACHINE_H_ */
|
||||
|
||||
@@ -229,3 +229,9 @@ bool DataPositions::any_more(const DataPositions& other) const
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
long long DataPositions::total_edabits(int n_bits) const
|
||||
{
|
||||
auto usage = edabits;
|
||||
return usage[{false, n_bits}] + usage[{true, n_bits}];
|
||||
}
|
||||
|
||||
@@ -85,6 +85,8 @@ public:
|
||||
void print_cost() const;
|
||||
bool empty() const;
|
||||
bool any_more(const DataPositions& other) const;
|
||||
|
||||
long long total_edabits(int n_bits) const;
|
||||
};
|
||||
|
||||
template<class sint, class sgf2n> class Processor;
|
||||
@@ -229,6 +231,10 @@ public:
|
||||
|
||||
static long additional_inputs(const DataPositions& usage);
|
||||
|
||||
static string get_prep_dir(const Names& N);
|
||||
static void check_setup(const Names& N);
|
||||
static void check_setup(int num_players, const string& prep_dir);
|
||||
|
||||
Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir,
|
||||
DataPositions& usage, int thread_num = -1);
|
||||
Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num = -1);
|
||||
@@ -299,7 +305,7 @@ class Data_Files
|
||||
|
||||
Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<sint>* procp = 0,
|
||||
SubProcessor<sgf2n>* proc2 = 0);
|
||||
Data_Files(const Names& N);
|
||||
Data_Files(const Names& N, int thread_num = -1);
|
||||
~Data_Files();
|
||||
|
||||
DataPositions tellg() { return usage; }
|
||||
|
||||
@@ -60,8 +60,7 @@ T Preprocessing<T>::get_random_from_inputs(int nplayers)
|
||||
template<class T>
|
||||
Sub_Data_Files<T>::Sub_Data_Files(const Names& N, DataPositions& usage,
|
||||
int thread_num) :
|
||||
Sub_Data_Files(N,
|
||||
OnlineOptions::singleton.prep_dir_prefix<T>(N.num_players()), usage,
|
||||
Sub_Data_Files(N, get_prep_dir(N), usage,
|
||||
thread_num)
|
||||
{
|
||||
}
|
||||
@@ -98,6 +97,32 @@ string Sub_Data_Files<T>::get_edabit_filename(const Names& N, int n_bits,
|
||||
get_prep_sub_dir<T>(N.num_players()), n_bits, N.my_num(), thread_num);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
string Sub_Data_Files<T>::get_prep_dir(const Names& N)
|
||||
{
|
||||
return OnlineOptions::singleton.prep_dir_prefix<T>(N.num_players());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::check_setup(const Names& N)
|
||||
{
|
||||
return check_setup(N.num_players(), get_prep_dir(N));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::check_setup(int num_players, const string& prep_dir)
|
||||
{
|
||||
try
|
||||
{
|
||||
T::clear::check_setup(prep_dir);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
throw prep_setup_error(e.what(), num_players,
|
||||
T::template proto_fake_opts<typename T::clear>());
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
const string& prep_data_dir, DataPositions& usage, int thread_num) :
|
||||
@@ -109,19 +134,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
|
||||
#endif
|
||||
|
||||
try
|
||||
{
|
||||
T::clear::check_setup(prep_data_dir);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
cerr << "Something is wrong with the preprocessing data on disk." << endl;
|
||||
cerr
|
||||
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
|
||||
<< num_players
|
||||
<< T::clear::fake_opts() << "'?" << endl;
|
||||
throw;
|
||||
}
|
||||
check_setup(num_players, prep_data_dir);
|
||||
|
||||
string type_short = T::type_short();
|
||||
string type_string = T::type_string();
|
||||
@@ -173,11 +186,11 @@ Data_Files<sint, sgf2n>::Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
|
||||
Data_Files<sint, sgf2n>::Data_Files(const Names& N, int thread_num) :
|
||||
usage(N.num_players()),
|
||||
DataFp(*new Sub_Data_Files<sint>(N, usage)),
|
||||
DataF2(*new Sub_Data_Files<sgf2n>(N, usage)),
|
||||
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage))
|
||||
DataFp(*new Sub_Data_Files<sint>(N, usage, thread_num)),
|
||||
DataF2(*new Sub_Data_Files<sgf2n>(N, usage, thread_num)),
|
||||
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage, thread_num))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -112,6 +112,8 @@ template<class T>
|
||||
class DummyLivePrep : public Preprocessing<T>
|
||||
{
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
|
||||
static void basic_setup(Player&)
|
||||
{
|
||||
}
|
||||
@@ -125,6 +127,11 @@ public:
|
||||
"live preprocessing not implemented for " + T::type_string());
|
||||
}
|
||||
|
||||
static bool bits_from_dabits()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
|
||||
Preprocessing<T>(usage)
|
||||
{
|
||||
|
||||
@@ -29,7 +29,12 @@ public:
|
||||
if (not BufferBase::file)
|
||||
{
|
||||
if (this->open()->fail())
|
||||
throw runtime_error("error opening " + this->filename);
|
||||
throw runtime_error(
|
||||
"error opening " + this->filename
|
||||
+ ", have you generated edaBits, "
|
||||
"for example by running "
|
||||
"'./Fake-Offline.x -e "
|
||||
+ to_string(n_bits) + " ...'?");
|
||||
}
|
||||
|
||||
assert(BufferBase::file);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "Instruction.h"
|
||||
#include "instructions.h"
|
||||
#include "Processor.h"
|
||||
#include "Memory.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "GC/instructions.h"
|
||||
|
||||
@@ -54,7 +55,7 @@ void Instruction::gbitcom(vector<cgf2n>& registers) const
|
||||
}
|
||||
}
|
||||
|
||||
void Instruction::execute_regint(ArithmeticProcessor& Proc, vector<Integer>& Mi) const
|
||||
void Instruction::execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const
|
||||
{
|
||||
(void) Mi;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
@@ -14,6 +14,7 @@ using namespace std;
|
||||
template<class sint, class sgf2n> class Machine;
|
||||
template<class sint, class sgf2n> class Processor;
|
||||
template<class T> class SubProcessor;
|
||||
template<class T> class MemoryPart;
|
||||
class ArithmeticProcessor;
|
||||
class SwitchableOutput;
|
||||
|
||||
@@ -86,6 +87,8 @@ enum
|
||||
SUBCFI = 0x2B,
|
||||
SUBSFI = 0x2C,
|
||||
PREFIXSUMS = 0x2D,
|
||||
PICKS = 0x2E,
|
||||
CONCATS = 0x2F,
|
||||
// Multiplication/division/other arithmetic
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -392,7 +395,7 @@ public:
|
||||
template<class cgf2n>
|
||||
void gbitcom(vector<cgf2n>& registers) const;
|
||||
|
||||
void execute_regint(ArithmeticProcessor& Proc, vector<Integer>& Mi) const;
|
||||
void execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const;
|
||||
|
||||
void shuffle(ArithmeticProcessor& Proc) const;
|
||||
void bitdecint(ArithmeticProcessor& Proc) const;
|
||||
|
||||
@@ -208,6 +208,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
get_ints(r, s, 2);
|
||||
n = get_int(s);
|
||||
break;
|
||||
case PICKS:
|
||||
get_ints(r, s, 3);
|
||||
n = get_int(s);
|
||||
break;
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
@@ -392,6 +396,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case EDABIT:
|
||||
case SEDABIT:
|
||||
case WRITEFILESHARE:
|
||||
case CONCATS:
|
||||
num_var_args = get_int(s) - 1;
|
||||
r[0] = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
@@ -930,6 +935,20 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case MOVC:
|
||||
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
|
||||
break;
|
||||
case CONCATS:
|
||||
{
|
||||
auto& S = Proc.Procp.get_S();
|
||||
auto dest = S.begin() + r[0];
|
||||
for (auto j = start.begin(); j < start.end(); j += 2)
|
||||
{
|
||||
auto source = S.begin() + *(j + 1);
|
||||
assert(dest + *j <= S.end());
|
||||
assert(source + *j <= S.end());
|
||||
for (int k = 0; k < *j; k++)
|
||||
*dest++ = *source++;
|
||||
}
|
||||
return;
|
||||
}
|
||||
case DIVC:
|
||||
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2]));
|
||||
break;
|
||||
|
||||
@@ -53,8 +53,6 @@ class Machine : public BaseMachine
|
||||
|
||||
public:
|
||||
|
||||
vector<Program> progs;
|
||||
|
||||
Memory<sgf2n> M2;
|
||||
Memory<sint> Mp;
|
||||
Memory<Integer> Mi;
|
||||
@@ -63,10 +61,6 @@ class Machine : public BaseMachine
|
||||
vector<Timer> join_timer;
|
||||
Timer finish_timer;
|
||||
|
||||
bool direct;
|
||||
int opening_sum;
|
||||
bool receive_threads;
|
||||
int max_broadcast;
|
||||
bool use_encryption;
|
||||
bool live_prep;
|
||||
|
||||
|
||||
@@ -55,8 +55,6 @@ template<class sint, class sgf2n>
|
||||
Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
const OnlineOptions opts, int lg2)
|
||||
: my_number(playerNames.my_num()), N(playerNames),
|
||||
direct(opts.direct), opening_sum(opts.opening_sum),
|
||||
receive_threads(opts.receive_threads), max_broadcast(opts.max_broadcast),
|
||||
use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts),
|
||||
external_clients(my_number)
|
||||
{
|
||||
@@ -69,11 +67,6 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (opening_sum < 2)
|
||||
this->opening_sum = N.num_players();
|
||||
if (max_broadcast < 2)
|
||||
this->max_broadcast = N.num_players();
|
||||
|
||||
// Set the prime modulus from command line or program if applicable
|
||||
if (opts.prime)
|
||||
sint::clear::init_field(opts.prime);
|
||||
@@ -102,7 +95,17 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
sint::bit_type::MAC_Check::setup(*P);
|
||||
sgf2n::MAC_Check::setup(*P);
|
||||
|
||||
alphapi = read_generate_write_mac_key<sint>(*P);
|
||||
if (opts.live_prep)
|
||||
alphapi = read_generate_write_mac_key<sint>(*P);
|
||||
else
|
||||
{
|
||||
// check for directory
|
||||
Sub_Data_Files<sint>::check_setup(N);
|
||||
// require existing MAC key
|
||||
if (sint::has_mac)
|
||||
read_mac_key<sint>(N, alphapi);
|
||||
}
|
||||
|
||||
alpha2i = read_generate_write_mac_key<sgf2n>(*P);
|
||||
alphabi = read_generate_write_mac_key<typename
|
||||
sint::bit_type::part_type>(*P);
|
||||
@@ -451,6 +454,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
finish_timer.start();
|
||||
|
||||
// actual usage
|
||||
bool multithread = nthreads > 1;
|
||||
auto res = stop_threads();
|
||||
DataPositions& pos = res.first;
|
||||
|
||||
@@ -479,7 +483,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
cerr << "Communication details "
|
||||
"(rounds in parallel threads counted double):" << endl;
|
||||
comm_stats.print();
|
||||
cerr << "CPU time = " << proc_timer.elapsed() << endl;
|
||||
cerr << "CPU time = " << proc_timer.elapsed();
|
||||
if (multithread)
|
||||
cerr << " (overall core time)";
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
print_timers();
|
||||
|
||||
@@ -19,6 +19,28 @@ template<class T>
|
||||
class MemoryPart : public CheckVector<T>
|
||||
{
|
||||
public:
|
||||
template<class U>
|
||||
static void check_index(const vector<U>& M, size_t i)
|
||||
{
|
||||
(void) M, (void) i;
|
||||
#ifndef NO_CHECK_INDEX
|
||||
if (i >= M.size())
|
||||
throw overflow(U::type_string() + " memory", i, M.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
T& operator[](size_t i)
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
}
|
||||
|
||||
const T& operator[](size_t i) const
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
}
|
||||
|
||||
void minimum_size(size_t size);
|
||||
};
|
||||
|
||||
@@ -40,35 +62,21 @@ class Memory
|
||||
size_t size_c()
|
||||
{ return MC.size(); }
|
||||
|
||||
template<class U>
|
||||
static void check_index(const vector<U>& M, size_t i)
|
||||
{
|
||||
(void) M, (void) i;
|
||||
#ifndef NO_CHECK_INDEX
|
||||
if (i >= M.size())
|
||||
throw overflow(U::type_string() + " memory", i, M.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
const typename T::clear& read_C(size_t i) const
|
||||
{
|
||||
check_index(MC, i);
|
||||
return MC[i];
|
||||
}
|
||||
const T& read_S(size_t i) const
|
||||
{
|
||||
check_index(MS, i);
|
||||
return MS[i];
|
||||
}
|
||||
|
||||
void write_C(size_t i,const typename T::clear& x)
|
||||
{
|
||||
check_index(MC, i);
|
||||
MC[i]=x;
|
||||
}
|
||||
void write_S(size_t i,const T& x)
|
||||
{
|
||||
check_index(MS, i);
|
||||
MS[i]=x;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@
|
||||
#include "Networking/CryptoPlayer.h"
|
||||
|
||||
template<class W>
|
||||
class OfflineMachine : public W
|
||||
class OfflineMachine : public W, BaseMachine
|
||||
{
|
||||
DataPositions usage;
|
||||
BaseMachine machine;
|
||||
Names& playerNames;
|
||||
Player& P;
|
||||
int n_threads;
|
||||
|
||||
template<class T>
|
||||
void generate();
|
||||
@@ -34,6 +32,8 @@ public:
|
||||
|
||||
template<class T, class U>
|
||||
int run();
|
||||
|
||||
const Names& get_N();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_OFFLINEMACHINE_H_ */
|
||||
|
||||
@@ -18,18 +18,19 @@ OfflineMachine<W>::OfflineMachine(int argc, const char** argv,
|
||||
W(argc, argv, opt, online_opts, V(), nplayers), playerNames(
|
||||
W::playerNames), P(*this->new_player("machine"))
|
||||
{
|
||||
machine.load_schedule(online_opts.progname, false);
|
||||
load_schedule(online_opts.progname, false);
|
||||
Program program(playerNames.num_players());
|
||||
program.parse(machine.bc_filenames[0]);
|
||||
program.parse(bc_filenames[0]);
|
||||
progs.push_back(program);
|
||||
|
||||
if (program.usage_unknown())
|
||||
{
|
||||
cerr << "Preprocessing might be insufficient "
|
||||
cerr << "Preprocessing will be insufficient "
|
||||
<< "due to unknown requirements" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
usage = program.get_offline_data_used();
|
||||
n_threads = machine.nthreads;
|
||||
}
|
||||
|
||||
template<class W>
|
||||
@@ -73,7 +74,7 @@ int OfflineMachine<W>::run()
|
||||
template<class W>
|
||||
int OfflineMachine<W>::buffered_total(size_t required, size_t batch)
|
||||
{
|
||||
return DIV_CEIL(required, batch) * batch + (n_threads - 1) * batch;
|
||||
return DIV_CEIL(required, batch) * batch + (nthreads - 1) * batch;
|
||||
}
|
||||
|
||||
template<class W>
|
||||
@@ -183,4 +184,10 @@ void OfflineMachine<W>::generate()
|
||||
output.Check(P);
|
||||
}
|
||||
|
||||
template<class W>
|
||||
const Names& OfflineMachine<W>::get_N()
|
||||
{
|
||||
return playerNames;
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_OFFLINEMACHINE_HPP_ */
|
||||
|
||||
@@ -43,6 +43,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
BaseMachine::s().thread_num = num;
|
||||
|
||||
auto& queues = machine.queues[num];
|
||||
auto& opts = machine.opts;
|
||||
queues->next();
|
||||
ThreadQueue::thread_queue = queues;
|
||||
|
||||
@@ -58,7 +59,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
#endif
|
||||
player = new CryptoPlayer(*(tinfo->Nms), id);
|
||||
}
|
||||
else if (!machine.receive_threads or machine.direct)
|
||||
else if (!opts.receive_threads or opts.direct)
|
||||
{
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using single-threaded receiving" << endl;
|
||||
@@ -80,7 +81,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
typename sgf2n::MAC_Check* MC2;
|
||||
typename sint::MAC_Check* MCp;
|
||||
|
||||
if (machine.direct)
|
||||
if (opts.direct)
|
||||
{
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using direct communication." << endl;
|
||||
@@ -93,8 +94,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using indirect communication." << endl;
|
||||
#endif
|
||||
MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), machine.opening_sum, machine.max_broadcast);
|
||||
MCp = new typename sint::MAC_Check(*(tinfo->alphapi), machine.opening_sum, machine.max_broadcast);
|
||||
MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), opts.opening_sum, opts.max_broadcast);
|
||||
MCp = new typename sint::MAC_Check(*(tinfo->alphapi), opts.opening_sum, opts.max_broadcast);
|
||||
}
|
||||
|
||||
// Allocate memory for first program before starting the clock
|
||||
@@ -376,6 +377,10 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
|
||||
{
|
||||
ti.Sub_Main_Func();
|
||||
}
|
||||
catch (setup_error&)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
|
||||
@@ -393,16 +398,20 @@ void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N, int thread_nu
|
||||
cerr << "Purging preprocessed data because something is wrong" << endl;
|
||||
try
|
||||
{
|
||||
Data_Files<sint, sgf2n> df(N);
|
||||
Data_Files<sint, sgf2n> df(N, thread_num);
|
||||
df.purge();
|
||||
DataPositions pos;
|
||||
Sub_Data_Files<typename sint::bit_type> bit_df(N, pos, thread_num);
|
||||
bit_df.get_part();
|
||||
bit_df.purge();
|
||||
}
|
||||
catch(...)
|
||||
catch(setup_error&)
|
||||
{
|
||||
}
|
||||
catch(exception& e)
|
||||
{
|
||||
cerr << "Purging failed. This might be because preprocessed data is incomplete." << endl
|
||||
<< "SECURITY FAILURE; YOU ARE ON YOUR OWN NOW!" << endl;
|
||||
cerr << "Reason: " << e.what() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,25 +41,28 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir,
|
||||
}
|
||||
|
||||
void PrepBase::print_left(const char* name, size_t n, const string& type_string,
|
||||
size_t used)
|
||||
size_t used, bool large)
|
||||
{
|
||||
if (n > 0 and OnlineOptions::singleton.verbose)
|
||||
cerr << "\t" << n << " " << name << " of " << type_string << " left"
|
||||
<< endl;
|
||||
|
||||
if (n > used / 10)
|
||||
if (n > used / 10 and n >= 64)
|
||||
{
|
||||
cerr << "Significant amount of unused " << name << " of " << type_string
|
||||
<< ". For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size." << endl;
|
||||
cerr
|
||||
<< "Note that some protocols have larger minimum batch sizes."
|
||||
<< endl;
|
||||
<< " distorting the benchmark. ";
|
||||
if (large)
|
||||
cerr << "This protocol has a large minimum batch size, "
|
||||
<< "which makes this unavoidable for small programs.";
|
||||
else
|
||||
cerr << "For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size.";
|
||||
cerr << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used)
|
||||
int n_bits, size_t used, bool malicious)
|
||||
{
|
||||
if (n > 0 and OnlineOptions::singleton.verbose)
|
||||
{
|
||||
@@ -70,8 +73,15 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
}
|
||||
|
||||
if (n * n_batch > used / 10)
|
||||
{
|
||||
cerr << "Significant amount of unused edaBits of size " << n_bits
|
||||
<< ". For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size "
|
||||
<< "or increasing the bucket size with --bucket-size." << endl;
|
||||
<< ". ";
|
||||
if (malicious)
|
||||
cerr << "This protocol has a large minimum batch size, "
|
||||
<< "which makes this unavoidable for small programs.";
|
||||
else
|
||||
cerr << "For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size.";
|
||||
cerr << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,9 +26,9 @@ public:
|
||||
int my_num, int thread_num = 0);
|
||||
|
||||
static void print_left(const char* name, size_t n,
|
||||
const string& type_string, size_t used);
|
||||
const string& type_string, size_t used, bool large = false);
|
||||
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used);
|
||||
int n_bits, size_t used, bool malicious);
|
||||
|
||||
TimerWithComm prep_timer;
|
||||
};
|
||||
|
||||
@@ -65,6 +65,8 @@
|
||||
X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
sint s, \
|
||||
s += *op1++; *dest++ = s) \
|
||||
X(PICKS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1] + r[2]], \
|
||||
*dest++ = *op1; op1 += int(n)) \
|
||||
X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
|
||||
@@ -38,7 +38,7 @@ except:
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
batch_size = min(N, 128)
|
||||
|
||||
if 'savemem' in program.args:
|
||||
N = batch_size
|
||||
|
||||
@@ -38,7 +38,7 @@ except:
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
batch_size = min(N, 128)
|
||||
|
||||
if 'savemem' in program.args:
|
||||
N = batch_size
|
||||
|
||||
@@ -38,7 +38,7 @@ except:
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
batch_size = min(N, 128)
|
||||
|
||||
assert batch_size <= N
|
||||
ml.Layer.back_batch_size = batch_size
|
||||
|
||||
@@ -26,12 +26,14 @@ exec(subprocess.check_output(['Scripts/process-tf.py', program.args[1]]))
|
||||
|
||||
opt = ml.Optimizer()
|
||||
opt.set_layers_with_inputs(layers)
|
||||
layers[0].X.input_from(0)
|
||||
layers[0].X.input_from(0, binary=True)
|
||||
for layer in layers:
|
||||
layer.input_from(0, raw='raw' in program.args)
|
||||
layer.input_from(0, binary=True)
|
||||
|
||||
sint(0).reveal().store_in_mem(0)
|
||||
|
||||
opt.time_layers = 'time_layers' in program.args
|
||||
|
||||
start_timer(1)
|
||||
opt.forward(1, keep_intermediate=False)
|
||||
stop_timer(1)
|
||||
|
||||
@@ -115,7 +115,7 @@ void BrainPrep<T>::buffer_triples()
|
||||
+ to_string(ZProtocol<T>::share_type::clear::N_BITS)
|
||||
+ "-bit integer computation");
|
||||
typedef Rep3Share<gfp2> pShare;
|
||||
auto buffer_size = OnlineOptions::singleton.batch_size;
|
||||
auto buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE);
|
||||
Player& P = this->protocol->P;
|
||||
vector<array<ZShare<T>, 3>> triples;
|
||||
vector<array<Rep3Share<gfp2>, 3>> check_triples;
|
||||
|
||||
32
Protocols/BufferScope.h
Normal file
32
Protocols/BufferScope.h
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* BufferScope.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_BUFFERSCOPE_H_
|
||||
#define PROTOCOLS_BUFFERSCOPE_H_
|
||||
|
||||
template<class T> class BufferPrep;
|
||||
template<class T> class Preprocessing;
|
||||
|
||||
template<class T>
|
||||
class BufferScope
|
||||
{
|
||||
BufferPrep<T>& prep;
|
||||
int bak;
|
||||
|
||||
public:
|
||||
BufferScope(Preprocessing<T> & prep, int buffer_size) :
|
||||
prep(dynamic_cast<BufferPrep<T>&>(prep))
|
||||
{
|
||||
bak = this->prep.buffer_size;
|
||||
this->prep.buffer_size = buffer_size;
|
||||
}
|
||||
|
||||
~BufferScope()
|
||||
{
|
||||
prep.buffer_size = bak;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_BUFFERSCOPE_H_ */
|
||||
@@ -33,6 +33,8 @@ class ChaiGearPrep : public MaliciousRingPrep<T>
|
||||
void buffer_bits(false_type);
|
||||
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
|
||||
static void basic_setup(Player& P);
|
||||
static void key_setup(Player& P, mac_key_type alphai);
|
||||
static void teardown();
|
||||
|
||||
@@ -33,6 +33,8 @@ class CowGearPrep : public MaliciousRingPrep<T>
|
||||
void buffer_bits(false_type);
|
||||
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
|
||||
static void basic_setup(Player& P);
|
||||
static void key_setup(Player& P, mac_key_type alphai);
|
||||
static void setup(Player& P, mac_key_type alphai);
|
||||
|
||||
@@ -112,9 +112,6 @@ PairwiseGenerator<typename T::clear::FD>& CowGearPrep<T>::get_generator()
|
||||
{
|
||||
auto& machine = *pairwise_machine;
|
||||
typedef typename T::open_type::FD FD;
|
||||
// generate minimal number of items
|
||||
this->buffer_size = min(machine.setup<FD>().alpha.num_slots(),
|
||||
(unsigned)OnlineOptions::singleton.batch_size);
|
||||
pairwise_generator = new PairwiseGenerator<FD>(0, machine, &proc->P);
|
||||
}
|
||||
return *pairwise_generator;
|
||||
|
||||
@@ -6,21 +6,28 @@
|
||||
#ifndef PROTOCOLS_DABITSACRIFICE_H_
|
||||
#define PROTOCOLS_DABITSACRIFICE_H_
|
||||
|
||||
#include "Processor/BaseMachine.h"
|
||||
|
||||
template<class T>
|
||||
class DabitSacrifice
|
||||
{
|
||||
const int S;
|
||||
|
||||
size_t n_masks, n_produced;
|
||||
|
||||
public:
|
||||
DabitSacrifice();
|
||||
~DabitSacrifice();
|
||||
|
||||
int minimum_n_inputs(int n_outputs = 0)
|
||||
{
|
||||
if (n_outputs < 1)
|
||||
n_outputs = OnlineOptions::singleton.batch_size;
|
||||
if (T::clear::N_BITS < 0)
|
||||
// sacrifice uses S^2 random bits
|
||||
n_outputs = max(n_outputs, 10 * S * S);
|
||||
n_outputs = BaseMachine::batch_size<T>(DATA_DABIT,
|
||||
n_outputs, max(n_outputs, 10 * S * S));
|
||||
else
|
||||
n_outputs = BaseMachine::batch_size<T>(DATA_DABIT, n_outputs);
|
||||
assert(n_outputs > 0);
|
||||
return n_outputs + S;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,13 +7,15 @@
|
||||
#define PROTOCOLS_DABITSACRIFICE_HPP_
|
||||
|
||||
#include "DabitSacrifice.h"
|
||||
#include "BufferScope.h"
|
||||
#include "Tools/PointerVector.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
template<class T>
|
||||
DabitSacrifice<T>::DabitSacrifice() :
|
||||
S(OnlineOptions::singleton.security_parameter)
|
||||
S(OnlineOptions::singleton.security_parameter),
|
||||
n_masks(0), n_produced()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -36,10 +38,15 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
|
||||
timer.start();
|
||||
#endif
|
||||
int n = check_dabits.size() - S;
|
||||
n_masks += S;
|
||||
assert(n > 0);
|
||||
GlobalPRNG G(proc.P);
|
||||
typedef typename T::bit_type::part_type BT;
|
||||
vector<T> shares;
|
||||
vector<BT> bit_shares;
|
||||
if (T::clear::N_BITS <= 0)
|
||||
dynamic_cast<BufferPrep<T>&>(proc.DataF).buffer_extra(DATA_BIT,
|
||||
S * (ceil(log2(n)) + S));
|
||||
for (int i = 0; i < S; i++)
|
||||
{
|
||||
dabit<T> to_check;
|
||||
@@ -58,6 +65,7 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
|
||||
T tmp;
|
||||
proc.DataF.get_one(DATA_BIT, tmp);
|
||||
masked += tmp << (1 + j);
|
||||
n_masks++;
|
||||
}
|
||||
shares.push_back(masked);
|
||||
bit_shares.push_back(to_check.second);
|
||||
@@ -84,6 +92,7 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
|
||||
}
|
||||
}
|
||||
dabits.insert(dabits.end(), check_dabits.begin(), check_dabits.begin() + n);
|
||||
n_produced += n;
|
||||
MCBB.Check(proc.P);
|
||||
delete &MCBB;
|
||||
#ifdef VERBOSE_DABIT
|
||||
@@ -92,6 +101,17 @@ void DabitSacrifice<T>::sacrifice_without_bit_check(vector<dabit<T> >& dabits,
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
DabitSacrifice<T>::~DabitSacrifice()
|
||||
{
|
||||
#ifdef DABIT_WASTAGE
|
||||
if (n_produced > 0)
|
||||
{
|
||||
cerr << "daBit wastage: " << float(n_masks) / n_produced << endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void DabitSacrifice<T>::sacrifice_and_check_bits(vector<dabit<T> >& dabits,
|
||||
vector<dabit<T> >& check_dabits, SubProcessor<T>& proc,
|
||||
@@ -113,7 +133,10 @@ void DabitSacrifice<T>::sacrifice_and_check_bits(vector<dabit<T> >& dabits,
|
||||
queues->wrap_up(job);
|
||||
}
|
||||
else
|
||||
{
|
||||
BufferScope<T> scope(proc.DataF, multiplicands.size());
|
||||
protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc);
|
||||
}
|
||||
vector<T> check_for_zero;
|
||||
for (auto& x : to_check)
|
||||
check_for_zero.push_back(x.first - products.next());
|
||||
|
||||
@@ -17,11 +17,13 @@ void DealerPrep<T>::buffer_triples()
|
||||
vector<bool> senders(P.num_players());
|
||||
senders.back() = true;
|
||||
octetStreams os(P), to_receive(P);
|
||||
int buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE,
|
||||
this->buffer_size);
|
||||
if (this->proc->input.is_dealer())
|
||||
{
|
||||
SeededPRNG G;
|
||||
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
T triples[3];
|
||||
for (int i = 0; i < 2; i++)
|
||||
@@ -41,7 +43,7 @@ void DealerPrep<T>::buffer_triples()
|
||||
else
|
||||
{
|
||||
P.send_receive_all(senders, os, to_receive);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
this->triples.push_back(to_receive.back().get<FixedVec<T, 3>>().get());
|
||||
}
|
||||
}
|
||||
@@ -68,11 +70,12 @@ void DealerPrep<T>::buffer_inverses(true_type)
|
||||
vector<bool> senders(P.num_players());
|
||||
senders.back() = true;
|
||||
octetStreams os(P), to_receive(P);
|
||||
int buffer_size = BaseMachine::batch_size<T>(DATA_INVERSE);
|
||||
if (this->proc->input.is_dealer())
|
||||
{
|
||||
SeededPRNG G;
|
||||
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
T tuple[2];
|
||||
while (tuple[0] == 0)
|
||||
@@ -92,7 +95,7 @@ void DealerPrep<T>::buffer_inverses(true_type)
|
||||
else
|
||||
{
|
||||
P.send_receive_all(senders, os, to_receive);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
this->inverses.push_back(to_receive.back().get<FixedVec<T, 2>>().get());
|
||||
}
|
||||
}
|
||||
@@ -105,11 +108,12 @@ void DealerPrep<T>::buffer_bits()
|
||||
vector<bool> senders(P.num_players());
|
||||
senders.back() = true;
|
||||
octetStreams os(P), to_receive(P);
|
||||
int buffer_size = BaseMachine::batch_size<T>(DATA_BIT);
|
||||
if (this->proc->input.is_dealer())
|
||||
{
|
||||
SeededPRNG G;
|
||||
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
T bit = G.get_bit();
|
||||
make_share(shares.data(), typename T::clear(bit),
|
||||
@@ -123,7 +127,7 @@ void DealerPrep<T>::buffer_bits()
|
||||
else
|
||||
{
|
||||
P.send_receive_all(senders, os, to_receive);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
this->bits.push_back(to_receive.back().get<T>());
|
||||
}
|
||||
}
|
||||
@@ -136,12 +140,13 @@ void DealerPrep<T>::buffer_dabits(ThreadQueues*)
|
||||
vector<bool> senders(P.num_players());
|
||||
senders.back() = true;
|
||||
octetStreams os(P), to_receive(P);
|
||||
int buffer_size = BaseMachine::batch_size<T>(DATA_DABIT);
|
||||
if (this->proc->input.is_dealer())
|
||||
{
|
||||
SeededPRNG G;
|
||||
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
|
||||
vector<GC::SemiSecret> bit_shares(P.num_players() - 1);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
auto bit = G.get_bit();
|
||||
make_share(shares.data(), typename T::clear(bit),
|
||||
@@ -160,7 +165,7 @@ void DealerPrep<T>::buffer_dabits(ThreadQueues*)
|
||||
else
|
||||
{
|
||||
P.send_receive_all(senders, os, to_receive);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
this->dabits.push_back({to_receive.back().get<T>(),
|
||||
to_receive.back().get<typename T::bit_type>()});
|
||||
@@ -200,7 +205,8 @@ void DealerPrep<T>::buffer_edabits(int length, false_type)
|
||||
vector<bool> senders(P.num_players());
|
||||
senders.back() = true;
|
||||
octetStreams os(P), to_receive(P);
|
||||
int n_vecs = OnlineOptions::singleton.batch_size / edabitvec<T>::MAX_SIZE;
|
||||
int n_vecs = DIV_CEIL(BaseMachine::edabit_batch_size<T>(length),
|
||||
edabitvec<T>::MAX_SIZE);
|
||||
auto& buffer = this->edabits[{false, length}];
|
||||
if (this->proc->input.is_dealer())
|
||||
{
|
||||
|
||||
@@ -28,6 +28,8 @@ class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>>
|
||||
HemiMatrixPrep(const HemiMatrixPrep&) = delete;
|
||||
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
|
||||
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep,
|
||||
DataPositions& usage) :
|
||||
super(usage), n_rows(n_rows), n_inner(n_inner),
|
||||
|
||||
@@ -35,6 +35,8 @@ class HemiPrep : public SemiHonestRingPrep<T>
|
||||
SemiPrep<T>& get_two_party_prep();
|
||||
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
|
||||
static void basic_setup(Player& P);
|
||||
static void teardown();
|
||||
|
||||
|
||||
@@ -142,6 +142,8 @@ void HemiPrep<T>::buffer_bits()
|
||||
if (this->proc->P.num_players() == 2)
|
||||
{
|
||||
auto& prep = get_two_party_prep();
|
||||
prep.buffer_size = BaseMachine::batch_size<T>(DATA_BIT,
|
||||
this->buffer_size);
|
||||
prep.buffer_dabits(0);
|
||||
for (auto& x : prep.dabits)
|
||||
this->bits.push_back(x.first);
|
||||
@@ -158,6 +160,8 @@ void HemiPrep<T>::buffer_dabits(ThreadQueues* queues)
|
||||
if (this->proc->P.num_players() == 2)
|
||||
{
|
||||
auto& prep = get_two_party_prep();
|
||||
prep.buffer_size = BaseMachine::batch_size<T>(DATA_DABIT,
|
||||
this->buffer_size);
|
||||
prep.buffer_dabits(queues);
|
||||
this->dabits = prep.dabits;
|
||||
prep.dabits.clear();
|
||||
|
||||
@@ -24,7 +24,7 @@ KeyGenProtocol<X, L>::KeyGenProtocol(Player& P, const FHE_Params& params,
|
||||
int level) :
|
||||
P(P), params(params), fftd(params.FFTD().at(level)), usage(P)
|
||||
{
|
||||
open_type::init_field(params.FFTD().at(level).get_prD().pr);
|
||||
open_type::init_field(params.FFTD().at(level).get_prD().pr, false);
|
||||
typename share_type::mac_key_type alphai;
|
||||
|
||||
auto& batch_size = OnlineOptions::singleton.batch_size;
|
||||
@@ -54,7 +54,8 @@ KeyGenProtocol<X, L>::~KeyGenProtocol()
|
||||
{
|
||||
MC->Check(P);
|
||||
|
||||
usage.print_cost();
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
usage.print_cost();
|
||||
|
||||
delete proc;
|
||||
delete prep;
|
||||
@@ -63,6 +64,7 @@ KeyGenProtocol<X, L>::~KeyGenProtocol()
|
||||
MC->teardown();
|
||||
|
||||
OnlineOptions::singleton.batch_size = backup_batch_size;
|
||||
open_type::reset();
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
|
||||
@@ -12,6 +12,7 @@ using namespace std;
|
||||
#include "Protocols/MAC_Check_Base.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/Coordinator.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
|
||||
/* The MAX number of things we will partially open before running
|
||||
@@ -38,6 +39,8 @@ class TreeSum
|
||||
void add_openings(vector<T>& values, const Player& P, int sum_players,
|
||||
int last_sum_players, int send_player);
|
||||
|
||||
virtual void post_add_process(vector<T>&) {}
|
||||
|
||||
protected:
|
||||
int base_player;
|
||||
int opening_sum;
|
||||
@@ -54,7 +57,9 @@ public:
|
||||
vector<Timer> timers;
|
||||
vector<Timer> player_timers;
|
||||
|
||||
TreeSum(int opening_sum = 10, int max_broadcast = 10, int base_player = 0);
|
||||
TreeSum(int opening_sum = OnlineOptions::singleton.opening_sum,
|
||||
int max_broadcast = OnlineOptions::singleton.max_broadcast,
|
||||
int base_player = 0);
|
||||
virtual ~TreeSum();
|
||||
|
||||
void run(vector<T>& values, const Player& P);
|
||||
@@ -114,7 +119,7 @@ Coordinator* Tree_MAC_Check<U>::coordinator = 0;
|
||||
* SPDZ opening protocol with MAC check (indirect communication)
|
||||
*/
|
||||
template<class U>
|
||||
class MAC_Check_ : public Tree_MAC_Check<U>
|
||||
class MAC_Check_ : public virtual Tree_MAC_Check<U>
|
||||
{
|
||||
public:
|
||||
MAC_Check_(const typename U::mac_key_type::Scalar& ai, int opening_sum = 10,
|
||||
@@ -135,7 +140,7 @@ template<class T> class MascotPrep;
|
||||
* SPDZ2k opening protocol with MAC check
|
||||
*/
|
||||
template<class T, class U, class V, class W>
|
||||
class MAC_Check_Z2k : public Tree_MAC_Check<W>
|
||||
class MAC_Check_Z2k : public virtual Tree_MAC_Check<W>
|
||||
{
|
||||
protected:
|
||||
Preprocessing<W>* prep;
|
||||
@@ -161,12 +166,11 @@ template<class W>
|
||||
using MAC_Check_Z2k_ = MAC_Check_Z2k<typename W::open_type,
|
||||
typename W::mac_key_type, typename W::open_type, W>;
|
||||
|
||||
|
||||
/**
|
||||
* SPDZ opening protocol with MAC check (pairwise communication)
|
||||
*/
|
||||
template<class T>
|
||||
class Direct_MAC_Check: public MAC_Check_<T>
|
||||
class Direct_MAC_Check: public virtual MAC_Check_<T>
|
||||
{
|
||||
typedef MAC_Check_<T> super;
|
||||
|
||||
@@ -186,7 +190,35 @@ public:
|
||||
|
||||
void init_open(const Player& P, int n = 0);
|
||||
void prepare_open(const T& secret, int = -1);
|
||||
void exchange(const Player& P);
|
||||
virtual void exchange(const Player& P);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class Direct_MAC_Check_Z2k: virtual public MAC_Check_Z2k_<T>,
|
||||
virtual public Direct_MAC_Check<T>
|
||||
{
|
||||
public:
|
||||
Direct_MAC_Check_Z2k(const typename T::mac_key_type& ai) :
|
||||
Tree_MAC_Check<T>(ai), MAC_Check_Z2k_<T>(ai), MAC_Check_<T>(ai),
|
||||
Direct_MAC_Check<T>(ai)
|
||||
{
|
||||
}
|
||||
|
||||
void prepare_open(const T& secret, int = -1)
|
||||
{
|
||||
MAC_Check_Z2k_<T>::prepare_open(secret);
|
||||
}
|
||||
|
||||
void exchange(const Player& P)
|
||||
{
|
||||
Direct_MAC_Check<T>::exchange(P);
|
||||
assert(this->WaitingForCheck() > 0);
|
||||
}
|
||||
|
||||
void Check(const Player& P)
|
||||
{
|
||||
MAC_Check_Z2k_<T>::Check(P);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -272,6 +304,7 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
|
||||
{
|
||||
values[i].add(oss[j], use_lengths ? lengths[i] : -1);
|
||||
}
|
||||
post_add_process(values);
|
||||
MC.timers[SUM].stop();
|
||||
}
|
||||
}
|
||||
@@ -279,6 +312,11 @@ void TreeSum<T>::add_openings(vector<T>& values, const Player& P,
|
||||
template<class T>
|
||||
void TreeSum<T>::start(vector<T>& values, const Player& P)
|
||||
{
|
||||
if (opening_sum < 2)
|
||||
opening_sum = P.num_players();
|
||||
if (max_broadcast < 2)
|
||||
max_broadcast = P.num_players();
|
||||
|
||||
os.reset_write_head();
|
||||
int sum_players = P.num_players();
|
||||
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
|
||||
|
||||
@@ -98,6 +98,7 @@ void Tree_MAC_Check<U>::init_open(const Player&, int n)
|
||||
template<class U>
|
||||
void Tree_MAC_Check<U>::prepare_open(const U& secret, int)
|
||||
{
|
||||
assert(U::mac_type::invertible);
|
||||
this->values.push_back(secret.get_share());
|
||||
macs.push_back(secret.get_mac());
|
||||
}
|
||||
@@ -344,7 +345,7 @@ Direct_MAC_Check<T>::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai
|
||||
|
||||
template<class T>
|
||||
Direct_MAC_Check<T>::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai) :
|
||||
MAC_Check_<T>(ai)
|
||||
Tree_MAC_Check<T>(ai), MAC_Check_<T>(ai)
|
||||
{
|
||||
open_counter = 0;
|
||||
}
|
||||
@@ -405,6 +406,7 @@ void Direct_MAC_Check<T>::init_open(const Player& P, int n)
|
||||
template<class T>
|
||||
void Direct_MAC_Check<T>::prepare_open(const T& secret, int)
|
||||
{
|
||||
assert(T::mac_type::invertible);
|
||||
this->values.push_back(secret.get_share());
|
||||
this->macs.push_back(secret.get_mac());
|
||||
}
|
||||
|
||||
@@ -77,6 +77,11 @@ class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
|
||||
public virtual SimplerMalRepRingPrep<T>
|
||||
{
|
||||
public:
|
||||
static bool dabits_from_bits()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol)
|
||||
|
||||
@@ -64,7 +64,8 @@ void MalRepRingPrep<T>::buffer_squares()
|
||||
MaliciousRepPrep<prep_type> prep(_);
|
||||
assert(this->proc != 0);
|
||||
prep.init_honest(this->proc->P);
|
||||
prep.buffer_size = this->buffer_size;
|
||||
prep.buffer_size = BaseMachine::batch_size<T>(DATA_SQUARE,
|
||||
this->buffer_size);
|
||||
prep.buffer_squares();
|
||||
for (auto& x : prep.squares)
|
||||
this->squares.push_back({{x[0], x[1]}});
|
||||
|
||||
@@ -35,11 +35,6 @@ public:
|
||||
typedef MalRepRingShare<K + 2, S> SquareToBitShare;
|
||||
typedef MalRepRingPrep<MalRepRingShare> SquarePrep;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "RR";
|
||||
}
|
||||
|
||||
MalRepRingShare()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -53,6 +53,10 @@ public:
|
||||
{
|
||||
return "M" + string(1, T::type_char());
|
||||
}
|
||||
static string type_string()
|
||||
{
|
||||
return "malicious " + super::type_string();
|
||||
}
|
||||
|
||||
MaliciousRep3Share()
|
||||
{
|
||||
|
||||
@@ -76,7 +76,8 @@ void MaliciousRepPrep<T>::buffer_triples()
|
||||
{
|
||||
check_field_size<typename T::open_type>();
|
||||
auto& triples = this->triples;
|
||||
auto buffer_size = this->buffer_size;
|
||||
auto buffer_size = BaseMachine::batch_size<T>(DATA_TRIPLE,
|
||||
this->buffer_size);
|
||||
auto& honest_proc = this->honest_proc;
|
||||
assert(honest_proc != 0);
|
||||
Player& P = honest_proc->P;
|
||||
@@ -140,7 +141,7 @@ void MaliciousRepPrep<T>::buffer_squares()
|
||||
vector<typename T::open_type> opened;
|
||||
vector<array<T, 2>> check_squares;
|
||||
auto& squares = this->squares;
|
||||
auto buffer_size = this->buffer_size;
|
||||
auto buffer_size = BaseMachine::batch_size<T>(DATA_SQUARE, this->buffer_size);
|
||||
auto& honest_prep = this->honest_prep;
|
||||
auto& honest_proc = this->honest_proc;
|
||||
auto& MC = this->MC;
|
||||
@@ -186,7 +187,8 @@ void MaliciousBitOnlyRepPrep<T>::buffer_bits()
|
||||
vector<typename T::open_type> opened;
|
||||
vector<array<T, 2>> check_squares;
|
||||
auto& bits = this->bits;
|
||||
auto buffer_size = this->buffer_size;
|
||||
auto buffer_size = BaseMachine::batch_size<T>(DATA_BIT,
|
||||
this->buffer_size);
|
||||
assert(honest_proc);
|
||||
Player& P = honest_proc->P;
|
||||
honest_prep.buffer_size = buffer_size;
|
||||
|
||||
@@ -32,9 +32,8 @@ void MaliciousDabitOnlyPrep<T>::buffer_dabits(ThreadQueues* queues, false_type,
|
||||
{
|
||||
assert(this->proc != 0);
|
||||
vector<dabit<T>> check_dabits;
|
||||
DabitSacrifice<T> dabit_sacrifice;
|
||||
this->buffer_dabits_without_check(check_dabits,
|
||||
dabit_sacrifice.minimum_n_inputs(), queues);
|
||||
dabit_sacrifice.minimum_n_inputs(this->buffer_size), queues);
|
||||
dabit_sacrifice.sacrifice_and_check_bits(this->dabits, check_dabits,
|
||||
*this->proc, queues);
|
||||
}
|
||||
|
||||
@@ -90,6 +90,8 @@ class MascotPrep : public virtual MaliciousRingPrep<T>,
|
||||
public virtual MascotDabitOnlyPrep<T>
|
||||
{
|
||||
public:
|
||||
static bool bits_from_triples() { return true; }
|
||||
|
||||
MascotPrep(SubProcessor<T>* proc, DataPositions& usage) :
|
||||
BufferPrep<T>(usage), BitPrep<T>(proc, usage),
|
||||
RingPrep<T>(proc, usage),
|
||||
|
||||
@@ -62,8 +62,11 @@ void MascotTriplePrep<T>::buffer_triples()
|
||||
auto& params = this->params;
|
||||
auto& triple_generator = this->triple_generator;
|
||||
params.generateBits = false;
|
||||
triple_generator->set_batch_size(
|
||||
BaseMachine::batch_size<T>(DATA_TRIPLE, this->buffer_size));
|
||||
triple_generator->generate();
|
||||
triple_generator->unlock();
|
||||
triple_generator->set_batch_size(OnlineOptions::singleton.batch_size);
|
||||
assert(triple_generator->uncheckedTriples.size() != 0);
|
||||
for (auto& triple : triple_generator->uncheckedTriples)
|
||||
this->triples.push_back(
|
||||
|
||||
@@ -118,8 +118,7 @@ void Rep3Shuffler<T>::apply(vector<T>& a, size_t n, int unit_size,
|
||||
template<class T>
|
||||
void Rep3Shuffler<T>::del(int handle)
|
||||
{
|
||||
for (int i = 0; i < 2; i++)
|
||||
shuffles.at(handle)[i].clear();
|
||||
shuffles.at(handle) = {};
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -42,6 +42,11 @@ class Rep4RingOnlyPrep : public virtual Rep4RingPrep<T>,
|
||||
}
|
||||
|
||||
public:
|
||||
static bool dabits_from_bits()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static void edabit_sacrifice_buckets(vector<edabit<T>>&, size_t, bool, int,
|
||||
SubProcessor<T>&, int, int, const void* = 0)
|
||||
{
|
||||
|
||||
@@ -46,7 +46,7 @@ void Rep4RingPrep<T>::buffer_inputs(int player)
|
||||
template<class T>
|
||||
void Rep4RingPrep<T>::buffer_triples()
|
||||
{
|
||||
generate_triples(this->triples, OnlineOptions::singleton.batch_size,
|
||||
generate_triples(this->triples, BaseMachine::batch_size<T>(DATA_TRIPLE),
|
||||
this->protocol);
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ void Rep4RingPrep<T>::buffer_bits()
|
||||
auto& protocol = this->proc->protocol;
|
||||
|
||||
vector<typename T::open_type> bits;
|
||||
int batch_size = OnlineOptions::singleton.batch_size;
|
||||
int batch_size = BaseMachine::batch_size<T>(DATA_BIT);
|
||||
bits.reserve(batch_size);
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
bits.push_back(G.get_bit());
|
||||
|
||||
@@ -12,7 +12,7 @@ void RepRingOnlyEdabitPrep<T>::buffer_edabits(int n_bits, ThreadQueues*)
|
||||
{
|
||||
assert(this->proc);
|
||||
int dl = T::bit_type::default_length;
|
||||
int buffer_size = DIV_CEIL(this->buffer_size, dl) * dl;
|
||||
int buffer_size = DIV_CEIL(BaseMachine::edabit_batch_size<T>(n_bits, this->buffer_size), dl) * dl;
|
||||
vector<T> wholes;
|
||||
wholes.resize(buffer_size);
|
||||
Instruction inst;
|
||||
@@ -49,5 +49,5 @@ void RepRingOnlyEdabitPrep<T>::buffer_edabits(int n_bits, ThreadQueues*)
|
||||
SubProcessor<bt> bit_proc(party.MC->get_part_MC(), this->proc->bit_prep, P);
|
||||
bit_adder.multi_add(sums, summands, 0, buffer_size / dl, bit_proc, dl, 0);
|
||||
|
||||
this->push_edabits(this->edabits[{false, n_bits}], wholes, sums, buffer_size);
|
||||
this->push_edabits(this->edabits[{false, n_bits}], wholes, sums);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "Protocols/ShuffleSacrifice.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
#include "edabit.h"
|
||||
#include "DabitSacrifice.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
@@ -36,13 +37,13 @@ class BufferPrep : public Preprocessing<T>
|
||||
|
||||
friend class InScope;
|
||||
|
||||
static const bool homomorphic = false;
|
||||
|
||||
template<int>
|
||||
void buffer_inverses(true_type);
|
||||
template<int>
|
||||
void buffer_inverses(false_type) { throw runtime_error("no inverses"); }
|
||||
|
||||
virtual bool bits_from_dabits() { return false; }
|
||||
|
||||
protected:
|
||||
vector<array<T, 3>> triples;
|
||||
vector<array<T, 2>> squares;
|
||||
@@ -83,8 +84,9 @@ protected:
|
||||
{ throw runtime_error("no personal daBits"); }
|
||||
|
||||
void push_edabits(vector<edabitvec<T>>& edabits,
|
||||
const vector<T>& sums, const vector<vector<typename T::bit_type::part_type>>& bits,
|
||||
int buffer_size);
|
||||
const vector<T>& sums,
|
||||
const vector<vector<typename T::bit_type::part_type>>& bits);
|
||||
|
||||
public:
|
||||
typedef T share_type;
|
||||
|
||||
@@ -103,6 +105,10 @@ public:
|
||||
throw runtime_error("sacrifice not available");
|
||||
}
|
||||
|
||||
static bool bits_from_dabits() { return false; }
|
||||
static bool bits_from_triples() { return false; }
|
||||
static bool dabits_from_bits() { return false; }
|
||||
|
||||
BufferPrep(DataPositions& usage);
|
||||
virtual ~BufferPrep();
|
||||
|
||||
@@ -135,6 +141,8 @@ public:
|
||||
|
||||
SubProcessor<T>* get_proc() { return proc; }
|
||||
void set_proc(SubProcessor<T>* proc) { this->proc = proc; }
|
||||
|
||||
void buffer_extra(Dtype type, int n_items);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -272,7 +280,7 @@ public:
|
||||
void buffer_edabits(int n_bits, false_type)
|
||||
{ this->template buffer_edabits_without_check<0>(n_bits,
|
||||
this->edabits[{false, n_bits}],
|
||||
OnlineOptions::singleton.batch_size); }
|
||||
BaseMachine::edabit_batch_size<T>(n_bits, this->buffer_size)); }
|
||||
template<int>
|
||||
void buffer_edabits(int, true_type)
|
||||
{ throw not_implemented(); }
|
||||
@@ -286,6 +294,8 @@ public:
|
||||
template<class T>
|
||||
class MaliciousDabitOnlyPrep : public virtual RingPrep<T>
|
||||
{
|
||||
DabitSacrifice<T> dabit_sacrifice;
|
||||
|
||||
template<int>
|
||||
void buffer_dabits(ThreadQueues* queues, true_type, false_type);
|
||||
template<int>
|
||||
@@ -312,6 +322,8 @@ class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep<T>
|
||||
{
|
||||
typedef typename T::bit_type::part_type BT;
|
||||
|
||||
DabitSacrifice<T> dabit_sacrifice;
|
||||
|
||||
protected:
|
||||
void buffer_personal_edabits(int n_bits, vector<T>& sums,
|
||||
vector<vector<BT>>& bits, SubProcessor<BT>& proc, int input_player,
|
||||
@@ -343,7 +355,7 @@ public:
|
||||
bool strict, int player, SubProcessor<T>& proc, int begin, int end,
|
||||
const void* supply = 0)
|
||||
{
|
||||
EdabitShuffleSacrifice<T>().edabit_sacrifice_buckets(to_check, n_bits, strict,
|
||||
EdabitShuffleSacrifice<T>(n_bits).edabit_sacrifice_buckets(to_check, strict,
|
||||
player, proc, begin, end, supply);
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user