Machine learning functionality, dishonest-majority binary secret sharing.

This commit is contained in:
Marcel Keller
2019-10-11 15:46:31 +11:00
parent 5f0a7ad8e3
commit 7a5195d83c
203 changed files with 6256 additions and 1485 deletions

View File

@@ -1,5 +1,12 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.1.2
- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission
- Binary computation for dishonest majority using secret sharing
- Mathematical functions from [SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA)
- Fixed security bug: CowGear would reuse triples.
## 0.1.1 (Aug 6, 2019)
- ECDSA

View File

@@ -101,12 +101,12 @@ def determine_scope(block, options):
used_from_scope = set()
def find_in_scope(reg, scope):
if scope is None:
return False
elif reg in scope.defined_registers:
return True
else:
return find_in_scope(reg, scope.scope)
while True:
if scope is None:
return False
elif reg in scope.defined_registers:
return True
scope = scope.scope
def read(reg, n):
if last_def[reg] == -1:
@@ -386,7 +386,7 @@ class Merger:
last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
last_text_input = None
last_text_input = [None, None]
depths = [0] * len(block.instructions)
self.depths = depths
@@ -474,10 +474,14 @@ class Merger:
# will be merged
if isinstance(instr, TextInputInstruction):
if last_text_input is not None and \
type(block.instructions[last_text_input]) is not type(instr):
add_edge(last_text_input, n)
last_text_input = n
if last_text_input[0] is not None:
if instr.merge_id() != \
block.instructions[last_text_input[0]].merge_id():
add_edge(last_text_input[0], n)
last_text_input[1] = last_text_input[0]
elif last_text_input[1] is not None:
add_edge(last_text_input[1], n)
last_text_input[0] = n
if isinstance(instr, merge_classes):
open_nodes.add(n)

View File

@@ -80,6 +80,12 @@ def LTZ(s, a, k, kappa):
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
def LessThanZero(a, k, kappa):
import types
res = types.sint()
LTZ(res, a, k, kappa)
return res
def Trunc(d, a, k, m, kappa, signed):
"""
d = a >> m
@@ -153,6 +159,8 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
k: bit length of a
m: compile-time integer
"""
if m == 0:
return a
if k == int(program.options.ring):
# cannot work with bit length k+1
tmp = TruncRing(None, a, k, m - 1, signed)
@@ -359,7 +367,7 @@ def CarryOutAux(d, a, kappa):
movs(d, a[0][1])
# carry out with carry-in bit c
def CarryOut(res, a, b, c, kappa):
def CarryOut(res, a, b, c=0, kappa=None):
"""
res = last carry bit in addition of a and b
@@ -368,8 +376,9 @@ def CarryOut(res, a, b, c, kappa):
c: initial carry-in bit
"""
k = len(a)
import types
d = [program.curr_block.new_reg('s') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
t = [[types.sint() for i in range(k)] for i in range(4)]
s = [program.curr_block.new_reg('s') for i in range(3)]
for i in range(k):
mulm(t[0][i], b[i], a[i])
@@ -377,12 +386,19 @@ def CarryOut(res, a, b, c, kappa):
addm(t[2][i], b[i], a[i])
subs(t[3][i], t[2][i], t[1][i])
d[i] = [t[3][i], t[0][i]]
mulsi(s[0], d[-1][0], c)
adds(s[1], d[-1][1], s[0])
s[0] = d[-1][0] * c
s[1] = d[-1][1] + s[0]
d[-1][1] = s[1]
CarryOutAux(res, d[::-1], kappa)
def CarryOutLE(a, b, c=0):
""" Little-endian version """
import types
res = types.sint()
CarryOut(res, a[::-1], b[::-1], c)
return res
def BitLTL(res, a, b, kappa):
"""
res = a <? b (logarithmic rounds version)

View File

@@ -47,7 +47,10 @@ COST = { 'modp': defaultdict(lambda: 0,
'bittriple': 0.00004828818388140422,
'bitgf2ntriple': 0.00020716801325875284,
'PreMulC': 2 * 0.00020716801325875284,
})
}),
'all': { 'round': 0,
'inv': 0,
}
}

View File

@@ -325,7 +325,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
def Pow2(a, l, kappa):
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
x = [types.sint() for i in range(m)]
return Pow2_from_bits(t)
def Pow2_from_bits(bits):
m = len(bits)
t = list(bits)
pow2k = [types.cint() for i in range(m)]
for i in range(m):
pow2k[i] = two_power(2**i)
@@ -353,13 +357,20 @@ def B2U_from_Pow2(pow2a, l, kappa):
#print ' '.join(str(b.value) for b in y)
return [1 - y[i] for i in range(l)]
def Trunc(a, l, m, kappa, compute_modulo=False):
def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
""" Oblivious truncation by secret m """
if util.is_constant(m) and not compute_modulo:
# cheaper
res = type(a)(size=a.size)
comparison.Trunc(res, a, l, m, kappa, signed=signed)
return res
if l == 1:
if compute_modulo:
return a * m, 1 + m
else:
return a * (1 - m)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, Pow2(m, l, kappa))
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
@@ -370,8 +381,6 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
x, pow2m = B2U(m, l, kappa)
#assert(pow2m.value == 2**m.value)
#assert(sum(b.value for b in x) == m.value)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, pow2m)
for i in range(l):
bit(r[i])
t1 = two_power(i) * r[i]
@@ -495,17 +504,28 @@ def TruncPrRing(a, k, m, signed=True):
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
else:
from types import sint
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m))
if signed:
a += (1 << (k - 1))
if program.Program.prog.use_trunc_pr:
res = sint()
trunc_pr(res, a, k, m)
else:
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + \
(overflow << (k - m))
if signed:
res -= (1 << (k - m - 1))
return res
def TruncPrField(a, k, m, kappa=None):
if m == 0:
return a
if kappa is None:
kappa = 40
@@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False):
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest)
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest,
signed=False)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
signed=False)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest, signed=False)
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
return y

View File

@@ -894,6 +894,55 @@ class inputfloat(base.TextInputInstruction):
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())
@base.vectorize
class inputmixed(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTMIXED']
field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM)
types = {
0: (1, 0),
1: (1, 1),
2: (4, 1)
}
type_ids = {
'int': 0,
'fix': 1,
'float': 2
}
def __init__(self, name, *args):
try:
type_id = self.type_ids[name]
except:
pass
super(inputmixed_class, self).__init__(type_id, *args)
@property
def arg_format(self):
for i in self.bases():
t = self.args[i]
yield 'int'
for j in range(self.types[t][0]):
yield 'sw'
for j in range(self.types[t][1]):
yield 'int'
yield 'p'
def bases(self):
i = 0
while i < len(self.args):
yield i
i += sum(self.types[self.args[i]]) + 2
def add_usage(self, req_node):
for i in self.bases():
t = self.args[i]
player = self.args[i + sum(self.types[t]) + 1]
n_dest = self.types[t][0]
req_node.increment((self.field_type, 'input', player), \
n_dest * self.get_size())
@base.gf2n
class startinput(base.RawInputInstruction):
r""" Receive inputs from player $p$. """
@@ -957,6 +1006,11 @@ class print_reg_plain(base.IOInstruction):
code = base.opcodes['PRINTREGPLAIN']
arg_format = ['c']
class cond_print_plain(base.IOInstruction):
r""" Conditionally print the value of a register. """
code = base.opcodes['CONDPRINTPLAIN']
arg_format = ['c', 'c']
class print_int(base.IOInstruction):
r""" Print only the value of register \verb|ci| to stdout. """
__slots__ = []
@@ -1383,6 +1437,9 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
def merge_id(self):
# can merge different sizes
# but not if large
if self.get_size() > 100:
return type(self), self.get_size()
return type(self)
# def expand(self):
@@ -1468,6 +1525,14 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
@base.vectorize
class trunc_pr(base.VarArgsInstruction):
""" Probalistic truncation for semi-honest computation """
""" with honest majority """
__slots__ = []
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
###
### CISC-style instructions
###

View File

@@ -89,6 +89,7 @@ opcodes = dict(
MULS = 0xA6,
MULRS = 0xA7,
DOTPRODS = 0xA8,
TRUNC_PR = 0xA9,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -102,6 +103,7 @@ opcodes = dict(
INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
INPUTMIXED = 0xF2,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -168,6 +170,7 @@ opcodes = dict(
READFILESHARE = 0xBE,
CONDPRINTSTR = 0xBF,
PRINTFLOATPREC = 0xE0,
CONDPRINTPLAIN = 0xE1,
GBITDEC = 0x184,
GBITCOM = 0x185,
# Secure socket
@@ -767,21 +770,6 @@ class ClearShiftInstruction(ClearImmediate):
### Jumps etc
###
class dummywrite(Instruction):
""" Dummy instruction to create source node in the dependency graph,
preventing read-before-write warnings. """
__slots__ = []
def __init__(self, *args, **kwargs):
self.arg_format = [arg.reg_type + 'w' for arg in args]
super(dummywrite, self).__init__(*args, **kwargs)
def execute(self):
pass
def get_encoding(self):
return []
class JumpInstruction(Instruction):
__slots__ = ['jump_arg']

View File

@@ -5,6 +5,7 @@ from Compiler import instructions,instructions_base,comparison,program,util
import inspect,math
import random
import collections
import operator
def get_program():
return instructions.program
@@ -93,16 +94,25 @@ def print_ln(s='', *args):
print_str(s, *args)
print_char('\n')
def print_ln_if(cond, s):
def print_ln_if(cond, ss, *args):
if util.is_constant(cond):
if cond:
print_ln(s)
print_ln(ss, *args)
else:
s += ' ' * ((len(s) + 3) % 4)
s += '\n'
while s:
cond.print_if(s[:4])
s = s[4:]
subs = ss.split('%s')
assert len(subs) == len(args) + 1
cond = cint.conv(cond)
for i, s in enumerate(subs):
if i != 0:
cond_print_plain(cond, cint.conv(args[i - 1]))
if i < len(args):
s += ' ' * ((-len(s)) % 4)
else:
s += ' ' * ((-len(s) + 3) % 4)
s += '\n'
while s:
cond.print_if(s[:4])
s = s[4:]
def print_float_precision(n):
print_float_prec(n)
@@ -798,19 +808,23 @@ def range_loop(loop_body, start, stop=None, step=None):
lambda x: ((stop - start) / step) * x[0]
def for_range(start, stop=None, step=None):
""" Execute loop bodies consecutively """
def decorator(loop_body):
range_loop(loop_body, start, stop, step)
return loop_body
return decorator
def for_range_parallel(n_parallel, n_loops):
""" Execute up to n_parallel loop bodies in parallel """
return map_reduce_single(n_parallel, n_loops)
def for_range_opt(n_loops):
return map_reduce_single(None, n_loops)
def for_range_opt(n_loops, budget=None):
""" Execute loop bodies in parallel up to an optimization budget """
return map_reduce_single(None, n_loops, budget=budget)
def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
reducer=lambda *x: [], mem_state=None):
reducer=lambda *x: [], mem_state=None, budget=None):
budget = budget or get_program().budget
if not (isinstance(n_parallel, int) or n_parallel is None):
raise CompilerException('Number of parallel executions' \
'must be constant')
@@ -848,14 +862,16 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
r = reducer(mem_state, state)
write_state_to_memory(r)
else:
n_parallel_reg = MemValue(regint(0))
if n_loops == 0:
return
regint.push(0)
parent_block = get_block()
@while_do(lambda x: x + n_parallel_reg <= n_loops, regint(0))
@while_do(lambda x: x + regint.pop() <= n_loops, regint(0))
def _(i):
state = tuplify(initializer())
k = 0
block = get_block()
while k < n_loops and (len(get_block()) < get_program().budget \
while k < n_loops and (len(get_block()) < budget \
or k == 0) \
and block is get_block():
j = i + k
@@ -865,7 +881,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
write_state_to_memory(r)
global n_opt_loops
n_opt_loops = k
n_parallel_reg.write(k)
regint.push(k)
return i + k
my_n_parallel = n_opt_loops
loop_rounds = n_loops / my_n_parallel
@@ -915,12 +931,46 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
return decorator
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
"""
Execute loop bodies in up to n_threads threads,
up to n_parallel in parallel per thread
"""
return map_reduce(n_threads, n_parallel, n_loops, \
lambda *x: [], lambda *x: [], thread_mem_req)
def for_range_opt_multithread(n_threads, n_loops):
"""
Execute loop bodies in up to n_threads threads,
in parallel up to an optimization budget per thread
"""
return for_range_multithread(n_threads, None, n_loops)
def multithread(n_threads, n_items):
"""
Distribute the computation of n_items to n_threads threads,
but leave the in-thread repetition up to the user
"""
if n_threads == 1 or n_items == 1:
return lambda loop_body: loop_body(0, n_items)
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
thread_mem_req={}):
thread_mem_req={}, looping=True):
n_threads = n_threads or 1
if isinstance(n_loops, list):
split = n_loops
n_loops = reduce(operator.mul, n_loops)
def decorator(loop_body):
def new_body(i):
indices = []
for n in reversed(split):
indices.insert(0, i % n)
i /= n
return loop_body(*indices)
return new_body
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
return lambda loop_body: new_dec(decorator(loop_body))
if n_threads == 1 or n_loops == 1:
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
if thread_mem_req:
@@ -937,12 +987,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
state = tuple(initializer())
def f(inc):
base = args[get_arg()][0]
if not looping:
return loop_body(base, thread_rounds + inc)
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint, \
args[get_arg()].address + 2)
mem_state = Array(len(state), type(state[0]) \
if state else cint, args[get_arg()][1])
base = args[get_arg()][0]
@map_reduce_single(n_parallel, thread_rounds + inc, \
initializer, reducer, mem_state)
def f(i):
@@ -1014,8 +1066,9 @@ def while_loop(loop_body, condition, arg):
pushint(arg if isinstance(arg,regint) else regint(arg))
def loop_fn():
result = loop_body(regint.pop())
cont = condition(result)
pushint(result)
return condition(result)
return cont
if_statement(pre_condition, lambda: do_while(loop_fn))
regint.pop()
@@ -1278,7 +1331,7 @@ def sint_cint_division(a, b, k, f, kappa):
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b < 0)
sign_a = sint(1) - 2 * sint(a < 0)
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
@@ -1326,7 +1379,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
y = a.extend(2 *k) * w
y = y.round(2*k, f, kappa, nearest, signed=True)
for i in range(theta):
for i in range(theta - 1):
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
x = x * x
@@ -1358,7 +1411,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
# For simplex, we can get rid of computing abs(b)
temp = None
if simplex_flag == False:
temp = b.less_than(0, 2 * k)
temp = comparison.LessThanZero(b, 2 * k, kappa)
elif simplex_flag == True:
temp = cint(0)

814
Compiler/ml.py Normal file
View File

@@ -0,0 +1,814 @@
import mpc_math, math
from Compiler.types import *
from Compiler.types import _unreduced_squant
from Compiler.library import *
def log_e(x):
return mpc_math.log_fx(x, math.e)
def exp(x):
return mpc_math.pow_fx(math.e, x)
def sanitize(x, raw, lower, upper):
exp_limit = 2 ** (x.k - x.f - 1)
limit = math.log(exp_limit)
if get_program().options.ring:
res = raw
else:
res = (x > limit).if_else(upper, raw)
return (x < -limit).if_else(lower, res)
def sigmoid(x):
return sigmoid_from_e_x(x, exp(-x))
def sigmoid_from_e_x(x, e_x):
return sanitize(x, 1 / (1 + e_x), 0, 1)
def sigmoid_prime(x):
sx = sigmoid(x)
return sx * (1 - sx)
def lse_0_from_e_x(x, e_x):
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
def lse_0(x):
return lse_0_from_e_x(x, exp(x))
def relu_prime(x):
return (0 <= x)
def relu(x):
return (0 < x).if_else(x, 0)
def progress(x):
return
print_ln(x)
time()
def set_n_threads(n_threads):
Layer.n_threads = n_threads
Optimizer.n_threads = n_threads
class Layer:
n_threads = 1
class Output(Layer):
def __init__(self, N, debug=False):
self.N = N
self.X = sfix.Array(N)
self.Y = sfix.Array(N)
self.nabla_X = sfix.Array(N)
self.l = MemValue(sfix(-1))
self.e_x = sfix.Array(N)
self.debug = debug
self.weights = cint.Array(N)
self.weights.assign_all(1)
self.weight_total = N
nablas = lambda self: ()
thetas = lambda self: ()
reset = lambda self: None
def divisor(self, divisor, size):
return cfix(1.0 / divisor, size=size)
def forward(self, N=None):
N = N or self.N
lse = sfix.Array(N)
@multithread(self.n_threads, N)
def _(base, size):
x = self.X.get_vector(base, size)
y = self.Y.get_vector(base, size)
e_x = exp(-x)
self.e_x.assign(e_x, base)
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
e_x = self.e_x.get_vector(0, N)
self.l.write(sum(lse) * \
self.divisor(self.N, 1))
def backward(self):
@multithread(self.n_threads, self.N)
def _(base, size):
diff = sigmoid_from_e_x(self.X.get_vector(base, size),
self.e_x.get_vector(base, size)) - \
self.Y.get_vector(base, size)
assert sfix.f == cfix.f
diff *= self.weights.get_vector(base, size)
self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \
base)
# @for_range_opt(len(diff))
# def _(i):
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
if self.debug:
a = cfix.Array(len(diff))
a.assign(diff.reveal())
@for_range(len(diff))
def _(i):
x = a[i]
print_ln_if((x < -1.001) + (x > 1.001), 'sigmoid')
#print_ln('%s', x)
def set_weights(self, weights):
self.weights.assign(weights)
self.weight_total = sum(weights)
class DenseBase(Layer):
thetas = lambda self: (self.W, self.b)
nablas = lambda self: (self.nabla_W, self.nabla_b)
def backward_params(self, f_schur_Y):
N = self.N
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
def _(j, k):
assert self.d == 1
a = [f_schur_Y[i][0][k] for i in range(N)]
b = [self.X[i][0][j] for i in range(N)]
tmp[j][k] = sfix.unreduced_dot_product(a, b)
if self.d_in * self.d_out < 100000:
print 'reduce at once'
@multithread(self.n_threads, self.d_in * self.d_out)
def _(base, size):
self.nabla_W.assign_vector(
tmp.get_vector(base, size).reduce_after_mul(), base=base)
else:
@for_range_opt(self.d_in)
def _(i):
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N))
for j in range(self.d)) for i in range(self.d_out))
progress('nabla W/b')
class Dense(DenseBase):
def __init__(self, N, d_in, d_out, d=1, activation='id'):
self.activation = activation
if activation == 'id':
self.f = lambda x: x
elif activation == 'relu':
self.f = relu
self.f_prime = relu_prime
elif activation == 'sigmoid':
self.f = sigmoid
self.f_prime = sigmoid_prime
self.N = N
self.d_in = d_in
self.d_out = d_out
self.d = d
self.X = MultiArray([N, d, d_in], sfix)
self.Y = MultiArray([N, d, d_out], sfix)
self.W = sfix.Matrix(d_in, d_out)
self.b = sfix.Array(d_out)
self.reset()
self.nabla_Y = MultiArray([N, d, d_out], sfix)
self.nabla_X = MultiArray([N, d, d_in], sfix)
self.nabla_W = sfix.Matrix(d_in, d_out)
self.nabla_W.assign_all(0)
self.nabla_b = sfix.Array(d_out)
self.f_input = MultiArray([N, d, d_out], sfix)
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
@for_range(d_in)
def _(i):
@for_range(d_out)
def _(j):
self.W[i][j] = sfix.get_random(-r, r)
self.b.assign_all(0)
def compute_f_input(self):
prod = MultiArray([self.N, self.d, self.d_out], sfix)
@for_range_opt_multithread(self.n_threads, self.N)
def _(i):
self.X[i].plain_mul(self.W, res=prod[i])
@for_range_opt_multithread(self.n_threads, self.N)
def _(i):
@for_range_opt(self.d)
def _(j):
v = prod[i][j].get_vector() + self.b.get_vector()
self.f_input[i][j].assign(v)
progress('f input')
def forward(self):
self.compute_f_input()
self.Y.assign_vector(self.f(self.f_input.get_vector()))
def backward(self, compute_nabla_X=True):
N = self.N
d = self.d
d_out = self.d_out
X = self.X
Y = self.Y
W = self.W
b = self.b
nabla_X = self.nabla_X
nabla_Y = self.nabla_Y
nabla_W = self.nabla_W
nabla_b = self.nabla_b
if self.activation == 'id':
f_schur_Y = nabla_Y
else:
f_prime_bit = MultiArray([N, d, d_out], sint)
f_schur_Y = MultiArray([N, d, d_out], sfix)
self.compute_f_input()
f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector()))
progress('f prime')
@for_range_opt(N)
def _(i):
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
progress('f prime schur Y')
if compute_nabla_X:
@for_range_opt(N)
def _(i):
if self.activation == 'id':
nabla_X[i] = nabla_Y[i].mul_trans(W)
else:
nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W)
progress('nabla X')
self.backward_params(f_schur_Y)
class QuantizedDense(DenseBase):
def __init__(self, N, d_in, d_out):
self.N = N
self.d_in = d_in
self.d_out = d_out
self.d = 1
self.H = math.sqrt(1.5 / (d_in + d_out))
self.W = sfix.Matrix(d_in, d_out)
self.nabla_W = self.W.same_shape()
self.T = sint.Matrix(d_in, d_out)
self.b = sfix.Array(d_out)
self.nabla_b = self.b.same_shape()
self.X = MultiArray([N, 1, d_in], sfix)
self.Y = MultiArray([N, 1, d_out], sfix)
self.nabla_Y = self.Y.same_shape()
def reset(self):
@for_range(self.d_in)
def _(i):
@for_range(self.d_out)
def _(j):
self.W[i][j] = sfix.get_random(-1, 1)
self.b.assign_all(0)
def forward(self):
@for_range_opt(self.d_in)
def _(i):
@for_range_opt(self.d_out)
def _(j):
over = self.W[i][j] > 0.5
under = self.W[i][j] < -0.5
self.T[i][j] = over.if_else(1, under.if_else(-1, 0))
over = self.W[i][j] > 1
under = self.W[i][j] < -1
self.W[i][j] = over.if_else(1, under.if_else(-1, self.W[i][j]))
@for_range_opt(self.N)
def _(i):
assert self.d_out == 1
self.Y[i][0][0] = self.b[0] + self.H * sfix._new(
sint.dot_product([self.T[j][0] for j in range(self.d_in)],
[self.X[i][0][j].v for j in range(self.d_in)]))
def backward(self, compute_nabla_X=False):
assert not compute_nabla_X
self.backward_params(self.nabla_Y)
class Dropout:
def __init__(self, N, d1, d2=1):
self.N = N
self.d1 = d1
self.d2 = d2
self.X = MultiArray([N, d1, d2], sfix)
self.Y = MultiArray([N, d1, d2], sfix)
self.nabla_Y = MultiArray([N, d1, d2], sfix)
self.nabla_X = MultiArray([N, d1, d2], sfix)
self.alpha = 0.5
self.B = MultiArray([N, d1, d2], sint)
def forward(self):
assert self.alpha == 0.5
@for_range(self.N)
def _(i):
@for_range(self.d1)
def _(j):
@for_range(self.d2)
def _(k):
self.B[i][j][k] = sint.get_random_bit()
self.Y = self.X.schur(self.B)
def backward(self):
self.nabla_X = self.nabla_Y.schur(self.B)
class QuantBase(object):
n_threads = 1
@staticmethod
def new_squant():
class _(squant):
@classmethod
def get_input_from(cls, player, size=None):
return cls._new(sint.get_input_from(player, size=size))
return _
def __init__(self, input_shape, output_shape):
self.input_shape = input_shape
self.output_shape = output_shape
self.input_squant = self.new_squant()
self.output_squant = self.new_squant()
self.X = MultiArray(input_shape, self.input_squant)
self.Y = MultiArray(output_shape, self.output_squant)
def temp_shape(self):
return [0]
class QuantConvBase(QuantBase):
fewer_rounds = True
temp_weights = None
temp_inputs = None
@classmethod
def init_temp(cls, layers):
size = 0
for layer in layers:
size = max(size, reduce(operator.mul, layer.temp_shape()))
cls.temp_weights = sfix.Array(size)
cls.temp_inputs = sfix.Array(size)
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride):
super(QuantConvBase, self).__init__(input_shape, output_shape)
self.weight_shape = weight_shape
self.bias_shape = bias_shape
self.stride = stride
self.weight_squant = self.new_squant()
self.bias_squant = self.new_squant()
self.weights = MultiArray(weight_shape, self.weight_squant)
self.bias = Array(output_shape[-1], self.bias_squant)
self.unreduced = MultiArray(self.output_shape, sint,
address=self.Y.address)
assert(weight_shape[-1] == input_shape[-1])
assert(bias_shape[0] == output_shape[-1])
assert(len(bias_shape) == 1)
assert(len(input_shape) == 4)
assert(len(output_shape) == 4)
assert(len(weight_shape) == 4)
def input_from(self, player):
for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
self.weights.input_from(player, budget=100000)
self.bias.input_from(player)
print 'WARNING: assuming that bias quantization parameters are correct'
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
def dot_product(self, iv, wv, out_y, out_x, out_c):
bias = self.bias[out_c]
acc = squant.unreduced_dot_product(iv, wv)
acc.v += bias.v
acc.res_params = self.output_squant.params
#self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul()
self.unreduced[0][out_y][out_x][out_c] = acc.v
def reduction(self):
unreduced = self.unreduced
n_summands = self.n_summands()
start_timer(2)
n_outputs = reduce(operator.mul, self.output_shape)
if n_outputs % self.n_threads == 0:
n_per_thread = n_outputs / self.n_threads
@for_range_opt_multithread(self.n_threads, self.n_threads)
def _(i):
res = _unreduced_squant(
sint.load_mem(unreduced.address + i * n_per_thread,
size=n_per_thread),
(self.input_squant.params, self.weight_squant.params),
self.output_squant.params,
n_summands).reduce_after_mul()
res.store_in_mem(self.Y.address + i * n_per_thread)
else:
@for_range_opt_multithread(self.n_threads, self.output_shape[1])
def _(out_y):
self.Y[0][out_y].assign_vector(_unreduced_squant(
unreduced[0][out_y].get_vector(),
(self.input_squant.params, self.weight_squant.params),
self.output_squant.params,
n_summands).reduce_after_mul())
stop_timer(2)
def temp_shape(self):
return list(self.output_shape[1:]) + [self.n_summands()]
def prepare_temp(self):
shape = self.temp_shape()
inputs = MultiArray(shape, self.input_squant,
address=self.temp_inputs)
weights = MultiArray(shape, self.weight_squant,
address=self.temp_weights)
return inputs, weights
class QuantConv2d(QuantConvBase):
def n_summands(self):
_, weights_h, weights_w, _ = self.weight_shape
_, inputs_h, inputs_w, n_channels_in = self.input_shape
return weights_h * weights_w * n_channels_in
def forward(self, N=1):
assert(N == 1)
assert(self.weight_shape[0] == self.output_shape[-1])
_, weights_h, weights_w, _ = self.weight_shape
_, inputs_h, inputs_w, n_channels_in = self.input_shape
_, output_h, output_w, n_channels_out = self.output_shape
stride_h, stride_w = self.stride
padding_h, padding_w = (weights_h // 2, weights_w // 2)
if self.fewer_rounds:
inputs, weights = self.prepare_temp()
@for_range_opt_multithread(self.n_threads,
[output_h, output_w, n_channels_out])
def _(out_y, out_x, out_c):
in_x_origin = (out_x * stride_w) - padding_w
in_y_origin = (out_y * stride_h) - padding_h
iv = []
wv = []
for filter_y in range(weights_h):
in_y = in_y_origin + filter_y
inside_y = (0 <= in_y) * (in_y < inputs_h)
for filter_x in range(weights_w):
in_x = in_x_origin + filter_x
inside_x = (0 <= in_x) * (in_x < inputs_w)
inside = inside_y * inside_x
if inside is 0:
continue
for in_c in range(n_channels_in):
iv += [self.X[0][in_y * inside_y]
[in_x * inside_x][in_c]]
wv += [self.weights[out_c][filter_y][filter_x][in_c]]
wv[-1] *= inside
if self.fewer_rounds:
inputs[out_y][out_x][out_c].assign(iv)
weights[out_y][out_x][out_c].assign(wv)
else:
self.dot_product(iv, wv, out_y, out_x, out_c)
if self.fewer_rounds:
@for_range_opt_multithread(self.n_threads,
list(self.output_shape[1:]))
def _(out_y, out_x, out_c):
self.dot_product(inputs[out_y][out_x][out_c],
weights[out_y][out_x][out_c],
out_y, out_x, out_c)
self.reduction()
class QuantDepthwiseConv2d(QuantConvBase):
def n_summands(self):
_, weights_h, weights_w, _ = self.weight_shape
return weights_h * weights_w
def forward(self, N=1):
assert(N == 1)
assert(self.weight_shape[-1] == self.output_shape[-1])
assert(self.input_shape[-1] == self.output_shape[-1])
_, weights_h, weights_w, _ = self.weight_shape
_, inputs_h, inputs_w, n_channels_in = self.input_shape
_, output_h, output_w, n_channels_out = self.output_shape
stride_h, stride_w = self.stride
padding_h, padding_w = (weights_h // 2, weights_w // 2)
depth_multiplier = 1
if self.fewer_rounds:
inputs, weights = self.prepare_temp()
@for_range_opt_multithread(self.n_threads,
[output_h, output_w, n_channels_in])
def _(out_y, out_x, in_c):
for m in range(depth_multiplier):
oc = m + in_c * depth_multiplier
in_x_origin = (out_x * stride_w) - padding_w
in_y_origin = (out_y * stride_h) - padding_h
iv = []
wv = []
for filter_y in range(weights_h):
for filter_x in range(weights_w):
in_x = in_x_origin + filter_x
in_y = in_y_origin + filter_y
inside = (0 <= in_x) * (in_x < inputs_w) * \
(0 <= in_y) * (in_y < inputs_h)
if inside is 0:
continue
iv += [self.X[0][in_y][in_x][in_c]]
wv += [self.weights[0][filter_y][filter_x][oc]]
wv[-1] *= inside
if self.fewer_rounds:
inputs[out_y][out_x][oc].assign(iv)
weights[out_y][out_x][oc].assign(wv)
else:
self.dot_product(iv, wv, out_y, out_x, oc)
if self.fewer_rounds:
@for_range_opt_multithread(self.n_threads,
list(self.output_shape[1:]))
def _(out_y, out_x, out_c):
self.dot_product(inputs[out_y][out_x][out_c],
weights[out_y][out_x][out_c],
out_y, out_x, out_c)
self.reduction()
class QuantAveragePool2d(QuantBase):
def __init__(self, input_shape, output_shape, filter_size):
super(QuantAveragePool2d, self).__init__(input_shape, output_shape)
self.filter_size = filter_size
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
def forward(self, N=1):
assert(N == 1)
_, input_h, input_w, n_channels_in = self.input_shape
_, output_h, output_w, n_channels_out = self.output_shape
n = input_h * input_w
print 'divisor: ', n
assert output_h == output_w == 1
assert n_channels_in == n_channels_out
padding_h, padding_w = (0, 0)
stride_h, stride_w = (2, 2)
filter_h, filter_w = self.filter_size
@for_range_opt(output_h)
def _(out_y):
@for_range_opt(output_w)
def _(out_x):
@for_range_opt(n_channels_in)
def _(c):
in_x_origin = (out_x * stride_w) - padding_w
in_y_origin = (out_y * stride_h) - padding_h
fxs = (-in_x_origin).max(0)
#fxe = min(filter_w, input_w - in_x_origin)
fys = (-in_y_origin).max(0)
#fye = min(filter_h, input_h - in_y_origin)
acc = 0
#fc = 0
for i in range(filter_h):
filter_y = fys + i
for j in range(filter_w):
filter_x = fxs + j
in_x = in_x_origin + filter_x
in_y = in_y_origin + filter_y
acc += self.X[0][in_y][in_x][c].v
#fc += 1
logn = int(math.log(n, 2))
acc = (acc + n / 2)
if 2 ** logn == n:
acc = acc.round(self.output_squant.params.k + logn,
logn, nearest=True)
else:
acc = acc.int_div(sint(n),
self.output_squant.params.k + logn)
#acc = min(255, max(0, acc))
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
class QuantReshape(QuantBase):
def __init__(self, input_shape, _, output_shape):
super(QuantReshape, self).__init__(input_shape, output_shape)
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
_ = self.new_squant()
for s in self.input_squant, _, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
for i in range(2):
sint.get_input_from(player)
def forward(self, N=1):
assert(N == 1)
# reshaping is implicit
self.Y.assign(self.X)
class QuantSoftmax(QuantBase):
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
def forward(self, N=1):
assert(N == 1)
assert(len(self.input_shape) == 2)
# just print the best
def comp(left, right):
c = left[1].v.greater_than(right[1].v, self.input_squant.params.k)
#print_ln('comp %s %s %s', c.reveal(), left[1].v.reveal(), right[1].v.reveal())
return [c.if_else(x, y) for x, y in zip(left, right)]
print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal())
class Optimizer:
n_threads = Layer.n_threads
def forward(self, N):
for j in range(len(self.layers) - 1):
self.layers[j].forward()
self.layers[j + 1].X.assign(self.layers[j].Y)
self.layers[-1].forward(N)
def backward(self):
for j in range(1, len(self.layers)):
self.layers[-j].backward()
self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X)
self.layers[0].backward(compute_nabla_X=False)
def run(self):
i = MemValue(0)
@do_while
def _():
if self.X_by_label is not None:
N = self.layers[0].N
assert self.layers[-1].N == N
assert N % 2 == 0
n = N / 2
@for_range(n)
def _(i):
self.layers[-1].Y[i] = 0
self.layers[-1].Y[i + n] = 1
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
self.X_by_label) / n))
print '%d runs per epoch' % n_per_epoch
indices_by_label = []
for label, X in enumerate(self.X_by_label):
indices = regint.Array(n * n_per_epoch)
indices_by_label.append(indices)
indices.assign(i % len(X) for i in range(len(indices)))
indices.shuffle()
@for_range(n_per_epoch)
def _(j):
j = MemValue(j)
for label, X in enumerate(self.X_by_label):
indices = indices_by_label[label]
@for_range_multithread(self.n_threads, 1, n)
def _(i):
idx = indices[i + j * n_per_epoch]
self.layers[0].X[i + label * n] = X[idx]
self.forward(None)
self.backward()
self.update(i)
else:
self.forward(None)
self.backward()
self.update(i)
loss = self.layers[-1].l
if self.report_loss:
print_ln('loss after epoch %s: %s', i, loss.reveal())
else:
print_ln('done with epoch %s', i)
time()
i.iadd(1)
res = (i < self.n_epochs)
if self.tol > 0:
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
return res
print_ln('finished after %s epochs', i)
class Adam(Optimizer):
def __init__(self, layers, n_epochs):
self.alpha = .001
self.beta1 = 0.9
self.beta2 = 0.999
self.epsilon = 10 ** -8
self.n_epochs = n_epochs
self.layers = layers
self.ms = []
self.vs = []
self.gs = []
self.thetas = []
for layer in layers:
for nabla in layer.nablas():
self.gs.append(nabla)
for x in self.ms, self.vs:
x.append(nabla.same_shape())
for theta in layer.thetas():
self.thetas.append(theta)
self.mhat_factors = Array(n_epochs, sfix)
self.vhat_factors = Array(n_epochs, sfix)
for i in range(n_epochs):
for factors, beta in ((self.mhat_factors, self.beta1),
(self.vhat_factors, self.beta2)):
factors[i] = 1. / (1 - beta ** (i + 1))
def update(self, i_epoch):
for m, v, g, theta in zip(self.ms, self.vs, self.gs, self.thetas):
@for_range_opt(len(m))
def _(k):
m[k] = self.beta1 * m[k] + (1 - self.beta1) * g[k]
v[k] = self.beta2 * v[k] + (1 - self.beta2) * g[k] ** 2
mhat = m[k] * self.mhat_factors[i_epoch]
vhat = v[k] * self.vhat_factors[i_epoch]
theta[k] = theta[k] - self.alpha * mhat / \
mpc_math.sqrt(vhat) + self.epsilon
class SGD(Optimizer):
def __init__(self, layers, n_epochs, debug=False, report_loss=False):
self.momentum = 0.9
self.layers = layers
self.n_epochs = n_epochs
self.thetas = []
self.nablas = []
self.delta_thetas = []
for layer in layers:
self.nablas.extend(layer.nablas())
self.thetas.extend(layer.thetas())
for theta in layer.thetas():
self.delta_thetas.append(theta.same_shape())
self.gamma = MemValue(sfix(0.01))
self.debug = debug
self.report_loss = report_loss
self.tol = 0.000
self.X_by_label = None
def reset(self, X_by_label=None):
self.X_by_label = X_by_label
for y in self.delta_thetas:
y.assign_all(0)
for layer in self.layers:
layer.reset()
def update(self, i_epoch):
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
self.delta_thetas):
@for_range_opt_multithread(self.n_threads, len(nabla))
def _(k):
old = delta_theta[k]
if isinstance(old, Array):
old = old.get_vector()
red_old = self.momentum * old
new = self.gamma * nabla[k]
diff = red_old - new
delta_theta[k] = diff
theta[k] = theta[k] + delta_theta[k]
if self.debug:
for x, name in (old, 'old'), (red_old, 'red_old'), \
(new, 'new'), (diff, 'diff'):
x = x.reveal()
print_ln_if((x > 1000) + (x < -1000),
name + ': %s %s %s %s',
*[y.v.reveal() for y in old, red_old, \
new, diff])
if self.debug:
d = delta_theta.get_vector().reveal()
a = cfix.Array(len(d.v))
a.assign(d)
@for_range(len(a))
def _(i):
x = a[i]
print_ln_if((x > 1000) + (x < -1000),
'update len=%d' % len(nabla))
a.assign(nabla.get_vector().reveal())
@for_range(len(a))
def _(i):
x = a[i]
print_ln_if((x > 1000) + (x < -1000),
'nabla len=%d' % len(nabla))
self.gamma.imul(1 - 10 ** - 6)

752
Compiler/mpc_math.py Normal file
View File

@@ -0,0 +1,752 @@
##
# @file
# Arithmetic Module for Complex Math Operations
#
# Implements trigonometric and logarithmic functions.
import math
from Compiler import floatingpoint
from Compiler import types
from Compiler import comparison
from Compiler import program
# polynomials as enumerated on Hart's book
##
# @private
p_3307 = [1.57079632679489000000000, -0.64596409750624600000000,
0.07969262624616700000000, -0.00468175413531868000000,
0.00016044118478735800000, -0.00000359884323520707000,
0.00000005692172920657320, -0.00000000066880348849204,
0.00000000000606691056085, -0.00000000000004375295071,
0.00000000000000025002854]
##
# @private
p_3508 = [1.00000000000000000000, -0.50000000000000000000,
0.04166666666666667129, -0.00138888888888888873,
0.00002480158730158702, -0.00000027557319223933,
0.00000000208767569817, -0.00000000001147074513,
0.00000000000004779454, -0.00000000000000015612,
0.00000000000000000040]
##
# @private
p_1045 = [1.000000077443021686, 0.693147180426163827795756,
0.224022651071017064605384, 0.055504068620466379157744,
0.009618341225880462374977, 0.001332730359281437819329,
0.000155107460590052573978, 0.000014197847399765606711,
0.000001863347724137967076]
##
# @private
p_2524 = [-2.05466671951, -8.8626599391,
+6.10585199015, +4.81147460989]
##
# @private
q_2524 = [+0.353553425277, +4.54517087629,
+6.42784209029, +1]
##
# @private
p_5102 = [+21514.05962602441933193254468, +73597.43380288444240814980706,
+100272.5618306302784970511863, +69439.29750032252337059765503,
+25858.09739719099025716567793, +5038.63918550126655793779119,
+460.1588804635351471161727227, +15.08767735870030987717455528,
+0.07523052818757628444510729539]
##
# @private
q_5102 = [+21514.05962602441933193298234, +80768.78701155924885176713209,
+122892.6789092784776298743322, +97323.20349053555680260434387,
+42868.57652046408093184006664, +10401.13491566890057005103878,
+1289.75056911611097141145955, +68.51937831018968013114024294,
+1]
##
# @private
p_4737 = [-9338.550897341021522505385079, +43722.68009378241623148489754,
-86008.12066370804865047446067, +92190.57592175496843898184959,
-58360.27724533928122075635101, +22081.61324178027161353562222,
-4805.541226761699661564427739, +542.2148323255220943742314911,
-24.94928894422502466205102672, 0.2222361619461131578797029272]
##
# @private
q_4737 =[-9338.550897341021522505384935, +45279.10524333925315190231067,
-92854.24688696401422824346529, +104687.2504366298224257408682,
-70581.74909396877350961227976, +28972.22947326672977624954443,
-7044.002024719172700685571406, +935.7104153502806086331621628,
-56.83369358538071475796209327, 1]
##
# @private
p_4754 = [-6.90859801, +12.85564644, -5.94939208]
##
# @private
q_4754 = [-6.92529156, +14.20305096, -8.27925501, 1]
# all inputs are calcualted in radians hence we need some conversion.
pi = math.radians(180)
pi_over_2 = math.radians(90)
##
# truncates values regardless of the input type. (It always rounds down)
# @param x: coefficient to be truncated.
#
# @return truncated sint value of x
def trunc(x):
if type(x) is types.sfix:
return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True)
elif type(x) is types.sfloat:
v, p, z, s = floatingpoint.FLRound(x, 0)
#return types.sfloat(v, p, z, s, x.err)
return types.sfloat(v, p, z, s)
return x
##
# loads integer to fractional type (sint)
# @param x: coefficient to be truncated.
#
# @return returns sfix, sfloat loaded value
def load_sint(x, l_type):
if l_type is types.sfix:
return types.sfix.from_sint(x)
elif l_type is types.sfloat:
return x
return x
##
# evaluates a Polynomial to a given x in a privacy preserving manner.
# Inputs can be of any kind of register, secret or otherwise.
#
# @param p_c: Polynomial coefficients. (Array)
#
# @param x: Value to which the polynomial p_c is evaluated to.(register)
#
# @return the evaluation of the polynomial. return type depends on inputs.
def p_eval(p_c, x):
degree = len(p_c) - 1
if type(x) is types.sfix:
# ignore coefficients smaller than precision
for c in reversed(p_c):
if c < 2 ** -(x.f + 1):
degree -= 1
else:
break
pre_mults = floatingpoint.PreOpL(lambda a,b,_: a * b,
[x] * degree)
local_aggregation = 0
# Evaluation of the Polynomial
for i, pre_mult in zip(p_c[1:], pre_mults):
local_aggregation += pre_mult.mul_no_reduce(x.coerce(i))
return local_aggregation.reduce_after_mul() + p_c[0]
##
# reduces the input to [0,90) and returns whether the reduced value is
# greater than \Pi and greater than Pi over 2
# @param x: value of any type to be reduced to the [0,90) interval
#
# @return w: reduced angle in either fixed or floating point .
#
# @return b1: \{0,1\} value. Returns one when reduction to 2*\pi
# is greater than \pi
#
# @return b2: \{0,1\} value. Returns one when reduction to
# \pi is greater than \pi/2.
def sTrigSub(x):
# reduction to 2* \pi
f = x * (1.0 / (2 * pi))
f = load_sint(trunc(f), type(x))
y = x - (f) * (2 * pi)
# reduction to \pi
b1 = y > pi
w = b1 * ((2 * pi - y) - y) + y
# reduction to \pi/2
b2 = w > pi_over_2
w = b2 * ((pi - w) - w) + w
# returns scaled angle and boolean flags
return w, b1, b2
# kernel method calls -- they are built in a generic way
##
# Kernel sin. Returns the sin of a given angle on the [0, \pi/2) interval and
# adjust the sign in case the angle was reduced on the [0,360) interval
#
# @param w: fractional value for an angle on the [0, \pi) interval.
#
# @return returns the sin of w.
def ssin(w, s):
# calculates the v of w for polynomial evaluation
v = w * (1.0 / pi_over_2)
v_2 = v ** 2
# adjust sign according to the movement in the reduction
b = s * (-2) + 1
# calculate the sin using polynomial evaluation
local_sin = b * v * p_eval(p_3307, v_2)
return local_sin
##
# Kernel cos. Returns the cos of a given angle on the [0.pi/2)
# interval and adjust
# the sign in case the angle was reduced on the [0,360) interval.
#
# @param w: fractional value for an angle on the [0,\pi) interval.
#
# @param s: \{0,1\} value. Corresponding to b2. Returns 1 if the angle
# was reduced from an angle in the [\pi/2,\pi) interval.
#
# @return returns the cos of w (sfix).
def scos(w, s):
# calculates the v of the w.
v = w
v_2 = v ** 2
# adjust sign according to the movement in the reduction
b = s * (-2) + 1
# calculate the cos using polynomial evaluation
local_cos = b * p_eval(p_3508, v_2)
return local_cos
# facade method calls --it is built in a generic way
##
# Returns the sin of any given fractional value.
#
# @param x: fractional input (sfix, sfloat).
#
# @return returns the sin of x (sfix, sfloat).
def sin(x):
# reduces the angle to the [0,\pi/2) interval.
w, b1, b2 = sTrigSub(x)
# returns the sin with sign correction
return ssin(w, b1)
##
# Returns the sin of any given fractional value.
#
# @param x: fractional input (sfix, sfloat).
#
# @return returns the sin of x (sfix, sfloat).
def cos(x):
# reduces the angle to the [0,\pi/2) interval.
w, b1, b2 = sTrigSub(x)
# returns the sin with sign correction
return scos(w, b2)
##
# Returns the tan (sfix, sfloat) of any given fractional value.
#
# @param x: fractional input (sfix, sfloat).
#
# @return returns the tan of x (sifx, sfloat).
def tan(x):
# reduces the angle to the [0,\pi/2) interval.
w, b1, b2 = sTrigSub(x)
# calculates the sin and the cos.
local_sin = ssin(w, b1)
local_cos = scos(w, b2)
# obtains the local tan
local_tan = local_sin/local_cos
return local_tan
##
# Returns the result of 2^a for any unbounded number
# @param a: exponent for 2^a
#
# @return returns the value of 2^a if it is within the range
@types.vectorize
def exp2_fx(a):
if types.program.options.ring:
sint = types.sint
intbitint = types.intbitint
# how many bits to use from integer part
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
n_bits = a.f + n_int_bits
n_shift = int(types.program.options.ring) - a.k
r_bits = [sint.get_random_bit() for i in range(a.k)]
shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal()
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = sint()
comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])
lower_r = sint.bit_compose(r_bits[:a.f])
lower_masked = sint.bit_compose(masked_bits[:a.f])
lower = lower_r + lower_masked - (lower_overflow << (a.f))
c = types.sfix._new(lower, k=a.k, f=a.f)
higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits],
r_bits[a.f:n_bits],
carry_in=lower_overflow,
get_carry=True)
d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]),
k=a.k, f=a.f)
e = p_eval(p_1045, c)
g = d * e
small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False,
nearest=types.sfix.round_nearest),
k=a.k, f=a.f)
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
r_bits[n_bits:-1],
higher_bits[-1])
# should be for free
highest_bits = intbitint.ripple_carry_adder(
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
carry_in=higher_bits[-1])
bits_to_check = [x.bit_xor(y)
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
t = floatingpoint.KMul(bits_to_check)
# sign
s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry)
return s.if_else(t.if_else(small_result, 0), g)
else:
# obtain absolute value of a
s = a < 0
a = (s * (-2) + 1) * a
# isolates fractional part of number
b = trunc(a)
c = a - load_sint(b, type(a))
# squares integer part of a
d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a))
# evaluates fractional part of a in p_1045
e = p_eval(p_1045, c)
g = d * e
return (1 - s) * g + s * ((types.sfix(1)) / g)
##
# Returns the result of log_2(x) for any unbounded number. This is
# achieved by changing x into f*2^n where f is bounded by[0.5, 1].
# Then the polynomials are used to calculate the log_2 of f,
# which is then just added to n.
#
# @param x: input for log_2 (sfix, sint).
#
# @return returns (sfix) the value of log2(X)
@types.vectorize
def log2_fx(x):
if type(x) is types.sfix:
# transforms sfix to f*2^n, where f is [o.5,1] bounded
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
p -= x.f
vlen = x.f
else:
d = types.sfloat(x)
v, p, vlen = d.v, d.p, d.vlen
# isolates mantisa of d, now the n can be also substituted by the
# secret shared p from d in the expresion above.
v = load_sint(v, type(x))
w = (1.0 / (2 ** (vlen)))
v = v * w
# polynomials for the log_2 evaluation of f are calculated
P = p_eval(p_2524, v)
Q = p_eval(q_2524, v)
# the log is returned by adding the result of the division plus p.
a = P / Q + load_sint(vlen + p, type(x))
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
##
# Returns the value of the expression x^y where both inputs
# are secret shared. It uses log2_fx together with
# exp2_fx to calcualte the expresion 2^{y*log2(x)}.
#
# @param x: (sfix) secret shared base.
#
# @param y: (sfix, clear types) secret shared exponent.
#
# @return returns the value of x^y
def pow_fx(x, y):
log2_x =0
# obtains log2(x)
if (type(x) == int or type(x) == float):
log2_x = math.log(x,2)
else:
log2_x = log2_fx(x)
# obtains y * log2(x)
exp = y * log2_x
# returns 2^(y*log2(x))
return exp2_fx(exp)
##
# Returns the value of the expression log_b(x) where x is
# secret shared. It uses log2_fx to calculate the expression
# logb(2)*log2(x).
#
# @param x:(sfix, sint) secret shared coefficient for log.
#
# @param b:(int) base for log operation.
#
# @return returns (sfix) the value of logb(x).
def log_fx(x, b):
# calculates logb(2)
logb_2 = math.log(2, b)
# returns logb(2) * log2(x)
return logb_2 * log2_fx(x)
##
# Returns the absolute value of a fix point number.
# The method is also applicable to sfloat,
# however, more efficient mechanisms can be devised.
#
# @param x: (sfix)
#
# @return (sfix) unsigned
def abs_fx(x):
s = x < 0
return (1 - 2 * s) * x
##
# Floors the input and stores the value into a sflix register
# @param x: coefficient to be floored.
#
# @return floored sint value of x
def floor_fx(x):
return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x))
### sqrt methods
##
# obtains the most significative bit (MSB)
# of a given input. The size of the vector
# is tuned to the needs of sqrt.
# @param b: number from which you obtain the
# most significative bit.
# @param k: number of bits for which
# an output of size (k+1) if even
# is going to be produced.
# @return z: index array for MSB of size
# k or K+1 if even.
def MSB(b, k):
# calculation of z
# x in order 0 - k
if (k > types.program.bit_length):
raise OverflowError("The supported bit \
lenght of the application is smaller than k")
x_order = b.bit_decompose(k)
x = [0] * k
# x i now inverted
for i in range(k - 1, -1, -1):
x[k - 1 - i] = x_order[i]
# y is inverted for PReOR and then restored
y_order = floatingpoint.PreOR(x)
# y in order (restored in orginal order
y = [0] * k
for i in range(k - 1, -1, -1):
y[k - 1 - i] = y_order[i]
# obtain z
z = [0] * (k + 1 - k % 2)
for i in range(k - 1):
z[i] = y[i] - y[i + 1]
z[k - 1] = y[k - 1]
return z
##
# Similar to norm_SQ, saves rounds by not
# calculating v and c.
#
# @param b: sint input to be normalized.
# @param k: bitsize of the input, by definition
# its value is either sfix.k or program.bit_lengthh
# @return m_odd: the parity of most signficative bit index m
# @return m: index of most significative bit
# @return w: 2^m/2 or 2^ (m-1) /2
def norm_simplified_SQ(b, k):
z = MSB(b, k)
# construct m
#m = types.sint(0)
m_odd = 0
for i in range(k):
#m = m + (i + 1) * z[i]
# determine the parity of the input
if (i % 2 == 0):
m_odd = m_odd + z[i]
# construct w,
k_over_2 = k / 2 + 1
w_array = [0] * (k_over_2)
w_array[0] = z[0]
for i in range(1, k_over_2):
w_array[i] = z[2 * i - 1] + z[2 * i]
# w aggregation
w = types.sint(0)
for i in range(k_over_2):
w += (2 ** i) * w_array[i]
# return computed values
#return m_odd, m, w
return m_odd, None, w
##
# Obtains the sqrt using our custom mechanism
# for any sfix input value.
# no restrictions on the size of f.
#
# @param x: secret shared input from which the sqrt
# is calucalted,
#
# @return g: approximated sqrt
def sqrt_simplified_fx(x):
# fix theta (number of iterations)
theta = max(int(math.ceil(math.log(types.sfix.k))), 6)
# process to use 2^(m/2) approximation
m_odd, m, w = norm_simplified_SQ(x.v, x.k)
# process to set up the precision and allocate correct 2**f
if x.f % 2 == 1:
m_odd = (1 - 2 * m_odd) + m_odd
w = (w * 2 - w) * (1-m_odd) + w
# map number to use sfix format and instantiate the number
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2))
# obtains correct 2 ** (m/2)
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
# produce x/ 2^(m/2)
y_0 = types.cfix(1.0) / w
# from this point on it sufices to work sfix-wise
g_0 = (y_0 * x)
h_0 = y_0 * types.cfix(0.5)
gh_0 = g_0 * h_0
## initialization
g = g_0
h = h_0
gh = gh_0
for i in range(1, theta - 2):
r = (3 / 2.0) - gh
g = g * r
h = h * r
gh = g * h
# newton
r = (3 / 2.0) - gh
h = h * r
H = 4 * (h * h)
H = H * x
H = (3) - H
H = h * H
g = H * x
g = g
return g
##
# Calculates the normSQ of a number
# @param x: number from which the norm is going to be extracted
# @param k: bitsize of x
#
# @return c: where c = x*v where c is bounded by 2^{k-1} and 2^k
# @return v: where v = 2^k-m
# @return m: where m = MSB
# @return w: where w = 2^{m/2} if m is oeven and 2^{m-1 / 2} otherwise
def norm_SQ(b, k):
# calculation of z
# x in order 0 - k
z = MSB(b,k)
# now reverse bits of z[i] to generate v
v = types.sint(0)
for i in range(k):
v += (2**(k - i - 1)) * z[i]
c = b * v
# construct m
m = types.sint(0)
for i in range(k):
m = m + (i+1) * z[i]
# construct w, changes from what is on the paper
# and the documentation
k_over_2= k/2+1#int(math.ceil((k/2.0)))+1
w_array = [0]*(k_over_2 )
w_array[0] = z[0]
for i in range(1, k_over_2):
w_array[i] = z[2 * i - 1] + z[2 * i]
w = types.sint(0)
for i in range(k_over_2):
w += (2 ** i) * w_array[i]
# return computed values
return c, v, m, w
##
# Given f and k, returns a linear approximation of 1/x^{1/2}
# escalated by s^f.
# Method only works for sfix inputs. It uses the normSQ.
# the method is an implementation of [Liedel2012]
# @param x: number from which the approximation is caluclated
# @param k: bitsize of x
# @param f: precision of the input f
#
# @return c: Some approximation of (1/x^{1/2} * 2^f) *K
# where K is close to 1
def lin_app_SQ(b, k, f):
alpha = types.cfix((-0.8099868542) * 2 ** (k))
beta = types.cfix(1.787727479 * 2 ** (2 * k))
# obtain normSQ parameters
c, v, m, W = norm_SQ(types.sint(b), k)
# c is now escalated
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
# m even or odd determination
m_bit = types.sint()
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
m = load_sint(m_bit, types.sfix)
# w times v this way both terms have 2^3k and can be symplified
w = w * v
factor = 1.0 / (2 ** (3.0 * k - 2 * f))
w = w * factor # w escalated to 3k -2 * f
# normalization factor W* 1/2 ^{f/2}
w = w * W * types.cfix(1.0 / (2 ** (f / 2.0)))
# now we need to elminate an additional root of 2 in case m was odd
sqr_2 = types.cfix((2 ** (1 / 2.0)))
w = (1 - m) * w + sqr_2 * w * m
return w
##
# Given bitsize k and precision f, it calulates the square root of x.
# @param x: number from which the norm is going to be extracted
# @param k: bitsize of x.
# @param f: precision of x.
#
# @return g: square root of de-scaled input x
def sqrt_fx(x_l, k, f):
factor = 1.0 / (2.0 ** f)
x = load_sint(x_l, types.sfix) * factor
theta = int(math.ceil(math.log(k/5.4)))
y_0 = lin_app_SQ(x_l,k,f) #cfix(1.0/ (cx ** (1/2.0))) # lin_app_SQ(x_l,5,2)
y_0 = y_0 * factor #*((1.0/(2.0 ** f)))
g_0 = y_0 * x
#g = mpc_math.load_sint(mpc_math.trunc(g_0),types.sfix)
h_0 = y_0 *(0.5)
gh_0 = g_0 * h_0
##initialization
g= g_0
h= h_0
gh =gh_0
for i in range(1,theta-2): #to implement \in [1,\theta-2]
r = (3/2.0) - gh
g = g * r
h = h * r
gh = g * h
# newton
r = (3/2.0) - gh
h = h * r
H = 4 * (h * h)
H = H * x
H = (3) - H
H = h * H
g = H * x
g = g #* (0.5)
return g
##
# Returns the sqrt (sfix) of any given fractional
# value as long as it can be rounded to a integral value
# to 2^f precision.
#
# Note that sqrt only works as long as this inequality is respected:
# 3*k - 2 *f < x.f (x.f by default is 20)
# @param x: fractional input (sfix).
#
# @return returns the aTan of x (sifx).
@types.vectorize
def sqrt(x, k = types.sfix.k, f = types.sfix.f):
if (3 *k -2 * f >= types.sfix.f):
return sqrt_simplified_fx(x)
# raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ")
else:
param = trunc(x *(2 ** (f)))
return sqrt_fx(param ,k ,f)
##
# Returns the aTan (sfix) of any given fractional value.
#
# @param x: fractional input (sfix).
#
# @return returns the aTan of x (sifx).
def atan(x):
# obtain absolute value of x
s = x < 0
x_abs = (s * (-2) + 1) * x
# angle isolation
b = x_abs > 1
v = (types.cfix(1.0) / x_abs)
v = (1 - b) * (x_abs - v) + v
v_2 =v*v
# range of polynomial coefficients
assert x.k - x.f >= 18
P = p_eval(p_5102, v_2)
Q = p_eval(q_5102, v_2)
# padding
y = v * (P / Q)
y_pi_over_two = pi_over_2 - y
# sign correction
y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two
y = (1 - s) * (y - (-y)) + (-y)
return y
##
# Returns the aSin (sfix) of any given fractional value.
#
# @param x: fractional input (sfix). valid interval is -1.0 <= x <= 1
#
# @return returns the aSin of x (sfix).
def asin(x):
# Square x
x_2 = x*x
# trignometric identities
sqrt_l = sqrt(1- (x_2))
x_sqrt_l =x / sqrt_l
return atan(x_sqrt_l)
##
# Returns the aCos (sfix) of any given fractional value.
#
# @param x: fractional input (sfix). -1.0 < x < 1
#
# @return returns the aCos of x (sifx).
def acos(x):
y = asin(x)
return pi_over_2 - y

View File

@@ -44,8 +44,10 @@ class Program(object):
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
options.binary)) > 1:
raise CompilerError('can only use one out of -p, -B, -R, -F')
self.bit_length = int(options.ring) or int(options.binary) \
or int(options.field)
if options.ring:
self.bit_length = int(options.ring) - 1
else:
self.bit_length = int(options.binary) or int(options.field)
if not self.bit_length:
self.bit_length = BIT_LENGTHS[param]
print 'Default bit length:', self.bit_length
@@ -71,6 +73,7 @@ class Program(object):
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
self.types = {}
self.budget = int(self.options.budget)
self.verbose = False
self.to_merge = [Compiler.instructions.asm_open_class, \
Compiler.instructions.gasm_open_class, \
Compiler.instructions.muls_class, \
@@ -82,10 +85,13 @@ class Program(object):
Compiler.instructions.asm_input_class, \
Compiler.instructions.gasm_input_class,
Compiler.instructions.inputfix_class,
Compiler.instructions.inputfloat_class]
Compiler.instructions.inputfloat_class,
Compiler.instructions.inputmixed_class,
Compiler.instructions.trunc_pr_class]
import Compiler.GC.instructions as gc
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]
self.use_trunc_pr = False
Program.prog = self
self.reset_values()
@@ -452,6 +458,8 @@ class Tape:
else:
self.alloc_pool = defaultdict(set)
self.purged = False
self.n_rounds = 0
self.n_to_merge = 0
def __len__(self):
return len(self.instructions)
@@ -506,6 +514,8 @@ class Tape:
instructions = self.instructions
for inst in instructions:
inst.add_usage(req_node)
req_node.num['all', 'round'] = self.n_rounds
req_node.num['all', 'inv'] = self.n_to_merge
def __str__(self):
return self.name
@@ -530,7 +540,7 @@ class Tape:
self.basicblocks.append(sub)
self.active_basicblock = sub
self.req_node.add_block(sub)
print 'Compiling basic block', sub.name
#print 'Compiling basic block', sub.name
def init_registers(self):
self.reset_registers()
@@ -601,6 +611,8 @@ class Tape:
if len(block.instructions) > 10000:
print 'Merging instructions...'
numrounds = merger.longest_paths_merge()
block.n_rounds = numrounds
block.n_to_merge = len(merger.open_nodes)
if numrounds > 0:
print 'Program requires %d rounds of communication' % numrounds
if merger.counter:
@@ -633,10 +645,11 @@ class Tape:
# allocate registers
reg_counts = self.count_regs()
if not options.noreallocate:
print 'Tape register usage:', dict(reg_counts)
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
print 'Re-allocating...'
if self.program.verbose:
print 'Tape register usage:', dict(reg_counts)
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
print 'Re-allocating...'
allocator = al.StraightlineAllocator(REG_MAX)
def alloc_loop(block):
for reg in sorted(block.used_from_scope,
@@ -661,7 +674,7 @@ class Tape:
self.req_num = self.req_tree.aggregate()
print 'Tape requires', self.req_num
for req,num in sorted(self.req_num.items()):
if num == float('inf'):
if num == float('inf') or num >= 2 ** 32:
num = -1
if req[1] in data_types:
self.basicblocks[-1].instructions.append(
@@ -692,8 +705,9 @@ class Tape:
self.basicblocks[-1].instructions.append(
Compiler.instructions.reqbl(bl,
add_to_prog=False))
print 'Tape requires prime bit length', self.req_bit_length['p']
print 'Tape requires galois bit length', self.req_bit_length['2']
if self.program.verbose:
print 'Tape requires prime bit length', self.req_bit_length['p']
print 'Tape requires galois bit length', self.req_bit_length['2']
@unpurged
def _get_instructions(self):
@@ -783,6 +797,8 @@ class Tape:
return res
__rmul__ = __mul__
def set_all(self, value):
if value == float('inf') and self['all', 'inv'] > 0:
print 'Going to unknown from %s' % self
res = Tape.ReqNum()
for i in self:
res[i] = value
@@ -832,6 +848,15 @@ class Tape:
self.parent = parent
def aggregate(self, name):
res = self.aggregator([node.aggregate() for node in self.nodes])
try:
n_reps = self.aggregator([1])
n_rounds = res['all', 'round']
n_invs = res['all', 'inv']
if (n_invs / n_rounds) * 1000 < n_reps:
print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)
except:
pass
return res
def add_node(self, tape, name):
new_node = Tape.ReqNode(name)

View File

@@ -167,6 +167,12 @@ class _number(object):
def pow2(self, bit_length=None, security=None):
return 2**self
def min(self, other):
return (self < other).if_else(self, other)
def max(self, other):
return (self < other).if_else(other, self)
class _int(object):
def if_else(self, a, b):
if hasattr(a, 'for_mux'):
@@ -705,6 +711,10 @@ class regint(_register, _int):
popint(res)
return res
@vectorized_classmethod
def push(cls, value):
pushint(cls.conv(value))
@vectorized_classmethod
def get_random(cls, bit_length):
""" Public insecure randomness """
@@ -781,10 +791,10 @@ class regint(_register, _int):
@vectorize
@read_mem_value
def int_op(self, other, inst, reverse=False):
if isinstance(other, _secret):
try:
other = self.conv(other)
except:
return NotImplemented
elif not isinstance(other, type(self)):
other = type(self)(other)
res = regint()
if reverse:
inst(res, other, self)
@@ -898,6 +908,8 @@ class regint(_register, _int):
def print_reg_plain(self):
print_int(self)
def print_if(self, string):
cint(self).print_if(string)
class _secret(_register):
__slots__ = []
@@ -1121,6 +1133,13 @@ class sint(_secret, _int):
comparison.PRandInt(res, bits)
return res
@vectorized_classmethod
def get_input_from(cls, player):
""" Secret input """
res = cls()
inputmixed('int', res, player)
return res
@classmethod
def get_raw_input_from(cls, player):
res = cls()
@@ -1196,8 +1215,8 @@ class sint(_secret, _int):
@vectorize
def __lt__(self, other, bit_length=None, security=None):
res = sint()
comparison.LTZ(res, self - other, bit_length or program.bit_length +
(not (int(program.options.ring) == program.bit_length)),
comparison.LTZ(res, self - other,
(bit_length or program.bit_length) + 1,
security or program.security)
return res
@@ -1205,8 +1224,8 @@ class sint(_secret, _int):
@vectorize
def __gt__(self, other, bit_length=None, security=None):
res = sint()
comparison.LTZ(res, other - self, bit_length or program.bit_length +
(not (int(program.options.ring) == program.bit_length)),
comparison.LTZ(res, other - self,
(bit_length or program.bit_length) + 1,
security or program.security)
return res
@@ -1304,13 +1323,14 @@ class sint(_secret, _int):
return floatingpoint.BitDec(self, bit_length, bit_length, security)
def TruncMul(self, other, k, m, kappa=None, nearest=False):
return (self * other).round(k, m, kappa, nearest)
return (self * other).round(k, m, kappa, nearest, signed=True)
def TruncPr(self, k, m, kappa=None):
return floatingpoint.TruncPr(self, k, m, kappa)
def TruncPr(self, k, m, kappa=None, signed=True):
return floatingpoint.TruncPr(self, k, m, kappa, signed=signed)
@vectorize
def round(self, k, m, kappa=None, nearest=False, signed=False):
kappa = kappa or program.security
secret = isinstance(m, sint)
if nearest:
if secret:
@@ -1320,7 +1340,7 @@ class sint(_secret, _int):
else:
if secret:
return floatingpoint.Trunc(self, k, m, kappa)
return self.TruncPr(k, m, kappa)
return self.TruncPr(k, m, kappa, signed=signed)
def Norm(self, k, f, kappa=None, simplex_flag=False):
return library.Norm(self, k, f, kappa, simplex_flag)
@@ -1461,23 +1481,25 @@ class _bitint(object):
linear_rounds = False
@classmethod
def bit_adder(cls, a, b):
def bit_adder(cls, a, b, carry_in=0, get_carry=False):
a, b = list(a), list(b)
a += [0] * (len(b) - len(a))
b += [0] * (len(a) - len(b))
return cls.bit_adder_selection(a, b)
return cls.bit_adder_selection(a, b, carry_in=carry_in,
get_carry=get_carry)
@classmethod
def bit_adder_selection(cls, a, b):
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
if cls.log_rounds:
return cls.carry_lookahead_adder(a, b)
return cls.carry_lookahead_adder(a, b, carry_in=carry_in)
elif cls.linear_rounds:
return cls.ripple_carry_adder(a, b)
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
else:
return cls.carry_select_adder(a, b)
return cls.carry_select_adder(a, b, carry_in=carry_in)
@classmethod
def carry_lookahead_adder(cls, a, b, fewer_inv=False):
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
get_carry=False):
lower = []
for (ai,bi) in zip(a,b):
if ai is 0 or bi is 0:
@@ -1493,10 +1515,13 @@ class _bitint(object):
else:
pre_op = floatingpoint.PreOpL
if d:
carries = (0,) + zip(*pre_op(carry, d))[1]
carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1]
else:
carries = []
return lower + cls.sum_from_carries(a, b, carries)
res = lower + cls.sum_from_carries(a, b, carries)
if get_carry:
res += [carries[-1]]
return res
@staticmethod
def sum_from_carries(a, b, carries):
@@ -1504,7 +1529,7 @@ class _bitint(object):
for (ai, bi, carry) in zip(a, b, carries)]
@classmethod
def carry_select_adder(cls, a, b, get_carry=False):
def carry_select_adder(cls, a, b, get_carry=False, carry_in=0):
a += [0] * (len(b) - len(a))
b += [0] * (len(a) - len(b))
n = len(a)
@@ -1524,7 +1549,7 @@ class _bitint(object):
raise Exception('blocks not summing up: %s != %s' % \
(sum(blocks), n))
res = []
carry = 0
carry = carry_in
cin_one = util.long_one(a + b)
for m in blocks:
aa = a[:m]
@@ -1540,7 +1565,8 @@ class _bitint(object):
return res
@classmethod
def ripple_carry_adder(cls, a, b, carry=0):
def ripple_carry_adder(cls, a, b, carry_in=0):
carry = carry_in
res = []
for aa, bb in zip(a, b):
cc, carry = cls.full_adder(aa, bb, carry)
@@ -1760,14 +1786,15 @@ class intbitint(_bitint, sint):
for i in range(len(a))]
@classmethod
def bit_adder_selection(cls, a, b):
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
if cls.linear_rounds:
return cls.ripple_carry_adder(a, b)
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
# experimental cut-off with dead code elimination
elif len(a) < 122 or cls.log_rounds:
return cls.carry_lookahead_adder(a, b)
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
else:
return cls.carry_select_adder(a, b)
return cls.carry_select_adder(a, b, carry_in=carry_in)
class sgf2nint(_bitint, sgf2n):
bin_type = sgf2n
@@ -1904,10 +1931,10 @@ class sgf2nfloat(sgf2n):
sgf2nfloat.set_precision(24, 8)
def parse_type(other):
def parse_type(other, k=None, f=None):
# converts type to cfix/sfix depending on the case
if isinstance(other, cfix.scalars):
return cfix(other)
return cfix(other, k=k, f=f)
elif isinstance(other, cint):
tmp = cfix()
tmp.load_int(other)
@@ -1975,9 +2002,11 @@ class cfix(_number, _structure):
return 1
@vectorize_init
def __init__(self, v=None, size=None):
f = self.f
k = self.k
def __init__(self, v=None, k=None, f=None, size=None):
f = f or self.f
k = k or self.k
self.f = f
self.k = k
self.size = get_global_vector_size()
if isinstance(v, cint):
self.v = cint(v,size=self.size)
@@ -2025,22 +2054,20 @@ class cfix(_number, _structure):
other = parse_type(other)
if isinstance(other, cfix):
return cfix(self.v + other.v)
elif isinstance(other, sfix):
return sfix(self.v + other.v)
else:
raise CompilerError('Invalid type %s for cfix.__add__' % type(other))
return NotImplemented
@vectorize
def mul(self, other):
other = parse_type(other)
if isinstance(other, cfix):
assert self.f == other.f
sgn = cint(1 - 2 * (self.v * other.v < 0))
absolute = self.v * other.v * sgn
val = sgn * (absolute >> self.f)
return cfix(val)
elif isinstance(other, sfix):
res = sfix((self.v * other.v) >> self.f)
return res
return NotImplemented
else:
raise CompilerError('Invalid type %s for cfix.__mul__' % type(other))
@@ -2130,7 +2157,8 @@ class cfix(_number, _structure):
if isinstance(other, cfix):
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
elif isinstance(other, sfix):
return sfix(library.FPDiv(self.v, other.v, self.k, self.f, other.kappa))
return sfix(library.FPDiv(self.v, other.v, self.k, self.f,
other.kappa, nearest=sfix.round_nearest))
else:
raise TypeError('Incompatible fixed point types in division')
@@ -2169,6 +2197,7 @@ class _single(_number, _structure):
return cls._new(cls.int_type.load_mem(address))
@classmethod
@read_mem_value
def conv(cls, other):
if isinstance(other, cls):
return other
@@ -2193,9 +2222,13 @@ class _single(_number, _structure):
@classmethod
def dot_product(cls, x, y, res_params=None):
return cls.unreduced_dot_product(x, y, res_params).reduce_after_mul()
@classmethod
def unreduced_dot_product(cls, x, y, res_params=None):
dp = cls.int_type.dot_product([xx.pre_mul() for xx in x],
[yy.pre_mul() for yy in y])
return x[0].unreduced(dp, y[0], res_params, len(x)).reduce_after_mul()
return x[0].unreduced(dp, y[0], res_params, len(x))
@classmethod
def row_matrix_mul(cls, row, matrix, res_params=None):
@@ -2300,20 +2333,25 @@ class _fix(_single):
@classmethod
def coerce(cls, other):
if isinstance(other, _fix):
if isinstance(other, (_fix, cfix)):
return other
else:
return cls.conv(other)
@classmethod
def from_sint(cls, other):
def from_sint(cls, other, k=None, f=None):
res = cls()
res.f = f or cls.f
res.k = k or cls.k
res.load_int(cls.int_type.conv(other))
return res
@classmethod
def _new(cls, other):
return cls(other)
def _new(cls, other, k=None, f=None):
res = cls(other)
res.k = k or cls.k
res.f = f or cls.f
return res
@vectorize_init
def __init__(self, _v=None, size=None):
@@ -2331,7 +2369,7 @@ class _fix(_single):
self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size)
elif isinstance(_v, self.float_type):
p = (f + _v.p)
b = (p >= 0)
b = (p.greater_equal(0, _v.vlen))
a = b*(_v.v << (p)) + (1-b)*(_v.v >> (-p))
self.v = (1-2*_v.s)*a
elif isinstance(_v, type(self)):
@@ -2355,31 +2393,45 @@ class _fix(_single):
def add(self, other):
other = self.coerce(other)
if isinstance(other, (_fix, cfix)):
return type(self)(self.v + other.v)
return self._new(self.v + other.v, k=self.k, f=self.f)
elif isinstance(other, cfix.scalars):
tmp = cfix(other)
tmp = cfix(other, k=self.k, f=self.f)
return self + tmp
else:
raise CompilerError('Invalid type %s for _fix.__add__' % type(other))
return NotImplemented
@vectorize
def mul(self, other):
if isinstance(other, (sint, cint, regint, int, long)):
return self._new(self.v * other, k=self.k, f=self.f)
elif isinstance(other, float):
if int(other) == other:
return self.mul(int(other))
v = int(round(other * 2 ** self.f))
if v == 0:
return 0
f = self.f
while v % 2 == 0:
f -= 1
v /= 2
k = len(bin(abs(v))) - 1
other = cfix(cint(v))
other.f = f
other.k = k
other = self.coerce(other)
if isinstance(other, _fix):
val = self.v.TruncMul(other.v, self.k * 2, self.f, self.kappa,
if isinstance(other, (_fix, cfix)):
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
self.kappa,
self.round_nearest)
if self.size >= other.size:
return self._new(val)
return self._new(val, k=self.k, f=self.f)
else:
return self.vec._new(val)
elif isinstance(other, cfix):
res = type(self)((self.v * other.v) >> self.f)
return res
return self.vec._new(val, k=self.k, f=self.f)
elif isinstance(other, cfix.scalars):
scalar_fix = cfix(other)
return self * scalar_fix
else:
raise CompilerError('Invalid type %s for _fix.__mul__' % type(other))
return NotImplemented
@vectorize
def __neg__(self):
@@ -2397,6 +2449,7 @@ class _fix(_single):
else:
raise TypeError('Incompatible fixed point types in division')
@vectorize
def __rdiv__(self, other):
return self.coerce(other) / self
@@ -2418,14 +2471,24 @@ class sfix(_fix):
@vectorized_classmethod
def get_input_from(cls, player):
v = cls.int_type()
inputfix(v, cls.f, player)
inputmixed('fix', v, cls.f, player)
return cls._new(v)
@classmethod
def coerce(cls, other):
return parse_type(other)
@vectorized_classmethod
def get_random(cls, lower, upper):
""" Uniform random number around centre of bounds """
""" Range can be smaller """
log_range = int(math.log(upper - lower, 2))
n_bits = log_range + cls.f
average = lower + 0.5 * (upper - lower)
lower = average - 0.5 * 2 ** log_range
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
def coerce(self, other):
return parse_type(other, k=self.k, f=self.f)
def mul_no_reduce(self, other, res_params=None):
assert self.f == other.f
return self.unreduced(self.v * other.v)
def pre_mul(self):
@@ -2434,16 +2497,21 @@ class sfix(_fix):
def unreduced(self, v, other=None, res_params=None, n_summands=1):
return unreduced_sfix(v, self.k * 2, self.f, self.kappa)
class unreduced_sfix(object):
class unreduced_sfix(_single):
int_type = sint
@classmethod
def _new(cls, v):
return cls(v, 2 * sfix.k, sfix.f, sfix.kappa)
def __init__(self, v, k, m, kappa):
self.v = v
self.k = k
self.m = m
self.kappa = kappa
self.size = self.v.size
def __add__(self, other):
if other in (0, 0L):
if other is 0 or other is 0L:
return self
assert self.k == other.k
assert self.m == other.m
@@ -2455,7 +2523,10 @@ class unreduced_sfix(object):
@vectorize
def reduce_after_mul(self):
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
nearest=sfix.round_nearest))
nearest=sfix.round_nearest,
signed=True))
sfix.unreduced_type = unreduced_sfix
# this is for 20 bit decimal precision
# with 40 bitlength of entire number
@@ -2503,7 +2574,7 @@ class squant(_single):
raise CompilerError('%f not quantizable' % value)
self.v = self.int_type(q)
reset_global_vector_size()
elif isinstance(value, type(self)):
elif isinstance(value, squant) and value.params == self.params:
self.v = value.v
else:
raise CompilerError('cannot convert %s to squant' % value)
@@ -2538,7 +2609,7 @@ class squant(_single):
return self.mul_no_reduce(other, res_params).reduce_after_mul()
def mul_no_reduce(self, other, res_params=None):
if isinstance(other, sint):
if isinstance(other, (sint, cint, regint)):
return self._new(other * (self.v - self.Z) + self.Z,
params=self.get_params())
other = self.coerce(other)
@@ -2572,7 +2643,7 @@ class _unreduced_squant(object):
self.res_params = res_params or params[0]
def __add__(self, other):
if other in (0, 0L):
if other is 0 or other is 0L:
return self
assert self.params == other.params
assert self.res_params == other.res_params
@@ -2650,7 +2721,8 @@ class squant_params(object):
int_mult = util.expand(int_mult, size)
tmp = unreduced.v * int_mult + shifted_Z
shifted = tmp.round(self.max_length, n_shift,
squant.kappa, squant.round_nearest)
kappa=squant.kappa, nearest=squant.round_nearest,
signed=True)
if squant.clamp:
length = max(self.k, self.max_length - n_shift) + 1
top = (1 << self.k) - 1
@@ -2715,6 +2787,10 @@ class sfloat(_number, _structure):
else:
return cls(other)
@classmethod
def coerce(cls, other):
return cls.conv(other)
@staticmethod
def convert_float(v, vlen, plen):
if v < 0:
@@ -2747,7 +2823,7 @@ class sfloat(_number, _structure):
p = sint()
z = sint()
s = sint()
inputfloat(v, p, z, s, cls.vlen, player)
inputmixed('float', v, p, z, s, cls.vlen, player)
return cls(v, p, z, s)
@vectorize_init
@@ -2935,7 +3011,7 @@ class sfloat(_number, _structure):
return self + -other
def __rsub__(self, other):
raise NotImplementedError()
return -self + other
def __div__(self, other):
other = self.conv(other)
@@ -3144,17 +3220,18 @@ class Array(object):
for i in range(self.length):
yield self[i]
def assign(self, other):
if isinstance(other, Array):
def loop(i):
self[i] = other[i]
library.range_loop(loop, len(self))
elif isinstance(other, Tape.Register):
if len(other) == self.length:
self[0] = other
else:
raise CompilerError('Length mismatch between array and vector')
else:
def same_shape(self):
return Array(self.length, self.value_type)
def assign(self, other, base=0):
try:
other = other.get_vector()
except:
pass
try:
other.store_in_mem(self.get_address(base))
assert len(self) >= other.size + base
except AttributeError:
for i,j in enumerate(other):
self[i] = j
return self
@@ -3169,22 +3246,42 @@ class Array(object):
self[i] = mem_value
return self
def get_vector(self):
return self.value_type.load_mem(self.address, size=self.length)
def get_vector(self, base=0, size=None):
size = size or self.length
return self.value_type.load_mem(self.get_address(base), size=size)
def get_mem_value(self, index):
return MemValue(self[index], self.get_address(index))
def input_from(self, player, budget=None):
self.assign(self.value_type.get_input_from(player, size=len(self)))
def __add__(self, other):
if other is 0:
return self
assert len(self) == len(other)
return Array.create_from(x + y for x, y in zip(self, other))
return self.get_vector() + other
def __sub__(self, other):
assert len(self) == len(other)
return Array.create_from(x - y for x, y in zip(self, other))
return self.get_vector() - other
def __mul__(self, value):
return Array.create_from(x * value for x in self)
return self.get_vector() * value
def __pow__(self, value):
return self.get_vector() ** value
__radd__ = __add__
__rmul__ = __mul__
def shuffle(self):
@library.for_range(len(self))
def _(i):
j = regint.get_random(64) % (len(self) - i)
tmp = self[i]
self[i] = self[i + j]
self[i + j] = tmp
def reveal(self):
return Array.create_from(x.reveal() for x in self)
@@ -3223,6 +3320,9 @@ class SubMultiArray(object):
self.address, index, debug=self.debug)
return self.sub_cache[key]
def __setitem__(self, index, other):
self[index].assign(other)
def __len__(self):
return self.sizes[0]
@@ -3235,35 +3335,60 @@ class SubMultiArray(object):
def total_size(self):
return reduce(operator.mul, self.sizes) * self.value_type.n_elements()
def get_vector(self):
return self.value_type.load_mem(self.address, size=self.total_size())
def get_vector(self, base=0, size=None):
assert self.value_type.n_elements() == 1
size = size or self.total_size()
return self.value_type.load_mem(self.address + base, size=size)
def assign_vector(self, vector):
assert vector.size == self.total_size()
vector.store_in_mem(self.address)
def assign_vector(self, vector, base=0):
assert self.value_type.n_elements() == 1
assert vector.size <= self.total_size()
vector.store_in_mem(self.address + base)
class MultiArray(SubMultiArray):
def __init__(self, sizes, value_type, debug=None, address=None):
self.array = Array(reduce(operator.mul, sizes), \
value_type, address=address)
SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \
debug=debug)
if len(sizes) < 2:
raise CompilerError('Use Array')
def assign(self, other):
if self.value_type.n_elements() > 1:
assert self.sizes == other.sizes
self.assign_vector(other.get_vector())
class Matrix(MultiArray):
def __init__(self, rows, columns, value_type, debug=None, address=None):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address)
def same_shape(self):
return MultiArray(self.sizes, self.value_type)
def __setitem__(self, index, other):
assert other.size == self.sizes[1]
other.store_in_mem(self[index].address)
def input_from(self, player, budget=None):
@library.for_range_opt(self.sizes[0], budget=budget)
def _(i):
self[i].input_from(player, budget=budget)
def schur(self, other):
assert self.sizes == other.sizes
if len(self.sizes) == 2:
res = Matrix(self.sizes[0], self.sizes[1], self.value_type)
else:
res = MultiArray(self.sizes, self.value_type)
res.assign_vector(self.get_vector() * other.get_vector())
return res
def __add__(self, other):
if other is 0:
return self
assert self.sizes == other.sizes
if len(self.sizes) == 2:
res = Matrix(self.sizes[0], self.sizes[1], self.value_type)
else:
res = MultiArray(self.sizes, self.value_type)
res.assign_vector(self.get_vector() + other.get_vector())
return res
__radd__ = __add__
def iadd(self, other):
assert self.sizes == other.sizes
self.assign_vector(self.get_vector() + other.get_vector())
def __mul__(self, other):
return self.mul(other)
def mul(self, other, res_params=None):
assert len(self.sizes) == 2
if isinstance(other, Array):
assert len(other) == self.sizes[1]
if self.value_type.n_elements() == 1:
@@ -3277,7 +3402,8 @@ class Matrix(MultiArray):
matrix[i][0] = x
res = self * matrix
return Array.create_from(x[0] for x in res)
elif isinstance(other, Matrix):
elif isinstance(other, SubMultiArray):
assert len(other.sizes) == 2
assert other.sizes[0] == self.sizes[1]
if res_params is not None:
class t(self.value_type):
@@ -3287,14 +3413,16 @@ class Matrix(MultiArray):
t = self.value_type
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
try:
if max(res_matrix.sizes) > 1000:
raise AttributeError()
A = self.get_vector()
B = other.get_vector()
res_matrix.assign_vector(
self.value_type.matrix_mul(A, B, self.sizes[1],
res_params))
except AttributeError:
except (AttributeError, AssertionError):
# fallback for sfloat etc.
@library.for_range(self.sizes[0])
@library.for_range_opt(self.sizes[0])
def _(i):
try:
res_matrix[i] = self.value_type.row_matrix_mul(
@@ -3311,6 +3439,78 @@ class Matrix(MultiArray):
else:
raise NotImplementedError
def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True,
res=None):
assert len(self.sizes) == 2
assert len(other.sizes) == 2
if res is None:
if reduce:
res_matrix = Matrix(n_rows, n_columns, self.value_type)
else:
res_matrix = Matrix(n_rows, n_columns, \
self.value_type.unreduced_type)
else:
res_matrix = res
@library.for_range_opt(n_rows)
def _(i):
@library.for_range_opt(n_columns)
def _(j):
col = column(other, j)
r = row(self, i)
if reduce:
res_matrix[i][j] = self.value_type.dot_product(r, col)
else:
entry = self.value_type.unreduced_dot_product(r, col)
res_matrix[i][j] = entry
return res_matrix
def plain_mul(self, other, res=None):
assert other.sizes[0] == self.sizes[1]
return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \
other.sizes[1], \
lambda x, j: [x[k][j] for k in range(len(x))],
res=res)
def mul_trans(self, other):
assert other.sizes[1] == self.sizes[1]
return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \
other.sizes[0], lambda x, j: x[j])
def trans_mul(self, other, reduce=True, res=None):
assert other.sizes[0] == self.sizes[0]
return self.budget_mul(other, self.sizes[1], \
lambda x, j: [x[k][j] for k in range(len(x))], \
other.sizes[1], \
lambda x, j: [x[k][j] for k in range(len(x))],
reduce=reduce, res=res)
def transpose(self):
assert len(self.sizes) == 2
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
@library.for_range_opt(self.sizes[1])
def _(i):
@library.for_range_opt(self.sizes[0])
def _(j):
res[i][j] = self[j][i]
return res
class MultiArray(SubMultiArray):
def __init__(self, sizes, value_type, debug=None, address=None):
if isinstance(address, Array):
self.array = address
else:
self.array = Array(reduce(operator.mul, sizes), \
value_type, address=address)
SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \
debug=debug)
if len(sizes) < 2:
raise CompilerError('Use Array')
class Matrix(MultiArray):
def __init__(self, rows, columns, value_type, debug=None, address=None):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address)
class VectorArray(object):
def __init__(self, length, value_type, vector_size, address=None):
self.array = Array(length * vector_size, value_type, address)

View File

@@ -1,7 +1,7 @@
This directory contains the code used for the benchmarks by [Dalskov
et al.](https://eprint.iacr.org/2019/889) `*-ecdsa-party.cpp`
contains the high-level programs while the two phases are implemented
`preprocessing.hpp` and `sign.hpp`, respectively.
in `preprocessing.hpp` and `sign.hpp`, respectively.
#### Compilation

View File

@@ -104,7 +104,7 @@ class invalid_program: public exception
class file_error: public exception
{ string filename, ans;
public:
file_error(string m="") : filename(m)
file_error(string m) : filename(m)
{
ans="File Error : ";
ans+=filename;

View File

@@ -36,7 +36,8 @@ bigint SemiHomomorphicNoiseBounds::min_p0()
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q)
{
return 33.1 * (log_q - log2(3.2));
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
return 37.8 * (log_q - log2(3.2));
}

View File

@@ -26,7 +26,7 @@ inline FileSacriFactory<T>::FileSacriFactory(const char* type, const Player& P,
if (output_thread)
file1 << "-" << output_thread;
this->inpf.open(file1.str().c_str(),ios::in | ios::binary);
if (this->inpf.fail()) { throw file_error(); }
if (this->inpf.fail()) { throw file_error(file1.str()); }
}
template<class T>
@@ -221,12 +221,12 @@ void Triple_Checking(const Player& P,MAC_Check<gf2n_short>& MC,int nm)
/* Open file for reading in the initial triples */
stringstream file1; file1 << PREP_DIR "Initial-Triples-" << file_completion(dummy) << "-P" << P.my_num();
ifstream inpf(file1.str().c_str(),ios::in | ios::binary);
if (inpf.fail()) { throw file_error(); }
if (inpf.fail()) { throw file_error(file1.str()); }
/* Open file for writing out the final triples */
stringstream file3; file3 << PREP_DIR "Triples-" << file_completion(dummy) << "-P" << P.my_num();
ofstream outf(file3.str().c_str(),ios::out | ios::binary);
if (outf.fail()) { throw file_error(); }
if (outf.fail()) { throw file_error(file3.str()); }
gf2n_short te,t;
Create_Random(t,P);
@@ -444,12 +444,12 @@ void Square_Checking(const Player& P,MAC_Check<gf2n_short>& MC,int ns)
/* Open files for reading in the initial data */
stringstream file1; file1 << PREP_DIR "Initial-Squares-" << file_completion(dummy) << "-P" << P.my_num();
ifstream inpf_s(file1.str().c_str(),ios::in | ios::binary);
if (inpf_s.fail()) { throw file_error(); }
if (inpf_s.fail()) { throw file_error(file1.str()); }
/* Open files for writing out the final data */
stringstream file3; file3 << PREP_DIR "Squares-" << file_completion(dummy) << "-P" << P.my_num();
ofstream outf_s(file3.str().c_str(),ios::out | ios::binary);
if (outf_s.fail()) { throw file_error(); }
if (outf_s.fail()) { throw file_error(file3.str()); }
gf2n_short te,t,t2;
Create_Random(t,P);

View File

@@ -153,7 +153,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
others_ciphertexts.resize(this->sec, pk.get_params());
for (int i = 1; i < P.num_players(); i++)
{
#ifdef VERBOSE
#ifdef VERBOSE_HE
cerr << "Sending proof with " << 1e-9 * ciphertexts.get_length() << "+"
<< 1e-9 * cleartexts.get_length() << " GB" << endl;
#endif
@@ -164,7 +164,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
#ifndef LESS_ALLOC_MORE_MEM
Verifier<FD,S> verifier(proof);
#endif
#ifdef VERBOSE
#ifdef VERBOSE_HE
cerr << "Checking proof of player " << i << endl;
#endif
timers["Verifying"].start();

View File

@@ -7,6 +7,8 @@
#include "GC/Processor.h"
#include "GC/square64.h"
#include "GC/Processor.hpp"
namespace GC
{
@@ -14,7 +16,7 @@ int FakeSecret::default_length = 128;
ostream& FakeSecret::out = cout;
void FakeSecret::load(int n, const Integer& x)
void FakeSecret::load_clear(int n, const Integer& x)
{
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
throw out_of_range("public value too long");

View File

@@ -79,7 +79,7 @@ public:
__uint128_t operator^=(const FakeSecret& other) { return a ^= other.a; }
void load(int n, const Integer& x);
void load_clear(int n, const Integer& x);
template <class T>
void load(int n, const Memory<T>& mem, size_t address) { load(n, mem[address]); }
template <class T>

View File

@@ -93,8 +93,8 @@ unsigned GC::Instruction<T>::get_max_reg(int reg_type) const
offset = 1;
break;
case INPUTB:
skip = 3;
offset = 2;
skip = 4;
offset = 3;
break;
case CONVCBIT:
return BaseInstruction::get_max_reg(INT);

View File

@@ -23,9 +23,6 @@
namespace GC
{
extern template class ReplicatedSecret<SemiHonestRepSecret>;
extern template class ReplicatedSecret<MaliciousRepSecret>;
#define GC_MACHINE(T) \
template class Instruction<T>; \
template class Machine<T>; \
@@ -34,8 +31,4 @@ extern template class ReplicatedSecret<MaliciousRepSecret>;
template class Thread<T>; \
template class ThreadMaster<T>; \
GC_MACHINE(FakeSecret);
GC_MACHINE(SemiHonestRepSecret);
GC_MACHINE(MaliciousRepSecret)
} /* namespace GC */

View File

@@ -1,34 +0,0 @@
/*
* MaliciousRepPrep.h
*
*/
#ifndef GC_MALICIOUSREPPREP_H_
#define GC_MALICIOUSREPPREP_H_
#include "MaliciousRepSecret.h"
#include "Protocols/ReplicatedPrep.h"
namespace GC
{
class MaliciousRepPrep : public BufferPrep<MaliciousRepSecret>
{
ReplicatedBase* protocol;
public:
MaliciousRepPrep(DataPositions& usage);
~MaliciousRepPrep();
void set_protocol(MaliciousRepSecret::Protocol& protocol);
void buffer_triples();
void buffer_bits();
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
};
} /* namespace GC */
#endif /* GC_MALICIOUSREPPREP_H_ */

View File

@@ -6,7 +6,7 @@
#ifndef GC_MALICIOUSREPSECRET_H_
#define GC_MALICIOUSREPSECRET_H_
#include "ReplicatedSecret.h"
#include "ShareSecret.h"
#include "Machine.h"
#include "Protocols/Beaver.h"
#include "Protocols/MaliciousRepMC.h"
@@ -17,7 +17,8 @@ template<class T> class MaliciousRepMC;
namespace GC
{
class MaliciousRepThread;
template<class T> class ShareThread;
template<class T> class RepPrep;
class MaliciousRepSecret : public ReplicatedSecret<MaliciousRepSecret>
{
@@ -30,7 +31,8 @@ public:
typedef MC MAC_Check;
typedef Beaver<MaliciousRepSecret> Protocol;
typedef NotImplementedInput Input;
typedef ReplicatedInput<MaliciousRepSecret> Input;
typedef RepPrep<MaliciousRepSecret> LivePrep;
static MC* new_mc(Machine<MaliciousRepSecret>& machine)
{

View File

@@ -1,71 +0,0 @@
/*
* MalicousRepParty.cpp
*
*/
#include "Protocols/MaliciousRepMC.h"
#include "MaliciousRepThread.h"
#include "Math/Setup.h"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Processor/Data_Files.hpp"
namespace GC
{
thread_local MaliciousRepThread* MaliciousRepThread::singleton = 0;
MaliciousRepThread::MaliciousRepThread(int i,
ThreadMaster<MaliciousRepSecret>& master) :
Thread<MaliciousRepSecret>(i, master), DataF(usage)
{
}
void MaliciousRepThread::pre_run()
{
if (singleton)
throw runtime_error("there can only be one");
singleton = this;
DataF.set_protocol(*protocol);
}
void MaliciousRepThread::post_run()
{
#ifndef INSECURE
cerr << "Removing used pre-processed data" << endl;
DataF.prune();
#endif
}
void MaliciousRepThread::and_(Processor<MaliciousRepSecret>& processor,
const vector<int>& args, bool repeat)
{
assert(P->num_players() == 3);
processor.check_args(args, 4);
protocol->init_mul(DataF, *MC);
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
int left = args[i + 2];
int right = args[i + 3];
MaliciousRepSecret y_ext;
if (repeat)
y_ext = processor.S[right].extend_bit();
else
y_ext = processor.S[right];
protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits));
}
protocol->exchange();
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
int out = args[i + 1];
processor.S[out] = protocol->finalize_mul().mask(n_bits);
}
}
} /* namespace GC */

View File

@@ -1,48 +0,0 @@
/*
* MalicousRepParty.h
*
*/
#ifndef GC_MALICIOUSREPTHREAD_H_
#define GC_MALICIOUSREPTHREAD_H_
#include "Thread.h"
#include "MaliciousRepSecret.h"
#include "MaliciousRepPrep.h"
#include "Processor/Data_Files.h"
#include <array>
namespace GC
{
class MaliciousRepThread : public Thread<MaliciousRepSecret>
{
static thread_local MaliciousRepThread* singleton;
public:
static MaliciousRepThread& s();
DataPositions usage;
MaliciousRepPrep DataF;
MaliciousRepThread(int i, ThreadMaster<MaliciousRepSecret>& master);
virtual ~MaliciousRepThread() {}
void pre_run();
void post_run();
void and_(Processor<MaliciousRepSecret>& processor, const vector<int>& args, bool repeat);
};
inline MaliciousRepThread& MaliciousRepThread::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
} /* namespace GC */
#endif /* GC_MALICIOUSREPTHREAD_H_ */

46
GC/RepPrep.h Normal file
View File

@@ -0,0 +1,46 @@
/*
* MaliciousRepPrep.h
*
*/
#ifndef GC_REPPREP_H_
#define GC_REPPREP_H_
#include "MaliciousRepSecret.h"
#include "ShiftableTripleBuffer.h"
#include "Protocols/ReplicatedPrep.h"
namespace GC
{
template<class T>
class RepPrep : public BufferPrep<T>, ShiftableTripleBuffer<T>
{
ReplicatedBase* protocol;
public:
RepPrep(DataPositions& usage, Thread<T>& thread);
~RepPrep();
void set_protocol(typename T::Protocol& protocol);
void buffer_triples();
void buffer_bits();
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
void get(Dtype type, T* data)
{
BufferPrep<T>::get(type, data);
}
array<T, 3> get_triple(int n_bits)
{
return ShiftableTripleBuffer<T>::get_triple(n_bits);
}
};
} /* namespace GC */
#endif /* GC_REPPREP_H_ */

View File

@@ -3,8 +3,8 @@
*
*/
#include "MaliciousRepPrep.h"
#include "MaliciousRepThread.h"
#include "RepPrep.h"
#include "ShareThread.h"
#include "Processor/OnlineOptions.h"
#include "Protocols/MalRepRingPrep.hpp"
@@ -15,33 +15,40 @@
namespace GC
{
MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage) :
BufferPrep<MaliciousRepSecret>(usage), protocol(0)
template<class T>
RepPrep<T>::RepPrep(DataPositions& usage, Thread<T>& thread) :
BufferPrep<T>(usage), protocol(0)
{
(void) thread;
}
MaliciousRepPrep::~MaliciousRepPrep()
template<class T>
RepPrep<T>::~RepPrep()
{
if (protocol)
delete protocol;
}
void MaliciousRepPrep::set_protocol(MaliciousRepSecret::Protocol& protocol)
template<class T>
void RepPrep<T>::set_protocol(typename T::Protocol& protocol)
{
this->protocol = new ReplicatedBase(protocol.P);
}
void MaliciousRepPrep::buffer_triples()
template<class T>
void RepPrep<T>::buffer_triples()
{
assert(protocol != 0);
auto MC = MaliciousRepThread::s().new_mc();
shuffle_triple_generation(triples, protocol->P, *MC, 64);
auto MC = ShareThread<T>::s().new_mc();
shuffle_triple_generation(this->triples, protocol->P, *MC, 64);
delete MC;
}
void MaliciousRepPrep::buffer_bits()
template<class T>
void RepPrep<T>::buffer_bits()
{
assert(this->protocol != 0);
assert(this->protocol->P.num_players() == 3);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
{
this->bits.push_back({});

View File

@@ -1,113 +0,0 @@
/*
* ReplicatedParty.cpp
*
*/
#include "ReplicatedParty.h"
#include "Thread.h"
#include "MaliciousRepThread.h"
#include "Networking/Server.h"
#include "Tools/ezOptionParser.h"
#include "Tools/benchmarking.h"
namespace GC
{
template<class T>
ReplicatedParty<T>::ReplicatedParty(int argc, const char** argv) :
ThreadMaster<T>(online_opts), online_opts(opt, argc, argv)
{
opt.add(
"localhost", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Host where party 0 is running (default: localhost)", // Help description.
"-h", // Flag token.
"--hostname" // Flag token.
);
opt.add(
"5000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Base port number (default: 5000).", // Help description.
"-pn", // Flag token.
"--portnum" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Unencrypted communication.", // Help description.
"-u", // Flag token.
"--unencrypted" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Check opening by communication instead of hashing.", // Help description.
"-c", // Flag token.
"--communication" // Flag token.
);
online_opts.finalize(opt, argc, argv);
this->progname = online_opts.progname;
int my_num = online_opts.playerno;
int pnb;
string hostname;
opt.get("-pn")->getInt(pnb);
opt.get("-h")->getString(hostname);
this->machine.use_encryption = not opt.get("-u")->isSet;
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
T::out.activate(my_num == 0 or online_opts.interactive);
if (not this->machine.use_encryption)
insecure("unencrypted communication");
Server* server = Server::start_networking(this->N, my_num, 3, hostname, pnb);
this->run();
this->machine.write_memory(this->N.my_num());
if (server)
delete server;
}
template<>
Thread<SemiHonestRepSecret>* ReplicatedParty<SemiHonestRepSecret>::new_thread(int i)
{
return ThreadMaster<SemiHonestRepSecret>::new_thread(i);
}
template<>
Thread<MaliciousRepSecret>* ReplicatedParty<MaliciousRepSecret>::new_thread(int i)
{
return new MaliciousRepThread(i, *this);
}
template<>
void ReplicatedParty<SemiHonestRepSecret>::post_run()
{
}
template<>
void ReplicatedParty<MaliciousRepSecret>::post_run()
{
DataPositions usage;
for (auto thread : threads)
usage.increase(((MaliciousRepThread*)thread)->usage);
usage.print_cost();
}
extern template class ReplicatedSecret<SemiHonestRepSecret>;
extern template class ReplicatedSecret<MaliciousRepSecret>;
template class ReplicatedParty<SemiHonestRepSecret>;
template class ReplicatedParty<MaliciousRepSecret>;
}

View File

@@ -1,44 +0,0 @@
/*
* ReplicatedParty.h
*
*/
#ifndef GC_REPLICATEDPARTY_H_
#define GC_REPLICATEDPARTY_H_
#include "Protocols/ReplicatedMC.h"
#include "Protocols/MaliciousRepMC.h"
#include "ReplicatedSecret.h"
#include "Processor.h"
#include "Program.h"
#include "Memory.h"
#include "ThreadMaster.h"
namespace GC
{
template<class T>
class ReplicatedParty : public ThreadMaster<T>
{
ez::ezOptionParser opt;
OnlineOptions online_opts;
public:
static Thread<T>& s();
ReplicatedParty(int argc, const char** argv);
Thread<T>* new_thread(int i);
void post_run();
};
template<class T>
inline Thread<T>& ReplicatedParty<T>::s()
{
return Thread<T>::s();
}
}
#endif /* GC_REPLICATEDPARTY_H_ */

View File

@@ -1,254 +0,0 @@
/*
* ReplicatedSecret.cpp
*
*/
#include "ReplicatedSecret.h"
#include "ReplicatedParty.h"
#include "MaliciousRepSecret.h"
#include "Protocols/MaliciousRepMC.h"
#include "MaliciousRepThread.h"
#include "Thread.h"
#include "square64.h"
#include "Protocols/Share.h"
#include "Protocols/ReplicatedMC.hpp"
#include "Protocols/Replicated.hpp"
namespace GC
{
template<class U>
int ReplicatedSecret<U>::default_length = 8 * sizeof(ReplicatedSecret<U>::value_type);
template<class U>
SwitchableOutput ReplicatedSecret<U>::out;
template<class U>
void ReplicatedSecret<U>::load(int n, const Integer& x)
{
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
throw out_of_range("public value too long");
*this = x;
}
template<class U>
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
template<class U>
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;
}
template<class U>
void ReplicatedSecret<U>::load(vector<ReadAccess<U> >& accesses,
const Memory<U>& mem)
{
for (auto access : accesses)
access.dest = mem[access.address];
}
template<class U>
void ReplicatedSecret<U>::store(Memory<U>& mem,
vector<WriteAccess<U> >& accesses)
{
for (auto access : accesses)
mem[access.address] = access.source;
}
template<class U>
void ReplicatedSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
const vector<ClearWriteAccess>& accesses)
{
for (auto access : accesses)
mem[access.address] = access.value;
}
template<class U>
void ReplicatedSecret<U>::inputb(Processor<U>& processor,
const vector<int>& args)
{
auto& party = ReplicatedParty<U>::s();
party.os.resize(2);
for (auto& o : party.os)
o.reset_write_head();
InputArgList a(args);
bool interactive = party.n_interactive_inputs_from_me(a) > 0;
for (auto x : a)
{
if (x.from == party.P->my_num())
{
auto& res = processor.S[x.dest];
res.prepare_input(party.os, processor.get_input(x.params, interactive), x.n_bits, party.secure_prng);
}
}
if (interactive)
cout << "Thank you" << endl;
for (int i = 0; i < 2; i++)
party.P->pass_around(party.os[i], i + 1);
for (auto x : a)
{
int from = x.from;
int n_bits = x.n_bits;
if (from != party.P->my_num())
{
auto& res = processor.S[x.dest];
res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits);
}
}
}
template<class U>
U ReplicatedSecret<U>::input(Processor<U>& processor, const InputArgs& args)
{
int from = args.from;
int n_bits = args.n_bits;
auto& party = ReplicatedParty<U>::s();
U res;
party.os.resize(2);
for (auto& o : party.os)
o.reset_write_head();
if (from == party.P->my_num())
{
res.prepare_input(party.os, processor.get_input(args.params), n_bits, party.secure_prng);
party.P->send_relative(party.os);
}
else
{
party.P->receive_player(from, party.os[0], true);
res.finalize_input(party, party.os[0], from, n_bits);
}
return res;
}
template<class U>
void ReplicatedSecret<U>::prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng)
{
randomize_to_sum(input, secure_prng);
*this &= get_mask(n_bits);
for (int i = 0; i < 2; i++)
BitVec(get_mask(n_bits) & (*this)[i]).pack(os[i], n_bits);
}
template<class U>
void ReplicatedSecret<U>::finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits)
{
int j = party.P->get_offset(from) == 2;
(*this)[j] = BitVec::unpack_new(o, n_bits);
(*this)[1 - j] = 0;
}
template<class U>
BitVec ReplicatedSecret<U>::local_mul(const ReplicatedSecret& other) const
{
return (*this)[0] * other.sum() + (*this)[1] * other[0];
}
template<class U>
void ReplicatedSecret<U>::and_(int n,
const ReplicatedSecret<U>& x,
const ReplicatedSecret<U>& y, bool repeat)
{
(void)n, (void)x, (void)y, (void)repeat;
throw runtime_error("use static method");
}
template<>
void ReplicatedSecret<SemiHonestRepSecret>::and_(Processor<SemiHonestRepSecret>& processor,
const vector<int>& args, bool repeat)
{
auto& party = Thread<SemiHonestRepSecret>::s();
assert(party.P->num_players() == 3);
processor.check_args(args, 4);
assert(party.protocol != 0);
auto& protocol = *party.protocol;
protocol.init_mul();
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
int left = args[i + 2];
int right = args[i + 3];
MaliciousRepSecret y_ext;
if (repeat)
y_ext = processor.S[right].extend_bit();
else
y_ext = processor.S[right];
protocol.prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits);
}
protocol.exchange();
for (size_t i = 0; i < args.size(); i += 4)
processor.S[args[i + 1]] = protocol.finalize_mul(args[i]);
}
template<>
void ReplicatedSecret<MaliciousRepSecret>::and_(
Processor<MaliciousRepSecret>& processor, const vector<int>& args,
bool repeat)
{
MaliciousRepThread::s().and_(processor, args, repeat);
}
template<class U>
void ReplicatedSecret<U>::trans(Processor<U>& processor,
int n_outputs, const vector<int>& args)
{
assert(length == 2);
for (int k = 0; k < 2; k++)
{
square64 square;
for (size_t i = n_outputs; i < args.size(); i++)
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
square.transpose(args.size() - n_outputs, n_outputs);
for (int i = 0; i < n_outputs; i++)
processor.S[args[i]][k] = square.rows[i];
}
}
template<class U>
void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
{
(void) n_bits;
ReplicatedSecret share = *this;
vector<BitVec> opened;
auto& party = ReplicatedParty<U>::s();
party.MC->POpen_Begin(opened, {share}, *party.P);
party.MC->POpen_End(opened, {share}, *party.P);
x = IntBase(opened[0]);
}
template<>
void ReplicatedSecret<SemiHonestRepSecret>::random_bit()
{
auto& party = ReplicatedParty<SemiHonestRepSecret>::s();
*this = party.secure_prng.get_bit();
octetStream o;
(*this)[0].pack(o, 1);
party.P->pass_around(o, 1);
(*this)[1].unpack(o, 1);
}
template<>
void ReplicatedSecret<MaliciousRepSecret>::random_bit()
{
MaliciousRepSecret res;
MaliciousRepThread::s().DataF.get_one(DATA_BIT, res);
*this = res;
}
template class ReplicatedSecret<SemiHonestRepSecret>;
template class ReplicatedSecret<MaliciousRepSecret>;
}

View File

@@ -83,6 +83,8 @@ public:
static typename T::out_type out;
static const bool needs_ot = false;
static T& cast(T& reg) { return *reinterpret_cast<T*>(&reg); }
static const T& cast(const T& reg) { return *reinterpret_cast<const T*>(&reg); }
@@ -98,34 +100,40 @@ public:
static Secret<T> carryless_mult(const Secret<T>& x, const Secret<T>& y);
static void output(T& reg);
template<class U>
static void load(vector< ReadAccess< Secret<T> > >& accesses, const U& mem);
template<class U>
static void store(U& mem, vector< WriteAccess< Secret<T> > >& accesses);
template<class U, class V>
static void load(vector< ReadAccess<V> >& accesses, const U& mem);
template<class U, class V>
static void store(U& mem, vector< WriteAccess<V> >& accesses);
static void andrs(Processor< Secret<T> >& processor, const vector<int>& args)
template<class U>
static void andrs(Processor<U>& processor, const vector<int>& args)
{ T::andrs(processor, args); }
static void ands(Processor< Secret<T> >& processor, const vector<int>& args)
template<class U>
static void ands(Processor<U>& processor, const vector<int>& args)
{ T::ands(processor, args); }
static void inputb(Processor< Secret<T> >& processor, const vector<int>& args)
template<class U>
static void inputb(Processor<U>& processor, const vector<int>& args)
{ T::inputb(processor, args); }
static void trans(Processor<Secret<T> >& processor, int n_inputs, const vector<int>& args);
template<class U>
static void trans(Processor<U>& processor, int n_inputs, const vector<int>& args);
static void convcbit(Integer& dest, const Clear& source) { T::convcbit(dest, source); }
Secret();
Secret(const Integer& x) { *this = x; }
void load(int n, const Integer& x);
void operator=(const Integer& x) { load(default_length, x); }
void load_clear(int n, const Integer& x);
void operator=(const Integer& x) { load_clear(default_length, x); }
void load(int n, const Memory<AuthValue>& mem, size_t address);
Secret<T> operator<<(int i);
Secret<T> operator>>(int i);
void bitcom(Memory< Secret<T> >& S, const vector<int>& regs);
void bitdec(Memory< Secret<T> >& S, const vector<int>& regs) const;
template<class U>
void bitcom(Memory<U>& S, const vector<int>& regs);
template<class U>
void bitdec(Memory<U>& S, const vector<int>& regs) const;
Secret<T> operator+(const Secret<T> x) const;
Secret<T>& operator+=(const Secret<T> x) { *this = *this + x; return *this; }

View File

@@ -102,9 +102,9 @@ void Secret<T>::random_bit()
}
template <class T>
template <class U>
template <class U, class V>
void Secret<T>::store(U& mem,
vector<WriteAccess<Secret<T> > >& accesses)
vector<WriteAccess<V> >& accesses)
{
T::store(mem, accesses);
}
@@ -194,7 +194,7 @@ T& GC::Secret<T>::get_new_reg()
}
template <class T>
void Secret<T>::load(int n, const Integer& x)
void Secret<T>::load_clear(int n, const Integer& x)
{
if ((unsigned)n < 8 * sizeof(x) and abs(x.get()) > (1LL << n))
throw out_of_range("public value too long");
@@ -219,8 +219,8 @@ void Secret<T>::load(int n, const Integer& x)
}
template <class T>
template <class U>
void Secret<T>::load(vector<ReadAccess < Secret<T> > >& accesses, const U& mem)
template <class U, class V>
void Secret<T>::load(vector<ReadAccess <V> >& accesses, const U& mem)
{
for (auto&& access : accesses)
{
@@ -252,7 +252,8 @@ Secret<T> Secret<T>::operator>>(int i)
}
template <class T>
void Secret<T>::bitcom(Memory<Secret>& S, const vector<int>& regs)
template <class U>
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
{
registers.clear();
for (unsigned int i = 0; i < regs.size(); i++)
@@ -264,7 +265,8 @@ void Secret<T>::bitcom(Memory<Secret>& S, const vector<int>& regs)
}
template <class T>
void Secret<T>::bitdec(Memory<Secret>& S, const vector<int>& regs) const
template <class U>
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
{
if (regs.size() > registers.size())
throw out_of_range(
@@ -280,7 +282,8 @@ void Secret<T>::bitdec(Memory<Secret>& S, const vector<int>& regs) const
}
template<class T>
void Secret<T>::trans(Processor<Secret<T> >& processor, int n_outputs,
template<class U>
void Secret<T>::trans(Processor<U>& processor, int n_outputs,
const vector<int>& args)
{
int n_inputs = args.size() - n_outputs;

11
GC/SemiHonestRepPrep.cpp Normal file
View File

@@ -0,0 +1,11 @@
/*
* ReplicatedPrep.cpp
*
*/
#include <GC/SemiHonestRepPrep.h>
namespace GC
{
} /* namespace GC */

28
GC/SemiHonestRepPrep.h Normal file
View File

@@ -0,0 +1,28 @@
/*
* ReplicatedPrep.h
*
*/
#ifndef GC_SEMIHONESTREPPREP_H_
#define GC_SEMIHONESTREPPREP_H_
#include "RepPrep.h"
#include "ShareSecret.h"
namespace GC
{
class SemiHonestRepPrep : public RepPrep<SemiHonestRepSecret>
{
public:
SemiHonestRepPrep(DataPositions& usage, Thread<SemiHonestRepSecret>& thread) :
RepPrep<SemiHonestRepSecret>(usage, thread)
{
}
void buffer_triples() { throw not_implemented(); }
};
} /* namespace GC */
#endif /* GC_SEMIHONESTREPPREP_H_ */

58
GC/SemiPrep.cpp Normal file
View File

@@ -0,0 +1,58 @@
/*
* SemiPrep.cpp
*
*/
#include "SemiPrep.h"
#include "ThreadMaster.h"
#include "OT/NPartyTripleGenerator.h"
#include "OT/BitDiagonal.h"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "OT/NPartyTripleGenerator.hpp"
namespace GC
{
SemiPrep::SemiPrep(DataPositions& usage, Thread<SemiSecret>& thread) :
BufferPrep<SemiSecret>(usage), thread(thread), triple_generator(0)
{
}
void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
{
(void) protocol;
params.set_passive();
triple_generator = new SemiSecret::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
thread.master.N, thread.thread_num, thread.master.opts.batch_size,
1, params, thread.P);
triple_generator->multi_threaded = false;
}
void SemiPrep::buffer_triples()
{
assert(this->triple_generator);
this->triple_generator->generatePlainTriples();
for (auto& x : this->triple_generator->plainTriples)
{
this->triples.push_back({{x[0], x[1], x[2]}});
}
this->triple_generator->unlock();
}
SemiPrep::~SemiPrep()
{
if (triple_generator)
delete triple_generator;
}
void SemiPrep::buffer_bits()
{
word r = thread.secure_prng.get_word();
for (size_t i = 0; i < sizeof(word) * 8; i++)
this->bits.push_back((r >> i) & 1);
}
} /* namespace GC */

51
GC/SemiPrep.h Normal file
View File

@@ -0,0 +1,51 @@
/*
* SemiPrep.h
*
*/
#ifndef GC_SEMIPREP_H_
#define GC_SEMIPREP_H_
#include "Protocols/ReplicatedPrep.h"
#include "OT/TripleMachine.h"
#include "SemiSecret.h"
#include "ShiftableTripleBuffer.h"
template<class T> class Beaver;
namespace GC
{
class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret>
{
Thread<SemiSecret>& thread;
SemiSecret::TripleGenerator* triple_generator;
MascotParams params;
public:
SemiPrep(DataPositions& usage, Thread<SemiSecret>& thread);
~SemiPrep();
void set_protocol(Beaver<SemiSecret>& protocol);
void buffer_triples();
void buffer_bits();
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
void get(Dtype type, SemiSecret* data)
{
BufferPrep<SemiSecret>::get(type, data);
}
array<SemiSecret, 3> get_triple(int n_bits)
{
return ShiftableTripleBuffer<SemiSecret>::get_triple(n_bits);
}
};
} /* namespace GC */
#endif /* GC_SEMIPREP_H_ */

52
GC/SemiSecret.cpp Normal file
View File

@@ -0,0 +1,52 @@
/*
* SemiSecret.cpp
*
*/
#include "GC/ShareParty.h"
#include "SemiSecret.h"
#include "GC/ShareSecret.hpp"
#include "Protocols/MAC_Check_Base.hpp"
namespace GC
{
void SemiSecret::trans(Processor<SemiSecret>& processor, int n_outputs,
const vector<int>& args)
{
square64 square;
for (size_t i = n_outputs; i < args.size(); i++)
square.rows[i - n_outputs] = processor.S[args[i]].get();
square.transpose(args.size() - n_outputs, n_outputs);
for (int i = 0; i < n_outputs; i++)
processor.S[args[i]] = square.rows[i];
}
void SemiSecret::load_clear(int n, const Integer& x)
{
check_length(n, x);
*this = constant(x, Thread<SemiSecret>::s().P->my_num());
}
void SemiSecret::bitcom(Memory<SemiSecret>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
void SemiSecret::bitdec(Memory<SemiSecret>& S,
const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;
}
void SemiSecret::reveal(size_t n_bits, Clear& x)
{
auto& thread = Thread<SemiSecret>::s();
x = thread.MC->POpen(*this, *thread.P).mask(n_bits);
}
} /* namespace GC */

67
GC/SemiSecret.h Normal file
View File

@@ -0,0 +1,67 @@
/*
* SemiSecret.h
*
*/
#ifndef GC_SEMISECRET_H_
#define GC_SEMISECRET_H_
#include "Protocols/SemiMC.h"
#include "Protocols/SemiShare.h"
#include "Processor/DummyProtocol.h"
#include "ShareSecret.h"
template<class T> class Beaver;
namespace GC
{
class SemiPrep;
class SemiSecret : public SemiShare<BitVec>, public ShareSecret<SemiSecret>
{
public:
typedef Memory<SemiSecret> DynamicMemory;
typedef SemiMC<SemiSecret> MC;
typedef Beaver<SemiSecret> Protocol;
typedef MC MAC_Check;
typedef SemiPrep LivePrep;
typedef SemiInput<SemiSecret> Input;
static const int default_length = sizeof(BitVec) * 8;
static string type_string() { return "binary secret"; }
static string phase_name() { return "Binary computation"; }
static MC* new_mc(Machine<SemiSecret>& _) { (void) _; return new MC; }
static void trans(Processor<SemiSecret>& processor, int n_outputs,
const vector<int>& args);
SemiSecret()
{
}
SemiSecret(long other) :
SemiShare<BitVec>(other)
{
}
SemiSecret(const IntBase& other) :
SemiShare<BitVec>(other)
{
}
void load_clear(int n, const Integer& x);
void bitcom(Memory<SemiSecret>& S, const vector<int>& regs);
void bitdec(Memory<SemiSecret>& S, const vector<int>& regs) const;
void xor_(int n, const SemiSecret& x, const SemiSecret& y)
{ *this = BitVec(x ^ y).mask(n); }
void reveal(size_t n_bits, Clear& x);
};
} /* namespace GC */
#endif /* GC_SEMISECRET_H_ */

51
GC/ShareParty.h Normal file
View File

@@ -0,0 +1,51 @@
/*
* ReplicatedParty.h
*
*/
#ifndef GC_SHAREPARTY_H_
#define GC_SHAREPARTY_H_
#include "Protocols/ReplicatedMC.h"
#include "Protocols/MaliciousRepMC.h"
#include "ShareSecret.h"
#include "Processor.h"
#include "Program.h"
#include "Memory.h"
#include "ThreadMaster.h"
namespace GC
{
template<class T>
class ShareParty : public ThreadMaster<T>
{
static ShareParty<T>* singleton;
ez::ezOptionParser opt;
OnlineOptions online_opts;
public:
static ShareParty& s();
typename T::mac_key_type mac_key;
ShareParty(int argc, const char** argv, int default_batch_size = 0);
Thread<T>* new_thread(int i);
void post_run();
};
template<class T>
inline ShareParty<T>& ShareParty<T>::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
}
#endif /* GC_SHAREPARTY_H_ */

137
GC/ShareParty.hpp Normal file
View File

@@ -0,0 +1,137 @@
/*
* ReplicatedParty.cpp
*
*/
#include "ShareParty.h"
#include "Thread.h"
#include "ShareThread.h"
#include "SemiPrep.h"
#include "Networking/Server.h"
#include "Networking/CryptoPlayer.h"
#include "Tools/ezOptionParser.h"
#include "Tools/benchmarking.h"
#include "Tools/NetworkOptions.h"
#include "Protocols/fake-stuff.h"
#include "ShareThread.hpp"
#include "RepPrep.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/fake-stuff.hpp"
namespace GC
{
template<class T>
ShareParty<T>* ShareParty<T>::singleton = 0;
template<class T>
ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
ThreadMaster<T>(online_opts), online_opts(opt, argc, argv,
default_batch_size)
{
if (singleton)
throw runtime_error("there can only be one");
singleton = this;
NetworkOptionsWithNumber network_opts(opt, argc, argv,
T::dishonest_majority ? 2 : 3, T::dishonest_majority);
if (T::dishonest_majority)
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use encrypted channels.", // Help description.
"-e", // Flag token.
"--encrypted" // Flag token.
);
else
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Unencrypted communication.", // Help description.
"-u", // Flag token.
"--unencrypted" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Check opening by communication instead of hashing.", // Help description.
"-c", // Flag token.
"--communication" // Flag token.
);
online_opts.finalize(opt, argc, argv);
this->progname = online_opts.progname;
int my_num = online_opts.playerno;
if (T::dishonest_majority)
this->machine.use_encryption = opt.get("-e")->isSet;
else
this->machine.use_encryption = not opt.get("-u")->isSet;
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
T::out.activate(my_num == 0 or online_opts.interactive);
if (not this->machine.use_encryption and not T::dishonest_majority)
insecure("unencrypted communication");
Server* server = network_opts.start_networking(this->N, my_num);
if (online_opts.live_prep)
if (T::needs_ot)
{
Player* P;
if (this->machine.use_encryption)
P = new CryptoPlayer(this->N, 0xFFFF);
else
P = new PlainPlayer(this->N, 0xFFFF);
for (int i = 0; i < this->machine.nthreads; i++)
this->machine.ot_setups.push_back({{{*P, true}}});
delete P;
}
try
{
gf2n _;
read_mac_keys(get_prep_dir(network_opts.nplayers, 128, 128), this->N,
this->mac_key, _);
}
catch (exception& e)
{
SeededPRNG G;
this->mac_key.randomize(G);
}
this->run();
this->machine.write_memory(this->N.my_num());
if (server)
delete server;
}
template<class T>
Thread<T>* ShareParty<T>::new_thread(int i)
{
return new ShareThread<T>(i, *this);
}
template<class T>
void ShareParty<T>::post_run()
{
DataPositions usage;
for (auto thread : this->threads)
usage.increase(dynamic_cast<ShareThread<T>*>(thread)->usage);
usage.print_cost();
}
}

View File

@@ -3,8 +3,8 @@
*
*/
#ifndef GC_REPLICATEDSECRET_H_
#define GC_REPLICATEDSECRET_H_
#ifndef GC_SHARESECRET_H_
#define GC_SHARESECRET_H_
#include <vector>
using namespace std;
@@ -18,6 +18,7 @@ using namespace std;
#include "Tools/SwitchableOutput.h"
#include "Protocols/Replicated.h"
#include "Protocols/ReplicatedMC.h"
#include "Processor/DummyProtocol.h"
namespace GC
{
@@ -32,22 +33,9 @@ template <class T>
class Machine;
template<class U>
class ReplicatedSecret : public FixedVec<BitVec, 2>
class ShareSecret
{
typedef FixedVec<BitVec, 2> super;
public:
typedef BitVec clear;
typedef BitVec open_type;
typedef BitVec mac_type;
typedef BitVec mac_key_type;
typedef ReplicatedBase Protocol;
static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; }
static int default_length;
static SwitchableOutput out;
static void store_clear_in_dynamic(Memory<U>& mem,
@@ -63,38 +51,59 @@ public:
static void and_(Processor<U>& processor, const vector<int>& args, bool repeat);
static void inputb(Processor<U>& processor, const vector<int>& args);
static void trans(Processor<U>& processor, int n_outputs,
const vector<int>& args);
static void convcbit(Integer& dest, const Clear& source) { dest = source; }
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }
static U input(Processor<U>& processor, const InputArgs& args);
void prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng);
void finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits);
void check_length(int n, const Integer& x);
void random_bit();
};
template<class U>
class ReplicatedSecret : public FixedVec<BitVec, 2>, public ShareSecret<U>
{
typedef FixedVec<BitVec, 2> super;
public:
typedef BitVec clear;
typedef BitVec open_type;
typedef BitVec mac_type;
typedef BitVec mac_key_type;
typedef ReplicatedBase Protocol;
static const int N_BITS = clear::N_BITS;
static const bool dishonest_majority = false;
static const bool needs_ot = false;
static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; }
static int default_length;
static void trans(Processor<U>& processor, int n_outputs,
const vector<int>& args);
ReplicatedSecret() {}
template <class T>
ReplicatedSecret(const T& other) : super(other) {}
void load(int n, const Integer& x);
void load_clear(int n, const Integer& x);
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
{ *this = x ^ y; (void)n; }
void and_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y, bool repeat);
void andrs(int n, const ReplicatedSecret& x, const ReplicatedSecret& y);
BitVec local_mul(const ReplicatedSecret& other) const;
void reveal(size_t n_bits, Clear& x);
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
{ *this = x ^ y; (void)n; }
void random_bit();
void reveal(size_t n_bits, Clear& x);
};
class SemiHonestRepPrep;
class SemiHonestRepSecret : public ReplicatedSecret<SemiHonestRepSecret>
{
@@ -106,6 +115,8 @@ public:
typedef ReplicatedMC<SemiHonestRepSecret> MC;
typedef Replicated<SemiHonestRepSecret> Protocol;
typedef MC MAC_Check;
typedef SemiHonestRepPrep LivePrep;
typedef ReplicatedInput<SemiHonestRepSecret> Input;
static MC* new_mc(Machine<SemiHonestRepSecret>& _) { (void) _; return new MC; }
@@ -116,4 +127,4 @@ public:
}
#endif /* GC_REPLICATEDSECRET_H_ */
#endif /* GC_SHARESECRET_H_ */

168
GC/ShareSecret.hpp Normal file
View File

@@ -0,0 +1,168 @@
/*
* ReplicatedSecret.cpp
*
*/
#include "ShareSecret.h"
#include "MaliciousRepSecret.h"
#include "Protocols/MaliciousRepMC.h"
#include "ShareThread.h"
#include "Thread.h"
#include "square64.h"
#include "Protocols/Share.h"
#include "Protocols/ReplicatedMC.hpp"
#include "Protocols/Beaver.hpp"
#include "ShareParty.h"
#include "ShareThread.hpp"
namespace GC
{
template<class U>
int ReplicatedSecret<U>::default_length = 8 * sizeof(typename ReplicatedSecret<U>::value_type);
template<class U>
SwitchableOutput ShareSecret<U>::out;
template<class U>
void ShareSecret<U>::check_length(int n, const Integer& x)
{
if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n))
throw out_of_range("public value too long");
}
template<class U>
void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
{
this->check_length(n, x);
*this = x;
}
template<class U>
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
template<class U>
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;
}
template<class U>
void ShareSecret<U>::load(vector<ReadAccess<U> >& accesses,
const Memory<U>& mem)
{
for (auto access : accesses)
access.dest = mem[access.address];
}
template<class U>
void ShareSecret<U>::store(Memory<U>& mem,
vector<WriteAccess<U> >& accesses)
{
for (auto access : accesses)
mem[access.address] = access.source;
}
template<class U>
void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
const vector<ClearWriteAccess>& accesses)
{
for (auto access : accesses)
mem[access.address] = access.value;
}
template<class U>
void ShareSecret<U>::inputb(Processor<U>& processor,
const vector<int>& args)
{
auto& party = ShareThread<U>::s();
typename U::Input input(*party.MC, party.DataF, *party.P);
input.reset_all(*party.P);
InputArgList a(args);
bool interactive = party.n_interactive_inputs_from_me(a) > 0;
for (auto x : a)
{
if (x.from == party.P->my_num())
{
input.add_mine(processor.get_input(x.params, interactive), x.n_bits);
}
else
input.add_other(x.from);
}
if (interactive)
cout << "Thank you" << endl;
input.exchange();
for (auto x : a)
{
int from = x.from;
int n_bits = x.n_bits;
auto& res = processor.S[x.dest];
res = input.finalize(from, n_bits).mask(n_bits);
}
}
template<class U>
BitVec ReplicatedSecret<U>::local_mul(const ReplicatedSecret& other) const
{
return (*this)[0] * other.sum() + (*this)[1] * other[0];
}
template<class U>
void ShareSecret<U>::and_(
Processor<U>& processor, const vector<int>& args,
bool repeat)
{
ShareThread<U>::s().and_(processor, args, repeat);
}
template<class U>
void ReplicatedSecret<U>::trans(Processor<U>& processor,
int n_outputs, const vector<int>& args)
{
assert(length == 2);
for (int k = 0; k < 2; k++)
{
square64 square;
for (size_t i = n_outputs; i < args.size(); i++)
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
square.transpose(args.size() - n_outputs, n_outputs);
for (int i = 0; i < n_outputs; i++)
processor.S[args[i]][k] = square.rows[i];
}
}
template<class U>
void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
{
(void) n_bits;
auto& share = *this;
vector<BitVec> opened;
auto& party = ShareThread<U>::s();
party.MC->POpen_Begin(opened, {share}, *party.P);
party.MC->POpen_End(opened, {share}, *party.P);
x = IntBase(opened[0]);
}
template<class U>
void ShareSecret<U>::random_bit()
{
U res;
ShareThread<U>::s().DataF.get_one(DATA_BIT, res);
*this = res;
}
}

55
GC/ShareThread.h Normal file
View File

@@ -0,0 +1,55 @@
/*
* MalicousRepParty.h
*
*/
#ifndef GC_SHARETHREAD_H_
#define GC_SHARETHREAD_H_
#include "Thread.h"
#include "MaliciousRepSecret.h"
#include "RepPrep.h"
#include "SemiHonestRepPrep.h"
#include "Processor/Data_Files.h"
#include "Protocols/ReplicatedInput.h"
#include <array>
namespace GC
{
template<class T>
class ShareThread : public Thread<T>
{
static thread_local ShareThread<T>* singleton;
public:
static ShareThread& s();
DataPositions usage;
Preprocessing<T>& DataF;
ShareThread(int i, ThreadMaster<T>& master);
virtual ~ShareThread();
void pre_run();
void post_run();
void and_(Processor<T>& processor, const vector<int>& args, bool repeat);
};
template<class T>
thread_local ShareThread<T>* ShareThread<T>::singleton = 0;
template<class T>
inline ShareThread<T>& ShareThread<T>::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
} /* namespace GC */
#endif /* GC_SHARETHREAD_H_ */

88
GC/ShareThread.hpp Normal file
View File

@@ -0,0 +1,88 @@
/*
* MalicousRepParty.cpp
*
*/
#ifndef GC_SHARETHREAD_HPP_
#define GC_SHARETHREAD_HPP_
#include <GC/ShareThread.h>
#include "Protocols/MaliciousRepMC.h"
#include "Math/Setup.h"
#include "Processor/Data_Files.hpp"
namespace GC
{
template<class T>
ShareThread<T>::ShareThread(int i,
ThreadMaster<T>& master) :
Thread<T>(i, master), usage(master.N.num_players()), DataF(
master.opts.live_prep ?
*(Preprocessing<T>*) new typename T::LivePrep(usage,
*this) :
*(Preprocessing<T>*) new Sub_Data_Files<T>(master.N,
get_prep_dir(master.N.num_players(), 128, 128),
usage))
{
}
template<class T>
ShareThread<T>::~ShareThread()
{
delete &DataF;
}
template<class T>
void ShareThread<T>::pre_run()
{
if (singleton)
throw runtime_error("there can only be one");
singleton = this;
assert(this->protocol != 0);
DataF.set_protocol(*this->protocol);
}
template<class T>
void ShareThread<T>::post_run()
{
#ifndef INSECURE
cerr << "Removing used pre-processed data" << endl;
DataF.prune();
#endif
}
template<class T>
void ShareThread<T>::and_(Processor<T>& processor,
const vector<int>& args, bool repeat)
{
auto& protocol = this->protocol;
processor.check_args(args, 4);
protocol->init_mul(DataF, *this->MC);
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
int left = args[i + 2];
int right = args[i + 3];
T y_ext;
if (repeat)
y_ext = processor.S[right].extend_bit();
else
y_ext = processor.S[right];
protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits);
}
protocol->exchange();
for (size_t i = 0; i < args.size(); i += 4)
{
int n_bits = args[i];
int out = args[i + 1];
processor.S[out] = protocol->finalize_mul(n_bits);
}
}
} /* namespace GC */
#endif

View File

@@ -0,0 +1,60 @@
/*
* ShiftableTripleBuffer.h
*
*/
#ifndef GC_SHIFTABLETRIPLEBUFFER_H_
#define GC_SHIFTABLETRIPLEBUFFER_H_
#include "Math/FixedVec.h"
#include <assert.h>
namespace GC
{
template<class T>
class ShiftableTripleBuffer
{
FixedVec<T, 3> triple_buffer;
int buffer_left;
virtual void get(Dtype type, T* data) = 0;
public:
ShiftableTripleBuffer() :
buffer_left(0)
{
}
virtual ~ShiftableTripleBuffer() {}
array<T, 3> get_triple(int n_bits)
{
int max_n_bits = T::N_BITS;
assert(n_bits <= max_n_bits);
assert(n_bits > 0);
array<T, 3> res;
if (n_bits <= buffer_left)
{
res = triple_buffer.mask(n_bits).get();
triple_buffer >>= n_bits;
buffer_left -= n_bits;
}
else
{
get(DATA_TRIPLE, res.data());
FixedVec<T, 3> tmp = res;
res = tmp.mask(n_bits).get();
triple_buffer += (tmp >> n_bits) << buffer_left;
buffer_left += max_n_bits - n_bits;
}
return res;
}
};
} /* namespace GC */
#endif /* GC_SHIFTABLETRIPLEBUFFER_H_ */

View File

@@ -39,7 +39,6 @@ public:
Names& N;
Player* P;
PRNG secure_prng;
vector<octetStream> os;
int thread_num;
WaitQueue<ScheduleItem> tape_schedule;

View File

@@ -56,6 +56,11 @@ void ThreadMaster<T>::run()
P = new PlainPlayer(N, 0xff << 24);
machine.load_schedule(progname);
if (T::needs_ot)
for (int i = 0; i < machine.nthreads; i++)
machine.ot_setups.push_back({{*P, true}, {*P, true}});
for (int i = 0; i < machine.nthreads; i++)
threads.push_back(new_thread(i));
for (auto thread : threads)

11
GC/TinyMC.cpp Normal file
View File

@@ -0,0 +1,11 @@
/*
* TinyMC.cpp
*
*/
#include "TinyMC.h"
namespace GC
{
} /* namespace GC */

67
GC/TinyMC.h Normal file
View File

@@ -0,0 +1,67 @@
/*
* TinyMC.h
*
*/
#ifndef GC_TINYMC_H_
#define GC_TINYMC_H_
#include "Protocols/MAC_Check_Base.h"
namespace GC
{
template<class T>
class TinyMC : public MAC_Check_Base<T>
{
typename T::part_type::MAC_Check part_MC;
vector<typename T::part_type::open_type> part_values;
vector<typename T::part_type::super> part_shares;
public:
TinyMC(typename T::mac_key_type mac_key) :
part_MC(mac_key)
{
this->alphai = mac_key;
}
typename T::part_type::MAC_Check& get_part_MC()
{
return part_MC;
}
void POpen_Begin(vector<typename T::open_type>& values, const vector<T>& S,
const Player& P)
{
values.clear();
part_shares.clear();
for (auto& share : S)
for (auto& part : share.get_regs())
part_shares.push_back(part);
part_MC.POpen_Begin(part_values, part_shares, P);
}
void POpen_End(vector<typename T::open_type>& values, const vector<T>& S,
const Player& P)
{
values.clear();
part_MC.POpen_End(part_values, part_shares, P);
int i = 0;
for (auto& share : S)
{
typename T::open_type opened = 0;
for (size_t j = 0; j < share.get_regs().size(); j++)
opened += typename T::open_type(part_values[i++].get_bit(0)) << j;
values.push_back(opened);
}
}
void Check(const Player& P)
{
part_MC.Check(P);
}
};
} /* namespace GC */
#endif /* GC_TINYMC_H_ */

52
GC/TinyPrep.h Normal file
View File

@@ -0,0 +1,52 @@
/*
* TinyPrep.h
*
*/
#ifndef GC_TINYPREP_H_
#define GC_TINYPREP_H_
#include "Thread.h"
#include "OT/TripleMachine.h"
#include "Protocols/Beaver.h"
#include "Protocols/ReplicatedPrep.h"
#include "Protocols/RandomPrep.h"
namespace GC
{
template<class T>
class TinyPrep : public BufferPrep<T>, public RandomPrep<typename T::part_type::super>
{
typedef Share<Z2<1 + T::part_type::s>> res_type;
Thread<T>& thread;
typename T::TripleGenerator* triple_generator;
typename T::part_type::TripleGenerator* input_generator;
MascotParams params;
vector<array<typename T::part_type, 3>> triple_buffer;
public:
TinyPrep(DataPositions& usage, Thread<T>& thread);
~TinyPrep();
void set_protocol(Beaver<T>& protocol);
void buffer_triples();
void buffer_bits();
void buffer_inputs(int player);
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
typename T::part_type::super get_random();
array<T, 3> get_triple(int n_bits);
};
} /* namespace GC */
#endif /* GC_TINYPREP_H_ */

174
GC/TinyPrep.hpp Normal file
View File

@@ -0,0 +1,174 @@
/*
* TinyPrep.cpp
*
*/
#include "TinyPrep.h"
namespace GC
{
template<class T>
TinyPrep<T>::TinyPrep(DataPositions& usage, Thread<T>& thread) :
BufferPrep<T>(usage), thread(thread), triple_generator(0),
input_generator(0)
{
}
template<class T>
TinyPrep<T>::~TinyPrep()
{
if (triple_generator)
delete triple_generator;
if (input_generator)
delete input_generator;
}
template<class T>
void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
{
(void) protocol;
params.generateMACs = true;
params.amplify = false;
params.check = false;
params.set_mac_key(thread.MC->get_alphai());
triple_generator = new typename T::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
thread.master.N, thread.thread_num,
thread.master.opts.batch_size,
1, params, thread.P);
triple_generator->multi_threaded = false;
input_generator = new typename T::part_type::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(1),
thread.master.N, thread.thread_num,
thread.master.opts.batch_size,
1, params, thread.P);
input_generator->multi_threaded = false;
thread.MC->get_part_MC().set_prep(*this);
}
template<class T>
void TinyPrep<T>::buffer_triples()
{
auto& triple_generator = this->triple_generator;
params.generateBits = false;
vector<array<typename T::part_type::super, 3>> triples;
ShuffleSacrifice<typename T::part_type::super> sacrifice;
while (int(triples.size()) < sacrifice.minimum_n_inputs())
{
triple_generator->generatePlainTriples();
triple_generator->unlock();
assert(triple_generator->plainTriples.size() != 0);
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
triple_generator->valueBits[2].set_portion(i,
triple_generator->plainTriples[i][2]);
triple_generator->run_multipliers({});
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
{
for (int j = 0; j < T::default_length; j++)
{
triples.push_back({});
for (int k = 0; k < 3; k++)
{
auto& share = triples.back()[k];
share.set_share(
triple_generator->plainTriples.at(i).at(k).get_bit(
j));
typename T::part_type::mac_type mac;
mac = thread.MC->get_alphai() * share.get_share();
for (auto& multiplier : triple_generator->ot_multipliers)
mac += multiplier->macs.at(k).at(i * T::default_length + j);
share.set_mac(mac);
}
}
}
}
sacrifice.triple_sacrifice(triples, triples,
*thread.P, thread.MC->get_part_MC());
for (size_t i = 0; i < triples.size() / T::default_length; i++)
{
this->triples.push_back({});
auto& triple = this->triples.back();
for (auto& x : triple)
x.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_triple = triples[j + i * T::default_length];
for (int k = 0; k < 3; k++)
triple[k].get_reg(j) = source_triple[k];
}
}
}
template<class T>
void TinyPrep<T>::buffer_bits()
{
auto tmp = BufferPrep<T>::get_random_from_inputs(thread.P->num_players());
for (auto& bit : tmp.get_regs())
{
this->bits.push_back({});
this->bits.back().resize_regs(1);
this->bits.back().get_reg(0) = bit;
}
}
template<class T>
void TinyPrep<T>::buffer_inputs(int player)
{
auto& inputs = this->inputs;
inputs.resize(thread.P->num_players());
assert(this->input_generator);
this->input_generator->generateInputs(player);
for (size_t i = 0; i < this->input_generator->inputs.size() / T::default_length; i++)
{
inputs[player].push_back({});
inputs[player].back().share.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_input = this->input_generator->inputs[j
+ i * T::default_length];
inputs[player].back().share.get_reg(j) = res_type(source_input.share);
inputs[player].back().value ^= typename T::open_type(
source_input.value.get_bit(0)) << j;
}
}
}
template<class T>
typename T::part_type::super GC::TinyPrep<T>::get_random()
{
T tmp;
this->get_one(DATA_BIT, tmp);
return tmp.get_reg(0);
}
template<class T>
array<T, 3> TinyPrep<T>::get_triple(int n_bits)
{
assert(n_bits > 0);
while ((unsigned)n_bits > triple_buffer.size())
{
array<T, 3> tmp;
this->get(DATA_TRIPLE, tmp.data());
for (size_t i = 0; i < tmp[0].get_regs().size(); i++)
{
triple_buffer.push_back(
{ {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} });
}
}
array<T, 3> res;
for (int j = 0; j < 3; j++)
res[j].resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
{
for (int j = 0; j < 3; j++)
res[j].get_reg(i) = triple_buffer.back()[j];
triple_buffer.pop_back();
}
return res;
}
} /* namespace GC */

11
GC/TinySecret.cpp Normal file
View File

@@ -0,0 +1,11 @@
/*
* TinySecret.cpp
*
*/
#include "TinySecret.h"
namespace GC
{
} /* namespace GC */

163
GC/TinySecret.h Normal file
View File

@@ -0,0 +1,163 @@
/*
* TinySecret.h
*
*/
#ifndef GC_TINYSECRET_H_
#define GC_TINYSECRET_H_
#include "Secret.h"
#include "TinyShare.h"
#include "ShareParty.h"
#include "OT/Rectangle.h"
#include "OT/BitDiagonal.h"
template<class T> class NPartyTripleGenerator;
template<class T> class OTTripleGenerator;
template<class T> class TinyMultiplier;
namespace GC
{
template<class T> class TinyPrep;
template<class T> class TinyMC;
template<int S>
class TinySecret : public Secret<TinyShare<S>>
{
typedef TinySecret This;
public:
typedef TinyShare<S> part_type;
typedef Secret<part_type> super;
typedef typename part_type::mac_key_type mac_key_type;
typedef BitVec open_type;
typedef BitVec clear;
typedef TinyMC<This> MC;
typedef MC MAC_Check;
typedef Beaver<This> Protocol;
typedef ::Input<This> Input;
typedef TinyPrep<This> LivePrep;
typedef Memory<This> DynamicMemory;
typedef OTTripleGenerator<This> TripleGenerator;
typedef TinyMultiplier<This> Multiplier;
typedef typename part_type::sacri_type sacri_type;
typedef typename part_type::mac_type mac_type;
typedef BitDiagonal Rectangle;
static const bool dishonest_majority = true;
static const bool needs_ot = true;
static const int default_length = 64;
static string type_short()
{
return "T";
}
static DataFieldType field_type()
{
return BitVec::field_type();
}
static int size()
{
return part_type::size() * default_length;
}
static MC* new_mc(Machine<This>& machine)
{
(void) machine;
return new MC(ShareParty<This>::s().mac_key);
}
static void store_clear_in_dynamic(Memory<This>& mem,
const vector<ClearWriteAccess>& accesses)
{
auto& party = ShareThread<This>::s();
for (auto access : accesses)
mem[access.address] = constant(access.value, party.P->my_num(),
{});
}
static This constant(BitVec other, int my_num, mac_key_type alphai)
{
This res;
res.resize_regs(other.length());
for (int i = 0; i < other.length(); i++)
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
return res;
}
TinySecret()
{
}
TinySecret(const super& other) :
super(other)
{
}
void assign(const char* buffer)
{
this->resize_regs(default_length);
for (int i = 0; i < default_length; i++)
this->get_reg(i).assign(buffer + i * part_type::size());
}
This operator-(const This& other) const
{
return *this + other;
}
This operator*(const BitVec& other) const
{
This res = *this;
for (int i = 0; i < super::size(); i++)
if (not other.get_bit(i))
res.get_reg(i) = {};
return res;
}
This extend_bit() const
{
This res;
res.get_regs().resize(BitVec::N_BITS, this->get_reg(0));
return res;
}
This mask(int n_bits) const
{
This res = *this;
res.get_regs().resize(n_bits);
return res;
}
void reveal(size_t n_bits, Clear& x)
{
auto& to_open = *this;
to_open.resize_regs(n_bits);
auto& party = ShareThread<This>::s();
x = party.MC->POpen(to_open, *party.P);
}
void output(ostream& s, bool human = true) const
{
assert(this->get_regs().size() == default_length);
for (auto& reg : this->get_regs())
reg.output(s, human);
}
};
template<int S>
inline TinySecret<S> operator*(const BitVec& clear, const TinySecret<S>& share)
{
return share * clear;
}
} /* namespace GC */
#endif /* GC_TINYSECRET_H_ */

11
GC/TinyShare.cpp Normal file
View File

@@ -0,0 +1,11 @@
/*
* TinyShare.cpp
*
*/
#include "TinyShare.h"
namespace GC
{
} /* namespace GC */

80
GC/TinyShare.h Normal file
View File

@@ -0,0 +1,80 @@
/*
* TinyShare.h
*
*/
#ifndef GC_TINYSHARE_H_
#define GC_TINYSHARE_H_
#include "ShareSecret.h"
#include "ShareParty.h"
#include "Secret.h"
#include "Protocols/Spdz2kShare.h"
#include "Processor/NoLivePrep.h"
namespace GC
{
template<int S> class TinySecret;
template<class T> class ShareThread;
template<int S>
class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>>
{
typedef TinyShare This;
public:
typedef Spdz2kShare<1, S> super;
typedef void DynamicMemory;
typedef NoLivePrep<This> LivePrep;
typedef SwitchableOutput out_type;
static string name()
{
return "tiny share";
}
static ShareThread<TinySecret<S>>& get_party()
{
return ShareThread<TinySecret<S>>::s();
}
static This new_reg()
{
return {};
}
TinyShare()
{
}
TinyShare(const typename super::super& other) :
super(other)
{
}
void XOR(const This& a, const This& b)
{
*this = a + b;
}
void public_input(bool input)
{
auto& party = get_party();
*this = super::constant(input, party.P->my_num(),
ShareParty < TinySecret < S >> ::s().mac_key);
}
void random()
{
TinySecret<S> tmp;
get_party().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
};
} /* namespace GC */
#endif /* GC_TINYSHARE_H_ */

View File

@@ -55,7 +55,7 @@
X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \
X(SHRCI, C0 = C1 >> IMM) \
X(SHLCI, C0 = C1 << IMM) \
X(LDBITS, S0.load(R1, IMM)) \
X(LDBITS, S0.load_clear(R1, IMM)) \
X(LDMS, S0 = MSD) \
X(STMS, MSD = S0) \
X(LDMSI, S0 = MSI) \
@@ -67,7 +67,7 @@
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \
X(STMSDI, PROC.store_dynamic_indirect(EXTRA, MD)) \
X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA, MD)) \
X(CONVSINT, S0.load(IMM, I1)) \
X(CONVSINT, S0.load_clear(IMM, I1)) \
X(CONVCINT, C0 = I1) \
X(CONVCBIT, T::convcbit(I0, C1)) \
X(MOVS, S0 = PS1) \

View File

@@ -25,6 +25,11 @@ Copyright (c) 2018, The University of Bristol, Bar-Ilan University
Please contact mks.keller@gmail.com
The same license as for SPDZ-2 applies.
___________________________________________________________________
SCALE-MAMBA [https://github.com/KULeuven-COSIC/SCALE-MAMBA]
Copyright (c) 2019, The University of Bristol, COSIC-KU Leuven
Please contact nigel.smart@kuleuven.be
See below for the full license.
___________________________________________________________________
University of Bristol : Open Access Software Licence
@@ -46,3 +51,27 @@ Any use of the software for scientific publications or commercial purposes shoul
Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk
___________________________________________________________________
This software incorporates components from the original SPDZ system, as well as various
extensions. It's copyright therefore rests with the following two institutions:
Copyright (c) 2017, The University of Bristol, Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
Copyright (c) 2018, COSIC-KU Leuven, Kasteelpark Arenberg 10, bus 2452, B-3001 Leuven-Heverlee, Belgium.
All rights reserved
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Any use of the software for commercial purposes should be reported to the nigel.smart@kuleuven.be
This is for impact and usage monitoring purposes only.
Enquiries about further applications and development opportunities are welcome. Please contact nigel.smart@kuleuven.be

View File

@@ -65,18 +65,6 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
"-ip", // Flag token.
"--ip-file-name" // Flag token.
);
opt.add(
"empty", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Where to obtain memory, new|old|empty (default: empty)\n\t"
"new: copy from Player-Memory-P<i> file\n\t"
"old: reuse previous memory in Memory-P<i>\n\t"
"empty: create new empty memory", // Help description.
"-m", // Flag token.
"--memory" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
@@ -143,14 +131,13 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
"--encrypted" // Flag token.
);
string memtype, hostname, ipFileName;
string hostname, ipFileName;
int lg2, pnbase, opening_sum, max_broadcast;
int my_port;
online_opts.finalize(opt, argc, argv);
opt.get("--portnumbase")->getInt(pnbase);
opt.get("--lg2")->getInt(lg2);
opt.get("--memory")->getString(memtype);
opt.get("--hostname")->getString(hostname);
opt.get("--ip-file-name")->getString(ipFileName);
opt.get("--opening-sum")->getInt(opening_sum);
@@ -192,7 +179,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
try
#endif
{
Machine<T, U>(playerno, playerNames, online_opts.progname, memtype, lg2,
Machine<T, U>(playerno, playerNames, online_opts.progname, online_opts.memtype, lg2,
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
opt.get("--threads")->isSet, max_broadcast,
opt.get("--encrypted")->isSet, online_opts.live_prep,

View File

@@ -5,7 +5,6 @@
#include <OT/TripleMachine.h>
#include "OT/NPartyTripleGenerator.h"
#include "OT/OTMachine.h"
#include "OT/OTTripleSetup.h"
#include "Math/gf2n.h"
#include "Math/Setup.h"
@@ -13,9 +12,11 @@
#include "Tools/ezOptionParser.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include "Math/BitVec.h"
#include "Protocols/fake-stuff.hpp"
#include "Math/Z2k.hpp"
#include "OT/NPartyTripleGenerator.hpp"
#include <iostream>
#include <fstream>
@@ -23,24 +24,10 @@ using namespace std;
void* run_ngenerator_thread(void* ptr)
{
((MascotGenerator*)ptr)->generate();
((GeneratorThread*)ptr)->generate();
return 0;
}
MascotParams::MascotParams()
{
generateMACs = true;
amplify = true;
check = true;
generateBits = false;
timerclear(&start);
}
void MascotParams::set_passive()
{
generateMACs = amplify = check = false;
}
TripleMachine::TripleMachine(int argc, const char** argv) :
nConnections(1), bonding(0)
{
@@ -167,9 +154,10 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
}
template<class T>
NPartyTripleGenerator<T>* TripleMachine::new_generator(OTTripleSetup& setup, int i)
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i)
{
return new NPartyTripleGenerator<T>(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this);
return new typename T::TripleGenerator(setup, N[i % nConnections], i,
nTriplesPerThread, nloops, *this);
}
void TripleMachine::run()
@@ -186,13 +174,13 @@ void TripleMachine::run()
PlainPlayer P(N[0], 0xF000);
OTTripleSetup setup(P, true);
vector<MascotGenerator*> generators(nthreads);
vector<GeneratorThread*> generators(nthreads);
vector<pthread_t> threads(nthreads);
for (int i = 0; i < nthreads; i++)
{
if (primeField)
generators[i] = new_generator<Share<gfp1>>(setup, i);
generators[i] = new_generator<Share<gfp>>(setup, i);
else if (z2k)
{
if (z2k == 32 and z2s == 32)
@@ -270,58 +258,3 @@ void TripleMachine::output_mac_keys()
else
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s);
}
template<> gf2n_long MascotParams::get_mac_key()
{
return mac_key2l;
}
template<> gf2n_short MascotParams::get_mac_key()
{
return mac_key2s;
}
template<> gfp1 MascotParams::get_mac_key()
{
return mac_keyp;
}
template<> Z2<48> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<64> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<32> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> void MascotParams::set_mac_key(gf2n_long key)
{
mac_key2l = key;
}
template<> void MascotParams::set_mac_key(gf2n_short key)
{
mac_key2s = key;
}
template<> void MascotParams::set_mac_key(gfp1 key)
{
mac_keyp = key;
}
template<> void MascotParams::set_mac_key(Z2<64> key)
{
mac_keyz = key;
}
template<> void MascotParams::set_mac_key(Z2<48> key)
{
mac_keyz = key;
}

View File

@@ -3,10 +3,25 @@
*
*/
#include "GC/ReplicatedParty.h"
#include "GC/ShareParty.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/MaliciousRepSecret.h"
#include "GC/Instruction.hpp"
#include "GC/Machine.hpp"
#include "GC/Processor.hpp"
#include "GC/Program.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
int main(int argc, const char** argv)
{
GC::ReplicatedParty<GC::MaliciousRepSecret>(argc, argv);
GC::ShareParty<GC::MaliciousRepSecret>(argc, argv);
}

View File

@@ -3,9 +3,24 @@
*
*/
#include "GC/ReplicatedParty.h"
#include "GC/ShareParty.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/Instruction.hpp"
#include "GC/Machine.hpp"
#include "GC/Processor.hpp"
#include "GC/Program.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
int main(int argc, const char** argv)
{
GC::ReplicatedParty<GC::SemiHonestRepSecret>(argc, argv);
GC::ShareParty<GC::SemiHonestRepSecret>(argc, argv);
}

View File

@@ -0,0 +1,28 @@
/*
* semi-bin-party.cpp
*
*/
#include "GC/ShareParty.h"
#include "GC/SemiSecret.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/Machine.hpp"
#include "GC/Program.hpp"
#include "GC/Instruction.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Processor.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/ReplicatedInput.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Input.hpp"
int main(int argc, const char** argv)
{
GC::ShareParty<GC::SemiSecret>(argc, argv);
}

View File

@@ -4,6 +4,7 @@
*/
#include "Processor/Machine.h"
#include "Processor/RingOptions.h"
#include "Protocols/Spdz2kShare.h"
#include "Math/gf2n.h"
#include "Networking/Server.h"
@@ -27,13 +28,24 @@ int main(int argc, const char** argv)
int s;
opt.get("-S")->getInt(s);
opt.resetArgs();
RingOptions ring_options(opt, argc, argv);
int k = ring_options.R;
#ifdef VERBOSE
cerr << "Using SPDZ2k with security parameter " << s << endl;
cerr << "Using SPDZ2k with ring length " << k << " and security parameter "
<< s << endl;
#endif
if (s == 64)
return spdz_main<Spdz2kShare<64, 64>, Share<gf2n>>(argc, argv, opt);
else if (s == 48)
return spdz_main<Spdz2kShare<64, 48>, Share<gf2n>>(argc, argv, opt);
#undef Z
#define Z(K, S) \
if (s == S and k == K) \
return spdz_main<Spdz2kShare<K, S>, Share<gf2n>>(argc, argv, opt);
Z(64, 64)
Z(64, 48)
Z(72, 64)
Z(72, 48)
else
throw runtime_error("not compiled for s=" + to_string(s));
throw runtime_error(
"not compiled for k=" + to_string(k) + " and s=" + to_string(s));
}

31
Machines/tiny-party.cpp Normal file
View File

@@ -0,0 +1,31 @@
/*
* tiny-party.cpp
*
*/
#include "GC/TinySecret.h"
#include "GC/ShareParty.h"
#include "GC/TinyMC.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/Instruction.hpp"
#include "GC/Machine.hpp"
#include "GC/Processor.hpp"
#include "GC/Program.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MascotPrep.hpp"
int main(int argc, const char** argv)
{
GC::ShareParty<GC::TinySecret<40>>(argc, argv, 1000);
}

View File

@@ -18,7 +18,7 @@ OT_EXE = ot.x ot-offline.x
COMMON = $(MATH) $(TOOLS) $(NETWORK)
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) BMR/Key.o
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(OT)
VM = $(PROCESSOR) $(COMMON)
@@ -35,7 +35,7 @@ DEPS := $(wildcard */*.d)
.SECONDARY: $(OBJS)
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x mascot-party.x
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x
ifeq ($(USE_NTL),1)
all: overdrive she-offline cowgear-party.x
@@ -77,7 +77,7 @@ spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Off
tldr:
-echo ARCH = -march=native >> CONFIG.mine
$(MAKE) Player-Online.x
$(MAKE) mascot-party.x
ifeq ($(OS), Darwin)
tldr: mac-setup
@@ -90,7 +90,7 @@ shamir: shamir-party.x malicious-shamir-party.x galois-degree.x
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE)
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC)
$(AR) -csr $@ $^
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
@@ -104,47 +104,17 @@ static-dir:
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp))
Fake-Offline.x: Fake-Offline.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON)
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) $(ECLIB)
Check-Offline.x: Check-Offline.o $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o Check-Offline.x $^ $(LDLIBS)
Check-Offline.x: $(PROCESSOR)
Check-Offline-Z2k.x: Check-Offline-Z2k.cpp $(COMMON)
$(CXX) $(CFLAGS) -o Check-Offline-Z2k.x $^ $(LDLIBS)
Server.x: Server.cpp $(COMMON)
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)
Setup.x: Setup.cpp $(COMMON)
$(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS)
ot.x: $(OT) $(COMMON) OT/OText_main.cpp $(LIBSIMPLEOT)
ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
ot-check.x: $(OT) $(COMMON)
$(CXX) $(CFLAGS) -o ot-check.x OT/OutputCheck.cpp $(COMMON) $(LDLIBS)
ot-offline.x: $(OT) $(LIBSIMPLEOT) Machines/TripleMachine.o
ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp
$(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o $(COMMON) $(LDLIBS)
ot-offline.x: $(OT) $(COMMON) ot-offline.cpp $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
check-passive.x: $(COMMON) check-passive.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON)
$(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS)
gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON)
$(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS)
gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(GC)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
gc-emulate.x: $(PROCESSOR) GC/FakeSecret.o GC/square64.o
bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS)
@@ -155,48 +125,41 @@ bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT)
bmr-clean:
-rm BMR/*.o BMR/*/*.o GC/*.o
client-setup.x: client-setup.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
ifeq ($(USE_NTL),1)
simple-offline.x: $(COMMON) $(FHEOFFLINE) simple-offline.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
pairwise-offline.x: $(COMMON) $(FHEOFFLINE) pairwise-offline.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
cnc-offline.x: $(COMMON) $(FHEOFFLINE) cnc-offline.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
spdz2-offline.x: $(COMMON) $(FHEOFFLINE) spdz2-offline.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
endif
simple-offline.x: $(FHEOFFLINE)
pairwise-offline.x: $(FHEOFFLINE)
cnc-offline.x: $(FHEOFFLINE)
spdz2-offline.x: $(FHEOFFLINE)
yao-party.x: $(YAO)
yao-clean:
-rm Yao/*.o
galois-degree.x: galois-degree.cpp
galois-degree.x: Utils/galois-degree.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
default-prime-length.x: default-prime-length.cpp
default-prime-length.x: Utils/default-prime-length.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
%.x: Utils/%.o $(COMMON)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%.x: Machines/%.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) $(ECLIB)
replicated-bin-party.x: $(GC)
malicious-rep-bin-party.x: $(GC)
replicated-bin-party.x: GC/square64.o
malicious-rep-bin-party.x: GC/square64.o
semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
tiny-party.x: $(OT)
shamir-party.x: Machines/ShamirMachine.o
malicious-shamir-party.x: Machines/ShamirMachine.o
spdz2k-party.x: $(OT)

View File

@@ -8,12 +8,18 @@
#include "Integer.h"
#include "field_types.h"
#include "Square.h"
class BitDiagonal;
class BitVec : public IntBase
{
public:
typedef BitVec Scalar;
typedef BitVec next;
typedef BitDiagonal Square;
static const int n_bits = sizeof(a) * 8;
static char type_char() { return 'B'; }
@@ -32,10 +38,14 @@ public:
BitVec operator/(const BitVec& other) const { (void) other; throw not_implemented(); }
BitVec& operator+=(const BitVec& other) { *this ^= other; return *this; }
BitVec& operator-=(const BitVec& other) { *this ^= other; return *this; }
BitVec extend_bit() const { return -(a & 1); }
BitVec mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; }
template<int t>
void add(octetStream& os) { *this += os.get<BitVec>(); }
void mul(const BitVec& a, const BitVec& b) { *this = a * b; }
void randomize(PRNG& G, int n = n_bits) { IntBase::randomize(G); *this = mask(n); }

View File

@@ -7,6 +7,7 @@
#define MATH_FIXEDVEC_H_
#include <string>
#include <array>
using namespace std;
#include "Tools/octetStream.h"
@@ -21,7 +22,7 @@ template<class T> class Replicated;
template <class T, int L>
class FixedVec
{
T v[L];
array<T, L> v;
public:
typedef T value_type;
@@ -71,6 +72,16 @@ public:
v[i] = other[i];
}
FixedVec<T, L>(const array<T, L>& other)
{
v = other;
}
const array<T, L>& get() const
{
return v;
}
T& operator[](int i)
{
return v[i];

View File

@@ -24,9 +24,12 @@ protected:
long a;
public:
static const int N_BYTES = sizeof(a);
static const int N_BITS = 8 * sizeof(a);
static const int MAX_N_BITS = N_BITS;
static int size() { return sizeof(a); }
static int length() { return N_BITS; }
static string type_string() { return "integer"; }
static void init_default(int lgp) { (void)lgp; }
@@ -39,10 +42,12 @@ public:
long get() const { return a; }
bool get_bit(int i) const { return (a >> i) & 1; }
char* get_ptr() const { return (char*)&a; }
unsigned long debug() const { return a; }
void assign(long x) { *this = x; }
void assign(const char* buffer) { avx_memcpy(&a, buffer, sizeof(a)); }
void assign(const void* buffer) { avx_memcpy(&a, buffer, sizeof(a)); }
void assign_zero() { a = 0; }
void assign_one() { a = 1; }
@@ -50,8 +55,20 @@ public:
bool is_one() const { return a == 1; }
bool is_bit() const { return is_zero() or is_one(); }
long operator>>(const IntBase& other) const { return a >> other.a; }
long operator<<(const IntBase& other) const { return a << other.a; }
long operator>>(const IntBase& other) const
{
if (other.a < N_BITS)
return (unsigned long) a >> other.a;
else
return 0;
}
long operator<<(const IntBase& other) const
{
if (other.a < N_BITS)
return a << other.a;
else
return 0;
}
long operator^(const IntBase& other) const { return a ^ other.a; }
long operator&(const IntBase& other) const { return a & other.a; }

View File

@@ -121,7 +121,7 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2
if (mkdir_p(ss.str().c_str()) == -1)
{
cerr << "mkdir_p(" << ss.str() << ") failed\n";
throw file_error();
throw file_error(ss.str());
}
// Output the data

View File

@@ -4,6 +4,7 @@
*/
#include "Square.h"
#include "BitVec.h"
template<>
void Square<gf2n_short>::to(gf2n_short& result)
@@ -34,3 +35,11 @@ void Square<gfp1>::to(gfp1& result)
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L);
result.assign((void*) ans);
}
template<>
void Square<BitVec>::to(BitVec& result)
{
result = 0;
for (int i = 0; i < N_ROWS; i++)
result ^= ((rows[i] >> i) & 1) << i;
}

View File

@@ -12,6 +12,8 @@ template<class U>
class Square
{
public:
typedef U RowType;
static const int N_ROWS = U::MAX_N_BITS;
static const int N_ROWS_ALLOCATED = N_ROWS;
static const int N_COLUMNS = N_ROWS;
@@ -21,16 +23,11 @@ public:
U rows[N_ROWS];
template<class T>
Square& sub(const Square& other);
template<class T>
Square& rsub(const Square& other);
template<class T>
Square& sub(const void* other);
template <class T>
void randomize(int row, PRNG& G) { rows[row].randomize(G); }
template <class T>
void conditional_add(BitVector& conditions, Square& other,
int offset);
void to(U& result);

View File

@@ -6,7 +6,6 @@
#include "Math/Square.h"
template<class U>
template<class T>
Square<U>& Square<U>::sub(const Square<U>& other)
{
for (int i = 0; i < U::length(); i++)
@@ -15,7 +14,6 @@ Square<U>& Square<U>::sub(const Square<U>& other)
}
template<class U>
template<class T>
Square<U>& Square<U>::rsub(const Square<U>& other)
{
for (int i = 0; i < U::length(); i++)
@@ -24,7 +22,6 @@ Square<U>& Square<U>::rsub(const Square<U>& other)
}
template<class U>
template<class T>
Square<U>& Square<U>::sub(const void* other)
{
U value;
@@ -35,7 +32,6 @@ Square<U>& Square<U>::sub(const void* other)
}
template<class U>
template<class T>
void Square<U>::conditional_add(BitVector& conditions,
Square<U>& other, int offset)
{

View File

@@ -16,6 +16,8 @@ public:
static void init_default(int l) { (void) l; }
static void read_setup(int nparties, int lg2p, int gf2ndegree);
void normalize() {}
};
#endif /* MATH_VALUEINTERFACE_H_ */

View File

@@ -109,6 +109,7 @@ public:
Z2<K+L> operator*(const Z2<L>& other) const;
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(); }
Z2<K> operator*(int other) const { return *this * Z2<K>(other); }
Z2<K> operator/(const Z2& other) const { (void) other; throw not_implemented(); }

View File

@@ -352,7 +352,7 @@ void gf2n_short::input(istream& s,bool human)
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
throw file_error("gf2n_short input");
}
throw end_of_file("gf2n_short");
}

View File

@@ -64,6 +64,7 @@ class gf2n_short
typedef gf2n_short Scalar;
static const int MAX_N_BITS = 64;
static const int N_BYTES = sizeof(a);
static void init_field(int nn);
static int degree() { return n; }

View File

@@ -257,7 +257,7 @@ void gf2n_long::input(istream& s,bool human)
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
throw file_error("gf2n_long input");
}
throw end_of_file("gf2n_long");
}

View File

@@ -100,6 +100,7 @@ class gf2n_long
typedef ::Square<gf2n_long> Square;
const static int MAX_N_BITS = 128;
const static int N_BYTES = sizeof(a);
typedef gf2n_long Scalar;

View File

@@ -56,6 +56,7 @@ class gfp_
static const int N_LIMBS = L;
static const int MAX_N_BITS = 64 * L;
static const int N_BYTES = sizeof(a);
template<class T>
static void init(bool mont = true)

View File

@@ -256,7 +256,7 @@ void modp_<L>::input(istream& s,const Zp_Data& ZpD,bool human)
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
throw file_error("modp input");
}
throw end_of_file("modp");
}

View File

@@ -6,8 +6,8 @@
#ifndef MATH_OPERATORS_H_
#define MATH_OPERATORS_H_
template <class T>
T operator*(const bool& x, const T& y) { return x ? y : T(); }
//template <class T>
//T operator*(const bool& x, const T& y) { return x ? y : T(); }
//template <class T>
//T operator*(const T& y, const bool& x) { return x ? y : T(); }
template <class T>

View File

@@ -56,10 +56,23 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
nplayers = 0;
portnum_base = pnb;
string line;
ports.clear();
while (getline(hostsfile, line))
{
if (line.length() > 0 && line.at(0) != '#') {
names.push_back(line);
auto pos = line.find(':');
if (pos == string::npos)
{
names.push_back(line);
ports.push_back(default_port(nplayers));
}
else
{
names.push_back(line.substr(0, pos));
int port;
stringstream(line.substr(pos + 1)) >> port;
ports.push_back(port);
}
nplayers++;
if (nplayers_wanted > 0 and nplayers_wanted == nplayers)
break;
@@ -67,29 +80,18 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
}
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
throw runtime_error("not enought hosts in HOSTS");
setup_ports();
#ifdef DEBUG_NETWORKING
cerr << "Got list of " << nplayers << " players from file: " << endl;
for (unsigned int i = 0; i < names.size(); i++)
cerr << " " << names[i] << endl;
cerr << " " << names[i] << ":" << ports[i] << endl;
#endif
setup_server();
}
Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
int default_nplayers) :
Names()
int default_nplayers) : Names()
{
NetworkOptions network_opts(opt, argc, argv);
opt.add(
to_string(default_nplayers).c_str(), // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of players", // Help description.
"-N", // Flag token.
"--nparties" // Flag token.
);
NetworkOptionsWithNumber network_opts(opt, argc, argv, default_nplayers, true);
opt.add(
"", // Default.
1, // Required?
@@ -101,9 +103,7 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
);
opt.parse(argc, argv);
opt.get("-p")->getInt(player_no);
opt.get("-N")->getInt(nplayers);
global_server = Server::start_networking(*this, player_no, nplayers,
network_opts.hostname, network_opts.portnum_base);
global_server = network_opts.start_networking(*this, player_no);
}
void Names::setup_ports()
@@ -396,6 +396,9 @@ void MultiPlayer<T>::exchange_no_stats(int other, const octetStream& o, octetStr
void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const
{
#ifdef VERBOSE_COMM
cerr << "Exchanging with " << other << endl;
#endif
TimeScope ts(comm_stats["Exchanging"].add(o));
exchange_no_stats(other, o, to_receive);
sent += o.get_length();
@@ -605,34 +608,34 @@ int RealTwoPartyPlayer::other_player_num() const
return other_player;
}
void RealTwoPartyPlayer::send(octetStream& o)
void RealTwoPartyPlayer::send(octetStream& o) const
{
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
o.Send(socket);
sent += o.get_length();
}
void VirtualTwoPartyPlayer::send(octetStream& o)
void VirtualTwoPartyPlayer::send(octetStream& o) const
{
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
P.send_to_no_stats(other_player, o);
sent += o.get_length();
}
void RealTwoPartyPlayer::receive(octetStream& o)
void RealTwoPartyPlayer::receive(octetStream& o) const
{
TimeScope ts(timer);
o.reset_write_head();
o.Receive(socket);
}
void VirtualTwoPartyPlayer::receive(octetStream& o)
void VirtualTwoPartyPlayer::receive(octetStream& o) const
{
TimeScope ts(timer);
P.receive_player_no_stats(other_player, o);
}
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o)
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
{
if (is_server)
@@ -655,7 +658,7 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const
o.exchange(socket, socket);
}
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o)
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0]));
sent += o[0].get_length();
@@ -667,11 +670,21 @@ VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) :
{
}
void OffsetPlayer::send_receive_player(vector<octetStream>& o)
void OffsetPlayer::send_receive_player(vector<octetStream>& o) const
{
P.exchange(P.get_player(offset), o[0], o[1]);
}
void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o,
bool donthash) const
{
(void) donthash;
vector<octetStream> os(2);
os[0] = o[my_num()];
send_receive_player(os);
o[1 - my_num()] = os[1];
}
CommStats& CommStats::operator +=(const CommStats& other)
{
data += other.data;

View File

@@ -126,9 +126,11 @@ public:
virtual ~PlayerBase();
int my_real_num() const { return player_no; }
virtual int my_num() const = 0;
virtual int num_players() const = 0;
virtual void pass_around(octetStream& o, int offset = 1) const = 0;
virtual void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const = 0;
};
class Player : public PlayerBase
@@ -276,9 +278,10 @@ public:
virtual int my_num() const = 0;
virtual int other_player_num() const = 0;
virtual void send(octetStream& o) = 0;
virtual void receive(octetStream& o) = 0;
virtual void send_receive_player(vector<octetStream>& o) = 0;
virtual void send(octetStream& o) const = 0;
virtual void receive(octetStream& o) const = 0;
virtual void send_receive_player(vector<octetStream>& o) const = 0;
void Broadcast_Receive(vector<octetStream>& o, bool donthash=false) const;
};
class RealTwoPartyPlayer : public TwoPartyPlayer
@@ -295,8 +298,8 @@ public:
RealTwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
~RealTwoPartyPlayer();
void send(octetStream& o);
void receive(octetStream& o);
void send(octetStream& o) const;
void receive(octetStream& o) const;
int other_player_num() const;
int my_num() const { return is_server; }
@@ -305,7 +308,7 @@ public:
/* Send and receive to/from the other player
* - o[0] contains my data, received data put in o[1]
*/
void send_receive_player(vector<octetStream>& o);
void send_receive_player(vector<octetStream>& o) const;
void exchange(octetStream& o) const;
void exchange(int other, octetStream& o) const { (void)other; exchange(o); }
@@ -326,9 +329,9 @@ public:
int other_player_num() const { return other_player; }
int num_players() const { return 2; }
void send(octetStream& o);
void receive(octetStream& o);
void send_receive_player(vector<octetStream>& o);
void send(octetStream& o) const;
void receive(octetStream& o) const;
void send_receive_player(vector<octetStream>& o) const;
void pass_around(octetStream& o, int _ = 1) const { (void)_, (void) o; throw not_implemented(); }
};
@@ -349,16 +352,15 @@ public:
int num_players() const { return 2; }
int get_offset() const { return offset; }
void send(octetStream& o) { P.send_to(P.get_player(offset), o, true); }
void reverse_send(octetStream& o) { P.send_to(P.get_player(-offset), o, true); }
void receive(octetStream& o) { P.receive_player(P.get_player(offset), o, true); }
void send(octetStream& o) const { P.send_to(P.get_player(offset), o, true); }
void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o, true); }
void receive(octetStream& o) const { P.receive_player(P.get_player(offset), o, true); }
void reverse_receive(octetStream& o) { P.receive_player(P.get_player(-offset), o, true); }
void send_receive_player(vector<octetStream>& o);
void send_receive_player(vector<octetStream>& o) const;
void reverse_exchange(octetStream& o) const { P.pass_around(o, P.num_players() - offset); }
void exchange(int other, octetStream& o) const { (void)other; P.pass_around(o, offset); }
void exchange(octetStream& o) const { P.exchange(P.get_player(offset), o); }
void pass_around(octetStream& o, int _ = 1) const { (void)_; P.pass_around(o, offset); }
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
};
#endif

View File

@@ -106,6 +106,8 @@ void BaseOT::exec_base(bool new_receiver_inputs)
receiver_maketable(&receiver);
}
os[0].reset_write_head();
for (i = 0; i < nOT; i += 4)
{
if (ot_role & RECEIVER)
@@ -117,12 +119,24 @@ void BaseOT::exec_base(bool new_receiver_inputs)
cs[j] = receiver_inputs[i + j];
}
receiver_rsgen(&receiver, Rs_pack[0], cs);
os[0].reset_write_head();
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
receiver_keygen(&receiver, receiver_keys);
// Copy keys to receiver_outputs
for (j = 0; j < 4; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
}
}
}
send_if_ot_receiver(P, os, ot_role);
}
send_if_ot_receiver(P, os, ot_role);
for (i = 0; i < nOT; i += 4)
{
if (ot_role & SENDER)
{
os[1].get_bytes((octet*) Rs_pack[1], len);
@@ -143,18 +157,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
}
}
}
if (ot_role & RECEIVER)
{
// Copy keys to receiver_outputs
for (j = 0; j < 4; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
}
}
}
#ifdef BASE_OT_DEBUG
for (j = 0; j < 4; j++)
{

19
OT/BitDiagonal.cpp Normal file
View File

@@ -0,0 +1,19 @@
/*
* Diagonal.cpp
*
*/
#include <OT/BitDiagonal.h>
void BitDiagonal::pack(octetStream& os) const
{
for (int i = 0; i < N_ROWS; i++)
os.store_int(rows[i].get_bit(i), 1);
}
void BitDiagonal::unpack(octetStream& os)
{
*this = {};
for (int i = 0; i < N_ROWS; i++)
rows[i] = os.get_int(1) << i;
}

24
OT/BitDiagonal.h Normal file
View File

@@ -0,0 +1,24 @@
/*
* Diagonal.h
*
*/
#ifndef OT_BITDIAGONAL_H_
#define OT_BITDIAGONAL_H_
#include "Math/Square.h"
#include "Math/BitVec.h"
class BitDiagonal : public Square<BitVec>
{
public:
static int size()
{
return 8 * BitVec::size();
}
void pack(octetStream& os) const;
void unpack(octetStream& os);
};
#endif /* OT_BITDIAGONAL_H_ */

View File

@@ -9,9 +9,12 @@
#include "BitMatrix.h"
#include "Rectangle.h"
#include "BitDiagonal.h"
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Z2k.h"
#include "Math/BitVec.h"
#include "GC/TinySecret.h"
#include "OT/Rectangle.hpp"
#include "Math/Z2k.hpp"
@@ -268,25 +271,22 @@ void square128::randomize(PRNG& G)
G.get_octets((octet*)&rows, sizeof(rows));
}
template <>
void square128::randomize<gf2n_long>(int row, PRNG& G)
void square128::randomize(int row, PRNG& G)
{
rows[row] = G.get_doubleword();
}
template<>
void square128::conditional_add<gf2n_long>(BitVector& conditions, square128& other, int offset)
void square128::conditional_add(BitVector& conditions, square128& other, int offset)
{
for (int i = 0; i < 128; i++)
if (conditions.get_bit(128 * offset + i))
rows[i] ^= other.rows[i];
}
template <class T>
void square128::hash_row_wise(MMO& mmo, square128& input)
{
mmo.hashBlockWise<T,128>((octet*)rows, (octet*)input.rows);
mmo.hashBlockWise<gf2n_long,128>((octet*)rows, (octet*)input.rows);
}
template <>
@@ -395,20 +395,17 @@ square128& square128::operator^=(square128& other)
return *this;
}
template<>
square128& square128::add<gf2n_long>(square128& other)
square128& square128::add(square128& other)
{
return *this ^= other;
}
template<>
square128& square128::sub<gf2n_long>(square128& other)
square128& square128::sub(square128& other)
{
return *this ^= other;
}
template<>
square128& square128::rsub<gf2n_long>(square128& other)
square128& square128::rsub(square128& other)
{
return *this ^= other;
}
@@ -421,8 +418,7 @@ square128& square128::operator^=(const __m128i* other)
return *this;
}
template <>
square128& square128::sub<gf2n_long>(const __m128i* other)
square128& square128::sub(const __m128i* other)
{
return *this ^= other;
}
@@ -500,7 +496,7 @@ template <class U>
void Matrix<U>::randomize(int row, PRNG& G)
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].template randomize<gf2n_long>(row, G);
squares[i].randomize(row, G);
}
void BitMatrix::transpose()
@@ -597,44 +593,40 @@ Slice<U>::Slice(U& bm, size_t start, size_t size) :
}
template <class U>
template <class T>
Slice<U>& Slice<U>::rsub(Slice<U>& other)
{
if (bm.squares.size() < other.end)
throw invalid_length();
for (size_t i = other.start; i < other.end; i++)
bm.squares[i].template rsub<T>(other.bm.squares[i]);
bm.squares[i].rsub(other.bm.squares[i]);
return *this;
}
template <class U>
template <class T>
Slice<U>& Slice<U>::sub(BitVector& other, int repeat)
{
if (end * U::PartType::N_COLUMNS > other.size() * repeat)
throw invalid_length(to_string(U::PartType::N_COLUMNS));
for (size_t i = start; i < end; i++)
{
bm.squares[i].template sub<T>(other.get_ptr_to_byte(i / repeat,
bm.squares[i].sub(other.get_ptr_to_byte(i / repeat,
U::PartType::N_ROW_BYTES));
}
return *this;
}
template <class U>
template <class T>
void Slice<U>::randomize(int row, PRNG& G)
{
for (size_t i = start; i < end; i++)
bm.squares[i].template randomize<T>(row, G);
bm.squares[i].randomize(row, G);
}
template <class U>
template <class T>
void Slice<U>::conditional_add(BitVector& conditions, U& other, bool useOffset)
{
for (size_t i = start; i < end; i++)
bm.squares[i].template conditional_add<T>(conditions, other.squares[i], useOffset * i);
bm.squares[i].conditional_add(conditions, other.squares[i], useOffset * i);
}
template <>
@@ -651,7 +643,7 @@ void Slice<U>::print()
cout << "hex / value" << endl;
for (int i = 0; i < 16; i++)
{
cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl;
cout << T(bm.squares[0].rows[i]) << endl;
}
cout << endl;
}
@@ -671,24 +663,13 @@ void Slice<U>::unpack(octetStream& os)
bm.squares[i].unpack(os);
}
#define M(N,L) Matrix<Rectangle< Z2<N>, Z2<L> > >
#undef XXX
#define XXX(T,N,L) \
template class Matrix<Rectangle< Z2<N>, Z2<L> > >; \
template class Slice<Matrix<Rectangle< Z2<N>, Z2<L> > > >; \
template Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& Slice< \
Matrix<Rectangle<Z2<N>, Z2<L> > > >::rsub<T>( \
Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& other); \
template Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >& Slice< \
Matrix<Rectangle<Z2<N>, Z2<L> > > >::sub<T>(BitVector& other, int repeat); \
template void Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >::conditional_add< \
T>(BitVector& conditions, \
Matrix<Rectangle<Z2<N>, Z2<L> > >& other, bool useOffset); \
#undef X
#define X(N,L) \
template void Slice<Matrix<Rectangle< Z2<N>, Z2<L> > > >::randomize<Z2<L> >(int row, PRNG& G); \
XXX(Z2<L>, N, L)
//X(96, 160)
@@ -700,6 +681,11 @@ Y(64, 48)
Y(66, 64)
Y(66, 48)
Y(32, 32)
Y(1, 40)
Y(72, 48)
Y(74, 48)
Y(72, 64)
Y(74, 64)
template class Matrix<square128>;
@@ -710,19 +696,15 @@ template class Slice<BM>; \
XX(BM, gf2n_long)
#define XX(BM, GF) \
template void Slice<BM >::conditional_add<GF>(BitVector& conditions, BM& other, bool useOffset); \
template Slice<BM >& Slice<BM >::rsub<GF>(Slice<BM >& other); \
template Slice<BM >& Slice<BM >::sub<GF>(BitVector& other, int repeat); \
template void Slice<BM >::randomize<GF>(int row, PRNG& G); \
//template void Slice<BM >::print<GF>();
BMS
template class Slice<Matrix<gf2n_short_square>>;
XX(Matrix<gf2n_short_square>, gf2n_short)
#define XXXX(BM, GF) \
template class Slice<BM>; \
XX(BM, GF)
template class Slice<Matrix<Square<gf2n_long>>>;
XX(Matrix<Square<gf2n_long>>, gf2n_long)
template class Slice<Matrix<Square<gfp1>>>;
XX(Matrix<Square<gfp1>>, gfp1)
XXXX(Matrix<gf2n_short_square>, gf2n_short)
XXXX(Matrix<Square<gf2n_long>>, gf2n_long)
XXXX(Matrix<Square<gfp1>>, gfp1)
XXXX(Matrix<BitDiagonal>, BitVec)

View File

@@ -19,7 +19,7 @@
using namespace std;
union square128 {
typedef int128 RowType;
typedef gf2n_long RowType;
const static int N_ROWS = 128;
const static int N_ROWS_ALLOCATED = 128;
@@ -46,24 +46,16 @@ union square128 {
square128& operator^=(BitVector& other);
bool operator==(square128& other);
template <class T>
square128& add(square128& other);
template <class T>
square128& sub(square128& other);
template <class T>
square128& rsub(square128& other);
template <class T>
square128& sub(const __m128i* other);
template <class T>
square128& sub(const void* other) { return sub<T>((__m128i*)other); }
square128& sub(const void* other) { return sub((__m128i*)other); }
void randomize(PRNG& G);
template <class T>
void randomize(int row, PRNG& G);
template <class T>
void conditional_add(BitVector& conditions, square128& other, int offset);
void transpose();
template <class T>
void hash_row_wise(MMO& mmo, square128& input);
template <class T>
void to(T& result);
@@ -173,14 +165,10 @@ class Slice
public:
Slice(U& bm, size_t start, size_t size);
template <class T>
Slice<U>& rsub(Slice<U>& other);
template <class T>
Slice<U>& sub(BitVector& other, int repeat = 1);
template <class T>
void randomize(int row, PRNG& G);
template <class T>
void conditional_add(BitVector& conditions, U& other, bool useOffset = false);
void transpose();

106
OT/MascotParams.cpp Normal file
View File

@@ -0,0 +1,106 @@
/*
* TripleMachine.cpp
*
*/
#include <OT/TripleMachine.h>
#include "OT/NPartyTripleGenerator.h"
#include "OT/OTTripleSetup.h"
#include "Math/gf2n.h"
#include "Math/Setup.h"
#include "Protocols/Spdz2kShare.h"
#include "Tools/ezOptionParser.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include "Math/BitVec.h"
#include "Protocols/fake-stuff.hpp"
#include "Math/Z2k.hpp"
#include <iostream>
#include <fstream>
using namespace std;
MascotParams::MascotParams()
{
generateMACs = true;
amplify = true;
check = true;
generateBits = false;
timerclear(&start);
}
void MascotParams::set_passive()
{
generateMACs = amplify = check = false;
}
template<> gf2n_long MascotParams::get_mac_key()
{
return mac_key2l;
}
template<> gf2n_short MascotParams::get_mac_key()
{
return mac_key2s;
}
template<> gfp1 MascotParams::get_mac_key()
{
return mac_keyp;
}
template<> Z2<48> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<64> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<40> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<32> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> BitVec MascotParams::get_mac_key()
{
return 0;
}
template<> void MascotParams::set_mac_key(gf2n_long key)
{
mac_key2l = key;
}
template<> void MascotParams::set_mac_key(gf2n_short key)
{
mac_key2s = key;
}
template<> void MascotParams::set_mac_key(gfp1 key)
{
mac_keyp = key;
}
template<> void MascotParams::set_mac_key(Z2<64> key)
{
mac_keyz = key;
}
template<> void MascotParams::set_mac_key(Z2<48> key)
{
mac_keyz = key;
}
template<> void MascotParams::set_mac_key(Z2<40> key)
{
mac_keyz = key;
}

View File

@@ -24,7 +24,7 @@ class PlainTriple;
template <class T, int N>
using ShareTriple = ShareTriple_<T, T, N>;
class MascotGenerator
class GeneratorThread
{
protected:
pthread_mutex_t mutex;
@@ -37,8 +37,8 @@ public:
bool multi_threaded;
MascotGenerator() : nTriples(0), multi_threaded(true) {}
virtual ~MascotGenerator() {};
GeneratorThread() : nTriples(0), multi_threaded(true) {}
virtual ~GeneratorThread() {};
virtual void generate() = 0;
void lock();
@@ -48,7 +48,7 @@ public:
};
template<class T>
class OTTripleGenerator : public MascotGenerator
class OTTripleGenerator : public GeneratorThread
{
typedef typename T::open_type open_type;
typedef typename T::mac_key_type mac_key_type;
@@ -79,7 +79,7 @@ protected:
public:
// TwoPartyPlayer's for OTs, n-party Player for sacrificing
vector<TwoPartyPlayer*> players;
vector<OTMultiplierMac<sacri_type, open_type>*> ot_multipliers;
vector<typename T::Multiplier*> ot_multipliers;
//vector<OTMachine*> machines;
BitVector baseReceiverInput; // same for every set of OTs
vector< vector< vector<BitVector> > > baseSenderInputs;
@@ -111,6 +111,8 @@ public:
void generatePlainTriples();
void plainTripleRound(int k = 0);
void run_multipliers(MultJob job);
size_t data_sent();
};
@@ -121,8 +123,28 @@ class NPartyTripleGenerator : public OTTripleGenerator<T>
typedef typename T::mac_key_type mac_key_type;
typedef typename T::sacri_type sacri_type;
template <int K, int S>
void generateTriplesZ2k();
virtual void generateTriples() = 0;
virtual void generateBits() = 0;
public:
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
vector<InputTuple<Share<sacri_type>>> inputs;
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
Player* parentPlayer = 0);
virtual ~NPartyTripleGenerator() {}
void generate();
void generateInputs(int player);
};
template<class T>
class MascotTripleGenerator : public NPartyTripleGenerator<T>
{
typedef typename T::open_type open_type;
typedef typename T::mac_key_type mac_key_type;
typedef typename T::sacri_type sacri_type;
void generateTriples();
void generateBits();
@@ -132,21 +154,37 @@ class NPartyTripleGenerator : public OTTripleGenerator<T>
void sacrifice(vector<ShareTriple_<open_type, mac_key_type, 2> >& uncheckedTriples,
typename T::MAC_Check& MC, PRNG& G);
public:
vector<T> bits;
MascotTripleGenerator(OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
Player* parentPlayer = 0);
};
template<class T>
class Spdz2kTripleGenerator : public NPartyTripleGenerator<T>
{
typedef typename T::open_type open_type;
typedef typename T::mac_key_type mac_key_type;
typedef typename T::sacri_type sacri_type;
void generateBits() { throw not_implemented(); }
template<class U>
void sacrificeZ2k(vector<ShareTriple_<sacri_type, mac_key_type, 2> >& uncheckedTriples,
void sacrificeZ2k(
vector<
ShareTriple_<typename T::sacri_type,
typename T::mac_key_type, 2> >& uncheckedTriples,
U& MC, PRNG& G);
public:
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
vector<T> bits;
vector<InputTuple<Share<sacri_type>>> inputs;
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
Player* parentPlayer = 0);
void generate();
void generateInputs(int player);
void generateTriples();
};
template<class T>

View File

@@ -1,3 +1,6 @@
#ifndef OT_NPARTYTRIPLGENERATOR_HPP_
#define OT_NPARTYTRIPLGENERATOR_HPP_
#include "NPartyTripleGenerator.h"
#include "OT/OTExtensionWithMatrix.h"
@@ -11,9 +14,11 @@
#include "Tools/Subroutines.h"
#include "Protocols/MAC_Check.h"
#include "Protocols/Spdz2kPrep.h"
#include "GC/SemiSecret.h"
#include "OT/Triple.hpp"
#include "OT/Rectangle.hpp"
#include "OT/OTMultiplier.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.h"
#include "Protocols/MascotPrep.hpp"
@@ -46,6 +51,24 @@ NPartyTripleGenerator<T>::NPartyTripleGenerator(OTTripleSetup& setup,
{
}
template<class T>
MascotTripleGenerator<T>::MascotTripleGenerator(OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
machine, parentPlayer)
{
}
template<class T>
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
machine, parentPlayer)
{
}
template<class T>
OTTripleGenerator<T>::OTTripleGenerator(OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
@@ -174,9 +197,9 @@ void NPartyTripleGenerator<T>::generate()
timers["Generator thread"].stop();
if (machine.output)
cout << "Written " << nTriples << " " << T::type_string() << " outputs to " << ss.str() << endl;
#ifdef VERBOSE
#ifdef VERBOSE_OT
else
cout << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl;
cerr << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl;
#endif
}
@@ -251,7 +274,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
}
template<>
void NPartyTripleGenerator<Share<gf2n>>::generateBits()
void MascotTripleGenerator<Share<gf2n>>::generateBits()
{
for (int i = 0; i < nparties-1; i++)
ot_multipliers[i]->inbox.push(DATA_BIT);
@@ -288,9 +311,9 @@ void NPartyTripleGenerator<Share<gf2n>>::generateBits()
gf2n r;
for (int j = 0; j < nBitsToCheck; j++)
{
gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key<gf2n>();
gf2n mac_sum = valueBits[0].get_bit(j) ? machine.get_mac_key<gf2n>() : 0;
for (int i = 0; i < nparties-1; i++)
mac_sum += ((MascotMultiplier<gf2n>*)ot_multipliers[i])->macs[0][j];
mac_sum += ot_multipliers[i]->macs[0][j];
bits[j].set_share(valueBits[0].get_bit(j));
bits[j].set_mac(mac_sum);
r.randomize(G);
@@ -310,7 +333,7 @@ void NPartyTripleGenerator<Share<gf2n>>::generateBits()
}
template<>
void NPartyTripleGenerator<Share<gfp1>>::generateBits()
void MascotTripleGenerator<Share<gfp1>>::generateBits()
{
generateTriples();
}
@@ -322,9 +345,12 @@ void NPartyTripleGenerator<T>::generateBits()
}
template<class T>
template<int K, int S>
void NPartyTripleGenerator<T>::generateTriplesZ2k()
void Spdz2kTripleGenerator<T>::generateTriples()
{
const int K = T::k;
const int S = T::s;
auto& uncheckedTriples = this->uncheckedTriples;
auto& timers = this->timers;
auto& machine = this->machine;
auto& nTriplesPerLoop = this->nTriplesPerLoop;
@@ -386,7 +412,7 @@ void NPartyTripleGenerator<T>::generateTriplesZ2k()
timers["Triple computation"].start();
for (int i = 0; i < nparties-1; i++)
{
c += ((Spdz2kMultiplier<K, S>*)ot_multipliers[i])->c_output[j];
c += ot_multipliers[i]->c_output[j];
}
#ifdef DEBUG_SPDZ2K
@@ -433,36 +459,6 @@ void NPartyTripleGenerator<T>::generateTriplesZ2k()
}
}
template<>
void NPartyTripleGenerator<Spdz2kShare<32, 32>>::generateTriples()
{
this->generateTriplesZ2k<32, 32>();
}
template<>
void NPartyTripleGenerator<Spdz2kShare<64, 64>>::generateTriples()
{
this->generateTriplesZ2k<64, 64>();
}
template<>
void NPartyTripleGenerator<Spdz2kShare<64, 48>>::generateTriples()
{
this->generateTriplesZ2k<64, 48>();
}
template<>
void NPartyTripleGenerator<Spdz2kShare<66, 64>>::generateTriples()
{
this->generateTriplesZ2k<66, 64>();
}
template<>
void NPartyTripleGenerator<Spdz2kShare<66, 48>>::generateTriples()
{
this->generateTriplesZ2k<66, 48>();
}
template<class U>
void OTTripleGenerator<U>::generatePlainTriples()
{
@@ -500,13 +496,15 @@ void OTTripleGenerator<U>::plainTripleRound(int k)
for (int j = 0; j < nPreampTriplesPerLoop; j++)
{
T a((char*)valueBits[0].get_ptr() + j * T::size());
T b((char*)valueBits[1].get_ptr() + j / nAmplify * T::size());
T a;
a.assign((char*)valueBits[0].get_ptr() + j * T::size());
T b;
b.assign((char*)valueBits[1].get_ptr() + j / nAmplify * T::size());
T c = a * b;
timers["Triple computation"].start();
for (int i = 0; i < nparties-1; i++)
{
c += dynamic_cast<typename U::Multiplier*>(ot_multipliers[i])->c_output[j];
c += ot_multipliers[i]->c_output[j];
}
timers["Triple computation"].stop();
if (machine.amplify)
@@ -531,7 +529,7 @@ void OTTripleGenerator<U>::plainTripleRound(int k)
}
template<class U>
void NPartyTripleGenerator<U>::generateTriples()
void MascotTripleGenerator<U>::generateTriples()
{
typedef typename U::open_type T;
@@ -547,6 +545,7 @@ void NPartyTripleGenerator<U>::generateTriples()
auto& outputFile = this->outputFile;
auto& field_size = this->field_size;
auto& nPreampTriplesPerLoop = this->nPreampTriplesPerLoop;
auto& uncheckedTriples = this->uncheckedTriples;
for (int i = 0; i < nparties-1; i++)
ot_multipliers[i]->inbox.push(DATA_TRIPLE);
@@ -626,7 +625,7 @@ void NPartyTripleGenerator<U>::generateTriples()
}
template<class T>
void NPartyTripleGenerator<T>::sacrifice(
void MascotTripleGenerator<T>::sacrifice(
vector<ShareTriple_<open_type, mac_key_type, 2> >& uncheckedTriples, typename T::MAC_Check& MC, PRNG& G)
{
auto& machine = this->machine;
@@ -663,7 +662,7 @@ void NPartyTripleGenerator<T>::sacrifice(
template<class W>
template<class U>
void NPartyTripleGenerator<W>::sacrificeZ2k(
void Spdz2kTripleGenerator<W>::sacrificeZ2k(
vector<ShareTriple_<sacri_type, mac_key_type, 2> >& uncheckedTriples, U& MC, PRNG& G)
{
typedef sacri_type T;
@@ -707,7 +706,7 @@ void NPartyTripleGenerator<W>::sacrificeZ2k(
}
if (machine.generateBits)
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
throw not_implemented();
else
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
@@ -716,7 +715,7 @@ void NPartyTripleGenerator<W>::sacrificeZ2k(
template<>
template<class U, class V, class W, int N>
void NPartyTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
void MascotTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
vector< ShareTriple_<U, V, N> >& triples, W& MC, ofstream& outputFile)
{
vector< Share<gfp1> > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop);
@@ -746,7 +745,7 @@ void NPartyTripleGenerator<Share<gfp1>>::generateBitsFromTriples(
template<class T>
template<class U, class V, class W, int N>
void NPartyTripleGenerator<T>::generateBitsFromTriples(
void MascotTripleGenerator<T>::generateBitsFromTriples(
vector< ShareTriple_<U, V, N> >& triples, W& MC, ofstream& outputFile)
{
throw how_would_that_work();
@@ -786,22 +785,22 @@ void OTTripleGenerator<T>::print_progress(int k)
}
}
void MascotGenerator::lock()
void GeneratorThread::lock()
{
pthread_mutex_lock(&mutex);
}
void MascotGenerator::unlock()
void GeneratorThread::unlock()
{
pthread_mutex_unlock(&mutex);
}
void MascotGenerator::signal()
void GeneratorThread::signal()
{
pthread_cond_signal(&ready);
}
void MascotGenerator::wait()
void GeneratorThread::wait()
{
if (multi_threaded)
pthread_cond_wait(&ready, &mutex);
@@ -821,18 +820,11 @@ void OTTripleGenerator<T>::wait_for_multipliers()
ot_multipliers[i]->outbox.pop();
}
template<class T>
void OTTripleGenerator<T>::run_multipliers(MultJob job)
{
signal_multipliers(job);
wait_for_multipliers();
}
template class NPartyTripleGenerator<Share<gf2n_long>>;
template class NPartyTripleGenerator<Share<gf2n_short>>;
template class NPartyTripleGenerator<Share<gfp1>>;
template class OTTripleGenerator<SemiShare<gf2n>>;
template class OTTripleGenerator<SemiShare<gfp1>>;
template class OTTripleGenerator<Semi2kShare<64>>;
template class OTTripleGenerator<Semi2kShare<72>>;
template class NPartyTripleGenerator<Spdz2kShare<32, 32>>;
template class NPartyTripleGenerator<Spdz2kShare<64, 64>>;
template class NPartyTripleGenerator<Spdz2kShare<64, 48>>;
template class NPartyTripleGenerator<Spdz2kShare<66, 64>>;
template class NPartyTripleGenerator<Spdz2kShare<66, 48>>;
#endif

View File

@@ -8,6 +8,8 @@
#include "Math/gfp.h"
#include "Math/Z2k.h"
#include "Math/gf2nlong.h"
#include "Math/BitVec.h"
#include "GC/TinySecret.h"
#include "OT/Rectangle.hpp"
#include "Math/Z2k.hpp"
@@ -71,7 +73,7 @@ void OTExtensionWithMatrix::transfer(int nOTs,
for (int loop = 0; loop < nloops; loop++)
{
extend<gf2n_long>(nOTs, newReceiverInput);
extend(nOTs, newReceiverInput);
#ifdef OTEXT_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
@@ -97,24 +99,25 @@ void OTCorrelator<U>::resize(int nOTs)
}
// the template is used to denote the field of the hash output
template <class T>
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
{
extend_correlated(nOTs_requested, newReceiverInput);
hash_outputs<T>(nOTs_requested);
hash_outputs(nOTs_requested);
}
void OTExtensionWithMatrix::extend_correlated(BitVector& newReceiverInput)
void OTExtensionWithMatrix::extend_correlated(const BitVector& newReceiverInput)
{
extend_correlated(newReceiverInput.size(), newReceiverInput);
}
void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& newReceiverInput)
void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, const BitVector& newReceiverBits)
{
// if (nOTs % nbaseOTs != 0)
// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
if (nOTs_requested == 0)
return;
// local copy
auto newReceiverInput = newReceiverBits;
if ((ot_role & RECEIVER) and (size_t)nOTs_requested != newReceiverInput.size())
throw runtime_error("wrong number of choice bits");
int nOTs_requested_rounded = (nOTs_requested + 127) / 128 * 128;
@@ -133,8 +136,8 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new
// subloop for first part to interleave communication with computation
for (int start = 0; start < nOTs / 128; start += slice)
{
expand<gf2n_long>(start, slice);
this->correlate<gf2n_long>(start, slice, newReceiverInput, true);
expand(start, slice);
this->correlate(start, slice, newReceiverInput, true);
transpose(start, slice);
}
@@ -164,7 +167,6 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new
}
template <class U>
template <class T>
void OTCorrelator<U>::expand(int start, int slice)
{
(void)start, (void)slice;
@@ -180,8 +182,8 @@ void OTCorrelator<U>::expand(int start, int slice)
{
for (int i = 0; i < nbaseOTs; i++)
{
receiverOutputSlice.template randomize<T>(i, G_sender[i][0]);
t1Slice.template randomize<T>(i, G_sender[i][1]);
receiverOutputSlice.randomize(i, G_sender[i][0]);
t1Slice.randomize(i, G_sender[i][1]);
}
}
@@ -189,23 +191,22 @@ void OTCorrelator<U>::expand(int start, int slice)
{
for (int i = 0; i < nbaseOTs; i++)
// randomize base receiver output
senderOutputSlices[0].template randomize<T>(i, G_receiver[i]);
senderOutputSlices[0].randomize(i, G_receiver[i]);
}
}
template <class T>
void OTExtensionWithMatrix::expand_transposed()
{
for (int i = 0; i < nbaseOTs; i++)
{
if (ot_role & RECEIVER)
{
receiverOutputMatrix.squares[i/128].randomize<T>(i % 128, G_sender[i][0]);
t1.squares[i/128].randomize<T>(i % 128, G_sender[i][1]);
receiverOutputMatrix.squares[i/128].randomize(i % 128, G_sender[i][0]);
t1.squares[i/128].randomize(i % 128, G_sender[i][1]);
}
if (ot_role & SENDER)
{
senderOutputMatrices[0].squares[i/128].randomize<T>(i % 128, G_receiver[i]);
senderOutputMatrices[0].squares[i/128].randomize(i % 128, G_receiver[i]);
}
}
}
@@ -224,7 +225,6 @@ void OTCorrelator<U>::setup_for_correlation(BitVector& baseReceiverInput,
}
template <class U>
template <class T>
void OTCorrelator<U>::correlate(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat)
{
@@ -240,8 +240,8 @@ void OTCorrelator<U>::correlate(int start, int slice,
// create correlation
if (ot_role & RECEIVER)
{
t1Slice.template rsub<T>(receiverOutputSlice);
t1Slice.template sub<T>(newReceiverInput, repeat);
t1Slice.rsub(receiverOutputSlice);
t1Slice.sub(newReceiverInput, repeat);
t1Slice.pack(os[0]);
// t1 = receiverOutputMatrix;
@@ -260,7 +260,7 @@ void OTCorrelator<U>::correlate(int start, int slice,
{
// u = t0 + t1 + x
uSlice.unpack(os[1]);
senderOutputSlices[0].template conditional_add<T>(baseReceiverInput, u, !useConstantBase);
senderOutputSlices[0].conditional_add(baseReceiverInput, u, !useConstantBase);
}
#ifdef OTEXT_TIMER
gettimeofday(&commst2, NULL);
@@ -302,13 +302,12 @@ void OTExtensionWithMatrix::transpose(int start, int slice)
/*
* Hash outputs to make into random OT
*/
template <class T>
void OTExtensionWithMatrix::hash_outputs(int nOTs)
{
hash_outputs<T>(nOTs, senderOutputMatrices, receiverOutputMatrix);
hash_outputs(nOTs, senderOutputMatrices, receiverOutputMatrix);
}
template <class T, class V>
template <class V>
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput)
{
//cout << "Hashing... " << flush;
@@ -319,6 +318,7 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
gettimeofday(&startv, NULL);
#endif
typedef typename V::PartType::RowType T;
int n_rows = V::PartType::N_ROWS_ALLOCATED;
int n = (nOTs + n_rows - 1) / n_rows * V::PartType::N_ROWS;
@@ -326,11 +326,6 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
senderOutput[i].resize_vertical(n);
receiverOutput.resize_vertical(n);
if (V::PartType::N_ROW_BYTES != T::size())
throw runtime_error(
"length mismatch for MMO hash: "
+ to_string(V::PartType::N_ROW_BYTES) + " != "
+ to_string(T::size()));
if (nOTs % 8 != 0)
throw runtime_error("number of OTs must be divisible by 8");
@@ -378,7 +373,7 @@ void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output)
output.resize(nTriples);
for (unsigned int j = 0; j < nTriples; j++)
{
receiverOutputMatrix.squares[j].template sub<T>(senderOutputMatrices[0].squares[j]).to(output[j]);
receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]);
}
}
@@ -516,56 +511,36 @@ void OTExtensionWithMatrix::print_pre_expand()
}
template class OTCorrelator<BitMatrix>;
template void OTCorrelator<BitMatrix>::correlate<gf2n_long>(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat);
#define Z(BM,GF) \
template void OTCorrelator<BM>::correlate<GF>(int start, int slice, \
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
template void OTCorrelator<BM>::expand<GF>(int start, int slice); \
template class OTCorrelator<BM>; \
template void OTCorrelator<BM>::reduce_squares<GF>(unsigned int nTriples, \
vector<GF>& output);
template class OTCorrelator<Matrix<gf2n_short_square>>;
Z(Matrix<gf2n_short_square>, gf2n_short)
template class OTCorrelator<Matrix<Square<gf2n_long>>>;
Z(Matrix<Square<gf2n_long>>, gf2n_long)
template class OTCorrelator<Matrix<Square<gfp1>>>;
Z(Matrix<Square<gfp1>>, gfp1)
#define ZZZZ(GF) \
template void OTExtensionWithMatrix::print_post_correlate<GF>( \
BitVector& newReceiverInput, int j, int offset, int sender); \
template void OTExtensionWithMatrix::extend<GF>(int nOTs_requested, \
BitVector& newReceiverInput); \
#define ZZZ(GF, M) \
template void OTExtensionWithMatrix::hash_outputs<GF, M >(int, vector<M >&, M&);
#define MM Matrix<Rectangle<Z2<512>, Z2<160> > >
#define ZZZ(GF, M) Z(M, GF) \
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&);
ZZZZ(gfp1)
ZZZZ(gf2n_long)
ZZZ(Z2<160>, MM)
ZZZ(gf2n_short, Matrix<gf2n_short_square>)
ZZZ(gf2n_long, Matrix<Square<gf2n_long>>)
ZZZ(gfp1, Matrix<Square<gfp1>>)
ZZZ(BitVec, Matrix<BitDiagonal>)
#undef XX
#define XX(T,U,N,L) \
template class OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >; \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::correlate<T>(int start, int slice, \
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<U>& output); \
template void OTExtensionWithMatrix::hash_outputs<T, Matrix<Rectangle<Z2<N>, Z2<L> > > >(int, \
template void OTExtensionWithMatrix::hash_outputs(int, \
std::vector<Matrix<Rectangle<Z2<N>, Z2<L> > >, std::allocator<Matrix<Rectangle<Z2<N>, Z2<L> > > > >&, \
Matrix<Rectangle<Z2<N>, Z2<L> > >&);
#undef X
#define X(N,L) \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::expand<Z2<L> >(int start, int slice); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<Z2kRectangle<N, L> >& output); \
XX(Z2<L>,Z2<N>,N,L)
@@ -579,3 +554,8 @@ Y(64, 48)
Y(66, 64)
Y(66, 48)
Y(32, 32)
Y(1, 40)
Y(72, 48)
Y(74, 48)
Y(72, 64)
Y(74, 64)

View File

@@ -39,12 +39,10 @@ public:
receiverOutputMatrix(matrices[0]), t1(matrices[1]) {}
void resize(int nOTs);
template <class T>
void expand(int start, int slice);
void setup_for_correlation(BitVector& baseReceiverInput,
vector<U>& baseSenderOutputs,
U& baseReceiverOutput);
template <class T>
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
template <class T>
void reduce_squares(unsigned int nTriples, vector<T>& output);
@@ -76,14 +74,12 @@ public:
void seed(vector<BitMatrix>& baseSenderInput,
BitMatrix& baseReceiverOutput);
void transfer(int nOTs, const BitVector& receiverInput);
template <class T>
void extend(int nOTs, BitVector& newReceiverInput);
void extend_correlated(BitVector& newReceiverInput);
void extend_correlated(int nOTs, BitVector& newReceiverInput);
void extend_correlated(const BitVector& newReceiverInput);
void extend_correlated(int nOTs, const BitVector& newReceiverInput);
void transpose(int start, int slice);
template <class T>
void expand_transposed();
template <class T, class V>
template <class V>
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput);
void print(BitVector& newReceiverInput, int i = 0);
@@ -100,7 +96,6 @@ public:
octet* get_sender_output(int choice, int i);
protected:
template <class T>
void hash_outputs(int nOTs);
};

View File

@@ -79,7 +79,7 @@ public:
};
template <class T>
class MascotMultiplier : public OTMultiplier<Share<T>>
class MascotMultiplier : public OTMultiplier<T>
{
OTCorrelator<Matrix<typename T::Square> > auth_ot_ext;
void after_correlation();
@@ -88,13 +88,32 @@ class MascotMultiplier : public OTMultiplier<Share<T>>
const vector<BitVector>& baseReceiverOutput);
public:
vector<T> c_output;
vector<typename T::open_type> c_output;
MascotMultiplier(OTTripleGenerator<Share<T>>& generator, int thread_num);
MascotMultiplier(OTTripleGenerator<T>& generator, int thread_num);
void multiplyForInputs(MultJob job);
};
template <class T>
class TinyMultiplier : public OTMultiplier<T>
{
OTVole<typename T::part_type::sacri_type,
typename T::part_type::mac_key_type> mac_vole;
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
public:
vector<typename T::open_type> c_output;
TinyMultiplier(OTTripleGenerator<T>& generator, int thread_num);
void multiplyForInputs(MultJob job) { (void) job; throw not_implemented(); }
};
template <int K, int S> class Spdz2kShare;
template <int K, int S>

View File

@@ -9,6 +9,7 @@
#include "OT/NPartyTripleGenerator.h"
#include "OT/Rectangle.h"
#include "Math/Z2k.h"
#include "Math/BitVec.h"
#include "Protocols/SemiShare.h"
#include "Protocols/Semi2kShare.h"
#include "Protocols/Spdz2kShare.h"
@@ -37,14 +38,28 @@ OTMultiplier<T>::OTMultiplier(OTTripleGenerator<T>& generator,
}
template<class T>
MascotMultiplier<T>::MascotMultiplier(OTTripleGenerator<Share<T>>& generator,
MascotMultiplier<T>::MascotMultiplier(OTTripleGenerator<T>& generator,
int thread_num) :
OTMultiplier<Share<T>>(generator, thread_num),
OTMultiplier<T>(generator, thread_num),
auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true)
{
c_output.resize(generator.nTriplesPerLoop);
}
template<class T>
TinyMultiplier<T>::TinyMultiplier(OTTripleGenerator<T>& generator,
int thread_num) :
OTMultiplier<T>(generator, thread_num),
mac_vole(
128, 128, 0, 1,
generator.players[thread_num],
{ },
{ },
{ }, BOTH, false)
{
c_output.resize(generator.nTriplesPerLoop);
}
template <int K, int S>
Spdz2kMultiplier<K, S>::Spdz2kMultiplier(OTTripleGenerator<Spdz2kShare<K, S>>& generator, int thread_num) :
OTMultiplier<Spdz2kShare<K, S>>
@@ -75,7 +90,7 @@ template<class T>
void OTMultiplier<T>::multiply()
{
keyBits.set(generator.machine.template get_mac_key<typename T::mac_key_type>());
rot_ext.extend<gf2n_long>(keyBits.size(), keyBits);
rot_ext.extend(keyBits.size(), keyBits);
this->outbox.push({});
senderOutput.resize(keyBits.size());
for (size_t j = 0; j < keyBits.size(); j++)
@@ -123,7 +138,6 @@ void OTMultiplier<T>::multiply()
template<class W>
void OTMultiplier<W>::multiplyForTriples()
{
typedef typename W::open_type T;
typedef typename W::Rectangle X;
// dummy input for OT correlator
@@ -148,13 +162,13 @@ void OTMultiplier<W>::multiplyForTriples()
BitVector aBits = generator.valueBits[0];
//timers["Extension"].start();
rot_ext.extend_correlated(aBits);
rot_ext.hash_outputs<T>(aBits.size(), baseSenderOutputs, baseReceiverOutput);
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
//timers["Extension"].stop();
//timers["Correlation"].start();
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
baseReceiverOutput);
otCorrelator.template correlate<T>(0, generator.nPreampTriplesPerLoop,
otCorrelator.correlate(0, generator.nPreampTriplesPerLoop,
generator.valueBits[1], false, generator.nAmplify);
//timers["Correlation"].stop();
@@ -171,6 +185,14 @@ void MascotMultiplier<T>::init_authenticator(const BitVector& keyBits,
this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput);
}
template<class T>
void TinyMultiplier<T>::init_authenticator(const BitVector& keyBits,
const vector<vector<BitVector> >& senderOutput,
const vector<BitVector>& receiverOutput)
{
mac_vole.init(keyBits, senderOutput, receiverOutput);
}
template <int K, int S>
void Spdz2kMultiplier<K, S>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
@@ -188,9 +210,11 @@ void SemiMultiplier<T>::after_correlation()
this->outbox.push({});
}
template <class T>
void MascotMultiplier<T>::after_correlation()
template <class U>
void MascotMultiplier<U>::after_correlation()
{
typedef typename U::open_type T;
this->auth_ot_ext.resize(
this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS);
this->auth_ot_ext.set_role(BOTH);
@@ -210,8 +234,8 @@ void MascotMultiplier<T>::after_correlation()
int nValues = this->generator.nTriplesPerLoop;
if (this->generator.machine.check && (j % 2 == 0))
nValues *= 2;
this->auth_ot_ext.template expand<T>(0, nValues);
this->auth_ot_ext.template correlate<T>(0, nValues,
this->auth_ot_ext.expand(0, nValues);
this->auth_ot_ext.correlate(0, nValues,
this->generator.valueBits[j], true);
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
}
@@ -219,6 +243,29 @@ void MascotMultiplier<T>::after_correlation()
}
}
template <class T>
void TinyMultiplier<T>::after_correlation()
{
this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop,
this->c_output);
this->outbox.push({});
this->macs.resize(3);
MultJob job;
this->inbox.pop(job);
for (int j = 0; j < 3; j++)
{
int nValues = this->generator.nTriplesPerLoop * T::default_length;
auto& bits = this->generator.valueBits[j];
vector<typename T::part_type::sacri_type> values(nValues);
for (int i = 0; i < nValues; i++)
values[i] = bits.get_bit(i);
mac_vole.evaluate(this->macs[j], values);
}
this->outbox.push(job);
}
template <int K, int S>
void Spdz2kMultiplier<K, S>::after_correlation()
{
@@ -295,9 +342,9 @@ void OTMultiplier<Share<gf2n>>::multiplyForBits()
for (int i = 0; i < generator.nloops; i++)
{
auth_ot_ext.expand<gf2n_long>(0, nBlocks);
auth_ot_ext.expand(0, nBlocks);
inbox.pop(job);
auth_ot_ext.correlate<gf2n_long>(0, nBlocks, generator.valueBits[0], true);
auth_ot_ext.correlate(0, nBlocks, generator.valueBits[0], true);
auth_ot_ext.transpose(0, nBlocks);
for (int j = 0; j < nBits; j++)
@@ -311,8 +358,8 @@ void OTMultiplier<Share<gf2n>>::multiplyForBits()
}
}
template<class T>
void MascotMultiplier<T>::multiplyForInputs(MultJob job)
template<class U>
void MascotMultiplier<U>::multiplyForInputs(MultJob job)
{
assert(job.input);
auto& generator = this->generator;
@@ -320,10 +367,10 @@ void MascotMultiplier<T>::multiplyForInputs(MultJob job)
auth_ot_ext.set_role(mine ? RECEIVER : SENDER);
int nOTs = job.n_inputs * generator.field_size;
auth_ot_ext.resize(nOTs);
auth_ot_ext.template expand<T>(0, job.n_inputs);
auth_ot_ext.expand(0, job.n_inputs);
if (mine)
this->inbox.pop();
auth_ot_ext.template correlate<T>(0, job.n_inputs, generator.valueBits[0], true);
auth_ot_ext.correlate(0, job.n_inputs, generator.valueBits[0], true);
auto& input_macs = this->input_macs;
input_macs.resize(job.n_inputs);
if (mine)
@@ -355,24 +402,3 @@ void OTMultiplier<T>::multiplyForBits()
{
throw runtime_error("bit generation not implemented in this case");
}
template class OTMultiplier<Share<gf2n>>;
template class OTMultiplier<Share<gfp1>>;
template class OTMultiplier<SemiShare<gf2n>>;
template class OTMultiplier<SemiShare<gfp1>>;
template class SemiMultiplier<SemiShare<gf2n>>;
template class SemiMultiplier<SemiShare<gfp1>>;
template class SemiMultiplier<Semi2kShare<64>>;
template class SemiMultiplier<Semi2kShare<72>>;
template class MascotMultiplier<gf2n_long>;
template class MascotMultiplier<gf2n_short>;
template class MascotMultiplier<gfp1>;
#define X(K, S) \
template class Spdz2kMultiplier<K, S>; \
template class OTMultiplier<Spdz2kShare<K, S>>;
X(64, 64)
X(64, 48)
X(66, 64)
X(66, 48)
X(32, 32)

View File

@@ -3,7 +3,6 @@
#include "Networking/Player.h"
#include "OT/BaseOT.h"
#include "OT/OTMachine.h"
#include "Tools/random.h"
#include "Tools/time-func.h"

Some files were not shown because too many files have changed in this diff Show More