Secure shuffling.

This commit is contained in:
Marcel Keller
2022-05-27 14:19:33 +02:00
parent 2dad77ba32
commit 5ab8c702dd
108 changed files with 2227 additions and 542 deletions

View File

@@ -259,7 +259,6 @@ ProgramParty::~ProgramParty()
reset();
if (P)
{
cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl;
delete P;
}
delete[] eval_threads;

View File

@@ -28,7 +28,7 @@ RealProgramParty<T>* RealProgramParty<T>::singleton = 0;
template<class T>
RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
garble_processor(garble_machine), dummy_proc({{}, 0})
garble_processor(garble_machine), dummy_proc({}, 0)
{
assert(singleton == 0);
singleton = this;
@@ -157,6 +157,9 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
MC->Check(*P);
data_sent = P->total_comm().sent;
if (online_opts.verbose)
P->total_comm().print();
this->machine.write_memory(this->N.my_num());
}

View File

@@ -1,5 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.2 (Mai 27, 2022)
- Secure shuffling
- O(n log n) radix sorting
- Documented BGV encryption interface
- Optimized matrix multiplication in dealer protocol
- Fixed security bug in homomorphic encryption parameter generation
- Fixed Security bug in Temi matrix multiplication
## 0.3.1 (Apr 19, 2022)
- Protocol in dealer model

View File

@@ -382,7 +382,6 @@ class sbits(bits):
reg_type = 'sb'
is_clear = False
clear_type = cbits
default_type = cbits
load_inst = (inst.ldmsbi, inst.ldmsb)
store_inst = (inst.stmsbi, inst.stmsb)
bitdec = inst.bitdecs
@@ -404,6 +403,9 @@ class sbits(bits):
else:
return sbits.get_type(n)(value)
@staticmethod
def _new(value):
return value
@staticmethod
def get_random_bit():
res = sbit()
inst.bitb(res)
@@ -909,6 +911,7 @@ class cbit(bit, cbits):
sbits.bit_type = sbit
cbits.bit_type = cbit
sbit.clear_type = cbit
sbits.default_type = sbits
class bitsBlock(oram.Block):
value_type = sbits

View File

@@ -17,6 +17,7 @@ right order.
import itertools
import operator
import math
from . import tools
from random import randint
from functools import reduce
@@ -2406,6 +2407,70 @@ class trunc_pr(base.VarArgsInstruction):
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
@base.gf2n
class secshuffle(base.VectorInstruction, base.DataInstruction):
""" Secure shuffling.
:param: destination (sint)
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['SECSHUFFLE']
arg_format = ['sw','s','int']
def __init__(self, *args, **kwargs):
super(secshuffle_class, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', 0), float('inf'))
class gensecshuffle(base.DataInstruction):
""" Generate secure shuffle to bit used several times.
:param: destination (regint)
:param: size (int)
"""
__slots__ = []
code = base.opcodes['GENSECSHUFFLE']
arg_format = ['ciw','int']
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', 0), float('inf'))
class applyshuffle(base.VectorInstruction, base.DataInstruction):
""" Generate secure shuffle to bit used several times.
:param: destination (sint)
:param: source (sint)
:param: number of elements to be treated as one (int)
:param: handle (regint)
:param: reverse (0/1)
"""
__slots__ = []
code = base.opcodes['APPLYSHUFFLE']
arg_format = ['sw','s','int','ci','int']
def __init__(self, *args, **kwargs):
super(applyshuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
req_node.increment((self.field_type, 'triple', 0), float('inf'))
class delshuffle(base.Instruction):
""" Delete secure shuffle.
:param: handle (regint)
"""
code = base.opcodes['DELSHUFFLE']
arg_format = ['ci']
class check(base.Instruction):
"""
Force MAC check in current thread and all idle thread if current

View File

@@ -106,6 +106,11 @@ opcodes = dict(
CONV2DS = 0xAC,
CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
# Shuffling
SECSHUFFLE = 0xFA,
GENSECSHUFFLE = 0xFB,
APPLYSHUFFLE = 0xFC,
DELSHUFFLE = 0xFD,
# Data access
TRIPLE = 0x50,
BIT = 0x51,

View File

@@ -348,7 +348,7 @@ class Entry(object):
def __len__(self):
return 2 + len(self.x)
def __repr__(self):
return '{empty=%s}' % self.is_empty if self.is_empty \
return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \
else '{%s: %s}' % (self.v, self.x)
def __add__(self, other):
try:
@@ -466,12 +466,14 @@ class AbstractORAM(object):
def get_array(size, t, *args, **kwargs):
return t.dynamic_array(size, t, *args, **kwargs)
def read(self, index):
return self._read(self.value_type.hard_conv(index))
res = self._read(self.index_type.hard_conv(index))
res = [self.value_type._new(x) for x in res]
return res
def write(self, index, value):
value = util.tuplify(value)
value = [self.value_type.conv(x) for x in value]
new_value = [self.value_type.get_type(length).hard_conv(v) \
for length,v in zip(self.entry_size, value \
if isinstance(value, (tuple, list)) \
else (value,))]
for length,v in zip(self.entry_size, value)]
return self._write(self.index_type.hard_conv(index), *new_value)
def access(self, index, new_value, write, new_empty=False):
return self._access(self.index_type.hard_conv(index),
@@ -795,7 +797,8 @@ class RefTrivialORAM(EndRecursiveEviction):
for i,value in enumerate(values):
index = MemValue(self.value_type.hard_conv(i))
new_value = [MemValue(self.value_type.hard_conv(v)) \
for v in (value if isinstance(value, (tuple, list)) \
for v in (value if isinstance(
value, (tuple, list, Array)) \
else (value,))]
self.ram[i] = Entry(index, new_value, value_type=self.value_type)
@@ -986,7 +989,8 @@ class List(EndRecursiveEviction):
for i,value in enumerate(values):
index = self.value_type.hard_conv(i)
new_value = [self.value_type.hard_conv(v) \
for v in (value if isinstance(value, (tuple, list)) \
for v in (value if isinstance(
value, (tuple, list, Array)) \
else (value,))]
self.__setitem__(index, new_value)
def __repr__(self):
@@ -1062,11 +1066,12 @@ class TreeORAM(AbstractORAM):
stop_timer(1)
start_timer()
self.root = RefBucket(1, self)
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
self.index = self.index_structure(size, self.D, self.index_type,
init_rounds, True)
self.read_value = Array(self.value_length, value_type)
self.read_value = Array(self.value_length, value_type.default_type)
self.read_non_empty = MemValue(self.value_type.bit_type(0))
self.state = MemValue(self.value_type(0))
self.state = MemValue(self.value_type.default_type(0))
@method_block
def add_to_root(self, state, is_empty, v, *x):
if len(x) != self.value_length:
@@ -1106,10 +1111,10 @@ class TreeORAM(AbstractORAM):
self.evict_bucket(RefBucket(p_bucket2, self), d)
@method_block
def read_and_renew_index(self, u):
l_star = random_block(self.D, self.value_type)
l_star = random_block(self.D, self.index_type)
if use_insecure_randomness:
new_path = regint.get_random(self.D)
l_star = self.value_type(new_path)
l_star = self.index_type(new_path)
self.state.write(l_star)
return self.index.update(u, l_star, evict=False).reveal()
@method_block
@@ -1120,7 +1125,7 @@ class TreeORAM(AbstractORAM):
parallel = get_parallel(self.index_size, *self.internal_value_type())
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
[self.value_type.default_type] * self.value_length)
def process(level):
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
bucket = RefBucket(b_index, self)
@@ -1142,9 +1147,9 @@ class TreeORAM(AbstractORAM):
Program.prog.curr_tape.start_new_basicblock()
crash()
def internal_value_type(self):
return self.value_type, self.value_length + 1
return self.value_type.default_type, self.value_length + 1
def internal_entry_size(self):
return self.value_type, [self.D] + list(self.entry_size)
return self.value_type.default_type, [self.D] + list(self.entry_size)
def n_buckets(self):
return 2**(self.D+1)
@method_block
@@ -1176,8 +1181,9 @@ class TreeORAM(AbstractORAM):
#print 'pre-add', self
maybe_start_timer(4)
self.add_to_root(state, entry.empty(), \
self.value_type(entry.v.read()), \
*(self.value_type(i.read()) for i in entry.x))
self.index_type(entry.v.read()), \
*(self.value_type.default_type(i.read())
for i in entry.x))
maybe_stop_timer(4)
#print 'pre-evict', self
if evict:
@@ -1228,21 +1234,27 @@ class TreeORAM(AbstractORAM):
raise CompilerError('Batch initialization only possible with sint.')
depth = log2(m)
leaves = [0] * m
entries = [0] * m
indexed_values = [0] * m
leaves = self.value_type.Array(m)
indexed_values = \
self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1)
# assign indices 0, ..., m-1
for i,value in enumerate(values):
@for_range(m)
def _(i):
value = values[i]
index = MemValue(self.value_type.hard_conv(i))
new_value = [MemValue(self.value_type.hard_conv(v)) \
for v in (value if isinstance(value, (tuple, list)) \
else (value,))]
indexed_values[i] = [index] + new_value
entries = sint.Matrix(self.bucket_size * 2 ** self.D,
len(Entry(0, list(indexed_values[0]), False)))
# assign leaves
for i,index_value in enumerate(indexed_values):
@for_range(len(indexed_values))
def _(i):
index_value = list(indexed_values[i])
leaves[i] = random_block(self.D, self.value_type)
index = index_value[0]
@@ -1252,18 +1264,20 @@ class TreeORAM(AbstractORAM):
# save unsorted leaves for position map
unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves]
permutation.sort(leaves, comp=permutation.normal_comparator)
leaves.sort()
bucket_sz = 0
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
B = [[0]*3 for i in range(m)]
B = sint.Matrix(m, 3)
B[0] = [0, leaves[0], 0]
B[-1] = [None, None, sint(1)]
s = 0
s = MemValue(sint(0))
for i in range(1, m):
@for_range_opt(m - 1)
def _(j):
i = j + 1
eq = leaves[i].equal(leaves[i-1])
s = (s + eq) * eq
s.write((s + eq) * eq)
B[i][0] = s
B[i][1] = leaves[i]
B[i-1][2] = 1 - eq
@@ -1271,7 +1285,7 @@ class TreeORAM(AbstractORAM):
#last_in_bucket[i-1] = 1 - eq
# shuffle
permutation.shuffle(B, value_type=sint)
B.secure_shuffle()
#cint(0).print_reg('shuf')
sz = MemValue(0) #cint(0)
@@ -1279,7 +1293,8 @@ class TreeORAM(AbstractORAM):
empty_positions = Array(nleaves, self.value_type)
empty_leaves = Array(nleaves, self.value_type)
for i in range(m):
@for_range(m)
def _(i):
if_then(reveal(B[i][2]))
#if B[i][2] == 1:
#cint(i).print_reg('last')
@@ -1291,12 +1306,13 @@ class TreeORAM(AbstractORAM):
empty_positions[szval] = B[i][0] #pos[i][0]
#empty_positions[szval].reveal().print_reg('ps0')
empty_leaves[szval] = B[i][1] #pos[i][1]
sz += 1
sz.iadd(1)
end_if()
pos_bits = []
pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2)
for i in range(nleaves):
@for_range_opt(nleaves)
def _(i):
leaf = empty_leaves[i]
# split into 2 if bucket size can't fit into one field elem
if self.bucket_size + Program.prog.security > 128:
@@ -1315,46 +1331,39 @@ class TreeORAM(AbstractORAM):
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
else:
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
pos_bits += [[b, leaf] for b in bucket_bits]
assert len(bucket_bits) == self.bucket_size
for j, b in enumerate(bucket_bits):
pos_bits[i * self.bucket_size + j] = [b, leaf]
# sort to get empty positions first
permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator)
pos_bits.sort(n_bits=1)
# now assign positions to empty entries
empty_entries = [0] * (self.bucket_size*2**self.D - m)
for i in range(self.bucket_size*2**self.D - m):
@for_range(len(entries) - m)
def _(i):
vtype, vlength = self.internal_value_type()
leaf = vtype(pos_bits[i][1])
# set leaf in empty entry for assigning after shuffle
value = tuple([leaf] + [vtype(0) for j in range(vlength)])
value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)])
entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype)
empty_entries[i] = entry
entries[m + i] = entry
# now shuffle, reveal positions and place entries
entries = entries + empty_entries
while len(entries) & (len(entries)-1) != 0:
entries.append(None)
permutation.shuffle(entries, value_type=sint)
entries = [entry for entry in entries if entry is not None]
clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries]
entries.secure_shuffle()
clear_leaves = Array.create_from(
Entry(entries.get_columns()).x[0].reveal())
Program.prog.curr_tape.start_new_basicblock()
bucket_sizes = Array(2**self.D, regint)
for i in range(2**self.D):
bucket_sizes[i] = 0
k = 0
for entry,leaf in zip(entries, clear_leaves):
leaf = leaf.read()
k += 1
# for some reason leaf_buckets is in bit-reversed order
bits = bit_decompose(leaf, self.D)
rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1]))
bucket = RefBucket(rev_leaf + (1 << self.D), self)
# hack: 1*entry ensures MemValues are converted to sints
bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry
@for_range_opt(len(entries))
def _(k):
leaf = clear_leaves[k]
bucket = RefBucket(leaf + (1 << self.D), self)
bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k])
bucket_sizes[leaf] += 1
self.index.batch_init([leaf.read() for leaf in unsorted_leaves])
@@ -1599,16 +1608,20 @@ class PackedIndexStructure(object):
def batch_init(self, values):
""" Initialize m values with indices 0, ..., m-1 """
m = len(values)
n_entries = max(1, m/self.entries_per_block)
new_values = [0] * n_entries
n_entries = max(1, m//self.entries_per_block)
new_values = sint.Matrix(n_entries, self.elements_per_block)
values = Array.create_from(values)
for i in range(n_entries):
@for_range(n_entries)
def _(i):
block = [0] * self.elements_per_block
for j in range(self.elements_per_block):
base = i * self.entries_per_block + j * self.entries_per_element
for k in range(self.entries_per_element):
if base + k < m:
block[j] += values[base + k] << (k * self.entry_size)
@if_(base + k < m)
def _():
block[j] += \
values[base + k] << (k * sum(self.entry_size))
new_values[i] = block
@@ -1667,7 +1680,8 @@ def OptimalORAM(size,*args,**kwargs):
experiments.
:param size: number of elements
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn`
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
:py:class:`sfix`
"""
if optimal_threshold is None:
if n_threads == 1:
@@ -1784,7 +1798,7 @@ def test_batch_init(oram_type, N):
oram = oram_type(N, value_type)
print('initialized')
print_reg(cint(0), 'init')
oram.batch_init([value_type(i) for i in range(N)])
oram.batch_init(Array.create_from(sint(regint.inc(N))))
print_reg(cint(0), 'done')
@for_range(N)
def f(i):

View File

@@ -111,24 +111,6 @@ def bucket_size_sorter(x, y):
return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z])
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
""" Simulate secure shuffling with Waksman network for 2 players.
Returns the network switching config so it may be re-used later. """
n = len(x)
if n & (n-1) != 0:
raise CompilerError('shuffle requires n a power of 2')
if config is None:
config = permutation.configure_waksman(permutation.random_perm(n))
for i,c in enumerate(config):
config[i] = [value_type(b) for b in c]
permutation.waksman(x, config, reverse=reverse)
permutation.waksman(x, config, reverse=reverse)
return config
def LT(a, b):
a_bits = bit_decompose(a)
b_bits = bit_decompose(b)
@@ -472,10 +454,15 @@ class PathORAM(TreeORAM):
print_ln()
# shuffle entries and levels
while len(merged_entries) & (len(merged_entries)-1) != 0:
merged_entries.append(None) #self.root.bucket.empty_entry(False))
permutation.rec_shuffle(merged_entries, value_type=self.value_type)
merged_entries = [e for e in merged_entries if e is not None]
flat = []
for x in merged_entries:
flat += list(x[0]) + [x[1]]
flat = self.value_type(flat)
assert len(flat) % len(merged_entries) == 0
l = len(flat) // len(merged_entries)
shuffled = flat.secure_shuffle(l)
merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]]
for i in range(len(shuffled) // l)]
# need to copy entries/levels to memory for re-positioning
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)

View File

@@ -10,16 +10,6 @@ if '_Array' not in dir():
from Compiler.program import Program
_Array = Array
SORT_BITS = []
insecure_random = Random(0)
def predefined_comparator(x, y):
""" Assumes SORT_BITS is populated with the required sorting network bits """
if predefined_comparator.sort_bits_iter is None:
predefined_comparator.sort_bits_iter = iter(SORT_BITS)
return next(predefined_comparator.sort_bits_iter)
predefined_comparator.sort_bits_iter = None
def list_comparator(x, y):
""" Uses the first element in the list for comparison """
return x[0] < y[0]
@@ -37,10 +27,6 @@ def bitwise_comparator(x, y):
def cond_swap_bit(x,y, b):
""" swap if b == 1 """
if x is None:
return y, None
elif y is None:
return x, None
if isinstance(x, list):
t = [(xi - yi) * b for xi,yi in zip(x, y)]
return [xi - ti for xi,ti in zip(x, t)], \
@@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator):
else:
raise CompilerError('Length of list must be power of two')
def merge(a, b, comp):
""" General length merge (pads to power of 2) """
while len(a) & (len(a)-1) != 0:
a.append(None)
while len(b) & (len(b)-1) != 0:
b.append(None)
if len(a) < len(b):
a += [None] * (len(b) - len(a))
elif len(b) < len(a):
b += [None] * (len(b) - len(b))
t = a + b
odd_even_merge(t, comp)
for i,v in enumerate(t[::]):
if v is None:
t.remove(None)
return t
def sort(a, comp):
""" Pads to power of 2, sorts, removes padding """
length = len(a)
@@ -112,47 +81,12 @@ def sort(a, comp):
odd_even_merge_sort(a, comp)
del a[length:]
def recursive_merge(a, comp):
""" Recursively merge a list of sorted lists (initially sorted by size) """
if len(a) == 1:
return
# merge smallest two lists, place result in correct position, recurse
t = merge(a[0], a[1], comp)
del a[0]
del a[0]
added = False
for i,c in enumerate(a):
if len(c) >= len(t):
a.insert(i, t)
added = True
break
if not added:
a.append(t)
recursive_merge(a, comp)
# The following functionality for shuffling isn't used any more as it
# has been moved to the virtual machine. The code has been kept for
# reference.
def random_perm(n):
""" Generate a random permutation of length n
WARNING: randomness fixed at compile-time, this is NOT secure
"""
if not Program.prog.options.insecure:
raise CompilerError('no secure implementation of Waksman permution, '
'use --insecure to activate')
a = list(range(n))
for i in range(n-1, 0, -1):
j = insecure_random.randint(0, i)
t = a[i]
a[i] = a[j]
a[j] = t
return a
def inverse(perm):
inv = [None] * len(perm)
for i, p in enumerate(perm):
inv[p] = i
return inv
def configure_waksman(perm):
def configure_waksman(perm, n_iter=[0]):
top = n_iter == [0]
n = len(perm)
if n == 2:
return [(perm[0], perm[0])]
@@ -175,6 +109,7 @@ def configure_waksman(perm):
via = 0
j0 = j
while True:
n_iter[0] += 1
#print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2)
i = inv_perm[j]
@@ -209,8 +144,11 @@ def configure_waksman(perm):
assert sorted(p0) == list(range(n//2))
assert sorted(p1) == list(range(n//2))
p0_config = configure_waksman(p0)
p1_config = configure_waksman(p1)
p0_config = configure_waksman(p0, n_iter)
p1_config = configure_waksman(p1, n_iter)
if top:
print(n_iter[0], 'iterations for Waksman')
assert O[0] == 0, 'not a Waksman network'
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
def waksman(a, config, depth=0, start=0, reverse=False):
@@ -358,23 +296,10 @@ def iter_waksman(a, config, reverse=False):
# nblocks /= 2
# depth -= 1
def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
n = len(x)
if n & (n-1) != 0:
raise CompilerError('shuffle requires n a power of 2')
if config is None:
config = configure_waksman(random_perm(n))
for i,c in enumerate(config):
config[i] = [value_type.bit_type(b) for b in c]
waksman(x, config, reverse=reverse)
waksman(x, config, reverse=reverse)
def config_shuffle(n, value_type):
""" Compute config for oblivious shuffling.
Take mod 2 for active sec. """
perm = random_perm(n)
def config_from_perm(perm, value_type):
n = len(perm)
assert(list(sorted(perm))) == list(range(n))
if n & (n-1) != 0:
# pad permutation to power of 2
m = 2**int(math.ceil(math.log(n, 2)))
@@ -394,103 +319,3 @@ def config_shuffle(n, value_type):
for j,b in enumerate(c):
config[i * len(perm) + j] = b
return config
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
""" Simulate secure shuffling with Waksman network for 2 players.
WARNING: This is not a properly secure implementation but has roughly the right complexity.
Returns the network switching config so it may be re-used later. """
n = len(x)
m = 2**int(math.ceil(math.log(n, 2)))
assert n == m, 'only working for powers of two'
if config is None:
config = config_shuffle(n, value_type)
if isinstance(x, list):
if isinstance(x[0], list):
length = len(x[0])
assert len(x) == length
for i in range(length):
xi = Array(m, value_type.reg_type)
for j in range(n):
xi[j] = x[j][i]
for j in range(n, m):
xi[j] = value_type(0)
iter_waksman(xi, config, reverse=reverse)
iter_waksman(xi, config, reverse=reverse)
for j, y in enumerate(xi):
x[j][i] = y
else:
xa = Array(m, value_type.reg_type)
for i in range(n):
xa[i] = x[i]
for i in range(n, m):
xa[i] = value_type(0)
iter_waksman(xa, config, reverse=reverse)
iter_waksman(xa, config, reverse=reverse)
x[:] = xa
elif isinstance(x, Array):
if len(x) != m and config is None:
raise CompilerError('Non-power of 2 Array input not yet supported')
iter_waksman(x, config, reverse=reverse)
iter_waksman(x, config, reverse=reverse)
else:
raise CompilerError('Invalid type for shuffle:', type(x))
return config
def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None):
""" Shuffle a list of ORAM entries.
Randomly permutes the first "perm_size" entries, leaving the rest (empty
entry padding) in the same position. """
n = len(x)
l = len(x[0])
if n & (n-1) != 0:
raise CompilerError('Entries must be padded to power of two length.')
if perm_size is None:
perm_size = n
xarrays = [Array(n, value_type.reg_type) for i in range(l)]
for i in range(n):
for j,value in enumerate(x[i]):
if isinstance(value, MemValue):
xarrays[j][i] = value.read()
else:
xarrays[j][i] = value
if config is None:
config = config_shuffle(perm_size, value_type)
for xi in xarrays:
shuffle(xi, config, value_type, reverse)
for i in range(n):
x[i] = entry_cls(xarrays[j][i] for j in range(l))
return config
def sort_zeroes(bits, x, n_ones, value_type):
""" Return Array of values in "x" where the corresponding bit in "bits" is
a 0.
The total number of zeroes in "bits" must be known.
"bits" and "x" must be Arrays. """
config = config_shuffle(len(x), value_type)
shuffle(bits, config=config, value_type=value_type)
shuffle(x, config=config, value_type=value_type)
result = Array(n_ones, value_type.reg_type)
sz = MemValue(0)
last_x = MemValue(value_type(0))
#for i,b in enumerate(bits):
#if_then(b.reveal() == 0)
#result[sz.read()] = x[i]
#sz += 1
#end_if()
@for_range(len(bits))
def f(i):
found = (bits[i].reveal() == 0)
szval = sz.read()
result[szval] = last_x + (x[i] - last_x) * found
sz.write(sz + found)
last_x.write(result[szval])
return result

54
Compiler/sorting.py Normal file
View File

@@ -0,0 +1,54 @@
import itertools
from Compiler import types, library, instructions
def dest_comp(B):
Bt = B.transpose()
Bt_flat = Bt.get_vector()
St_flat = Bt.value_type.Array(len(Bt_flat))
St_flat.assign(Bt_flat)
@library.for_range(len(St_flat) - 1)
def _(i):
St_flat[i + 1] = St_flat[i + 1] + St_flat[i]
Tt_flat = Bt.get_vector() * St_flat.get_vector()
Tt = types.Matrix(*Bt.sizes, B.value_type)
Tt.assign_vector(Tt_flat)
return sum(Tt) - 1
def reveal_sort(k, D, reverse=False):
assert len(k) == len(D)
library.break_point()
shuffle = types.sint.get_secure_shuffle(len(k))
k_prime = k.get_vector().secure_permute(shuffle).reveal()
idx = types.Array.create_from(k_prime)
if reverse:
D.assign_vector(D.get_slice_vector(idx))
library.break_point()
D.secure_permute(shuffle, reverse=True)
else:
D.secure_permute(shuffle)
library.break_point()
v = D.get_vector()
D.assign_slice_vector(idx, v)
library.break_point()
instructions.delshuffle(shuffle)
def radix_sort(k, D, n_bits=None, signed=True):
assert len(k) == len(D)
bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits))
if signed and len(bs) > 1:
bs[-1][:] = bs[-1][:].bit_not()
B = types.sint.Matrix(len(k), 2)
h = types.Array.create_from(types.sint(types.regint.inc(len(k))))
@library.for_range(len(bs))
def _(i):
b = bs[i]
B.set_column(0, 1 - b.get_vector())
B.set_column(1, b.get_vector())
c = types.Array.create_from(dest_comp(B))
reveal_sort(c, h, reverse=False)
@library.if_e(i < len(bs) - 1)
def _():
reveal_sort(h, bs[i + 1], reverse=True)
@library.else_
def _():
reveal_sort(h, D, reverse=True)

View File

@@ -1937,6 +1937,11 @@ class _secret(_register, _secret_structure):
matmuls(res, A, B, n_rows, n, n_cols)
return res
@staticmethod
def _new(self):
# mirror sfix
return self
@no_doc
def __init__(self, reg_type, val=None, size=None):
if isinstance(val, self.clear_type):
@@ -2093,6 +2098,12 @@ class _secret(_register, _secret_structure):
else:
return self * self
@set_instruction_type
def secure_shuffle(self, unit_size=1):
res = type(self)(size=self.size)
secshuffle(res, self, unit_size)
return res
@set_instruction_type
@vectorize
def reveal(self):
@@ -2741,6 +2752,17 @@ class sint(_secret, _int):
return w
@staticmethod
def get_secure_shuffle(n):
res = regint()
gensecshuffle(res, n)
return res
def secure_permute(self, shuffle, unit_size=1, reverse=False):
res = sint(size=self.size)
applyshuffle(res, self, unit_size, shuffle, reverse)
return res
class sintbit(sint):
""" :py:class:`sint` holding a bit, supporting binary operations
(``&, |, ^``). """
@@ -4291,6 +4313,10 @@ class _fix(_single):
k = self.k
return revealed_fix._new(val)
def bit_decompose(self, n_bits=None):
""" Bit decomposition. """
return self.v.bit_decompose(n_bits or self.k)
class sfix(_fix):
""" Secret fixed-point number represented as secret integer, by
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
@@ -4312,6 +4338,8 @@ class sfix(_fix):
int_type = sint
bit_type = sintbit
clear_type = cfix
get_type = staticmethod(lambda n: sint)
default_type = sint
@vectorized_classmethod
def get_input_from(cls, player):
@@ -4385,6 +4413,10 @@ class sfix(_fix):
def coerce(self, other):
return parse_type(other, k=self.k, f=self.f)
def hard_conv_me(self, cls):
assert cls == sint
return self.v
def mul_no_reduce(self, other, res_params=None):
assert self.f == other.f
assert self.k == other.k
@@ -4409,6 +4441,14 @@ class sfix(_fix):
return personal(player, cfix._new(self.v.reveal_to(player)._v,
self.k, self.f))
def secure_shuffle(self, *args, **kwargs):
return self._new(self.v.secure_shuffle(*args, **kwargs),
k=self.k, f=self.f)
def secure_permute(self, *args, **kwargs):
return self._new(self.v.secure_permute(*args, **kwargs),
k=self.k, f=self.f)
class unreduced_sfix(_single):
int_type = sint
@@ -5395,13 +5435,21 @@ class Array(_vectorizable):
regint.inc(len(indices), self.address, 0) + indices,
size=len(indices))
def get_slice_vector(self, slice):
def get_slice_addresses(self, slice):
assert self.value_type.n_elements() == 1
assert len(slice) <= self.total_size()
base = regint.inc(len(slice), slice.address, 1, 1)
inc = regint.inc(len(slice), 0, 1, 1, 1)
inc = regint.inc(len(slice), self.address, 1, 1, 1)
addresses = slice.value_type.load_mem(base) + inc
return self.value_type.load_mem(self.address + addresses)
return addresses
def get_slice_vector(self, slice):
addresses = self.get_slice_addresses(slice)
return self.value_type.load_mem(addresses)
def assign_slice_vector(self, slice, vector):
addresses = self.get_slice_addresses(slice)
vector.store_in_mem(addresses)
def expand_to_vector(self, index, size):
""" Create vector from single entry.
@@ -5514,6 +5562,14 @@ class Array(_vectorizable):
""" Insecure shuffle in place. """
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
def secure_shuffle(self):
""" Secure shuffle in place according to the security model. """
self.assign_vector(self.get_vector().secure_shuffle())
def secure_permute(self, *args, **kwargs):
""" Secure permutate in place according to the security model. """
self.assign_vector(self.get_vector().secure_permute(*args, **kwargs))
def randomize(self, *args):
""" Randomize according to data type. """
self.assign_vector(self.value_type.get_random(*args, size=len(self)))
@@ -5570,15 +5626,26 @@ class Array(_vectorizable):
"""
return personal(player, self.create_from(self[:].reveal_to(player)._v))
def sort(self, n_threads=None):
def sort(self, n_threads=None, batcher=False, n_bits=None):
"""
Sort in place using Batchers' odd-even merge mergesort
with complexity :math:`O(n (\log n)^2)`.
Sort in place using radix sort with complexity :math:`O(n \log
n)` for :py:class:`sint` and :py:class:`sfix`, and Batcher's
odd-even mergesort with :math:`O(n (\log n)^2)` for
:py:class:`sfloat`.
:param n_threads: number of threads to use (single thread by
default)
default), need to use Batcher's algorithm for several threads
:param batcher: use Batcher's odd-even mergesort in any case
:param n_bits: number of bits in keys (default: global bit length)
"""
library.loopy_odd_even_merge_sort(self, n_threads=n_threads)
if batcher or self.value_type.n_elements() > 1:
library.loopy_odd_even_merge_sort(self, n_threads=n_threads)
else:
if n_threads or 1 > 1:
raise CompilerError('multi-threaded sorting only implemented '
'with Batcher\'s odd-even mergesort')
import sorting
sorting.radix_sort(self, self, n_bits=n_bits)
def Array(self, size):
# compatibility with registers
@@ -5619,6 +5686,8 @@ class SubMultiArray(_vectorizable):
:return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise"""
if isinstance(index, slice) and index == slice(None):
return self.get_vector()
if isinstance(index, int) and index < 0:
index += self.sizes[0]
key = program.curr_block, str(index)
if key not in self.sub_cache:
if util.is_constant(index) and \
@@ -5673,6 +5742,10 @@ class SubMultiArray(_vectorizable):
def total_size(self):
return reduce(operator.mul, self.sizes) * self.value_type.n_elements()
def part_size(self):
return reduce(operator.mul, self.sizes[1:]) * \
self.value_type.n_elements()
def get_vector(self, base=0, size=None):
""" Return vector with content. Not implemented for floating-point.
@@ -5731,13 +5804,21 @@ class SubMultiArray(_vectorizable):
:param slice: regint array
"""
addresses = self.get_slice_addresses(slice)
return self.value_type.load_mem(self.address + addresses)
def assign_slice_vector(self, slice, vector):
addresses = self.get_slice_addresses(slice)
vector.store_in_mem(self.address + addresses)
def get_slice_addresses(self, slice):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
assert len(slice) * part_size <= self.total_size()
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
addresses = slice.value_type.load_mem(base) * part_size + inc
return self.value_type.load_mem(self.address + addresses)
return addresses
def get_addresses(self, *indices):
assert self.value_type.n_elements() == 1
@@ -6218,6 +6299,31 @@ class SubMultiArray(_vectorizable):
n = self.sizes[0]
return self.array.get(regint.inc(n, 0, n + 1))
def secure_shuffle(self):
""" Securely shuffle rows (first index). """
self.assign_vector(self.get_vector().secure_shuffle(self.part_size()))
def secure_permute(self, permutation, reverse=False):
""" Securely permute rows (first index). """
self.assign_vector(self.get_vector().secure_permute(
permutation, self.part_size(), reverse))
def sort(self, key_indices=None, n_bits=None):
""" Sort sub-arrays (different first index) in place.
:param key_indices: indices to sorting keys, for example
``(1, 2)`` to sort three-dimensional array ``a`` by keys
``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length.
:param n_bits: number of bits in keys (default: global bit length)
"""
if key_indices is None:
key_indices = (0,) * (len(self.sizes) - 1)
key_indices = (None,) + util.tuplify(key_indices)
import sorting
keys = self.get_vector_by_indices(*key_indices)
sorting.radix_sort(keys, self, n_bits=n_bits)
def randomize(self, *args):
""" Randomize according to data type. """
if self.total_size() < program.options.budget:
@@ -6334,6 +6440,18 @@ class Matrix(MultiArray):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address)
@staticmethod
def create_from(rows):
rows = list(rows)
if isinstance(rows[0], (list, tuple)):
t = type(rows[0][0])
else:
t = type(rows[0])
res = Matrix(len(rows), len(rows[0]), t)
for i in range(len(rows)):
res[i].assign(rows[i])
return res
def get_column(self, index):
""" Get column as vector.
@@ -6344,6 +6462,9 @@ class Matrix(MultiArray):
self.sizes[1])
return self.value_type.load_mem(addresses)
def get_columns(self):
return (self.get_column(i) for i in range(self.sizes[1]))
def get_column_by_row_indices(self, rows, column):
assert self.value_type.n_elements() == 1
addresses = rows * self.sizes[1] + \

View File

@@ -47,17 +47,8 @@ inline void receive(client_socket* socket, octet* data, size_t len)
#else
typedef ssl_ctx client_ctx;
typedef ssl_socket client_socket;
class client_socket : public ssl_socket
{
public:
client_socket(boost::asio::io_service& io_service,
boost::asio::ssl::context& ctx, int plaintext_socket, string other,
string me, bool client) :
ssl_socket(io_service, ctx, plaintext_socket, other, me, client)
{
}
};
#endif
/**

View File

@@ -58,7 +58,8 @@ public:
{
if (this->size() != y.size())
throw out_of_range("vector length mismatch");
for (unsigned int i = 0; i < this->size(); i++)
size_t n = this->size();
for (unsigned int i = 0; i < n; i++)
(*this)[i] += y[i];
return *this;
}
@@ -67,9 +68,11 @@ public:
{
if (this->size() != y.size())
throw out_of_range("vector length mismatch");
AddableVector<T> res(y.size());
for (unsigned int i = 0; i < this->size(); i++)
res[i] = (*this)[i] - y[i];
AddableVector<T> res;
res.reserve(y.size());
size_t n = this->size();
for (unsigned int i = 0; i < n; i++)
res.push_back((*this)[i] - y[i]);
return res;
}

View File

@@ -31,6 +31,12 @@ word check_pk_id(word a, word b)
}
void Ciphertext::Scale()
{
Scale(params->get_plaintext_modulus());
}
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1)
{
if (c0.params!=c1.params) { throw params_mismatch(); }
@@ -115,9 +121,28 @@ void Ciphertext::add(octetStream& os)
*this += tmp;
}
void Ciphertext::rerandomize(const FHE_PK& pk)
{
Rq_Element tmp(*params);
SeededPRNG G;
vector<FFT_Data::S> r(params->FFTD()[0].m());
bigint p = pk.p();
assert(p != 0);
for (auto& x : r)
{
G.get<FFT_Data::S>(x, params->p0().numBits() - p.numBits() - 1);
x *= p;
}
tmp.from(r, 0);
Scale();
cc0 += tmp;
auto zero = pk.encrypt(*params);
zero.Scale(pk.p());
*this += zero;
}
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gf2n_short,P2Data,int>& a,const Ciphertext& c);
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
const Ciphertext& c);

View File

@@ -15,6 +15,12 @@ template<class T,class FD,class S> void mul(Ciphertext& ans,const Ciphertext& c,
void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1);
void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk);
/**
* BGV ciphertext.
* The class allows adding two ciphertexts as well as adding a plaintext and
* a ciphertext via operator overloading. The multiplication of two ciphertexts
* requires the public key and thus needs a separate function.
*/
class Ciphertext
{
Rq_Element cc0,cc1;
@@ -54,6 +60,7 @@ class Ciphertext
// Scale down an element from level 1 to level 0, if at level 0 do nothing
void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); }
void Scale();
// Throws error if ans,c0,c1 etc have different params settings
// - Thus programmer needs to ensure this rather than this being done
@@ -90,6 +97,12 @@ class Ciphertext
template <class FD>
Ciphertext& operator*=(const Plaintext_<FD>& other) { ::mul(*this, *this, other); return *this; }
/**
* Ciphertext multiplication.
* @param pk public key
* @param x second ciphertext
* @returns product ciphertext
*/
Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const
{ Ciphertext res(*params); ::mul(res, *this, x, pk); return res; }
@@ -98,14 +111,18 @@ class Ciphertext
return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this};
}
/// Re-randomize for circuit privacy.
void rerandomize(const FHE_PK& pk);
int level() const { return cc0.level(); }
// pack/unpack (like IO) also assume params are known and already set
// correctly
/// Append to buffer
void pack(octetStream& o) const
{ cc0.pack(o); cc1.pack(o); o.store(pk_id); }
void unpack(octetStream& o)
{ cc0.unpack(o); cc1.unpack(o); o.get(pk_id); }
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o)
{ cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); }
void output(ostream& s) const
{ cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); }

View File

@@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag(
{
auto& c = products.at(i);
for (int j = 0; j < n_matrices; j++)
{
res.at(j).entries.init();
for (size_t k = 0; k < n_rows; k++)
res.at(j)[{k, i}] = c.element(j * n_rows + k);
}
}
return res;
}

View File

@@ -7,6 +7,11 @@
FFT_Data::FFT_Data() :
twop(-1)
{
}
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;

View File

@@ -50,7 +50,7 @@ class FFT_Data
void pack(octetStream& o) const;
void unpack(octetStream& o);
FFT_Data() { ; }
FFT_Data();
FFT_Data(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }

View File

@@ -12,6 +12,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p())
{
}
FHE_SK::FHE_SK(const FHE_Params& pms) :
FHE_SK(pms, pms.get_plaintext_modulus())
{
}
FHE_SK& FHE_SK::operator+=(const FHE_SK& c)
{
@@ -38,6 +43,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G)
}
FHE_PK::FHE_PK(const FHE_Params& pms) :
FHE_PK(pms, pms.get_plaintext_modulus())
{
}
Rq_Element FHE_PK::sample_secret_key(PRNG& G)
{
Rq_Element sk = FHE_SK(*this).s();
@@ -179,32 +189,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext<typename FD::T, FD, typename FD::S>&
template<class FD>
Ciphertext FHE_PK::encrypt(
const Plaintext<typename FD::T, FD, typename FD::S>& mess) const
{
return encrypt(Rq_Element(*params, mess));
}
Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const
{
Random_Coins rc(*params);
PRNG G;
G.ReSeed();
rc.generate(G);
return encrypt(mess, rc);
Ciphertext res(*params);
quasi_encrypt(res, mess, rc);
return res;
}
template<class T, class FD, class S>
void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element ans = quasi_decrypt(c);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
}
Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
ans.change_rep(polynomial);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
return ans;
}
Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c)
{
return decrypt(c, params->get_plaintext_field_data<FFT_Data>());
}
template<class FD>
Plaintext<typename FD::T, FD, typename FD::S> FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD)
{
@@ -299,12 +328,12 @@ void FHE_PK::unpack(octetStream& o)
o.consume((octet*) tag, 8);
if (memcmp(tag, "PKPKPKPK", 8))
throw runtime_error("invalid serialization of public key");
a0.unpack(o);
b0.unpack(o);
a0.unpack(o, *params);
b0.unpack(o, *params);
if (params->n_mults() > 0)
{
Sw_a.unpack(o);
Sw_b.unpack(o);
Sw_a.unpack(o, *params);
Sw_b.unpack(o, *params);
}
pr.unpack(o);
}
@@ -322,7 +351,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const
return false;
}
void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
const bigint& pr) const
{
@@ -345,8 +373,6 @@ void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
throw runtime_error("incorrect key pair");
}
void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
{
if (this->pr != pr)
@@ -361,6 +387,24 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
}
}
bigint FHE_SK::get_noise(const Ciphertext& c)
{
sk.lower_level();
Ciphertext cc = c;
if (cc.level())
cc.Scale();
Rq_Element tmp = quasi_decrypt(cc);
bigint res;
bigint q = tmp.get_modulus();
bigint half_q = q / 2;
for (auto& x : tmp.to_vec_bigint())
{
// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " ";
res = max(res, x > half_q ? x - q : x);
}
return res;
}
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,

View File

@@ -12,6 +12,10 @@
class FHE_PK;
class Ciphertext;
/**
* BGV secret key.
* The class allows addition.
*/
class FHE_SK
{
Rq_Element sk;
@@ -29,6 +33,8 @@ class FHE_SK
// secret key always on lower level
void assign(const Rq_Element& s) { sk=s; sk.lower_level(); }
FHE_SK(const FHE_Params& pms);
FHE_SK(const FHE_Params& pms, const bigint& p)
: sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; }
@@ -38,8 +44,11 @@ class FHE_SK
const Rq_Element& s() const { return sk; }
/// Append to buffer
void pack(octetStream& os) const { sk.pack(os); pr.pack(os); }
void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); }
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); }
// Assumes Ring and prime of mess have already been set correctly
// Ciphertext c must be at level 0 or an error occurs
@@ -50,9 +59,14 @@ class FHE_SK
template <class FD>
Plaintext<typename FD::T, FD, typename FD::S> decrypt(const Ciphertext& c, const FD& FieldD);
/// Decryption for cleartexts modulo prime
Plaintext_<FFT_Data> decrypt(const Ciphertext& c);
template <class FD>
void decrypt_any(Plaintext_<FD>& mess, const Ciphertext& c);
Rq_Element quasi_decrypt(const Ciphertext& c) const;
// Three stage procedure for Distributed Decryption
// - First stage produces my shares
// - Second stage adds in another players shares, do this once for each other player
@@ -62,7 +76,6 @@ class FHE_SK
void dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_number,int num_players) const;
void dist_decrypt_2(vector<bigint>& vv,const vector<bigint>& vv1) const;
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
/* Add secret keys
@@ -82,10 +95,15 @@ class FHE_SK
template<class FD>
void check(const FHE_PK& pk, const FD& FieldD);
bigint get_noise(const Ciphertext& c);
friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; }
};
/**
* BGV public key.
*/
class FHE_PK
{
Rq_Element a0,b0;
@@ -104,8 +122,10 @@ class FHE_PK
)
{ a0=a; b0=b; Sw_a=sa; Sw_b=sb; }
FHE_PK(const FHE_Params& pms, const bigint& p = 0)
FHE_PK(const FHE_Params& pms);
FHE_PK(const FHE_Params& pms, const bigint& p)
: a0(pms.FFTD(),evaluation,evaluation),
b0(pms.FFTD(),evaluation,evaluation),
Sw_a(pms.FFTD(),evaluation,evaluation),
@@ -143,8 +163,11 @@ class FHE_PK
template <class FD>
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess, const Random_Coins& rc) const;
/// Encryption
template <class FD>
Ciphertext encrypt(const Plaintext<typename FD::T, FD, typename FD::S>& mess) const;
Ciphertext encrypt(const Rq_Element& mess) const;
friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
@@ -156,8 +179,10 @@ class FHE_PK
void check_noise(const FHE_SK& sk) const;
void check_noise(const Rq_Element& x, bool check_modulo = false) const;
// params setting is done out of these IO/pack/unpack functions
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o);
bool operator!=(const FHE_PK& x) const;
@@ -170,21 +195,39 @@ class FHE_PK
void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G);
/**
* BGV key pair
*/
class FHE_KeyPair
{
public:
/// Public key
FHE_PK pk;
/// Secret key
FHE_SK sk;
FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) :
FHE_KeyPair(const FHE_Params& params, const bigint& pr) :
pk(params, pr), sk(params, pr)
{
}
/// Initialization
FHE_KeyPair(const FHE_Params& params) :
pk(params), sk(params)
{
}
void generate(PRNG& G)
{
KeyGen(pk, sk, G);
}
/// Generate fresh keys
void generate()
{
SeededPRNG G;
generate(G);
}
};
template <class S>

View File

@@ -1,5 +1,6 @@
#include "FHE_Params.h"
#include "NTL-Subs.h"
#include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h"
#include "Protocols/HemiOptions.h"
@@ -67,6 +68,7 @@ void FHE_Params::pack(octetStream& o) const
Bval.pack(o);
o.store(sec_p);
o.store(matrix_dim);
fd.pack(o);
}
void FHE_Params::unpack(octetStream& o)
@@ -80,6 +82,7 @@ void FHE_Params::unpack(octetStream& o)
Bval.unpack(o);
o.get(sec_p);
o.get(matrix_dim);
fd.unpack(o);
}
bool FHE_Params::operator!=(const FHE_Params& other) const
@@ -92,3 +95,37 @@ bool FHE_Params::operator!=(const FHE_Params& other) const
else
return false;
}
void FHE_Params::basic_generation_mod_prime(int plaintext_length)
{
if (n_mults() == 0)
generate_semi_setup(plaintext_length, 0, *this, fd, false);
else
{
Parameters parameters(1, plaintext_length, 0);
parameters.generate_setup(*this, fd);
}
}
template<>
const FFT_Data& FHE_Params::get_plaintext_field_data() const
{
return fd;
}
template<>
const P2Data& FHE_Params::get_plaintext_field_data() const
{
throw not_implemented();
}
template<>
const PPData& FHE_Params::get_plaintext_field_data() const
{
throw not_implemented();
}
bigint FHE_Params::get_plaintext_modulus() const
{
return fd.get_prime();
}

View File

@@ -15,6 +15,9 @@
#include "Tools/random.h"
#include "Protocols/config.h"
/**
* Cryptosystem parameters
*/
class FHE_Params
{
protected:
@@ -29,8 +32,15 @@ class FHE_Params
bigint Bval;
int matrix_dim;
FFT_Data fd;
public:
/**
* Initialization.
* @param n_mults number of ciphertext multiplications (0/1)
* @param drown_sec parameter for function privacy (default 40)
*/
FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY);
int n_mults() const { return FFTData.size() - 1; }
@@ -59,10 +69,24 @@ class FHE_Params
int phi_m() const { return FFTData[0].phi_m(); }
const Ring& get_ring() { return FFTData[0].get_R(); }
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer
void unpack(octetStream& o);
bool operator!=(const FHE_Params& other) const;
/**
* Generate parameter for computation modulo a prime
* @param plaintext_length bit length of prime
*/
void basic_generation_mod_prime(int plaintext_length);
template<class FD>
const FD& get_plaintext_field_data() const;
bigint get_plaintext_modulus() const;
};
#endif

View File

@@ -107,10 +107,12 @@ int generate_semi_setup(int plaintext_length, int sec,
int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up)
{
#ifdef VERBOSE
cout << "Need ciphertext modulus of length " << lgp0;
if (params.n_mults() > 0)
cout << "+" << lgp1;
cout << " and " << phi_N(m) << " slots" << endl;
#endif
int extra_slack = 0;
if (round_up)
@@ -125,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1,
}
extra_slack = i - 1;
lgp0 += extra_slack;
#ifdef VERBOSE
cout << "Rounding up to " << lgp0 << ", giving extra slack of "
<< extra_slack << " bits" << endl;
#endif
}
Ring R;
@@ -148,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1,
int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
bool round_up, FHE_Params& params)
{
(void) lg2pi, (void) n;
#ifdef VERBOSE
if (n >= 2 and n <= 10)
cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2]
<< ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl;
cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl;
cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl;
#endif
int extra_slack = 0;
if (round_up)
@@ -171,11 +179,15 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
extra_slack = 2 * i;
lg2p0 += i;
lg2p1 += i;
#ifdef VERBOSE
cout << "Rounding up to " << lg2p0 << "+" << lg2p1
<< ", giving extra slack of " << extra_slack << " bits" << endl;
#endif
}
#ifdef VERBOSE
cout << "Total length: " << lg2p0 + lg2p1 << endl;
#endif
return extra_slack;
}
@@ -215,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p,
{
double phi_m_bound =
NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1);
#ifdef VERBOSE
cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl;
#endif
if (phi_N(m) < phi_m_bound)
{
int old_m = m;
(void) old_m;
m = 2 << int(ceil(log2(phi_m_bound)));
#ifdef VERBOSE
cout << "m = " << old_m << " too small, increasing it to " << m << endl;
#endif
generate_prime(p, numBits(p), m);
}
else
@@ -244,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p,
void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
const string& i, const bigint& pr0)
{
(void) i;
if (lg2pr==0) { throw invalid_params(); }
bigint step=m;
@@ -260,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
assert(numBits(pr) == lg2pr);
}
#ifdef VERBOSE
cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl;
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
#endif
assert(pr % m == 1);
assert(pr % p == 1);
assert(numBits(pr) == lg2pr);
cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl;
}
/*
@@ -626,6 +650,9 @@ void char_2_dimension(int& m, int& lg2)
case 16:
m = 4369;
break;
case 15:
m = 4681;
break;
case 12:
m = 4095;
break;

View File

@@ -167,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1)
bigint NoiseBounds::min_p1()
{
return drown * B_KS + 1;
return max(bigint(drown * B_KS), bigint((phi_m * p) << 10));
}
bigint NoiseBounds::opt_p1()
@@ -181,8 +181,10 @@ bigint NoiseBounds::opt_p1()
// solve
mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a);
bigint res = ceil(s);
#ifdef VERBOSE
cout << "Optimal p1 vs minimal: " << numBits(res) << "/"
<< numBits(min_p1()) << endl;
#endif
return res;
}
@@ -194,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1)
{
min_p0 *= 2;
min_p1 *= 2;
#ifdef VERBOSE
cout << "increasing lengths: " << numBits(min_p0) << "/"
<< numBits(min_p1) << endl;
#endif
}
lg2p1 = numBits(min_p1);
lg2p0 = numBits(min_p0);

View File

@@ -42,6 +42,8 @@ public:
bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); }
static double min_phi_m(int log_q, double sigma);
static double min_phi_m(int log_q, const FHE_Params& params);
bigint get_B_clean() { return B_clean; }
};
// as per ePrint 2012:642 for slack = 0

View File

@@ -55,13 +55,13 @@ void P2Data::check_dimensions() const
// cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl;
if (A.size() != Ai.size())
throw runtime_error("forward and backward mapping dimensions mismatch");
if (A.size() != A[0].size())
if (A.size() != A.at(0).size())
throw runtime_error("forward mapping not square");
if (Ai.size() != Ai[0].size())
if (Ai.size() != Ai.at(0).size())
throw runtime_error("backward mapping not square");
if ((int)A[0].size() != slots * gf2n_short::degree())
if ((int)A.at(0).size() != slots * gf2n_short::degree())
throw runtime_error(
"mapping dimension incorrect: " + to_string(A[0].size())
"mapping dimension incorrect: " + to_string(A.at(0).size())
+ " != " + to_string(slots) + " * "
+ to_string(gf2n_short::degree()));
}

View File

@@ -11,10 +11,43 @@
template<class T, class FD, class S>
Plaintext<T, FD, S>::Plaintext(const FHE_Params& params) :
Plaintext(params.get_plaintext_field_data<FD>(), Both)
{
}
template<class T, class FD, class S>
unsigned int Plaintext<T, FD, S>::num_slots() const
{
return (*Field_Data).phi_m();
}
template<class T, class FD, class S>
int Plaintext<T, FD, S>::degree() const
{
return (*Field_Data).phi_m();
}
template<>
unsigned int Plaintext<gf2n_short,P2Data,int>::num_slots() const
{
return (*Field_Data).num_slots();
}
template<>
int Plaintext<gf2n_short,P2Data,int>::degree() const
{
return (*Field_Data).degree();
}
template<>
void Plaintext<gfp, FFT_Data, bigint>::from(const Generator<bigint>& source) const
{
b.resize(degree);
b.resize(degree());
for (auto& x : b)
{
source.get(bigint::tmp);
@@ -31,7 +64,7 @@ void Plaintext<gfp,FFT_Data,bigint>::from_poly() const
Ring_Element e(*Field_Data,polynomial);
e.from(b);
e.change_rep(evaluation);
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
a[i] = gfp(e.get_element(i), e.get_FFTD().get_prD());
type=Both;
@@ -60,7 +93,7 @@ void Plaintext<gfp,PPData,bigint>::from_poly() const
for (unsigned int i=0; i<aa.size(); i++)
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
(*Field_Data).to_eval(aa);
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<aa.size(); i++)
a[i] = {aa[i], Field_Data->get_prD()};
type=Both;
@@ -90,7 +123,7 @@ template<>
void Plaintext<gf2n_short,P2Data,int>::from_poly() const
{
if (type!=Polynomial) { return; }
a.resize(n_slots);
a.resize(num_slots());
(*Field_Data).backward(a,b);
type=Both;
}
@@ -106,34 +139,13 @@ void Plaintext<gf2n_short,P2Data,int>::to_poly() const
template<>
void Plaintext<gfp,FFT_Data,bigint>::set_sizes()
{ n_slots = (*Field_Data).phi_m();
degree = n_slots;
}
template<>
void Plaintext<gfp,PPData,bigint>::set_sizes()
{ n_slots = (*Field_Data).phi_m();
degree = n_slots;
}
template<>
void Plaintext<gf2n_short,P2Data,int>::set_sizes()
{ n_slots = (*Field_Data).num_slots();
degree = (*Field_Data).degree();
}
template<class T, class FD, class S>
void Plaintext<T, FD, S>::allocate(PT_Type type) const
{
if (type != Evaluation)
b.resize(degree);
b.resize(degree());
if (type != Polynomial)
a.resize(n_slots);
a.resize(num_slots());
this->type = type;
}
@@ -141,7 +153,7 @@ void Plaintext<T, FD, S>::allocate(PT_Type type) const
template<class T, class FD, class S>
void Plaintext<T, FD, S>::allocate_slots(const bigint& value)
{
b.resize(degree);
b.resize(degree());
for (auto& x : b)
x.allocate_slots(value);
}
@@ -236,7 +248,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
type=Polynomial;
break;
case Diagonal:
a.resize(n_slots);
a.resize(num_slots());
a[0].randomize(G);
for (unsigned int i=1; i<a.size(); i++)
{ a[i]=a[0]; }
@@ -244,7 +256,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
break;
default:
// Gen a plaintext with 0/1 in each slot
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{
if (G.get_bit())
@@ -272,7 +284,7 @@ void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
b[0].generateUniform(G, n_bits, false);
}
else
for (int i = 0; i < n_slots; i++)
for (size_t i = 0; i < num_slots(); i++)
b[i].generateUniform(G, n_bits, false);
break;
default:
@@ -288,7 +300,7 @@ void Plaintext<T,FD,S>::assign_zero(PT_Type t)
allocate();
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].assign_zero(); }
}
@@ -306,7 +318,7 @@ void Plaintext<T,FD,S>::assign_one(PT_Type t)
allocate();
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].assign_one(); }
}
@@ -359,7 +371,7 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
@@ -387,7 +399,7 @@ void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
@@ -418,7 +430,7 @@ void add(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i].add(x.a[i],y.a[i]); }
}
@@ -446,7 +458,7 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i]= (x.a[i] - y.a[i]); }
}
@@ -478,7 +490,7 @@ void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] - y.a[i]); }
}
@@ -510,7 +522,7 @@ void sub(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,i
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.n_slots);
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i].sub(x.a[i],y.a[i]); }
}
@@ -545,7 +557,7 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
@@ -565,7 +577,7 @@ void Plaintext<gfp,PPData,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
@@ -607,7 +619,7 @@ bool Plaintext<T,FD,S>::equals(const Plaintext& x) const
if (type!=Polynomial and x.type!=Polynomial)
{
a.resize(n_slots);
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ if (!(a[i] == x.a[i])) { return false; } }
}
@@ -671,9 +683,9 @@ void Plaintext<T,FD,S>::unpack(octetStream& o)
unsigned int size;
o.get(size);
allocate();
if (size != b.size())
if (size != b.size() and size != 0)
throw length_error("unexpected length received");
for (unsigned int i = 0; i < b.size(); i++)
for (unsigned int i = 0; i < size; i++)
b[i] = o.get<S>();
}

View File

@@ -18,6 +18,7 @@
*/
#include "FHE/Generator.h"
#include "FHE/FFT_Data.h"
#include "Math/fixint.h"
#include <vector>
@@ -25,6 +26,8 @@ using namespace std;
class FHE_PK;
class Rq_Element;
class FHE_Params;
class FFT_Data;
template<class T> class AddableVector;
// Forward declaration as apparently this is needed for friends in templates
@@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits };
enum PT_Type { Polynomial, Evaluation, Both };
/**
* BGV plaintext.
* Use ``Plaintext_mod_prime`` instead of filling in the templates.
* The plaintext is held in one of the two representations or both,
* polynomial and evaluation. The latter is the one allowing element-wise
* multiplication over a vector.
* Plaintexts can be added, subtracted, and multiplied via operator overloading.
*/
template<class T,class FD,class _>
class Plaintext
{
typedef typename FD::poly_type S;
int n_slots;
int degree;
mutable vector<T> a; // The thing in evaluation/FFT form
mutable vector<S> b; // Now in polynomial form
@@ -58,33 +67,47 @@ class Plaintext
const FD *Field_Data;
void set_sizes();
int degree() const;
public:
const FD& get_field() const { return *Field_Data; }
unsigned int num_slots() const { return n_slots; }
/// Number of slots
unsigned int num_slots() const;
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
{ Field_Data=&FieldD; allocate(type); }
Plaintext(const FD& FieldD, const Rq_Element& other);
/// Initialization
Plaintext(const FHE_Params& params);
void allocate(PT_Type type) const;
void allocate() const { allocate(type); }
void allocate_slots(const bigint& value);
int get_min_alloc();
// Access evaluation representation
/**
* Read slot.
* @param i slot number
* @returns slot content
*/
T element(int i) const
{ if (type==Polynomial)
{ from_poly(); }
return a[i];
}
/**
* Write to slot
* @param i slot number
* @param e new slot content
*/
void set_element(int i,const T& e)
{ if (type==Polynomial)
{ throw not_implemented(); }
a.resize(n_slots);
a.resize(num_slots());
a[i]=e;
type=Evaluation;
}
@@ -171,10 +194,10 @@ class Plaintext
bool is_diagonal() const;
/* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already
*/
/// Append to buffer
void pack(octetStream& o) const;
/// Read from buffer. Assumes parameters are set correctly
void unpack(octetStream& o);
size_t report_size(ReportType type);
@@ -185,4 +208,6 @@ class Plaintext
template <class FD>
using Plaintext_ = Plaintext<typename FD::T, FD, typename FD::S>;
typedef Plaintext_<FFT_Data> Plaintext_mod_prime;
#endif

View File

@@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o)
o.get(pi_inv);
o.get(poly);
}
else
else if (mm != 0)
init(*this, mm);
}

View File

@@ -87,7 +87,6 @@ void Ring_Element::negate()
void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
if (a.element.empty())
{
@@ -100,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
return;
}
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (&ans == &a)
{
ans += b;

View File

@@ -5,7 +5,7 @@
#include "Math/modp.hpp"
Rq_Element::Rq_Element(const FHE_PK& pk) :
Rq_Element(pk.get_params().FFTD())
Rq_Element(pk.get_params().FFTD(), evaluation, evaluation)
{
}
@@ -347,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const
return sz;
}
void Rq_Element::unpack(octetStream& o, const FHE_Params& params)
{
set_data(params.FFTD());
unpack(o);
}
void Rq_Element::print_first_non_zero() const
{
vector<bigint> v = to_vec_bigint();

View File

@@ -69,8 +69,9 @@ protected:
a({b0}), lev(n_mults()) {}
template<class T, class FD, class S>
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext) :
Rq_Element(params)
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext,
RepType r0 = polynomial, RepType r1 = polynomial) :
Rq_Element(params, r0, r1)
{
from(plaintext.get_iterator());
}
@@ -159,6 +160,9 @@ protected:
void pack(octetStream& o) const;
void unpack(octetStream& o);
// without prior initialization
void unpack(octetStream& o, const FHE_Params& params);
void output(ostream& s) const;
void input(istream& s);

View File

@@ -57,7 +57,7 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
template <class FD>
void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
OT_ROLE role, int n_summands)
OT_ROLE role, int)
{
o.reset_write_head();
@@ -67,20 +67,10 @@ void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
G.ReSeed();
timers["Mask randomization"].start();
product_share.randomize(G);
bigint B = 6 * machine.setup<FD>().params.get_R();
B *= machine.setup<FD>().FieldD.get_prime();
B <<= machine.setup<FD>().params.secp();
// slack
B *= NonInteractiveProof::slack(machine.sec,
machine.setup<FD>().params.phi_m());
B <<= machine.extra_slack;
B *= n_summands;
rc.generateUniform(G, 0, B, B);
mask = c;
mask.rerandomize(other_pk);
timers["Mask randomization"].stop();
timers["Encryption"].start();
other_pk.encrypt(mask, product_share, rc);
timers["Encryption"].stop();
mask += c;
mask += product_share;
mask.pack(o);
res -= product_share;
}

View File

@@ -75,6 +75,8 @@ void secure_init(T& setup, Player& P, U& machine,
+ OnlineOptions::singleton.prime.get_str() + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
string reason;
try
{
ifstream file(filename);
@@ -82,12 +84,30 @@ void secure_init(T& setup, Player& P, U& machine,
os.input(file);
os.get(machine.extra_slack);
setup.unpack(os);
}
catch (exception& e)
{
reason = e.what();
}
try
{
setup.check(P, machine);
}
catch (...)
catch (exception& e)
{
cout << "Finding parameters for security " << sec << " and field size ~2^"
<< plaintext_length << endl;
reason = e.what();
}
if (not reason.empty())
{
if (OnlineOptions::singleton.verbose)
cerr << "Generating parameters for security " << sec
<< " and field size ~2^" << plaintext_length
<< " because no suitable material "
"from a previous run was found (" << reason << ")"
<< endl;
setup = {};
setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine);
octetStream os;

View File

@@ -50,11 +50,6 @@ public:
return "no";
}
static string type_short()
{
return "no";
}
static DataFieldType field_type()
{
throw not_implemented();
@@ -66,7 +61,7 @@ public:
static void fail()
{
throw runtime_error("VM does not support binary circuits");
throw runtime_error("functionality not available");
}
NoValue() {}
@@ -143,6 +138,11 @@ public:
return 0;
}
static int length()
{
return 0;
}
static void fail()
{
NoValue::fail();

View File

@@ -5,6 +5,7 @@
#include "Protocols/DealerShare.h"
#include "Protocols/DealerInput.h"
#include "Protocols/Dealer.h"
#include "Processor/RingMachine.hpp"
#include "Processor/Machine.hpp"
@@ -12,6 +13,7 @@
#include "Protocols/DealerPrep.hpp"
#include "Protocols/DealerInput.hpp"
#include "Protocols/DealerMC.hpp"
#include "Protocols/DealerMatrixPrep.hpp"
#include "Protocols/Beaver.hpp"
#include "Semi.hpp"
#include "GC/DealerPrep.h"

View File

@@ -21,5 +21,5 @@ using MamaShare_ = MamaShare<T, N_MAMA_MACS>;
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
DishonestMajorityFieldMachine<MamaShare_, Share>(argc, argv, opt);
DishonestMajorityFieldMachine<MamaShare_, MamaShare_>(argc, argv, opt);
}

View File

@@ -244,6 +244,7 @@ paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o
mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o
l2h-example.x: $(VM) $(OT) Machines/Tinier.o
he-example.x: $(FHEOFFLINE)
mascot-offline.x: $(VM) $(TINIER)
cowgear-offline.x: $(TINIER) $(FHEOFFLINE)
static/rep-bmr-party.x: $(BMR)

View File

@@ -24,7 +24,12 @@ public:
typedef T value_type;
typedef FixedVec Scalar;
static const int length = L;
static const int vector_length = L;
static int length()
{
return L * T::length();
}
static int size()
{

View File

@@ -136,7 +136,7 @@ void write_online_setup(string dirname, const bigint& p)
if (mkdir_p(ss.str().c_str()) == -1)
{
cerr << "mkdir_p(" << ss.str() << ") failed\n";
throw file_error(ss.str());
throw file_error("cannot create " + dirname);
}
// Output the data
@@ -167,6 +167,6 @@ string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
res += "-" + to_string(log2mod);
res += "/";
if (mkdir_p(res.c_str()) < 0)
throw file_error(res);
throw file_error("cannot create " + res);
return res;
}

View File

@@ -439,6 +439,12 @@ void Z2<K>::randomize(PRNG& G, int n)
template<int K>
void Z2<K>::randomize_part(PRNG& G, int n)
{
if (n >= N_BITS)
{
randomize(G);
return;
}
*this = {};
G.get_octets((octet*)a, DIV_CEIL(n, 8));
a[DIV_CEIL(n, 64) - 1] &= mp_limb_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS);

View File

@@ -67,7 +67,10 @@ Z2<K>::Z2(const IntBase<T>& x) :
template<int K>
bool Z2<K>::get_bit(int i) const
{
return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS));
if (i < N_BITS)
return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS));
else
return false;
}
template<int K>

View File

@@ -174,7 +174,8 @@ void Zp_Data::unpack(octetStream& o)
int m;
o.get(m);
montgomery = m;
init(pr, m);
if (pr != 0)
init(pr, m);
}
bool Zp_Data::operator!=(const Zp_Data& other) const

View File

@@ -44,6 +44,19 @@ int fields_2[num_2_fields][4] =
{ 128, 7, 2, 1 },
};
template<class U>
string gf2n_<U>::options()
{
string res = to_string(fields_2[0][0]);
for (int i = 1; i < num_2_fields; i++)
{
int n = fields_2[i][0];
if (n <= MAX_N_BITS)
res += ", " + to_string(n);
}
return res;
}
template<class U>
void gf2n_<U>::init_tables()
{
@@ -113,7 +126,7 @@ void gf2n_<U>::init_field(int nn)
if (j==-1)
{
throw gf2n_not_supported(nn);
throw gf2n_not_supported(nn, options());
}
n=nn;

View File

@@ -86,6 +86,8 @@ protected:
static bool allows(Dtype type) { (void) type; return true; }
static string options();
static const true_type invertible;
static const true_type characteristic_two;
@@ -154,6 +156,8 @@ protected:
gf2n_ operator*(int x) const { return *this * gf2n_(x); }
gf2n_ invert() const;
gf2n_ operator-() const { return *this; }
void negate() { return; }
/* Bitwise Ops */

View File

@@ -107,6 +107,12 @@ public:
a = other.get();
}
template<int K>
gfpvar_(const Z2<K>& other) :
gfpvar_(bigint(other))
{
}
void assign(const void* buffer);
void assign_zero();

View File

@@ -50,17 +50,12 @@ public:
void Broadcast_Receive_no_stats(vector<octetStream>& os) const
{
vector<octetStream> to_send(P.num_players(), os[P.my_num()]);
vector<vector<bool>> channels(P.num_players(),
vector<bool>(P.num_players(), true));
for (auto& x: channels)
x.back() = false;
channels.back() = vector<bool>(P.num_players(), false);
vector<octetStream> to_receive;
P.send_receive_all(channels, to_send, to_receive);
for (int i = 0; i < P.num_players() - 1; i++)
if (i != P.my_num())
os[i] = to_receive[i];
vector<bool> senders(P.num_players(), true), receivers(P.num_players(),
true);
senders.back() = false;
receivers.back() = false;
P.partial_broadcast(senders, receivers, os);
os.resize(num_players());
}
};

View File

@@ -212,8 +212,8 @@ void CryptoPlayer::partial_broadcast(const vector<bool>& my_senders,
for (int offset = 1; offset < num_players(); offset++)
{
int other = get_player(offset);
bool receive = my_senders[other];
if (my_receivers[other])
bool receive = my_senders.at(other);
if (my_receivers.at(other))
{
this->senders[other]->request(os[my_num()]);
sent += os[my_num()].get_length();

View File

@@ -811,14 +811,6 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const
return res;
}
size_t NamedCommStats::total_data()
{
size_t res = 0;
for (auto& x : *this)
res += x.second.data;
return res;
}
void NamedCommStats::print(bool newline)
{
for (auto it = begin(); it != end(); it++)

View File

@@ -157,7 +157,6 @@ public:
NamedCommStats& operator+=(const NamedCommStats& other);
NamedCommStats operator+(const NamedCommStats& other) const;
NamedCommStats operator-(const NamedCommStats& other) const;
size_t total_data();
void print(bool newline = false);
void reset();
#ifdef VERBOSE_COMM

View File

@@ -230,7 +230,7 @@ void Sub_Data_Files<T>::prune()
my_input_buffers.prune();
for (int j = 0; j < num_players; j++)
input_buffers[j].prune();
for (auto it : extended)
for (auto& it : extended)
it.second.prune();
dabit_buffer.prune();
if (part != 0)

View File

@@ -293,7 +293,7 @@ int InputBase<T>::get_player(SubProcessor<T>& Proc, int arg, bool player_from_re
if (player_from_reg)
{
assert(Proc.Proc);
auto res = Proc.Proc->read_Ci(arg);
auto res = Proc.Proc->sync_Ci(arg);
if (res >= Proc.P.num_players())
throw runtime_error("player id too large: " + to_string(res));
return res;

View File

@@ -13,6 +13,7 @@ using namespace std;
template<class sint, class sgf2n> class Machine;
template<class sint, class sgf2n> class Processor;
template<class T> class SubProcessor;
class ArithmeticProcessor;
class SwitchableOutput;
@@ -107,6 +108,11 @@ enum
CONV2DS = 0xAC,
CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
// Shuffling
SECSHUFFLE = 0xFA,
GENSECSHUFFLE = 0xFB,
APPLYSHUFFLE = 0xFC,
DELSHUFFLE = 0xFD,
// Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -250,6 +256,7 @@ enum
GMULS = 0x1A6,
GMULRS = 0x1A7,
GDOTPRODS = 0x1A8,
GSECSHUFFLE = 0x1FA,
// Data access
GTRIPLE = 0x150,
GBIT = 0x151,
@@ -388,6 +395,9 @@ public:
template<class T>
void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0,
T* nan = 0) const;
template<class T>
typename T::clear sanitize(SubProcessor<T>& proc, int reg) const;
};
#endif

View File

@@ -157,6 +157,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case LISTEN:
case CLOSECLIENTCONNECTION:
case CRASH:
case DELSHUFFLE:
r[0]=get_int(s);
break;
// instructions with 2 registers + 1 integer operand
@@ -203,6 +204,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case DIGESTC:
case INPUTMASK:
case GINPUTMASK:
case SECSHUFFLE:
case GSECSHUFFLE:
get_ints(r, s, 2);
n = get_int(s);
break;
@@ -230,6 +233,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case CONDPRINTSTR:
case CONDPRINTSTRB:
case RANDOMS:
case GENSECSHUFFLE:
r[0]=get_int(s);
n = get_int(s);
break;
@@ -269,6 +273,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
// instructions with 5 register operands
case PRINTFLOATPLAIN:
case PRINTFLOATPLAINB:
case APPLYSHUFFLE:
get_vector(5, start, s);
break;
case INCINT:
@@ -558,6 +563,7 @@ int BaseInstruction::get_reg_type() const
case CONVCBITVEC:
case INTOUTPUT:
case ACCEPTCLIENTCONNECTION:
case GENSECSHUFFLE:
return INT;
case PREP:
case GPREP:
@@ -835,11 +841,13 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
{
for (int i = 0; i < size; i++)
Proc.write_Ci(r[0] + i,
Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get());
Proc.sync(
Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get()));
}
else if (n <= 64)
for (int i = 0; i < size; i++)
Proc.write_Ci(r[0] + i, Integer(Proc.read_Cp(r[1] + i), n).get());
Proc.write_Ci(r[0] + i,
Proc.sync(Integer(Proc.read_Cp(r[1] + i), n).get()));
else
throw Processor_Error(to_string(n) + "-bit conversion impossible; "
"integer registers only have 64 bits");
@@ -856,40 +864,32 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
n++;
break;
case LDMCI:
Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1])));
Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.sync_Ci(r[1])));
break;
case STMC:
Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0]));
n++;
break;
case STMCI:
Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0]));
Proc.machine.Mp.write_C(Proc.sync_Ci(r[1]), Proc.read_Cp(r[0]));
break;
case MOVC:
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
break;
case DIVC:
if (Proc.read_Cp(r[2]).is_zero())
throw Processor_Error("Division by zero from register");
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / Proc.read_Cp(r[2]));
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2]));
break;
case GDIVC:
if (Proc.read_C2(r[2]).is_zero())
throw Processor_Error("Division by zero from register");
Proc.write_C2(r[0], Proc.read_C2(r[1]) / Proc.read_C2(r[2]));
Proc.write_C2(r[0], Proc.read_C2(r[1]) / sanitize(Proc.Proc2, r[2]));
break;
case FLOORDIVC:
if (Proc.read_Cp(r[2]).is_zero())
throw Processor_Error("Division by zero from register");
Proc.temp.aa.from_signed(Proc.read_Cp(r[1]));
Proc.temp.aa2.from_signed(Proc.read_Cp(r[2]));
Proc.temp.aa2.from_signed(sanitize(Proc.Procp, r[2]));
Proc.write_Cp(r[0], bigint(Proc.temp.aa / Proc.temp.aa2));
break;
case MODC:
if (Proc.read_Cp(r[2]).is_zero())
throw Processor_Error("Modulo by zero from register");
to_bigint(Proc.temp.aa, Proc.read_Cp(r[1]));
to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2]));
to_bigint(Proc.temp.aa2, sanitize(Proc.Procp, r[2]));
mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t());
Proc.temp.ansp.convert_destroy(Proc.temp.aa);
Proc.write_Cp(r[0],Proc.temp.ansp);
@@ -948,7 +948,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Procp.protocol.randoms_inst(Procp.get_S(), *this);
return;
case INPUTMASKREG:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2]));
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.sync_Ci(r[2]));
Proc.write_Cp(r[1], Proc.temp.rrp);
break;
case INPUTMASK:
@@ -1034,7 +1034,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
return;
case MATMULSM:
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
Proc.read_Ci(r[1]), Proc.read_Ci(r[2]));
Proc.sync_Ci(r[1]), Proc.sync_Ci(r[2]));
return;
case CONV2DS:
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
@@ -1042,6 +1042,21 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case TRUNC_PR:
Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp);
return;
case SECSHUFFLE:
Proc.Procp.secure_shuffle(*this);
return;
case GSECSHUFFLE:
Proc.Proc2.secure_shuffle(*this);
return;
case GENSECSHUFFLE:
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this));
return;
case APPLYSHUFFLE:
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)));
return;
case DELSHUFFLE:
Proc.Procp.delete_shuffle(Proc.read_Ci(r[0]));
return;
case CHECK:
{
CheckJob job;
@@ -1056,14 +1071,14 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.PC += (signed int) n;
break;
case JMPI:
Proc.PC += (signed int) Proc.read_Ci(r[0]);
Proc.PC += (signed int) Proc.sync_Ci(r[0]);
break;
case JMPNZ:
if (Proc.read_Ci(r[0]) != 0)
if (Proc.sync_Ci(r[0]) != 0)
{ Proc.PC += (signed int) n; }
break;
case JMPEQZ:
if (Proc.read_Ci(r[0]) == 0)
if (Proc.sync_Ci(r[0]) == 0)
{ Proc.PC += (signed int) n; }
break;
case PRINTREG:
@@ -1123,7 +1138,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.machine.join_tape(r[0]);
break;
case CRASH:
if (Proc.read_Ci(r[0]))
if (Proc.sync_Ci(r[0]))
throw crash_requested();
break;
case STARTGRIND:
@@ -1146,7 +1161,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
// ***
case LISTEN:
// listen for connections at port number n
Proc.external_clients.start_listening(Proc.read_Ci(r[0]));
Proc.external_clients.start_listening(Proc.sync_Ci(r[0]));
break;
case ACCEPTCLIENTCONNECTION:
{
@@ -1335,4 +1350,15 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c
out << "]";
}
template<class T>
typename T::clear Instruction::sanitize(SubProcessor<T>& proc, int reg) const
{
if (not T::real_shares(proc.P))
return 1;
auto& res = proc.get_C_ref(reg);
if (res.is_zero())
throw Processor_Error("Division by zero from register");
return res;
}
#endif

View File

@@ -30,7 +30,7 @@ void Machine<sint, sgf2n>::init_binary_domains(int security_parameter, int lg2)
if (not is_same<typename sgf2n::mac_key_type, GC::NoValue>())
{
if (sgf2n::clear::degree() < security_parameter)
if (sgf2n::mac_key_type::length() < security_parameter)
{
cerr << "Security parameter needs to be at most n in GF(2^n)."
<< endl;
@@ -469,7 +469,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << my_number << ")" << endl;
<< " rounds (party " << my_number;
if (threads.size() > 1)
cerr << "; rounds counted double due to multi-threading";
cerr << ")" << endl;
auto& P = *this->P;
Bundle<octetStream> bundle(P);

View File

@@ -36,7 +36,9 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
("Bit length of GF(2^n) field (default: " + to_string(V::default_degree()) + ")").c_str(), // Help description.
("Bit length of GF(2^n) field (default: "
+ to_string(V::default_degree()) + "; options are "
+ V::options() + ")").c_str(), // Help description.
"-lg2", // Flag token.
"--lg2" // Flag token.
);

View File

@@ -20,6 +20,7 @@
#include "Tools/CheckVector.h"
#include "GC/Processor.h"
#include "GC/ShareThread.h"
#include "Protocols/SecureShuffle.h"
class Program;
@@ -31,6 +32,8 @@ class SubProcessor
DataPositions bit_usage;
SecureShuffle<T> shuffler;
void resize(size_t size) { C.resize(size); S.resize(size); }
template<class sint, class sgf2n> friend class Processor;
@@ -70,6 +73,11 @@ public:
size_t b);
void conv2ds(const Instruction& instruction);
void secure_shuffle(const Instruction& instruction);
size_t generate_secure_shuffle(const Instruction& instruction);
void apply_shuffle(const Instruction& instruction, int handle);
void delete_shuffle(int handle);
void input_personal(const vector<int>& args);
void send_personal(const vector<int>& args);
void private_output(const vector<int>& args);
@@ -127,6 +135,10 @@ public:
ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num),
sent(0), rounds(0), opts(opts) {}
virtual ~ArithmeticProcessor()
{
}
bool use_stdin()
{
return thread_num == 0 and opts.interactive;
@@ -146,6 +158,11 @@ public:
CheckVector<long>& get_Ci()
{ return Ci; }
virtual long sync_Ci(size_t) const
{
throw not_implemented();
}
void shuffle(const Instruction& instruction);
void bitdecint(const Instruction& instruction);
};
@@ -241,6 +258,10 @@ class Processor : public ArithmeticProcessor
cint get_inverse2(unsigned m);
// synchronize in asymmetric protocols
long sync_Ci(size_t i) const;
long sync(long x) const;
private:
template<class T> friend class SPDZ;

View File

@@ -9,6 +9,7 @@
#include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp"
#include "GC/ShareThread.hpp"
#include "Protocols/SecureShuffle.hpp"
#include <sodium.h>
#include <string>
@@ -23,6 +24,7 @@ SubProcessor<T>::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check&
template <class T>
SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
Preprocessing<T>& DataF, Player& P, ArithmeticProcessor* Proc) :
shuffler(*this),
Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC),
bit_prep(bit_usage)
{
@@ -340,6 +342,9 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id,
// Tolerent to no file if no shares yet persisted.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector<int>& data_registers) {
if (not sint::real_shares(P))
return;
string filename;
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data";
@@ -370,6 +375,9 @@ template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_shares_to_file(long start_pos,
const vector<int>& data_registers)
{
if (not sint::real_shares(P))
return;
string filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size();
@@ -633,6 +641,33 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
}
}
template<class T>
void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
{
SecureShuffle<T>(S, instruction.get_size(), instruction.get_n(),
instruction.get_r(0), instruction.get_r(1), *this);
}
template<class T>
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction)
{
return shuffler.generate(instruction.get_n());
}
template<class T>
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle)
{
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
instruction.get_start()[0], instruction.get_start()[1], handle,
instruction.get_start()[4]);
}
template<class T>
void SubProcessor<T>::delete_shuffle(int handle)
{
shuffler.del(handle);
}
template<class T>
void SubProcessor<T>::input_personal(const vector<int>& args)
{
@@ -690,4 +725,25 @@ typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
return inverses2m[m];
}
template<class sint, class sgf2n>
long Processor<sint, sgf2n>::sync_Ci(size_t i) const
{
return sync(read_Ci(i));
}
template<class sint, class sgf2n>
long Processor<sint, sgf2n>::sync(long x) const
{
if (not sint::symmetric)
{
// send number to dealer
if (P.my_num() == 0)
P.send_long(P.num_players() - 1, x);
if (not sint::real_shares(P))
return P.receive_long(0);
}
return x;
}
#endif

View File

@@ -50,7 +50,10 @@ RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
case L: \
machine.template run<U<L>, V<gf2n>>(); \
break;
X(64) X(72) X(128) X(192)
X(64)
#ifndef FEWER_RINGS
X(72) X(128) X(192)
#endif
#ifdef RING_SIZE
X(RING_SIZE)
#endif

View File

@@ -0,0 +1,50 @@
# example code for graph with vertices 0,1,2 and with following weights
# 0 -> 1: 5
# 0 -> 2: 20
# 1 -> 2: 10
# output should be the following
# from 0 to 0 at cost 0 via vertex 0
# from 0 to 1 at cost 5 via vertex 0
# from 0 to 2 at cost 15 via vertex 1
from oram import OptimalORAM
from dijkstra import dijkstra
# structure for edges
# contains tuples of form (neighbor, cost, last neighbor bit)
edges = OptimalORAM(4, # number of edges
entry_size=(2, # enough bits for vertices
5, # enough bits for costs
1) # always one
)
# first edge from vertex 0
edges[0] = (1, 5, 0)
# second and last edge from vertex 0
edges[1] = (2, 20, 1)
# edge from vertex 1
edges[2] = (2, 10, 1)
# dummy edge from vertex 2 to itself
edges[3] = (2, 0, 1)
# structure assigning edge list indices to vertices
e_index = OptimalORAM(3, # number vertices
entry_size=2) # enough bits for edge indices
# edges from 0 start at 0
e_index[0] = 0
# edges from 1 start at 2
e_index[1] = 2
# edges from 2 start at 3
e_index[2] = 3
source = sint(0)
res = dijkstra(source, edges, e_index, OptimalORAM)
@for_range(res.size)
def _(i):
import util
print_ln('from %s to %s at cost %s via vertex %s', source.reveal(), i,
res[i][0].reveal(), res[i][1].reveal())

View File

@@ -1,9 +0,0 @@
import dijkstra
from path_oram import OptimalORAM
n = 1000
dist = dijkstra.test_dijkstra_on_cycle(n, OptimalORAM)
for i in range(n):
print_ln('%s: %s', i, dist[i][0].reveal())

36
Protocols/Dealer.h Normal file
View File

@@ -0,0 +1,36 @@
/*
* Dealer.h
*
*/
#ifndef PROTOCOLS_DEALER_H_
#define PROTOCOLS_DEALER_H_
#include "Beaver.h"
template<class T>
class Dealer : public Beaver<T>
{
SeededPRNG G;
public:
Dealer(Player& P) :
Beaver<T>(P)
{
}
T get_random()
{
if (T::real_shares(this->P))
return G.get<T>();
else
return {};
}
vector<int> get_relevant_players()
{
return vector<int>(1, this->P.num_players() - 1);
}
};
#endif /* PROTOCOLS_DEALER_H_ */

View File

@@ -24,6 +24,7 @@ public:
DealerInput(SubProcessor<T>& proc, typename T::MAC_Check&);
DealerInput(typename T::MAC_Check&, Preprocessing<T>&, Player& P);
DealerInput(Player& P);
DealerInput(SubProcessor<T>*, Player& P);
~DealerInput();
bool is_dealer(int player = -1);

View File

@@ -10,7 +10,7 @@
template<class T>
DealerInput<T>::DealerInput(SubProcessor<T>& proc, typename T::MAC_Check&) :
DealerInput(proc.P)
DealerInput(&proc, proc.P)
{
}
@@ -23,6 +23,13 @@ DealerInput<T>::DealerInput(typename T::MAC_Check&, Preprocessing<T>&,
template<class T>
DealerInput<T>::DealerInput(Player& P) :
DealerInput(0, P)
{
}
template<class T>
DealerInput<T>::DealerInput(SubProcessor<T>* proc, Player& P) :
InputBase<T>(proc),
P(P), to_send(P), shares(P.num_players()), from_dealer(false),
sub_player(P)
{
@@ -68,8 +75,8 @@ void DealerInput<T>::add_mine(const typename T::open_type& input,
if (is_dealer())
{
make_share(shares.data(), input, P.num_players() - 1, 0, G);
for (int i = 1; i < P.num_players(); i++)
shares.at(i - 1).pack(to_send[i]);
for (int i = 0; i < P.num_players() - 1; i++)
shares.at(i).pack(to_send[i]);
from_dealer = true;
}
else

View File

@@ -25,6 +25,7 @@ public:
void prepare_open(const T& secret);
void exchange(const Player& P);
typename T::open_type finalize_raw();
array<typename T::open_type*, 2> finalize_several(int n);
DealerMC& get_part_MC()
{

View File

@@ -73,4 +73,11 @@ typename T::open_type DealerMC<T>::finalize_raw()
return {};
}
template<class T>
array<typename T::open_type*, 2> DealerMC<T>::finalize_several(int n)
{
assert(sub_player);
return internal.finalize_several(n);
}
#endif /* PROTOCOLS_DEALERMC_HPP_ */

View File

@@ -0,0 +1,32 @@
/*
* DealerMatrixPrep.h
*
*/
#ifndef PROTOCOLS_DEALERMATRIXPREP_H_
#define PROTOCOLS_DEALERMATRIXPREP_H_
#include "ShareMatrix.h"
template<class T>
class DealerMatrixPrep : public BufferPrep<ShareMatrix<T>>
{
typedef BufferPrep<ShareMatrix<T>> super;
typedef typename T::LivePrep LivePrep;
int n_rows, n_inner, n_cols;
LivePrep* prep;
public:
DealerMatrixPrep(int n_rows, int n_inner, int n_cols,
typename T::LivePrep&, DataPositions& usage);
void set_protocol(typename ShareMatrix<T>::Protocol&)
{
}
void buffer_triples();
};
#endif /* PROTOCOLS_DEALERMATRIXPREP_H_ */

View File

@@ -0,0 +1,87 @@
/*
* DealerMatrixPrep.hpp
*
*/
#include "DealerMatrixPrep.h"
template<class T>
DealerMatrixPrep<T>::DealerMatrixPrep(int n_rows, int n_inner, int n_cols,
typename T::LivePrep& prep, DataPositions& usage) :
super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols),
prep(&prep)
{
}
template<class T>
void append_shares(vector<octetStream>& os,
ValueMatrix<typename T::clear>& M, PRNG& G)
{
size_t n = os.size();
for (auto& value : M.entries)
{
T sum;
for (size_t i = 0; i < n - 2; i++)
{
auto share = G.get<T>();
sum += share;
share.pack(os[i]);
}
(value - sum).pack(os[n - 2]);
}
}
template<class T>
ShareMatrix<T> receive_shares(octetStream& o, int n, int m)
{
ShareMatrix<T> res(n, m);
for (size_t i = 0; i < res.entries.size(); i++)
res.entries.v.push_back(o.get<T>());
return res;
}
template<class T>
void DealerMatrixPrep<T>::buffer_triples()
{
assert(this->prep);
assert(this->prep->proc);
auto& P = this->prep->proc->P;
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
int batch_size = 100;
if (not T::real_shares(P))
{
SeededPRNG G;
ValueMatrix<typename T::clear> A(n_rows, n_inner), B(n_inner, n_cols),
C(n_rows, n_cols);
for (int i = 0; i < P.num_players() - 1; i++)
os[i].reserve(
batch_size * T::size()
* (A.entries.size() + B.entries.size()
+ C.entries.size()));
for (int i = 0; i < batch_size; i++)
{
A.randomize(G);
B.randomize(G);
C = A * B;
append_shares<T>(os, A, G);
append_shares<T>(os, B, G);
append_shares<T>(os, C, G);
this->triples.push_back({{{n_rows, n_inner}, {n_inner, n_cols},
{n_rows, n_cols}}});
}
P.send_receive_all(senders, os, to_receive);
}
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < batch_size; i++)
{
auto& o = to_receive.back();
this->triples.push_back({{receive_shares<T>(o, n_rows, n_inner),
receive_shares<T>(o, n_inner, n_cols),
receive_shares<T>(o, n_rows, n_cols)}});
}
}
}

View File

@@ -11,6 +11,13 @@
template<class T>
class DealerPrep : virtual public BitPrep<T>
{
friend class DealerMatrixPrep<T>;
template<int = 0>
void buffer_inverses(true_type);
template<int = 0>
void buffer_inverses(false_type);
template<int = 0>
void buffer_edabits(int n_bits, true_type);
template<int = 0>
@@ -23,8 +30,14 @@ public:
}
void buffer_triples();
void buffer_inverses();
void buffer_bits();
void buffer_inputs(int player)
{
this->buffer_inputs_as_usual(player, this->proc);
}
void buffer_dabits(ThreadQueues* = 0);
void buffer_edabits(int n_bits, ThreadQueues*);
void buffer_sedabits(int n_bits, ThreadQueues*);

View File

@@ -45,6 +45,57 @@ void DealerPrep<T>::buffer_triples()
}
}
template<class T>
void DealerPrep<T>::buffer_inverses()
{
buffer_inverses(T::invertible);
}
template<class T>
template<int>
void DealerPrep<T>::buffer_inverses(false_type)
{
throw not_implemented();
}
template<class T>
template<int>
void DealerPrep<T>::buffer_inverses(true_type)
{
assert(this->proc);
auto& P = this->proc->P;
vector<bool> senders(P.num_players());
senders.back() = true;
octetStreams os(P), to_receive(P);
if (this->proc->input.is_dealer())
{
SeededPRNG G;
vector<SemiShare<typename T::clear>> shares(P.num_players() - 1);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
{
T tuple[2];
while (tuple[0] == 0)
tuple[0] = G.get<T>();
tuple[1] = tuple[0].invert();
for (auto& value : tuple)
{
make_share(shares.data(), typename T::clear(value),
P.num_players() - 1, 0, G);
for (int i = 1; i < P.num_players(); i++)
shares.at(i - 1).pack(os[i - 1]);
}
this->inverses.push_back({});
}
P.send_receive_all(senders, os, to_receive);
}
else
{
P.send_receive_all(senders, os, to_receive);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
this->inverses.push_back(to_receive.back().get<FixedVec<T, 2>>().get());
}
}
template<class T>
void DealerPrep<T>::buffer_bits()
{

View File

@@ -13,12 +13,16 @@ template<class T> class DealerPrep;
template<class T> class DealerInput;
template<class T> class DealerMC;
template<class T> class DirectDealerMC;
template<class T> class DealerMatrixPrep;
template<class T> class Hemi;
namespace GC
{
class DealerSecret;
}
template<class T> class Dealer;
template<class T>
class DealerShare : public SemiShare<T>
{
@@ -30,22 +34,26 @@ public:
typedef DealerMC<This> MAC_Check;
typedef DirectDealerMC<This> Direct_MC;
typedef Beaver<This> Protocol;
typedef Hemi<This> Protocol;
typedef DealerInput<This> Input;
typedef DealerPrep<This> LivePrep;
typedef ::PrivateOutput<This> PrivateOutput;
typedef DealerMatrixPrep<This> MatrixPrep;
typedef Dealer<This> BasicProtocol;
static false_type dishonest_majority;
const static bool needs_ot = false;
const static bool symmetric = false;
static string type_short()
{
return "DD" + string(1, T::type_char());
}
static int threshold(int)
static bool real_shares(const Player& P)
{
throw runtime_error("undefined threshold");
return P.my_num() != P.num_players() - 1;
}
static This constant(const T& other, int my_num,

View File

@@ -33,6 +33,7 @@ public:
static const bool has_trunc_pr = true;
static const bool dishonest_majority = false;
static const bool malicious = false;
static string type_short()
{

View File

@@ -13,22 +13,24 @@
* Matrix multiplication optimized with semi-homomorphic encryption
*/
template<class T>
class Hemi : public Semi<T>
class Hemi : public T::BasicProtocol
{
map<array<int, 3>, HemiMatrixPrep<T>*> matrix_preps;
map<array<int, 3>, typename T::MatrixPrep*> matrix_preps;
DataPositions matrix_usage;
MatrixMC<T> mc;
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
public:
Hemi(Player& P) :
Semi<T>(P)
T::BasicProtocol(P)
{
}
~Hemi();
HemiMatrixPrep<T>& get_matrix_prep(const array<int, 3>& dimensions,
typename T::MatrixPrep& get_matrix_prep(const array<int, 3>& dimensions,
SubProcessor<T>& processor);
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,

View File

@@ -21,12 +21,12 @@ Hemi<T>::~Hemi()
}
template<class T>
HemiMatrixPrep<T>& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
typename T::MatrixPrep& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
SubProcessor<T>& processor)
{
if (matrix_preps.find(dims) == matrix_preps.end())
matrix_preps.insert({dims,
new HemiMatrixPrep<T>(dims[0], dims[1], dims[2],
new typename T::MatrixPrep(dims[0], dims[1], dims[2],
dynamic_cast<typename T::LivePrep&>(processor.DataF),
matrix_usage)});
return *matrix_preps.at(dims);
@@ -52,22 +52,27 @@ void Hemi<T>::matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]);
for (int k = 0; k < dim[1]; k++)
if (not T::real_shares(processor.P))
{
for (int i = 0; i < dim[0]; i++)
matrix_multiply(A, B, processor);
return;
}
for (int i = 0; i < dim[0]; i++)
for (int k = 0; k < dim[1]; k++)
{
auto kk = Proc->get_Ci().at(dim[4] + k);
auto ii = Proc->get_Ci().at(dim[3] + i);
A[{i, k}] = source.at(a + ii * dim[7] + kk);
A.entries.v.push_back(source.at(a + ii * dim[7] + kk));
}
for (int k = 0; k < dim[1]; k++)
for (int j = 0; j < dim[2]; j++)
{
auto jj = Proc->get_Ci().at(dim[6] + j);
auto ll = Proc->get_Ci().at(dim[5] + k);
B[{k, j}] = source.at(b + ll * dim[8] + jj);
B.entries.v.push_back(source.at(b + ll * dim[8] + jj));
}
}
auto res = matrix_multiply(A, B, processor);
@@ -94,13 +99,16 @@ ShareMatrix<T> Hemi<T>::matrix_multiply(const ShareMatrix<T>& A,
subdim[1] = min(max_inner, A.n_cols - i);
subdim[2] = min(max_cols, B.n_cols - j);
auto& prep = get_matrix_prep(subdim, processor);
MatrixMC<T> mc;
beaver.init(prep, mc);
beaver.init_mul();
beaver.prepare_mul(A.from(0, i, subdim.data()),
B.from(i, j, subdim.data() + 1));
beaver.exchange();
C.add_from_col(j, beaver.finalize_mul());
bool for_real = T::real_shares(processor.P);
beaver.prepare_mul(A.from(0, i, subdim.data(), for_real),
B.from(i, j, subdim.data() + 1, for_real));
if (for_real)
{
beaver.exchange();
C.add_from_col(j, beaver.finalize_mul());
}
}
}
@@ -150,6 +158,15 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
array<int, 3> dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}});
ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]);
if (not T::real_shares(processor.P))
{
matrix_multiply(A, B, processor);
return;
}
A.entries.init();
B.entries.init();
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;

View File

@@ -10,6 +10,7 @@
template<class T> class HemiPrep;
template<class T> class Hemi;
template<class T> class HemiMatrixPrep;
template<class T>
class HemiShare : public SemiShare<T>
@@ -26,6 +27,9 @@ public:
typedef typename conditional<T::prime_field, Hemi<This>, Beaver<This>>::type Protocol;
typedef HemiPrep<This> LivePrep;
typedef HemiMatrixPrep<This> MatrixPrep;
typedef Semi<This> BasicProtocol;
static const bool needs_ot = false;
static const bool local_mul = true;
static true_type triple_matmul;

View File

@@ -298,7 +298,8 @@ void TreeSum<T>::start(vector<T>& values, const Player& P)
{
// send from the root player
os.reset_write_head();
for (unsigned int i=0; i<values.size(); i++)
size_t n = values.size();
for (unsigned int i=0; i<n; i++)
{ values[i].pack(os); }
timers[BCAST].start();
for (int i = 1; i < max_broadcast && i < P.num_players(); i++)

View File

@@ -330,10 +330,13 @@ Direct_MAC_Check<T>::~Direct_MAC_Check() {
template<class T>
void direct_add_openings(vector<T>& values, const PlayerBase& P, vector<octetStream>& os)
{
for (unsigned int i=0; i<values.size(); i++)
for (int j=0; j<P.num_players(); j++)
if (j!=P.my_num())
values[i].add(os.at(j));
size_t n = P.num_players();
size_t me = P.my_num();
assert(os.size() == n);
for (auto& value : values)
for (size_t j = 0; j < n; j++)
if (j != me)
value += os[j].get<T>();
}
template<class T>

View File

@@ -13,6 +13,7 @@ using namespace std;
#include "Tools/PointerVector.h"
template<class T> class Preprocessing;
template<class T> class MatrixMC;
/**
* Abstract base class for opening protocols
@@ -20,6 +21,8 @@ template<class T> class Preprocessing;
template<class T>
class MAC_Check_Base
{
friend class MatrixMC<T>;
protected:
/* MAC Share */
typename T::mac_key_type::Scalar alphai;
@@ -59,6 +62,7 @@ public:
/// Get next opened value
virtual typename T::clear finalize_open();
virtual typename T::open_type finalize_raw();
array<typename T::open_type*, 2> finalize_several(size_t n);
/// Check whether all ``shares`` are ``value``
virtual void CheckFor(const typename T::open_type& value, const vector<T>& shares, const Player& P);

View File

@@ -70,6 +70,13 @@ typename T::open_type MAC_Check_Base<T>::finalize_raw()
return values.next();
}
template<class T>
array<typename T::open_type*, 2> MAC_Check_Base<T>::finalize_several(size_t n)
{
assert(values.left() >= n);
return {{values.skip(0), values.skip(n)}};
}
template<class T>
void MAC_Check_Base<T>::CheckFor(const typename T::open_type& value,
const vector<T>& shares, const Player& P)

View File

@@ -42,8 +42,12 @@ public:
typedef GC::MaliciousRepSecret bit_type;
// indicate security relevance of field size
typedef T mac_key_type;
const static bool expensive = true;
static const bool has_trunc_pr = false;
static const bool malicious = true;
static string type_short()
{

View File

@@ -160,7 +160,7 @@ template<class T>
void CommMaliciousRepMC<T>::POpen_Begin(vector<typename T::clear>& values,
const vector<T>& S, const Player& P)
{
assert(T::length == 2);
assert(T::vector_length == 2);
(void)values;
os.resize(2);
for (auto& o : os)

View File

@@ -45,6 +45,8 @@ public:
typedef GC::MaliciousCcdSecret<gf2n_short> bit_type;
#endif
static const bool malicious = true;
static string type_short()
{
return "M" + super::type_short();

View File

@@ -122,6 +122,7 @@ public:
const static bool expensive = false;
const static bool variable_players = false;
static const bool has_trunc_pr = true;
static const bool malicious = false;
static string type_short()
{

View File

@@ -37,6 +37,8 @@ public:
typedef GC::Rep4Secret bit_type;
static const bool malicious = true;
static string type_short()
{
return "R4" + string(1, T::type_char());

View File

@@ -121,6 +121,8 @@ public:
virtual void cisc(SubProcessor<T>&, const Instruction&)
{ throw runtime_error("CISC instructions not implemented"); }
virtual vector<int> get_relevant_players();
};
/**
@@ -146,7 +148,7 @@ public:
static void assign(T& share, const typename T::clear& value, int my_num)
{
assert(T::length == 2);
assert(T::vector_length == 2);
share.assign_zero();
if (my_num < 2)
share[my_num] = value;

View File

@@ -28,7 +28,7 @@ ProtocolBase<T>::ProtocolBase() :
template<class T>
Replicated<T>::Replicated(Player& P) : ReplicatedBase(P)
{
assert(T::length == 2);
assert(T::vector_length == 2);
}
template<class T>
@@ -152,6 +152,16 @@ T ProtocolBase<T>::get_random()
return res;
}
template<class T>
vector<int> ProtocolBase<T>::get_relevant_players()
{
vector<int> res;
int n = dynamic_cast<typename T::Protocol&>(*this).P.num_players();
for (int i = 0; i < T::threshold(n) + 1; i++)
res.push_back(i);
return res;
}
template<class T>
void Replicated<T>::init_mul()
{

View File

@@ -71,7 +71,7 @@ public:
ReplicatedInput(SubProcessor<T>* proc, Player& P) :
PrepLessInput<T>(proc), proc(proc), P(P), protocol(P)
{
assert(T::length == 2);
assert(T::vector_length == 2);
expect.resize(P.num_players());
this->reset_all(P);
}

View File

@@ -28,7 +28,7 @@ void ReplicatedMC<T>::POpen_Begin(vector<typename T::open_type>&,
template<class T>
void ReplicatedMC<T>::prepare(const vector<T>& S)
{
assert(T::length == 2);
assert(T::vector_length == 2);
o.reset_write_head();
to_send.reset_write_head();
to_send.reserve(S.size() * T::value_type::size());

53
Protocols/SecureShuffle.h Normal file
View File

@@ -0,0 +1,53 @@
/*
* SecureShuffle.h
*
*/
#ifndef PROTOCOLS_SECURESHUFFLE_H_
#define PROTOCOLS_SECURESHUFFLE_H_
#include <vector>
using namespace std;
template<class T> class SubProcessor;
template<class T>
class SecureShuffle
{
SubProcessor<T>& proc;
vector<T> to_shuffle;
vector<vector<T>> config;
vector<T> tmp;
int unit_size;
vector<vector<vector<vector<T>>>> shuffles;
size_t n_shuffle;
bool exact;
void player_round(int config_player);
void generate(int config_player, int n_shuffle);
void waksman(vector<T>& a, int depth, int start);
void cond_swap(T& x, T& y, const T& b);
void iter_waksman(bool reverse = false);
void waksman_round(int size, bool inwards, bool reverse);
void pre(vector<T>& a, size_t n, size_t input_base);
void post(vector<T>& a, size_t n, size_t input_base);
public:
SecureShuffle(vector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T>& proc);
SecureShuffle(SubProcessor<T>& proc);
int generate(int n_shuffle);
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int handle, bool reverse);
void del(int handle);
};
#endif /* PROTOCOLS_SECURESHUFFLE_H_ */

328
Protocols/SecureShuffle.hpp Normal file
View File

@@ -0,0 +1,328 @@
/*
* SecureShuffle.hpp
*
*/
#ifndef PROTOCOLS_SECURESHUFFLE_HPP_
#define PROTOCOLS_SECURESHUFFLE_HPP_
#include "SecureShuffle.h"
#include "Tools/Waksman.h"
#include <math.h>
#include <algorithm>
template<class T>
SecureShuffle<T>::SecureShuffle(SubProcessor<T>& proc) :
proc(proc), unit_size(0), n_shuffle(0), exact(false)
{
}
template<class T>
SecureShuffle<T>::SecureShuffle(vector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T>& proc) :
proc(proc), unit_size(unit_size)
{
pre(a, n, input_base);
for (auto i : proc.protocol.get_relevant_players())
player_round(i);
post(a, n, output_base);
}
template<class T>
void SecureShuffle<T>::apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int handle, bool reverse)
{
this->unit_size = unit_size;
pre(a, n, input_base);
auto& shuffle = shuffles.at(handle);
assert(shuffle.size() == proc.protocol.get_relevant_players().size());
if (reverse)
for (auto it = shuffle.end(); it > shuffle.begin(); it--)
{
this->config = *(it - 1);
iter_waksman(reverse);
}
else
for (auto& config : shuffle)
{
this->config = config;
iter_waksman(reverse);
}
post(a, n, output_base);
}
template<class T>
void SecureShuffle<T>::del(int handle)
{
shuffles.at(handle).clear();
}
template<class T>
void SecureShuffle<T>::pre(vector<T>& a, size_t n, size_t input_base)
{
n_shuffle = n / unit_size;
assert(unit_size * n_shuffle == n);
size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle))));
exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious;
to_shuffle.clear();
if (exact)
{
to_shuffle.resize(n_shuffle_pow2 * unit_size);
for (size_t i = 0; i < n; i++)
to_shuffle[i] = a[input_base + i];
}
else
{
// sorting power of two elements together with indicator bits
to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle))));
for (size_t i = 0; i < n_shuffle; i++)
{
for (int j = 0; j < unit_size; j++)
to_shuffle[i * (unit_size + 1) + j] = a[input_base
+ i * unit_size + j];
to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1,
proc.P.my_num(), proc.MC.get_alphai());
}
this->unit_size++;
}
}
template<class T>
void SecureShuffle<T>::post(vector<T>& a, size_t n, size_t output_base)
{
if (exact)
for (size_t i = 0; i < n; i++)
a[output_base + i] = to_shuffle[i];
else
{
auto& MC = proc.MC;
MC.init_open(proc.P);
int shuffle_unit_size = this->unit_size;
int unit_size = shuffle_unit_size - 1;
for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++)
MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1));
MC.exchange(proc.P);
size_t i_shuffle = 0;
for (size_t i = 0; i < n_shuffle; i++)
{
auto bit = MC.finalize_open();
if (bit == 1)
{
// only output real elements
for (int j = 0; j < unit_size; j++)
a.at(output_base + i_shuffle * unit_size + j) =
to_shuffle.at(i * shuffle_unit_size + j);
i_shuffle++;
}
}
if (i_shuffle != n_shuffle)
throw runtime_error("incorrect shuffle");
}
}
template<class T>
void SecureShuffle<T>::player_round(int config_player)
{
generate(config_player, n_shuffle);
iter_waksman();
}
template<class T>
int SecureShuffle<T>::generate(int n_shuffle)
{
int res = shuffles.size();
shuffles.push_back({});
auto& shuffle = shuffles.back();
for (auto i : proc.protocol.get_relevant_players())
{
generate(i, n_shuffle);
shuffle.push_back(config);
}
return res;
}
template<class T>
void SecureShuffle<T>::generate(int config_player, int n)
{
auto& P = proc.P;
auto& input = proc.input;
input.reset_all(P);
int n_pow2 = 1 << int(ceil(log2(n)));
Waksman waksman(n_pow2);
if (P.my_num() == config_player)
{
vector<int> perm;
int shuffle_size = n;
for (int j = 0; j < n_pow2; j++)
perm.push_back(j);
SeededPRNG G;
for (int i = 0; i < shuffle_size; i++)
{
int j = G.get_uint(shuffle_size - i);
swap(perm[i], perm[i + j]);
}
auto config_bits = waksman.configure(perm);
for (size_t i = 0; i < config_bits.size(); i++)
{
auto& x = config_bits[i];
for (size_t j = 0; j < x.size(); j++)
if (waksman.matters(i, j))
input.add_mine(int(x[j]));
else
assert(x[j] == 0);
}
}
else
for (size_t i = 0; i < waksman.n_bits(); i++)
input.add_other(config_player);
input.exchange();
config.clear();
typename T::Protocol checker(P);
checker.init(proc.DataF, proc.MC);
checker.init_dotprod();
auto one = T::constant(1, P.my_num(), proc.MC.get_alphai());
for (size_t i = 0; i < waksman.n_rounds(); i++)
{
config.push_back({});
for (int j = 0; j < n_pow2; j++)
{
if (waksman.matters(i, j))
{
config.back().push_back(input.finalize(config_player));
if (T::malicious)
checker.prepare_dotprod(config.back().back(),
one - config.back().back());
}
else
config.back().push_back({});
}
}
if (T::malicious)
{
checker.next_dotprod();
checker.exchange();
assert(
typename T::clear(
proc.MC.open(checker.finalize_dotprod(waksman.n_bits()),
P)) == 0);
checker.check();
}
}
template<class T>
void SecureShuffle<T>::waksman(vector<T>& a, int depth, int start)
{
int n = a.size();
if (n == 2)
{
cond_swap(a[0], a[1], config.at(depth).at(start));
return;
}
vector<T> a0(n / 2), a1(n / 2);
for (int i = 0; i < n / 2; i++)
{
a0.at(i) = a.at(2 * i);
a1.at(i) = a.at(2 * i + 1);
cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2));
}
waksman(a0, depth + 1, start);
waksman(a1, depth + 1, start + n / 2);
for (int i = 0; i < n / 2; i++)
{
a.at(2 * i) = a0.at(i);
a.at(2 * i + 1) = a1.at(i);
cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start));
}
}
template<class T>
void SecureShuffle<T>::cond_swap(T& x, T& y, const T& b)
{
auto diff = proc.protocol.mul(x - y, b);
x -= diff;
y += diff;
}
template<class T>
void SecureShuffle<T>::iter_waksman(bool reverse)
{
int n = to_shuffle.size() / unit_size;
for (int depth = 0; depth < log2(n); depth++)
waksman_round(depth, true, reverse);
for (int depth = log2(n) - 2; depth >= 0; depth--)
waksman_round(depth, false, reverse);
}
template<class T>
void SecureShuffle<T>::waksman_round(int depth, bool inwards, bool reverse)
{
int n = to_shuffle.size() / unit_size;
assert((int) config.at(depth).size() == n);
int nblocks = 1 << depth;
int size = n / (2 * nblocks);
bool outwards = !inwards;
proc.protocol.init_mul();
vector<array<int, 5>> indices;
indices.reserve(n / 2);
Waksman waksman(n);
for (int k = 0; k < n / 2; k++)
{
int j = k % size;
int i = k / size;
int base = 2 * i * size;
int in1 = base + j + j * inwards;
int in2 = in1 + inwards + size * outwards;
int out1 = base + j + j * outwards;
int out2 = out1 + outwards + size * inwards;
int i_bit = base + j + size * (outwards ^ reverse);
bool run = waksman.matters(depth, i_bit);
if (run)
{
for (int l = 0; l < unit_size; l++)
proc.protocol.prepare_mul(config.at(depth).at(i_bit),
to_shuffle.at(in1 * unit_size + l)
- to_shuffle.at(in2 * unit_size + l));
}
indices.push_back({{in1, in2, out1, out2, run}});
}
proc.protocol.exchange();
tmp.resize(to_shuffle.size());
for (int k = 0; k < n / 2; k++)
{
auto idx = indices.at(k);
for (int l = 0; l < unit_size; l++)
{
T diff;
if (idx[4])
diff = proc.protocol.finalize_mul();
tmp.at(idx[2] * unit_size + l) = to_shuffle.at(
idx[0] * unit_size + l) - diff;
tmp.at(idx[3] * unit_size + l) = to_shuffle.at(
idx[1] * unit_size + l) + diff;
}
}
swap(tmp, to_shuffle);
}
#endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */

View File

@@ -78,6 +78,7 @@ public:
const static bool variable_players = true;
const static bool expensive = false;
static const bool has_trunc_pr = true;
static const bool malicious = false;
static string type_short() { return "D" + string(1, T::type_char()); }

View File

@@ -49,6 +49,7 @@ public:
const static bool dishonest_majority = false;
const static bool variable_players = true;
const static bool expensive = false;
const static bool malicious = true;
static string type_short()
{

View File

@@ -56,6 +56,7 @@ class Share_ : public ShareInterface
const static bool dishonest_majority = T::dishonest_majority;
const static bool variable_players = T::variable_players;
const static bool has_mac = true;
static const bool malicious = true;
static int size()
{ return T::size() + V::size(); }

View File

@@ -40,12 +40,17 @@ public:
static const bool has_trunc_pr = false;
static const bool has_split = false;
static const bool has_mac = false;
static const bool malicious = false;
static const false_type triple_matmul;
const static bool symmetric = true;
static const int default_length = 1;
static string type_short() { return "undef"; }
static string type_short() { throw runtime_error("don't call this"); }
static bool real_shares(const Player&) { return true; }
template<class T, class U>
static void split(vector<U>, vector<int>, int, T*, int,
@@ -63,6 +68,8 @@ public:
template<class T, class U>
static void generate_mac_key(T&, U&) {}
static int threshold(int) { throw runtime_error("undefined threshold"); }
};
#endif /* PROTOCOLS_SHAREINTERFACE_H_ */

View File

@@ -14,6 +14,124 @@ using namespace std;
template<class T> class MatrixMC;
template<class T>
class NonInitVector
{
template<class U> friend class NonInitVector;
size_t size_;
public:
AddableVector<T> v;
NonInitVector(size_t size) :
size_(size)
{
v.reserve(size);
}
template<class U>
NonInitVector(const NonInitVector<U>& other) :
size_(other.size()), v(other.v)
{
}
size_t size() const
{
return size_;
}
void init()
{
v.resize(size_);
}
void check() const
{
#ifdef DEBUG_MATRIX
assert(not v.empty());
#endif
}
typename vector<T>::iterator begin()
{
check();
return v.begin();
}
typename vector<T>::iterator end()
{
check();
return v.end();
}
T& at(size_t index)
{
check();
return v.at(index);
}
const T& at(size_t index) const
{
#ifdef DEBUG_MATRIX
assert(index < size());
#endif
return (*this)[index];
}
T& operator[](size_t index)
{
check();
return v[index];
}
const T& operator[](size_t index) const
{
check();
return v[index];
}
NonInitVector operator-(const NonInitVector& other) const
{
assert(size() == other.size());
NonInitVector res(size());
if (other.v.empty())
return *this;
else if (v.empty())
{
res.init();
res.v = res.v - other.v;
}
else
res.v = v - other.v;
return res;
}
NonInitVector& operator+=(const NonInitVector& other)
{
assert(size() == other.size());
if (not other.v.empty())
{
if (v.empty())
*this = other;
else
v += other.v;
}
return *this;
}
bool operator!=(const NonInitVector& other) const
{
return v != other.v;
}
void randomize(PRNG& G)
{
v.clear();
for (size_t i = 0; i < size(); i++)
v.push_back(G.get<T>());
}
};
template<class T>
class ValueMatrix : public ValueInterface
{
@@ -21,7 +139,7 @@ class ValueMatrix : public ValueInterface
public:
int n_rows, n_cols;
AddableVector<T> entries;
NonInitVector<T> entries;
static DataFieldType field_type()
{
@@ -48,15 +166,19 @@ public:
T& operator[](const pair<int, int>& indices)
{
#ifdef DEBUG_MATRIX
assert(indices.first < n_rows);
assert(indices.second < n_cols);
#endif
return entries.at(indices.first * n_cols + indices.second);
}
const T& operator[](const pair<int, int>& indices) const
{
#ifdef DEBUG_MATRIX
assert(indices.first < n_rows);
assert(indices.second < n_cols);
#endif
return entries.at(indices.first * n_cols + indices.second);
}
@@ -80,6 +202,9 @@ public:
{
assert(n_cols == other.n_rows);
This res(n_rows, other.n_cols);
if (entries.v.empty() or other.entries.v.empty())
return res;
res.entries.init();
for (int i = 0; i < n_rows; i++)
for (int j = 0; j < other.n_cols; j++)
for (int k = 0; k < n_cols; k++)
@@ -103,9 +228,9 @@ public:
ValueMatrix transpose() const
{
ValueMatrix res(this->n_cols, this->n_rows);
for (int i = 0; i < this->n_rows; i++)
for (int j = 0; j < this->n_cols; j++)
res[{j, i}] = (*this)[{i, j}];
for (int j = 0; j < this->n_cols; j++)
for (int i = 0; i < this->n_rows; i++)
res.entries.v.push_back((*this)[{i, j}]);
return res;
}
@@ -139,7 +264,7 @@ public:
{
This res(other.n_rows, other.n_cols);
for (size_t i = 0; i < other.entries.size(); i++)
res.entries[i] = T::constant(other.entries[i], my_num, key);
res.entries.v.push_back(T::constant(other.entries[i], my_num, key));
res.check();
return res;
}
@@ -167,24 +292,29 @@ public:
ShareMatrix from_col(int start, int size) const
{
ShareMatrix res(this->n_rows, min(size, this->n_cols - start));
res.entries.clear();
for (int i = 0; i < res.n_rows; i++)
for (int j = 0; j < res.n_cols; j++)
res[{i, j}] = (*this)[{i, start + j}];
res.entries.v.push_back((*this)[{i, start + j}]);
return res;
}
ShareMatrix from(int start_row, int start_col, int* sizes) const
ShareMatrix from(int start_row, int start_col, int* sizes, bool for_real =
true) const
{
ShareMatrix res(min(sizes[0], this->n_rows - start_row),
min(sizes[1], this->n_cols - start_col));
if (not for_real)
return res;
for (int i = 0; i < res.n_rows; i++)
for (int j = 0; j < res.n_cols; j++)
res[{i, j}] = (*this)[{start_row + i, start_col + j}];
res.entries.v.push_back((*this)[{start_row + i, start_col + j}]);
return res;
}
void add_from_col(int start, const ShareMatrix& other)
{
this->entries.init();
for (int i = 0; i < this->n_rows; i++)
for (int j = 0; j < other.n_cols; j++)
(*this)[{i, start + j}] += other[{i, j}];
@@ -197,6 +327,9 @@ ShareMatrix<T> operator*(const ValueMatrix<typename T::clear>& a,
{
assert(a.n_cols == b.n_rows);
ShareMatrix<T> res(a.n_rows, b.n_cols);
if (a.entries.v.empty() or b.entries.v.empty())
return res;
res.entries.init();
for (int i = 0; i < a.n_rows; i++)
for (int j = 0; j < b.n_cols; j++)
for (int k = 0; k < a.n_cols; k++)
@@ -208,9 +341,22 @@ ShareMatrix<T> operator*(const ValueMatrix<typename T::clear>& a,
template<class T>
class MatrixMC : public MAC_Check_Base<ShareMatrix<T>>
{
typename T::MAC_Check inner;
typename T::MAC_Check& inner;
public:
MatrixMC() :
inner(
*(OnlineOptions::singleton.direct ?
new typename T::Direct_MC :
new typename T::MAC_Check))
{
}
~MatrixMC()
{
delete &inner;
}
void exchange(const Player& P)
{
inner.init_open(P);
@@ -224,8 +370,15 @@ public:
for (auto& share : this->secrets)
{
this->values.push_back({share.n_rows, share.n_cols});
for (auto& entry : this->values.back().entries)
entry = inner.finalize_open();
if (share.entries.v.empty())
for (size_t i = 0; i < share.entries.size(); i++)
inner.finalize_open();
else
{
auto range = inner.finalize_several(share.entries.size());
auto& v = this->values.back().entries.v;
v.insert(v.begin(), range[0], range[1]);
}
}
}
};

View File

@@ -25,6 +25,9 @@ public:
typedef typename conditional<T::prime_field, Hemi<This>, Beaver<This>>::type Protocol;
typedef TemiPrep<This> LivePrep;
typedef HemiMatrixPrep<This> MatrixPrep;
typedef Semi<This> BasicProtocol;
static const bool needs_ot = false;
static const bool local_mul = false;

View File

@@ -130,7 +130,6 @@ template<class T, class U>
void make_share(DealerShare<T>* Sa, const T& a, int N, const U&, PRNG& G)
{
make_share((SemiShare<T>*) Sa, a, N - 1, U(), G);
Sa[N - 1] = {};
}
template<class T, class U, class V>
@@ -273,6 +272,11 @@ inline string mac_filename(string directory, int playerno)
+ to_string(playerno);
}
template <>
inline void write_mac_key(const string&, int, int, GC::NoValue)
{
}
template <class U>
void write_mac_key(const string& directory, int i, int nplayers, U key)
{
@@ -301,6 +305,11 @@ void read_mac_key(const string& directory, const Names& N, T& key)
read_mac_key(directory, N.my_num(), N.num_players(), key);
}
template <>
inline void read_mac_key(const string&, int, int, GC::NoValue&)
{
}
template <class U>
void read_mac_key(const string& directory, int player_num, int nplayers, U& key)
{
@@ -367,7 +376,7 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P,
}
template <class U>
void read_global_mac_key(const string& directory, int nparties, U& key, false_type)
void read_global_mac_key(const string& directory, int nparties, U& key)
{
U pp;
key.assign_zero();
@@ -383,17 +392,11 @@ void read_global_mac_key(const string& directory, int nparties, U& key, false_ty
cout << "Final Keys : " << key << endl;
}
template <class U>
void read_global_mac_key(const string&, int, U&, true_type)
template <>
inline void read_global_mac_key(const string&, int, GC::NoValue&)
{
}
template <class U>
void read_global_mac_key(const string& directory, int nparties, U& key)
{
read_global_mac_key(directory, nparties, key, is_same<U, GC::NoValue>());
}
template <class T>
T reconstruct(vector<T>& shares)
{
@@ -579,14 +582,14 @@ void plain_edabits(vector<typename T::clear>& as,
as.resize(max_size);
bs.clear();
bs.resize(length);
bigint value;
Z2<T::clear::MAX_EDABITS> value;
for (int j = 0; j < max_size; j++)
{
if (not zero)
G.get_bigint(value, length, true);
value.randomize_part(G, length);
as[j] = value;
for (int k = 0; k < length; k++)
bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j;
bs[k] ^= BitVec(value.get_bit(k)) << j;
}
}

View File

@@ -101,8 +101,9 @@ The following table lists all protocols that are fully supported.
| Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) |
| Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A |
| Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) |
| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) |
| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep3 / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) |
| Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) |
| Malicious, honest supermajority | [Rep4](#honest-majority) | [Rep4](#honest-majority) | [Rep4](#honest-majority) | N/A |
| Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A |
Modulo prime and modulo 2^k are the two settings that allow
@@ -280,6 +281,8 @@ compute the preprocessing time for a particular computation.
- Python 3.5 or later
- NTL library for homomorphic encryption (optional; tested with NTL 10.5)
- If using macOS, Sierra or later
- Windows/VirtualBox: see [this
issue](https://github.com/data61/MP-SPDZ/issues/557) for a discussion
#### Compilation

View File

@@ -84,7 +84,9 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil
{
}
gf2n_not_supported::gf2n_not_supported(int n) :
runtime_error("GF(2^" + to_string(n) + ") not supported")
gf2n_not_supported::gf2n_not_supported(int n, string options) :
runtime_error(
"GF(2^" + to_string(n) + ") not supported"
+ (options.empty() ? "" : ", options are " + options))
{
}

View File

@@ -281,7 +281,7 @@ public:
class gf2n_not_supported : public runtime_error
{
public:
gf2n_not_supported(int n);
gf2n_not_supported(int n, string options = "");
};
#endif

Some files were not shown because too many files have changed in this diff Show More