mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Machine learning functionality, dishonest-majority binary secret sharing.
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.1.2
|
||||
|
||||
- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission
|
||||
- Binary computation for dishonest majority using secret sharing
|
||||
- Mathematical functions from [SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA)
|
||||
- Fixed security bug: CowGear would reuse triples.
|
||||
|
||||
## 0.1.1 (Aug 6, 2019)
|
||||
|
||||
- ECDSA
|
||||
|
||||
@@ -101,12 +101,12 @@ def determine_scope(block, options):
|
||||
used_from_scope = set()
|
||||
|
||||
def find_in_scope(reg, scope):
|
||||
if scope is None:
|
||||
return False
|
||||
elif reg in scope.defined_registers:
|
||||
return True
|
||||
else:
|
||||
return find_in_scope(reg, scope.scope)
|
||||
while True:
|
||||
if scope is None:
|
||||
return False
|
||||
elif reg in scope.defined_registers:
|
||||
return True
|
||||
scope = scope.scope
|
||||
|
||||
def read(reg, n):
|
||||
if last_def[reg] == -1:
|
||||
@@ -386,7 +386,7 @@ class Merger:
|
||||
last_print_str = None
|
||||
last = defaultdict(lambda: defaultdict(lambda: None))
|
||||
last_open = deque()
|
||||
last_text_input = None
|
||||
last_text_input = [None, None]
|
||||
|
||||
depths = [0] * len(block.instructions)
|
||||
self.depths = depths
|
||||
@@ -474,10 +474,14 @@ class Merger:
|
||||
|
||||
# will be merged
|
||||
if isinstance(instr, TextInputInstruction):
|
||||
if last_text_input is not None and \
|
||||
type(block.instructions[last_text_input]) is not type(instr):
|
||||
add_edge(last_text_input, n)
|
||||
last_text_input = n
|
||||
if last_text_input[0] is not None:
|
||||
if instr.merge_id() != \
|
||||
block.instructions[last_text_input[0]].merge_id():
|
||||
add_edge(last_text_input[0], n)
|
||||
last_text_input[1] = last_text_input[0]
|
||||
elif last_text_input[1] is not None:
|
||||
add_edge(last_text_input[1], n)
|
||||
last_text_input[0] = n
|
||||
|
||||
if isinstance(instr, merge_classes):
|
||||
open_nodes.add(n)
|
||||
|
||||
@@ -80,6 +80,12 @@ def LTZ(s, a, k, kappa):
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
import types
|
||||
res = types.sint()
|
||||
LTZ(res, a, k, kappa)
|
||||
return res
|
||||
|
||||
def Trunc(d, a, k, m, kappa, signed):
|
||||
"""
|
||||
d = a >> m
|
||||
@@ -153,6 +159,8 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
k: bit length of a
|
||||
m: compile-time integer
|
||||
"""
|
||||
if m == 0:
|
||||
return a
|
||||
if k == int(program.options.ring):
|
||||
# cannot work with bit length k+1
|
||||
tmp = TruncRing(None, a, k, m - 1, signed)
|
||||
@@ -359,7 +367,7 @@ def CarryOutAux(d, a, kappa):
|
||||
movs(d, a[0][1])
|
||||
|
||||
# carry out with carry-in bit c
|
||||
def CarryOut(res, a, b, c, kappa):
|
||||
def CarryOut(res, a, b, c=0, kappa=None):
|
||||
"""
|
||||
res = last carry bit in addition of a and b
|
||||
|
||||
@@ -368,8 +376,9 @@ def CarryOut(res, a, b, c, kappa):
|
||||
c: initial carry-in bit
|
||||
"""
|
||||
k = len(a)
|
||||
import types
|
||||
d = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
|
||||
t = [[types.sint() for i in range(k)] for i in range(4)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
for i in range(k):
|
||||
mulm(t[0][i], b[i], a[i])
|
||||
@@ -377,12 +386,19 @@ def CarryOut(res, a, b, c, kappa):
|
||||
addm(t[2][i], b[i], a[i])
|
||||
subs(t[3][i], t[2][i], t[1][i])
|
||||
d[i] = [t[3][i], t[0][i]]
|
||||
mulsi(s[0], d[-1][0], c)
|
||||
adds(s[1], d[-1][1], s[0])
|
||||
s[0] = d[-1][0] * c
|
||||
s[1] = d[-1][1] + s[0]
|
||||
d[-1][1] = s[1]
|
||||
|
||||
CarryOutAux(res, d[::-1], kappa)
|
||||
|
||||
def CarryOutLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
import types
|
||||
res = types.sint()
|
||||
CarryOut(res, a[::-1], b[::-1], c)
|
||||
return res
|
||||
|
||||
def BitLTL(res, a, b, kappa):
|
||||
"""
|
||||
res = a <? b (logarithmic rounds version)
|
||||
|
||||
@@ -47,7 +47,10 @@ COST = { 'modp': defaultdict(lambda: 0,
|
||||
'bittriple': 0.00004828818388140422,
|
||||
'bitgf2ntriple': 0.00020716801325875284,
|
||||
'PreMulC': 2 * 0.00020716801325875284,
|
||||
})
|
||||
}),
|
||||
'all': { 'round': 0,
|
||||
'inv': 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -325,7 +325,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
def Pow2(a, l, kappa):
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
x = [types.sint() for i in range(m)]
|
||||
return Pow2_from_bits(t)
|
||||
|
||||
def Pow2_from_bits(bits):
|
||||
m = len(bits)
|
||||
t = list(bits)
|
||||
pow2k = [types.cint() for i in range(m)]
|
||||
for i in range(m):
|
||||
pow2k[i] = two_power(2**i)
|
||||
@@ -353,13 +357,20 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
#print ' '.join(str(b.value) for b in y)
|
||||
return [1 - y[i] for i in range(l)]
|
||||
|
||||
def Trunc(a, l, m, kappa, compute_modulo=False):
|
||||
def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
|
||||
""" Oblivious truncation by secret m """
|
||||
if util.is_constant(m) and not compute_modulo:
|
||||
# cheaper
|
||||
res = type(a)(size=a.size)
|
||||
comparison.Trunc(res, a, l, m, kappa, signed=signed)
|
||||
return res
|
||||
if l == 1:
|
||||
if compute_modulo:
|
||||
return a * m, 1 + m
|
||||
else:
|
||||
return a * (1 - m)
|
||||
if program.Program.prog.options.ring and not compute_modulo:
|
||||
return TruncInRing(a, l, Pow2(m, l, kappa))
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
@@ -370,8 +381,6 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
|
||||
x, pow2m = B2U(m, l, kappa)
|
||||
#assert(pow2m.value == 2**m.value)
|
||||
#assert(sum(b.value for b in x) == m.value)
|
||||
if program.Program.prog.options.ring and not compute_modulo:
|
||||
return TruncInRing(a, l, pow2m)
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
t1 = two_power(i) * r[i]
|
||||
@@ -495,17 +504,28 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
|
||||
else:
|
||||
from types import sint
|
||||
# extra bit to mask overflow
|
||||
r_bits = [sint.get_random_bit() for i in range(k + 1)]
|
||||
n_shift = n_ring - len(r_bits)
|
||||
tmp = a + sint.bit_compose(r_bits)
|
||||
masked = (tmp << n_shift).reveal()
|
||||
shifted = (masked << 1 >> (n_shift + m + 1))
|
||||
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
|
||||
res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m))
|
||||
if signed:
|
||||
a += (1 << (k - 1))
|
||||
if program.Program.prog.use_trunc_pr:
|
||||
res = sint()
|
||||
trunc_pr(res, a, k, m)
|
||||
else:
|
||||
# extra bit to mask overflow
|
||||
r_bits = [sint.get_random_bit() for i in range(k + 1)]
|
||||
n_shift = n_ring - len(r_bits)
|
||||
tmp = a + sint.bit_compose(r_bits)
|
||||
masked = (tmp << n_shift).reveal()
|
||||
shifted = (masked << 1 >> (n_shift + m + 1))
|
||||
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
|
||||
res = shifted - sint.bit_compose(r_bits[m:k]) + \
|
||||
(overflow << (k - m))
|
||||
if signed:
|
||||
res -= (1 << (k - m - 1))
|
||||
return res
|
||||
|
||||
def TruncPrField(a, k, m, kappa=None):
|
||||
if m == 0:
|
||||
return a
|
||||
if kappa is None:
|
||||
kappa = 40
|
||||
|
||||
@@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest)
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
|
||||
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
|
||||
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest)
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
|
||||
signed=False)
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
|
||||
signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
|
||||
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
|
||||
return y
|
||||
|
||||
|
||||
@@ -894,6 +894,55 @@ class inputfloat(base.TextInputInstruction):
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
4 * self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
class inputmixed(base.TextInputInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTMIXED']
|
||||
field_type = 'modp'
|
||||
# the following has to match TYPE: (N_DEST, N_PARAM)
|
||||
types = {
|
||||
0: (1, 0),
|
||||
1: (1, 1),
|
||||
2: (4, 1)
|
||||
}
|
||||
type_ids = {
|
||||
'int': 0,
|
||||
'fix': 1,
|
||||
'float': 2
|
||||
}
|
||||
|
||||
def __init__(self, name, *args):
|
||||
try:
|
||||
type_id = self.type_ids[name]
|
||||
except:
|
||||
pass
|
||||
super(inputmixed_class, self).__init__(type_id, *args)
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
yield 'int'
|
||||
for j in range(self.types[t][0]):
|
||||
yield 'sw'
|
||||
for j in range(self.types[t][1]):
|
||||
yield 'int'
|
||||
yield 'p'
|
||||
|
||||
def bases(self):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
yield i
|
||||
i += sum(self.types[self.args[i]]) + 2
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
player = self.args[i + sum(self.types[t]) + 1]
|
||||
n_dest = self.types[t][0]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
n_dest * self.get_size())
|
||||
|
||||
@base.gf2n
|
||||
class startinput(base.RawInputInstruction):
|
||||
r""" Receive inputs from player $p$. """
|
||||
@@ -957,6 +1006,11 @@ class print_reg_plain(base.IOInstruction):
|
||||
code = base.opcodes['PRINTREGPLAIN']
|
||||
arg_format = ['c']
|
||||
|
||||
class cond_print_plain(base.IOInstruction):
|
||||
r""" Conditionally print the value of a register. """
|
||||
code = base.opcodes['CONDPRINTPLAIN']
|
||||
arg_format = ['c', 'c']
|
||||
|
||||
class print_int(base.IOInstruction):
|
||||
r""" Print only the value of register \verb|ci| to stdout. """
|
||||
__slots__ = []
|
||||
@@ -1383,6 +1437,9 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
|
||||
|
||||
def merge_id(self):
|
||||
# can merge different sizes
|
||||
# but not if large
|
||||
if self.get_size() > 100:
|
||||
return type(self), self.get_size()
|
||||
return type(self)
|
||||
|
||||
# def expand(self):
|
||||
@@ -1468,6 +1525,14 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
for reg in self.args[i + 2:i + self.args[i]]:
|
||||
yield reg
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
""" Probalistic truncation for semi-honest computation """
|
||||
""" with honest majority """
|
||||
__slots__ = []
|
||||
code = base.opcodes['TRUNC_PR']
|
||||
arg_format = tools.cycle(['sw','s','int','int'])
|
||||
|
||||
###
|
||||
### CISC-style instructions
|
||||
###
|
||||
|
||||
@@ -89,6 +89,7 @@ opcodes = dict(
|
||||
MULS = 0xA6,
|
||||
MULRS = 0xA7,
|
||||
DOTPRODS = 0xA8,
|
||||
TRUNC_PR = 0xA9,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -102,6 +103,7 @@ opcodes = dict(
|
||||
INPUT = 0x60,
|
||||
INPUTFIX = 0xF0,
|
||||
INPUTFLOAT = 0xF1,
|
||||
INPUTMIXED = 0xF2,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
@@ -168,6 +170,7 @@ opcodes = dict(
|
||||
READFILESHARE = 0xBE,
|
||||
CONDPRINTSTR = 0xBF,
|
||||
PRINTFLOATPREC = 0xE0,
|
||||
CONDPRINTPLAIN = 0xE1,
|
||||
GBITDEC = 0x184,
|
||||
GBITCOM = 0x185,
|
||||
# Secure socket
|
||||
@@ -767,21 +770,6 @@ class ClearShiftInstruction(ClearImmediate):
|
||||
### Jumps etc
|
||||
###
|
||||
|
||||
class dummywrite(Instruction):
|
||||
""" Dummy instruction to create source node in the dependency graph,
|
||||
preventing read-before-write warnings. """
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.arg_format = [arg.reg_type + 'w' for arg in args]
|
||||
super(dummywrite, self).__init__(*args, **kwargs)
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
def get_encoding(self):
|
||||
return []
|
||||
|
||||
class JumpInstruction(Instruction):
|
||||
__slots__ = ['jump_arg']
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from Compiler import instructions,instructions_base,comparison,program,util
|
||||
import inspect,math
|
||||
import random
|
||||
import collections
|
||||
import operator
|
||||
|
||||
def get_program():
|
||||
return instructions.program
|
||||
@@ -93,16 +94,25 @@ def print_ln(s='', *args):
|
||||
print_str(s, *args)
|
||||
print_char('\n')
|
||||
|
||||
def print_ln_if(cond, s):
|
||||
def print_ln_if(cond, ss, *args):
|
||||
if util.is_constant(cond):
|
||||
if cond:
|
||||
print_ln(s)
|
||||
print_ln(ss, *args)
|
||||
else:
|
||||
s += ' ' * ((len(s) + 3) % 4)
|
||||
s += '\n'
|
||||
while s:
|
||||
cond.print_if(s[:4])
|
||||
s = s[4:]
|
||||
subs = ss.split('%s')
|
||||
assert len(subs) == len(args) + 1
|
||||
cond = cint.conv(cond)
|
||||
for i, s in enumerate(subs):
|
||||
if i != 0:
|
||||
cond_print_plain(cond, cint.conv(args[i - 1]))
|
||||
if i < len(args):
|
||||
s += ' ' * ((-len(s)) % 4)
|
||||
else:
|
||||
s += ' ' * ((-len(s) + 3) % 4)
|
||||
s += '\n'
|
||||
while s:
|
||||
cond.print_if(s[:4])
|
||||
s = s[4:]
|
||||
|
||||
def print_float_precision(n):
|
||||
print_float_prec(n)
|
||||
@@ -798,19 +808,23 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
lambda x: ((stop - start) / step) * x[0]
|
||||
|
||||
def for_range(start, stop=None, step=None):
|
||||
""" Execute loop bodies consecutively """
|
||||
def decorator(loop_body):
|
||||
range_loop(loop_body, start, stop, step)
|
||||
return loop_body
|
||||
return decorator
|
||||
|
||||
def for_range_parallel(n_parallel, n_loops):
|
||||
""" Execute up to n_parallel loop bodies in parallel """
|
||||
return map_reduce_single(n_parallel, n_loops)
|
||||
|
||||
def for_range_opt(n_loops):
|
||||
return map_reduce_single(None, n_loops)
|
||||
def for_range_opt(n_loops, budget=None):
|
||||
""" Execute loop bodies in parallel up to an optimization budget """
|
||||
return map_reduce_single(None, n_loops, budget=budget)
|
||||
|
||||
def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
reducer=lambda *x: [], mem_state=None):
|
||||
reducer=lambda *x: [], mem_state=None, budget=None):
|
||||
budget = budget or get_program().budget
|
||||
if not (isinstance(n_parallel, int) or n_parallel is None):
|
||||
raise CompilerException('Number of parallel executions' \
|
||||
'must be constant')
|
||||
@@ -848,14 +862,16 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
r = reducer(mem_state, state)
|
||||
write_state_to_memory(r)
|
||||
else:
|
||||
n_parallel_reg = MemValue(regint(0))
|
||||
if n_loops == 0:
|
||||
return
|
||||
regint.push(0)
|
||||
parent_block = get_block()
|
||||
@while_do(lambda x: x + n_parallel_reg <= n_loops, regint(0))
|
||||
@while_do(lambda x: x + regint.pop() <= n_loops, regint(0))
|
||||
def _(i):
|
||||
state = tuplify(initializer())
|
||||
k = 0
|
||||
block = get_block()
|
||||
while k < n_loops and (len(get_block()) < get_program().budget \
|
||||
while k < n_loops and (len(get_block()) < budget \
|
||||
or k == 0) \
|
||||
and block is get_block():
|
||||
j = i + k
|
||||
@@ -865,7 +881,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
write_state_to_memory(r)
|
||||
global n_opt_loops
|
||||
n_opt_loops = k
|
||||
n_parallel_reg.write(k)
|
||||
regint.push(k)
|
||||
return i + k
|
||||
my_n_parallel = n_opt_loops
|
||||
loop_rounds = n_loops / my_n_parallel
|
||||
@@ -915,12 +931,46 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
return decorator
|
||||
|
||||
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
|
||||
"""
|
||||
Execute loop bodies in up to n_threads threads,
|
||||
up to n_parallel in parallel per thread
|
||||
"""
|
||||
return map_reduce(n_threads, n_parallel, n_loops, \
|
||||
lambda *x: [], lambda *x: [], thread_mem_req)
|
||||
|
||||
def for_range_opt_multithread(n_threads, n_loops):
|
||||
"""
|
||||
Execute loop bodies in up to n_threads threads,
|
||||
in parallel up to an optimization budget per thread
|
||||
"""
|
||||
return for_range_multithread(n_threads, None, n_loops)
|
||||
|
||||
def multithread(n_threads, n_items):
|
||||
"""
|
||||
Distribute the computation of n_items to n_threads threads,
|
||||
but leave the in-thread repetition up to the user
|
||||
"""
|
||||
if n_threads == 1 or n_items == 1:
|
||||
return lambda loop_body: loop_body(0, n_items)
|
||||
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
|
||||
reducer=None, looping=False)
|
||||
|
||||
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
thread_mem_req={}):
|
||||
thread_mem_req={}, looping=True):
|
||||
n_threads = n_threads or 1
|
||||
if isinstance(n_loops, list):
|
||||
split = n_loops
|
||||
n_loops = reduce(operator.mul, n_loops)
|
||||
def decorator(loop_body):
|
||||
def new_body(i):
|
||||
indices = []
|
||||
for n in reversed(split):
|
||||
indices.insert(0, i % n)
|
||||
i /= n
|
||||
return loop_body(*indices)
|
||||
return new_body
|
||||
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
|
||||
return lambda loop_body: new_dec(decorator(loop_body))
|
||||
if n_threads == 1 or n_loops == 1:
|
||||
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
|
||||
if thread_mem_req:
|
||||
@@ -937,12 +987,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
|
||||
state = tuple(initializer())
|
||||
def f(inc):
|
||||
base = args[get_arg()][0]
|
||||
if not looping:
|
||||
return loop_body(base, thread_rounds + inc)
|
||||
if thread_mem_req:
|
||||
thread_mem = Array(thread_mem_req[regint], regint, \
|
||||
args[get_arg()].address + 2)
|
||||
mem_state = Array(len(state), type(state[0]) \
|
||||
if state else cint, args[get_arg()][1])
|
||||
base = args[get_arg()][0]
|
||||
@map_reduce_single(n_parallel, thread_rounds + inc, \
|
||||
initializer, reducer, mem_state)
|
||||
def f(i):
|
||||
@@ -1014,8 +1066,9 @@ def while_loop(loop_body, condition, arg):
|
||||
pushint(arg if isinstance(arg,regint) else regint(arg))
|
||||
def loop_fn():
|
||||
result = loop_body(regint.pop())
|
||||
cont = condition(result)
|
||||
pushint(result)
|
||||
return condition(result)
|
||||
return cont
|
||||
if_statement(pre_condition, lambda: do_while(loop_fn))
|
||||
regint.pop()
|
||||
|
||||
@@ -1278,7 +1331,7 @@ def sint_cint_division(a, b, k, f, kappa):
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
sign_b = cint(1) - 2 * cint(b < 0)
|
||||
sign_a = sint(1) - 2 * sint(a < 0)
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
|
||||
absolute_b = b * sign_b
|
||||
absolute_a = a * sign_a
|
||||
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
||||
@@ -1326,7 +1379,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
y = a.extend(2 *k) * w
|
||||
y = y.round(2*k, f, kappa, nearest, signed=True)
|
||||
|
||||
for i in range(theta):
|
||||
for i in range(theta - 1):
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
x = x * x
|
||||
@@ -1358,7 +1411,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
# For simplex, we can get rid of computing abs(b)
|
||||
temp = None
|
||||
if simplex_flag == False:
|
||||
temp = b.less_than(0, 2 * k)
|
||||
temp = comparison.LessThanZero(b, 2 * k, kappa)
|
||||
elif simplex_flag == True:
|
||||
temp = cint(0)
|
||||
|
||||
|
||||
814
Compiler/ml.py
Normal file
814
Compiler/ml.py
Normal file
@@ -0,0 +1,814 @@
|
||||
import mpc_math, math
|
||||
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _unreduced_squant
|
||||
from Compiler.library import *
|
||||
|
||||
def log_e(x):
|
||||
return mpc_math.log_fx(x, math.e)
|
||||
|
||||
def exp(x):
|
||||
return mpc_math.pow_fx(math.e, x)
|
||||
|
||||
def sanitize(x, raw, lower, upper):
|
||||
exp_limit = 2 ** (x.k - x.f - 1)
|
||||
limit = math.log(exp_limit)
|
||||
if get_program().options.ring:
|
||||
res = raw
|
||||
else:
|
||||
res = (x > limit).if_else(upper, raw)
|
||||
return (x < -limit).if_else(lower, res)
|
||||
|
||||
def sigmoid(x):
|
||||
return sigmoid_from_e_x(x, exp(-x))
|
||||
|
||||
def sigmoid_from_e_x(x, e_x):
|
||||
return sanitize(x, 1 / (1 + e_x), 0, 1)
|
||||
|
||||
def sigmoid_prime(x):
|
||||
sx = sigmoid(x)
|
||||
return sx * (1 - sx)
|
||||
|
||||
def lse_0_from_e_x(x, e_x):
|
||||
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
|
||||
|
||||
def lse_0(x):
|
||||
return lse_0_from_e_x(x, exp(x))
|
||||
|
||||
def relu_prime(x):
|
||||
return (0 <= x)
|
||||
|
||||
def relu(x):
|
||||
return (0 < x).if_else(x, 0)
|
||||
|
||||
def progress(x):
|
||||
return
|
||||
print_ln(x)
|
||||
time()
|
||||
|
||||
def set_n_threads(n_threads):
|
||||
Layer.n_threads = n_threads
|
||||
Optimizer.n_threads = n_threads
|
||||
|
||||
class Layer:
|
||||
n_threads = 1
|
||||
|
||||
class Output(Layer):
|
||||
def __init__(self, N, debug=False):
|
||||
self.N = N
|
||||
self.X = sfix.Array(N)
|
||||
self.Y = sfix.Array(N)
|
||||
self.nabla_X = sfix.Array(N)
|
||||
self.l = MemValue(sfix(-1))
|
||||
self.e_x = sfix.Array(N)
|
||||
self.debug = debug
|
||||
self.weights = cint.Array(N)
|
||||
self.weights.assign_all(1)
|
||||
self.weight_total = N
|
||||
|
||||
nablas = lambda self: ()
|
||||
thetas = lambda self: ()
|
||||
reset = lambda self: None
|
||||
|
||||
def divisor(self, divisor, size):
|
||||
return cfix(1.0 / divisor, size=size)
|
||||
|
||||
def forward(self, N=None):
|
||||
N = N or self.N
|
||||
lse = sfix.Array(N)
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
x = self.X.get_vector(base, size)
|
||||
y = self.Y.get_vector(base, size)
|
||||
e_x = exp(-x)
|
||||
self.e_x.assign(e_x, base)
|
||||
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
||||
e_x = self.e_x.get_vector(0, N)
|
||||
self.l.write(sum(lse) * \
|
||||
self.divisor(self.N, 1))
|
||||
|
||||
def backward(self):
|
||||
@multithread(self.n_threads, self.N)
|
||||
def _(base, size):
|
||||
diff = sigmoid_from_e_x(self.X.get_vector(base, size),
|
||||
self.e_x.get_vector(base, size)) - \
|
||||
self.Y.get_vector(base, size)
|
||||
assert sfix.f == cfix.f
|
||||
diff *= self.weights.get_vector(base, size)
|
||||
self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \
|
||||
base)
|
||||
# @for_range_opt(len(diff))
|
||||
# def _(i):
|
||||
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
|
||||
if self.debug:
|
||||
a = cfix.Array(len(diff))
|
||||
a.assign(diff.reveal())
|
||||
@for_range(len(diff))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x < -1.001) + (x > 1.001), 'sigmoid')
|
||||
#print_ln('%s', x)
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.weights.assign(weights)
|
||||
self.weight_total = sum(weights)
|
||||
|
||||
class DenseBase(Layer):
|
||||
thetas = lambda self: (self.W, self.b)
|
||||
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
||||
|
||||
def backward_params(self, f_schur_Y):
|
||||
N = self.N
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
|
||||
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
|
||||
def _(j, k):
|
||||
assert self.d == 1
|
||||
a = [f_schur_Y[i][0][k] for i in range(N)]
|
||||
b = [self.X[i][0][j] for i in range(N)]
|
||||
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
||||
|
||||
if self.d_in * self.d_out < 100000:
|
||||
print 'reduce at once'
|
||||
@multithread(self.n_threads, self.d_in * self.d_out)
|
||||
def _(base, size):
|
||||
self.nabla_W.assign_vector(
|
||||
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
||||
else:
|
||||
@for_range_opt(self.d_in)
|
||||
def _(i):
|
||||
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
|
||||
|
||||
self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N))
|
||||
for j in range(self.d)) for i in range(self.d_out))
|
||||
|
||||
progress('nabla W/b')
|
||||
|
||||
class Dense(DenseBase):
|
||||
def __init__(self, N, d_in, d_out, d=1, activation='id'):
|
||||
self.activation = activation
|
||||
if activation == 'id':
|
||||
self.f = lambda x: x
|
||||
elif activation == 'relu':
|
||||
self.f = relu
|
||||
self.f_prime = relu_prime
|
||||
elif activation == 'sigmoid':
|
||||
self.f = sigmoid
|
||||
self.f_prime = sigmoid_prime
|
||||
|
||||
self.N = N
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.d = d
|
||||
|
||||
self.X = MultiArray([N, d, d_in], sfix)
|
||||
self.Y = MultiArray([N, d, d_out], sfix)
|
||||
self.W = sfix.Matrix(d_in, d_out)
|
||||
self.b = sfix.Array(d_out)
|
||||
|
||||
self.reset()
|
||||
|
||||
self.nabla_Y = MultiArray([N, d, d_out], sfix)
|
||||
self.nabla_X = MultiArray([N, d, d_in], sfix)
|
||||
self.nabla_W = sfix.Matrix(d_in, d_out)
|
||||
self.nabla_W.assign_all(0)
|
||||
self.nabla_b = sfix.Array(d_out)
|
||||
|
||||
self.f_input = MultiArray([N, d, d_out], sfix)
|
||||
|
||||
def reset(self):
|
||||
d_in = self.d_in
|
||||
d_out = self.d_out
|
||||
r = math.sqrt(6.0 / (d_in + d_out))
|
||||
@for_range(d_in)
|
||||
def _(i):
|
||||
@for_range(d_out)
|
||||
def _(j):
|
||||
self.W[i][j] = sfix.get_random(-r, r)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def compute_f_input(self):
|
||||
prod = MultiArray([self.N, self.d, self.d_out], sfix)
|
||||
@for_range_opt_multithread(self.n_threads, self.N)
|
||||
def _(i):
|
||||
self.X[i].plain_mul(self.W, res=prod[i])
|
||||
|
||||
@for_range_opt_multithread(self.n_threads, self.N)
|
||||
def _(i):
|
||||
@for_range_opt(self.d)
|
||||
def _(j):
|
||||
v = prod[i][j].get_vector() + self.b.get_vector()
|
||||
self.f_input[i][j].assign(v)
|
||||
progress('f input')
|
||||
|
||||
def forward(self):
|
||||
self.compute_f_input()
|
||||
self.Y.assign_vector(self.f(self.f_input.get_vector()))
|
||||
|
||||
def backward(self, compute_nabla_X=True):
|
||||
N = self.N
|
||||
d = self.d
|
||||
d_out = self.d_out
|
||||
X = self.X
|
||||
Y = self.Y
|
||||
W = self.W
|
||||
b = self.b
|
||||
nabla_X = self.nabla_X
|
||||
nabla_Y = self.nabla_Y
|
||||
nabla_W = self.nabla_W
|
||||
nabla_b = self.nabla_b
|
||||
|
||||
if self.activation == 'id':
|
||||
f_schur_Y = nabla_Y
|
||||
else:
|
||||
f_prime_bit = MultiArray([N, d, d_out], sint)
|
||||
f_schur_Y = MultiArray([N, d, d_out], sfix)
|
||||
|
||||
self.compute_f_input()
|
||||
f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector()))
|
||||
|
||||
progress('f prime')
|
||||
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
|
||||
|
||||
progress('f prime schur Y')
|
||||
|
||||
if compute_nabla_X:
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
if self.activation == 'id':
|
||||
nabla_X[i] = nabla_Y[i].mul_trans(W)
|
||||
else:
|
||||
nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W)
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y)
|
||||
|
||||
class QuantizedDense(DenseBase):
|
||||
def __init__(self, N, d_in, d_out):
|
||||
self.N = N
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.d = 1
|
||||
self.H = math.sqrt(1.5 / (d_in + d_out))
|
||||
|
||||
self.W = sfix.Matrix(d_in, d_out)
|
||||
self.nabla_W = self.W.same_shape()
|
||||
self.T = sint.Matrix(d_in, d_out)
|
||||
self.b = sfix.Array(d_out)
|
||||
self.nabla_b = self.b.same_shape()
|
||||
|
||||
self.X = MultiArray([N, 1, d_in], sfix)
|
||||
self.Y = MultiArray([N, 1, d_out], sfix)
|
||||
self.nabla_Y = self.Y.same_shape()
|
||||
|
||||
def reset(self):
|
||||
@for_range(self.d_in)
|
||||
def _(i):
|
||||
@for_range(self.d_out)
|
||||
def _(j):
|
||||
self.W[i][j] = sfix.get_random(-1, 1)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def forward(self):
|
||||
@for_range_opt(self.d_in)
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
over = self.W[i][j] > 0.5
|
||||
under = self.W[i][j] < -0.5
|
||||
self.T[i][j] = over.if_else(1, under.if_else(-1, 0))
|
||||
over = self.W[i][j] > 1
|
||||
under = self.W[i][j] < -1
|
||||
self.W[i][j] = over.if_else(1, under.if_else(-1, self.W[i][j]))
|
||||
@for_range_opt(self.N)
|
||||
def _(i):
|
||||
assert self.d_out == 1
|
||||
self.Y[i][0][0] = self.b[0] + self.H * sfix._new(
|
||||
sint.dot_product([self.T[j][0] for j in range(self.d_in)],
|
||||
[self.X[i][0][j].v for j in range(self.d_in)]))
|
||||
|
||||
def backward(self, compute_nabla_X=False):
|
||||
assert not compute_nabla_X
|
||||
self.backward_params(self.nabla_Y)
|
||||
|
||||
class Dropout:
|
||||
def __init__(self, N, d1, d2=1):
|
||||
self.N = N
|
||||
self.d1 = d1
|
||||
self.d2 = d2
|
||||
self.X = MultiArray([N, d1, d2], sfix)
|
||||
self.Y = MultiArray([N, d1, d2], sfix)
|
||||
self.nabla_Y = MultiArray([N, d1, d2], sfix)
|
||||
self.nabla_X = MultiArray([N, d1, d2], sfix)
|
||||
self.alpha = 0.5
|
||||
self.B = MultiArray([N, d1, d2], sint)
|
||||
|
||||
def forward(self):
|
||||
assert self.alpha == 0.5
|
||||
@for_range(self.N)
|
||||
def _(i):
|
||||
@for_range(self.d1)
|
||||
def _(j):
|
||||
@for_range(self.d2)
|
||||
def _(k):
|
||||
self.B[i][j][k] = sint.get_random_bit()
|
||||
self.Y = self.X.schur(self.B)
|
||||
|
||||
def backward(self):
|
||||
self.nabla_X = self.nabla_Y.schur(self.B)
|
||||
|
||||
class QuantBase(object):
|
||||
n_threads = 1
|
||||
|
||||
@staticmethod
|
||||
def new_squant():
|
||||
class _(squant):
|
||||
@classmethod
|
||||
def get_input_from(cls, player, size=None):
|
||||
return cls._new(sint.get_input_from(player, size=size))
|
||||
return _
|
||||
|
||||
def __init__(self, input_shape, output_shape):
|
||||
self.input_shape = input_shape
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.input_squant = self.new_squant()
|
||||
self.output_squant = self.new_squant()
|
||||
|
||||
self.X = MultiArray(input_shape, self.input_squant)
|
||||
self.Y = MultiArray(output_shape, self.output_squant)
|
||||
|
||||
def temp_shape(self):
|
||||
return [0]
|
||||
|
||||
class QuantConvBase(QuantBase):
|
||||
fewer_rounds = True
|
||||
temp_weights = None
|
||||
temp_inputs = None
|
||||
|
||||
@classmethod
|
||||
def init_temp(cls, layers):
|
||||
size = 0
|
||||
for layer in layers:
|
||||
size = max(size, reduce(operator.mul, layer.temp_shape()))
|
||||
cls.temp_weights = sfix.Array(size)
|
||||
cls.temp_inputs = sfix.Array(size)
|
||||
|
||||
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride):
|
||||
super(QuantConvBase, self).__init__(input_shape, output_shape)
|
||||
|
||||
self.weight_shape = weight_shape
|
||||
self.bias_shape = bias_shape
|
||||
self.stride = stride
|
||||
|
||||
self.weight_squant = self.new_squant()
|
||||
self.bias_squant = self.new_squant()
|
||||
|
||||
self.weights = MultiArray(weight_shape, self.weight_squant)
|
||||
self.bias = Array(output_shape[-1], self.bias_squant)
|
||||
|
||||
self.unreduced = MultiArray(self.output_shape, sint,
|
||||
address=self.Y.address)
|
||||
|
||||
assert(weight_shape[-1] == input_shape[-1])
|
||||
assert(bias_shape[0] == output_shape[-1])
|
||||
assert(len(bias_shape) == 1)
|
||||
assert(len(input_shape) == 4)
|
||||
assert(len(output_shape) == 4)
|
||||
assert(len(weight_shape) == 4)
|
||||
|
||||
def input_from(self, player):
|
||||
for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
self.weights.input_from(player, budget=100000)
|
||||
self.bias.input_from(player)
|
||||
print 'WARNING: assuming that bias quantization parameters are correct'
|
||||
|
||||
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
|
||||
|
||||
def dot_product(self, iv, wv, out_y, out_x, out_c):
|
||||
bias = self.bias[out_c]
|
||||
acc = squant.unreduced_dot_product(iv, wv)
|
||||
acc.v += bias.v
|
||||
acc.res_params = self.output_squant.params
|
||||
#self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul()
|
||||
self.unreduced[0][out_y][out_x][out_c] = acc.v
|
||||
|
||||
def reduction(self):
|
||||
unreduced = self.unreduced
|
||||
n_summands = self.n_summands()
|
||||
start_timer(2)
|
||||
n_outputs = reduce(operator.mul, self.output_shape)
|
||||
if n_outputs % self.n_threads == 0:
|
||||
n_per_thread = n_outputs / self.n_threads
|
||||
@for_range_opt_multithread(self.n_threads, self.n_threads)
|
||||
def _(i):
|
||||
res = _unreduced_squant(
|
||||
sint.load_mem(unreduced.address + i * n_per_thread,
|
||||
size=n_per_thread),
|
||||
(self.input_squant.params, self.weight_squant.params),
|
||||
self.output_squant.params,
|
||||
n_summands).reduce_after_mul()
|
||||
res.store_in_mem(self.Y.address + i * n_per_thread)
|
||||
else:
|
||||
@for_range_opt_multithread(self.n_threads, self.output_shape[1])
|
||||
def _(out_y):
|
||||
self.Y[0][out_y].assign_vector(_unreduced_squant(
|
||||
unreduced[0][out_y].get_vector(),
|
||||
(self.input_squant.params, self.weight_squant.params),
|
||||
self.output_squant.params,
|
||||
n_summands).reduce_after_mul())
|
||||
stop_timer(2)
|
||||
|
||||
def temp_shape(self):
|
||||
return list(self.output_shape[1:]) + [self.n_summands()]
|
||||
|
||||
def prepare_temp(self):
|
||||
shape = self.temp_shape()
|
||||
inputs = MultiArray(shape, self.input_squant,
|
||||
address=self.temp_inputs)
|
||||
weights = MultiArray(shape, self.weight_squant,
|
||||
address=self.temp_weights)
|
||||
return inputs, weights
|
||||
|
||||
class QuantConv2d(QuantConvBase):
|
||||
def n_summands(self):
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
||||
return weights_h * weights_w * n_channels_in
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
assert(self.weight_shape[0] == self.output_shape[-1])
|
||||
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
||||
_, output_h, output_w, n_channels_out = self.output_shape
|
||||
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = (weights_h // 2, weights_w // 2)
|
||||
|
||||
if self.fewer_rounds:
|
||||
inputs, weights = self.prepare_temp()
|
||||
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[output_h, output_w, n_channels_out])
|
||||
def _(out_y, out_x, out_c):
|
||||
in_x_origin = (out_x * stride_w) - padding_w
|
||||
in_y_origin = (out_y * stride_h) - padding_h
|
||||
iv = []
|
||||
wv = []
|
||||
for filter_y in range(weights_h):
|
||||
in_y = in_y_origin + filter_y
|
||||
inside_y = (0 <= in_y) * (in_y < inputs_h)
|
||||
for filter_x in range(weights_w):
|
||||
in_x = in_x_origin + filter_x
|
||||
inside_x = (0 <= in_x) * (in_x < inputs_w)
|
||||
inside = inside_y * inside_x
|
||||
if inside is 0:
|
||||
continue
|
||||
for in_c in range(n_channels_in):
|
||||
iv += [self.X[0][in_y * inside_y]
|
||||
[in_x * inside_x][in_c]]
|
||||
wv += [self.weights[out_c][filter_y][filter_x][in_c]]
|
||||
wv[-1] *= inside
|
||||
if self.fewer_rounds:
|
||||
inputs[out_y][out_x][out_c].assign(iv)
|
||||
weights[out_y][out_x][out_c].assign(wv)
|
||||
else:
|
||||
self.dot_product(iv, wv, out_y, out_x, out_c)
|
||||
|
||||
if self.fewer_rounds:
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
list(self.output_shape[1:]))
|
||||
def _(out_y, out_x, out_c):
|
||||
self.dot_product(inputs[out_y][out_x][out_c],
|
||||
weights[out_y][out_x][out_c],
|
||||
out_y, out_x, out_c)
|
||||
|
||||
self.reduction()
|
||||
|
||||
class QuantDepthwiseConv2d(QuantConvBase):
|
||||
def n_summands(self):
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
return weights_h * weights_w
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
assert(self.weight_shape[-1] == self.output_shape[-1])
|
||||
assert(self.input_shape[-1] == self.output_shape[-1])
|
||||
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
||||
_, output_h, output_w, n_channels_out = self.output_shape
|
||||
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = (weights_h // 2, weights_w // 2)
|
||||
|
||||
depth_multiplier = 1
|
||||
|
||||
if self.fewer_rounds:
|
||||
inputs, weights = self.prepare_temp()
|
||||
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[output_h, output_w, n_channels_in])
|
||||
def _(out_y, out_x, in_c):
|
||||
for m in range(depth_multiplier):
|
||||
oc = m + in_c * depth_multiplier
|
||||
in_x_origin = (out_x * stride_w) - padding_w
|
||||
in_y_origin = (out_y * stride_h) - padding_h
|
||||
iv = []
|
||||
wv = []
|
||||
for filter_y in range(weights_h):
|
||||
for filter_x in range(weights_w):
|
||||
in_x = in_x_origin + filter_x
|
||||
in_y = in_y_origin + filter_y
|
||||
inside = (0 <= in_x) * (in_x < inputs_w) * \
|
||||
(0 <= in_y) * (in_y < inputs_h)
|
||||
if inside is 0:
|
||||
continue
|
||||
iv += [self.X[0][in_y][in_x][in_c]]
|
||||
wv += [self.weights[0][filter_y][filter_x][oc]]
|
||||
wv[-1] *= inside
|
||||
if self.fewer_rounds:
|
||||
inputs[out_y][out_x][oc].assign(iv)
|
||||
weights[out_y][out_x][oc].assign(wv)
|
||||
else:
|
||||
self.dot_product(iv, wv, out_y, out_x, oc)
|
||||
|
||||
if self.fewer_rounds:
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
list(self.output_shape[1:]))
|
||||
def _(out_y, out_x, out_c):
|
||||
self.dot_product(inputs[out_y][out_x][out_c],
|
||||
weights[out_y][out_x][out_c],
|
||||
out_y, out_x, out_c)
|
||||
|
||||
self.reduction()
|
||||
|
||||
class QuantAveragePool2d(QuantBase):
|
||||
def __init__(self, input_shape, output_shape, filter_size):
|
||||
super(QuantAveragePool2d, self).__init__(input_shape, output_shape)
|
||||
self.filter_size = filter_size
|
||||
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
|
||||
_, input_h, input_w, n_channels_in = self.input_shape
|
||||
_, output_h, output_w, n_channels_out = self.output_shape
|
||||
|
||||
n = input_h * input_w
|
||||
print 'divisor: ', n
|
||||
|
||||
assert output_h == output_w == 1
|
||||
assert n_channels_in == n_channels_out
|
||||
|
||||
padding_h, padding_w = (0, 0)
|
||||
stride_h, stride_w = (2, 2)
|
||||
filter_h, filter_w = self.filter_size
|
||||
|
||||
@for_range_opt(output_h)
|
||||
def _(out_y):
|
||||
@for_range_opt(output_w)
|
||||
def _(out_x):
|
||||
@for_range_opt(n_channels_in)
|
||||
def _(c):
|
||||
in_x_origin = (out_x * stride_w) - padding_w
|
||||
in_y_origin = (out_y * stride_h) - padding_h
|
||||
fxs = (-in_x_origin).max(0)
|
||||
#fxe = min(filter_w, input_w - in_x_origin)
|
||||
fys = (-in_y_origin).max(0)
|
||||
#fye = min(filter_h, input_h - in_y_origin)
|
||||
acc = 0
|
||||
#fc = 0
|
||||
for i in range(filter_h):
|
||||
filter_y = fys + i
|
||||
for j in range(filter_w):
|
||||
filter_x = fxs + j
|
||||
in_x = in_x_origin + filter_x
|
||||
in_y = in_y_origin + filter_y
|
||||
acc += self.X[0][in_y][in_x][c].v
|
||||
#fc += 1
|
||||
logn = int(math.log(n, 2))
|
||||
acc = (acc + n / 2)
|
||||
if 2 ** logn == n:
|
||||
acc = acc.round(self.output_squant.params.k + logn,
|
||||
logn, nearest=True)
|
||||
else:
|
||||
acc = acc.int_div(sint(n),
|
||||
self.output_squant.params.k + logn)
|
||||
#acc = min(255, max(0, acc))
|
||||
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
|
||||
|
||||
class QuantReshape(QuantBase):
|
||||
def __init__(self, input_shape, _, output_shape):
|
||||
super(QuantReshape, self).__init__(input_shape, output_shape)
|
||||
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
_ = self.new_squant()
|
||||
for s in self.input_squant, _, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
for i in range(2):
|
||||
sint.get_input_from(player)
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
# reshaping is implicit
|
||||
self.Y.assign(self.X)
|
||||
|
||||
class QuantSoftmax(QuantBase):
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
assert(len(self.input_shape) == 2)
|
||||
|
||||
# just print the best
|
||||
def comp(left, right):
|
||||
c = left[1].v.greater_than(right[1].v, self.input_squant.params.k)
|
||||
#print_ln('comp %s %s %s', c.reveal(), left[1].v.reveal(), right[1].v.reveal())
|
||||
return [c.if_else(x, y) for x, y in zip(left, right)]
|
||||
print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal())
|
||||
|
||||
class Optimizer:
|
||||
n_threads = Layer.n_threads
|
||||
|
||||
def forward(self, N):
|
||||
for j in range(len(self.layers) - 1):
|
||||
self.layers[j].forward()
|
||||
self.layers[j + 1].X.assign(self.layers[j].Y)
|
||||
self.layers[-1].forward(N)
|
||||
|
||||
def backward(self):
|
||||
for j in range(1, len(self.layers)):
|
||||
self.layers[-j].backward()
|
||||
self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X)
|
||||
self.layers[0].backward(compute_nabla_X=False)
|
||||
|
||||
def run(self):
|
||||
i = MemValue(0)
|
||||
@do_while
|
||||
def _():
|
||||
if self.X_by_label is not None:
|
||||
N = self.layers[0].N
|
||||
assert self.layers[-1].N == N
|
||||
assert N % 2 == 0
|
||||
n = N / 2
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
self.layers[-1].Y[i] = 0
|
||||
self.layers[-1].Y[i + n] = 1
|
||||
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
||||
self.X_by_label) / n))
|
||||
print '%d runs per epoch' % n_per_epoch
|
||||
indices_by_label = []
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = regint.Array(n * n_per_epoch)
|
||||
indices_by_label.append(indices)
|
||||
indices.assign(i % len(X) for i in range(len(indices)))
|
||||
indices.shuffle()
|
||||
@for_range(n_per_epoch)
|
||||
def _(j):
|
||||
j = MemValue(j)
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = indices_by_label[label]
|
||||
@for_range_multithread(self.n_threads, 1, n)
|
||||
def _(i):
|
||||
idx = indices[i + j * n_per_epoch]
|
||||
self.layers[0].X[i + label * n] = X[idx]
|
||||
self.forward(None)
|
||||
self.backward()
|
||||
self.update(i)
|
||||
else:
|
||||
self.forward(None)
|
||||
self.backward()
|
||||
self.update(i)
|
||||
loss = self.layers[-1].l
|
||||
if self.report_loss:
|
||||
print_ln('loss after epoch %s: %s', i, loss.reveal())
|
||||
else:
|
||||
print_ln('done with epoch %s', i)
|
||||
time()
|
||||
i.iadd(1)
|
||||
res = (i < self.n_epochs)
|
||||
if self.tol > 0:
|
||||
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
|
||||
return res
|
||||
print_ln('finished after %s epochs', i)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, layers, n_epochs):
|
||||
self.alpha = .001
|
||||
self.beta1 = 0.9
|
||||
self.beta2 = 0.999
|
||||
self.epsilon = 10 ** -8
|
||||
self.n_epochs = n_epochs
|
||||
|
||||
self.layers = layers
|
||||
self.ms = []
|
||||
self.vs = []
|
||||
self.gs = []
|
||||
self.thetas = []
|
||||
for layer in layers:
|
||||
for nabla in layer.nablas():
|
||||
self.gs.append(nabla)
|
||||
for x in self.ms, self.vs:
|
||||
x.append(nabla.same_shape())
|
||||
for theta in layer.thetas():
|
||||
self.thetas.append(theta)
|
||||
|
||||
self.mhat_factors = Array(n_epochs, sfix)
|
||||
self.vhat_factors = Array(n_epochs, sfix)
|
||||
|
||||
for i in range(n_epochs):
|
||||
for factors, beta in ((self.mhat_factors, self.beta1),
|
||||
(self.vhat_factors, self.beta2)):
|
||||
factors[i] = 1. / (1 - beta ** (i + 1))
|
||||
|
||||
def update(self, i_epoch):
|
||||
for m, v, g, theta in zip(self.ms, self.vs, self.gs, self.thetas):
|
||||
@for_range_opt(len(m))
|
||||
def _(k):
|
||||
m[k] = self.beta1 * m[k] + (1 - self.beta1) * g[k]
|
||||
v[k] = self.beta2 * v[k] + (1 - self.beta2) * g[k] ** 2
|
||||
mhat = m[k] * self.mhat_factors[i_epoch]
|
||||
vhat = v[k] * self.vhat_factors[i_epoch]
|
||||
theta[k] = theta[k] - self.alpha * mhat / \
|
||||
mpc_math.sqrt(vhat) + self.epsilon
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, layers, n_epochs, debug=False, report_loss=False):
|
||||
self.momentum = 0.9
|
||||
self.layers = layers
|
||||
self.n_epochs = n_epochs
|
||||
self.thetas = []
|
||||
self.nablas = []
|
||||
self.delta_thetas = []
|
||||
for layer in layers:
|
||||
self.nablas.extend(layer.nablas())
|
||||
self.thetas.extend(layer.thetas())
|
||||
for theta in layer.thetas():
|
||||
self.delta_thetas.append(theta.same_shape())
|
||||
self.gamma = MemValue(sfix(0.01))
|
||||
self.debug = debug
|
||||
self.report_loss = report_loss
|
||||
self.tol = 0.000
|
||||
self.X_by_label = None
|
||||
|
||||
def reset(self, X_by_label=None):
|
||||
self.X_by_label = X_by_label
|
||||
for y in self.delta_thetas:
|
||||
y.assign_all(0)
|
||||
for layer in self.layers:
|
||||
layer.reset()
|
||||
|
||||
def update(self, i_epoch):
|
||||
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
||||
self.delta_thetas):
|
||||
@for_range_opt_multithread(self.n_threads, len(nabla))
|
||||
def _(k):
|
||||
old = delta_theta[k]
|
||||
if isinstance(old, Array):
|
||||
old = old.get_vector()
|
||||
red_old = self.momentum * old
|
||||
new = self.gamma * nabla[k]
|
||||
diff = red_old - new
|
||||
delta_theta[k] = diff
|
||||
theta[k] = theta[k] + delta_theta[k]
|
||||
if self.debug:
|
||||
for x, name in (old, 'old'), (red_old, 'red_old'), \
|
||||
(new, 'new'), (diff, 'diff'):
|
||||
x = x.reveal()
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
name + ': %s %s %s %s',
|
||||
*[y.v.reveal() for y in old, red_old, \
|
||||
new, diff])
|
||||
if self.debug:
|
||||
d = delta_theta.get_vector().reveal()
|
||||
a = cfix.Array(len(d.v))
|
||||
a.assign(d)
|
||||
@for_range(len(a))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
'update len=%d' % len(nabla))
|
||||
a.assign(nabla.get_vector().reveal())
|
||||
@for_range(len(a))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
'nabla len=%d' % len(nabla))
|
||||
self.gamma.imul(1 - 10 ** - 6)
|
||||
752
Compiler/mpc_math.py
Normal file
752
Compiler/mpc_math.py
Normal file
@@ -0,0 +1,752 @@
|
||||
##
|
||||
# @file
|
||||
# Arithmetic Module for Complex Math Operations
|
||||
#
|
||||
# Implements trigonometric and logarithmic functions.
|
||||
|
||||
import math
|
||||
from Compiler import floatingpoint
|
||||
from Compiler import types
|
||||
from Compiler import comparison
|
||||
from Compiler import program
|
||||
# polynomials as enumerated on Hart's book
|
||||
##
|
||||
# @private
|
||||
p_3307 = [1.57079632679489000000000, -0.64596409750624600000000,
|
||||
0.07969262624616700000000, -0.00468175413531868000000,
|
||||
0.00016044118478735800000, -0.00000359884323520707000,
|
||||
0.00000005692172920657320, -0.00000000066880348849204,
|
||||
0.00000000000606691056085, -0.00000000000004375295071,
|
||||
0.00000000000000025002854]
|
||||
##
|
||||
# @private
|
||||
p_3508 = [1.00000000000000000000, -0.50000000000000000000,
|
||||
0.04166666666666667129, -0.00138888888888888873,
|
||||
0.00002480158730158702, -0.00000027557319223933,
|
||||
0.00000000208767569817, -0.00000000001147074513,
|
||||
0.00000000000004779454, -0.00000000000000015612,
|
||||
0.00000000000000000040]
|
||||
##
|
||||
# @private
|
||||
p_1045 = [1.000000077443021686, 0.693147180426163827795756,
|
||||
0.224022651071017064605384, 0.055504068620466379157744,
|
||||
0.009618341225880462374977, 0.001332730359281437819329,
|
||||
0.000155107460590052573978, 0.000014197847399765606711,
|
||||
0.000001863347724137967076]
|
||||
##
|
||||
# @private
|
||||
p_2524 = [-2.05466671951, -8.8626599391,
|
||||
+6.10585199015, +4.81147460989]
|
||||
##
|
||||
# @private
|
||||
q_2524 = [+0.353553425277, +4.54517087629,
|
||||
+6.42784209029, +1]
|
||||
##
|
||||
# @private
|
||||
p_5102 = [+21514.05962602441933193254468, +73597.43380288444240814980706,
|
||||
+100272.5618306302784970511863, +69439.29750032252337059765503,
|
||||
+25858.09739719099025716567793, +5038.63918550126655793779119,
|
||||
+460.1588804635351471161727227, +15.08767735870030987717455528,
|
||||
+0.07523052818757628444510729539]
|
||||
##
|
||||
# @private
|
||||
q_5102 = [+21514.05962602441933193298234, +80768.78701155924885176713209,
|
||||
+122892.6789092784776298743322, +97323.20349053555680260434387,
|
||||
+42868.57652046408093184006664, +10401.13491566890057005103878,
|
||||
+1289.75056911611097141145955, +68.51937831018968013114024294,
|
||||
+1]
|
||||
##
|
||||
# @private
|
||||
p_4737 = [-9338.550897341021522505385079, +43722.68009378241623148489754,
|
||||
-86008.12066370804865047446067, +92190.57592175496843898184959,
|
||||
-58360.27724533928122075635101, +22081.61324178027161353562222,
|
||||
-4805.541226761699661564427739, +542.2148323255220943742314911,
|
||||
-24.94928894422502466205102672, 0.2222361619461131578797029272]
|
||||
##
|
||||
# @private
|
||||
q_4737 =[-9338.550897341021522505384935, +45279.10524333925315190231067,
|
||||
-92854.24688696401422824346529, +104687.2504366298224257408682,
|
||||
-70581.74909396877350961227976, +28972.22947326672977624954443,
|
||||
-7044.002024719172700685571406, +935.7104153502806086331621628,
|
||||
-56.83369358538071475796209327, 1]
|
||||
##
|
||||
# @private
|
||||
p_4754 = [-6.90859801, +12.85564644, -5.94939208]
|
||||
|
||||
##
|
||||
# @private
|
||||
q_4754 = [-6.92529156, +14.20305096, -8.27925501, 1]
|
||||
|
||||
# all inputs are calcualted in radians hence we need some conversion.
|
||||
pi = math.radians(180)
|
||||
pi_over_2 = math.radians(90)
|
||||
|
||||
##
|
||||
# truncates values regardless of the input type. (It always rounds down)
|
||||
# @param x: coefficient to be truncated.
|
||||
#
|
||||
# @return truncated sint value of x
|
||||
def trunc(x):
|
||||
if type(x) is types.sfix:
|
||||
return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True)
|
||||
elif type(x) is types.sfloat:
|
||||
v, p, z, s = floatingpoint.FLRound(x, 0)
|
||||
#return types.sfloat(v, p, z, s, x.err)
|
||||
return types.sfloat(v, p, z, s)
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# loads integer to fractional type (sint)
|
||||
# @param x: coefficient to be truncated.
|
||||
#
|
||||
# @return returns sfix, sfloat loaded value
|
||||
def load_sint(x, l_type):
|
||||
if l_type is types.sfix:
|
||||
return types.sfix.from_sint(x)
|
||||
elif l_type is types.sfloat:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# evaluates a Polynomial to a given x in a privacy preserving manner.
|
||||
# Inputs can be of any kind of register, secret or otherwise.
|
||||
#
|
||||
# @param p_c: Polynomial coefficients. (Array)
|
||||
#
|
||||
# @param x: Value to which the polynomial p_c is evaluated to.(register)
|
||||
#
|
||||
# @return the evaluation of the polynomial. return type depends on inputs.
|
||||
def p_eval(p_c, x):
|
||||
degree = len(p_c) - 1
|
||||
if type(x) is types.sfix:
|
||||
# ignore coefficients smaller than precision
|
||||
for c in reversed(p_c):
|
||||
if c < 2 ** -(x.f + 1):
|
||||
degree -= 1
|
||||
else:
|
||||
break
|
||||
pre_mults = floatingpoint.PreOpL(lambda a,b,_: a * b,
|
||||
[x] * degree)
|
||||
local_aggregation = 0
|
||||
# Evaluation of the Polynomial
|
||||
for i, pre_mult in zip(p_c[1:], pre_mults):
|
||||
local_aggregation += pre_mult.mul_no_reduce(x.coerce(i))
|
||||
return local_aggregation.reduce_after_mul() + p_c[0]
|
||||
|
||||
|
||||
##
|
||||
# reduces the input to [0,90) and returns whether the reduced value is
|
||||
# greater than \Pi and greater than Pi over 2
|
||||
# @param x: value of any type to be reduced to the [0,90) interval
|
||||
#
|
||||
# @return w: reduced angle in either fixed or floating point .
|
||||
#
|
||||
# @return b1: \{0,1\} value. Returns one when reduction to 2*\pi
|
||||
# is greater than \pi
|
||||
#
|
||||
# @return b2: \{0,1\} value. Returns one when reduction to
|
||||
# \pi is greater than \pi/2.
|
||||
def sTrigSub(x):
|
||||
# reduction to 2* \pi
|
||||
f = x * (1.0 / (2 * pi))
|
||||
f = load_sint(trunc(f), type(x))
|
||||
y = x - (f) * (2 * pi)
|
||||
# reduction to \pi
|
||||
b1 = y > pi
|
||||
w = b1 * ((2 * pi - y) - y) + y
|
||||
# reduction to \pi/2
|
||||
b2 = w > pi_over_2
|
||||
w = b2 * ((pi - w) - w) + w
|
||||
# returns scaled angle and boolean flags
|
||||
return w, b1, b2
|
||||
|
||||
# kernel method calls -- they are built in a generic way
|
||||
|
||||
|
||||
##
|
||||
# Kernel sin. Returns the sin of a given angle on the [0, \pi/2) interval and
|
||||
# adjust the sign in case the angle was reduced on the [0,360) interval
|
||||
#
|
||||
# @param w: fractional value for an angle on the [0, \pi) interval.
|
||||
#
|
||||
# @return returns the sin of w.
|
||||
def ssin(w, s):
|
||||
# calculates the v of w for polynomial evaluation
|
||||
v = w * (1.0 / pi_over_2)
|
||||
v_2 = v ** 2
|
||||
# adjust sign according to the movement in the reduction
|
||||
b = s * (-2) + 1
|
||||
# calculate the sin using polynomial evaluation
|
||||
local_sin = b * v * p_eval(p_3307, v_2)
|
||||
return local_sin
|
||||
|
||||
|
||||
##
|
||||
# Kernel cos. Returns the cos of a given angle on the [0.pi/2)
|
||||
# interval and adjust
|
||||
# the sign in case the angle was reduced on the [0,360) interval.
|
||||
#
|
||||
# @param w: fractional value for an angle on the [0,\pi) interval.
|
||||
#
|
||||
# @param s: \{0,1\} value. Corresponding to b2. Returns 1 if the angle
|
||||
# was reduced from an angle in the [\pi/2,\pi) interval.
|
||||
#
|
||||
# @return returns the cos of w (sfix).
|
||||
def scos(w, s):
|
||||
# calculates the v of the w.
|
||||
v = w
|
||||
v_2 = v ** 2
|
||||
# adjust sign according to the movement in the reduction
|
||||
b = s * (-2) + 1
|
||||
# calculate the cos using polynomial evaluation
|
||||
local_cos = b * p_eval(p_3508, v_2)
|
||||
return local_cos
|
||||
|
||||
|
||||
# facade method calls --it is built in a generic way
|
||||
|
||||
##
|
||||
# Returns the sin of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix, sfloat).
|
||||
#
|
||||
# @return returns the sin of x (sfix, sfloat).
|
||||
def sin(x):
|
||||
# reduces the angle to the [0,\pi/2) interval.
|
||||
w, b1, b2 = sTrigSub(x)
|
||||
# returns the sin with sign correction
|
||||
return ssin(w, b1)
|
||||
|
||||
|
||||
##
|
||||
# Returns the sin of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix, sfloat).
|
||||
#
|
||||
# @return returns the sin of x (sfix, sfloat).
|
||||
def cos(x):
|
||||
# reduces the angle to the [0,\pi/2) interval.
|
||||
w, b1, b2 = sTrigSub(x)
|
||||
|
||||
# returns the sin with sign correction
|
||||
return scos(w, b2)
|
||||
|
||||
|
||||
##
|
||||
# Returns the tan (sfix, sfloat) of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix, sfloat).
|
||||
#
|
||||
# @return returns the tan of x (sifx, sfloat).
|
||||
def tan(x):
|
||||
# reduces the angle to the [0,\pi/2) interval.
|
||||
w, b1, b2 = sTrigSub(x)
|
||||
# calculates the sin and the cos.
|
||||
local_sin = ssin(w, b1)
|
||||
local_cos = scos(w, b2)
|
||||
# obtains the local tan
|
||||
local_tan = local_sin/local_cos
|
||||
return local_tan
|
||||
|
||||
|
||||
##
|
||||
# Returns the result of 2^a for any unbounded number
|
||||
# @param a: exponent for 2^a
|
||||
#
|
||||
# @return returns the value of 2^a if it is within the range
|
||||
@types.vectorize
|
||||
def exp2_fx(a):
|
||||
if types.program.options.ring:
|
||||
sint = types.sint
|
||||
intbitint = types.intbitint
|
||||
# how many bits to use from integer part
|
||||
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
||||
n_bits = a.f + n_int_bits
|
||||
n_shift = int(types.program.options.ring) - a.k
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal()
|
||||
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
||||
lower_overflow = sint()
|
||||
comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1],
|
||||
r_bits[a.f-1::-1])
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
lower_masked = sint.bit_compose(masked_bits[:a.f])
|
||||
lower = lower_r + lower_masked - (lower_overflow << (a.f))
|
||||
c = types.sfix._new(lower, k=a.k, f=a.f)
|
||||
higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits],
|
||||
r_bits[a.f:n_bits],
|
||||
carry_in=lower_overflow,
|
||||
get_carry=True)
|
||||
d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]),
|
||||
k=a.k, f=a.f)
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False,
|
||||
nearest=types.sfix.round_nearest),
|
||||
k=a.k, f=a.f)
|
||||
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
|
||||
r_bits[n_bits:-1],
|
||||
higher_bits[-1])
|
||||
# should be for free
|
||||
highest_bits = intbitint.ripple_carry_adder(
|
||||
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
|
||||
carry_in=higher_bits[-1])
|
||||
bits_to_check = [x.bit_xor(y)
|
||||
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
|
||||
t = floatingpoint.KMul(bits_to_check)
|
||||
# sign
|
||||
s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry)
|
||||
return s.if_else(t.if_else(small_result, 0), g)
|
||||
else:
|
||||
# obtain absolute value of a
|
||||
s = a < 0
|
||||
a = (s * (-2) + 1) * a
|
||||
# isolates fractional part of number
|
||||
b = trunc(a)
|
||||
c = a - load_sint(b, type(a))
|
||||
# squares integer part of a
|
||||
d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a))
|
||||
# evaluates fractional part of a in p_1045
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
return (1 - s) * g + s * ((types.sfix(1)) / g)
|
||||
|
||||
|
||||
##
|
||||
# Returns the result of log_2(x) for any unbounded number. This is
|
||||
# achieved by changing x into f*2^n where f is bounded by[0.5, 1].
|
||||
# Then the polynomials are used to calculate the log_2 of f,
|
||||
# which is then just added to n.
|
||||
#
|
||||
# @param x: input for log_2 (sfix, sint).
|
||||
#
|
||||
# @return returns (sfix) the value of log2(X)
|
||||
@types.vectorize
|
||||
def log2_fx(x):
|
||||
if type(x) is types.sfix:
|
||||
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
||||
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
|
||||
p -= x.f
|
||||
vlen = x.f
|
||||
else:
|
||||
d = types.sfloat(x)
|
||||
v, p, vlen = d.v, d.p, d.vlen
|
||||
# isolates mantisa of d, now the n can be also substituted by the
|
||||
# secret shared p from d in the expresion above.
|
||||
v = load_sint(v, type(x))
|
||||
w = (1.0 / (2 ** (vlen)))
|
||||
v = v * w
|
||||
# polynomials for the log_2 evaluation of f are calculated
|
||||
P = p_eval(p_2524, v)
|
||||
Q = p_eval(q_2524, v)
|
||||
# the log is returned by adding the result of the division plus p.
|
||||
a = P / Q + load_sint(vlen + p, type(x))
|
||||
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
|
||||
|
||||
|
||||
##
|
||||
# Returns the value of the expression x^y where both inputs
|
||||
# are secret shared. It uses log2_fx together with
|
||||
# exp2_fx to calcualte the expresion 2^{y*log2(x)}.
|
||||
#
|
||||
# @param x: (sfix) secret shared base.
|
||||
#
|
||||
# @param y: (sfix, clear types) secret shared exponent.
|
||||
#
|
||||
# @return returns the value of x^y
|
||||
def pow_fx(x, y):
|
||||
log2_x =0
|
||||
# obtains log2(x)
|
||||
if (type(x) == int or type(x) == float):
|
||||
log2_x = math.log(x,2)
|
||||
else:
|
||||
log2_x = log2_fx(x)
|
||||
# obtains y * log2(x)
|
||||
exp = y * log2_x
|
||||
# returns 2^(y*log2(x))
|
||||
return exp2_fx(exp)
|
||||
|
||||
|
||||
##
|
||||
# Returns the value of the expression log_b(x) where x is
|
||||
# secret shared. It uses log2_fx to calculate the expression
|
||||
# logb(2)*log2(x).
|
||||
#
|
||||
# @param x:(sfix, sint) secret shared coefficient for log.
|
||||
#
|
||||
# @param b:(int) base for log operation.
|
||||
#
|
||||
# @return returns (sfix) the value of logb(x).
|
||||
def log_fx(x, b):
|
||||
# calculates logb(2)
|
||||
logb_2 = math.log(2, b)
|
||||
# returns logb(2) * log2(x)
|
||||
return logb_2 * log2_fx(x)
|
||||
|
||||
|
||||
##
|
||||
# Returns the absolute value of a fix point number.
|
||||
# The method is also applicable to sfloat,
|
||||
# however, more efficient mechanisms can be devised.
|
||||
#
|
||||
# @param x: (sfix)
|
||||
#
|
||||
# @return (sfix) unsigned
|
||||
def abs_fx(x):
|
||||
s = x < 0
|
||||
return (1 - 2 * s) * x
|
||||
|
||||
|
||||
##
|
||||
# Floors the input and stores the value into a sflix register
|
||||
# @param x: coefficient to be floored.
|
||||
#
|
||||
# @return floored sint value of x
|
||||
def floor_fx(x):
|
||||
return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x))
|
||||
|
||||
|
||||
### sqrt methods
|
||||
|
||||
|
||||
##
|
||||
# obtains the most significative bit (MSB)
|
||||
# of a given input. The size of the vector
|
||||
# is tuned to the needs of sqrt.
|
||||
# @param b: number from which you obtain the
|
||||
# most significative bit.
|
||||
# @param k: number of bits for which
|
||||
# an output of size (k+1) if even
|
||||
# is going to be produced.
|
||||
# @return z: index array for MSB of size
|
||||
# k or K+1 if even.
|
||||
def MSB(b, k):
|
||||
# calculation of z
|
||||
# x in order 0 - k
|
||||
if (k > types.program.bit_length):
|
||||
raise OverflowError("The supported bit \
|
||||
lenght of the application is smaller than k")
|
||||
|
||||
x_order = b.bit_decompose(k)
|
||||
x = [0] * k
|
||||
# x i now inverted
|
||||
for i in range(k - 1, -1, -1):
|
||||
x[k - 1 - i] = x_order[i]
|
||||
# y is inverted for PReOR and then restored
|
||||
y_order = floatingpoint.PreOR(x)
|
||||
|
||||
# y in order (restored in orginal order
|
||||
y = [0] * k
|
||||
for i in range(k - 1, -1, -1):
|
||||
y[k - 1 - i] = y_order[i]
|
||||
|
||||
# obtain z
|
||||
z = [0] * (k + 1 - k % 2)
|
||||
for i in range(k - 1):
|
||||
z[i] = y[i] - y[i + 1]
|
||||
z[k - 1] = y[k - 1]
|
||||
|
||||
return z
|
||||
|
||||
|
||||
##
|
||||
# Similar to norm_SQ, saves rounds by not
|
||||
# calculating v and c.
|
||||
#
|
||||
# @param b: sint input to be normalized.
|
||||
# @param k: bitsize of the input, by definition
|
||||
# its value is either sfix.k or program.bit_lengthh
|
||||
# @return m_odd: the parity of most signficative bit index m
|
||||
# @return m: index of most significative bit
|
||||
# @return w: 2^m/2 or 2^ (m-1) /2
|
||||
def norm_simplified_SQ(b, k):
|
||||
z = MSB(b, k)
|
||||
# construct m
|
||||
#m = types.sint(0)
|
||||
m_odd = 0
|
||||
for i in range(k):
|
||||
#m = m + (i + 1) * z[i]
|
||||
# determine the parity of the input
|
||||
if (i % 2 == 0):
|
||||
m_odd = m_odd + z[i]
|
||||
|
||||
# construct w,
|
||||
k_over_2 = k / 2 + 1
|
||||
w_array = [0] * (k_over_2)
|
||||
w_array[0] = z[0]
|
||||
for i in range(1, k_over_2):
|
||||
w_array[i] = z[2 * i - 1] + z[2 * i]
|
||||
|
||||
# w aggregation
|
||||
w = types.sint(0)
|
||||
for i in range(k_over_2):
|
||||
w += (2 ** i) * w_array[i]
|
||||
|
||||
# return computed values
|
||||
#return m_odd, m, w
|
||||
return m_odd, None, w
|
||||
|
||||
|
||||
##
|
||||
# Obtains the sqrt using our custom mechanism
|
||||
# for any sfix input value.
|
||||
# no restrictions on the size of f.
|
||||
#
|
||||
# @param x: secret shared input from which the sqrt
|
||||
# is calucalted,
|
||||
#
|
||||
# @return g: approximated sqrt
|
||||
def sqrt_simplified_fx(x):
|
||||
# fix theta (number of iterations)
|
||||
theta = max(int(math.ceil(math.log(types.sfix.k))), 6)
|
||||
|
||||
# process to use 2^(m/2) approximation
|
||||
m_odd, m, w = norm_simplified_SQ(x.v, x.k)
|
||||
# process to set up the precision and allocate correct 2**f
|
||||
if x.f % 2 == 1:
|
||||
m_odd = (1 - 2 * m_odd) + m_odd
|
||||
w = (w * 2 - w) * (1-m_odd) + w
|
||||
# map number to use sfix format and instantiate the number
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2))
|
||||
# obtains correct 2 ** (m/2)
|
||||
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
|
||||
# produce x/ 2^(m/2)
|
||||
y_0 = types.cfix(1.0) / w
|
||||
|
||||
# from this point on it sufices to work sfix-wise
|
||||
g_0 = (y_0 * x)
|
||||
h_0 = y_0 * types.cfix(0.5)
|
||||
gh_0 = g_0 * h_0
|
||||
|
||||
## initialization
|
||||
g = g_0
|
||||
h = h_0
|
||||
gh = gh_0
|
||||
|
||||
for i in range(1, theta - 2):
|
||||
r = (3 / 2.0) - gh
|
||||
g = g * r
|
||||
h = h * r
|
||||
gh = g * h
|
||||
|
||||
# newton
|
||||
r = (3 / 2.0) - gh
|
||||
h = h * r
|
||||
H = 4 * (h * h)
|
||||
H = H * x
|
||||
H = (3) - H
|
||||
H = h * H
|
||||
g = H * x
|
||||
g = g
|
||||
|
||||
return g
|
||||
|
||||
|
||||
##
|
||||
# Calculates the normSQ of a number
|
||||
# @param x: number from which the norm is going to be extracted
|
||||
# @param k: bitsize of x
|
||||
#
|
||||
# @return c: where c = x*v where c is bounded by 2^{k-1} and 2^k
|
||||
# @return v: where v = 2^k-m
|
||||
# @return m: where m = MSB
|
||||
# @return w: where w = 2^{m/2} if m is oeven and 2^{m-1 / 2} otherwise
|
||||
def norm_SQ(b, k):
|
||||
# calculation of z
|
||||
# x in order 0 - k
|
||||
z = MSB(b,k)
|
||||
# now reverse bits of z[i] to generate v
|
||||
v = types.sint(0)
|
||||
for i in range(k):
|
||||
v += (2**(k - i - 1)) * z[i]
|
||||
c = b * v
|
||||
|
||||
# construct m
|
||||
m = types.sint(0)
|
||||
for i in range(k):
|
||||
m = m + (i+1) * z[i]
|
||||
|
||||
# construct w, changes from what is on the paper
|
||||
# and the documentation
|
||||
k_over_2= k/2+1#int(math.ceil((k/2.0)))+1
|
||||
w_array = [0]*(k_over_2 )
|
||||
w_array[0] = z[0]
|
||||
for i in range(1, k_over_2):
|
||||
w_array[i] = z[2 * i - 1] + z[2 * i]
|
||||
|
||||
w = types.sint(0)
|
||||
for i in range(k_over_2):
|
||||
w += (2 ** i) * w_array[i]
|
||||
|
||||
# return computed values
|
||||
return c, v, m, w
|
||||
|
||||
|
||||
##
|
||||
# Given f and k, returns a linear approximation of 1/x^{1/2}
|
||||
# escalated by s^f.
|
||||
# Method only works for sfix inputs. It uses the normSQ.
|
||||
# the method is an implementation of [Liedel2012]
|
||||
# @param x: number from which the approximation is caluclated
|
||||
# @param k: bitsize of x
|
||||
# @param f: precision of the input f
|
||||
#
|
||||
# @return c: Some approximation of (1/x^{1/2} * 2^f) *K
|
||||
# where K is close to 1
|
||||
def lin_app_SQ(b, k, f):
|
||||
|
||||
alpha = types.cfix((-0.8099868542) * 2 ** (k))
|
||||
beta = types.cfix(1.787727479 * 2 ** (2 * k))
|
||||
|
||||
# obtain normSQ parameters
|
||||
c, v, m, W = norm_SQ(types.sint(b), k)
|
||||
|
||||
# c is now escalated
|
||||
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
|
||||
|
||||
|
||||
# m even or odd determination
|
||||
m_bit = types.sint()
|
||||
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
|
||||
m = load_sint(m_bit, types.sfix)
|
||||
|
||||
# w times v this way both terms have 2^3k and can be symplified
|
||||
w = w * v
|
||||
factor = 1.0 / (2 ** (3.0 * k - 2 * f))
|
||||
w = w * factor # w escalated to 3k -2 * f
|
||||
# normalization factor W* 1/2 ^{f/2}
|
||||
w = w * W * types.cfix(1.0 / (2 ** (f / 2.0)))
|
||||
# now we need to elminate an additional root of 2 in case m was odd
|
||||
sqr_2 = types.cfix((2 ** (1 / 2.0)))
|
||||
w = (1 - m) * w + sqr_2 * w * m
|
||||
|
||||
return w
|
||||
|
||||
|
||||
##
|
||||
# Given bitsize k and precision f, it calulates the square root of x.
|
||||
# @param x: number from which the norm is going to be extracted
|
||||
# @param k: bitsize of x.
|
||||
# @param f: precision of x.
|
||||
#
|
||||
# @return g: square root of de-scaled input x
|
||||
def sqrt_fx(x_l, k, f):
|
||||
factor = 1.0 / (2.0 ** f)
|
||||
|
||||
x = load_sint(x_l, types.sfix) * factor
|
||||
|
||||
theta = int(math.ceil(math.log(k/5.4)))
|
||||
|
||||
y_0 = lin_app_SQ(x_l,k,f) #cfix(1.0/ (cx ** (1/2.0))) # lin_app_SQ(x_l,5,2)
|
||||
|
||||
y_0 = y_0 * factor #*((1.0/(2.0 ** f)))
|
||||
g_0 = y_0 * x
|
||||
|
||||
|
||||
#g = mpc_math.load_sint(mpc_math.trunc(g_0),types.sfix)
|
||||
h_0 = y_0 *(0.5)
|
||||
gh_0 = g_0 * h_0
|
||||
|
||||
##initialization
|
||||
g= g_0
|
||||
h= h_0
|
||||
gh =gh_0
|
||||
|
||||
for i in range(1,theta-2): #to implement \in [1,\theta-2]
|
||||
r = (3/2.0) - gh
|
||||
g = g * r
|
||||
h = h * r
|
||||
gh = g * h
|
||||
|
||||
# newton
|
||||
r = (3/2.0) - gh
|
||||
h = h * r
|
||||
H = 4 * (h * h)
|
||||
H = H * x
|
||||
H = (3) - H
|
||||
H = h * H
|
||||
g = H * x
|
||||
g = g #* (0.5)
|
||||
|
||||
return g
|
||||
|
||||
##
|
||||
# Returns the sqrt (sfix) of any given fractional
|
||||
# value as long as it can be rounded to a integral value
|
||||
# to 2^f precision.
|
||||
#
|
||||
# Note that sqrt only works as long as this inequality is respected:
|
||||
# 3*k - 2 *f < x.f (x.f by default is 20)
|
||||
# @param x: fractional input (sfix).
|
||||
#
|
||||
# @return returns the aTan of x (sifx).
|
||||
@types.vectorize
|
||||
def sqrt(x, k = types.sfix.k, f = types.sfix.f):
|
||||
|
||||
if (3 *k -2 * f >= types.sfix.f):
|
||||
return sqrt_simplified_fx(x)
|
||||
# raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ")
|
||||
else:
|
||||
param = trunc(x *(2 ** (f)))
|
||||
return sqrt_fx(param ,k ,f)
|
||||
|
||||
|
||||
##
|
||||
# Returns the aTan (sfix) of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix).
|
||||
#
|
||||
# @return returns the aTan of x (sifx).
|
||||
def atan(x):
|
||||
# obtain absolute value of x
|
||||
s = x < 0
|
||||
x_abs = (s * (-2) + 1) * x
|
||||
# angle isolation
|
||||
b = x_abs > 1
|
||||
v = (types.cfix(1.0) / x_abs)
|
||||
v = (1 - b) * (x_abs - v) + v
|
||||
v_2 =v*v
|
||||
|
||||
# range of polynomial coefficients
|
||||
assert x.k - x.f >= 18
|
||||
P = p_eval(p_5102, v_2)
|
||||
Q = p_eval(q_5102, v_2)
|
||||
|
||||
# padding
|
||||
y = v * (P / Q)
|
||||
y_pi_over_two = pi_over_2 - y
|
||||
|
||||
# sign correction
|
||||
y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two
|
||||
y = (1 - s) * (y - (-y)) + (-y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
##
|
||||
# Returns the aSin (sfix) of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix). valid interval is -1.0 <= x <= 1
|
||||
#
|
||||
# @return returns the aSin of x (sfix).
|
||||
def asin(x):
|
||||
# Square x
|
||||
x_2 = x*x
|
||||
# trignometric identities
|
||||
sqrt_l = sqrt(1- (x_2))
|
||||
x_sqrt_l =x / sqrt_l
|
||||
return atan(x_sqrt_l)
|
||||
|
||||
|
||||
##
|
||||
# Returns the aCos (sfix) of any given fractional value.
|
||||
#
|
||||
# @param x: fractional input (sfix). -1.0 < x < 1
|
||||
#
|
||||
# @return returns the aCos of x (sifx).
|
||||
def acos(x):
|
||||
y = asin(x)
|
||||
return pi_over_2 - y
|
||||
@@ -44,8 +44,10 @@ class Program(object):
|
||||
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
|
||||
options.binary)) > 1:
|
||||
raise CompilerError('can only use one out of -p, -B, -R, -F')
|
||||
self.bit_length = int(options.ring) or int(options.binary) \
|
||||
or int(options.field)
|
||||
if options.ring:
|
||||
self.bit_length = int(options.ring) - 1
|
||||
else:
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if not self.bit_length:
|
||||
self.bit_length = BIT_LENGTHS[param]
|
||||
print 'Default bit length:', self.bit_length
|
||||
@@ -71,6 +73,7 @@ class Program(object):
|
||||
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
|
||||
self.types = {}
|
||||
self.budget = int(self.options.budget)
|
||||
self.verbose = False
|
||||
self.to_merge = [Compiler.instructions.asm_open_class, \
|
||||
Compiler.instructions.gasm_open_class, \
|
||||
Compiler.instructions.muls_class, \
|
||||
@@ -82,10 +85,13 @@ class Program(object):
|
||||
Compiler.instructions.asm_input_class, \
|
||||
Compiler.instructions.gasm_input_class,
|
||||
Compiler.instructions.inputfix_class,
|
||||
Compiler.instructions.inputfloat_class]
|
||||
Compiler.instructions.inputfloat_class,
|
||||
Compiler.instructions.inputmixed_class,
|
||||
Compiler.instructions.trunc_pr_class]
|
||||
import Compiler.GC.instructions as gc
|
||||
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
|
||||
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]
|
||||
self.use_trunc_pr = False
|
||||
Program.prog = self
|
||||
|
||||
self.reset_values()
|
||||
@@ -452,6 +458,8 @@ class Tape:
|
||||
else:
|
||||
self.alloc_pool = defaultdict(set)
|
||||
self.purged = False
|
||||
self.n_rounds = 0
|
||||
self.n_to_merge = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instructions)
|
||||
@@ -506,6 +514,8 @@ class Tape:
|
||||
instructions = self.instructions
|
||||
for inst in instructions:
|
||||
inst.add_usage(req_node)
|
||||
req_node.num['all', 'round'] = self.n_rounds
|
||||
req_node.num['all', 'inv'] = self.n_to_merge
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@@ -530,7 +540,7 @@ class Tape:
|
||||
self.basicblocks.append(sub)
|
||||
self.active_basicblock = sub
|
||||
self.req_node.add_block(sub)
|
||||
print 'Compiling basic block', sub.name
|
||||
#print 'Compiling basic block', sub.name
|
||||
|
||||
def init_registers(self):
|
||||
self.reset_registers()
|
||||
@@ -601,6 +611,8 @@ class Tape:
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Merging instructions...'
|
||||
numrounds = merger.longest_paths_merge()
|
||||
block.n_rounds = numrounds
|
||||
block.n_to_merge = len(merger.open_nodes)
|
||||
if numrounds > 0:
|
||||
print 'Program requires %d rounds of communication' % numrounds
|
||||
if merger.counter:
|
||||
@@ -633,10 +645,11 @@ class Tape:
|
||||
# allocate registers
|
||||
reg_counts = self.count_regs()
|
||||
if not options.noreallocate:
|
||||
print 'Tape register usage:', dict(reg_counts)
|
||||
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
|
||||
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
|
||||
print 'Re-allocating...'
|
||||
if self.program.verbose:
|
||||
print 'Tape register usage:', dict(reg_counts)
|
||||
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
|
||||
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
|
||||
print 'Re-allocating...'
|
||||
allocator = al.StraightlineAllocator(REG_MAX)
|
||||
def alloc_loop(block):
|
||||
for reg in sorted(block.used_from_scope,
|
||||
@@ -661,7 +674,7 @@ class Tape:
|
||||
self.req_num = self.req_tree.aggregate()
|
||||
print 'Tape requires', self.req_num
|
||||
for req,num in sorted(self.req_num.items()):
|
||||
if num == float('inf'):
|
||||
if num == float('inf') or num >= 2 ** 32:
|
||||
num = -1
|
||||
if req[1] in data_types:
|
||||
self.basicblocks[-1].instructions.append(
|
||||
@@ -692,8 +705,9 @@ class Tape:
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.reqbl(bl,
|
||||
add_to_prog=False))
|
||||
print 'Tape requires prime bit length', self.req_bit_length['p']
|
||||
print 'Tape requires galois bit length', self.req_bit_length['2']
|
||||
if self.program.verbose:
|
||||
print 'Tape requires prime bit length', self.req_bit_length['p']
|
||||
print 'Tape requires galois bit length', self.req_bit_length['2']
|
||||
|
||||
@unpurged
|
||||
def _get_instructions(self):
|
||||
@@ -783,6 +797,8 @@ class Tape:
|
||||
return res
|
||||
__rmul__ = __mul__
|
||||
def set_all(self, value):
|
||||
if value == float('inf') and self['all', 'inv'] > 0:
|
||||
print 'Going to unknown from %s' % self
|
||||
res = Tape.ReqNum()
|
||||
for i in self:
|
||||
res[i] = value
|
||||
@@ -832,6 +848,15 @@ class Tape:
|
||||
self.parent = parent
|
||||
def aggregate(self, name):
|
||||
res = self.aggregator([node.aggregate() for node in self.nodes])
|
||||
try:
|
||||
n_reps = self.aggregator([1])
|
||||
n_rounds = res['all', 'round']
|
||||
n_invs = res['all', 'inv']
|
||||
if (n_invs / n_rounds) * 1000 < n_reps:
|
||||
print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
|
||||
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)
|
||||
except:
|
||||
pass
|
||||
return res
|
||||
def add_node(self, tape, name):
|
||||
new_node = Tape.ReqNode(name)
|
||||
|
||||
@@ -167,6 +167,12 @@ class _number(object):
|
||||
def pow2(self, bit_length=None, security=None):
|
||||
return 2**self
|
||||
|
||||
def min(self, other):
|
||||
return (self < other).if_else(self, other)
|
||||
|
||||
def max(self, other):
|
||||
return (self < other).if_else(other, self)
|
||||
|
||||
class _int(object):
|
||||
def if_else(self, a, b):
|
||||
if hasattr(a, 'for_mux'):
|
||||
@@ -705,6 +711,10 @@ class regint(_register, _int):
|
||||
popint(res)
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def push(cls, value):
|
||||
pushint(cls.conv(value))
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_random(cls, bit_length):
|
||||
""" Public insecure randomness """
|
||||
@@ -781,10 +791,10 @@ class regint(_register, _int):
|
||||
@vectorize
|
||||
@read_mem_value
|
||||
def int_op(self, other, inst, reverse=False):
|
||||
if isinstance(other, _secret):
|
||||
try:
|
||||
other = self.conv(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
elif not isinstance(other, type(self)):
|
||||
other = type(self)(other)
|
||||
res = regint()
|
||||
if reverse:
|
||||
inst(res, other, self)
|
||||
@@ -898,6 +908,8 @@ class regint(_register, _int):
|
||||
def print_reg_plain(self):
|
||||
print_int(self)
|
||||
|
||||
def print_if(self, string):
|
||||
cint(self).print_if(string)
|
||||
|
||||
class _secret(_register):
|
||||
__slots__ = []
|
||||
@@ -1121,6 +1133,13 @@ class sint(_secret, _int):
|
||||
comparison.PRandInt(res, bits)
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_input_from(cls, player):
|
||||
""" Secret input """
|
||||
res = cls()
|
||||
inputmixed('int', res, player)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
@@ -1196,8 +1215,8 @@ class sint(_secret, _int):
|
||||
@vectorize
|
||||
def __lt__(self, other, bit_length=None, security=None):
|
||||
res = sint()
|
||||
comparison.LTZ(res, self - other, bit_length or program.bit_length +
|
||||
(not (int(program.options.ring) == program.bit_length)),
|
||||
comparison.LTZ(res, self - other,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
return res
|
||||
|
||||
@@ -1205,8 +1224,8 @@ class sint(_secret, _int):
|
||||
@vectorize
|
||||
def __gt__(self, other, bit_length=None, security=None):
|
||||
res = sint()
|
||||
comparison.LTZ(res, other - self, bit_length or program.bit_length +
|
||||
(not (int(program.options.ring) == program.bit_length)),
|
||||
comparison.LTZ(res, other - self,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
return res
|
||||
|
||||
@@ -1304,13 +1323,14 @@ class sint(_secret, _int):
|
||||
return floatingpoint.BitDec(self, bit_length, bit_length, security)
|
||||
|
||||
def TruncMul(self, other, k, m, kappa=None, nearest=False):
|
||||
return (self * other).round(k, m, kappa, nearest)
|
||||
return (self * other).round(k, m, kappa, nearest, signed=True)
|
||||
|
||||
def TruncPr(self, k, m, kappa=None):
|
||||
return floatingpoint.TruncPr(self, k, m, kappa)
|
||||
def TruncPr(self, k, m, kappa=None, signed=True):
|
||||
return floatingpoint.TruncPr(self, k, m, kappa, signed=signed)
|
||||
|
||||
@vectorize
|
||||
def round(self, k, m, kappa=None, nearest=False, signed=False):
|
||||
kappa = kappa or program.security
|
||||
secret = isinstance(m, sint)
|
||||
if nearest:
|
||||
if secret:
|
||||
@@ -1320,7 +1340,7 @@ class sint(_secret, _int):
|
||||
else:
|
||||
if secret:
|
||||
return floatingpoint.Trunc(self, k, m, kappa)
|
||||
return self.TruncPr(k, m, kappa)
|
||||
return self.TruncPr(k, m, kappa, signed=signed)
|
||||
|
||||
def Norm(self, k, f, kappa=None, simplex_flag=False):
|
||||
return library.Norm(self, k, f, kappa, simplex_flag)
|
||||
@@ -1461,23 +1481,25 @@ class _bitint(object):
|
||||
linear_rounds = False
|
||||
|
||||
@classmethod
|
||||
def bit_adder(cls, a, b):
|
||||
def bit_adder(cls, a, b, carry_in=0, get_carry=False):
|
||||
a, b = list(a), list(b)
|
||||
a += [0] * (len(b) - len(a))
|
||||
b += [0] * (len(a) - len(b))
|
||||
return cls.bit_adder_selection(a, b)
|
||||
return cls.bit_adder_selection(a, b, carry_in=carry_in,
|
||||
get_carry=get_carry)
|
||||
|
||||
@classmethod
|
||||
def bit_adder_selection(cls, a, b):
|
||||
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
|
||||
if cls.log_rounds:
|
||||
return cls.carry_lookahead_adder(a, b)
|
||||
return cls.carry_lookahead_adder(a, b, carry_in=carry_in)
|
||||
elif cls.linear_rounds:
|
||||
return cls.ripple_carry_adder(a, b)
|
||||
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
|
||||
else:
|
||||
return cls.carry_select_adder(a, b)
|
||||
return cls.carry_select_adder(a, b, carry_in=carry_in)
|
||||
|
||||
@classmethod
|
||||
def carry_lookahead_adder(cls, a, b, fewer_inv=False):
|
||||
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
|
||||
get_carry=False):
|
||||
lower = []
|
||||
for (ai,bi) in zip(a,b):
|
||||
if ai is 0 or bi is 0:
|
||||
@@ -1493,10 +1515,13 @@ class _bitint(object):
|
||||
else:
|
||||
pre_op = floatingpoint.PreOpL
|
||||
if d:
|
||||
carries = (0,) + zip(*pre_op(carry, d))[1]
|
||||
carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1]
|
||||
else:
|
||||
carries = []
|
||||
return lower + cls.sum_from_carries(a, b, carries)
|
||||
res = lower + cls.sum_from_carries(a, b, carries)
|
||||
if get_carry:
|
||||
res += [carries[-1]]
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def sum_from_carries(a, b, carries):
|
||||
@@ -1504,7 +1529,7 @@ class _bitint(object):
|
||||
for (ai, bi, carry) in zip(a, b, carries)]
|
||||
|
||||
@classmethod
|
||||
def carry_select_adder(cls, a, b, get_carry=False):
|
||||
def carry_select_adder(cls, a, b, get_carry=False, carry_in=0):
|
||||
a += [0] * (len(b) - len(a))
|
||||
b += [0] * (len(a) - len(b))
|
||||
n = len(a)
|
||||
@@ -1524,7 +1549,7 @@ class _bitint(object):
|
||||
raise Exception('blocks not summing up: %s != %s' % \
|
||||
(sum(blocks), n))
|
||||
res = []
|
||||
carry = 0
|
||||
carry = carry_in
|
||||
cin_one = util.long_one(a + b)
|
||||
for m in blocks:
|
||||
aa = a[:m]
|
||||
@@ -1540,7 +1565,8 @@ class _bitint(object):
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def ripple_carry_adder(cls, a, b, carry=0):
|
||||
def ripple_carry_adder(cls, a, b, carry_in=0):
|
||||
carry = carry_in
|
||||
res = []
|
||||
for aa, bb in zip(a, b):
|
||||
cc, carry = cls.full_adder(aa, bb, carry)
|
||||
@@ -1760,14 +1786,15 @@ class intbitint(_bitint, sint):
|
||||
for i in range(len(a))]
|
||||
|
||||
@classmethod
|
||||
def bit_adder_selection(cls, a, b):
|
||||
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
|
||||
if cls.linear_rounds:
|
||||
return cls.ripple_carry_adder(a, b)
|
||||
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
|
||||
# experimental cut-off with dead code elimination
|
||||
elif len(a) < 122 or cls.log_rounds:
|
||||
return cls.carry_lookahead_adder(a, b)
|
||||
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
|
||||
get_carry=get_carry)
|
||||
else:
|
||||
return cls.carry_select_adder(a, b)
|
||||
return cls.carry_select_adder(a, b, carry_in=carry_in)
|
||||
|
||||
class sgf2nint(_bitint, sgf2n):
|
||||
bin_type = sgf2n
|
||||
@@ -1904,10 +1931,10 @@ class sgf2nfloat(sgf2n):
|
||||
|
||||
sgf2nfloat.set_precision(24, 8)
|
||||
|
||||
def parse_type(other):
|
||||
def parse_type(other, k=None, f=None):
|
||||
# converts type to cfix/sfix depending on the case
|
||||
if isinstance(other, cfix.scalars):
|
||||
return cfix(other)
|
||||
return cfix(other, k=k, f=f)
|
||||
elif isinstance(other, cint):
|
||||
tmp = cfix()
|
||||
tmp.load_int(other)
|
||||
@@ -1975,9 +2002,11 @@ class cfix(_number, _structure):
|
||||
return 1
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, v=None, size=None):
|
||||
f = self.f
|
||||
k = self.k
|
||||
def __init__(self, v=None, k=None, f=None, size=None):
|
||||
f = f or self.f
|
||||
k = k or self.k
|
||||
self.f = f
|
||||
self.k = k
|
||||
self.size = get_global_vector_size()
|
||||
if isinstance(v, cint):
|
||||
self.v = cint(v,size=self.size)
|
||||
@@ -2025,22 +2054,20 @@ class cfix(_number, _structure):
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return cfix(self.v + other.v)
|
||||
elif isinstance(other, sfix):
|
||||
return sfix(self.v + other.v)
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for cfix.__add__' % type(other))
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def mul(self, other):
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
assert self.f == other.f
|
||||
sgn = cint(1 - 2 * (self.v * other.v < 0))
|
||||
absolute = self.v * other.v * sgn
|
||||
val = sgn * (absolute >> self.f)
|
||||
return cfix(val)
|
||||
elif isinstance(other, sfix):
|
||||
res = sfix((self.v * other.v) >> self.f)
|
||||
return res
|
||||
return NotImplemented
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for cfix.__mul__' % type(other))
|
||||
|
||||
@@ -2130,7 +2157,8 @@ class cfix(_number, _structure):
|
||||
if isinstance(other, cfix):
|
||||
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
|
||||
elif isinstance(other, sfix):
|
||||
return sfix(library.FPDiv(self.v, other.v, self.k, self.f, other.kappa))
|
||||
return sfix(library.FPDiv(self.v, other.v, self.k, self.f,
|
||||
other.kappa, nearest=sfix.round_nearest))
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
|
||||
@@ -2169,6 +2197,7 @@ class _single(_number, _structure):
|
||||
return cls._new(cls.int_type.load_mem(address))
|
||||
|
||||
@classmethod
|
||||
@read_mem_value
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cls):
|
||||
return other
|
||||
@@ -2193,9 +2222,13 @@ class _single(_number, _structure):
|
||||
|
||||
@classmethod
|
||||
def dot_product(cls, x, y, res_params=None):
|
||||
return cls.unreduced_dot_product(x, y, res_params).reduce_after_mul()
|
||||
|
||||
@classmethod
|
||||
def unreduced_dot_product(cls, x, y, res_params=None):
|
||||
dp = cls.int_type.dot_product([xx.pre_mul() for xx in x],
|
||||
[yy.pre_mul() for yy in y])
|
||||
return x[0].unreduced(dp, y[0], res_params, len(x)).reduce_after_mul()
|
||||
return x[0].unreduced(dp, y[0], res_params, len(x))
|
||||
|
||||
@classmethod
|
||||
def row_matrix_mul(cls, row, matrix, res_params=None):
|
||||
@@ -2300,20 +2333,25 @@ class _fix(_single):
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
if isinstance(other, _fix):
|
||||
if isinstance(other, (_fix, cfix)):
|
||||
return other
|
||||
else:
|
||||
return cls.conv(other)
|
||||
|
||||
@classmethod
|
||||
def from_sint(cls, other):
|
||||
def from_sint(cls, other, k=None, f=None):
|
||||
res = cls()
|
||||
res.f = f or cls.f
|
||||
res.k = k or cls.k
|
||||
res.load_int(cls.int_type.conv(other))
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _new(cls, other):
|
||||
return cls(other)
|
||||
def _new(cls, other, k=None, f=None):
|
||||
res = cls(other)
|
||||
res.k = k or cls.k
|
||||
res.f = f or cls.f
|
||||
return res
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, _v=None, size=None):
|
||||
@@ -2331,7 +2369,7 @@ class _fix(_single):
|
||||
self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size)
|
||||
elif isinstance(_v, self.float_type):
|
||||
p = (f + _v.p)
|
||||
b = (p >= 0)
|
||||
b = (p.greater_equal(0, _v.vlen))
|
||||
a = b*(_v.v << (p)) + (1-b)*(_v.v >> (-p))
|
||||
self.v = (1-2*_v.s)*a
|
||||
elif isinstance(_v, type(self)):
|
||||
@@ -2355,31 +2393,45 @@ class _fix(_single):
|
||||
def add(self, other):
|
||||
other = self.coerce(other)
|
||||
if isinstance(other, (_fix, cfix)):
|
||||
return type(self)(self.v + other.v)
|
||||
return self._new(self.v + other.v, k=self.k, f=self.f)
|
||||
elif isinstance(other, cfix.scalars):
|
||||
tmp = cfix(other)
|
||||
tmp = cfix(other, k=self.k, f=self.f)
|
||||
return self + tmp
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for _fix.__add__' % type(other))
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def mul(self, other):
|
||||
if isinstance(other, (sint, cint, regint, int, long)):
|
||||
return self._new(self.v * other, k=self.k, f=self.f)
|
||||
elif isinstance(other, float):
|
||||
if int(other) == other:
|
||||
return self.mul(int(other))
|
||||
v = int(round(other * 2 ** self.f))
|
||||
if v == 0:
|
||||
return 0
|
||||
f = self.f
|
||||
while v % 2 == 0:
|
||||
f -= 1
|
||||
v /= 2
|
||||
k = len(bin(abs(v))) - 1
|
||||
other = cfix(cint(v))
|
||||
other.f = f
|
||||
other.k = k
|
||||
other = self.coerce(other)
|
||||
if isinstance(other, _fix):
|
||||
val = self.v.TruncMul(other.v, self.k * 2, self.f, self.kappa,
|
||||
if isinstance(other, (_fix, cfix)):
|
||||
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
|
||||
self.kappa,
|
||||
self.round_nearest)
|
||||
if self.size >= other.size:
|
||||
return self._new(val)
|
||||
return self._new(val, k=self.k, f=self.f)
|
||||
else:
|
||||
return self.vec._new(val)
|
||||
elif isinstance(other, cfix):
|
||||
res = type(self)((self.v * other.v) >> self.f)
|
||||
return res
|
||||
return self.vec._new(val, k=self.k, f=self.f)
|
||||
elif isinstance(other, cfix.scalars):
|
||||
scalar_fix = cfix(other)
|
||||
return self * scalar_fix
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for _fix.__mul__' % type(other))
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def __neg__(self):
|
||||
@@ -2397,6 +2449,7 @@ class _fix(_single):
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
|
||||
@vectorize
|
||||
def __rdiv__(self, other):
|
||||
return self.coerce(other) / self
|
||||
|
||||
@@ -2418,14 +2471,24 @@ class sfix(_fix):
|
||||
@vectorized_classmethod
|
||||
def get_input_from(cls, player):
|
||||
v = cls.int_type()
|
||||
inputfix(v, cls.f, player)
|
||||
inputmixed('fix', v, cls.f, player)
|
||||
return cls._new(v)
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
return parse_type(other)
|
||||
@vectorized_classmethod
|
||||
def get_random(cls, lower, upper):
|
||||
""" Uniform random number around centre of bounds """
|
||||
""" Range can be smaller """
|
||||
log_range = int(math.log(upper - lower, 2))
|
||||
n_bits = log_range + cls.f
|
||||
average = lower + 0.5 * (upper - lower)
|
||||
lower = average - 0.5 * 2 ** log_range
|
||||
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
|
||||
|
||||
def coerce(self, other):
|
||||
return parse_type(other, k=self.k, f=self.f)
|
||||
|
||||
def mul_no_reduce(self, other, res_params=None):
|
||||
assert self.f == other.f
|
||||
return self.unreduced(self.v * other.v)
|
||||
|
||||
def pre_mul(self):
|
||||
@@ -2434,16 +2497,21 @@ class sfix(_fix):
|
||||
def unreduced(self, v, other=None, res_params=None, n_summands=1):
|
||||
return unreduced_sfix(v, self.k * 2, self.f, self.kappa)
|
||||
|
||||
class unreduced_sfix(object):
|
||||
class unreduced_sfix(_single):
|
||||
int_type = sint
|
||||
|
||||
@classmethod
|
||||
def _new(cls, v):
|
||||
return cls(v, 2 * sfix.k, sfix.f, sfix.kappa)
|
||||
|
||||
def __init__(self, v, k, m, kappa):
|
||||
self.v = v
|
||||
self.k = k
|
||||
self.m = m
|
||||
self.kappa = kappa
|
||||
self.size = self.v.size
|
||||
|
||||
def __add__(self, other):
|
||||
if other in (0, 0L):
|
||||
if other is 0 or other is 0L:
|
||||
return self
|
||||
assert self.k == other.k
|
||||
assert self.m == other.m
|
||||
@@ -2455,7 +2523,10 @@ class unreduced_sfix(object):
|
||||
@vectorize
|
||||
def reduce_after_mul(self):
|
||||
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
|
||||
nearest=sfix.round_nearest))
|
||||
nearest=sfix.round_nearest,
|
||||
signed=True))
|
||||
|
||||
sfix.unreduced_type = unreduced_sfix
|
||||
|
||||
# this is for 20 bit decimal precision
|
||||
# with 40 bitlength of entire number
|
||||
@@ -2503,7 +2574,7 @@ class squant(_single):
|
||||
raise CompilerError('%f not quantizable' % value)
|
||||
self.v = self.int_type(q)
|
||||
reset_global_vector_size()
|
||||
elif isinstance(value, type(self)):
|
||||
elif isinstance(value, squant) and value.params == self.params:
|
||||
self.v = value.v
|
||||
else:
|
||||
raise CompilerError('cannot convert %s to squant' % value)
|
||||
@@ -2538,7 +2609,7 @@ class squant(_single):
|
||||
return self.mul_no_reduce(other, res_params).reduce_after_mul()
|
||||
|
||||
def mul_no_reduce(self, other, res_params=None):
|
||||
if isinstance(other, sint):
|
||||
if isinstance(other, (sint, cint, regint)):
|
||||
return self._new(other * (self.v - self.Z) + self.Z,
|
||||
params=self.get_params())
|
||||
other = self.coerce(other)
|
||||
@@ -2572,7 +2643,7 @@ class _unreduced_squant(object):
|
||||
self.res_params = res_params or params[0]
|
||||
|
||||
def __add__(self, other):
|
||||
if other in (0, 0L):
|
||||
if other is 0 or other is 0L:
|
||||
return self
|
||||
assert self.params == other.params
|
||||
assert self.res_params == other.res_params
|
||||
@@ -2650,7 +2721,8 @@ class squant_params(object):
|
||||
int_mult = util.expand(int_mult, size)
|
||||
tmp = unreduced.v * int_mult + shifted_Z
|
||||
shifted = tmp.round(self.max_length, n_shift,
|
||||
squant.kappa, squant.round_nearest)
|
||||
kappa=squant.kappa, nearest=squant.round_nearest,
|
||||
signed=True)
|
||||
if squant.clamp:
|
||||
length = max(self.k, self.max_length - n_shift) + 1
|
||||
top = (1 << self.k) - 1
|
||||
@@ -2715,6 +2787,10 @@ class sfloat(_number, _structure):
|
||||
else:
|
||||
return cls(other)
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
return cls.conv(other)
|
||||
|
||||
@staticmethod
|
||||
def convert_float(v, vlen, plen):
|
||||
if v < 0:
|
||||
@@ -2747,7 +2823,7 @@ class sfloat(_number, _structure):
|
||||
p = sint()
|
||||
z = sint()
|
||||
s = sint()
|
||||
inputfloat(v, p, z, s, cls.vlen, player)
|
||||
inputmixed('float', v, p, z, s, cls.vlen, player)
|
||||
return cls(v, p, z, s)
|
||||
|
||||
@vectorize_init
|
||||
@@ -2935,7 +3011,7 @@ class sfloat(_number, _structure):
|
||||
return self + -other
|
||||
|
||||
def __rsub__(self, other):
|
||||
raise NotImplementedError()
|
||||
return -self + other
|
||||
|
||||
def __div__(self, other):
|
||||
other = self.conv(other)
|
||||
@@ -3144,17 +3220,18 @@ class Array(object):
|
||||
for i in range(self.length):
|
||||
yield self[i]
|
||||
|
||||
def assign(self, other):
|
||||
if isinstance(other, Array):
|
||||
def loop(i):
|
||||
self[i] = other[i]
|
||||
library.range_loop(loop, len(self))
|
||||
elif isinstance(other, Tape.Register):
|
||||
if len(other) == self.length:
|
||||
self[0] = other
|
||||
else:
|
||||
raise CompilerError('Length mismatch between array and vector')
|
||||
else:
|
||||
def same_shape(self):
|
||||
return Array(self.length, self.value_type)
|
||||
|
||||
def assign(self, other, base=0):
|
||||
try:
|
||||
other = other.get_vector()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
other.store_in_mem(self.get_address(base))
|
||||
assert len(self) >= other.size + base
|
||||
except AttributeError:
|
||||
for i,j in enumerate(other):
|
||||
self[i] = j
|
||||
return self
|
||||
@@ -3169,22 +3246,42 @@ class Array(object):
|
||||
self[i] = mem_value
|
||||
return self
|
||||
|
||||
def get_vector(self):
|
||||
return self.value_type.load_mem(self.address, size=self.length)
|
||||
def get_vector(self, base=0, size=None):
|
||||
size = size or self.length
|
||||
return self.value_type.load_mem(self.get_address(base), size=size)
|
||||
|
||||
def get_mem_value(self, index):
|
||||
return MemValue(self[index], self.get_address(index))
|
||||
|
||||
def input_from(self, player, budget=None):
|
||||
self.assign(self.value_type.get_input_from(player, size=len(self)))
|
||||
|
||||
def __add__(self, other):
|
||||
if other is 0:
|
||||
return self
|
||||
assert len(self) == len(other)
|
||||
return Array.create_from(x + y for x, y in zip(self, other))
|
||||
return self.get_vector() + other
|
||||
|
||||
def __sub__(self, other):
|
||||
assert len(self) == len(other)
|
||||
return Array.create_from(x - y for x, y in zip(self, other))
|
||||
return self.get_vector() - other
|
||||
|
||||
def __mul__(self, value):
|
||||
return Array.create_from(x * value for x in self)
|
||||
return self.get_vector() * value
|
||||
|
||||
def __pow__(self, value):
|
||||
return self.get_vector() ** value
|
||||
|
||||
__radd__ = __add__
|
||||
__rmul__ = __mul__
|
||||
|
||||
def shuffle(self):
|
||||
@library.for_range(len(self))
|
||||
def _(i):
|
||||
j = regint.get_random(64) % (len(self) - i)
|
||||
tmp = self[i]
|
||||
self[i] = self[i + j]
|
||||
self[i + j] = tmp
|
||||
|
||||
def reveal(self):
|
||||
return Array.create_from(x.reveal() for x in self)
|
||||
@@ -3223,6 +3320,9 @@ class SubMultiArray(object):
|
||||
self.address, index, debug=self.debug)
|
||||
return self.sub_cache[key]
|
||||
|
||||
def __setitem__(self, index, other):
|
||||
self[index].assign(other)
|
||||
|
||||
def __len__(self):
|
||||
return self.sizes[0]
|
||||
|
||||
@@ -3235,35 +3335,60 @@ class SubMultiArray(object):
|
||||
def total_size(self):
|
||||
return reduce(operator.mul, self.sizes) * self.value_type.n_elements()
|
||||
|
||||
def get_vector(self):
|
||||
return self.value_type.load_mem(self.address, size=self.total_size())
|
||||
def get_vector(self, base=0, size=None):
|
||||
assert self.value_type.n_elements() == 1
|
||||
size = size or self.total_size()
|
||||
return self.value_type.load_mem(self.address + base, size=size)
|
||||
|
||||
def assign_vector(self, vector):
|
||||
assert vector.size == self.total_size()
|
||||
vector.store_in_mem(self.address)
|
||||
def assign_vector(self, vector, base=0):
|
||||
assert self.value_type.n_elements() == 1
|
||||
assert vector.size <= self.total_size()
|
||||
vector.store_in_mem(self.address + base)
|
||||
|
||||
class MultiArray(SubMultiArray):
|
||||
def __init__(self, sizes, value_type, debug=None, address=None):
|
||||
self.array = Array(reduce(operator.mul, sizes), \
|
||||
value_type, address=address)
|
||||
SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \
|
||||
debug=debug)
|
||||
if len(sizes) < 2:
|
||||
raise CompilerError('Use Array')
|
||||
def assign(self, other):
|
||||
if self.value_type.n_elements() > 1:
|
||||
assert self.sizes == other.sizes
|
||||
self.assign_vector(other.get_vector())
|
||||
|
||||
class Matrix(MultiArray):
|
||||
def __init__(self, rows, columns, value_type, debug=None, address=None):
|
||||
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
|
||||
address=address)
|
||||
def same_shape(self):
|
||||
return MultiArray(self.sizes, self.value_type)
|
||||
|
||||
def __setitem__(self, index, other):
|
||||
assert other.size == self.sizes[1]
|
||||
other.store_in_mem(self[index].address)
|
||||
def input_from(self, player, budget=None):
|
||||
@library.for_range_opt(self.sizes[0], budget=budget)
|
||||
def _(i):
|
||||
self[i].input_from(player, budget=budget)
|
||||
|
||||
def schur(self, other):
|
||||
assert self.sizes == other.sizes
|
||||
if len(self.sizes) == 2:
|
||||
res = Matrix(self.sizes[0], self.sizes[1], self.value_type)
|
||||
else:
|
||||
res = MultiArray(self.sizes, self.value_type)
|
||||
res.assign_vector(self.get_vector() * other.get_vector())
|
||||
return res
|
||||
|
||||
def __add__(self, other):
|
||||
if other is 0:
|
||||
return self
|
||||
assert self.sizes == other.sizes
|
||||
if len(self.sizes) == 2:
|
||||
res = Matrix(self.sizes[0], self.sizes[1], self.value_type)
|
||||
else:
|
||||
res = MultiArray(self.sizes, self.value_type)
|
||||
res.assign_vector(self.get_vector() + other.get_vector())
|
||||
return res
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def iadd(self, other):
|
||||
assert self.sizes == other.sizes
|
||||
self.assign_vector(self.get_vector() + other.get_vector())
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.mul(other)
|
||||
|
||||
def mul(self, other, res_params=None):
|
||||
assert len(self.sizes) == 2
|
||||
if isinstance(other, Array):
|
||||
assert len(other) == self.sizes[1]
|
||||
if self.value_type.n_elements() == 1:
|
||||
@@ -3277,7 +3402,8 @@ class Matrix(MultiArray):
|
||||
matrix[i][0] = x
|
||||
res = self * matrix
|
||||
return Array.create_from(x[0] for x in res)
|
||||
elif isinstance(other, Matrix):
|
||||
elif isinstance(other, SubMultiArray):
|
||||
assert len(other.sizes) == 2
|
||||
assert other.sizes[0] == self.sizes[1]
|
||||
if res_params is not None:
|
||||
class t(self.value_type):
|
||||
@@ -3287,14 +3413,16 @@ class Matrix(MultiArray):
|
||||
t = self.value_type
|
||||
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
|
||||
try:
|
||||
if max(res_matrix.sizes) > 1000:
|
||||
raise AttributeError()
|
||||
A = self.get_vector()
|
||||
B = other.get_vector()
|
||||
res_matrix.assign_vector(
|
||||
self.value_type.matrix_mul(A, B, self.sizes[1],
|
||||
res_params))
|
||||
except AttributeError:
|
||||
except (AttributeError, AssertionError):
|
||||
# fallback for sfloat etc.
|
||||
@library.for_range(self.sizes[0])
|
||||
@library.for_range_opt(self.sizes[0])
|
||||
def _(i):
|
||||
try:
|
||||
res_matrix[i] = self.value_type.row_matrix_mul(
|
||||
@@ -3311,6 +3439,78 @@ class Matrix(MultiArray):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True,
|
||||
res=None):
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
if res is None:
|
||||
if reduce:
|
||||
res_matrix = Matrix(n_rows, n_columns, self.value_type)
|
||||
else:
|
||||
res_matrix = Matrix(n_rows, n_columns, \
|
||||
self.value_type.unreduced_type)
|
||||
else:
|
||||
res_matrix = res
|
||||
@library.for_range_opt(n_rows)
|
||||
def _(i):
|
||||
@library.for_range_opt(n_columns)
|
||||
def _(j):
|
||||
col = column(other, j)
|
||||
r = row(self, i)
|
||||
if reduce:
|
||||
res_matrix[i][j] = self.value_type.dot_product(r, col)
|
||||
else:
|
||||
entry = self.value_type.unreduced_dot_product(r, col)
|
||||
res_matrix[i][j] = entry
|
||||
return res_matrix
|
||||
|
||||
def plain_mul(self, other, res=None):
|
||||
assert other.sizes[0] == self.sizes[1]
|
||||
return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \
|
||||
other.sizes[1], \
|
||||
lambda x, j: [x[k][j] for k in range(len(x))],
|
||||
res=res)
|
||||
|
||||
def mul_trans(self, other):
|
||||
assert other.sizes[1] == self.sizes[1]
|
||||
return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \
|
||||
other.sizes[0], lambda x, j: x[j])
|
||||
|
||||
def trans_mul(self, other, reduce=True, res=None):
|
||||
assert other.sizes[0] == self.sizes[0]
|
||||
return self.budget_mul(other, self.sizes[1], \
|
||||
lambda x, j: [x[k][j] for k in range(len(x))], \
|
||||
other.sizes[1], \
|
||||
lambda x, j: [x[k][j] for k in range(len(x))],
|
||||
reduce=reduce, res=res)
|
||||
|
||||
def transpose(self):
|
||||
assert len(self.sizes) == 2
|
||||
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
|
||||
@library.for_range_opt(self.sizes[1])
|
||||
def _(i):
|
||||
@library.for_range_opt(self.sizes[0])
|
||||
def _(j):
|
||||
res[i][j] = self[j][i]
|
||||
return res
|
||||
|
||||
class MultiArray(SubMultiArray):
|
||||
def __init__(self, sizes, value_type, debug=None, address=None):
|
||||
if isinstance(address, Array):
|
||||
self.array = address
|
||||
else:
|
||||
self.array = Array(reduce(operator.mul, sizes), \
|
||||
value_type, address=address)
|
||||
SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \
|
||||
debug=debug)
|
||||
if len(sizes) < 2:
|
||||
raise CompilerError('Use Array')
|
||||
|
||||
class Matrix(MultiArray):
|
||||
def __init__(self, rows, columns, value_type, debug=None, address=None):
|
||||
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
|
||||
address=address)
|
||||
|
||||
class VectorArray(object):
|
||||
def __init__(self, length, value_type, vector_size, address=None):
|
||||
self.array = Array(length * vector_size, value_type, address)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
This directory contains the code used for the benchmarks by [Dalskov
|
||||
et al.](https://eprint.iacr.org/2019/889) `*-ecdsa-party.cpp`
|
||||
contains the high-level programs while the two phases are implemented
|
||||
`preprocessing.hpp` and `sign.hpp`, respectively.
|
||||
in `preprocessing.hpp` and `sign.hpp`, respectively.
|
||||
|
||||
#### Compilation
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class invalid_program: public exception
|
||||
class file_error: public exception
|
||||
{ string filename, ans;
|
||||
public:
|
||||
file_error(string m="") : filename(m)
|
||||
file_error(string m) : filename(m)
|
||||
{
|
||||
ans="File Error : ";
|
||||
ans+=filename;
|
||||
|
||||
@@ -36,7 +36,8 @@ bigint SemiHomomorphicNoiseBounds::min_p0()
|
||||
|
||||
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q)
|
||||
{
|
||||
return 33.1 * (log_q - log2(3.2));
|
||||
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
|
||||
return 37.8 * (log_q - log2(3.2));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ inline FileSacriFactory<T>::FileSacriFactory(const char* type, const Player& P,
|
||||
if (output_thread)
|
||||
file1 << "-" << output_thread;
|
||||
this->inpf.open(file1.str().c_str(),ios::in | ios::binary);
|
||||
if (this->inpf.fail()) { throw file_error(); }
|
||||
if (this->inpf.fail()) { throw file_error(file1.str()); }
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -221,12 +221,12 @@ void Triple_Checking(const Player& P,MAC_Check<gf2n_short>& MC,int nm)
|
||||
/* Open file for reading in the initial triples */
|
||||
stringstream file1; file1 << PREP_DIR "Initial-Triples-" << file_completion(dummy) << "-P" << P.my_num();
|
||||
ifstream inpf(file1.str().c_str(),ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(); }
|
||||
if (inpf.fail()) { throw file_error(file1.str()); }
|
||||
|
||||
/* Open file for writing out the final triples */
|
||||
stringstream file3; file3 << PREP_DIR "Triples-" << file_completion(dummy) << "-P" << P.my_num();
|
||||
ofstream outf(file3.str().c_str(),ios::out | ios::binary);
|
||||
if (outf.fail()) { throw file_error(); }
|
||||
if (outf.fail()) { throw file_error(file3.str()); }
|
||||
|
||||
gf2n_short te,t;
|
||||
Create_Random(t,P);
|
||||
@@ -444,12 +444,12 @@ void Square_Checking(const Player& P,MAC_Check<gf2n_short>& MC,int ns)
|
||||
/* Open files for reading in the initial data */
|
||||
stringstream file1; file1 << PREP_DIR "Initial-Squares-" << file_completion(dummy) << "-P" << P.my_num();
|
||||
ifstream inpf_s(file1.str().c_str(),ios::in | ios::binary);
|
||||
if (inpf_s.fail()) { throw file_error(); }
|
||||
if (inpf_s.fail()) { throw file_error(file1.str()); }
|
||||
|
||||
/* Open files for writing out the final data */
|
||||
stringstream file3; file3 << PREP_DIR "Squares-" << file_completion(dummy) << "-P" << P.my_num();
|
||||
ofstream outf_s(file3.str().c_str(),ios::out | ios::binary);
|
||||
if (outf_s.fail()) { throw file_error(); }
|
||||
if (outf_s.fail()) { throw file_error(file3.str()); }
|
||||
|
||||
gf2n_short te,t,t2;
|
||||
Create_Random(t,P);
|
||||
|
||||
@@ -153,7 +153,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
|
||||
others_ciphertexts.resize(this->sec, pk.get_params());
|
||||
for (int i = 1; i < P.num_players(); i++)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_HE
|
||||
cerr << "Sending proof with " << 1e-9 * ciphertexts.get_length() << "+"
|
||||
<< 1e-9 * cleartexts.get_length() << " GB" << endl;
|
||||
#endif
|
||||
@@ -164,7 +164,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
|
||||
#ifndef LESS_ALLOC_MORE_MEM
|
||||
Verifier<FD,S> verifier(proof);
|
||||
#endif
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_HE
|
||||
cerr << "Checking proof of player " << i << endl;
|
||||
#endif
|
||||
timers["Verifying"].start();
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "GC/Processor.h"
|
||||
#include "GC/square64.h"
|
||||
|
||||
#include "GC/Processor.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -14,7 +16,7 @@ int FakeSecret::default_length = 128;
|
||||
|
||||
ostream& FakeSecret::out = cout;
|
||||
|
||||
void FakeSecret::load(int n, const Integer& x)
|
||||
void FakeSecret::load_clear(int n, const Integer& x)
|
||||
{
|
||||
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
|
||||
throw out_of_range("public value too long");
|
||||
|
||||
@@ -79,7 +79,7 @@ public:
|
||||
|
||||
__uint128_t operator^=(const FakeSecret& other) { return a ^= other.a; }
|
||||
|
||||
void load(int n, const Integer& x);
|
||||
void load_clear(int n, const Integer& x);
|
||||
template <class T>
|
||||
void load(int n, const Memory<T>& mem, size_t address) { load(n, mem[address]); }
|
||||
template <class T>
|
||||
|
||||
@@ -93,8 +93,8 @@ unsigned GC::Instruction<T>::get_max_reg(int reg_type) const
|
||||
offset = 1;
|
||||
break;
|
||||
case INPUTB:
|
||||
skip = 3;
|
||||
offset = 2;
|
||||
skip = 4;
|
||||
offset = 3;
|
||||
break;
|
||||
case CONVCBIT:
|
||||
return BaseInstruction::get_max_reg(INT);
|
||||
|
||||
@@ -23,9 +23,6 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
extern template class ReplicatedSecret<SemiHonestRepSecret>;
|
||||
extern template class ReplicatedSecret<MaliciousRepSecret>;
|
||||
|
||||
#define GC_MACHINE(T) \
|
||||
template class Instruction<T>; \
|
||||
template class Machine<T>; \
|
||||
@@ -34,8 +31,4 @@ extern template class ReplicatedSecret<MaliciousRepSecret>;
|
||||
template class Thread<T>; \
|
||||
template class ThreadMaster<T>; \
|
||||
|
||||
GC_MACHINE(FakeSecret);
|
||||
GC_MACHINE(SemiHonestRepSecret);
|
||||
GC_MACHINE(MaliciousRepSecret)
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
/*
|
||||
* MaliciousRepPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_MALICIOUSREPPREP_H_
|
||||
#define GC_MALICIOUSREPPREP_H_
|
||||
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class MaliciousRepPrep : public BufferPrep<MaliciousRepSecret>
|
||||
{
|
||||
ReplicatedBase* protocol;
|
||||
|
||||
public:
|
||||
MaliciousRepPrep(DataPositions& usage);
|
||||
~MaliciousRepPrep();
|
||||
|
||||
void set_protocol(MaliciousRepSecret::Protocol& protocol);
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_bits();
|
||||
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_MALICIOUSREPPREP_H_ */
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef GC_MALICIOUSREPSECRET_H_
|
||||
#define GC_MALICIOUSREPSECRET_H_
|
||||
|
||||
#include "ReplicatedSecret.h"
|
||||
#include "ShareSecret.h"
|
||||
#include "Machine.h"
|
||||
#include "Protocols/Beaver.h"
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
@@ -17,7 +17,8 @@ template<class T> class MaliciousRepMC;
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class MaliciousRepThread;
|
||||
template<class T> class ShareThread;
|
||||
template<class T> class RepPrep;
|
||||
|
||||
class MaliciousRepSecret : public ReplicatedSecret<MaliciousRepSecret>
|
||||
{
|
||||
@@ -30,7 +31,8 @@ public:
|
||||
typedef MC MAC_Check;
|
||||
|
||||
typedef Beaver<MaliciousRepSecret> Protocol;
|
||||
typedef NotImplementedInput Input;
|
||||
typedef ReplicatedInput<MaliciousRepSecret> Input;
|
||||
typedef RepPrep<MaliciousRepSecret> LivePrep;
|
||||
|
||||
static MC* new_mc(Machine<MaliciousRepSecret>& machine)
|
||||
{
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
/*
|
||||
* MalicousRepParty.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "MaliciousRepThread.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
thread_local MaliciousRepThread* MaliciousRepThread::singleton = 0;
|
||||
|
||||
MaliciousRepThread::MaliciousRepThread(int i,
|
||||
ThreadMaster<MaliciousRepSecret>& master) :
|
||||
Thread<MaliciousRepSecret>(i, master), DataF(usage)
|
||||
{
|
||||
}
|
||||
|
||||
void MaliciousRepThread::pre_run()
|
||||
{
|
||||
if (singleton)
|
||||
throw runtime_error("there can only be one");
|
||||
singleton = this;
|
||||
DataF.set_protocol(*protocol);
|
||||
}
|
||||
|
||||
void MaliciousRepThread::post_run()
|
||||
{
|
||||
#ifndef INSECURE
|
||||
cerr << "Removing used pre-processed data" << endl;
|
||||
DataF.prune();
|
||||
#endif
|
||||
}
|
||||
|
||||
void MaliciousRepThread::and_(Processor<MaliciousRepSecret>& processor,
|
||||
const vector<int>& args, bool repeat)
|
||||
{
|
||||
assert(P->num_players() == 3);
|
||||
processor.check_args(args, 4);
|
||||
protocol->init_mul(DataF, *MC);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
MaliciousRepSecret y_ext;
|
||||
if (repeat)
|
||||
y_ext = processor.S[right].extend_bit();
|
||||
else
|
||||
y_ext = processor.S[right];
|
||||
protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits));
|
||||
}
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
int out = args[i + 1];
|
||||
processor.S[out] = protocol->finalize_mul().mask(n_bits);
|
||||
}
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
@@ -1,48 +0,0 @@
|
||||
/*
|
||||
* MalicousRepParty.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_MALICIOUSREPTHREAD_H_
|
||||
#define GC_MALICIOUSREPTHREAD_H_
|
||||
|
||||
#include "Thread.h"
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "MaliciousRepPrep.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class MaliciousRepThread : public Thread<MaliciousRepSecret>
|
||||
{
|
||||
static thread_local MaliciousRepThread* singleton;
|
||||
|
||||
public:
|
||||
static MaliciousRepThread& s();
|
||||
|
||||
DataPositions usage;
|
||||
MaliciousRepPrep DataF;
|
||||
|
||||
MaliciousRepThread(int i, ThreadMaster<MaliciousRepSecret>& master);
|
||||
virtual ~MaliciousRepThread() {}
|
||||
|
||||
void pre_run();
|
||||
void post_run();
|
||||
|
||||
void and_(Processor<MaliciousRepSecret>& processor, const vector<int>& args, bool repeat);
|
||||
};
|
||||
|
||||
inline MaliciousRepThread& MaliciousRepThread::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_MALICIOUSREPTHREAD_H_ */
|
||||
46
GC/RepPrep.h
Normal file
46
GC/RepPrep.h
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* MaliciousRepPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REPPREP_H_
|
||||
#define GC_REPPREP_H_
|
||||
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "ShiftableTripleBuffer.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class RepPrep : public BufferPrep<T>, ShiftableTripleBuffer<T>
|
||||
{
|
||||
ReplicatedBase* protocol;
|
||||
|
||||
public:
|
||||
RepPrep(DataPositions& usage, Thread<T>& thread);
|
||||
~RepPrep();
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_bits();
|
||||
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
void get(Dtype type, T* data)
|
||||
{
|
||||
BufferPrep<T>::get(type, data);
|
||||
}
|
||||
|
||||
array<T, 3> get_triple(int n_bits)
|
||||
{
|
||||
return ShiftableTripleBuffer<T>::get_triple(n_bits);
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_REPPREP_H_ */
|
||||
@@ -3,8 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "MaliciousRepPrep.h"
|
||||
#include "MaliciousRepThread.h"
|
||||
#include "RepPrep.h"
|
||||
#include "ShareThread.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
@@ -15,33 +15,40 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage) :
|
||||
BufferPrep<MaliciousRepSecret>(usage), protocol(0)
|
||||
template<class T>
|
||||
RepPrep<T>::RepPrep(DataPositions& usage, Thread<T>& thread) :
|
||||
BufferPrep<T>(usage), protocol(0)
|
||||
{
|
||||
(void) thread;
|
||||
}
|
||||
|
||||
MaliciousRepPrep::~MaliciousRepPrep()
|
||||
template<class T>
|
||||
RepPrep<T>::~RepPrep()
|
||||
{
|
||||
if (protocol)
|
||||
delete protocol;
|
||||
}
|
||||
|
||||
void MaliciousRepPrep::set_protocol(MaliciousRepSecret::Protocol& protocol)
|
||||
template<class T>
|
||||
void RepPrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
this->protocol = new ReplicatedBase(protocol.P);
|
||||
}
|
||||
|
||||
void MaliciousRepPrep::buffer_triples()
|
||||
template<class T>
|
||||
void RepPrep<T>::buffer_triples()
|
||||
{
|
||||
assert(protocol != 0);
|
||||
auto MC = MaliciousRepThread::s().new_mc();
|
||||
shuffle_triple_generation(triples, protocol->P, *MC, 64);
|
||||
auto MC = ShareThread<T>::s().new_mc();
|
||||
shuffle_triple_generation(this->triples, protocol->P, *MC, 64);
|
||||
delete MC;
|
||||
}
|
||||
|
||||
void MaliciousRepPrep::buffer_bits()
|
||||
template<class T>
|
||||
void RepPrep<T>::buffer_bits()
|
||||
{
|
||||
assert(this->protocol != 0);
|
||||
assert(this->protocol->P.num_players() == 3);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
{
|
||||
this->bits.push_back({});
|
||||
@@ -1,113 +0,0 @@
|
||||
/*
|
||||
* ReplicatedParty.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ReplicatedParty.h"
|
||||
#include "Thread.h"
|
||||
#include "MaliciousRepThread.h"
|
||||
#include "Networking/Server.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
ReplicatedParty<T>::ReplicatedParty(int argc, const char** argv) :
|
||||
ThreadMaster<T>(online_opts), online_opts(opt, argc, argv)
|
||||
{
|
||||
opt.add(
|
||||
"localhost", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Host where party 0 is running (default: localhost)", // Help description.
|
||||
"-h", // Flag token.
|
||||
"--hostname" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"5000", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Base port number (default: 5000).", // Help description.
|
||||
"-pn", // Flag token.
|
||||
"--portnum" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Unencrypted communication.", // Help description.
|
||||
"-u", // Flag token.
|
||||
"--unencrypted" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Check opening by communication instead of hashing.", // Help description.
|
||||
"-c", // Flag token.
|
||||
"--communication" // Flag token.
|
||||
);
|
||||
online_opts.finalize(opt, argc, argv);
|
||||
this->progname = online_opts.progname;
|
||||
int my_num = online_opts.playerno;
|
||||
int pnb;
|
||||
string hostname;
|
||||
opt.get("-pn")->getInt(pnb);
|
||||
opt.get("-h")->getString(hostname);
|
||||
this->machine.use_encryption = not opt.get("-u")->isSet;
|
||||
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
|
||||
|
||||
T::out.activate(my_num == 0 or online_opts.interactive);
|
||||
|
||||
if (not this->machine.use_encryption)
|
||||
insecure("unencrypted communication");
|
||||
|
||||
Server* server = Server::start_networking(this->N, my_num, 3, hostname, pnb);
|
||||
|
||||
this->run();
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
}
|
||||
|
||||
template<>
|
||||
Thread<SemiHonestRepSecret>* ReplicatedParty<SemiHonestRepSecret>::new_thread(int i)
|
||||
{
|
||||
return ThreadMaster<SemiHonestRepSecret>::new_thread(i);
|
||||
}
|
||||
|
||||
template<>
|
||||
Thread<MaliciousRepSecret>* ReplicatedParty<MaliciousRepSecret>::new_thread(int i)
|
||||
{
|
||||
return new MaliciousRepThread(i, *this);
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedParty<SemiHonestRepSecret>::post_run()
|
||||
{
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedParty<MaliciousRepSecret>::post_run()
|
||||
{
|
||||
DataPositions usage;
|
||||
for (auto thread : threads)
|
||||
usage.increase(((MaliciousRepThread*)thread)->usage);
|
||||
usage.print_cost();
|
||||
}
|
||||
|
||||
extern template class ReplicatedSecret<SemiHonestRepSecret>;
|
||||
extern template class ReplicatedSecret<MaliciousRepSecret>;
|
||||
|
||||
template class ReplicatedParty<SemiHonestRepSecret>;
|
||||
template class ReplicatedParty<MaliciousRepSecret>;
|
||||
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
/*
|
||||
* ReplicatedParty.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REPLICATEDPARTY_H_
|
||||
#define GC_REPLICATEDPARTY_H_
|
||||
|
||||
#include "Protocols/ReplicatedMC.h"
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "ReplicatedSecret.h"
|
||||
#include "Processor.h"
|
||||
#include "Program.h"
|
||||
#include "Memory.h"
|
||||
#include "ThreadMaster.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class ReplicatedParty : public ThreadMaster<T>
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
OnlineOptions online_opts;
|
||||
|
||||
public:
|
||||
static Thread<T>& s();
|
||||
|
||||
ReplicatedParty(int argc, const char** argv);
|
||||
|
||||
Thread<T>* new_thread(int i);
|
||||
|
||||
void post_run();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
inline Thread<T>& ReplicatedParty<T>::s()
|
||||
{
|
||||
return Thread<T>::s();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_REPLICATEDPARTY_H_ */
|
||||
@@ -1,254 +0,0 @@
|
||||
/*
|
||||
* ReplicatedSecret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ReplicatedSecret.h"
|
||||
#include "ReplicatedParty.h"
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "MaliciousRepThread.h"
|
||||
#include "Thread.h"
|
||||
#include "square64.h"
|
||||
|
||||
#include "Protocols/Share.h"
|
||||
|
||||
#include "Protocols/ReplicatedMC.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class U>
|
||||
int ReplicatedSecret<U>::default_length = 8 * sizeof(ReplicatedSecret<U>::value_type);
|
||||
|
||||
template<class U>
|
||||
SwitchableOutput ReplicatedSecret<U>::out;
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::load(int n, const Integer& x)
|
||||
{
|
||||
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
|
||||
throw out_of_range("public value too long");
|
||||
*this = x;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::load(vector<ReadAccess<U> >& accesses,
|
||||
const Memory<U>& mem)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
access.dest = mem[access.address];
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::store(Memory<U>& mem,
|
||||
vector<WriteAccess<U> >& accesses)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = access.source;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
|
||||
const vector<ClearWriteAccess>& accesses)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = access.value;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::inputb(Processor<U>& processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
auto& party = ReplicatedParty<U>::s();
|
||||
party.os.resize(2);
|
||||
for (auto& o : party.os)
|
||||
o.reset_write_head();
|
||||
|
||||
InputArgList a(args);
|
||||
bool interactive = party.n_interactive_inputs_from_me(a) > 0;
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
if (x.from == party.P->my_num())
|
||||
{
|
||||
auto& res = processor.S[x.dest];
|
||||
res.prepare_input(party.os, processor.get_input(x.params, interactive), x.n_bits, party.secure_prng);
|
||||
}
|
||||
}
|
||||
|
||||
if (interactive)
|
||||
cout << "Thank you" << endl;
|
||||
|
||||
for (int i = 0; i < 2; i++)
|
||||
party.P->pass_around(party.os[i], i + 1);
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
int from = x.from;
|
||||
int n_bits = x.n_bits;
|
||||
if (from != party.P->my_num())
|
||||
{
|
||||
auto& res = processor.S[x.dest];
|
||||
res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class U>
|
||||
U ReplicatedSecret<U>::input(Processor<U>& processor, const InputArgs& args)
|
||||
{
|
||||
int from = args.from;
|
||||
int n_bits = args.n_bits;
|
||||
auto& party = ReplicatedParty<U>::s();
|
||||
U res;
|
||||
party.os.resize(2);
|
||||
for (auto& o : party.os)
|
||||
o.reset_write_head();
|
||||
if (from == party.P->my_num())
|
||||
{
|
||||
res.prepare_input(party.os, processor.get_input(args.params), n_bits, party.secure_prng);
|
||||
party.P->send_relative(party.os);
|
||||
}
|
||||
else
|
||||
{
|
||||
party.P->receive_player(from, party.os[0], true);
|
||||
res.finalize_input(party, party.os[0], from, n_bits);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng)
|
||||
{
|
||||
randomize_to_sum(input, secure_prng);
|
||||
*this &= get_mask(n_bits);
|
||||
for (int i = 0; i < 2; i++)
|
||||
BitVec(get_mask(n_bits) & (*this)[i]).pack(os[i], n_bits);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits)
|
||||
{
|
||||
int j = party.P->get_offset(from) == 2;
|
||||
(*this)[j] = BitVec::unpack_new(o, n_bits);
|
||||
(*this)[1 - j] = 0;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
BitVec ReplicatedSecret<U>::local_mul(const ReplicatedSecret& other) const
|
||||
{
|
||||
return (*this)[0] * other.sum() + (*this)[1] * other[0];
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::and_(int n,
|
||||
const ReplicatedSecret<U>& x,
|
||||
const ReplicatedSecret<U>& y, bool repeat)
|
||||
{
|
||||
(void)n, (void)x, (void)y, (void)repeat;
|
||||
throw runtime_error("use static method");
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedSecret<SemiHonestRepSecret>::and_(Processor<SemiHonestRepSecret>& processor,
|
||||
const vector<int>& args, bool repeat)
|
||||
{
|
||||
auto& party = Thread<SemiHonestRepSecret>::s();
|
||||
assert(party.P->num_players() == 3);
|
||||
processor.check_args(args, 4);
|
||||
assert(party.protocol != 0);
|
||||
auto& protocol = *party.protocol;
|
||||
protocol.init_mul();
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
MaliciousRepSecret y_ext;
|
||||
if (repeat)
|
||||
y_ext = processor.S[right].extend_bit();
|
||||
else
|
||||
y_ext = processor.S[right];
|
||||
protocol.prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
processor.S[args[i + 1]] = protocol.finalize_mul(args[i]);
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedSecret<MaliciousRepSecret>::and_(
|
||||
Processor<MaliciousRepSecret>& processor, const vector<int>& args,
|
||||
bool repeat)
|
||||
{
|
||||
MaliciousRepThread::s().and_(processor, args, repeat);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::trans(Processor<U>& processor,
|
||||
int n_outputs, const vector<int>& args)
|
||||
{
|
||||
assert(length == 2);
|
||||
for (int k = 0; k < 2; k++)
|
||||
{
|
||||
square64 square;
|
||||
for (size_t i = n_outputs; i < args.size(); i++)
|
||||
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
|
||||
square.transpose(args.size() - n_outputs, n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
processor.S[args[i]][k] = square.rows[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
|
||||
{
|
||||
(void) n_bits;
|
||||
ReplicatedSecret share = *this;
|
||||
vector<BitVec> opened;
|
||||
auto& party = ReplicatedParty<U>::s();
|
||||
party.MC->POpen_Begin(opened, {share}, *party.P);
|
||||
party.MC->POpen_End(opened, {share}, *party.P);
|
||||
x = IntBase(opened[0]);
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedSecret<SemiHonestRepSecret>::random_bit()
|
||||
{
|
||||
auto& party = ReplicatedParty<SemiHonestRepSecret>::s();
|
||||
*this = party.secure_prng.get_bit();
|
||||
octetStream o;
|
||||
(*this)[0].pack(o, 1);
|
||||
party.P->pass_around(o, 1);
|
||||
(*this)[1].unpack(o, 1);
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedSecret<MaliciousRepSecret>::random_bit()
|
||||
{
|
||||
MaliciousRepSecret res;
|
||||
MaliciousRepThread::s().DataF.get_one(DATA_BIT, res);
|
||||
*this = res;
|
||||
}
|
||||
|
||||
template class ReplicatedSecret<SemiHonestRepSecret>;
|
||||
template class ReplicatedSecret<MaliciousRepSecret>;
|
||||
|
||||
}
|
||||
32
GC/Secret.h
32
GC/Secret.h
@@ -83,6 +83,8 @@ public:
|
||||
|
||||
static typename T::out_type out;
|
||||
|
||||
static const bool needs_ot = false;
|
||||
|
||||
static T& cast(T& reg) { return *reinterpret_cast<T*>(®); }
|
||||
static const T& cast(const T& reg) { return *reinterpret_cast<const T*>(®); }
|
||||
|
||||
@@ -98,34 +100,40 @@ public:
|
||||
static Secret<T> carryless_mult(const Secret<T>& x, const Secret<T>& y);
|
||||
static void output(T& reg);
|
||||
|
||||
template<class U>
|
||||
static void load(vector< ReadAccess< Secret<T> > >& accesses, const U& mem);
|
||||
template<class U>
|
||||
static void store(U& mem, vector< WriteAccess< Secret<T> > >& accesses);
|
||||
template<class U, class V>
|
||||
static void load(vector< ReadAccess<V> >& accesses, const U& mem);
|
||||
template<class U, class V>
|
||||
static void store(U& mem, vector< WriteAccess<V> >& accesses);
|
||||
|
||||
static void andrs(Processor< Secret<T> >& processor, const vector<int>& args)
|
||||
template<class U>
|
||||
static void andrs(Processor<U>& processor, const vector<int>& args)
|
||||
{ T::andrs(processor, args); }
|
||||
static void ands(Processor< Secret<T> >& processor, const vector<int>& args)
|
||||
template<class U>
|
||||
static void ands(Processor<U>& processor, const vector<int>& args)
|
||||
{ T::ands(processor, args); }
|
||||
static void inputb(Processor< Secret<T> >& processor, const vector<int>& args)
|
||||
template<class U>
|
||||
static void inputb(Processor<U>& processor, const vector<int>& args)
|
||||
{ T::inputb(processor, args); }
|
||||
|
||||
static void trans(Processor<Secret<T> >& processor, int n_inputs, const vector<int>& args);
|
||||
template<class U>
|
||||
static void trans(Processor<U>& processor, int n_inputs, const vector<int>& args);
|
||||
|
||||
static void convcbit(Integer& dest, const Clear& source) { T::convcbit(dest, source); }
|
||||
|
||||
Secret();
|
||||
Secret(const Integer& x) { *this = x; }
|
||||
|
||||
void load(int n, const Integer& x);
|
||||
void operator=(const Integer& x) { load(default_length, x); }
|
||||
void load_clear(int n, const Integer& x);
|
||||
void operator=(const Integer& x) { load_clear(default_length, x); }
|
||||
void load(int n, const Memory<AuthValue>& mem, size_t address);
|
||||
|
||||
Secret<T> operator<<(int i);
|
||||
Secret<T> operator>>(int i);
|
||||
|
||||
void bitcom(Memory< Secret<T> >& S, const vector<int>& regs);
|
||||
void bitdec(Memory< Secret<T> >& S, const vector<int>& regs) const;
|
||||
template<class U>
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
template<class U>
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
|
||||
Secret<T> operator+(const Secret<T> x) const;
|
||||
Secret<T>& operator+=(const Secret<T> x) { *this = *this + x; return *this; }
|
||||
|
||||
@@ -102,9 +102,9 @@ void Secret<T>::random_bit()
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
template <class U, class V>
|
||||
void Secret<T>::store(U& mem,
|
||||
vector<WriteAccess<Secret<T> > >& accesses)
|
||||
vector<WriteAccess<V> >& accesses)
|
||||
{
|
||||
T::store(mem, accesses);
|
||||
}
|
||||
@@ -194,7 +194,7 @@ T& GC::Secret<T>::get_new_reg()
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Secret<T>::load(int n, const Integer& x)
|
||||
void Secret<T>::load_clear(int n, const Integer& x)
|
||||
{
|
||||
if ((unsigned)n < 8 * sizeof(x) and abs(x.get()) > (1LL << n))
|
||||
throw out_of_range("public value too long");
|
||||
@@ -219,8 +219,8 @@ void Secret<T>::load(int n, const Integer& x)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void Secret<T>::load(vector<ReadAccess < Secret<T> > >& accesses, const U& mem)
|
||||
template <class U, class V>
|
||||
void Secret<T>::load(vector<ReadAccess <V> >& accesses, const U& mem)
|
||||
{
|
||||
for (auto&& access : accesses)
|
||||
{
|
||||
@@ -252,7 +252,8 @@ Secret<T> Secret<T>::operator>>(int i)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Secret<T>::bitcom(Memory<Secret>& S, const vector<int>& regs)
|
||||
template <class U>
|
||||
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
{
|
||||
registers.clear();
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -264,7 +265,8 @@ void Secret<T>::bitcom(Memory<Secret>& S, const vector<int>& regs)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Secret<T>::bitdec(Memory<Secret>& S, const vector<int>& regs) const
|
||||
template <class U>
|
||||
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
if (regs.size() > registers.size())
|
||||
throw out_of_range(
|
||||
@@ -280,7 +282,8 @@ void Secret<T>::bitdec(Memory<Secret>& S, const vector<int>& regs) const
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Secret<T>::trans(Processor<Secret<T> >& processor, int n_outputs,
|
||||
template<class U>
|
||||
void Secret<T>::trans(Processor<U>& processor, int n_outputs,
|
||||
const vector<int>& args)
|
||||
{
|
||||
int n_inputs = args.size() - n_outputs;
|
||||
|
||||
11
GC/SemiHonestRepPrep.cpp
Normal file
11
GC/SemiHonestRepPrep.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* ReplicatedPrep.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <GC/SemiHonestRepPrep.h>
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
28
GC/SemiHonestRepPrep.h
Normal file
28
GC/SemiHonestRepPrep.h
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* ReplicatedPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SEMIHONESTREPPREP_H_
|
||||
#define GC_SEMIHONESTREPPREP_H_
|
||||
|
||||
#include "RepPrep.h"
|
||||
#include "ShareSecret.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class SemiHonestRepPrep : public RepPrep<SemiHonestRepSecret>
|
||||
{
|
||||
public:
|
||||
SemiHonestRepPrep(DataPositions& usage, Thread<SemiHonestRepSecret>& thread) :
|
||||
RepPrep<SemiHonestRepSecret>(usage, thread)
|
||||
{
|
||||
}
|
||||
|
||||
void buffer_triples() { throw not_implemented(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SEMIHONESTREPPREP_H_ */
|
||||
58
GC/SemiPrep.cpp
Normal file
58
GC/SemiPrep.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
/*
|
||||
* SemiPrep.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "SemiPrep.h"
|
||||
#include "ThreadMaster.h"
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/BitDiagonal.h"
|
||||
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
SemiPrep::SemiPrep(DataPositions& usage, Thread<SemiSecret>& thread) :
|
||||
BufferPrep<SemiSecret>(usage), thread(thread), triple_generator(0)
|
||||
{
|
||||
}
|
||||
|
||||
void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
|
||||
{
|
||||
(void) protocol;
|
||||
params.set_passive();
|
||||
triple_generator = new SemiSecret::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
|
||||
thread.master.N, thread.thread_num, thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
}
|
||||
|
||||
void SemiPrep::buffer_triples()
|
||||
{
|
||||
assert(this->triple_generator);
|
||||
this->triple_generator->generatePlainTriples();
|
||||
for (auto& x : this->triple_generator->plainTriples)
|
||||
{
|
||||
this->triples.push_back({{x[0], x[1], x[2]}});
|
||||
}
|
||||
this->triple_generator->unlock();
|
||||
}
|
||||
|
||||
SemiPrep::~SemiPrep()
|
||||
{
|
||||
if (triple_generator)
|
||||
delete triple_generator;
|
||||
}
|
||||
|
||||
void SemiPrep::buffer_bits()
|
||||
{
|
||||
word r = thread.secure_prng.get_word();
|
||||
for (size_t i = 0; i < sizeof(word) * 8; i++)
|
||||
this->bits.push_back((r >> i) & 1);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
51
GC/SemiPrep.h
Normal file
51
GC/SemiPrep.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* SemiPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SEMIPREP_H_
|
||||
#define GC_SEMIPREP_H_
|
||||
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "OT/TripleMachine.h"
|
||||
#include "SemiSecret.h"
|
||||
#include "ShiftableTripleBuffer.h"
|
||||
|
||||
template<class T> class Beaver;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret>
|
||||
{
|
||||
Thread<SemiSecret>& thread;
|
||||
|
||||
SemiSecret::TripleGenerator* triple_generator;
|
||||
MascotParams params;
|
||||
|
||||
public:
|
||||
SemiPrep(DataPositions& usage, Thread<SemiSecret>& thread);
|
||||
~SemiPrep();
|
||||
|
||||
void set_protocol(Beaver<SemiSecret>& protocol);
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_bits();
|
||||
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
void get(Dtype type, SemiSecret* data)
|
||||
{
|
||||
BufferPrep<SemiSecret>::get(type, data);
|
||||
}
|
||||
|
||||
array<SemiSecret, 3> get_triple(int n_bits)
|
||||
{
|
||||
return ShiftableTripleBuffer<SemiSecret>::get_triple(n_bits);
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SEMIPREP_H_ */
|
||||
52
GC/SemiSecret.cpp
Normal file
52
GC/SemiSecret.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* SemiSecret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/ShareParty.h"
|
||||
#include "SemiSecret.h"
|
||||
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
void SemiSecret::trans(Processor<SemiSecret>& processor, int n_outputs,
|
||||
const vector<int>& args)
|
||||
{
|
||||
square64 square;
|
||||
for (size_t i = n_outputs; i < args.size(); i++)
|
||||
square.rows[i - n_outputs] = processor.S[args[i]].get();
|
||||
square.transpose(args.size() - n_outputs, n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
processor.S[args[i]] = square.rows[i];
|
||||
}
|
||||
|
||||
void SemiSecret::load_clear(int n, const Integer& x)
|
||||
{
|
||||
check_length(n, x);
|
||||
*this = constant(x, Thread<SemiSecret>::s().P->my_num());
|
||||
}
|
||||
|
||||
void SemiSecret::bitcom(Memory<SemiSecret>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
void SemiSecret::bitdec(Memory<SemiSecret>& S,
|
||||
const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
}
|
||||
|
||||
void SemiSecret::reveal(size_t n_bits, Clear& x)
|
||||
{
|
||||
auto& thread = Thread<SemiSecret>::s();
|
||||
x = thread.MC->POpen(*this, *thread.P).mask(n_bits);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
67
GC/SemiSecret.h
Normal file
67
GC/SemiSecret.h
Normal file
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* SemiSecret.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SEMISECRET_H_
|
||||
#define GC_SEMISECRET_H_
|
||||
|
||||
#include "Protocols/SemiMC.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "ShareSecret.h"
|
||||
|
||||
template<class T> class Beaver;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class SemiPrep;
|
||||
|
||||
class SemiSecret : public SemiShare<BitVec>, public ShareSecret<SemiSecret>
|
||||
{
|
||||
public:
|
||||
typedef Memory<SemiSecret> DynamicMemory;
|
||||
|
||||
typedef SemiMC<SemiSecret> MC;
|
||||
typedef Beaver<SemiSecret> Protocol;
|
||||
typedef MC MAC_Check;
|
||||
typedef SemiPrep LivePrep;
|
||||
typedef SemiInput<SemiSecret> Input;
|
||||
|
||||
static const int default_length = sizeof(BitVec) * 8;
|
||||
|
||||
static string type_string() { return "binary secret"; }
|
||||
static string phase_name() { return "Binary computation"; }
|
||||
|
||||
static MC* new_mc(Machine<SemiSecret>& _) { (void) _; return new MC; }
|
||||
|
||||
static void trans(Processor<SemiSecret>& processor, int n_outputs,
|
||||
const vector<int>& args);
|
||||
|
||||
SemiSecret()
|
||||
{
|
||||
}
|
||||
SemiSecret(long other) :
|
||||
SemiShare<BitVec>(other)
|
||||
{
|
||||
}
|
||||
SemiSecret(const IntBase& other) :
|
||||
SemiShare<BitVec>(other)
|
||||
{
|
||||
}
|
||||
|
||||
void load_clear(int n, const Integer& x);
|
||||
|
||||
void bitcom(Memory<SemiSecret>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<SemiSecret>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const SemiSecret& x, const SemiSecret& y)
|
||||
{ *this = BitVec(x ^ y).mask(n); }
|
||||
|
||||
void reveal(size_t n_bits, Clear& x);
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SEMISECRET_H_ */
|
||||
51
GC/ShareParty.h
Normal file
51
GC/ShareParty.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* ReplicatedParty.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SHAREPARTY_H_
|
||||
#define GC_SHAREPARTY_H_
|
||||
|
||||
#include "Protocols/ReplicatedMC.h"
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "ShareSecret.h"
|
||||
#include "Processor.h"
|
||||
#include "Program.h"
|
||||
#include "Memory.h"
|
||||
#include "ThreadMaster.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class ShareParty : public ThreadMaster<T>
|
||||
{
|
||||
static ShareParty<T>* singleton;
|
||||
|
||||
ez::ezOptionParser opt;
|
||||
OnlineOptions online_opts;
|
||||
|
||||
public:
|
||||
static ShareParty& s();
|
||||
|
||||
typename T::mac_key_type mac_key;
|
||||
|
||||
ShareParty(int argc, const char** argv, int default_batch_size = 0);
|
||||
|
||||
Thread<T>* new_thread(int i);
|
||||
|
||||
void post_run();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
inline ShareParty<T>& ShareParty<T>::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_SHAREPARTY_H_ */
|
||||
137
GC/ShareParty.hpp
Normal file
137
GC/ShareParty.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
/*
|
||||
* ReplicatedParty.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ShareParty.h"
|
||||
|
||||
#include "Thread.h"
|
||||
#include "ShareThread.h"
|
||||
#include "SemiPrep.h"
|
||||
#include "Networking/Server.h"
|
||||
#include "Networking/CryptoPlayer.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
#include "Tools/NetworkOptions.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
|
||||
#include "ShareThread.hpp"
|
||||
#include "RepPrep.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
ShareParty<T>* ShareParty<T>::singleton = 0;
|
||||
|
||||
template<class T>
|
||||
ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
|
||||
ThreadMaster<T>(online_opts), online_opts(opt, argc, argv,
|
||||
default_batch_size)
|
||||
{
|
||||
if (singleton)
|
||||
throw runtime_error("there can only be one");
|
||||
singleton = this;
|
||||
|
||||
NetworkOptionsWithNumber network_opts(opt, argc, argv,
|
||||
T::dishonest_majority ? 2 : 3, T::dishonest_majority);
|
||||
if (T::dishonest_majority)
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Use encrypted channels.", // Help description.
|
||||
"-e", // Flag token.
|
||||
"--encrypted" // Flag token.
|
||||
);
|
||||
else
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Unencrypted communication.", // Help description.
|
||||
"-u", // Flag token.
|
||||
"--unencrypted" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Check opening by communication instead of hashing.", // Help description.
|
||||
"-c", // Flag token.
|
||||
"--communication" // Flag token.
|
||||
);
|
||||
online_opts.finalize(opt, argc, argv);
|
||||
this->progname = online_opts.progname;
|
||||
int my_num = online_opts.playerno;
|
||||
|
||||
if (T::dishonest_majority)
|
||||
this->machine.use_encryption = opt.get("-e")->isSet;
|
||||
else
|
||||
this->machine.use_encryption = not opt.get("-u")->isSet;
|
||||
|
||||
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
|
||||
|
||||
T::out.activate(my_num == 0 or online_opts.interactive);
|
||||
|
||||
if (not this->machine.use_encryption and not T::dishonest_majority)
|
||||
insecure("unencrypted communication");
|
||||
|
||||
Server* server = network_opts.start_networking(this->N, my_num);
|
||||
|
||||
if (online_opts.live_prep)
|
||||
if (T::needs_ot)
|
||||
{
|
||||
Player* P;
|
||||
if (this->machine.use_encryption)
|
||||
P = new CryptoPlayer(this->N, 0xFFFF);
|
||||
else
|
||||
P = new PlainPlayer(this->N, 0xFFFF);
|
||||
for (int i = 0; i < this->machine.nthreads; i++)
|
||||
this->machine.ot_setups.push_back({{{*P, true}}});
|
||||
delete P;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
gf2n _;
|
||||
read_mac_keys(get_prep_dir(network_opts.nplayers, 128, 128), this->N,
|
||||
this->mac_key, _);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
SeededPRNG G;
|
||||
this->mac_key.randomize(G);
|
||||
}
|
||||
|
||||
this->run();
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Thread<T>* ShareParty<T>::new_thread(int i)
|
||||
{
|
||||
return new ShareThread<T>(i, *this);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShareParty<T>::post_run()
|
||||
{
|
||||
DataPositions usage;
|
||||
for (auto thread : this->threads)
|
||||
usage.increase(dynamic_cast<ShareThread<T>*>(thread)->usage);
|
||||
usage.print_cost();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REPLICATEDSECRET_H_
|
||||
#define GC_REPLICATEDSECRET_H_
|
||||
#ifndef GC_SHARESECRET_H_
|
||||
#define GC_SHARESECRET_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
@@ -18,6 +18,7 @@ using namespace std;
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "Protocols/Replicated.h"
|
||||
#include "Protocols/ReplicatedMC.h"
|
||||
#include "Processor/DummyProtocol.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -32,22 +33,9 @@ template <class T>
|
||||
class Machine;
|
||||
|
||||
template<class U>
|
||||
class ReplicatedSecret : public FixedVec<BitVec, 2>
|
||||
class ShareSecret
|
||||
{
|
||||
typedef FixedVec<BitVec, 2> super;
|
||||
|
||||
public:
|
||||
typedef BitVec clear;
|
||||
typedef BitVec open_type;
|
||||
typedef BitVec mac_type;
|
||||
typedef BitVec mac_key_type;
|
||||
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static string type_string() { return "replicated secret"; }
|
||||
static string phase_name() { return "Replicated computation"; }
|
||||
|
||||
static int default_length;
|
||||
static SwitchableOutput out;
|
||||
|
||||
static void store_clear_in_dynamic(Memory<U>& mem,
|
||||
@@ -63,38 +51,59 @@ public:
|
||||
static void and_(Processor<U>& processor, const vector<int>& args, bool repeat);
|
||||
static void inputb(Processor<U>& processor, const vector<int>& args);
|
||||
|
||||
static void trans(Processor<U>& processor, int n_outputs,
|
||||
const vector<int>& args);
|
||||
|
||||
static void convcbit(Integer& dest, const Clear& source) { dest = source; }
|
||||
|
||||
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }
|
||||
|
||||
static U input(Processor<U>& processor, const InputArgs& args);
|
||||
void prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng);
|
||||
void finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits);
|
||||
void check_length(int n, const Integer& x);
|
||||
|
||||
void random_bit();
|
||||
};
|
||||
|
||||
template<class U>
|
||||
class ReplicatedSecret : public FixedVec<BitVec, 2>, public ShareSecret<U>
|
||||
{
|
||||
typedef FixedVec<BitVec, 2> super;
|
||||
|
||||
public:
|
||||
typedef BitVec clear;
|
||||
typedef BitVec open_type;
|
||||
typedef BitVec mac_type;
|
||||
typedef BitVec mac_key_type;
|
||||
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static const int N_BITS = clear::N_BITS;
|
||||
|
||||
static const bool dishonest_majority = false;
|
||||
static const bool needs_ot = false;
|
||||
|
||||
static string type_string() { return "replicated secret"; }
|
||||
static string phase_name() { return "Replicated computation"; }
|
||||
|
||||
static int default_length;
|
||||
|
||||
static void trans(Processor<U>& processor, int n_outputs,
|
||||
const vector<int>& args);
|
||||
|
||||
ReplicatedSecret() {}
|
||||
template <class T>
|
||||
ReplicatedSecret(const T& other) : super(other) {}
|
||||
|
||||
void load(int n, const Integer& x);
|
||||
void load_clear(int n, const Integer& x);
|
||||
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
|
||||
{ *this = x ^ y; (void)n; }
|
||||
void and_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y, bool repeat);
|
||||
void andrs(int n, const ReplicatedSecret& x, const ReplicatedSecret& y);
|
||||
|
||||
BitVec local_mul(const ReplicatedSecret& other) const;
|
||||
|
||||
void reveal(size_t n_bits, Clear& x);
|
||||
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
|
||||
{ *this = x ^ y; (void)n; }
|
||||
|
||||
void random_bit();
|
||||
void reveal(size_t n_bits, Clear& x);
|
||||
};
|
||||
|
||||
class SemiHonestRepPrep;
|
||||
|
||||
class SemiHonestRepSecret : public ReplicatedSecret<SemiHonestRepSecret>
|
||||
{
|
||||
@@ -106,6 +115,8 @@ public:
|
||||
typedef ReplicatedMC<SemiHonestRepSecret> MC;
|
||||
typedef Replicated<SemiHonestRepSecret> Protocol;
|
||||
typedef MC MAC_Check;
|
||||
typedef SemiHonestRepPrep LivePrep;
|
||||
typedef ReplicatedInput<SemiHonestRepSecret> Input;
|
||||
|
||||
static MC* new_mc(Machine<SemiHonestRepSecret>& _) { (void) _; return new MC; }
|
||||
|
||||
@@ -116,4 +127,4 @@ public:
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_REPLICATEDSECRET_H_ */
|
||||
#endif /* GC_SHARESECRET_H_ */
|
||||
168
GC/ShareSecret.hpp
Normal file
168
GC/ShareSecret.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
/*
|
||||
* ReplicatedSecret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ShareSecret.h"
|
||||
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "ShareThread.h"
|
||||
#include "Thread.h"
|
||||
#include "square64.h"
|
||||
|
||||
#include "Protocols/Share.h"
|
||||
|
||||
#include "Protocols/ReplicatedMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "ShareParty.h"
|
||||
#include "ShareThread.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class U>
|
||||
int ReplicatedSecret<U>::default_length = 8 * sizeof(typename ReplicatedSecret<U>::value_type);
|
||||
|
||||
template<class U>
|
||||
SwitchableOutput ShareSecret<U>::out;
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::check_length(int n, const Integer& x)
|
||||
{
|
||||
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
|
||||
throw out_of_range("public value too long");
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
|
||||
{
|
||||
this->check_length(n, x);
|
||||
*this = x;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::load(vector<ReadAccess<U> >& accesses,
|
||||
const Memory<U>& mem)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
access.dest = mem[access.address];
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::store(Memory<U>& mem,
|
||||
vector<WriteAccess<U> >& accesses)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = access.source;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
|
||||
const vector<ClearWriteAccess>& accesses)
|
||||
{
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = access.value;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
auto& party = ShareThread<U>::s();
|
||||
typename U::Input input(*party.MC, party.DataF, *party.P);
|
||||
input.reset_all(*party.P);
|
||||
|
||||
InputArgList a(args);
|
||||
bool interactive = party.n_interactive_inputs_from_me(a) > 0;
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
if (x.from == party.P->my_num())
|
||||
{
|
||||
input.add_mine(processor.get_input(x.params, interactive), x.n_bits);
|
||||
}
|
||||
else
|
||||
input.add_other(x.from);
|
||||
}
|
||||
|
||||
if (interactive)
|
||||
cout << "Thank you" << endl;
|
||||
|
||||
input.exchange();
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
int from = x.from;
|
||||
int n_bits = x.n_bits;
|
||||
auto& res = processor.S[x.dest];
|
||||
res = input.finalize(from, n_bits).mask(n_bits);
|
||||
}
|
||||
}
|
||||
|
||||
template<class U>
|
||||
BitVec ReplicatedSecret<U>::local_mul(const ReplicatedSecret& other) const
|
||||
{
|
||||
return (*this)[0] * other.sum() + (*this)[1] * other[0];
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::and_(
|
||||
Processor<U>& processor, const vector<int>& args,
|
||||
bool repeat)
|
||||
{
|
||||
ShareThread<U>::s().and_(processor, args, repeat);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::trans(Processor<U>& processor,
|
||||
int n_outputs, const vector<int>& args)
|
||||
{
|
||||
assert(length == 2);
|
||||
for (int k = 0; k < 2; k++)
|
||||
{
|
||||
square64 square;
|
||||
for (size_t i = n_outputs; i < args.size(); i++)
|
||||
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
|
||||
square.transpose(args.size() - n_outputs, n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
processor.S[args[i]][k] = square.rows[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
|
||||
{
|
||||
(void) n_bits;
|
||||
auto& share = *this;
|
||||
vector<BitVec> opened;
|
||||
auto& party = ShareThread<U>::s();
|
||||
party.MC->POpen_Begin(opened, {share}, *party.P);
|
||||
party.MC->POpen_End(opened, {share}, *party.P);
|
||||
x = IntBase(opened[0]);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::random_bit()
|
||||
{
|
||||
U res;
|
||||
ShareThread<U>::s().DataF.get_one(DATA_BIT, res);
|
||||
*this = res;
|
||||
}
|
||||
|
||||
}
|
||||
55
GC/ShareThread.h
Normal file
55
GC/ShareThread.h
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* MalicousRepParty.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SHARETHREAD_H_
|
||||
#define GC_SHARETHREAD_H_
|
||||
|
||||
#include "Thread.h"
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "RepPrep.h"
|
||||
#include "SemiHonestRepPrep.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Protocols/ReplicatedInput.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class ShareThread : public Thread<T>
|
||||
{
|
||||
static thread_local ShareThread<T>* singleton;
|
||||
|
||||
public:
|
||||
static ShareThread& s();
|
||||
|
||||
DataPositions usage;
|
||||
Preprocessing<T>& DataF;
|
||||
|
||||
ShareThread(int i, ThreadMaster<T>& master);
|
||||
virtual ~ShareThread();
|
||||
|
||||
void pre_run();
|
||||
void post_run();
|
||||
|
||||
void and_(Processor<T>& processor, const vector<int>& args, bool repeat);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
thread_local ShareThread<T>* ShareThread<T>::singleton = 0;
|
||||
|
||||
template<class T>
|
||||
inline ShareThread<T>& ShareThread<T>::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SHARETHREAD_H_ */
|
||||
88
GC/ShareThread.hpp
Normal file
88
GC/ShareThread.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
/*
|
||||
* MalicousRepParty.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SHARETHREAD_HPP_
|
||||
#define GC_SHARETHREAD_HPP_
|
||||
|
||||
#include <GC/ShareThread.h>
|
||||
#include "Protocols/MaliciousRepMC.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
#include "Processor/Data_Files.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
ShareThread<T>::ShareThread(int i,
|
||||
ThreadMaster<T>& master) :
|
||||
Thread<T>(i, master), usage(master.N.num_players()), DataF(
|
||||
master.opts.live_prep ?
|
||||
*(Preprocessing<T>*) new typename T::LivePrep(usage,
|
||||
*this) :
|
||||
*(Preprocessing<T>*) new Sub_Data_Files<T>(master.N,
|
||||
get_prep_dir(master.N.num_players(), 128, 128),
|
||||
usage))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ShareThread<T>::~ShareThread()
|
||||
{
|
||||
delete &DataF;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShareThread<T>::pre_run()
|
||||
{
|
||||
if (singleton)
|
||||
throw runtime_error("there can only be one");
|
||||
singleton = this;
|
||||
assert(this->protocol != 0);
|
||||
DataF.set_protocol(*this->protocol);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShareThread<T>::post_run()
|
||||
{
|
||||
#ifndef INSECURE
|
||||
cerr << "Removing used pre-processed data" << endl;
|
||||
DataF.prune();
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ShareThread<T>::and_(Processor<T>& processor,
|
||||
const vector<int>& args, bool repeat)
|
||||
{
|
||||
auto& protocol = this->protocol;
|
||||
processor.check_args(args, 4);
|
||||
protocol->init_mul(DataF, *this->MC);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
T y_ext;
|
||||
if (repeat)
|
||||
y_ext = processor.S[right].extend_bit();
|
||||
else
|
||||
y_ext = processor.S[right];
|
||||
protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits);
|
||||
}
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
int out = args[i + 1];
|
||||
processor.S[out] = protocol->finalize_mul(n_bits);
|
||||
}
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
60
GC/ShiftableTripleBuffer.h
Normal file
60
GC/ShiftableTripleBuffer.h
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* ShiftableTripleBuffer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_SHIFTABLETRIPLEBUFFER_H_
|
||||
#define GC_SHIFTABLETRIPLEBUFFER_H_
|
||||
|
||||
#include "Math/FixedVec.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class ShiftableTripleBuffer
|
||||
{
|
||||
FixedVec<T, 3> triple_buffer;
|
||||
int buffer_left;
|
||||
|
||||
virtual void get(Dtype type, T* data) = 0;
|
||||
|
||||
public:
|
||||
ShiftableTripleBuffer() :
|
||||
buffer_left(0)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~ShiftableTripleBuffer() {}
|
||||
|
||||
array<T, 3> get_triple(int n_bits)
|
||||
{
|
||||
int max_n_bits = T::N_BITS;
|
||||
assert(n_bits <= max_n_bits);
|
||||
assert(n_bits > 0);
|
||||
array<T, 3> res;
|
||||
|
||||
if (n_bits <= buffer_left)
|
||||
{
|
||||
res = triple_buffer.mask(n_bits).get();
|
||||
triple_buffer >>= n_bits;
|
||||
buffer_left -= n_bits;
|
||||
}
|
||||
else
|
||||
{
|
||||
get(DATA_TRIPLE, res.data());
|
||||
FixedVec<T, 3> tmp = res;
|
||||
res = tmp.mask(n_bits).get();
|
||||
triple_buffer += (tmp >> n_bits) << buffer_left;
|
||||
buffer_left += max_n_bits - n_bits;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_SHIFTABLETRIPLEBUFFER_H_ */
|
||||
@@ -39,7 +39,6 @@ public:
|
||||
Names& N;
|
||||
Player* P;
|
||||
PRNG secure_prng;
|
||||
vector<octetStream> os;
|
||||
|
||||
int thread_num;
|
||||
WaitQueue<ScheduleItem> tape_schedule;
|
||||
|
||||
@@ -56,6 +56,11 @@ void ThreadMaster<T>::run()
|
||||
P = new PlainPlayer(N, 0xff << 24);
|
||||
|
||||
machine.load_schedule(progname);
|
||||
|
||||
if (T::needs_ot)
|
||||
for (int i = 0; i < machine.nthreads; i++)
|
||||
machine.ot_setups.push_back({{*P, true}, {*P, true}});
|
||||
|
||||
for (int i = 0; i < machine.nthreads; i++)
|
||||
threads.push_back(new_thread(i));
|
||||
for (auto thread : threads)
|
||||
|
||||
11
GC/TinyMC.cpp
Normal file
11
GC/TinyMC.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* TinyMC.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyMC.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
67
GC/TinyMC.h
Normal file
67
GC/TinyMC.h
Normal file
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* TinyMC.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINYMC_H_
|
||||
#define GC_TINYMC_H_
|
||||
|
||||
#include "Protocols/MAC_Check_Base.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class TinyMC : public MAC_Check_Base<T>
|
||||
{
|
||||
typename T::part_type::MAC_Check part_MC;
|
||||
vector<typename T::part_type::open_type> part_values;
|
||||
vector<typename T::part_type::super> part_shares;
|
||||
|
||||
public:
|
||||
TinyMC(typename T::mac_key_type mac_key) :
|
||||
part_MC(mac_key)
|
||||
{
|
||||
this->alphai = mac_key;
|
||||
}
|
||||
|
||||
typename T::part_type::MAC_Check& get_part_MC()
|
||||
{
|
||||
return part_MC;
|
||||
}
|
||||
|
||||
void POpen_Begin(vector<typename T::open_type>& values, const vector<T>& S,
|
||||
const Player& P)
|
||||
{
|
||||
values.clear();
|
||||
part_shares.clear();
|
||||
for (auto& share : S)
|
||||
for (auto& part : share.get_regs())
|
||||
part_shares.push_back(part);
|
||||
part_MC.POpen_Begin(part_values, part_shares, P);
|
||||
}
|
||||
|
||||
void POpen_End(vector<typename T::open_type>& values, const vector<T>& S,
|
||||
const Player& P)
|
||||
{
|
||||
values.clear();
|
||||
part_MC.POpen_End(part_values, part_shares, P);
|
||||
int i = 0;
|
||||
for (auto& share : S)
|
||||
{
|
||||
typename T::open_type opened = 0;
|
||||
for (size_t j = 0; j < share.get_regs().size(); j++)
|
||||
opened += typename T::open_type(part_values[i++].get_bit(0)) << j;
|
||||
values.push_back(opened);
|
||||
}
|
||||
}
|
||||
|
||||
void Check(const Player& P)
|
||||
{
|
||||
part_MC.Check(P);
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINYMC_H_ */
|
||||
52
GC/TinyPrep.h
Normal file
52
GC/TinyPrep.h
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* TinyPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINYPREP_H_
|
||||
#define GC_TINYPREP_H_
|
||||
|
||||
#include "Thread.h"
|
||||
#include "OT/TripleMachine.h"
|
||||
#include "Protocols/Beaver.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "Protocols/RandomPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class TinyPrep : public BufferPrep<T>, public RandomPrep<typename T::part_type::super>
|
||||
{
|
||||
typedef Share<Z2<1 + T::part_type::s>> res_type;
|
||||
|
||||
Thread<T>& thread;
|
||||
|
||||
typename T::TripleGenerator* triple_generator;
|
||||
typename T::part_type::TripleGenerator* input_generator;
|
||||
MascotParams params;
|
||||
|
||||
vector<array<typename T::part_type, 3>> triple_buffer;
|
||||
|
||||
public:
|
||||
TinyPrep(DataPositions& usage, Thread<T>& thread);
|
||||
~TinyPrep();
|
||||
|
||||
void set_protocol(Beaver<T>& protocol);
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_bits();
|
||||
|
||||
void buffer_inputs(int player);
|
||||
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
typename T::part_type::super get_random();
|
||||
|
||||
array<T, 3> get_triple(int n_bits);
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINYPREP_H_ */
|
||||
174
GC/TinyPrep.hpp
Normal file
174
GC/TinyPrep.hpp
Normal file
@@ -0,0 +1,174 @@
|
||||
/*
|
||||
* TinyPrep.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
TinyPrep<T>::TinyPrep(DataPositions& usage, Thread<T>& thread) :
|
||||
BufferPrep<T>(usage), thread(thread), triple_generator(0),
|
||||
input_generator(0)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinyPrep<T>::~TinyPrep()
|
||||
{
|
||||
if (triple_generator)
|
||||
delete triple_generator;
|
||||
if (input_generator)
|
||||
delete input_generator;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
|
||||
{
|
||||
(void) protocol;
|
||||
params.generateMACs = true;
|
||||
params.amplify = false;
|
||||
params.check = false;
|
||||
params.set_mac_key(thread.MC->get_alphai());
|
||||
triple_generator = new typename T::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
|
||||
thread.master.N, thread.thread_num,
|
||||
thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
input_generator = new typename T::part_type::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(1),
|
||||
thread.master.N, thread.thread_num,
|
||||
thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
input_generator->multi_threaded = false;
|
||||
thread.MC->get_part_MC().set_prep(*this);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_triples()
|
||||
{
|
||||
auto& triple_generator = this->triple_generator;
|
||||
params.generateBits = false;
|
||||
vector<array<typename T::part_type::super, 3>> triples;
|
||||
ShuffleSacrifice<typename T::part_type::super> sacrifice;
|
||||
while (int(triples.size()) < sacrifice.minimum_n_inputs())
|
||||
{
|
||||
triple_generator->generatePlainTriples();
|
||||
triple_generator->unlock();
|
||||
assert(triple_generator->plainTriples.size() != 0);
|
||||
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
|
||||
triple_generator->valueBits[2].set_portion(i,
|
||||
triple_generator->plainTriples[i][2]);
|
||||
triple_generator->run_multipliers({});
|
||||
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
|
||||
{
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
{
|
||||
triples.push_back({});
|
||||
for (int k = 0; k < 3; k++)
|
||||
{
|
||||
auto& share = triples.back()[k];
|
||||
share.set_share(
|
||||
triple_generator->plainTriples.at(i).at(k).get_bit(
|
||||
j));
|
||||
typename T::part_type::mac_type mac;
|
||||
mac = thread.MC->get_alphai() * share.get_share();
|
||||
for (auto& multiplier : triple_generator->ot_multipliers)
|
||||
mac += multiplier->macs.at(k).at(i * T::default_length + j);
|
||||
share.set_mac(mac);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sacrifice.triple_sacrifice(triples, triples,
|
||||
*thread.P, thread.MC->get_part_MC());
|
||||
for (size_t i = 0; i < triples.size() / T::default_length; i++)
|
||||
{
|
||||
this->triples.push_back({});
|
||||
auto& triple = this->triples.back();
|
||||
for (auto& x : triple)
|
||||
x.resize_regs(T::default_length);
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
{
|
||||
auto& source_triple = triples[j + i * T::default_length];
|
||||
for (int k = 0; k < 3; k++)
|
||||
triple[k].get_reg(j) = source_triple[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_bits()
|
||||
{
|
||||
auto tmp = BufferPrep<T>::get_random_from_inputs(thread.P->num_players());
|
||||
for (auto& bit : tmp.get_regs())
|
||||
{
|
||||
this->bits.push_back({});
|
||||
this->bits.back().resize_regs(1);
|
||||
this->bits.back().get_reg(0) = bit;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_inputs(int player)
|
||||
{
|
||||
auto& inputs = this->inputs;
|
||||
inputs.resize(thread.P->num_players());
|
||||
assert(this->input_generator);
|
||||
this->input_generator->generateInputs(player);
|
||||
for (size_t i = 0; i < this->input_generator->inputs.size() / T::default_length; i++)
|
||||
{
|
||||
inputs[player].push_back({});
|
||||
inputs[player].back().share.resize_regs(T::default_length);
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
{
|
||||
auto& source_input = this->input_generator->inputs[j
|
||||
+ i * T::default_length];
|
||||
inputs[player].back().share.get_reg(j) = res_type(source_input.share);
|
||||
inputs[player].back().value ^= typename T::open_type(
|
||||
source_input.value.get_bit(0)) << j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
typename T::part_type::super GC::TinyPrep<T>::get_random()
|
||||
{
|
||||
T tmp;
|
||||
this->get_one(DATA_BIT, tmp);
|
||||
return tmp.get_reg(0);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
array<T, 3> TinyPrep<T>::get_triple(int n_bits)
|
||||
{
|
||||
assert(n_bits > 0);
|
||||
while ((unsigned)n_bits > triple_buffer.size())
|
||||
{
|
||||
array<T, 3> tmp;
|
||||
this->get(DATA_TRIPLE, tmp.data());
|
||||
for (size_t i = 0; i < tmp[0].get_regs().size(); i++)
|
||||
{
|
||||
triple_buffer.push_back(
|
||||
{ {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} });
|
||||
}
|
||||
}
|
||||
|
||||
array<T, 3> res;
|
||||
for (int j = 0; j < 3; j++)
|
||||
res[j].resize_regs(n_bits);
|
||||
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
for (int j = 0; j < 3; j++)
|
||||
res[j].get_reg(i) = triple_buffer.back()[j];
|
||||
triple_buffer.pop_back();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
11
GC/TinySecret.cpp
Normal file
11
GC/TinySecret.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* TinySecret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinySecret.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
163
GC/TinySecret.h
Normal file
163
GC/TinySecret.h
Normal file
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
* TinySecret.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINYSECRET_H_
|
||||
#define GC_TINYSECRET_H_
|
||||
|
||||
#include "Secret.h"
|
||||
#include "TinyShare.h"
|
||||
#include "ShareParty.h"
|
||||
#include "OT/Rectangle.h"
|
||||
#include "OT/BitDiagonal.h"
|
||||
|
||||
template<class T> class NPartyTripleGenerator;
|
||||
template<class T> class OTTripleGenerator;
|
||||
template<class T> class TinyMultiplier;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class TinyPrep;
|
||||
template<class T> class TinyMC;
|
||||
|
||||
template<int S>
|
||||
class TinySecret : public Secret<TinyShare<S>>
|
||||
{
|
||||
typedef TinySecret This;
|
||||
|
||||
public:
|
||||
typedef TinyShare<S> part_type;
|
||||
typedef Secret<part_type> super;
|
||||
|
||||
typedef typename part_type::mac_key_type mac_key_type;
|
||||
|
||||
typedef BitVec open_type;
|
||||
typedef BitVec clear;
|
||||
|
||||
typedef TinyMC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef Beaver<This> Protocol;
|
||||
typedef ::Input<This> Input;
|
||||
typedef TinyPrep<This> LivePrep;
|
||||
typedef Memory<This> DynamicMemory;
|
||||
|
||||
typedef OTTripleGenerator<This> TripleGenerator;
|
||||
typedef TinyMultiplier<This> Multiplier;
|
||||
typedef typename part_type::sacri_type sacri_type;
|
||||
typedef typename part_type::mac_type mac_type;
|
||||
typedef BitDiagonal Rectangle;
|
||||
|
||||
static const bool dishonest_majority = true;
|
||||
static const bool needs_ot = true;
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "T";
|
||||
}
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
return BitVec::field_type();
|
||||
}
|
||||
|
||||
static int size()
|
||||
{
|
||||
return part_type::size() * default_length;
|
||||
}
|
||||
|
||||
static MC* new_mc(Machine<This>& machine)
|
||||
{
|
||||
(void) machine;
|
||||
return new MC(ShareParty<This>::s().mac_key);
|
||||
}
|
||||
|
||||
static void store_clear_in_dynamic(Memory<This>& mem,
|
||||
const vector<ClearWriteAccess>& accesses)
|
||||
{
|
||||
auto& party = ShareThread<This>::s();
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = constant(access.value, party.P->my_num(),
|
||||
{});
|
||||
}
|
||||
|
||||
static This constant(BitVec other, int my_num, mac_key_type alphai)
|
||||
{
|
||||
This res;
|
||||
res.resize_regs(other.length());
|
||||
for (int i = 0; i < other.length(); i++)
|
||||
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
|
||||
return res;
|
||||
}
|
||||
|
||||
TinySecret()
|
||||
{
|
||||
}
|
||||
TinySecret(const super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void assign(const char* buffer)
|
||||
{
|
||||
this->resize_regs(default_length);
|
||||
for (int i = 0; i < default_length; i++)
|
||||
this->get_reg(i).assign(buffer + i * part_type::size());
|
||||
}
|
||||
|
||||
This operator-(const This& other) const
|
||||
{
|
||||
return *this + other;
|
||||
}
|
||||
|
||||
This operator*(const BitVec& other) const
|
||||
{
|
||||
This res = *this;
|
||||
for (int i = 0; i < super::size(); i++)
|
||||
if (not other.get_bit(i))
|
||||
res.get_reg(i) = {};
|
||||
return res;
|
||||
}
|
||||
|
||||
This extend_bit() const
|
||||
{
|
||||
This res;
|
||||
res.get_regs().resize(BitVec::N_BITS, this->get_reg(0));
|
||||
return res;
|
||||
}
|
||||
|
||||
This mask(int n_bits) const
|
||||
{
|
||||
This res = *this;
|
||||
res.get_regs().resize(n_bits);
|
||||
return res;
|
||||
}
|
||||
|
||||
void reveal(size_t n_bits, Clear& x)
|
||||
{
|
||||
auto& to_open = *this;
|
||||
to_open.resize_regs(n_bits);
|
||||
auto& party = ShareThread<This>::s();
|
||||
x = party.MC->POpen(to_open, *party.P);
|
||||
}
|
||||
|
||||
void output(ostream& s, bool human = true) const
|
||||
{
|
||||
assert(this->get_regs().size() == default_length);
|
||||
for (auto& reg : this->get_regs())
|
||||
reg.output(s, human);
|
||||
}
|
||||
};
|
||||
|
||||
template<int S>
|
||||
inline TinySecret<S> operator*(const BitVec& clear, const TinySecret<S>& share)
|
||||
{
|
||||
return share * clear;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINYSECRET_H_ */
|
||||
11
GC/TinyShare.cpp
Normal file
11
GC/TinyShare.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* TinyShare.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyShare.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
80
GC/TinyShare.h
Normal file
80
GC/TinyShare.h
Normal file
@@ -0,0 +1,80 @@
|
||||
/*
|
||||
* TinyShare.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINYSHARE_H_
|
||||
#define GC_TINYSHARE_H_
|
||||
|
||||
#include "ShareSecret.h"
|
||||
#include "ShareParty.h"
|
||||
#include "Secret.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<int S> class TinySecret;
|
||||
template<class T> class ShareThread;
|
||||
|
||||
template<int S>
|
||||
class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>>
|
||||
{
|
||||
typedef TinyShare This;
|
||||
|
||||
public:
|
||||
typedef Spdz2kShare<1, S> super;
|
||||
|
||||
typedef void DynamicMemory;
|
||||
|
||||
typedef NoLivePrep<This> LivePrep;
|
||||
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "tiny share";
|
||||
}
|
||||
|
||||
static ShareThread<TinySecret<S>>& get_party()
|
||||
{
|
||||
return ShareThread<TinySecret<S>>::s();
|
||||
}
|
||||
|
||||
static This new_reg()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
TinyShare()
|
||||
{
|
||||
}
|
||||
TinyShare(const typename super::super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void XOR(const This& a, const This& b)
|
||||
{
|
||||
*this = a + b;
|
||||
}
|
||||
|
||||
void public_input(bool input)
|
||||
{
|
||||
auto& party = get_party();
|
||||
*this = super::constant(input, party.P->my_num(),
|
||||
ShareParty < TinySecret < S >> ::s().mac_key);
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
TinySecret<S> tmp;
|
||||
get_party().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINYSHARE_H_ */
|
||||
@@ -55,7 +55,7 @@
|
||||
X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \
|
||||
X(SHRCI, C0 = C1 >> IMM) \
|
||||
X(SHLCI, C0 = C1 << IMM) \
|
||||
X(LDBITS, S0.load(R1, IMM)) \
|
||||
X(LDBITS, S0.load_clear(R1, IMM)) \
|
||||
X(LDMS, S0 = MSD) \
|
||||
X(STMS, MSD = S0) \
|
||||
X(LDMSI, S0 = MSI) \
|
||||
@@ -67,7 +67,7 @@
|
||||
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \
|
||||
X(STMSDI, PROC.store_dynamic_indirect(EXTRA, MD)) \
|
||||
X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA, MD)) \
|
||||
X(CONVSINT, S0.load(IMM, I1)) \
|
||||
X(CONVSINT, S0.load_clear(IMM, I1)) \
|
||||
X(CONVCINT, C0 = I1) \
|
||||
X(CONVCBIT, T::convcbit(I0, C1)) \
|
||||
X(MOVS, S0 = PS1) \
|
||||
|
||||
29
License.txt
29
License.txt
@@ -25,6 +25,11 @@ Copyright (c) 2018, The University of Bristol, Bar-Ilan University
|
||||
Please contact mks.keller@gmail.com
|
||||
The same license as for SPDZ-2 applies.
|
||||
___________________________________________________________________
|
||||
SCALE-MAMBA [https://github.com/KULeuven-COSIC/SCALE-MAMBA]
|
||||
Copyright (c) 2019, The University of Bristol, COSIC-KU Leuven
|
||||
Please contact nigel.smart@kuleuven.be
|
||||
See below for the full license.
|
||||
___________________________________________________________________
|
||||
|
||||
|
||||
University of Bristol : Open Access Software Licence
|
||||
@@ -46,3 +51,27 @@ Any use of the software for scientific publications or commercial purposes shoul
|
||||
|
||||
Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk
|
||||
|
||||
___________________________________________________________________
|
||||
|
||||
|
||||
This software incorporates components from the original SPDZ system, as well as various
|
||||
extensions. It's copyright therefore rests with the following two institutions:
|
||||
|
||||
Copyright (c) 2017, The University of Bristol, Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
|
||||
Copyright (c) 2018, COSIC-KU Leuven, Kasteelpark Arenberg 10, bus 2452, B-3001 Leuven-Heverlee, Belgium.
|
||||
|
||||
All rights reserved
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Any use of the software for commercial purposes should be reported to the nigel.smart@kuleuven.be
|
||||
This is for impact and usage monitoring purposes only.
|
||||
|
||||
Enquiries about further applications and development opportunities are welcome. Please contact nigel.smart@kuleuven.be
|
||||
|
||||
@@ -65,18 +65,6 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
|
||||
"-ip", // Flag token.
|
||||
"--ip-file-name" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"empty", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Where to obtain memory, new|old|empty (default: empty)\n\t"
|
||||
"new: copy from Player-Memory-P<i> file\n\t"
|
||||
"old: reuse previous memory in Memory-P<i>\n\t"
|
||||
"empty: create new empty memory", // Help description.
|
||||
"-m", // Flag token.
|
||||
"--memory" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
@@ -143,14 +131,13 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
|
||||
"--encrypted" // Flag token.
|
||||
);
|
||||
|
||||
string memtype, hostname, ipFileName;
|
||||
string hostname, ipFileName;
|
||||
int lg2, pnbase, opening_sum, max_broadcast;
|
||||
int my_port;
|
||||
|
||||
online_opts.finalize(opt, argc, argv);
|
||||
opt.get("--portnumbase")->getInt(pnbase);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
opt.get("--memory")->getString(memtype);
|
||||
opt.get("--hostname")->getString(hostname);
|
||||
opt.get("--ip-file-name")->getString(ipFileName);
|
||||
opt.get("--opening-sum")->getInt(opening_sum);
|
||||
@@ -192,7 +179,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
|
||||
try
|
||||
#endif
|
||||
{
|
||||
Machine<T, U>(playerno, playerNames, online_opts.progname, memtype, lg2,
|
||||
Machine<T, U>(playerno, playerNames, online_opts.progname, online_opts.memtype, lg2,
|
||||
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
|
||||
opt.get("--threads")->isSet, max_broadcast,
|
||||
opt.get("--encrypted")->isSet, online_opts.live_prep,
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <OT/TripleMachine.h>
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/OTMachine.h"
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/Setup.h"
|
||||
@@ -13,9 +12,11 @@
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
@@ -23,24 +24,10 @@ using namespace std;
|
||||
|
||||
void* run_ngenerator_thread(void* ptr)
|
||||
{
|
||||
((MascotGenerator*)ptr)->generate();
|
||||
((GeneratorThread*)ptr)->generate();
|
||||
return 0;
|
||||
}
|
||||
|
||||
MascotParams::MascotParams()
|
||||
{
|
||||
generateMACs = true;
|
||||
amplify = true;
|
||||
check = true;
|
||||
generateBits = false;
|
||||
timerclear(&start);
|
||||
}
|
||||
|
||||
void MascotParams::set_passive()
|
||||
{
|
||||
generateMACs = amplify = check = false;
|
||||
}
|
||||
|
||||
TripleMachine::TripleMachine(int argc, const char** argv) :
|
||||
nConnections(1), bonding(0)
|
||||
{
|
||||
@@ -167,9 +154,10 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
|
||||
}
|
||||
|
||||
template<class T>
|
||||
NPartyTripleGenerator<T>* TripleMachine::new_generator(OTTripleSetup& setup, int i)
|
||||
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i)
|
||||
{
|
||||
return new NPartyTripleGenerator<T>(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this);
|
||||
return new typename T::TripleGenerator(setup, N[i % nConnections], i,
|
||||
nTriplesPerThread, nloops, *this);
|
||||
}
|
||||
|
||||
void TripleMachine::run()
|
||||
@@ -186,13 +174,13 @@ void TripleMachine::run()
|
||||
PlainPlayer P(N[0], 0xF000);
|
||||
OTTripleSetup setup(P, true);
|
||||
|
||||
vector<MascotGenerator*> generators(nthreads);
|
||||
vector<GeneratorThread*> generators(nthreads);
|
||||
vector<pthread_t> threads(nthreads);
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
if (primeField)
|
||||
generators[i] = new_generator<Share<gfp1>>(setup, i);
|
||||
generators[i] = new_generator<Share<gfp>>(setup, i);
|
||||
else if (z2k)
|
||||
{
|
||||
if (z2k == 32 and z2s == 32)
|
||||
@@ -270,58 +258,3 @@ void TripleMachine::output_mac_keys()
|
||||
else
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s);
|
||||
}
|
||||
|
||||
template<> gf2n_long MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2l;
|
||||
}
|
||||
|
||||
template<> gf2n_short MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2s;
|
||||
}
|
||||
|
||||
template<> gfp1 MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyp;
|
||||
}
|
||||
|
||||
template<> Z2<48> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<64> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<32> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_long key)
|
||||
{
|
||||
mac_key2l = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_short key)
|
||||
{
|
||||
mac_key2s = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gfp1 key)
|
||||
{
|
||||
mac_keyp = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<64> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<48> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
@@ -3,10 +3,25 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/ReplicatedParty.h"
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/MaliciousRepSecret.h"
|
||||
|
||||
#include "GC/Instruction.hpp"
|
||||
#include "GC/Machine.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
#include "GC/Program.hpp"
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
GC::ReplicatedParty<GC::MaliciousRepSecret>(argc, argv);
|
||||
GC::ShareParty<GC::MaliciousRepSecret>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -3,9 +3,24 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/ReplicatedParty.h"
|
||||
#include "GC/ShareParty.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/Instruction.hpp"
|
||||
#include "GC/Machine.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
#include "GC/Program.hpp"
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
GC::ReplicatedParty<GC::SemiHonestRepSecret>(argc, argv);
|
||||
GC::ShareParty<GC::SemiHonestRepSecret>(argc, argv);
|
||||
}
|
||||
|
||||
28
Machines/semi-bin-party.cpp
Normal file
28
Machines/semi-bin-party.cpp
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* semi-bin-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/SemiSecret.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
|
||||
#include "GC/Machine.hpp"
|
||||
#include "GC/Program.hpp"
|
||||
#include "GC/Instruction.hpp"
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
#include "Protocols/ReplicatedInput.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
GC::ShareParty<GC::SemiSecret>(argc, argv);
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "Processor/Machine.h"
|
||||
#include "Processor/RingOptions.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Networking/Server.h"
|
||||
@@ -27,13 +28,24 @@ int main(int argc, const char** argv)
|
||||
int s;
|
||||
opt.get("-S")->getInt(s);
|
||||
opt.resetArgs();
|
||||
RingOptions ring_options(opt, argc, argv);
|
||||
int k = ring_options.R;
|
||||
#ifdef VERBOSE
|
||||
cerr << "Using SPDZ2k with security parameter " << s << endl;
|
||||
cerr << "Using SPDZ2k with ring length " << k << " and security parameter "
|
||||
<< s << endl;
|
||||
#endif
|
||||
if (s == 64)
|
||||
return spdz_main<Spdz2kShare<64, 64>, Share<gf2n>>(argc, argv, opt);
|
||||
else if (s == 48)
|
||||
return spdz_main<Spdz2kShare<64, 48>, Share<gf2n>>(argc, argv, opt);
|
||||
|
||||
#undef Z
|
||||
#define Z(K, S) \
|
||||
if (s == S and k == K) \
|
||||
return spdz_main<Spdz2kShare<K, S>, Share<gf2n>>(argc, argv, opt);
|
||||
|
||||
Z(64, 64)
|
||||
Z(64, 48)
|
||||
Z(72, 64)
|
||||
Z(72, 48)
|
||||
|
||||
else
|
||||
throw runtime_error("not compiled for s=" + to_string(s));
|
||||
throw runtime_error(
|
||||
"not compiled for k=" + to_string(k) + " and s=" + to_string(s));
|
||||
}
|
||||
|
||||
31
Machines/tiny-party.cpp
Normal file
31
Machines/tiny-party.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* tiny-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/TinySecret.h"
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/TinyMC.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/Instruction.hpp"
|
||||
#include "GC/Machine.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
#include "GC/Program.hpp"
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
GC::ShareParty<GC::TinySecret<40>>(argc, argv, 1000);
|
||||
}
|
||||
79
Makefile
79
Makefile
@@ -18,7 +18,7 @@ OT_EXE = ot.x ot-offline.x
|
||||
|
||||
COMMON = $(MATH) $(TOOLS) $(NETWORK)
|
||||
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
|
||||
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) BMR/Key.o
|
||||
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o
|
||||
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(OT)
|
||||
VM = $(PROCESSOR) $(COMMON)
|
||||
|
||||
@@ -35,7 +35,7 @@ DEPS := $(wildcard */*.d)
|
||||
.SECONDARY: $(OBJS)
|
||||
|
||||
|
||||
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x mascot-party.x
|
||||
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
all: overdrive she-offline cowgear-party.x
|
||||
@@ -77,7 +77,7 @@ spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Off
|
||||
|
||||
tldr:
|
||||
-echo ARCH = -march=native >> CONFIG.mine
|
||||
$(MAKE) Player-Online.x
|
||||
$(MAKE) mascot-party.x
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
tldr: mac-setup
|
||||
@@ -90,7 +90,7 @@ shamir: shamir-party.x malicious-shamir-party.x galois-degree.x
|
||||
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
|
||||
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE)
|
||||
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC)
|
||||
$(AR) -csr $@ $^
|
||||
|
||||
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
|
||||
@@ -104,47 +104,17 @@ static-dir:
|
||||
|
||||
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp))
|
||||
|
||||
Fake-Offline.x: Fake-Offline.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON)
|
||||
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) $(ECLIB)
|
||||
|
||||
Check-Offline.x: Check-Offline.o $(COMMON) $(PROCESSOR)
|
||||
$(CXX) $(CFLAGS) -o Check-Offline.x $^ $(LDLIBS)
|
||||
Check-Offline.x: $(PROCESSOR)
|
||||
|
||||
Check-Offline-Z2k.x: Check-Offline-Z2k.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o Check-Offline-Z2k.x $^ $(LDLIBS)
|
||||
|
||||
Server.x: Server.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)
|
||||
|
||||
Setup.x: Setup.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS)
|
||||
|
||||
ot.x: $(OT) $(COMMON) OT/OText_main.cpp $(LIBSIMPLEOT)
|
||||
ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
ot-check.x: $(OT) $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o ot-check.x OT/OutputCheck.cpp $(COMMON) $(LDLIBS)
|
||||
ot-offline.x: $(OT) $(LIBSIMPLEOT) Machines/TripleMachine.o
|
||||
|
||||
ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp
|
||||
$(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o $(COMMON) $(LDLIBS)
|
||||
|
||||
ot-offline.x: $(OT) $(COMMON) ot-offline.cpp $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
check-passive.x: $(COMMON) check-passive.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS)
|
||||
|
||||
gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS)
|
||||
|
||||
gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(GC)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
gc-emulate.x: $(PROCESSOR) GC/FakeSecret.o GC/square64.o
|
||||
|
||||
bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS)
|
||||
@@ -155,48 +125,41 @@ bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT)
|
||||
bmr-clean:
|
||||
-rm BMR/*.o BMR/*/*.o GC/*.o
|
||||
|
||||
client-setup.x: client-setup.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
simple-offline.x: $(COMMON) $(FHEOFFLINE) simple-offline.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
pairwise-offline.x: $(COMMON) $(FHEOFFLINE) pairwise-offline.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
cnc-offline.x: $(COMMON) $(FHEOFFLINE) cnc-offline.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
spdz2-offline.x: $(COMMON) $(FHEOFFLINE) spdz2-offline.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
endif
|
||||
simple-offline.x: $(FHEOFFLINE)
|
||||
pairwise-offline.x: $(FHEOFFLINE)
|
||||
cnc-offline.x: $(FHEOFFLINE)
|
||||
spdz2-offline.x: $(FHEOFFLINE)
|
||||
|
||||
yao-party.x: $(YAO)
|
||||
|
||||
yao-clean:
|
||||
-rm Yao/*.o
|
||||
|
||||
galois-degree.x: galois-degree.cpp
|
||||
galois-degree.x: Utils/galois-degree.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
default-prime-length.x: default-prime-length.cpp
|
||||
default-prime-length.x: Utils/default-prime-length.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
%.x: Utils/%.o $(COMMON)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
%.x: Machines/%.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
%-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) $(ECLIB)
|
||||
|
||||
replicated-bin-party.x: $(GC)
|
||||
malicious-rep-bin-party.x: $(GC)
|
||||
replicated-bin-party.x: GC/square64.o
|
||||
malicious-rep-bin-party.x: GC/square64.o
|
||||
semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
tiny-party.x: $(OT)
|
||||
shamir-party.x: Machines/ShamirMachine.o
|
||||
malicious-shamir-party.x: Machines/ShamirMachine.o
|
||||
spdz2k-party.x: $(OT)
|
||||
|
||||
@@ -8,12 +8,18 @@
|
||||
|
||||
#include "Integer.h"
|
||||
#include "field_types.h"
|
||||
#include "Square.h"
|
||||
|
||||
class BitDiagonal;
|
||||
|
||||
class BitVec : public IntBase
|
||||
{
|
||||
public:
|
||||
typedef BitVec Scalar;
|
||||
|
||||
typedef BitVec next;
|
||||
typedef BitDiagonal Square;
|
||||
|
||||
static const int n_bits = sizeof(a) * 8;
|
||||
|
||||
static char type_char() { return 'B'; }
|
||||
@@ -32,10 +38,14 @@ public:
|
||||
BitVec operator/(const BitVec& other) const { (void) other; throw not_implemented(); }
|
||||
|
||||
BitVec& operator+=(const BitVec& other) { *this ^= other; return *this; }
|
||||
BitVec& operator-=(const BitVec& other) { *this ^= other; return *this; }
|
||||
|
||||
BitVec extend_bit() const { return -(a & 1); }
|
||||
BitVec mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; }
|
||||
|
||||
template<int t>
|
||||
void add(octetStream& os) { *this += os.get<BitVec>(); }
|
||||
|
||||
void mul(const BitVec& a, const BitVec& b) { *this = a * b; }
|
||||
|
||||
void randomize(PRNG& G, int n = n_bits) { IntBase::randomize(G); *this = mask(n); }
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define MATH_FIXEDVEC_H_
|
||||
|
||||
#include <string>
|
||||
#include <array>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
@@ -21,7 +22,7 @@ template<class T> class Replicated;
|
||||
template <class T, int L>
|
||||
class FixedVec
|
||||
{
|
||||
T v[L];
|
||||
array<T, L> v;
|
||||
|
||||
public:
|
||||
typedef T value_type;
|
||||
@@ -71,6 +72,16 @@ public:
|
||||
v[i] = other[i];
|
||||
}
|
||||
|
||||
FixedVec<T, L>(const array<T, L>& other)
|
||||
{
|
||||
v = other;
|
||||
}
|
||||
|
||||
const array<T, L>& get() const
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
T& operator[](int i)
|
||||
{
|
||||
return v[i];
|
||||
|
||||
@@ -24,9 +24,12 @@ protected:
|
||||
long a;
|
||||
|
||||
public:
|
||||
static const int N_BYTES = sizeof(a);
|
||||
static const int N_BITS = 8 * sizeof(a);
|
||||
static const int MAX_N_BITS = N_BITS;
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
static int length() { return N_BITS; }
|
||||
static string type_string() { return "integer"; }
|
||||
|
||||
static void init_default(int lgp) { (void)lgp; }
|
||||
@@ -39,10 +42,12 @@ public:
|
||||
long get() const { return a; }
|
||||
bool get_bit(int i) const { return (a >> i) & 1; }
|
||||
|
||||
char* get_ptr() const { return (char*)&a; }
|
||||
|
||||
unsigned long debug() const { return a; }
|
||||
|
||||
void assign(long x) { *this = x; }
|
||||
void assign(const char* buffer) { avx_memcpy(&a, buffer, sizeof(a)); }
|
||||
void assign(const void* buffer) { avx_memcpy(&a, buffer, sizeof(a)); }
|
||||
void assign_zero() { a = 0; }
|
||||
void assign_one() { a = 1; }
|
||||
|
||||
@@ -50,8 +55,20 @@ public:
|
||||
bool is_one() const { return a == 1; }
|
||||
bool is_bit() const { return is_zero() or is_one(); }
|
||||
|
||||
long operator>>(const IntBase& other) const { return a >> other.a; }
|
||||
long operator<<(const IntBase& other) const { return a << other.a; }
|
||||
long operator>>(const IntBase& other) const
|
||||
{
|
||||
if (other.a < N_BITS)
|
||||
return (unsigned long) a >> other.a;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
long operator<<(const IntBase& other) const
|
||||
{
|
||||
if (other.a < N_BITS)
|
||||
return a << other.a;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
long operator^(const IntBase& other) const { return a ^ other.a; }
|
||||
long operator&(const IntBase& other) const { return a & other.a; }
|
||||
|
||||
@@ -121,7 +121,7 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2
|
||||
if (mkdir_p(ss.str().c_str()) == -1)
|
||||
{
|
||||
cerr << "mkdir_p(" << ss.str() << ") failed\n";
|
||||
throw file_error();
|
||||
throw file_error(ss.str());
|
||||
}
|
||||
|
||||
// Output the data
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "Square.h"
|
||||
#include "BitVec.h"
|
||||
|
||||
template<>
|
||||
void Square<gf2n_short>::to(gf2n_short& result)
|
||||
@@ -34,3 +35,11 @@ void Square<gfp1>::to(gfp1& result)
|
||||
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L);
|
||||
result.assign((void*) ans);
|
||||
}
|
||||
|
||||
template<>
|
||||
void Square<BitVec>::to(BitVec& result)
|
||||
{
|
||||
result = 0;
|
||||
for (int i = 0; i < N_ROWS; i++)
|
||||
result ^= ((rows[i] >> i) & 1) << i;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ template<class U>
|
||||
class Square
|
||||
{
|
||||
public:
|
||||
typedef U RowType;
|
||||
|
||||
static const int N_ROWS = U::MAX_N_BITS;
|
||||
static const int N_ROWS_ALLOCATED = N_ROWS;
|
||||
static const int N_COLUMNS = N_ROWS;
|
||||
@@ -21,16 +23,11 @@ public:
|
||||
|
||||
U rows[N_ROWS];
|
||||
|
||||
template<class T>
|
||||
Square& sub(const Square& other);
|
||||
template<class T>
|
||||
Square& rsub(const Square& other);
|
||||
template<class T>
|
||||
Square& sub(const void* other);
|
||||
|
||||
template <class T>
|
||||
void randomize(int row, PRNG& G) { rows[row].randomize(G); }
|
||||
template <class T>
|
||||
void conditional_add(BitVector& conditions, Square& other,
|
||||
int offset);
|
||||
void to(U& result);
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "Math/Square.h"
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
Square<U>& Square<U>::sub(const Square<U>& other)
|
||||
{
|
||||
for (int i = 0; i < U::length(); i++)
|
||||
@@ -15,7 +14,6 @@ Square<U>& Square<U>::sub(const Square<U>& other)
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
Square<U>& Square<U>::rsub(const Square<U>& other)
|
||||
{
|
||||
for (int i = 0; i < U::length(); i++)
|
||||
@@ -24,7 +22,6 @@ Square<U>& Square<U>::rsub(const Square<U>& other)
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
Square<U>& Square<U>::sub(const void* other)
|
||||
{
|
||||
U value;
|
||||
@@ -35,7 +32,6 @@ Square<U>& Square<U>::sub(const void* other)
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
void Square<U>::conditional_add(BitVector& conditions,
|
||||
Square<U>& other, int offset)
|
||||
{
|
||||
|
||||
@@ -16,6 +16,8 @@ public:
|
||||
static void init_default(int l) { (void) l; }
|
||||
|
||||
static void read_setup(int nparties, int lg2p, int gf2ndegree);
|
||||
|
||||
void normalize() {}
|
||||
};
|
||||
|
||||
#endif /* MATH_VALUEINTERFACE_H_ */
|
||||
|
||||
@@ -109,6 +109,7 @@ public:
|
||||
Z2<K+L> operator*(const Z2<L>& other) const;
|
||||
|
||||
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(); }
|
||||
Z2<K> operator*(int other) const { return *this * Z2<K>(other); }
|
||||
|
||||
Z2<K> operator/(const Z2& other) const { (void) other; throw not_implemented(); }
|
||||
|
||||
|
||||
@@ -352,7 +352,7 @@ void gf2n_short::input(istream& s,bool human)
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
throw file_error("gf2n_short input");
|
||||
}
|
||||
throw end_of_file("gf2n_short");
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ class gf2n_short
|
||||
typedef gf2n_short Scalar;
|
||||
|
||||
static const int MAX_N_BITS = 64;
|
||||
static const int N_BYTES = sizeof(a);
|
||||
|
||||
static void init_field(int nn);
|
||||
static int degree() { return n; }
|
||||
|
||||
@@ -257,7 +257,7 @@ void gf2n_long::input(istream& s,bool human)
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
throw file_error("gf2n_long input");
|
||||
}
|
||||
throw end_of_file("gf2n_long");
|
||||
}
|
||||
|
||||
@@ -100,6 +100,7 @@ class gf2n_long
|
||||
typedef ::Square<gf2n_long> Square;
|
||||
|
||||
const static int MAX_N_BITS = 128;
|
||||
const static int N_BYTES = sizeof(a);
|
||||
|
||||
typedef gf2n_long Scalar;
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class gfp_
|
||||
|
||||
static const int N_LIMBS = L;
|
||||
static const int MAX_N_BITS = 64 * L;
|
||||
static const int N_BYTES = sizeof(a);
|
||||
|
||||
template<class T>
|
||||
static void init(bool mont = true)
|
||||
|
||||
@@ -256,7 +256,7 @@ void modp_<L>::input(istream& s,const Zp_Data& ZpD,bool human)
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
throw file_error("modp input");
|
||||
}
|
||||
throw end_of_file("modp");
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#ifndef MATH_OPERATORS_H_
|
||||
#define MATH_OPERATORS_H_
|
||||
|
||||
template <class T>
|
||||
T operator*(const bool& x, const T& y) { return x ? y : T(); }
|
||||
//template <class T>
|
||||
//T operator*(const bool& x, const T& y) { return x ? y : T(); }
|
||||
//template <class T>
|
||||
//T operator*(const T& y, const bool& x) { return x ? y : T(); }
|
||||
template <class T>
|
||||
|
||||
@@ -56,10 +56,23 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
|
||||
nplayers = 0;
|
||||
portnum_base = pnb;
|
||||
string line;
|
||||
ports.clear();
|
||||
while (getline(hostsfile, line))
|
||||
{
|
||||
if (line.length() > 0 && line.at(0) != '#') {
|
||||
names.push_back(line);
|
||||
auto pos = line.find(':');
|
||||
if (pos == string::npos)
|
||||
{
|
||||
names.push_back(line);
|
||||
ports.push_back(default_port(nplayers));
|
||||
}
|
||||
else
|
||||
{
|
||||
names.push_back(line.substr(0, pos));
|
||||
int port;
|
||||
stringstream(line.substr(pos + 1)) >> port;
|
||||
ports.push_back(port);
|
||||
}
|
||||
nplayers++;
|
||||
if (nplayers_wanted > 0 and nplayers_wanted == nplayers)
|
||||
break;
|
||||
@@ -67,29 +80,18 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
|
||||
}
|
||||
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
|
||||
throw runtime_error("not enought hosts in HOSTS");
|
||||
setup_ports();
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Got list of " << nplayers << " players from file: " << endl;
|
||||
for (unsigned int i = 0; i < names.size(); i++)
|
||||
cerr << " " << names[i] << endl;
|
||||
cerr << " " << names[i] << ":" << ports[i] << endl;
|
||||
#endif
|
||||
setup_server();
|
||||
}
|
||||
|
||||
Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
int default_nplayers) :
|
||||
Names()
|
||||
int default_nplayers) : Names()
|
||||
{
|
||||
NetworkOptions network_opts(opt, argc, argv);
|
||||
opt.add(
|
||||
to_string(default_nplayers).c_str(), // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of players", // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
NetworkOptionsWithNumber network_opts(opt, argc, argv, default_nplayers, true);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
1, // Required?
|
||||
@@ -101,9 +103,7 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
opt.get("-p")->getInt(player_no);
|
||||
opt.get("-N")->getInt(nplayers);
|
||||
global_server = Server::start_networking(*this, player_no, nplayers,
|
||||
network_opts.hostname, network_opts.portnum_base);
|
||||
global_server = network_opts.start_networking(*this, player_no);
|
||||
}
|
||||
|
||||
void Names::setup_ports()
|
||||
@@ -396,6 +396,9 @@ void MultiPlayer<T>::exchange_no_stats(int other, const octetStream& o, octetStr
|
||||
|
||||
void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "Exchanging with " << other << endl;
|
||||
#endif
|
||||
TimeScope ts(comm_stats["Exchanging"].add(o));
|
||||
exchange_no_stats(other, o, to_receive);
|
||||
sent += o.get_length();
|
||||
@@ -605,34 +608,34 @@ int RealTwoPartyPlayer::other_player_num() const
|
||||
return other_player;
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::send(octetStream& o)
|
||||
void RealTwoPartyPlayer::send(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
|
||||
o.Send(socket);
|
||||
sent += o.get_length();
|
||||
}
|
||||
|
||||
void VirtualTwoPartyPlayer::send(octetStream& o)
|
||||
void VirtualTwoPartyPlayer::send(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
|
||||
P.send_to_no_stats(other_player, o);
|
||||
sent += o.get_length();
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::receive(octetStream& o)
|
||||
void RealTwoPartyPlayer::receive(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(timer);
|
||||
o.reset_write_head();
|
||||
o.Receive(socket);
|
||||
}
|
||||
|
||||
void VirtualTwoPartyPlayer::receive(octetStream& o)
|
||||
void VirtualTwoPartyPlayer::receive(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(timer);
|
||||
P.receive_player_no_stats(other_player, o);
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o)
|
||||
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
{
|
||||
{
|
||||
if (is_server)
|
||||
@@ -655,7 +658,7 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const
|
||||
o.exchange(socket, socket);
|
||||
}
|
||||
|
||||
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o)
|
||||
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0]));
|
||||
sent += o[0].get_length();
|
||||
@@ -667,11 +670,21 @@ VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) :
|
||||
{
|
||||
}
|
||||
|
||||
void OffsetPlayer::send_receive_player(vector<octetStream>& o)
|
||||
void OffsetPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
{
|
||||
P.exchange(P.get_player(offset), o[0], o[1]);
|
||||
}
|
||||
|
||||
void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o,
|
||||
bool donthash) const
|
||||
{
|
||||
(void) donthash;
|
||||
vector<octetStream> os(2);
|
||||
os[0] = o[my_num()];
|
||||
send_receive_player(os);
|
||||
o[1 - my_num()] = os[1];
|
||||
}
|
||||
|
||||
CommStats& CommStats::operator +=(const CommStats& other)
|
||||
{
|
||||
data += other.data;
|
||||
|
||||
@@ -126,9 +126,11 @@ public:
|
||||
virtual ~PlayerBase();
|
||||
|
||||
int my_real_num() const { return player_no; }
|
||||
virtual int my_num() const = 0;
|
||||
virtual int num_players() const = 0;
|
||||
|
||||
virtual void pass_around(octetStream& o, int offset = 1) const = 0;
|
||||
virtual void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const = 0;
|
||||
};
|
||||
|
||||
class Player : public PlayerBase
|
||||
@@ -276,9 +278,10 @@ public:
|
||||
virtual int my_num() const = 0;
|
||||
virtual int other_player_num() const = 0;
|
||||
|
||||
virtual void send(octetStream& o) = 0;
|
||||
virtual void receive(octetStream& o) = 0;
|
||||
virtual void send_receive_player(vector<octetStream>& o) = 0;
|
||||
virtual void send(octetStream& o) const = 0;
|
||||
virtual void receive(octetStream& o) const = 0;
|
||||
virtual void send_receive_player(vector<octetStream>& o) const = 0;
|
||||
void Broadcast_Receive(vector<octetStream>& o, bool donthash=false) const;
|
||||
};
|
||||
|
||||
class RealTwoPartyPlayer : public TwoPartyPlayer
|
||||
@@ -295,8 +298,8 @@ public:
|
||||
RealTwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
|
||||
~RealTwoPartyPlayer();
|
||||
|
||||
void send(octetStream& o);
|
||||
void receive(octetStream& o);
|
||||
void send(octetStream& o) const;
|
||||
void receive(octetStream& o) const;
|
||||
|
||||
int other_player_num() const;
|
||||
int my_num() const { return is_server; }
|
||||
@@ -305,7 +308,7 @@ public:
|
||||
/* Send and receive to/from the other player
|
||||
* - o[0] contains my data, received data put in o[1]
|
||||
*/
|
||||
void send_receive_player(vector<octetStream>& o);
|
||||
void send_receive_player(vector<octetStream>& o) const;
|
||||
|
||||
void exchange(octetStream& o) const;
|
||||
void exchange(int other, octetStream& o) const { (void)other; exchange(o); }
|
||||
@@ -326,9 +329,9 @@ public:
|
||||
int other_player_num() const { return other_player; }
|
||||
int num_players() const { return 2; }
|
||||
|
||||
void send(octetStream& o);
|
||||
void receive(octetStream& o);
|
||||
void send_receive_player(vector<octetStream>& o);
|
||||
void send(octetStream& o) const;
|
||||
void receive(octetStream& o) const;
|
||||
void send_receive_player(vector<octetStream>& o) const;
|
||||
|
||||
void pass_around(octetStream& o, int _ = 1) const { (void)_, (void) o; throw not_implemented(); }
|
||||
};
|
||||
@@ -349,16 +352,15 @@ public:
|
||||
int num_players() const { return 2; }
|
||||
int get_offset() const { return offset; }
|
||||
|
||||
void send(octetStream& o) { P.send_to(P.get_player(offset), o, true); }
|
||||
void reverse_send(octetStream& o) { P.send_to(P.get_player(-offset), o, true); }
|
||||
void receive(octetStream& o) { P.receive_player(P.get_player(offset), o, true); }
|
||||
void send(octetStream& o) const { P.send_to(P.get_player(offset), o, true); }
|
||||
void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o, true); }
|
||||
void receive(octetStream& o) const { P.receive_player(P.get_player(offset), o, true); }
|
||||
void reverse_receive(octetStream& o) { P.receive_player(P.get_player(-offset), o, true); }
|
||||
void send_receive_player(vector<octetStream>& o);
|
||||
void send_receive_player(vector<octetStream>& o) const;
|
||||
|
||||
void reverse_exchange(octetStream& o) const { P.pass_around(o, P.num_players() - offset); }
|
||||
void exchange(int other, octetStream& o) const { (void)other; P.pass_around(o, offset); }
|
||||
void exchange(octetStream& o) const { P.exchange(P.get_player(offset), o); }
|
||||
void pass_around(octetStream& o, int _ = 1) const { (void)_; P.pass_around(o, offset); }
|
||||
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -106,6 +106,8 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
receiver_maketable(&receiver);
|
||||
}
|
||||
|
||||
os[0].reset_write_head();
|
||||
|
||||
for (i = 0; i < nOT; i += 4)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
@@ -117,12 +119,24 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
cs[j] = receiver_inputs[i + j];
|
||||
}
|
||||
receiver_rsgen(&receiver, Rs_pack[0], cs);
|
||||
os[0].reset_write_head();
|
||||
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
|
||||
receiver_keygen(&receiver, receiver_keys);
|
||||
|
||||
// Copy keys to receiver_outputs
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
for (k = 0; k < AES_BLK_SIZE; k++)
|
||||
{
|
||||
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
send_if_ot_receiver(P, os, ot_role);
|
||||
}
|
||||
|
||||
send_if_ot_receiver(P, os, ot_role);
|
||||
|
||||
for (i = 0; i < nOT; i += 4)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
os[1].get_bytes((octet*) Rs_pack[1], len);
|
||||
@@ -143,18 +157,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
// Copy keys to receiver_outputs
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
for (k = 0; k < AES_BLK_SIZE; k++)
|
||||
{
|
||||
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef BASE_OT_DEBUG
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
|
||||
19
OT/BitDiagonal.cpp
Normal file
19
OT/BitDiagonal.cpp
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
* Diagonal.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <OT/BitDiagonal.h>
|
||||
|
||||
void BitDiagonal::pack(octetStream& os) const
|
||||
{
|
||||
for (int i = 0; i < N_ROWS; i++)
|
||||
os.store_int(rows[i].get_bit(i), 1);
|
||||
}
|
||||
|
||||
void BitDiagonal::unpack(octetStream& os)
|
||||
{
|
||||
*this = {};
|
||||
for (int i = 0; i < N_ROWS; i++)
|
||||
rows[i] = os.get_int(1) << i;
|
||||
}
|
||||
24
OT/BitDiagonal.h
Normal file
24
OT/BitDiagonal.h
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Diagonal.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_BITDIAGONAL_H_
|
||||
#define OT_BITDIAGONAL_H_
|
||||
|
||||
#include "Math/Square.h"
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
class BitDiagonal : public Square<BitVec>
|
||||
{
|
||||
public:
|
||||
static int size()
|
||||
{
|
||||
return 8 * BitVec::size();
|
||||
}
|
||||
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os);
|
||||
};
|
||||
|
||||
#endif /* OT_BITDIAGONAL_H_ */
|
||||
@@ -9,9 +9,12 @@
|
||||
|
||||
#include "BitMatrix.h"
|
||||
#include "Rectangle.h"
|
||||
#include "BitDiagonal.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/BitVec.h"
|
||||
#include "GC/TinySecret.h"
|
||||
|
||||
#include "OT/Rectangle.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
@@ -268,25 +271,22 @@ void square128::randomize(PRNG& G)
|
||||
G.get_octets((octet*)&rows, sizeof(rows));
|
||||
}
|
||||
|
||||
template <>
|
||||
void square128::randomize<gf2n_long>(int row, PRNG& G)
|
||||
void square128::randomize(int row, PRNG& G)
|
||||
{
|
||||
rows[row] = G.get_doubleword();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void square128::conditional_add<gf2n_long>(BitVector& conditions, square128& other, int offset)
|
||||
void square128::conditional_add(BitVector& conditions, square128& other, int offset)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
if (conditions.get_bit(128 * offset + i))
|
||||
rows[i] ^= other.rows[i];
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void square128::hash_row_wise(MMO& mmo, square128& input)
|
||||
{
|
||||
mmo.hashBlockWise<T,128>((octet*)rows, (octet*)input.rows);
|
||||
mmo.hashBlockWise<gf2n_long,128>((octet*)rows, (octet*)input.rows);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -395,20 +395,17 @@ square128& square128::operator^=(square128& other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::add<gf2n_long>(square128& other)
|
||||
square128& square128::add(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::sub<gf2n_long>(square128& other)
|
||||
square128& square128::sub(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::rsub<gf2n_long>(square128& other)
|
||||
square128& square128::rsub(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
@@ -421,8 +418,7 @@ square128& square128::operator^=(const __m128i* other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
square128& square128::sub<gf2n_long>(const __m128i* other)
|
||||
square128& square128::sub(const __m128i* other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
@@ -500,7 +496,7 @@ template <class U>
|
||||
void Matrix<U>::randomize(int row, PRNG& G)
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].template randomize<gf2n_long>(row, G);
|
||||
squares[i].randomize(row, G);
|
||||
}
|
||||
|
||||
void BitMatrix::transpose()
|
||||
@@ -597,44 +593,40 @@ Slice<U>::Slice(U& bm, size_t start, size_t size) :
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
Slice<U>& Slice<U>::rsub(Slice<U>& other)
|
||||
{
|
||||
if (bm.squares.size() < other.end)
|
||||
throw invalid_length();
|
||||
for (size_t i = other.start; i < other.end; i++)
|
||||
bm.squares[i].template rsub<T>(other.bm.squares[i]);
|
||||
bm.squares[i].rsub(other.bm.squares[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
Slice<U>& Slice<U>::sub(BitVector& other, int repeat)
|
||||
{
|
||||
if (end * U::PartType::N_COLUMNS > other.size() * repeat)
|
||||
throw invalid_length(to_string(U::PartType::N_COLUMNS));
|
||||
for (size_t i = start; i < end; i++)
|
||||
{
|
||||
bm.squares[i].template sub<T>(other.get_ptr_to_byte(i / repeat,
|
||||
bm.squares[i].sub(other.get_ptr_to_byte(i / repeat,
|
||||
U::PartType::N_ROW_BYTES));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
void Slice<U>::randomize(int row, PRNG& G)
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].template randomize<T>(row, G);
|
||||
bm.squares[i].randomize(row, G);
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
void Slice<U>::conditional_add(BitVector& conditions, U& other, bool useOffset)
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].template conditional_add<T>(conditions, other.squares[i], useOffset * i);
|
||||
bm.squares[i].conditional_add(conditions, other.squares[i], useOffset * i);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -651,7 +643,7 @@ void Slice<U>::print()
|
||||
cout << "hex / value" << endl;
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl;
|
||||
cout << T(bm.squares[0].rows[i]) << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
@@ -671,24 +663,13 @@ void Slice<U>::unpack(octetStream& os)
|
||||
bm.squares[i].unpack(os);
|
||||
}
|
||||
|
||||
#define M(N,L) Matrix<Rectangle< Z2<N>, Z2<L> > >
|
||||
|
||||
#undef XXX
|
||||
#define XXX(T,N,L) \
|
||||
template class Matrix<Rectangle< Z2<N>, Z2<L> > >; \
|
||||
template class Slice<Matrix<Rectangle< Z2<N>, Z2<L> > > >; \
|
||||
template Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& Slice< \
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > > >::rsub<T>( \
|
||||
Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& other); \
|
||||
template Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& Slice< \
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > > >::sub<T>(BitVector& other, int repeat); \
|
||||
template void Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >::conditional_add< \
|
||||
T>(BitVector& conditions, \
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > >& other, bool useOffset); \
|
||||
|
||||
#undef X
|
||||
#define X(N,L) \
|
||||
template void Slice<Matrix<Rectangle< Z2<N>, Z2<L> > > >::randomize<Z2<L> >(int row, PRNG& G); \
|
||||
XXX(Z2<L>, N, L)
|
||||
|
||||
//X(96, 160)
|
||||
@@ -700,6 +681,11 @@ Y(64, 48)
|
||||
Y(66, 64)
|
||||
Y(66, 48)
|
||||
Y(32, 32)
|
||||
Y(1, 40)
|
||||
Y(72, 48)
|
||||
Y(74, 48)
|
||||
Y(72, 64)
|
||||
Y(74, 64)
|
||||
|
||||
template class Matrix<square128>;
|
||||
|
||||
@@ -710,19 +696,15 @@ template class Slice<BM>; \
|
||||
XX(BM, gf2n_long)
|
||||
|
||||
#define XX(BM, GF) \
|
||||
template void Slice<BM >::conditional_add<GF>(BitVector& conditions, BM& other, bool useOffset); \
|
||||
template Slice<BM >& Slice<BM >::rsub<GF>(Slice<BM >& other); \
|
||||
template Slice<BM >& Slice<BM >::sub<GF>(BitVector& other, int repeat); \
|
||||
template void Slice<BM >::randomize<GF>(int row, PRNG& G); \
|
||||
//template void Slice<BM >::print<GF>();
|
||||
|
||||
BMS
|
||||
|
||||
template class Slice<Matrix<gf2n_short_square>>;
|
||||
XX(Matrix<gf2n_short_square>, gf2n_short)
|
||||
#define XXXX(BM, GF) \
|
||||
template class Slice<BM>; \
|
||||
XX(BM, GF)
|
||||
|
||||
template class Slice<Matrix<Square<gf2n_long>>>;
|
||||
XX(Matrix<Square<gf2n_long>>, gf2n_long)
|
||||
|
||||
template class Slice<Matrix<Square<gfp1>>>;
|
||||
XX(Matrix<Square<gfp1>>, gfp1)
|
||||
XXXX(Matrix<gf2n_short_square>, gf2n_short)
|
||||
XXXX(Matrix<Square<gf2n_long>>, gf2n_long)
|
||||
XXXX(Matrix<Square<gfp1>>, gfp1)
|
||||
XXXX(Matrix<BitDiagonal>, BitVec)
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
using namespace std;
|
||||
|
||||
union square128 {
|
||||
typedef int128 RowType;
|
||||
typedef gf2n_long RowType;
|
||||
|
||||
const static int N_ROWS = 128;
|
||||
const static int N_ROWS_ALLOCATED = 128;
|
||||
@@ -46,24 +46,16 @@ union square128 {
|
||||
square128& operator^=(BitVector& other);
|
||||
bool operator==(square128& other);
|
||||
|
||||
template <class T>
|
||||
square128& add(square128& other);
|
||||
template <class T>
|
||||
square128& sub(square128& other);
|
||||
template <class T>
|
||||
square128& rsub(square128& other);
|
||||
template <class T>
|
||||
square128& sub(const __m128i* other);
|
||||
template <class T>
|
||||
square128& sub(const void* other) { return sub<T>((__m128i*)other); }
|
||||
square128& sub(const void* other) { return sub((__m128i*)other); }
|
||||
|
||||
void randomize(PRNG& G);
|
||||
template <class T>
|
||||
void randomize(int row, PRNG& G);
|
||||
template <class T>
|
||||
void conditional_add(BitVector& conditions, square128& other, int offset);
|
||||
void transpose();
|
||||
template <class T>
|
||||
void hash_row_wise(MMO& mmo, square128& input);
|
||||
template <class T>
|
||||
void to(T& result);
|
||||
@@ -173,14 +165,10 @@ class Slice
|
||||
public:
|
||||
Slice(U& bm, size_t start, size_t size);
|
||||
|
||||
template <class T>
|
||||
Slice<U>& rsub(Slice<U>& other);
|
||||
template <class T>
|
||||
Slice<U>& sub(BitVector& other, int repeat = 1);
|
||||
|
||||
template <class T>
|
||||
void randomize(int row, PRNG& G);
|
||||
template <class T>
|
||||
void conditional_add(BitVector& conditions, U& other, bool useOffset = false);
|
||||
void transpose();
|
||||
|
||||
|
||||
106
OT/MascotParams.cpp
Normal file
106
OT/MascotParams.cpp
Normal file
@@ -0,0 +1,106 @@
|
||||
/*
|
||||
* TripleMachine.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <OT/TripleMachine.h>
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
MascotParams::MascotParams()
|
||||
{
|
||||
generateMACs = true;
|
||||
amplify = true;
|
||||
check = true;
|
||||
generateBits = false;
|
||||
timerclear(&start);
|
||||
}
|
||||
|
||||
void MascotParams::set_passive()
|
||||
{
|
||||
generateMACs = amplify = check = false;
|
||||
}
|
||||
|
||||
template<> gf2n_long MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2l;
|
||||
}
|
||||
|
||||
template<> gf2n_short MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2s;
|
||||
}
|
||||
|
||||
template<> gfp1 MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyp;
|
||||
}
|
||||
|
||||
template<> Z2<48> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<64> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<40> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<32> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> BitVec MascotParams::get_mac_key()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_long key)
|
||||
{
|
||||
mac_key2l = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_short key)
|
||||
{
|
||||
mac_key2s = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gfp1 key)
|
||||
{
|
||||
mac_keyp = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<64> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<48> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<40> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
@@ -24,7 +24,7 @@ class PlainTriple;
|
||||
template <class T, int N>
|
||||
using ShareTriple = ShareTriple_<T, T, N>;
|
||||
|
||||
class MascotGenerator
|
||||
class GeneratorThread
|
||||
{
|
||||
protected:
|
||||
pthread_mutex_t mutex;
|
||||
@@ -37,8 +37,8 @@ public:
|
||||
|
||||
bool multi_threaded;
|
||||
|
||||
MascotGenerator() : nTriples(0), multi_threaded(true) {}
|
||||
virtual ~MascotGenerator() {};
|
||||
GeneratorThread() : nTriples(0), multi_threaded(true) {}
|
||||
virtual ~GeneratorThread() {};
|
||||
virtual void generate() = 0;
|
||||
|
||||
void lock();
|
||||
@@ -48,7 +48,7 @@ public:
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class OTTripleGenerator : public MascotGenerator
|
||||
class OTTripleGenerator : public GeneratorThread
|
||||
{
|
||||
typedef typename T::open_type open_type;
|
||||
typedef typename T::mac_key_type mac_key_type;
|
||||
@@ -79,7 +79,7 @@ protected:
|
||||
public:
|
||||
// TwoPartyPlayer's for OTs, n-party Player for sacrificing
|
||||
vector<TwoPartyPlayer*> players;
|
||||
vector<OTMultiplierMac<sacri_type, open_type>*> ot_multipliers;
|
||||
vector<typename T::Multiplier*> ot_multipliers;
|
||||
//vector<OTMachine*> machines;
|
||||
BitVector baseReceiverInput; // same for every set of OTs
|
||||
vector< vector< vector<BitVector> > > baseSenderInputs;
|
||||
@@ -111,6 +111,8 @@ public:
|
||||
void generatePlainTriples();
|
||||
void plainTripleRound(int k = 0);
|
||||
|
||||
void run_multipliers(MultJob job);
|
||||
|
||||
size_t data_sent();
|
||||
};
|
||||
|
||||
@@ -121,8 +123,28 @@ class NPartyTripleGenerator : public OTTripleGenerator<T>
|
||||
typedef typename T::mac_key_type mac_key_type;
|
||||
typedef typename T::sacri_type sacri_type;
|
||||
|
||||
template <int K, int S>
|
||||
void generateTriplesZ2k();
|
||||
virtual void generateTriples() = 0;
|
||||
virtual void generateBits() = 0;
|
||||
|
||||
public:
|
||||
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
|
||||
vector<InputTuple<Share<sacri_type>>> inputs;
|
||||
|
||||
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
Player* parentPlayer = 0);
|
||||
virtual ~NPartyTripleGenerator() {}
|
||||
|
||||
void generate();
|
||||
void generateInputs(int player);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class MascotTripleGenerator : public NPartyTripleGenerator<T>
|
||||
{
|
||||
typedef typename T::open_type open_type;
|
||||
typedef typename T::mac_key_type mac_key_type;
|
||||
typedef typename T::sacri_type sacri_type;
|
||||
|
||||
void generateTriples();
|
||||
void generateBits();
|
||||
@@ -132,21 +154,37 @@ class NPartyTripleGenerator : public OTTripleGenerator<T>
|
||||
|
||||
void sacrifice(vector<ShareTriple_<open_type, mac_key_type, 2> >& uncheckedTriples,
|
||||
typename T::MAC_Check& MC, PRNG& G);
|
||||
|
||||
public:
|
||||
vector<T> bits;
|
||||
|
||||
MascotTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
Player* parentPlayer = 0);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class Spdz2kTripleGenerator : public NPartyTripleGenerator<T>
|
||||
{
|
||||
typedef typename T::open_type open_type;
|
||||
typedef typename T::mac_key_type mac_key_type;
|
||||
typedef typename T::sacri_type sacri_type;
|
||||
|
||||
void generateBits() { throw not_implemented(); }
|
||||
|
||||
template<class U>
|
||||
void sacrificeZ2k(vector<ShareTriple_<sacri_type, mac_key_type, 2> >& uncheckedTriples,
|
||||
void sacrificeZ2k(
|
||||
vector<
|
||||
ShareTriple_<typename T::sacri_type,
|
||||
typename T::mac_key_type, 2> >& uncheckedTriples,
|
||||
U& MC, PRNG& G);
|
||||
|
||||
public:
|
||||
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
|
||||
vector<T> bits;
|
||||
vector<InputTuple<Share<sacri_type>>> inputs;
|
||||
|
||||
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
Player* parentPlayer = 0);
|
||||
|
||||
void generate();
|
||||
void generateInputs(int player);
|
||||
void generateTriples();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#ifndef OT_NPARTYTRIPLGENERATOR_HPP_
|
||||
#define OT_NPARTYTRIPLGENERATOR_HPP_
|
||||
|
||||
#include "NPartyTripleGenerator.h"
|
||||
|
||||
#include "OT/OTExtensionWithMatrix.h"
|
||||
@@ -11,9 +14,11 @@
|
||||
#include "Tools/Subroutines.h"
|
||||
#include "Protocols/MAC_Check.h"
|
||||
#include "Protocols/Spdz2kPrep.h"
|
||||
#include "GC/SemiSecret.h"
|
||||
|
||||
#include "OT/Triple.hpp"
|
||||
#include "OT/Rectangle.hpp"
|
||||
#include "OT/OTMultiplier.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.h"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
@@ -46,6 +51,24 @@ NPartyTripleGenerator<T>::NPartyTripleGenerator(OTTripleSetup& setup,
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
MascotTripleGenerator<T>::MascotTripleGenerator(OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
|
||||
machine, parentPlayer)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
|
||||
machine, parentPlayer)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OTTripleGenerator<T>::OTTripleGenerator(OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
@@ -174,9 +197,9 @@ void NPartyTripleGenerator<T>::generate()
|
||||
timers["Generator thread"].stop();
|
||||
if (machine.output)
|
||||
cout << "Written " << nTriples << " " << T::type_string() << " outputs to " << ss.str() << endl;
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_OT
|
||||
else
|
||||
cout << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl;
|
||||
cerr << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -251,7 +274,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Share<gf2n>>::generateBits()
|
||||
void MascotTripleGenerator<Share<gf2n>>::generateBits()
|
||||
{
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
ot_multipliers[i]->inbox.push(DATA_BIT);
|
||||
@@ -288,9 +311,9 @@ void NPartyTripleGenerator<Share<gf2n>>::generateBits()
|
||||
gf2n r;
|
||||
for (int j = 0; j < nBitsToCheck; j++)
|
||||
{
|
||||
gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key<gf2n>();
|
||||
gf2n mac_sum = valueBits[0].get_bit(j) ? machine.get_mac_key<gf2n>() : 0;
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
mac_sum += ((MascotMultiplier<gf2n>*)ot_multipliers[i])->macs[0][j];
|
||||
mac_sum += ot_multipliers[i]->macs[0][j];
|
||||
bits[j].set_share(valueBits[0].get_bit(j));
|
||||
bits[j].set_mac(mac_sum);
|
||||
r.randomize(G);
|
||||
@@ -310,7 +333,7 @@ void NPartyTripleGenerator<Share<gf2n>>::generateBits()
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Share<gfp1>>::generateBits()
|
||||
void MascotTripleGenerator<Share<gfp1>>::generateBits()
|
||||
{
|
||||
generateTriples();
|
||||
}
|
||||
@@ -322,9 +345,12 @@ void NPartyTripleGenerator<T>::generateBits()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<int K, int S>
|
||||
void NPartyTripleGenerator<T>::generateTriplesZ2k()
|
||||
void Spdz2kTripleGenerator<T>::generateTriples()
|
||||
{
|
||||
const int K = T::k;
|
||||
const int S = T::s;
|
||||
auto& uncheckedTriples = this->uncheckedTriples;
|
||||
|
||||
auto& timers = this->timers;
|
||||
auto& machine = this->machine;
|
||||
auto& nTriplesPerLoop = this->nTriplesPerLoop;
|
||||
@@ -386,7 +412,7 @@ void NPartyTripleGenerator<T>::generateTriplesZ2k()
|
||||
timers["Triple computation"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
{
|
||||
c += ((Spdz2kMultiplier<K, S>*)ot_multipliers[i])->c_output[j];
|
||||
c += ot_multipliers[i]->c_output[j];
|
||||
}
|
||||
|
||||
#ifdef DEBUG_SPDZ2K
|
||||
@@ -433,36 +459,6 @@ void NPartyTripleGenerator<T>::generateTriplesZ2k()
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Spdz2kShare<32, 32>>::generateTriples()
|
||||
{
|
||||
this->generateTriplesZ2k<32, 32>();
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Spdz2kShare<64, 64>>::generateTriples()
|
||||
{
|
||||
this->generateTriplesZ2k<64, 64>();
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Spdz2kShare<64, 48>>::generateTriples()
|
||||
{
|
||||
this->generateTriplesZ2k<64, 48>();
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Spdz2kShare<66, 64>>::generateTriples()
|
||||
{
|
||||
this->generateTriplesZ2k<66, 64>();
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator<Spdz2kShare<66, 48>>::generateTriples()
|
||||
{
|
||||
this->generateTriplesZ2k<66, 48>();
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void OTTripleGenerator<U>::generatePlainTriples()
|
||||
{
|
||||
@@ -500,13 +496,15 @@ void OTTripleGenerator<U>::plainTripleRound(int k)
|
||||
|
||||
for (int j = 0; j < nPreampTriplesPerLoop; j++)
|
||||
{
|
||||
T a((char*)valueBits[0].get_ptr() + j * T::size());
|
||||
T b((char*)valueBits[1].get_ptr() + j / nAmplify * T::size());
|
||||
T a;
|
||||
a.assign((char*)valueBits[0].get_ptr() + j * T::size());
|
||||
T b;
|
||||
b.assign((char*)valueBits[1].get_ptr() + j / nAmplify * T::size());
|
||||
T c = a * b;
|
||||
timers["Triple computation"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
{
|
||||
c += dynamic_cast<typename U::Multiplier*>(ot_multipliers[i])->c_output[j];
|
||||
c += ot_multipliers[i]->c_output[j];
|
||||
}
|
||||
timers["Triple computation"].stop();
|
||||
if (machine.amplify)
|
||||
@@ -531,7 +529,7 @@ void OTTripleGenerator<U>::plainTripleRound(int k)
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void NPartyTripleGenerator<U>::generateTriples()
|
||||
void MascotTripleGenerator<U>::generateTriples()
|
||||
{
|
||||
typedef typename U::open_type T;
|
||||
|
||||
@@ -547,6 +545,7 @@ void NPartyTripleGenerator<U>::generateTriples()
|
||||
auto& outputFile = this->outputFile;
|
||||
auto& field_size = this->field_size;
|
||||
auto& nPreampTriplesPerLoop = this->nPreampTriplesPerLoop;
|
||||
auto& uncheckedTriples = this->uncheckedTriples;
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
ot_multipliers[i]->inbox.push(DATA_TRIPLE);
|
||||
@@ -626,7 +625,7 @@ void NPartyTripleGenerator<U>::generateTriples()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void NPartyTripleGenerator<T>::sacrifice(
|
||||
void MascotTripleGenerator<T>::sacrifice(
|
||||
vector<ShareTriple_<open_type, mac_key_type, 2> >& uncheckedTriples, typename T::MAC_Check& MC, PRNG& G)
|
||||
{
|
||||
auto& machine = this->machine;
|
||||
@@ -663,7 +662,7 @@ void NPartyTripleGenerator<T>::sacrifice(
|
||||
|
||||
template<class W>
|
||||
template<class U>
|
||||
void NPartyTripleGenerator<W>::sacrificeZ2k(
|
||||
void Spdz2kTripleGenerator<W>::sacrificeZ2k(
|
||||
vector<ShareTriple_<sacri_type, mac_key_type, 2> >& uncheckedTriples, U& MC, PRNG& G)
|
||||
{
|
||||
typedef sacri_type T;
|
||||
@@ -707,7 +706,7 @@ void NPartyTripleGenerator<W>::sacrificeZ2k(
|
||||
}
|
||||
|
||||
if (machine.generateBits)
|
||||
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
|
||||
throw not_implemented();
|
||||
else
|
||||
if (machine.output)
|
||||
for (int j = 0; j < nTriplesPerLoop; j++)
|
||||
@@ -716,7 +715,7 @@ void NPartyTripleGenerator<W>::sacrificeZ2k(
|
||||
|
||||
template<>
|
||||
template<class U, class V, class W, int N>
|
||||
void NPartyTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
|
||||
void MascotTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
|
||||
vector< ShareTriple_<U, V, N> >& triples, W& MC, ofstream& outputFile)
|
||||
{
|
||||
vector< Share<gfp1> > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop);
|
||||
@@ -746,7 +745,7 @@ void NPartyTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
|
||||
|
||||
template<class T>
|
||||
template<class U, class V, class W, int N>
|
||||
void NPartyTripleGenerator<T>::generateBitsFromTriples(
|
||||
void MascotTripleGenerator<T>::generateBitsFromTriples(
|
||||
vector< ShareTriple_<U, V, N> >& triples, W& MC, ofstream& outputFile)
|
||||
{
|
||||
throw how_would_that_work();
|
||||
@@ -786,22 +785,22 @@ void OTTripleGenerator<T>::print_progress(int k)
|
||||
}
|
||||
}
|
||||
|
||||
void MascotGenerator::lock()
|
||||
void GeneratorThread::lock()
|
||||
{
|
||||
pthread_mutex_lock(&mutex);
|
||||
}
|
||||
|
||||
void MascotGenerator::unlock()
|
||||
void GeneratorThread::unlock()
|
||||
{
|
||||
pthread_mutex_unlock(&mutex);
|
||||
}
|
||||
|
||||
void MascotGenerator::signal()
|
||||
void GeneratorThread::signal()
|
||||
{
|
||||
pthread_cond_signal(&ready);
|
||||
}
|
||||
|
||||
void MascotGenerator::wait()
|
||||
void GeneratorThread::wait()
|
||||
{
|
||||
if (multi_threaded)
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
@@ -821,18 +820,11 @@ void OTTripleGenerator<T>::wait_for_multipliers()
|
||||
ot_multipliers[i]->outbox.pop();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTTripleGenerator<T>::run_multipliers(MultJob job)
|
||||
{
|
||||
signal_multipliers(job);
|
||||
wait_for_multipliers();
|
||||
}
|
||||
|
||||
template class NPartyTripleGenerator<Share<gf2n_long>>;
|
||||
template class NPartyTripleGenerator<Share<gf2n_short>>;
|
||||
template class NPartyTripleGenerator<Share<gfp1>>;
|
||||
|
||||
template class OTTripleGenerator<SemiShare<gf2n>>;
|
||||
template class OTTripleGenerator<SemiShare<gfp1>>;
|
||||
template class OTTripleGenerator<Semi2kShare<64>>;
|
||||
template class OTTripleGenerator<Semi2kShare<72>>;
|
||||
|
||||
template class NPartyTripleGenerator<Spdz2kShare<32, 32>>;
|
||||
template class NPartyTripleGenerator<Spdz2kShare<64, 64>>;
|
||||
template class NPartyTripleGenerator<Spdz2kShare<64, 48>>;
|
||||
template class NPartyTripleGenerator<Spdz2kShare<66, 64>>;
|
||||
template class NPartyTripleGenerator<Spdz2kShare<66, 48>>;
|
||||
#endif
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/gf2nlong.h"
|
||||
#include "Math/BitVec.h"
|
||||
#include "GC/TinySecret.h"
|
||||
|
||||
#include "OT/Rectangle.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
@@ -71,7 +73,7 @@ void OTExtensionWithMatrix::transfer(int nOTs,
|
||||
|
||||
for (int loop = 0; loop < nloops; loop++)
|
||||
{
|
||||
extend<gf2n_long>(nOTs, newReceiverInput);
|
||||
extend(nOTs, newReceiverInput);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&totalendv, NULL);
|
||||
double elapsed = timeval_diff(&totalstartv, &totalendv);
|
||||
@@ -97,24 +99,25 @@ void OTCorrelator<U>::resize(int nOTs)
|
||||
}
|
||||
|
||||
// the template is used to denote the field of the hash output
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
|
||||
{
|
||||
extend_correlated(nOTs_requested, newReceiverInput);
|
||||
hash_outputs<T>(nOTs_requested);
|
||||
hash_outputs(nOTs_requested);
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::extend_correlated(BitVector& newReceiverInput)
|
||||
void OTExtensionWithMatrix::extend_correlated(const BitVector& newReceiverInput)
|
||||
{
|
||||
extend_correlated(newReceiverInput.size(), newReceiverInput);
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& newReceiverInput)
|
||||
void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, const BitVector& newReceiverBits)
|
||||
{
|
||||
// if (nOTs % nbaseOTs != 0)
|
||||
// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
|
||||
if (nOTs_requested == 0)
|
||||
return;
|
||||
// local copy
|
||||
auto newReceiverInput = newReceiverBits;
|
||||
if ((ot_role & RECEIVER) and (size_t)nOTs_requested != newReceiverInput.size())
|
||||
throw runtime_error("wrong number of choice bits");
|
||||
int nOTs_requested_rounded = (nOTs_requested + 127) / 128 * 128;
|
||||
@@ -133,8 +136,8 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new
|
||||
// subloop for first part to interleave communication with computation
|
||||
for (int start = 0; start < nOTs / 128; start += slice)
|
||||
{
|
||||
expand<gf2n_long>(start, slice);
|
||||
this->correlate<gf2n_long>(start, slice, newReceiverInput, true);
|
||||
expand(start, slice);
|
||||
this->correlate(start, slice, newReceiverInput, true);
|
||||
transpose(start, slice);
|
||||
}
|
||||
|
||||
@@ -164,7 +167,6 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
void OTCorrelator<U>::expand(int start, int slice)
|
||||
{
|
||||
(void)start, (void)slice;
|
||||
@@ -180,8 +182,8 @@ void OTCorrelator<U>::expand(int start, int slice)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
receiverOutputSlice.template randomize<T>(i, G_sender[i][0]);
|
||||
t1Slice.template randomize<T>(i, G_sender[i][1]);
|
||||
receiverOutputSlice.randomize(i, G_sender[i][0]);
|
||||
t1Slice.randomize(i, G_sender[i][1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,23 +191,22 @@ void OTCorrelator<U>::expand(int start, int slice)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
// randomize base receiver output
|
||||
senderOutputSlices[0].template randomize<T>(i, G_receiver[i]);
|
||||
senderOutputSlices[0].randomize(i, G_receiver[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::expand_transposed()
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
receiverOutputMatrix.squares[i/128].randomize<T>(i % 128, G_sender[i][0]);
|
||||
t1.squares[i/128].randomize<T>(i % 128, G_sender[i][1]);
|
||||
receiverOutputMatrix.squares[i/128].randomize(i % 128, G_sender[i][0]);
|
||||
t1.squares[i/128].randomize(i % 128, G_sender[i][1]);
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
senderOutputMatrices[0].squares[i/128].randomize<T>(i % 128, G_receiver[i]);
|
||||
senderOutputMatrices[0].squares[i/128].randomize(i % 128, G_receiver[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,7 +225,6 @@ void OTCorrelator<U>::setup_for_correlation(BitVector& baseReceiverInput,
|
||||
}
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
void OTCorrelator<U>::correlate(int start, int slice,
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat)
|
||||
{
|
||||
@@ -240,8 +240,8 @@ void OTCorrelator<U>::correlate(int start, int slice,
|
||||
// create correlation
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
t1Slice.template rsub<T>(receiverOutputSlice);
|
||||
t1Slice.template sub<T>(newReceiverInput, repeat);
|
||||
t1Slice.rsub(receiverOutputSlice);
|
||||
t1Slice.sub(newReceiverInput, repeat);
|
||||
t1Slice.pack(os[0]);
|
||||
|
||||
// t1 = receiverOutputMatrix;
|
||||
@@ -260,7 +260,7 @@ void OTCorrelator<U>::correlate(int start, int slice,
|
||||
{
|
||||
// u = t0 + t1 + x
|
||||
uSlice.unpack(os[1]);
|
||||
senderOutputSlices[0].template conditional_add<T>(baseReceiverInput, u, !useConstantBase);
|
||||
senderOutputSlices[0].conditional_add(baseReceiverInput, u, !useConstantBase);
|
||||
}
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&commst2, NULL);
|
||||
@@ -302,13 +302,12 @@ void OTExtensionWithMatrix::transpose(int start, int slice)
|
||||
/*
|
||||
* Hash outputs to make into random OT
|
||||
*/
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::hash_outputs(int nOTs)
|
||||
{
|
||||
hash_outputs<T>(nOTs, senderOutputMatrices, receiverOutputMatrix);
|
||||
hash_outputs(nOTs, senderOutputMatrices, receiverOutputMatrix);
|
||||
}
|
||||
|
||||
template <class T, class V>
|
||||
template <class V>
|
||||
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput)
|
||||
{
|
||||
//cout << "Hashing... " << flush;
|
||||
@@ -319,6 +318,7 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
|
||||
typedef typename V::PartType::RowType T;
|
||||
|
||||
int n_rows = V::PartType::N_ROWS_ALLOCATED;
|
||||
int n = (nOTs + n_rows - 1) / n_rows * V::PartType::N_ROWS;
|
||||
@@ -326,11 +326,6 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
|
||||
senderOutput[i].resize_vertical(n);
|
||||
receiverOutput.resize_vertical(n);
|
||||
|
||||
if (V::PartType::N_ROW_BYTES != T::size())
|
||||
throw runtime_error(
|
||||
"length mismatch for MMO hash: "
|
||||
+ to_string(V::PartType::N_ROW_BYTES) + " != "
|
||||
+ to_string(T::size()));
|
||||
if (nOTs % 8 != 0)
|
||||
throw runtime_error("number of OTs must be divisible by 8");
|
||||
|
||||
@@ -378,7 +373,7 @@ void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output)
|
||||
output.resize(nTriples);
|
||||
for (unsigned int j = 0; j < nTriples; j++)
|
||||
{
|
||||
receiverOutputMatrix.squares[j].template sub<T>(senderOutputMatrices[0].squares[j]).to(output[j]);
|
||||
receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,56 +511,36 @@ void OTExtensionWithMatrix::print_pre_expand()
|
||||
}
|
||||
|
||||
template class OTCorrelator<BitMatrix>;
|
||||
template void OTCorrelator<BitMatrix>::correlate<gf2n_long>(int start, int slice,
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat);
|
||||
|
||||
#define Z(BM,GF) \
|
||||
template void OTCorrelator<BM>::correlate<GF>(int start, int slice, \
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
|
||||
template void OTCorrelator<BM>::expand<GF>(int start, int slice); \
|
||||
template class OTCorrelator<BM>; \
|
||||
template void OTCorrelator<BM>::reduce_squares<GF>(unsigned int nTriples, \
|
||||
vector<GF>& output);
|
||||
|
||||
template class OTCorrelator<Matrix<gf2n_short_square>>;
|
||||
Z(Matrix<gf2n_short_square>, gf2n_short)
|
||||
|
||||
template class OTCorrelator<Matrix<Square<gf2n_long>>>;
|
||||
Z(Matrix<Square<gf2n_long>>, gf2n_long)
|
||||
|
||||
template class OTCorrelator<Matrix<Square<gfp1>>>;
|
||||
Z(Matrix<Square<gfp1>>, gfp1)
|
||||
|
||||
#define ZZZZ(GF) \
|
||||
template void OTExtensionWithMatrix::print_post_correlate<GF>( \
|
||||
BitVector& newReceiverInput, int j, int offset, int sender); \
|
||||
template void OTExtensionWithMatrix::extend<GF>(int nOTs_requested, \
|
||||
BitVector& newReceiverInput); \
|
||||
|
||||
#define ZZZ(GF, M) \
|
||||
template void OTExtensionWithMatrix::hash_outputs<GF, M >(int, vector<M >&, M&);
|
||||
#define MM Matrix<Rectangle<Z2<512>, Z2<160> > >
|
||||
#define ZZZ(GF, M) Z(M, GF) \
|
||||
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&);
|
||||
|
||||
ZZZZ(gfp1)
|
||||
ZZZZ(gf2n_long)
|
||||
ZZZ(Z2<160>, MM)
|
||||
ZZZ(gf2n_short, Matrix<gf2n_short_square>)
|
||||
ZZZ(gf2n_long, Matrix<Square<gf2n_long>>)
|
||||
ZZZ(gfp1, Matrix<Square<gfp1>>)
|
||||
ZZZ(BitVec, Matrix<BitDiagonal>)
|
||||
|
||||
#undef XX
|
||||
#define XX(T,U,N,L) \
|
||||
template class OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >; \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::correlate<T>(int start, int slice, \
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
|
||||
vector<U>& output); \
|
||||
template void OTExtensionWithMatrix::hash_outputs<T, Matrix<Rectangle<Z2<N>, Z2<L> > > >(int, \
|
||||
template void OTExtensionWithMatrix::hash_outputs(int, \
|
||||
std::vector<Matrix<Rectangle<Z2<N>, Z2<L> > >, std::allocator<Matrix<Rectangle<Z2<N>, Z2<L> > > > >&, \
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > >&);
|
||||
|
||||
#undef X
|
||||
#define X(N,L) \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::expand<Z2<L> >(int start, int slice); \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
|
||||
vector<Z2kRectangle<N, L> >& output); \
|
||||
XX(Z2<L>,Z2<N>,N,L)
|
||||
@@ -579,3 +554,8 @@ Y(64, 48)
|
||||
Y(66, 64)
|
||||
Y(66, 48)
|
||||
Y(32, 32)
|
||||
Y(1, 40)
|
||||
Y(72, 48)
|
||||
Y(74, 48)
|
||||
Y(72, 64)
|
||||
Y(74, 64)
|
||||
|
||||
@@ -39,12 +39,10 @@ public:
|
||||
receiverOutputMatrix(matrices[0]), t1(matrices[1]) {}
|
||||
|
||||
void resize(int nOTs);
|
||||
template <class T>
|
||||
void expand(int start, int slice);
|
||||
void setup_for_correlation(BitVector& baseReceiverInput,
|
||||
vector<U>& baseSenderOutputs,
|
||||
U& baseReceiverOutput);
|
||||
template <class T>
|
||||
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
|
||||
template <class T>
|
||||
void reduce_squares(unsigned int nTriples, vector<T>& output);
|
||||
@@ -76,14 +74,12 @@ public:
|
||||
void seed(vector<BitMatrix>& baseSenderInput,
|
||||
BitMatrix& baseReceiverOutput);
|
||||
void transfer(int nOTs, const BitVector& receiverInput);
|
||||
template <class T>
|
||||
void extend(int nOTs, BitVector& newReceiverInput);
|
||||
void extend_correlated(BitVector& newReceiverInput);
|
||||
void extend_correlated(int nOTs, BitVector& newReceiverInput);
|
||||
void extend_correlated(const BitVector& newReceiverInput);
|
||||
void extend_correlated(int nOTs, const BitVector& newReceiverInput);
|
||||
void transpose(int start, int slice);
|
||||
template <class T>
|
||||
void expand_transposed();
|
||||
template <class T, class V>
|
||||
template <class V>
|
||||
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput);
|
||||
|
||||
void print(BitVector& newReceiverInput, int i = 0);
|
||||
@@ -100,7 +96,6 @@ public:
|
||||
octet* get_sender_output(int choice, int i);
|
||||
|
||||
protected:
|
||||
template <class T>
|
||||
void hash_outputs(int nOTs);
|
||||
};
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ public:
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MascotMultiplier : public OTMultiplier<Share<T>>
|
||||
class MascotMultiplier : public OTMultiplier<T>
|
||||
{
|
||||
OTCorrelator<Matrix<typename T::Square> > auth_ot_ext;
|
||||
void after_correlation();
|
||||
@@ -88,13 +88,32 @@ class MascotMultiplier : public OTMultiplier<Share<T>>
|
||||
const vector<BitVector>& baseReceiverOutput);
|
||||
|
||||
public:
|
||||
vector<T> c_output;
|
||||
vector<typename T::open_type> c_output;
|
||||
|
||||
MascotMultiplier(OTTripleGenerator<Share<T>>& generator, int thread_num);
|
||||
MascotMultiplier(OTTripleGenerator<T>& generator, int thread_num);
|
||||
|
||||
void multiplyForInputs(MultJob job);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class TinyMultiplier : public OTMultiplier<T>
|
||||
{
|
||||
OTVole<typename T::part_type::sacri_type,
|
||||
typename T::part_type::mac_key_type> mac_vole;
|
||||
|
||||
void after_correlation();
|
||||
void init_authenticator(const BitVector& baseReceiverInput,
|
||||
const vector< vector<BitVector> >& baseSenderInput,
|
||||
const vector<BitVector>& baseReceiverOutput);
|
||||
|
||||
public:
|
||||
vector<typename T::open_type> c_output;
|
||||
|
||||
TinyMultiplier(OTTripleGenerator<T>& generator, int thread_num);
|
||||
|
||||
void multiplyForInputs(MultJob job) { (void) job; throw not_implemented(); }
|
||||
};
|
||||
|
||||
template <int K, int S> class Spdz2kShare;
|
||||
|
||||
template <int K, int S>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/Rectangle.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/BitVec.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "Protocols/Semi2kShare.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
@@ -37,14 +38,28 @@ OTMultiplier<T>::OTMultiplier(OTTripleGenerator<T>& generator,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
MascotMultiplier<T>::MascotMultiplier(OTTripleGenerator<Share<T>>& generator,
|
||||
MascotMultiplier<T>::MascotMultiplier(OTTripleGenerator<T>& generator,
|
||||
int thread_num) :
|
||||
OTMultiplier<Share<T>>(generator, thread_num),
|
||||
OTMultiplier<T>(generator, thread_num),
|
||||
auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true)
|
||||
{
|
||||
c_output.resize(generator.nTriplesPerLoop);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinyMultiplier<T>::TinyMultiplier(OTTripleGenerator<T>& generator,
|
||||
int thread_num) :
|
||||
OTMultiplier<T>(generator, thread_num),
|
||||
mac_vole(
|
||||
128, 128, 0, 1,
|
||||
generator.players[thread_num],
|
||||
{ },
|
||||
{ },
|
||||
{ }, BOTH, false)
|
||||
{
|
||||
c_output.resize(generator.nTriplesPerLoop);
|
||||
}
|
||||
|
||||
template <int K, int S>
|
||||
Spdz2kMultiplier<K, S>::Spdz2kMultiplier(OTTripleGenerator<Spdz2kShare<K, S>>& generator, int thread_num) :
|
||||
OTMultiplier<Spdz2kShare<K, S>>
|
||||
@@ -75,7 +90,7 @@ template<class T>
|
||||
void OTMultiplier<T>::multiply()
|
||||
{
|
||||
keyBits.set(generator.machine.template get_mac_key<typename T::mac_key_type>());
|
||||
rot_ext.extend<gf2n_long>(keyBits.size(), keyBits);
|
||||
rot_ext.extend(keyBits.size(), keyBits);
|
||||
this->outbox.push({});
|
||||
senderOutput.resize(keyBits.size());
|
||||
for (size_t j = 0; j < keyBits.size(); j++)
|
||||
@@ -123,7 +138,6 @@ void OTMultiplier<T>::multiply()
|
||||
template<class W>
|
||||
void OTMultiplier<W>::multiplyForTriples()
|
||||
{
|
||||
typedef typename W::open_type T;
|
||||
typedef typename W::Rectangle X;
|
||||
|
||||
// dummy input for OT correlator
|
||||
@@ -148,13 +162,13 @@ void OTMultiplier<W>::multiplyForTriples()
|
||||
BitVector aBits = generator.valueBits[0];
|
||||
//timers["Extension"].start();
|
||||
rot_ext.extend_correlated(aBits);
|
||||
rot_ext.hash_outputs<T>(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
//timers["Extension"].stop();
|
||||
|
||||
//timers["Correlation"].start();
|
||||
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
|
||||
baseReceiverOutput);
|
||||
otCorrelator.template correlate<T>(0, generator.nPreampTriplesPerLoop,
|
||||
otCorrelator.correlate(0, generator.nPreampTriplesPerLoop,
|
||||
generator.valueBits[1], false, generator.nAmplify);
|
||||
//timers["Correlation"].stop();
|
||||
|
||||
@@ -171,6 +185,14 @@ void MascotMultiplier<T>::init_authenticator(const BitVector& keyBits,
|
||||
this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyMultiplier<T>::init_authenticator(const BitVector& keyBits,
|
||||
const vector<vector<BitVector> >& senderOutput,
|
||||
const vector<BitVector>& receiverOutput)
|
||||
{
|
||||
mac_vole.init(keyBits, senderOutput, receiverOutput);
|
||||
}
|
||||
|
||||
template <int K, int S>
|
||||
void Spdz2kMultiplier<K, S>::init_authenticator(const BitVector& keyBits,
|
||||
const vector< vector<BitVector> >& senderOutput,
|
||||
@@ -188,9 +210,11 @@ void SemiMultiplier<T>::after_correlation()
|
||||
this->outbox.push({});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void MascotMultiplier<T>::after_correlation()
|
||||
template <class U>
|
||||
void MascotMultiplier<U>::after_correlation()
|
||||
{
|
||||
typedef typename U::open_type T;
|
||||
|
||||
this->auth_ot_ext.resize(
|
||||
this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS);
|
||||
this->auth_ot_ext.set_role(BOTH);
|
||||
@@ -210,8 +234,8 @@ void MascotMultiplier<T>::after_correlation()
|
||||
int nValues = this->generator.nTriplesPerLoop;
|
||||
if (this->generator.machine.check && (j % 2 == 0))
|
||||
nValues *= 2;
|
||||
this->auth_ot_ext.template expand<T>(0, nValues);
|
||||
this->auth_ot_ext.template correlate<T>(0, nValues,
|
||||
this->auth_ot_ext.expand(0, nValues);
|
||||
this->auth_ot_ext.correlate(0, nValues,
|
||||
this->generator.valueBits[j], true);
|
||||
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
|
||||
}
|
||||
@@ -219,6 +243,29 @@ void MascotMultiplier<T>::after_correlation()
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void TinyMultiplier<T>::after_correlation()
|
||||
{
|
||||
this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop,
|
||||
this->c_output);
|
||||
|
||||
this->outbox.push({});
|
||||
|
||||
this->macs.resize(3);
|
||||
MultJob job;
|
||||
this->inbox.pop(job);
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
int nValues = this->generator.nTriplesPerLoop * T::default_length;
|
||||
auto& bits = this->generator.valueBits[j];
|
||||
vector<typename T::part_type::sacri_type> values(nValues);
|
||||
for (int i = 0; i < nValues; i++)
|
||||
values[i] = bits.get_bit(i);
|
||||
mac_vole.evaluate(this->macs[j], values);
|
||||
}
|
||||
this->outbox.push(job);
|
||||
}
|
||||
|
||||
template <int K, int S>
|
||||
void Spdz2kMultiplier<K, S>::after_correlation()
|
||||
{
|
||||
@@ -295,9 +342,9 @@ void OTMultiplier<Share<gf2n>>::multiplyForBits()
|
||||
|
||||
for (int i = 0; i < generator.nloops; i++)
|
||||
{
|
||||
auth_ot_ext.expand<gf2n_long>(0, nBlocks);
|
||||
auth_ot_ext.expand(0, nBlocks);
|
||||
inbox.pop(job);
|
||||
auth_ot_ext.correlate<gf2n_long>(0, nBlocks, generator.valueBits[0], true);
|
||||
auth_ot_ext.correlate(0, nBlocks, generator.valueBits[0], true);
|
||||
auth_ot_ext.transpose(0, nBlocks);
|
||||
|
||||
for (int j = 0; j < nBits; j++)
|
||||
@@ -311,8 +358,8 @@ void OTMultiplier<Share<gf2n>>::multiplyForBits()
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void MascotMultiplier<T>::multiplyForInputs(MultJob job)
|
||||
template<class U>
|
||||
void MascotMultiplier<U>::multiplyForInputs(MultJob job)
|
||||
{
|
||||
assert(job.input);
|
||||
auto& generator = this->generator;
|
||||
@@ -320,10 +367,10 @@ void MascotMultiplier<T>::multiplyForInputs(MultJob job)
|
||||
auth_ot_ext.set_role(mine ? RECEIVER : SENDER);
|
||||
int nOTs = job.n_inputs * generator.field_size;
|
||||
auth_ot_ext.resize(nOTs);
|
||||
auth_ot_ext.template expand<T>(0, job.n_inputs);
|
||||
auth_ot_ext.expand(0, job.n_inputs);
|
||||
if (mine)
|
||||
this->inbox.pop();
|
||||
auth_ot_ext.template correlate<T>(0, job.n_inputs, generator.valueBits[0], true);
|
||||
auth_ot_ext.correlate(0, job.n_inputs, generator.valueBits[0], true);
|
||||
auto& input_macs = this->input_macs;
|
||||
input_macs.resize(job.n_inputs);
|
||||
if (mine)
|
||||
@@ -355,24 +402,3 @@ void OTMultiplier<T>::multiplyForBits()
|
||||
{
|
||||
throw runtime_error("bit generation not implemented in this case");
|
||||
}
|
||||
|
||||
template class OTMultiplier<Share<gf2n>>;
|
||||
template class OTMultiplier<Share<gfp1>>;
|
||||
template class OTMultiplier<SemiShare<gf2n>>;
|
||||
template class OTMultiplier<SemiShare<gfp1>>;
|
||||
template class SemiMultiplier<SemiShare<gf2n>>;
|
||||
template class SemiMultiplier<SemiShare<gfp1>>;
|
||||
template class SemiMultiplier<Semi2kShare<64>>;
|
||||
template class SemiMultiplier<Semi2kShare<72>>;
|
||||
template class MascotMultiplier<gf2n_long>;
|
||||
template class MascotMultiplier<gf2n_short>;
|
||||
template class MascotMultiplier<gfp1>;
|
||||
|
||||
#define X(K, S) \
|
||||
template class Spdz2kMultiplier<K, S>; \
|
||||
template class OTMultiplier<Spdz2kShare<K, S>>;
|
||||
X(64, 64)
|
||||
X(64, 48)
|
||||
X(66, 64)
|
||||
X(66, 48)
|
||||
X(32, 32)
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "OT/BaseOT.h"
|
||||
#include "OT/OTMachine.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user