Rep4, SPDZ-wise, MNIST training.

This commit is contained in:
Marcel Keller
2020-10-28 11:20:52 +11:00
parent 53f9b023dc
commit f42e614399
184 changed files with 5837 additions and 820 deletions

View File

@@ -104,7 +104,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
}
else
{
T::read_or_generate_mac_key(prep_dir, N, mac_key);
T::read_or_generate_mac_key(prep_dir, *P, mac_key);
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
}

View File

@@ -24,8 +24,6 @@
#include <unistd.h>
ostream& EvalRegister::out = cout;
int Register::counter = 0;
void Register::init(int n_parties)

View File

@@ -22,6 +22,7 @@ using namespace std;
#include "Tools/FlexBuffer.h"
#include "Tools/PointerVector.h"
#include "Tools/Bundle.h"
#include "Tools/SwitchableOutput.h"
//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
@@ -199,6 +200,7 @@ public:
BlackHole& operator<<(T) { return *this; }
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
void activate(bool) {}
void redirect_to_file(ostream&) {}
};
inline BlackHole& endl(BlackHole& b) { return b; }
inline BlackHole& flush(BlackHole& b) { return b; }
@@ -211,7 +213,6 @@ public:
typedef NoMemory DynamicMemory;
typedef BlackHole out_type;
static BlackHole out;
static const bool actual_inputs = true;
@@ -353,8 +354,7 @@ public:
typedef EvalInputter Input;
typedef ostream& out_type;
static ostream& out;
typedef SwitchableOutput out_type;
static const bool actual_inputs = true;

View File

@@ -1,5 +1,24 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.0 (Oct 28, 2020)
- Rep4: honest-majority four-party computation with malicious security
- SY/SPDZ-wise: honest-majority computation with malicious security based on replicated or Shamir secret sharing
- Training with a sequence of dense layers
- Training and inference for multi-class classification
- Local share conversion for semi-honest protocols based on additive secret sharing modulo a power of two
- edaBit generation based on local share conversion
- Optimize exponentation with local share conversion
- Optimize Shamir pseudo-random secret sharing using a hyper-invertible matrix
- Mathematical functions (exponentation, logarithm, square root, and trigonometric functions) with binary circuits
- Direct construction of fixed-point values from any type, breaking `sfix(x)` where `x` is the integer representation of a fixed-point number. Use `sfix._new(x)` instead.
- Optimized dot product for `sfix`
- Matrix multiplication via operator overloading uses VM-optimized multiplication.
- Fake preprocessing for daBits and edaBits
- Fixed security bug: insufficient randomness in SemiBin random bit generation.
- Fixed security bug: insufficient randomization of FKOS15 inputs.
- Fixed security bug in binary computation with SPDZ(2k).
## 0.1.9 (Aug 24, 2020)
- Streamline inputs to binary circuits
@@ -7,7 +26,7 @@ The changelog explains changes pulled through from the private development repos
- Emulator for arithmetic circuits
- Efficient dot product with Shamir's secret sharing
- Lower memory usage for TensorFlow inference
- This version breaks bytecode compatibilty.
- This version breaks bytecode compatibility.
## 0.1.8 (June 15, 2020)

4
CONFIG
View File

@@ -24,7 +24,9 @@ USE_GF2N_LONG = 1
# AVX/AVX2 is required for replicated binary secret sharing
# BMI2 is used to optimize multiplication modulo a prime
# ADX is used to optimize big integer additions
# delete the second line to compile for a platform that supports everything
ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
ARCH = -march=native
# allow to set compiler in CONFIG.mine
CXX = g++
@@ -60,7 +62,7 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

View File

@@ -45,6 +45,7 @@ class bits(Tape.Register, _structure, _bit):
return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], []))
@classmethod
def bit_compose(cls, bits):
bits = list(bits)
if len(bits) == 1:
return bits[0]
bits = list(bits)
@@ -72,7 +73,7 @@ class bits(Tape.Register, _structure, _bit):
res = [self.bit_type() for i in range(n)]
self.bitdec(self, *res)
else:
res = self.trans([self])
res = self.bit_type.trans([self])
self.decomposed = res
return res + suffix
else:
@@ -83,8 +84,8 @@ class bits(Tape.Register, _structure, _bit):
cbits.conv_cint_vec(a, *res)
return res
@classmethod
def malloc(cls, size):
return Program.prog.malloc(size, cls)
def malloc(cls, size, creator_tape=None):
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@@ -430,6 +431,8 @@ class sbits(bits):
def equal(self, other, n=None):
bits = (~(self + other)).bit_decompose()
return reduce(operator.mul, bits)
def right_shift(self, m, k, security=None, signed=True):
return self.TruncPr(k, m)
def TruncPr(self, k, m, kappa=None):
if k > self.n:
raise Exception('TruncPr overflow: %d > %d' % (k, self.n))
@@ -481,8 +484,8 @@ class sbitvec(_vec):
def get_type(cls, n):
class sbitvecn(cls, _structure):
@staticmethod
def malloc(size):
return sbit.malloc(size * n)
def malloc(size, creator_tape=None):
return sbit.malloc(size * n, creator_tape=creator_tape)
@staticmethod
def n_elements():
return n
@@ -566,7 +569,8 @@ class sbitvec(_vec):
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
elif elements is not None:
elif elements is not None and not (util.is_constant(elements) and \
elements == 0):
self.v = sbits.trans(elements)
def popcnt(self):
res = sbitint.wallace_tree([[b] for b in self.v])
@@ -606,7 +610,10 @@ class sbitvec(_vec):
return cls.from_vec(other.v)
@property
def size(self):
return self.v[0].n
if not self.v or util.is_constant(self.v[0]):
return 1
else:
return self.v[0].n
@property
def n_bits(self):
return len(self.v)
@@ -725,6 +732,8 @@ class _sbitintbase:
return self.get_type(n).bit_compose(bits)
def round(self, k, m, kappa=None, nearest=None, signed=None):
bits = self.bit_decompose()
if signed:
bits += [bits[-1]] * (k - len(bits))
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None):
@@ -781,7 +790,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
@classmethod
def bit_compose(cls, bits):
# truncate and extend bits
bits = bits[:cls.n]
bits = list(bits)[:cls.n]
bits += [0] * (cls.n - len(bits))
return super(sbitint, cls).bit_compose(bits)
def force_bit_decompose(self, n_bits=None):
@@ -801,6 +810,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits)))
product = a * b
res_bits = product.bit_decompose()[m:k]
res_bits += [res_bits[-1]] * (self.n - len(res_bits))
t = self.combo_type(other)
return t.bit_compose(res_bits)
def __mul__(self, other):
@@ -824,6 +834,15 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
else:
res.append([(x & bit) for x in other.bit_decompose(n - i)])
return res
@classmethod
def popcnt_bits(cls, bits):
res = sbitvec.from_vec(bits).popcnt().elements()[0]
res = cls.conv(res)
return res
def pow2(self, k):
l = int(math.ceil(math.log(k, 2)))
bits = [self.equal(i, l) for i in range(k)]
return self.bit_compose(bits)
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
def __add__(self, other):
@@ -867,8 +886,11 @@ class cbitfix(object):
conv = staticmethod(lambda x: x)
load_mem = classmethod(lambda cls, *args: cls(cbits.load_mem(*args)))
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
def __init__(self, value):
self.v = value
@classmethod
def _new(cls, value):
res = cls()
res.v = value
return res
def output(self):
v = self.v
if self.k < v.unit:
@@ -897,10 +919,10 @@ class sbitfix(_fix):
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
def __xor__(self, other):
return type(self)(self.v ^ other.v)
return type(self)._new(self.v ^ other.v)
def __mul__(self, other):
if isinstance(other, sbit):
return type(self)(self.int_type(other * self.v))
return type(self)._new(self.int_type(other * self.v))
elif isinstance(other, sbitfixvec):
return other * self
else:
@@ -911,10 +933,11 @@ class sbitfix(_fix):
def multipliable(other, k, f, size):
class cls(_fix):
int_type = sbitint.get_type(k)
clear_type = cbitfix
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
sbitfix.set_precision(20, 41)
sbitfix.set_precision(16, 31)
class sbitfixvec(_fix):
int_type = sbitintvec

View File

@@ -220,6 +220,7 @@ class Merger:
else:
self.max_parallel_open = float('inf')
self.counter = defaultdict(lambda: 0)
self.rounds = defaultdict(lambda: 0)
self.dependency_graph(merge_classes)
def do_merge(self, merges_iter):
@@ -271,6 +272,7 @@ class Merger:
merge = merges[i]
t = type(self.instructions[merge[0]])
self.counter[t] += len(merge)
self.rounds[t] += 1
if len(merge) > 10000:
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))

View File

@@ -135,27 +135,37 @@ def Trunc(d, a, k, m, kappa, signed):
mulm(d, t, c[2])
def TruncRing(d, a, k, m, signed):
if program.use_split() == 3:
program.curr_tape.require_bit_length(1)
if program.use_split() in (2, 3):
if signed:
a += (1 << (k - 1))
from Compiler.types import sint
from .GC.types import sbitint
length = int(program.options.ring)
summands = a.split_to_n_summands(length, 3)
summands = a.split_to_n_summands(length, program.use_split())
x = sbitint.wallace_tree_without_finish(summands, True)
if m == 1:
low = x[1][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
if program.use_split() == 2:
carries = sbitint.get_carries(*x)
low = carries[m]
high = sint.conv(carries[length])
else:
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
low = sint.conv(mid_carry) + sint.conv(x[0][m])
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
for xx, yy in zip(x[1][m:-1],
x[0][m:-1])))
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
high = top_carry + sint.conv(x[0][-1])
if m == 1:
low = x[1][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
else:
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
low = sint.conv(mid_carry) + sint.conv(x[0][m])
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
for xx, yy in zip(x[1][m:-1],
x[0][m:-1])))
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
high = top_carry + sint.conv(x[0][-1])
shifted = sint()
shrsi(shifted, a, m)
res = shifted + sint.conv(low) - (high << (length - m))
if signed:
res -= (1 << (k - m - 1))
else:
a_prime = Mod2mRing(None, a, k, m, signed)
a -= a_prime

View File

@@ -53,6 +53,12 @@ def maskField(a, k, kappa):
@instructions_base.ret_cisc
def EQZ(a, k, kappa):
prog = program.Program.prog
if prog.use_split():
from GC.types import sbitvec
v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sint.conv(bit)
if program.Program.prog.options.ring:
c, r = maskRing(a, k)
else:
@@ -307,16 +313,22 @@ def BitDec(a, k, m, kappa, bits_to_compute=None):
def BitDecRing(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m
assert(n_shift >= 0)
if program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r = types.sint.bit_compose(r)
if program.Program.prog.use_split():
x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
else:
r_bits = [types.sint.get_random_bit() for i in range(m)]
r = types.sint.bit_compose(r_bits)
shifted = ((a - r) << n_shift).reveal()
masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return [types.sint.conv(bit) for bit in bits]
if program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r = types.sint.bit_compose(r)
else:
r_bits = [types.sint.get_random_bit() for i in range(m)]
r = types.sint.bit_compose(r_bits)
shifted = ((a - r) << n_shift).reveal()
masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return [types.sint.conv(bit) for bit in bits]
def BitDecField(a, k, m, kappa, bits_to_compute=None):
r_dprime = types.sint()
@@ -476,22 +488,20 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
def Int2FL(a, gamma, l, kappa):
lam = gamma - 1
s = types.sint()
comparison.LTZ(s, a, gamma, kappa)
z = EQZ(a, gamma, kappa)
a = (1 - 2 * s) * a
a_bits = BitDec(a, lam, lam, kappa)
s = a.less_than(0, gamma, security=kappa)
z = a.equal(0, gamma, security=kappa)
a = s.if_else(-a, a)
a_bits = a.bit_decompose(lam, security=kappa)
a_bits.reverse()
b = PreOR(a_bits, kappa)
t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b)))
p = - (lam - sum(b))
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
p = a.popcnt_bits(b) - lam
if gamma - 1 > l:
if types.sfloat.round_nearest:
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
p = p + overflow
else:
v = types.sint()
comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * (1 -z)
@@ -539,6 +549,7 @@ def TruncPrRing(a, k, m, signed=True):
n_ring = int(program.Program.prog.options.ring)
assert n_ring >= k, '%d too large' % k
if k == n_ring:
program.Program.prog.curr_tape.require_bit_length(1)
if program.Program.prog.use_edabit():
a += types.sint.get_edabit(m, True)[0]
else:
@@ -555,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True):
else:
# extra bit to mask overflow
prog = program.Program.prog
if prog.use_edabit() or prog.use_split() == 3:
prog.curr_tape.require_bit_length(1)
if prog.use_edabit() or prog.use_split() > 2:
lower = sint.get_random_int(m)
upper = sint.get_random_int(k - m)
msb = sint.get_random_bit()

View File

@@ -441,6 +441,8 @@ def cisc(function):
program.options.cisc = True
reset_global_vector_size()
program.curr_tape = old_tape
for x, bl in tape.req_bit_length.items():
old_tape.require_bit_length(bl, x)
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
@@ -523,25 +525,26 @@ def ret_cisc(function):
def sfix_cisc(function):
from Compiler.types import sfix, sint, cfix, copy_doc
def instruction(res, arg, k, f):
def instruction(res, arg, k, f, *args):
assert k is not None
assert f is not None
old = sfix.k, sfix.f, cfix.k, cfix.f
sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4
res.mov(res, function(sfix._new(arg, k=k, f=f)).v)
res.mov(res, function(sfix._new(arg, k=k, f=f), *args).v)
sfix.k, sfix.f, cfix.k, cfix.f = old
instruction.__name__ = function.__name__
instruction = cisc(instruction)
def wrapper(*args, **kwargs):
if isinstance(args[0], sfix):
assert len(args) == 1
for arg in args[1:]:
assert util.is_constant(arg)
assert not kwargs
assert args[0].size == args[0].v.size
k = args[0].k
f = args[0].f
res = sfix._new(sint(size=args[0].size), k=k, f=f)
instruction(res.v, args[0].v, k, f)
instruction(res.v, args[0].v, k, f, *args[1:])
return res
else:
return function(*args, **kwargs)

View File

@@ -134,6 +134,7 @@ def print_ln_if(cond, ss, *args):
print_str_if(cond, ss + '\n', *args)
def print_str_if(cond, ss, *args):
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
if util.is_constant(cond):
if cond:
print_ln(ss, *args)
@@ -160,7 +161,8 @@ def print_str_if(cond, ss, *args):
def print_ln_to(player, ss, *args):
""" Print line at :py:obj:`player` only. Note that printing is
disabled by default except at player 0.
disabled by default except at player 0. Activate interactive mode
with `-I` to enable it for all players.
:param player: int
:param ss: Python string
@@ -814,8 +816,8 @@ def range_loop(loop_body, start, stop=None, step=None):
if step is None:
step = 1
def loop_fn(i):
loop_body(i)
return i + step
res = loop_body(i)
return util.if_else(res == 0, stop, i + step)
if isinstance(step, int):
if step > 0:
condition = lambda x: x < stop
@@ -840,7 +842,9 @@ def for_range(start, stop=None, step=None):
in Python :py:func:`range`, but they can by any public
integer. Information has to be passed out via container types such
as :py:class:`Compiler.types.Array` or declaring registers as
:py:obj:`global`.
:py:obj:`global`. Note that changing Python data structures such
as lists within the loop is not possible, but the compiler cannot
warn about this.
:param start/stop/step: regint/cint/int
@@ -1057,7 +1061,7 @@ def for_range_opt_multithread(n_threads, n_loops):
"""
return for_range_multithread(n_threads, None, n_loops)
def multithread(n_threads, n_items):
def multithread(n_threads, n_items, max_size=None):
"""
Distribute the computation of :py:obj:`n_items` to
:py:obj:`n_threads` threads, but leave the in-thread repetition up
@@ -1075,8 +1079,19 @@ def multithread(n_threads, n_items):
def f(base, size):
...
"""
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
if max_size is None:
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
else:
def wrapper(function):
@multithread(n_threads, n_items)
def new_function(base, size):
for i in range(0, size, max_size):
part_base = base + i
part_size = min(max_size, size - i)
function(part_base, part_size)
break_point()
return wrapper
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
thread_mem_req={}, looping=True):
@@ -1563,8 +1578,8 @@ def cint_cint_division(a, b, k, f):
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b < 0)
sign_a = cint(1) - 2 * cint(a < 0)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = cint(1) - 2 * cint(a.less_than(0, k))
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
@@ -1632,9 +1647,12 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
f = max((k - nearest) // 2 + 1, f)
assert 2 * f > k - nearest
theta = int(ceil(log(k/3.5) / log(2)))
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
y = a.extend(2 *k) * w
y = y.round(2*k, f, kappa, nearest, signed=True)

View File

@@ -42,6 +42,7 @@ an example of how to run MP-SPDZ on TensorFlow graphs.
"""
import math
import re
from Compiler import mpc_math, util
from Compiler.types import *
@@ -58,13 +59,13 @@ def log_e(x):
def exp(x):
return mpc_math.pow_fx(math.e, x)
def sanitize(x, raw, lower, upper):
def get_limit(x):
exp_limit = 2 ** (x.k - x.f - 1)
limit = math.log(exp_limit)
if get_program().options.ring:
res = raw
else:
res = (x > limit).if_else(upper, raw)
return math.log(exp_limit)
def sanitize(x, raw, lower, upper):
limit = get_limit(x)
res = (x > limit).if_else(upper, raw)
return (x < -limit).if_else(lower, res)
def sigmoid(x):
@@ -137,10 +138,12 @@ def argmax(x):
return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1])
return tree_reduce(op, enumerate(x))[0]
report_progress = False
def progress(x):
return
print_ln(x)
time()
if report_progress:
print_ln(x)
time()
def set_n_threads(n_threads):
Layer.n_threads = n_threads
@@ -159,6 +162,10 @@ class Tensor(MultiArray):
self.alloc()
return super(Tensor, self).__getitem__(*args)
def assign_vector(self, *args):
self.alloc()
return super(Tensor, self).assign_vector(*args)
class Layer:
n_threads = 1
inputs = []
@@ -190,12 +197,23 @@ class Layer:
class NoVariableLayer(Layer):
input_from = lambda *args, **kwargs: None
class Output(Layer):
nablas = lambda self: ()
reset = lambda self: None
class Output(NoVariableLayer):
""" Fixed-point logistic regression output layer.
:param N: number of examples
:param approx: :py:obj:`False` (default) or parameter for :py:obj:`approx_sigmoid`
"""
n_outputs = 2
@classmethod
def from_args(cls, N, program):
res = cls(N, approx='approx' in program.args)
res.compute_loss = not 'no_loss' in program.args
return res
def __init__(self, N, debug=False, approx=False):
self.N = N
self.X = sfix.Array(N)
@@ -206,9 +224,7 @@ class Output(Layer):
self.debug = debug
self.weights = None
self.approx = approx
nablas = lambda self: ()
reset = lambda self: None
self.compute_loss = True
def divisor(self, divisor, size):
return cfix(1.0 / divisor, size=size)
@@ -224,11 +240,13 @@ class Output(Layer):
x = self.X.get_vector(base, size)
y = self.Y.get(batch.get_vector(base, size))
if self.approx:
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
if self.compute_loss:
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
return
e_x = exp(-x)
self.e_x.assign(e_x, base)
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
if self.compute_loss:
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
self.l.write(sum(lse) * \
self.divisor(N, 1))
@@ -246,13 +264,10 @@ class Output(Layer):
diff = self.eval(size, base) - \
self.Y.get(batch.get_vector(base, size))
assert sfix.f == cfix.f
if self.weights is None:
diff *= self.divisor(N, size)
else:
if self.weights is not None:
assert N == len(self.weights)
diff *= self.weights.get_vector(base, size)
if self.weight_total != 1:
diff *= self.divisor(self.weight_total, size)
assert self.weight_total == N
self.nabla_X.assign(diff, base)
# @for_range_opt(len(diff))
# def _(i):
@@ -271,6 +286,244 @@ class Output(Layer):
self.weights.assign(weights)
self.weight_total = sum(weights)
def average_loss(self, N):
return self.l.reveal()
def reveal_correctness(self, n=None, Y=None, debug=False):
if n is None:
n = self.X.sizes[0]
if Y is None:
Y = self.Y
n_correct = MemValue(0)
n_printed = MemValue(0)
@for_range_opt(n)
def _(i):
truth = Y[i].reveal()
b = self.X[i].reveal()
if debug:
nabla = self.nabla_X[i].reveal()
guess = b > 0
correct = truth == guess
n_correct.iadd(correct)
if debug:
to_print = (1 - correct) * (n_printed < 10)
n_printed.iadd(to_print)
print_ln_if(to_print, '%s: %s %s %s %s',
i, truth, guess, b, nabla)
return n_correct
class MultiOutputBase(NoVariableLayer):
def __init__(self, N, d_out, approx=False, debug=False):
self.X = sfix.Matrix(N, d_out)
self.Y = sint.Matrix(N, d_out)
self.nabla_X = sfix.Matrix(N, d_out)
self.l = MemValue(sfix(-1))
self.losses = sfix.Array(N)
self.approx = None
self.N = N
self.d_out = d_out
self.compute_loss = True
def eval(self, N):
d_out = self.X.sizes[1]
res = sfix.Matrix(N, d_out)
res.assign_vector(self.X.get_part_vector(0, N))
return res
def average_loss(self, N):
return sum(self.losses.get_vector(0, N)).reveal() / N
def reveal_correctness(self, n=None, Y=None, debug=False):
if n is None:
n = self.X.sizes[0]
if Y is None:
Y = self.Y
n_correct = MemValue(0)
n_printed = MemValue(0)
@for_range_opt(n)
def _(i):
a = Y[i].reveal_list()
b = self.X[i].reveal_list()
if debug:
loss = self.losses[i].reveal()
exp = self.get_extra_debugging(i)
nabla = self.nabla_X[i].reveal_list()
truth = argmax(a)
guess = argmax(b)
correct = truth == guess
n_correct.iadd(correct)
if debug:
to_print = (1 - correct) * (n_printed < 10)
n_printed.iadd(to_print)
print_ln_if(to_print, '%s: %s %s %s %s %s %s',
i, truth, guess, loss, b, exp, nabla)
return n_correct
@property
def n_outputs(self):
return self.d_out
def get_extra_debugging(self, i):
return ''
@staticmethod
def from_args(program, N, n_output):
if 'relu_out' in program.args:
res = ReluMultiOutput(N, n_output)
else:
res = MultiOutput(N, n_output, approx='approx' in program.args)
res.cheaper_loss = 'mse' in program.args
res.compute_loss = not 'no_loss' in program.args
return res
class MultiOutput(MultiOutputBase):
"""
Output layer for multi-class classification with softmax and cross entropy.
:param N: number of examples
:param d_out: number of classes
:param approx: use ReLU division instead of softmax for the loss
"""
def __init__(self, N, d_out, approx=False, debug=False):
MultiOutputBase.__init__(self, N, d_out)
self.exp = sfix.Matrix(N, d_out)
self.approx = approx
self.positives = sint.Matrix(N, d_out)
self.relus = sfix.Matrix(N, d_out)
self.cheaper_loss = False
self.debug = debug
self.true_X = sfix.Array(N)
def forward(self, batch):
N = len(batch)
d_out = self.X.sizes[1]
tmp = self.losses
@for_range_opt_multithread(self.n_threads, N)
def _(i):
if self.approx:
positives = self.X[i].get_vector() > (0 if self.cheaper_loss else 0.1)
relus = positives.if_else(self.X[i].get_vector(), 0)
self.positives[i].assign_vector(positives)
self.relus[i].assign_vector(relus)
if self.compute_loss:
if self.cheaper_loss:
s = sum(relus)
tmp[i] = sum((self.Y[batch[i]][j] * s - relus[j]) ** 2
for j in range(d_out)) / s ** 2 * 0.5
else:
div = relus / sum(relus).expand_to_vector(d_out)
self.losses[i] = -sfix.dot_product(
self.Y[batch[i]].get_vector(), log_e(div))
else:
m = util.max(self.X[i])
mv = m.expand_to_vector(d_out)
x = self.X[i].get_vector()
e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0)
self.exp[i].assign_vector(e)
if self.compute_loss:
true_X = sfix.dot_product(self.Y[batch[i]], self.X[i])
tmp[i] = m + log_e(sum(e)) - true_X
self.true_X[i] = true_X
self.l.write(sum(tmp.get_vector(0, N)) / N)
def eval(self, N):
d_out = self.X.sizes[1]
res = sfix.Matrix(N, d_out)
if self.approx:
@for_range_opt_multithread(self.n_threads, N)
def _(i):
relus = (self.X[i].get_vector() > 0).if_else(
self.X[i].get_vector(), 0)
res[i].assign_vector(relus / sum(relus).expand_to_vector(d_out))
return res
@for_range_opt_multithread(self.n_threads, N)
def _(i):
e = exp(self.X[i].get_vector())
res[i].assign_vector(e / sum(e).expand_to_vector(d_out))
return res
def backward(self, batch):
d_out = self.X.sizes[1]
if self.approx:
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
if self.cheaper_loss:
s = sum(self.relus[i])
ss = s * s * s
inv = 1 / ss
@for_range_opt(d_out)
def _(j):
res = 0
for k in range(d_out):
relu = self.relus[i][k]
summand = relu - self.Y[batch[i]][k] * s
summand *= (sfix.from_sint(j == k) - relu)
res += summand
fallback = -self.Y[batch[i]][j]
res *= inv
self.nabla_X[i][j] = self.positives[i][j].if_else(res, fallback)
return
relus = self.relus[i].get_vector()
positives = self.positives[i].get_vector()
inv = (1 / sum(relus)).expand_to_vector(d_out)
truths = self.Y[batch[i]].get_vector()
raw = truths / relus - inv
self.nabla_X[i] = -positives.if_else(raw, truths)
self.maybe_debug_backward(batch)
return
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
for j in range(d_out):
dividend = self.exp[i][j]
divisor = sum(self.exp[i])
div = (divisor > 0.1).if_else(dividend / divisor, 0)
self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div)
self.maybe_debug_backward(batch)
def maybe_debug_backward(self, batch):
if self.debug:
@for_range(len(batch))
def _(i):
check = 0
for j in range(self.X.sizes[1]):
to_check = self.nabla_X[i][j].reveal()
check += (to_check > len(batch)) + (to_check < -len(batch))
print_ln_if(check, 'X %s', self.X[i].reveal_nested())
print_ln_if(check, 'exp %s', self.exp[i].reveal_nested())
print_ln_if(check, 'nabla X %s',
self.nabla_X[i].reveal_nested())
def get_extra_debugging(self, i):
if self.approx:
return self.relus[i].reveal_list()
else:
return self.exp[i].reveal_list()
class ReluMultiOutput(MultiOutputBase):
"""
Output layer for multi-class classification with back-propagation
based on ReLU division.
:param N: number of examples
:param d_out: number of classes
"""
def forward(self, batch):
self.l.write(999)
def backward(self, batch):
N = len(batch)
d_out = self.X.sizes[1]
relus = sfix.Matrix(N, d_out)
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
positives = self.X[i].get_vector() > 0
relus = positives.if_else(self.X[i].get_vector(), 0)
s = sum(relus)
inv = 1 / s
prod = relus * inv
res = prod - self.Y[batch[i]].get_vector()
self.nabla_X[i].assign_vector(res)
class DenseBase(Layer):
thetas = lambda self: (self.W, self.b)
nablas = lambda self: (self.nabla_W, self.nabla_b)
@@ -279,26 +532,20 @@ class DenseBase(Layer):
N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
assert self.d == 1
if self.d_out == 1:
@multithread(self.n_threads, self.d_in)
def _(base, size):
A = sfix.Matrix(1, self.N, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
mp = A.direct_mul(B, reduce=False,
indices=(regint(0, size=1),
regint.inc(N),
batch.get_vector(),
regint.inc(size, base)))
tmp.assign_vector(mp, base)
else:
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
def _(j, k):
a = [f_schur_Y[i][0][k] for i in range(N)]
b = [self.X[i][0][j] for i in batch]
tmp[j][k] = sfix.unreduced_dot_product(a, b)
@multithread(self.n_threads, self.d_in)
def _(base, size):
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
mp = B.direct_trans_mul(A, reduce=False,
indices=(regint.inc(size, base),
batch.get_vector(),
regint.inc(N),
regint.inc(self.d_out)))
tmp.assign_part_vector(mp, base)
if self.d_in * self.d_out < 100000:
progress('nabla W (matmul)')
if self.d_in * self.d_out < 200000:
print('reduce at once')
@multithread(self.n_threads, self.d_in * self.d_out)
def _(base, size):
@@ -309,10 +556,46 @@ class DenseBase(Layer):
def _(i):
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N))
for j in range(self.d)) for i in range(self.d_out))
progress('nabla W')
progress('nabla W/b')
self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector()
for k in range(N))
for j in range(self.d)))
progress('nabla b')
if self.debug:
limit = N * self.debug
@for_range_opt(self.d_in)
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.nabla_W[i][j].reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check)
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
for k in range(N)])
print_ln('X %s', [self.X[k][0][i].reveal()
for k in range(N)])
@for_range_opt(self.d_out)
def _(j):
to_check = self.nabla_b[j].reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('nabla b %s %s: %s', j, len(self.b), to_check)
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
for k in range(N)])
@for_range_opt(len(batch))
def _(i):
to_check = self.nabla_X[i].get_vector().reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('X %s %s', i, self.X[i].reveal_nested())
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
class Dense(DenseBase):
""" Fixed-point dense (matrix multiplication) layer.
@@ -321,7 +604,7 @@ class Dense(DenseBase):
:param d_in: input dimension
:param d_out: output dimension
"""
def __init__(self, N, d_in, d_out, d=1, activation='id'):
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
self.activation = activation
if activation == 'id':
self.f = lambda x: x
@@ -349,15 +632,13 @@ class Dense(DenseBase):
self.f_input = MultiArray([N, d, d_out], sfix)
self.debug = debug
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
@for_range(d_in)
def _(i):
@for_range(d_out)
def _(j):
self.W[i][j] = sfix.get_random(-r, r)
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size()))
self.b.assign_all(0)
def input_from(self, player, raw=False):
@@ -372,15 +653,14 @@ class Dense(DenseBase):
prod = MultiArray([N, self.d, self.d_out], sfix)
else:
prod = self.f_input
@multithread(self.n_threads, N)
max_size = program.Program.prog.budget // self.d_out
@multithread(self.n_threads, N, max_size)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
prod.assign_vector(
X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size),
regint.inc(self.d_in),
regint.inc(self.d_in),
regint.inc(self.d_out))),
base)
prod.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
batch.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
if self.input_bias:
if self.d_out == 1:
@@ -389,7 +669,7 @@ class Dense(DenseBase):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
else:
@for_range_opt_multithread(self.n_threads, N)
@for_range_multithread(self.n_threads, 100, N)
def _(i):
v = prod[i].get_vector() + self.b.get_vector()
self.f_input[i].assign_vector(v)
@@ -397,8 +677,24 @@ class Dense(DenseBase):
def forward(self, batch=None):
self.compute_f_input(batch=batch)
self.Y.assign_vector(self.f(
self.f_input.get_part_vector(0, len(batch))))
@multithread(self.n_threads, len(batch), 128)
def _(base, size):
self.Y.assign_part_vector(self.f(
self.f_input.get_part_vector(base, size)), base)
if self.debug:
limit = self.debug
@for_range_opt(len(batch))
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.Y[i][0][j].reveal()
check = to_check > limit
@if_(check)
def _():
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
print_ln('X %s', self.X[i].reveal_nested())
print_ln('W %s',
[self.W[k][j].reveal() for k in range(self.d_in)])
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
@@ -419,26 +715,31 @@ class Dense(DenseBase):
f_prime_bit = MultiArray([N, d, d_out], sint)
f_schur_Y = MultiArray([N, d, d_out], sfix)
self.compute_f_input()
f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector()))
@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
f_prime_bit.assign_vector(
self.f_prime(self.f_input.get_vector(base, size)), base)
progress('f prime')
@for_range_opt(N)
def _(i):
i = batch[i]
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
f_schur_Y.assign_vector(nabla_Y.get_vector(base, size) *
f_prime_bit.get_vector(base, size),
base)
progress('f prime schur Y')
if compute_nabla_X:
@for_range_opt(N)
def _(i):
i = batch[i]
if self.activation == 'id':
nabla_X[i] = nabla_Y[i].mul_trans(W)
else:
nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W)
@multithread(self.n_threads, N)
def _(base, size):
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
nabla_X.assign_part_vector(
B.direct_mul_trans(W, indices=(regint.inc(size, base),
regint.inc(self.d_out),
regint.inc(self.d_out),
regint.inc(self.d_in))),
base)
progress('nabla X')
@@ -1151,6 +1452,7 @@ class QuantSoftmax(QuantBase, BaseLayer):
class Optimizer:
""" Base class for graphs of layers. """
n_threads = Layer.n_threads
always_shuffle = True
@property
def layers(self):
@@ -1175,6 +1477,14 @@ class Optimizer:
layer.last_used = list(filter(lambda x: x not in used, layer.inputs))
used.update(layer.inputs)
def batch_for(self, layer, batch):
if layer in (self.layers[0], self.layers[-1]):
return batch
else:
batch = regint.Array(len(batch))
batch.assign(regint.inc(len(batch)))
return batch
def forward(self, N=None, batch=None, keep_intermediate=True,
model_from=None):
""" Compute graph.
@@ -1193,7 +1503,7 @@ class Optimizer:
if model_from is not None:
layer.input_from(model_from)
break_point()
layer.forward(batch=batch)
layer.forward(batch=self.batch_for(layer, batch))
break_point()
if not keep_intermediate:
for l in layer.last_used:
@@ -1212,26 +1522,30 @@ class Optimizer:
""" Compute backward propagation. """
for layer in reversed(self.layers):
if len(layer.inputs) == 0:
layer.backward(compute_nabla_X=False, batch=batch)
layer.backward(compute_nabla_X=False,
batch=self.batch_for(layer, batch))
else:
layer.backward(batch=batch)
layer.backward(batch=self.batch_for(layer, batch))
if len(layer.inputs) == 1:
layer.inputs[0].nabla_Y.alloc()
layer.inputs[0].nabla_Y.assign_vector(
layer.nabla_X.get_part_vector(0, len(batch)))
def run(self, batch_size=None):
def run(self, batch_size=None, stop_on_loss=0):
""" Run training.
:param batch_size: batch size (defaults to example size of first layer)
"""
if self.n_epochs == 0:
return
if batch_size is not None:
N = batch_size
else:
N = self.layers[0].N
i = MemValue(0)
i = self.i_epoch
n_iterations = MemValue(0)
@do_while
def _():
@for_range(self.n_epochs)
def _(_):
if self.X_by_label is None:
self.X_by_label = [[None] * self.layers[0].N]
assert len(self.X_by_label) in (1, 2)
@@ -1239,16 +1553,18 @@ class Optimizer:
n = N // len(self.X_by_label)
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
self.X_by_label) / n))
n_iterations.iadd(n_per_epoch)
print('%d runs per epoch' % n_per_epoch)
indices_by_label = []
for label, X in enumerate(self.X_by_label):
indices = regint.Array(n * n_per_epoch)
indices_by_label.append(indices)
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
indices.shuffle()
if self.always_shuffle or n_per_epoch > 1:
indices.shuffle()
loss_sum = MemValue(sfix(0))
@for_range(n_per_epoch)
def _(j):
n_iterations.iadd(1)
batch = regint.Array(N)
for label, X in enumerate(self.X_by_label):
indices = indices_by_label[label]
@@ -1257,20 +1573,84 @@ class Optimizer:
label * n)
self.forward(batch=batch)
self.backward(batch=batch)
self.update(i)
loss = self.layers[-1].l
self.update(i, batch=batch)
loss_sum.iadd(self.layers[-1].l)
if self.print_loss_reduction:
before = self.layers[-1].average_loss(N)
self.forward(batch=batch)
after = self.layers[-1].average_loss(N)
print_ln('loss reduction in batch %s: %s (%s - %s)', j,
before - after, before, after)
elif self.print_losses:
print_ln('loss in batch %s: %s', j, self.layers[-1].average_loss(N))
if stop_on_loss:
loss = self.layers[-1].average_loss(N)
res = (loss < stop_on_loss) * (loss >= 0)
self.stopped_on_loss.write(1 - res)
return res
if self.report_loss and self.layers[-1].approx != 5:
print_ln('loss after epoch %s: %s', i, loss.reveal())
print_ln('loss in epoch %s: %s', i,
(loss_sum.reveal() * cfix(1 / n_per_epoch)))
else:
print_ln('done with epoch %s', i)
time()
i.iadd(1)
res = (i < self.n_epochs)
res = True
if self.tol > 0:
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
return res
print_ln('finished after %s epochs and %s iterations', i, n_iterations)
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y):
for arg in program.args:
m = re.match('rate(.*)', arg)
if m:
self.gamma = MemValue(cfix(float(m.group(1))))
if 'nomom' in program.args:
self.momentum = 0
model_input = 'model_input' in program.args
if model_input:
for layer in self.layers:
layer.input_from(0)
else:
self.reset()
@for_range(n_runs)
def _(i):
if not model_input:
start_timer(1)
self.run(batch_size, stop_on_loss=100)
stop_timer(1)
if 'no_acc' in program.args:
return
N = self.layers[0].X.sizes[0]
self.forward(N)
batch = regint.Array(N)
batch.assign_vector(regint.inc(N))
self.layers[-1].backward(batch)
n_correct = self.layers[-1].reveal_correctness(N, debug=True)
print_ln('train_acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / N,
n_correct, N)
training_address = self.layers[0].X.address
self.layers[0].X.address = test_X.address
n_test = len(test_Y)
self.forward(n_test)
self.layers[0].X.address = training_address
n_correct = self.layers[-1].reveal_correctness(n_test, test_Y)
print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / n_test,
n_correct, n_test)
if model_input:
start_timer(1)
self.run(batch_size)
stop_timer(1)
else:
@if_(util.or_op(self.stopped_on_loss, n_correct <
int(n_test // self.layers[-1].n_outputs * 1.1)))
def _():
self.gamma.imul(.5)
self.reset()
print_ln('reset after reducing learning rate to %s',
self.gamma)
class Adam(Optimizer):
def __init__(self, layers, n_epochs):
self.alpha = .001
@@ -1318,7 +1698,7 @@ class SGD(Optimizer):
:param n_epochs: number of epochs for training
:param report_loss: disclose and print loss
"""
def __init__(self, layers, n_epochs, debug=False, report_loss=False):
def __init__(self, layers, n_epochs, debug=False, report_loss=None):
self.momentum = 0.9
self.layers = layers
self.n_epochs = n_epochs
@@ -1330,11 +1710,19 @@ class SGD(Optimizer):
self.thetas.extend(layer.thetas())
for theta in layer.thetas():
self.delta_thetas.append(theta.same_shape())
self.gamma = MemValue(sfix(0.01))
self.gamma = MemValue(cfix(0.01))
self.debug = debug
self.report_loss = report_loss
if report_loss is None:
self.report_loss = layers[-1].compute_loss
else:
self.report_loss = report_loss
self.tol = 0.000
self.X_by_label = None
self.print_update_average = False
self.print_losses = False
self.print_loss_reduction = False
self.i_epoch = MemValue(0)
self.stopped_on_loss = MemValue(0)
def reset(self, X_by_label=None):
""" Reset layer parameters.
@@ -1353,40 +1741,64 @@ class SGD(Optimizer):
y.assign_all(0)
for layer in self.layers:
layer.reset()
self.i_epoch.write(0)
self.stopped_on_loss.write(0)
def update(self, i_epoch):
def update(self, i_epoch, batch):
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
self.delta_thetas):
@multithread(self.n_threads, len(nabla))
@multithread(self.n_threads, nabla.total_size())
def _(base, size):
old = delta_theta.get_vector(base, size)
red_old = self.momentum * old
new = self.gamma * nabla.get_vector(base, size)
rate = self.gamma.expand_to_vector(size)
nabla_vector = nabla.get_vector(base, size)
log_batch_size = math.log(len(batch), 2)
# divide by len(batch) by truncation
# increased rate if len(batch) is not a power of two
pre_trunc = nabla_vector.v * rate.v
k = nabla_vector.k + rate.k
m = rate.f + int(log_batch_size)
v = pre_trunc.round(k, m, signed=True,
nearest=sfix.round_nearest)
new = nabla_vector._new(v)
diff = red_old - new
delta_theta.assign_vector(diff, base)
theta.assign_vector(theta.get_vector(base, size) +
delta_theta.get_vector(base, size), base)
if self.debug:
for x, name in (old, 'old'), (red_old, 'red_old'), \
(new, 'new'), (diff, 'diff'):
x = x.reveal()
print_ln_if((x > 1000) + (x < -1000),
name + ': %s %s %s %s',
*[y.v.reveal() for y in (old, red_old, \
new, diff)])
if self.print_update_average:
vec = abs(delta_theta.get_vector().reveal())
print_ln('update average: %s (%s)',
sum(vec) * cfix(1 / len(vec)), len(vec))
if self.debug:
limit = int(self.debug)
d = delta_theta.get_vector().reveal()
a = cfix.Array(len(d.v))
aa = [cfix.Array(len(d.v)) for i in range(3)]
a = aa[0]
a.assign(d)
@for_range(len(a))
def _(i):
x = a[i]
print_ln_if((x > 1000) + (x < -1000),
'update len=%d' % len(nabla))
print_ln_if((x > limit) + (x < -limit),
'update epoch=%s %s index=%s %s',
i_epoch.read(), str(delta_theta), i, x)
a = aa[1]
a.assign(nabla.get_vector().reveal())
@for_range(len(a))
def _(i):
x = a[i]
print_ln_if((x > 1000) + (x < -1000),
'nabla len=%d' % len(nabla))
print_ln_if((x > len(batch) * limit) + (x < -len(batch) * limit),
'nabla epoch=%s %s index=%s %s',
i_epoch.read(), str(nabla), i, x)
a = aa[2]
a.assign(theta.get_vector().reveal())
@for_range(len(a))
def _(i):
x = a[i]
print_ln_if((x > limit) + (x < -limit),
'theta epoch=%s %s index=%s %s',
i_epoch.read(), str(theta), i, x)
index = regint.get_random(64) % len(a)
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index,
aa[1][index], aa[0][index], aa[2][index])
self.gamma.imul(1 - 10 ** - 6)

View File

@@ -13,6 +13,7 @@ from Compiler import types
from Compiler import comparison
from Compiler import program
from Compiler import instructions_base
from Compiler import library, util
# polynomials as enumerated on Hart's book
##
@@ -33,11 +34,8 @@ p_3508 = [1.00000000000000000000, -0.50000000000000000000,
0.00000000000000000040]
##
# @private
p_1045 = [1.000000077443021686, 0.693147180426163827795756,
0.224022651071017064605384, 0.055504068620466379157744,
0.009618341225880462374977, 0.001332730359281437819329,
0.000155107460590052573978, 0.000014197847399765606711,
0.000001863347724137967076]
p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(12)]
##
# @private
p_2524 = [-2.05466671951, -8.8626599391,
@@ -92,8 +90,8 @@ pi_over_2 = math.radians(90)
#
# @return truncated sint value of x
def trunc(x):
if type(x) is types.sfix:
return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True)
if isinstance(x, types._fix):
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
elif type(x) is types.sfloat:
v, p, z, s = floatingpoint.FLRound(x, 0)
#return types.sfloat(v, p, z, s, x.err)
@@ -125,7 +123,7 @@ def load_sint(x, l_type):
# @return the evaluation of the polynomial. return type depends on inputs.
def p_eval(p_c, x):
degree = len(p_c) - 1
if type(x) is types.sfix:
if isinstance(x, types._fix):
# ignore coefficients smaller than precision
for c in reversed(p_c):
if c < 2 ** -(x.f + 1):
@@ -160,10 +158,10 @@ def sTrigSub(x):
y = x - (f) * x.coerce(2 * pi)
# reduction to \pi
b1 = y > pi
w = b1 * ((2 * pi - y) - y) + y
w = b1.if_else(2 * pi - y, y)
# reduction to \pi/2
b2 = w > pi_over_2
w = b2 * ((pi - w) - w) + w
w = b2.if_else(pi - w, w)
# returns scaled angle and boolean flags
return w, b1, b2
@@ -182,9 +180,8 @@ def ssin(w, s):
v = w * (1.0 / pi_over_2)
v_2 = v ** 2
# adjust sign according to the movement in the reduction
b = s * (-2) + 1
# calculate the sin using polynomial evaluation
local_sin = b * v * p_eval(p_3307, v_2)
local_sin = s.if_else(-v, v) * p_eval(p_3307, v_2)
return local_sin
@@ -203,10 +200,10 @@ def scos(w, s):
# calculates the v of the w.
v = w
v_2 = v ** 2
# adjust sign according to the movement in the reduction
b = s * (-2) + 1
# calculate the cos using polynomial evaluation
local_cos = b * p_eval(p_3508, v_2)
tmp = p_eval(p_3508, v_2)
# adjust sign according to the movement in the reduction
local_cos = s.if_else(-tmp, tmp)
return local_cos
@@ -264,11 +261,12 @@ def tan(x):
@types.vectorize
@instructions_base.sfix_cisc
def exp2_fx(a):
def exp2_fx(a, zero_output=False):
"""
Power of two for fixed-point numbers.
:param a: exponent for :math:`2^a` (sfix)
:param zero_output: whether to output zero for very small values. If not, the result will be undefined.
:return: :math:`2^a` if it is within the range. Undefined otherwise
"""
@@ -279,54 +277,95 @@ def exp2_fx(a):
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
n_bits = a.f + n_int_bits
n_shift = int(types.program.options.ring) - a.k
if types.program.use_edabit():
l = sint.get_edabit(a.f, True)
u = sint.get_edabit(a.k - a.f, True)
r_bits = l[1] + u[1]
r = l[0] + (u[0] << a.f)
lower_r = l[0]
if types.program.use_split():
assert not zero_output
from Compiler.GC.types import sbitvec
if types.program.use_split() == 3:
x = a.v.split_to_two_summands(a.k)
bits = types._bitint.carry_lookahead_adder(x[0], x[1],
fewer_inv=False)
# converting MSB first reduces the number of rounds
s = sint.conv(bits[-1])
lower_overflow = sint.conv(x[0][a.f]) + \
sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f])
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
elif types.program.use_split() == 4:
x = list(zip(*a.v.split_to_n_summands(a.k, 4)))
bi = types._bitint
red = bi.wallace_reduction
sums1, carries1 = red(*x[:3], get_carry=False)
sums2, carries2 = red(x[3], sums1, carries1, False)
bits = bi.carry_lookahead_adder(sums2, carries2,
fewer_inv=False)
overflows = bi.full_adder(carries1[a.f], carries2[a.f],
bits[a.f] ^ sums2[a.f] ^ carries2[a.f])
overflows = reversed(list((sint.conv(x)
for x in reversed(overflows))))
lower_overflow = sint.bit_compose(sint.conv(x)
for x in overflows)
s = sint.conv(bits[-1])
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
else:
bits = sbitvec(a.v, a.k)
s = sint.conv(bits[-1])
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
higher_bits = bits[a.f:n_bits]
else:
r_bits = [sint.get_random_bit() for i in range(a.k)]
r = sint.bit_compose(r_bits)
lower_r = sint.bit_compose(r_bits[:a.f])
shifted = ((a.v - r) << n_shift).reveal()
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])
lower_masked = sint.bit_compose(masked_bits[:a.f])
lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f))
if types.program.use_edabit():
l = sint.get_edabit(a.f, True)
u = sint.get_edabit(a.k - a.f, True)
r_bits = l[1] + u[1]
r = l[0] + (u[0] << a.f)
lower_r = l[0]
else:
r_bits = [sint.get_random_bit() for i in range(a.k)]
r = sint.bit_compose(r_bits)
lower_r = sint.bit_compose(r_bits[:a.f])
shifted = ((a.v - r) << n_shift).reveal()
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])
lower_masked = sint.bit_compose(masked_bits[:a.f])
lower = lower_r + lower_masked - \
(sint.conv(lower_overflow) << (a.f))
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
masked_bits[a.f:n_bits],
carry_in=lower_overflow,
get_carry=True)
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
r_bits[n_bits:-1],
higher_bits[-1])
if zero_output:
# should be for free
highest_bits = r_bits[0].ripple_carry_adder(
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
carry_in=higher_bits[-1])
bits_to_check = [x.bit_xor(y)
for x, y in zip(highest_bits[:-1],
r_bits[n_bits:-1])]
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
bits_to_check))
# sign
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
del higher_bits[-1]
c = types.sfix._new(lower, k=a.k, f=a.f)
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
masked_bits[a.f:n_bits],
carry_in=lower_overflow,
get_carry=True)
assert(len(higher_bits) == n_bits - a.f + 1)
assert(len(higher_bits) == n_bits - a.f)
pow2_bits = [sint.conv(x) for x in higher_bits]
d = floatingpoint.Pow2_from_bits(pow2_bits[:-1])
d = floatingpoint.Pow2_from_bits(pow2_bits)
e = p_eval(p_1045, c)
g = d * e
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
2 ** n_int_bits, signed=False,
nearest=types.sfix.round_nearest),
k=a.k, f=a.f)
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
r_bits[n_bits:-1],
higher_bits[-1])
# should be for free
highest_bits = r_bits[0].ripple_carry_adder(
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
carry_in=higher_bits[-1])
bits_to_check = [x.bit_xor(y)
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
bits_to_check))
# sign
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
return s.if_else(t.if_else(small_result, 0), g)
if zero_output:
small_result = t.if_else(small_result, 0)
return s.if_else(small_result, g)
else:
assert not zero_output
# obtain absolute value of a
s = a < 0
a = (s * (-2) + 1) * a
a = s.if_else(-a, a)
# isolates fractional part of number
b = trunc(a)
c = a - b
@@ -335,7 +374,7 @@ def exp2_fx(a):
# evaluates fractional part of a in p_1045
e = p_eval(p_1045, c)
g = d * e
return (1 - s) * g + s / g
return s.if_else(1 / g, g)
@types.vectorize
@@ -353,19 +392,20 @@ def log2_fx(x):
:return: (sfix) the value of :math:`\log_2(x)`
"""
if type(x) is types.sfix:
if isinstance(x, types._fix):
# transforms sfix to f*2^n, where f is [o.5,1] bounded
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
p -= x.f
vlen = x.f
v = x._new(v, k=x.k, f=x.f)
else:
d = types.sfloat(x)
v, p, vlen = d.v, d.p, d.vlen
w = x.coerce(1.0 / (2 ** (vlen)))
v *= w
# isolates mantisa of d, now the n can be also substituted by the
# secret shared p from d in the expresion above.
w = x.coerce(1.0 / (2 ** (vlen)))
v = v * w
# polynomials for the log_2 evaluation of f are calculated
P = p_eval(p_2524, v)
Q = p_eval(q_2524, v)
@@ -384,7 +424,7 @@ def pow_fx(x, y):
:param y: (sfix, clear types) secret shared exponent.
:return: :math:`x^y` (sfix)
:return: :math:`x^y` (sfix) if positive and in range
"""
log2_x =0
# obtains log2(x)
@@ -456,9 +496,6 @@ def floor_fx(x):
def MSB(b, k):
# calculation of z
# x in order 0 - k
if (k > types.program.bit_length):
raise OverflowError("The supported bit \
lenght of the application is smaller than k")
x_order = b.bit_decompose(k)
x = [0] * k
@@ -511,9 +548,7 @@ def norm_simplified_SQ(b, k):
w_array[i] = z[2 * i - 1] + z[2 * i]
# w aggregation
w = types.sint(0)
for i in range(k_over_2):
w += (2 ** i) * w_array[i]
w = b.bit_compose(w_array)
# return computed values
#return m_odd, m, w
@@ -538,9 +573,9 @@ def sqrt_simplified_fx(x):
# process to set up the precision and allocate correct 2**f
if x.f % 2 == 1:
m_odd = (1 - 2 * m_odd) + m_odd
w = (w * 2 - w) * (1-m_odd) + w
w = m_odd.if_else(w, 2 * w)
# map number to use sfix format and instantiate the number
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
w = x._new(w << ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
# obtains correct 2 ** (m/2)
w = (w * (2 ** (1/2.0)) - w) * m_odd + w
# produce x/ 2^(m/2)
@@ -739,15 +774,15 @@ def atan(x):
"""
# obtain absolute value of x
s = x < 0
x_abs = (s * (-2) + 1) * x
x_abs = s.if_else(-x, x)
# angle isolation
b = x_abs > 1
v = 1 / x_abs
v = (1 - b) * (x_abs - v) + v
v = b.if_else(v, x_abs)
v_2 =v*v
# range of polynomial coefficients
assert x.k - x.f >= 15
assert x.k - x.f >= 19
P = p_eval(p_5102, v_2)
Q = p_eval(q_5102, v_2)
@@ -756,8 +791,8 @@ def atan(x):
y_pi_over_two = pi_over_2 - y
# sign correction
y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two
y = (1 - s) * (y - (-y)) + (-y)
y = b.if_else(y_pi_over_two, y)
y = s.if_else(-y, y)
return y

View File

@@ -98,6 +98,8 @@ class Program(object):
self.use_dabit = options.mixed
self._edabit = options.edabit
self._split = False
if options.split:
self.use_split(int(options.split))
self._square = False
self._always_raw = False
Program.prog = self
@@ -243,10 +245,12 @@ class Program(object):
""" The basic block that is currently being created. """
return self.curr_tape.active_basicblock
def malloc(self, size, mem_type, reg_type=None):
def malloc(self, size, mem_type, reg_type=None, creator_tape=None):
""" Allocate memory from the top """
if not isinstance(size, int):
raise CompilerError('size must be known at compile time')
if (creator_tape or self.curr_tape) != self.tapes[0]:
raise CompilerError('cannot allocate memory outside main thread')
if size == 0:
return
if isinstance(mem_type, type):
@@ -330,7 +334,9 @@ class Program(object):
if change is None:
return self._split
else:
assert change in (2, 3)
if change and not self.options.ring:
raise CompilerError('splitting only supported for rings')
assert change > 1
self._split = change
def use_square(self, change=None):
@@ -350,8 +356,12 @@ class Program(object):
self.use_trunc_pr = True
if 'split' in self.args or 'split3' in self.args:
self.use_split(3)
if 'split4' in self.args:
self.use_split(4)
if 'raw' in self.args:
self.always_raw(True)
if 'edabit' in self.args:
self.use_edabit(True)
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
@@ -559,12 +569,14 @@ class Tape:
numrounds = merger.longest_paths_merge()
block.n_rounds = numrounds
block.n_to_merge = len(merger.open_nodes)
if numrounds > 0 and self.program.verbose:
print('Program requires %d rounds of communication' % numrounds)
if merger.counter and self.program.verbose:
print('Block requires', \
', '.join('%d %s' % (y, x.__name__) \
for x, y in list(merger.counter.items())))
if merger.counter and self.program.verbose:
print('Block requires %s rounds' % \
', '.join('%d %s' % (y, x.__name__) \
for x, y in list(merger.rounds.items())))
# free memory
merger = None
if options.dead_code_elimination:

View File

@@ -10,6 +10,8 @@ Registers are used for computation, allocated on an ongoing basis,
and thread-specific. The memory is allocated statically and shared
between threads. This means that memory-based types such as
:py:class:`Array` can be used to transfer information between threads.
Note that creating memory-based types outside the main thread is not
supported.
If viewing this documentation in processed form, many function signatures
appear generic because of the use of decorators. See the source code for the
@@ -74,6 +76,7 @@ from . import instructions
from .util import is_zero, is_one
import operator
from functools import reduce
import re
class ClientMessageType:
@@ -118,17 +121,6 @@ class MPCThread(object):
program.join_tape(self.run_handles.pop(0))
def copy_doc(a, b):
try:
a.__doc__ = b.__doc__
except:
pass
def no_doc(operation):
def wrapper(*args, **kwargs):
return operation(*args, **kwargs)
return wrapper
def copy_doc(a, b):
try:
a.__doc__ = b.__doc__
@@ -316,6 +308,14 @@ class _number(object):
res.iadd(res.value_type.conv(aa[i] * bb[i]))
return res.read()
def __abs__(self):
""" Absolute value. """
return (self < 0).if_else(-self, self)
@staticmethod
def popcnt_bits(bits):
return sum(bits)
class _int(object):
""" Integer functionality. """
@@ -534,11 +534,11 @@ class _register(Tape.Register, _number, _structure):
return sum(b << i for i,b in enumerate(bits))
@classmethod
def malloc(cls, size):
def malloc(cls, size, creator_tape=None):
""" Allocate memory (statically).
:param size: compile-time (int) """
return program.malloc(size, cls)
return program.malloc(size, cls, creator_tape=creator_tape)
@classmethod
def free(cls, addr):
@@ -833,6 +833,20 @@ class cint(_clear, _int):
:param other: cint/regint/int """
return self.coerce_op(other, modc, True)
def less_than(self, other, bit_length):
""" Clear comparison for particular bit length.
:param other: cint/regint/int
:param bit_length: signed bit length of inputs
:return: 0/1 (regint), undefined if inputs outside range """
if bit_length <= 64:
return self < other
else:
diff = self - other
shifted = diff >> (bit_length - 1)
res = regint(shifted & 1)
return res
def __lt__(self, other):
""" Clear 64-bit comparison.
@@ -1732,15 +1746,20 @@ class sint(_secret, _int):
""" Secret random n-bit number according to security model.
:param bits: compile-time integer (int) """
if program.use_split() == 3:
if program.use_edabit():
return sint.get_edabit(bits, True)[0]
elif program.use_split() > 2:
tmp = sint()
randoms(tmp, bits)
x = tmp.split_to_two_summands(bits, True)
overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \
sint.conv(x[0][-1])
carry = comparison.CarryOutRawLE(x[1][:bits], x[0][:bits])
if program.use_split() > 3:
from .GC.types import sbitint
x = sbitint.full_adder(carry, x[0][bits], x[1][bits])
overflow = sint.conv(x[1]) * 2 + sint.conv(x[0])
else:
overflow = sint.conv(carry) + sint.conv(x[0][bits])
return tmp - (overflow << bits)
elif program.use_edabit():
return sint.get_edabit(bits, True)[0]
res = sint()
comparison.PRandInt(res, bits)
return res
@@ -1791,7 +1810,7 @@ class sint(_secret, _int):
def bit_decompose_clear(a, n_bits):
return floatingpoint.bits(a, n_bits)
@classmethod
@vectorized_classmethod
def get_raw_input_from(cls, player):
res = cls()
rawinput(player, res)
@@ -1980,7 +1999,7 @@ class sint(_secret, _int):
@vectorize
@read_mem_value
def __rshift__(self, other, bit_length=None, security=None):
def __rshift__(self, other, bit_length=None, security=None, signed=True):
""" Secret right shift.
:param other: secret or public integer (sint/cint/regint/int) """
@@ -1990,7 +2009,7 @@ class sint(_secret, _int):
if other == 0:
return self
res = sint()
comparison.Trunc(res, self, bit_length, other, security, True)
comparison.Trunc(res, self, bit_length, other, security, signed)
return res
elif isinstance(other, sint):
return floatingpoint.Trunc(self, bit_length, other, security)
@@ -2092,6 +2111,15 @@ class sint(_secret, _int):
columns = self.split_to_n_summands(length, n)
return _bitint.wallace_tree_without_finish(columns, get_carry)
@vectorize
def raw_right_shift(self, length):
res = sint()
shrsi(res, self, length)
return res
def raw_mod2m(self, m):
return self - (self.raw_right_shift(m) << m)
@vectorize
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
@@ -2304,6 +2332,14 @@ class _bitint(object):
b.pop(0)
else:
break
carries = cls.get_carries(a, b, fewer_inv=fewer_inv, carry_in=carry_in)
res = lower + cls.sum_from_carries(a, b, carries)
if get_carry:
res += [carries[-1]]
return res
@classmethod
def get_carries(cls, a, b, fewer_inv=False, carry_in=0):
d = [cls.half_adder(ai, bi) for (ai,bi) in zip(a,b)]
carry = floatingpoint.carry
if fewer_inv:
@@ -2314,10 +2350,7 @@ class _bitint(object):
carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1]
else:
carries = []
res = lower + cls.sum_from_carries(a, b, carries)
if get_carry:
res += [carries[-1]]
return res
return carries
@staticmethod
def sum_from_carries(a, b, carries):
@@ -2469,6 +2502,18 @@ class _bitint(object):
def wallace_tree(cls, rows):
return cls.wallace_tree_from_columns([list(x) for x in zip(*rows)])
@classmethod
def wallace_reduction(cls, a, b, c, get_carry=True):
assert len(a) == len(b) == len(c)
tmp = zip(*(cls.full_adder(*x) for x in zip(a, b, c)))
sums, carries = (list(x) for x in tmp)
carries = [0] + carries
if get_carry:
sums += [0]
else:
del carries[-1]
return sums, carries
def __sub__(self, other):
if type(other) == sgf2n:
raise CompilerError('Unclear subtraction')
@@ -2497,7 +2542,7 @@ class _bitint(object):
def __rshift__(self, other):
return self.compose(self.bit_decompose()[other:])
def bit_decompose(self, n_bits=None, *args):
def bit_decompose(self, n_bits=None, security=None):
if self.bits is None:
self.bits = self.force_bit_decompose(self.n_bits)
if n_bits is None:
@@ -2541,14 +2586,16 @@ class _bitint(object):
def __gt__(self, other):
return 1 - (self <= other)
def __eq__(self, other):
def __eq__(self, other, bit_length=None, security=None):
diff = self ^ other
diff_bits = [1 - x for x in diff.bit_decompose()]
diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]]
return floatingpoint.KMul(diff_bits)
def __ne__(self, other):
return 1 - (self == other)
equal = __eq__
def __neg__(self):
return 1 + self.compose(1 ^ b for b in self.bit_decompose())
@@ -2752,9 +2799,9 @@ def parse_type(other, k=None, f=None):
class cfix(_number, _structure):
""" Clear fixed-point number represented as clear integer. """
__slots__ = ['value', 'f', 'k', 'size']
__slots__ = ['value', 'f', 'k']
reg_type = 'c'
scalars = (int, float, regint)
scalars = (int, float, regint, cint)
@classmethod
def set_precision(cls, f, k = None):
""" Set the precision of the integer representation. Note that some
@@ -2779,7 +2826,7 @@ class cfix(_number, _structure):
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls(cint.load_mem(address))
return cls._new(cint.load_mem(address))
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
@@ -2787,7 +2834,7 @@ class cfix(_number, _structure):
Sender will have already bit shifted and sent as cints."""
cint_input = cint.read_from_socket(client_id, n)
if n == 1:
return cfix(cint_inputs)
return cfix._new(cint_inputs)
else:
return list(map(cfix, cint_inputs))
@@ -2805,34 +2852,47 @@ class cfix(_number, _structure):
writesocketc(client_id, message_type, *cint_values)
@staticmethod
def malloc(size):
return program.malloc(size, cint)
def malloc(size, creator_tape=None):
return program.malloc(size, cint, creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@classmethod
def from_int(cls, other):
res = cls()
res.load_int(other)
return res
@classmethod
def _new(cls, other, k=None, f=None):
res = cls(k=k, f=f)
res.v = cint.conv(other)
return res
@staticmethod
def int_rep(v, f):
v = v * (2 ** f)
try:
v = int(round(v))
except TypeError:
pass
return v
@vectorize_init
@read_mem_value
def __init__(self, v=None, k=None, f=None, size=None):
""" :param v: cfix/float/int """
f = self.f if f is None else f
k = self.k if k is None else k
self.f = f
self.k = k
self.size = get_global_vector_size()
if isinstance(v, cint):
self.v = cint(v,size=self.size)
elif isinstance(v, cfix.scalars):
v = v * (2 ** f)
try:
v = int(round(v))
except TypeError:
pass
self.v = cint(v, size=self.size)
if isinstance(v, cfix.scalars):
v = self.int_rep(v, f)
self.v = cint(v, size=size)
elif isinstance(v, cfix):
self.v = v.v
elif isinstance(v, MemValue):
self.v = v
elif v is None:
self.v = cint(0)
else:
@@ -2840,7 +2900,10 @@ class cfix(_number, _structure):
def __iter__(self):
for x in self.v:
yield type(self)(x, self.k, self.f)
yield self._new(x, self.k, self.f)
def __len__(self):
return len(self.v)
@vectorize
def load_int(self, v):
@@ -2863,6 +2926,10 @@ class cfix(_number, _structure):
""" Store in memory by public address. """
self.v.store_in_mem(address)
@property
def size(self):
return self.v.size
def sizeof(self):
return self.size * 4
@@ -2873,7 +2940,7 @@ class cfix(_number, _structure):
:param other: cfix/cint/regint/int """
other = parse_type(other)
if isinstance(other, cfix):
return cfix(self.v + other.v)
return cfix._new(self.v + other.v)
else:
return NotImplemented
@@ -2884,18 +2951,26 @@ class cfix(_number, _structure):
:param other: cfix/cint/regint/int/sint """
if isinstance(other, sint):
return sfix._new(self.v * other, k=self.k, f=self.f)
if isinstance(other, (int, regint, cint)):
return cfix._new(self.v * cint(other), k=self.k, f=self.f)
other = parse_type(other)
if isinstance(other, cfix):
assert self.f == other.f
sgn = cint(1 - 2 * (self.v * other.v < 0))
sgn = cint(1 - 2 * ((self < 0) ^ (other < 0)))
absolute = self.v * other.v * sgn
val = sgn * (absolute >> self.f)
return cfix(val)
return cfix._new(val)
elif isinstance(other, sfix):
return NotImplemented
else:
raise CompilerError('Invalid type %s for cfix.__mul__' % type(other))
def positive_mul(self, other):
assert isinstance(other, float)
assert other >= 0
v = self.v * int(round(other * 2 ** self.f))
return self._new(v >> self.f, k=self.k, f=self.f)
@vectorize
def __sub__(self, other):
""" Clear fixed-point subtraction.
@@ -2903,9 +2978,9 @@ class cfix(_number, _structure):
:param other: cfix/cint/regint/int """
other = parse_type(other)
if isinstance(other, cfix):
return cfix(self.v - other.v)
return cfix._new(self.v - other.v)
elif isinstance(other, sfix):
return sfix(self.v - other.v)
return sfix._new(self.v - other.v)
else:
raise NotImplementedError
@@ -2913,7 +2988,7 @@ class cfix(_number, _structure):
def __neg__(self):
""" Clear fixed-point negation. """
# cfix type always has .v
return cfix(-self.v)
return cfix._new(-self.v)
def __rsub__(self, other):
return -self + other
@@ -2939,7 +3014,8 @@ class cfix(_number, _structure):
""" Clear fixed-point comparison. """
other = parse_type(other)
if isinstance(other, cfix):
return self.v < other.v
assert self.k == other.k
return self.v.less_than(other.v, self.k)
elif isinstance(other, sfix):
if(self.k != other.k or self.f != other.f):
raise TypeError('Incompatible fixed point types in comparison')
@@ -2952,7 +3028,7 @@ class cfix(_number, _structure):
""" Clear fixed-point comparison. """
other = parse_type(other)
if isinstance(other, cfix):
return self.v <= other.v
return 1 - (self > other)
elif isinstance(other, sfix):
return other.v.greater_equal(self.v, self.k, other.kappa)
else:
@@ -2963,7 +3039,7 @@ class cfix(_number, _structure):
""" Clear fixed-point comparison. """
other = parse_type(other)
if isinstance(other, cfix):
return self.v > other.v
return other.__lt__(self)
elif isinstance(other, sfix):
return other.v.less_than(self.v, self.k, other.kappa)
else:
@@ -2974,7 +3050,7 @@ class cfix(_number, _structure):
""" Clear fixed-point comparison. """
other = parse_type(other)
if isinstance(other, cfix):
return self.v >= other.v
return 1 - (self < other)
elif isinstance(other, sfix):
return other.v.less_equal(self.v, self.k, other.kappa)
else:
@@ -3000,9 +3076,10 @@ class cfix(_number, _structure):
""" Clear fixed-point division.
:param other: cfix/cint/regint/int """
other = parse_type(other)
other = parse_type(other, self.k, self.f)
if isinstance(other, cfix):
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
return cfix._new(library.cint_cint_division(
self.v, other.v, self.k, self.f), k=self.k, f=self.f)
elif isinstance(other, sfix):
assert self.k == other.k
assert self.f == other.f
@@ -3016,11 +3093,11 @@ class cfix(_number, _structure):
def print_plain(self):
""" Clear fixed-point output. """
if self.k > 64:
raise CompilerError('Printing of fixed-point numbers not ' +
'implemented for more than 64-bit precision')
tmp = regint()
convmodp(tmp, self.v, bitlength=self.k)
sign = cint(tmp < 0)
sign = (((self.v + (1 << (self.k - 1))) >> self.k) & 1)
else:
tmp = regint()
convmodp(tmp, self.v, bitlength=self.k)
sign = cint(tmp < 0)
abs_v = sign.if_else(-self.v, self.v)
print_float_plain(cint(abs_v), cint(-self.f), \
cint(0), cint(sign), cint(0))
@@ -3064,8 +3141,8 @@ class _single(_number, _structure):
return cls.conv(other)
@classmethod
def malloc(cls, size):
return program.malloc(size, cls.int_type)
def malloc(cls, size, creator_tape=None):
return program.malloc(size, cls.int_type, creator_tape=creator_tape)
@classmethod
def free(cls, addr):
@@ -3192,7 +3269,7 @@ class _single(_number, _structure):
class _fix(_single):
""" Secret fixed point type. """
__slots__ = ['v', 'f', 'k', 'size']
__slots__ = ['v', 'f', 'k']
def set_precision(cls, f, k = None):
cls.f = f
@@ -3200,12 +3277,28 @@ class _fix(_single):
if k is None:
cls.k = 2 * f
else:
if k < f:
raise CompilerError('bit length cannot be less than precision')
cls.k = k
set_precision.__doc__ = cfix.set_precision.__doc__
set_precision = classmethod(set_precision)
@classmethod
def set_precision_from_args(cls, program):
f = None
k = None
for arg in program.args:
m = re.match('f([0-9]+)$', arg)
if m:
f = int(m.group(1))
m = re.match('k([0-9]+)$', arg)
if m:
k = int(m.group(1))
if f is not None:
print ('Setting fixed-point precision to %d/%s' % (f, k))
cls.set_precision(f, k)
cfix.set_precision(f, k)
elif k is not None:
raise CompilerError('need to set fractional precision')
@classmethod
def coerce(cls, other):
if isinstance(other, (_fix, cls.clear_type)):
@@ -3224,13 +3317,13 @@ class _fix(_single):
@classmethod
def _new(cls, other, k=None, f=None):
res = cls(other, k=k, f=f)
res = cls(k=k, f=f)
res.v = cls.int_type.conv(other)
return res
@vectorize_init
def __init__(self, _v=None, k=None, f=None, size=None):
""" :params _v: compile-time value (int/float) """
self.size = get_global_vector_size()
""" :params _v: int/float/regint/cint/sint/sfloat """
if k is None:
k = self.k
else:
@@ -3241,15 +3334,12 @@ class _fix(_single):
self.f = f
assert k is not None
assert f is not None
# warning: don't initialize a sfix from a sint, this is only used in internal methods;
# for external initialization use load_int.
if _v is None:
self.v = self.int_type(0)
elif isinstance(_v, self.int_type):
self.v = _v
self.size = _v.size
self.load_int(_v)
elif isinstance(_v, cfix.scalars):
self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size)
self.v = self.int_type(cfix.int_rep(_v, f=f), size=size)
elif isinstance(_v, self.float_type):
p = (f + _v.p)
b = (p.greater_equal(0, _v.vlen))
@@ -3265,7 +3355,6 @@ class _fix(_single):
if not isinstance(self.v, self.int_type):
raise CompilerError('sfix conversion failure: %s/%s' % (_v, self.v))
@vectorize
def load_int(self, v):
self.v = self.int_type(v) << self.f
@@ -3365,7 +3454,7 @@ class _fix(_single):
class revealed_fix(self.clear_type):
f = self.f
k = self.k
return revealed_fix(val)
return revealed_fix._new(val)
class sfix(_fix):
""" Secret fixed-point number represented as secret integer.
@@ -3410,6 +3499,24 @@ class sfix(_fix):
res = res.reduce_after_mul()
return res
@classmethod
def dot_product(cls, x, y, res_params=None):
""" Secret dot product.
:param x: iterable of appropriate secret type
:param y: iterable of appropriate secret type and same length """
x, y = list(x), list(y)
if res_params is None:
if isinstance(x[0], cls.int_type):
x, y = y, x
if isinstance(y[0], cls.int_type):
return cls._new(cls.int_type.dot_product((xx.v for xx in x), y),
k=x[0].k, f=x[0].f)
return super().dot_product(x, y, res_params)
def expand_to_vector(self, size):
return self._new(self.v.expand_to_vector(size), k=self.k, f=self.f)
def coerce(self, other):
return parse_type(other, k=self.k, f=self.f)
@@ -3426,7 +3533,7 @@ class sfix(_fix):
@staticmethod
def multipliable(v, k, f, size):
return cfix(cint.conv(v, size=size), k, f)
return cfix._new(cint.conv(v, size=size), k, f)
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
@@ -3436,8 +3543,8 @@ class sfix(_fix):
:param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
"""
return personal(player, cfix(self.v.reveal_to(player)._v,
self.k, self.f))
return personal(player, cfix._new(self.v.reveal_to(player)._v,
self.k, self.f))
class unreduced_sfix(_single):
int_type = sint
@@ -3466,10 +3573,9 @@ class unreduced_sfix(_single):
@vectorize
def reduce_after_mul(self):
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
nearest=sfix.round_nearest,
signed=True),
k=self.k // 2, f=self.m)
v = sfix.int_type.round(self.v, self.k, self.m, self.kappa,
nearest=sfix.round_nearest, signed=True)
return sfix._new(v, k=self.k // 2, f=self.m)
sfix.unreduced_type = unreduced_sfix
@@ -3696,8 +3802,9 @@ class sfloat(_number, _structure):
return 4
@classmethod
def malloc(cls, size):
return program.malloc(size * cls.n_elements(), sint)
def malloc(cls, size, creator_tape=None):
return program.malloc(size * cls.n_elements(), sint,
creator_tape=creator_tape)
@classmethod
def is_address_tuple(cls, address):
@@ -4142,12 +4249,14 @@ class Array(object):
self.address = address
self.address_cache = {}
self.debug = debug
self.creator_tape = program.curr_tape
if alloc:
self.alloc()
def alloc(self):
if self.address is None:
self.address = self.value_type.malloc(self.length)
self.address = self.value_type.malloc(self.length,
self.creator_tape)
def delete(self):
self.value_type.free(self.address)
@@ -4371,6 +4480,10 @@ class Array(object):
reveal_nested = reveal_list
def __str__(self):
return '%s array of length %s at %s' % (self.value_type, len(self),
self.address)
sint.dynamic_array = Array
sgf2n.dynamic_array = Array
@@ -4471,6 +4584,12 @@ class SubMultiArray(object):
return self.value_type.load_mem(self.address + base * part_size,
size=size)
def assign_part_vector(self, vector, base=0):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
assert vector.size <= self.total_size()
vector.store_in_mem(self.address + base * part_size)
def get_addresses(self, *indices):
assert self.value_type.n_elements() == 1
assert len(indices) == len(self.sizes)
@@ -4592,13 +4711,16 @@ class SubMultiArray(object):
t = self.value_type
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
try:
if max(res_matrix.sizes) > 1000:
raise AttributeError()
A = self.get_vector()
B = other.get_vector()
res_matrix.assign_vector(
self.value_type.matrix_mul(A, B, self.sizes[1],
res_params))
try:
res_matrix.assign_vector(self.direct_mul(other))
except AttributeError:
if max(res_matrix.sizes) > 1000:
raise AttributeError()
A = self.get_vector()
B = other.get_vector()
res_matrix.assign_vector(
self.value_type.matrix_mul(A, B, self.sizes[1],
res_params))
except (AttributeError, AssertionError):
# fallback for sfloat etc.
@library.for_range_opt(self.sizes[0])
@@ -4644,7 +4766,60 @@ class SubMultiArray(object):
self.sizes[0], *other.sizes,
reduce=reduce, indices=indices)
def direct_mul_trans(self, other, reduce=True, indices=None):
"""
Matrix multiplication with the transpose of :py:obj:`other`
in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
:return: Matrix as vector of relevant type (row-major)
"""
assert len(self.sizes) == 2
assert len(other.sizes) == 2
if indices is None:
assert self.sizes[1] == other.sizes[1]
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
assert len(indices[1]) == len(indices[2])
indices = list(indices)
indices[3] *= other.sizes[0]
return self.value_type.direct_matrix_mul(
self.address, other.address, None, self.sizes[1], 1,
reduce=reduce, indices=indices)
def direct_trans_mul(self, other, reduce=True, indices=None):
"""
Matrix multiplication with the transpose of :py:obj:`self`
in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
:return: Matrix as vector of relevant type (row-major)
"""
assert len(self.sizes) == 2
assert len(other.sizes) == 2
if indices is None:
assert self.sizes[0] == other.sizes[0]
indices = [regint.inc(i) for i in self.sizes[::-1] + other.sizes]
assert len(indices[1]) == len(indices[2])
indices = list(indices)
indices[1] *= self.sizes[1]
return self.value_type.direct_matrix_mul(
self.address, other.address, None, 1, other.sizes[1],
reduce=reduce, indices=indices)
def direct_mul_to_matrix(self, other):
""" Matrix multiplication in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:returns: :py:obj:`Matrix`
"""
res = self.value_type.Matrix(self.sizes[0], other.sizes[1])
res.assign_vector(self.direct_mul(other))
return res
@@ -4756,6 +4931,10 @@ class SubMultiArray(object):
return [f(sizes[1:]) for i in range(sizes[0])]
return f(self.sizes)
def __str__(self):
return '%s multi-array of lengths %s at %s' % (self.value_type,
self.sizes, self.address)
class MultiArray(SubMultiArray):
""" Multidimensional array. """
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
@@ -4990,25 +5169,15 @@ class MemValue(_mem):
return 'MemValue(%s,%d)' % (self.value_type, self.address)
class MemFloat(_mem):
class MemFloat(MemValue):
def __init__(self, *args):
value = sfloat(*args)
self.v = MemValue(value.v)
self.p = MemValue(value.p)
self.z = MemValue(value.z)
self.s = MemValue(value.s)
super().__init__(sfloat(*args))
def write(self, *args):
value = sfloat(*args)
self.v.write(value.v)
self.p.write(value.p)
self.z.write(value.z)
self.s.write(value.s)
super().write(value)
def read(self):
return sfloat(self.v, self.p, self.z, self.s)
class MemFix(_mem):
class MemFix(MemValue):
def __init__(self, *args):
arg_type = type(*args)
if arg_type == sfix:
@@ -5017,22 +5186,10 @@ class MemFix(_mem):
value = cfix(*args)
else:
raise CompilerError('MemFix init argument error')
self.reg_type = value.v.reg_type
self.v = MemValue(value.v)
super().__init__(value)
def write(self, *args):
value = sfix(*args)
self.v.write(value.v)
def reveal(self):
return cfix(self.v.reveal())
def read(self):
val = self.v.read()
if isinstance(val, sint):
return sfix(val)
else:
return cfix(val)
super().write(self.value_type(*args))
def getNamedTupleType(*names):
class NamedTuple(object):

View File

@@ -174,8 +174,17 @@ def is_all_ones(x, n):
else:
return False
def max(x, y):
return if_else(x > y, x, y)
def max(x, y=None):
if y is None:
return tree_reduce(max, x)
else:
return if_else(x > y, x, y)
def min(x, y=None):
if y is None:
return tree_reduce(min, x)
else:
return if_else(x < y, x, y)
def long_one(x):
try:

View File

@@ -5,6 +5,7 @@
#include "ECDSA/P256Element.h"
#include "Tools/mkpath.h"
#include "GC/TinierSecret.h"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/Share.hpp"

View File

@@ -90,8 +90,9 @@ class Offline_Check_Error: public runtime_error
runtime_error("Offline-Check-Error : " + m) {}
};
class mac_fail: public bad_value
{ virtual const char* what() const throw()
{ return "MacCheck Failure"; }
{
public:
mac_fail(string msg = "MacCheck Failure") : bad_value(msg) {}
};
class consistency_check_fail: public exception
{ virtual const char* what() const throw()

View File

@@ -64,7 +64,7 @@ void BitAdder::add(vector<vector<T> >& res,
int n_bits = summands.size();
for (size_t i = begin; i < end; i++)
res[i].resize(n_bits + 1);
res.at(i).resize(n_bits + 1);
size_t n_items = end - begin;

View File

@@ -8,11 +8,12 @@
#include "GC/square64.h"
#include "GC/Processor.hpp"
#include "GC/ShareSecret.hpp"
#include "Processor/Input.hpp"
namespace GC
{
SwitchableOutput FakeSecret::out;
const int FakeSecret::default_length;
void FakeSecret::load_clear(int n, const Integer& x)
@@ -87,6 +88,14 @@ FakeSecret FakeSecret::input(int from, word input, int n_bits)
return input;
}
void FakeSecret::inputbvec(Processor<FakeSecret>& processor,
ProcessorBase& input_processor, const vector<int>& args)
{
Input input;
input.reset_all(*ShareThread<FakeSecret>::s().P);
processor.inputbvec(input, input_processor, args, 0);
}
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
bool repeat)
{
@@ -96,4 +105,19 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
*this = BitVec(x & y).mask(n);
}
void FakeSecret::my_input(Input& inputter, BitVec value, int n_bits)
{
inputter.add_mine(value, n_bits);
}
void FakeSecret::other_input(Input&, int, int)
{
throw runtime_error("emulation is supposed to be lonely");
}
void FakeSecret::finalize_input(Input& inputter, int from, int n_bits)
{
*this = inputter.finalize(from, n_bits);
}
} /* namespace GC */

View File

@@ -24,6 +24,8 @@
#include <random>
#include <fstream>
class ProcessorBase;
namespace GC
{
@@ -53,7 +55,10 @@ public:
typedef FakeProtocol<FakeSecret> Protocol;
typedef FakeInput<FakeSecret> Input;
typedef SwitchableOutput out_type;
static string type_string() { return "fake secret"; }
static string type_short() { return "emulB"; }
static string phase_name() { return "Faking"; }
static const int default_length = 64;
@@ -62,7 +67,8 @@ public:
static const bool actual_inputs = true;
static SwitchableOutput out;
static const true_type invertible;
static const true_type characteristic_two;
static DataFieldType field_type() { return DATA_GF2; }
@@ -87,8 +93,8 @@ public:
template <class T>
static void inputb(T& processor, ArithmeticProcessor&, const vector<int>& args)
{ processor.input(args); }
template <class T, class U>
static void inputbvec(T&, U&, const vector<int>&) { throw not_implemented(); }
static void inputbvec(Processor<FakeSecret>& processor,
ProcessorBase& input_processor, const vector<int>& args);
template <class T>
static void reveal_inst(T& processor, const vector<int>& args)
{ processor.reveal(args); }
@@ -136,6 +142,14 @@ public:
void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; }
void invert(FakeSecret) { throw not_implemented(); }
void input(istream&, bool) { throw not_implemented(); }
bool operator<(FakeSecret) const { return false; }
void my_input(Input& inputter, BitVec value, int n_bits);
void other_input(Input& inputter, int from, int n_bits = 1);
void finalize_input(Input& inputter, int from, int n_bits);
};
} /* namespace GC */

View File

@@ -47,13 +47,6 @@ template<class T>
void Machine<T>::load_schedule(string progname)
{
BaseMachine::load_schedule(progname);
for (auto i : {1, 0, 0})
{
int n;
inpf >> n;
if (n != i)
throw runtime_error("old schedule format not supported");
}
print_compiler();
}

View File

@@ -7,6 +7,7 @@
#define GC_NOSHARE_H_
#include "Processor/DummyProtocol.h"
#include "BMR/Register.h"
#include "Tools/SwitchableOutput.h"
class InputArgs;
@@ -41,6 +42,11 @@ public:
return 0;
}
static string type_string()
{
return "no";
}
static void fail()
{
throw runtime_error("VM does not support binary circuits");
@@ -93,8 +99,6 @@ public:
static const bool expensive_triples = false;
static const bool is_real = false;
static SwitchableOutput out;
static MC* new_mc(mac_key_type)
{
return new MC;
@@ -130,7 +134,7 @@ public:
NoValue::fail();
}
static void inputb(Processor<NoShare>&, ArithmeticProcessor&, const vector<int>&) { fail(); }
static void inputb(Processor<NoShare>&, const ArithmeticProcessor&, const vector<int>&) { fail(); }
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
@@ -139,6 +143,10 @@ public:
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static void xors(Processor<NoShare>&, vector<int>) { fail(); }
static void ands(Processor<NoShare>&, vector<int>) { fail(); }
static void andrs(Processor<NoShare>&, vector<int>) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
NoShare() {}
@@ -161,8 +169,8 @@ public:
void operator^=(NoShare) { fail(); }
NoShare operator+(const NoShare&) const { fail(); return {}; }
NoShare operator-(NoShare) const { fail(); return 0; }
NoShare operator*(NoValue) const { fail(); return 0; }
NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
@@ -172,6 +180,8 @@ public:
NoShare get_bit(int) const { fail(); return {}; }
void invert(int, NoShare) { fail(); }
void input(istream&, bool) { fail(); }
};
} /* namespace GC */

View File

@@ -44,6 +44,8 @@ public:
Timer xor_timer;
typename T::out_type out;
Processor(Machine<T>& machine);
Processor(Memories<T>& memories, Machine<T>* machine = 0);
~Processor();

View File

@@ -301,15 +301,15 @@ void Processor<T>::print_reg(int reg, int n, int size)
bigint output;
for (int i = 0; i < size; i++)
output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i);
T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
print_str(n);
T::out << endl << flush;
out << endl << flush;
}
template <class T>
void Processor<T>::print_reg_plain(Clear& value)
{
T::out << hex << showbase << value << dec << flush;
out << hex << showbase << value << dec << flush;
}
template <class T>
@@ -323,7 +323,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
n_shift = sizeof(value.get()) * 8 - n_bits;
if (n_shift > 63)
n_shift = 0;
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
out << dec << (value.get() << n_shift >> n_shift) << flush;
}
else
{
@@ -334,26 +334,26 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
}
if (tmp >= bigint(1) << (n_bits - 1))
tmp -= bigint(1) << n_bits;
T::out << dec << tmp << flush;
out << dec << tmp << flush;
}
}
template <class T>
void Processor<T>::print_chr(int n)
{
T::out << (char)n << flush;
out << (char)n << flush;
}
template <class T>
void Processor<T>::print_str(int n)
{
T::out << string((char*)&n,sizeof(n)) << flush;
out << string((char*)&n,sizeof(n)) << flush;
}
template <class T>
void Processor<T>::print_float(const vector<int>& args)
{
bigint::output_float(T::out,
bigint::output_float(out,
bigint::get_float(C[args[0]], C[args[1]], C[args[2]], C[args[3]]),
C[args[4]]);
}
@@ -361,7 +361,7 @@ void Processor<T>::print_float(const vector<int>& args)
template <class T>
void Processor<T>::print_float_prec(int n)
{
T::out << setprecision(n);
out << setprecision(n);
}
} /* namespace GC */

26
GC/Rep4Secret.cpp Normal file
View File

@@ -0,0 +1,26 @@
/*
* Rep4Secret.cpp
*
*/
#ifndef GC_REP4SECRET_CPP_
#define GC_REP4SECRET_CPP_
#include "Rep4Secret.h"
#include "ShareSecret.hpp"
#include "ShareThread.hpp"
#include "Protocols/Rep4MC.hpp"
namespace GC
{
void Rep4Secret::load_clear(int n, const Integer& x)
{
this->check_length(n, x);
*this = constant(x, ShareThread<This>::s().P->my_num());
}
}
#endif /* GC_REP4SECRET_CPP_ */

53
GC/Rep4Secret.h Normal file
View File

@@ -0,0 +1,53 @@
/*
* Rep4Secret.h
*
*/
#ifndef GC_REP4SECRET_H_
#define GC_REP4SECRET_H_
#include "ShareSecret.h"
#include "Processor/NoLivePrep.h"
#include "Protocols/Rep4MC.h"
#include "Protocols/Rep4Share.h"
namespace GC
{
class Rep4Secret : public RepSecretBase<Rep4Secret, 3>
{
typedef RepSecretBase<Rep4Secret, 3> super;
typedef Rep4Secret This;
public:
typedef DummyLivePrep<This> LivePrep;
typedef Rep4<This> Protocol;
typedef Rep4MC<This> MC;
typedef MC MAC_Check;
typedef Rep4Input<This> Input;
static const bool expensive_triples = false;
static MC* new_mc(typename super::mac_key_type) { return new MC; }
static This constant(const typename super::clear& constant, int my_num,
typename super::mac_key_type = {})
{
return Rep4Share<typename super::clear>::constant(constant, my_num);
}
Rep4Secret()
{
}
template <class T>
Rep4Secret(const T& other) :
super(other)
{
}
void load_clear(int n, const Integer& x);
};
}
#endif /* GC_REP4SECRET_H_ */

View File

@@ -62,13 +62,13 @@ public:
typedef typename T::Input Input;
typedef typename T::out_type out_type;
static string type_string() { return "evaluation secret"; }
static string phase_name() { return T::name(); }
static int default_length;
static typename T::out_type out;
static const bool needs_ot = false;
static const bool is_real = true;
@@ -170,9 +170,6 @@ public:
template <class T>
int Secret<T>::default_length = 64;
template <class T>
typename T::out_type Secret<T>::out = T::out;
template <class T>
inline ostream& operator<<(ostream& o, Secret<T>& secret)
{

View File

@@ -58,10 +58,11 @@ SemiPrep::~SemiPrep()
void SemiPrep::buffer_bits()
{
auto& thread = Thread<SemiSecret>::s();
word r = thread.secure_prng.get_word();
word r = secure_prng.get_word();
for (size_t i = 0; i < sizeof(word) * 8; i++)
{
this->bits.push_back((r >> i) & 1);
}
}
size_t SemiPrep::data_sent()

View File

@@ -23,6 +23,8 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
SemiSecret::TripleGenerator* triple_generator;
MascotParams params;
SeededPRNG secure_prng;
public:
SemiPrep(DataPositions& usage, ShareThread<SemiSecret>& thread);
SemiPrep(DataPositions& usage, bool = true);

View File

@@ -80,8 +80,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
T::out.activate(my_num == 0 or online_opts.interactive);
if (not this->machine.use_encryption and not T::dishonest_majority)
insecure("unencrypted communication");

View File

@@ -44,8 +44,6 @@ public:
static const bool is_real = true;
static const bool actual_inputs = true;
static SwitchableOutput out;
static void store_clear_in_dynamic(Memory<U>& mem,
const vector<ClearWriteAccess>& accesses);
@@ -83,21 +81,26 @@ public:
void other_input(T& inputter, int from, int n_bits = 1);
template<class T>
void finalize_input(T& inputter, int from, int n_bits);
U& operator=(const U&);
};
template<class U>
class ReplicatedSecret : public FixedVec<BitVec, 2>, public ShareSecret<U>
template<class U, int L>
class RepSecretBase : public FixedVec<BitVec, L>, public ShareSecret<U>
{
typedef FixedVec<BitVec, 2> super;
typedef FixedVec<BitVec, L> super;
typedef RepSecretBase This;
public:
typedef U part_type;
typedef U small_type;
typedef U whole_type;
typedef BitVec clear;
typedef BitVec open_type;
typedef BitVec mac_type;
typedef BitVec mac_key_type;
typedef ReplicatedBase Protocol;
typedef NoShare bit_type;
static const int N_BITS = clear::N_BITS;
@@ -109,7 +112,7 @@ public:
static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; }
static const int default_length = 8 * sizeof(typename ReplicatedSecret<U>::value_type);
static const int default_length = N_BITS;
static int threshold(int)
{
@@ -124,9 +127,45 @@ public:
{
}
static void read_or_generate_mac_key(string, const Names&, mac_key_type) {}
static void read_or_generate_mac_key(string, const Player&, mac_key_type)
{
}
static ReplicatedSecret constant(const clear& value, int my_num, mac_key_type)
RepSecretBase()
{
}
template <class T>
RepSecretBase(const T& other) :
super(other)
{
}
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
void xor_(int n, const This& x, const This& y)
{ *this = x ^ y; (void)n; }
This operator&(const Clear& other)
{ return super::operator&(BitVec(other)); }
This lsb()
{ return *this & 1; }
This get_bit(int i)
{ return (*this >> i) & 1; }
};
template<class U>
class ReplicatedSecret : public RepSecretBase<U, 2>
{
typedef RepSecretBase<U, 2> super;
public:
typedef ReplicatedBase Protocol;
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
typename super::mac_key_type)
{
ReplicatedSecret res;
if (my_num < 2)
@@ -140,28 +179,44 @@ public:
void load_clear(int n, const Integer& x);
void bitcom(Memory<U>& S, const vector<int>& regs);
void bitdec(Memory<U>& S, const vector<int>& regs) const;
BitVec local_mul(const ReplicatedSecret& other) const;
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
{ *this = x ^ y; (void)n; }
void reveal(size_t n_bits, Clear& x);
ReplicatedSecret operator&(const Clear& other)
{ return super::operator&(BitVec(other)); }
ReplicatedSecret lsb()
{ return *this & 1; }
ReplicatedSecret get_bit(int i)
{ return (*this >> i) & 1; }
};
class SemiHonestRepPrep;
class SmallRepSecret : public FixedVec<BitVec_<unsigned char>, 2>
{
typedef FixedVec<BitVec_<unsigned char>, 2> super;
typedef SmallRepSecret This;
public:
typedef ReplicatedMC<This> MC;
typedef BitVec_<unsigned char> open_type;
typedef open_type clear;
typedef BitVec mac_key_type;
static MC* new_mc(mac_key_type)
{
return new MC;
}
SmallRepSecret()
{
}
template<class T>
SmallRepSecret(const T& other) :
super(other)
{
}
This lsb() const
{
return *this & 1;
}
};
class SemiHonestRepSecret : public ReplicatedSecret<SemiHonestRepSecret>
{
typedef ReplicatedSecret<SemiHonestRepSecret> super;
@@ -176,7 +231,7 @@ public:
typedef ReplicatedInput<SemiHonestRepSecret> Input;
typedef SemiHonestRepSecret part_type;
typedef SemiHonestRepSecret small_type;
typedef SmallRepSecret small_type;
typedef SemiHonestRepSecret whole_type;
static const bool expensive_triples = false;

View File

@@ -25,14 +25,11 @@
namespace GC
{
template<class U>
const int ReplicatedSecret<U>::N_BITS;
template<class U, int L>
const int RepSecretBase<U, L>::N_BITS;
template<class U>
const int ReplicatedSecret<U>::default_length;
template<class U>
SwitchableOutput ShareSecret<U>::out;
template<class U, int L>
const int RepSecretBase<U, L>::default_length;
template<class U>
void ShareSecret<U>::check_length(int n, const Integer& x)
@@ -59,16 +56,16 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
*this = x;
}
template<class U>
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
template<class U, int L>
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
{
*this = 0;
for (unsigned int i = 0; i < regs.size(); i++)
*this ^= (S[regs[i]] << i);
}
template<class U>
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
template<class U, int L>
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
{
for (unsigned int i = 0; i < regs.size(); i++)
S[regs[i]] = (*this >> i) & 1;
@@ -285,12 +282,11 @@ void ShareSecret<U>::xors(Processor<U>& processor, const vector<int>& args)
ShareThread<U>::s().xors(processor, args);
}
template<class U>
void ReplicatedSecret<U>::trans(Processor<U>& processor,
template<class U, int L>
void RepSecretBase<U, L>::trans(Processor<U>& processor,
int n_outputs, const vector<int>& args)
{
assert(length == 2);
for (int k = 0; k < 2; k++)
for (int k = 0; k < L; k++)
{
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
@@ -330,6 +326,14 @@ void ShareSecret<U>::random_bit()
*this = res;
}
template<class U>
U& GC::ShareSecret<U>::operator=(const U& other)
{
U& real_this = static_cast<U&>(*this);
real_this = other;
return real_this;
}
}
#endif

View File

@@ -54,7 +54,13 @@ void Thread<T>::run()
P = new CryptoPlayer(N, thread_num << 16);
else
P = new PlainPlayer(N, thread_num << 16);
processor.open_input_file(N.my_num(), thread_num);
processor.open_input_file(N.my_num(), thread_num,
master.opts.cmd_private_input_file);
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
processor.setup_redirection(P->my_num(), thread_num, master.opts);
if (processor.stdout_redirect_file.is_open())
processor.out.redirect_to_file(processor.stdout_redirect_file);
done.push(0);
pre_run();

View File

@@ -38,6 +38,7 @@ public:
typedef typename part_type::sacri_type sacri_type;
typedef typename part_type::mac_type mac_type;
typedef typename part_type::mac_share_type mac_share_type;
typedef BitDiagonal Rectangle;
typedef typename T::super check_type;
@@ -152,6 +153,11 @@ public:
reg.output(s, human);
}
void input(istream&, bool)
{
throw not_implemented();
}
template <class U>
void my_input(U& inputter, BitVec value, int n_bits)
{

View File

@@ -129,7 +129,7 @@
X(GLDMC, ) \
X(LDMS, ) \
X(LDMC, ) \
X(PRINTINT, S0.out << I0) \
X(PRINTINT, PROC.out << I0) \
X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \
X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \
X(RUN_TAPE, MACH->run_tapes(EXTRA)) \

View File

@@ -6,7 +6,7 @@
#include "Processor/config.h"
#include "Protocols/Share.h"
#include "GC/TinierSecret.h"
#include "Math/gfp.h"
#include "Math/gfp.hpp"
#include "Player-Online.hpp"

View File

@@ -1,2 +1,3 @@
#include "Rep.hpp"
#include "Protocols/Spdz2kPrep.hpp"
#include "Protocols/RepRingOnlyEdabitPrep.hpp"

View File

@@ -16,16 +16,14 @@
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/FakeShare.hpp"
SwitchableOutput GC::NoShare::out;
int main(int argc, const char** argv)
{
assert(argc > 1);
OnlineOptions online_opts;
Names N(0, 9999, vector<string>({"localhost"}));
Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector<string>({"localhost"}));
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
opt.parse(argc, argv);
opt.syntax = string(argv[0]) + " <progname>";
string progname;
if (opt.firstArgs.size() > 1)
progname = *opt.firstArgs.at(1);
@@ -41,7 +39,16 @@ int main(int argc, const char** argv)
exit(1);
}
switch (ring_opts.R)
#ifdef ROUND_NEAREST_IN_EMULATION
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;
#else
#ifdef RISKY_TRUNCATION_IN_EMULATION
cerr << "Using risky truncation" << endl;
#endif
#endif
int R = ring_opts.ring_size_from_opts_or_schedule(progname);
switch (R)
{
case 64:
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
@@ -53,7 +60,27 @@ int main(int argc, const char** argv)
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 256:
Machine<FakeShare<SignedZ2<256>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 192:
Machine<FakeShare<SignedZ2<192>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 384:
Machine<FakeShare<SignedZ2<384>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 512:
Machine<FakeShare<SignedZ2<512>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
default:
cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl;
cerr << "Not compiled for " << R << "-bit rings" << endl;
}
}

View File

@@ -1,6 +1,6 @@
#include "Player-Online.hpp"
#include "Math/gfp.h"
#include "Math/gfp.hpp"
#include "GC/TinierSecret.h"
int main(int argc, const char** argv)

View File

@@ -13,14 +13,27 @@
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
RingOptions opts(opt, argc, argv);
RingOptions opts(opt, argc, argv, true);
switch (opts.R)
{
case 64:
ReplicatedMachine<PostSacriRepRingShare<64, 40>, PostSacriRepFieldShare<gf2n>>(
argc, argv, opt);
break;
case 72:
switch (opts.S)
{
case 40:
ReplicatedMachine<PostSacriRepRingShare<64, 40>,
PostSacriRepFieldShare<gf2n>>(argc, argv, opt);
break;
case 64:
ReplicatedMachine<PostSacriRepRingShare<64, 64>,
PostSacriRepFieldShare<gf2n>>(argc, argv, opt);
break;
default:
cerr << "Security parameter " << opts.S << " not implemented"
<< endl;
exit(1);
}
break;
case 72:
ReplicatedMachine<PostSacriRepRingShare<72, 40>, PostSacriRepFieldShare<gf2n>>(
argc, argv, opt);
break;

View File

@@ -0,0 +1,38 @@
/*
* rep4-party.cpp
*
*/
#include "Protocols/Rep4Share2k.h"
#include "Protocols/Rep4Share.h"
#include "Protocols/Rep4MC.h"
#include "Protocols/ReplicatedMachine.h"
#include "Math/Z2k.h"
#include "Math/gf2n.h"
#include "Tools/ezOptionParser.h"
#include "GC/Rep4Secret.h"
#include "Processor/RingOptions.h"
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
#include "Protocols/ReplicatedMachine.hpp"
#include "Protocols/Rep4Input.hpp"
#include "Protocols/Rep4Prep.hpp"
#include "Protocols/Rep4MC.hpp"
#include "Protocols/Rep4.hpp"
#include "GC/BitAdder.hpp"
#include "Math/Z2k.hpp"
#include "Rep.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
switch (ring_opts.R)
{
#define X(R) case R: ReplicatedMachine<Rep4Share2<R>, Rep4Share<gf2n>>(argc, argv, opt, 4); break;
X(64) X(80) X(88)
default:
cerr << ring_opts.R << "-bit computation not implemented" << endl;
exit(1);
}
}

View File

@@ -24,6 +24,10 @@ int main(int argc, const char** argv)
ReplicatedMachine<Rep3Share2<72>, Rep3Share<gf2n>>(argc, argv,
"replicated-ring", opt);
break;
case 128:
ReplicatedMachine<Rep3Share2<128>, Rep3Share<gf2n>>(argc, argv,
"replicated-ring", opt);
break;
default:
throw runtime_error(to_string(opts.R) + "-bit computation not implemented");
}

View File

@@ -12,6 +12,7 @@
#include "Player-Online.hpp"
#include "Semi.hpp"
#include "GC/ShareSecret.hpp"
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -11,6 +11,7 @@
#include "Networking/Server.h"
#include "Player-Online.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)
{

View File

@@ -0,0 +1,41 @@
/*
* sy-rep-field-party.cpp
*
*/
#include "Protocols/SpdzWiseShare.h"
#include "Protocols/MaliciousRep3Share.h"
#include "Protocols/ReplicatedMachine.h"
#include "Protocols/MAC_Check.h"
#include "Protocols/SpdzWiseMC.h"
#include "Protocols/SpdzWisePrep.h"
#include "Protocols/SpdzWiseInput.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Tools/ezOptionParser.h"
#include "Processor/NoLivePrep.h"
#include "GC/MaliciousCcdSecret.h"
#include "Protocols/ReplicatedMachine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/SpdzWise.hpp"
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
#include "Math/gfp.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opts;
ReplicatedMachine<SpdzWiseShare<MaliciousRep3Share<gfp>>,
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opts);
}

View File

@@ -0,0 +1,68 @@
/*
* sy-rep-ring-party.cpp
*
*/
#include "Protocols/ReplicatedMachine.h"
#include "Protocols/SpdzWiseRingShare.h"
#include "Protocols/MaliciousRep3Share.h"
#include "Protocols/SpdzWiseMC.h"
#include "Protocols/SpdzWiseRingPrep.h"
#include "Protocols/SpdzWiseInput.h"
#include "Protocols/MalRepRingPrep.h"
#include "Processor/RingOptions.h"
#include "GC/MaliciousCcdSecret.h"
#include "Protocols/ReplicatedMachine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/SpdzWise.hpp"
#include "Protocols/SpdzWiseRing.hpp"
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Protocols/PostSacrifice.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/MaliciousRepPrep.hpp"
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
RingOptions opts(opt, argc, argv, true);
switch (opts.R)
{
case 64:
switch (opts.S)
{
case 40:
ReplicatedMachine<SpdzWiseRingShare<64, 40>,
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
break;
case 64:
ReplicatedMachine<SpdzWiseRingShare<64, 64>,
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
break;
default:
cerr << "Security parameter " << opts.S << " not implemented"
<< endl;
exit(1);
}
break;
case 72:
ReplicatedMachine<SpdzWiseRingShare<72, 40>,
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
break;
default:
throw runtime_error(
to_string(opts.R) + "-bit computation not implemented");
}
}

View File

@@ -0,0 +1,33 @@
/*
* sy-shamir-party.cpp
*
*/
#include "ShamirMachine.h"
#include "Protocols/ReplicatedMachine.h"
#include "Protocols/SpdzWiseShare.h"
#include "Protocols/MaliciousShamirShare.h"
#include "Protocols/SpdzWiseMC.h"
#include "Protocols/SpdzWiseInput.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "GC/CcdSecret.h"
#include "GC/MaliciousCcdSecret.h"
#include "Protocols/Share.hpp"
#include "Protocols/SpdzWise.hpp"
#include "Protocols/SpdzWisePrep.hpp"
#include "Protocols/SpdzWiseInput.hpp"
#include "Protocols/SpdzWiseShare.hpp"
#include "Machines/ShamirMachine.hpp"
int main(int argc, const char** argv)
{
auto& opts = ShamirOptions::singleton;
ez::ezOptionParser opt;
opts = {opt, argc, argv};
ReplicatedMachine<SpdzWiseShare<MaliciousShamirShare<gfp>>,
SpdzWiseShare<MaliciousShamirShare<gf2n>>>(
argc, argv,
{ }, opt, opts.nparties);
}

View File

@@ -42,7 +42,7 @@ all: arithmetic binary gen_input online offline externalIO bmr doc
doc:
cd doc; $(MAKE) html
arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot
arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy
binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr
ifeq ($(USE_NTL),1)
@@ -77,7 +77,7 @@ overdrive: simple-offline.x pairwise-offline.x cnc-offline.x
rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x
rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x Fake-Offline.x
rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x rep4-ring-party.x
rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x
@@ -98,10 +98,12 @@ endif
shamir: shamir-party.x malicious-shamir-party.x galois-degree.x
sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC)
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMON) $(BMR) $(GC)
$(AR) -csr $@ $^
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
@@ -184,12 +186,18 @@ hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
static/hemi-party.x: $(FHEOFFLINE)
static/soho-party.x: $(FHEOFFLINE)
static/cowgear-party.x: $(FHEOFFLINE)
static/chaigear-party.x: $(FHEOFFLINE)
mascot-party.x: Machines/SPDZ.o $(OT)
static/mascot-party.x: Machines/SPDZ.o
Player-Online.x: Machines/SPDZ.o $(OT)
mama-party.x: $(OT)
ps-rep-ring-party.x: Protocols/MalRepRingOptions.o
malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
sy-rep-ring-party.x: Protocols/MalRepRingOptions.o
rep4-ring-party.x: GC/Rep4Secret.o
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o
mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)

View File

@@ -25,6 +25,7 @@ public:
static const int n_bits = sizeof(T) * 8;
static char type_char() { return 'B'; }
static string type_short() { return "B"; }
static DataFieldType field_type() { return DATA_GF2; }
static bool allows(Dtype dtype) { return dtype == DATA_TRIPLE or dtype == DATA_BIT; }
@@ -59,7 +60,7 @@ public:
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
void pack(octetStream& os, int n) const { os.store_int(this->a, DIV_CEIL(n, 8)); }
void pack(octetStream& os, int n) const { os.store_int(mask(n).a, DIV_CEIL(n, 8)); }
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
static BitVec_ unpack_new(octetStream& os, int n = n_bits)

View File

@@ -156,11 +156,6 @@ public:
return true;
}
bool operator!=(const FixedVec& other) const
{
return not equal(other);
}
bool is_zero()
{
return equal(0);
@@ -170,6 +165,11 @@ public:
return equal(1);
}
bool operator!=(const FixedVec<T, L>& other) const
{
return not equal(other);
}
FixedVec<T, L>operator+(const FixedVec<T, L>& other) const
{
FixedVec<T, L> res;
@@ -291,6 +291,15 @@ public:
return res;
}
T lazy_sum() const
{
assert(L > 1);
T res = v[0].lazy_add(v[1]);
for (int i = 2; i < L; i++)
res = res.lazy_add(v[i]);
return res;
}
FixedVec<T, L> extend_bit() const
{
FixedVec<T, L> res;
@@ -343,13 +352,21 @@ public:
void output(ostream& s, bool human) const
{
for (auto& x : v)
x.output(s, human);
if (human)
s << *this;
else
for (auto& x : v)
x.output(s, human);
}
void input(istream& s, bool human)
{
for (auto& x : v)
x.input(s, human);
for (int i = 0; i < L; i++)
{
if (human and i != 0)
if (s.get() != ',')
throw runtime_error("cannot read vector");
(*this)[i].input(s, human);
}
}
void pack(octetStream& os) const

View File

@@ -17,6 +17,8 @@ class ValueInterface
public:
static const int MAX_EDABITS = 0;
static const false_type characteristic_two;
template<class T>
static void init(bool mont = true) { (void) mont; }
static void init_default(int l) { (void) l; }

View File

@@ -62,13 +62,14 @@ public:
static int t() { return 0; }
static char type_char() { return 'R'; }
static string type_short() { return "R"; }
static string type_string() { return "Z2^" + to_string(int(N_BITS)); }
static DataFieldType field_type() { return DATA_INT; }
static const bool invertible = false;
static const false_type invertible;
template <int L, int M>
template <int L, int M, bool LAZY = false>
static Z2<K> Mul(const Z2<L>& x, const Z2<M>& y);
static void reqbl(int n);
@@ -151,6 +152,9 @@ public:
void add(octetStream& os) { add(os.consume(size())); }
Z2 lazy_add(const Z2& x) const;
Z2 lazy_mul(const Z2& x) const;
Z2& invert();
void invert(const Z2& a) { *this = a; invert(); }
@@ -279,10 +283,17 @@ public:
template<int K>
inline Z2<K> Z2<K>::operator+(const Z2<K>& other) const
{
auto res = lazy_add(other);
res.normalize();
return res;
}
template<int K>
Z2<K> Z2<K>::lazy_add(const Z2<K>& other) const
{
Z2<K> res;
mpn_add_fixed_n<N_WORDS>(res.a, a, other.a);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
@@ -332,12 +343,13 @@ Z2<K>& Z2<K>::operator>>=(int other)
}
template <int K>
template <int L, int M>
template <int L, int M, bool LAZY>
inline Z2<K> Z2<K>::Mul(const Z2<L>& x, const Z2<M>& y)
{
Z2<K> res;
mpn_mul_fixed_<N_WORDS, Z2<L>::N_WORDS, Z2<M>::N_WORDS>(res.a, x.a, y.a);
res.a[N_WORDS - 1] &= UPPER_MASK;
if (not LAZY)
res.normalize();
return res;
}
@@ -348,6 +360,12 @@ inline Z2<(K > L) ? K : L> Z2<K>::operator*(const Z2<L>& other) const
return Z2<(K > L) ? K : L>::Mul(*this, other);
}
template <int K>
inline Z2<K> Z2<K>::lazy_mul(const Z2<K>& other) const
{
return Z2<K>::Mul<K, K, true>(*this, other);
}
template <int K>
Z2<K> Z2<K>::operator<<(int i) const
{

View File

@@ -13,6 +13,8 @@ template<int K>
const int Z2<K>::N_BITS;
template<int K>
const int Z2<K>::N_BYTES;
template<int K>
const false_type Z2<K>::invertible;
template<int K>
void Z2<K>::reqbl(int n)

View File

@@ -83,6 +83,7 @@ class gf2n_short : public ValueInterface
static DataFieldType field_type() { return DATA_GF2N; }
static char type_char() { return '2'; }
static string type_short() { return "2"; }
static string type_string() { return "gf2n"; }
static int size() { return sizeof(a); }
@@ -94,8 +95,8 @@ class gf2n_short : public ValueInterface
static bool allows(Dtype type) { (void) type; return true; }
static const bool invertible = true;
static const bool characteristic_two = true;
static const true_type invertible;
static const true_type characteristic_two;
static gf2n_short cut(int128 x) { return x.get_lower(); }
@@ -163,6 +164,9 @@ class gf2n_short : public ValueInterface
// x * y when one of x,y is a bit
void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; }
gf2n_short lazy_add(const gf2n_short& x) const { return *this + x; }
gf2n_short lazy_mul(const gf2n_short& x) const { return *this * x; }
gf2n_short operator+(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; }
gf2n_short operator*(const gf2n_short& x) const { gf2n_short res; res.mul(*this, x); return res; }
gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; }

View File

@@ -134,6 +134,7 @@ class gf2n_long : public ValueInterface
static DataFieldType field_type() { return DATA_GF2N; }
static char type_char() { return '2'; }
static string type_short() { return "2"; }
static string type_string() { return "gf2n_long"; }
static int size() { return sizeof(a); }
@@ -144,8 +145,8 @@ class gf2n_long : public ValueInterface
static bool allows(Dtype type) { (void) type; return true; }
static const bool invertible = true;
static const bool characteristic_two = true;
static const true_type invertible;
static const true_type characteristic_two;
static gf2n_long cut(int128 x) { return x; }
@@ -216,6 +217,9 @@ class gf2n_long : public ValueInterface
// x * y when one of x,y is a bit
void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; }
gf2n_long lazy_add(const gf2n_long& x) const { return *this + x; }
gf2n_long lazy_mul(const gf2n_long& x) const { return *this * x; }
gf2n_long operator+(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; }
gf2n_long operator*(const gf2n_long& x) const { gf2n_long res; res.mul(*this, x); return res; }
gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; }
@@ -251,6 +255,8 @@ class gf2n_long : public ValueInterface
gf2n_long& operator>>=(int i) { SHR(*this, i); return *this; }
gf2n_long& operator<<=(int i) { SHL(*this, i); return *this; }
bool operator<(gf2n_long) const { return false; }
/* Crap RNG */
void randomize(PRNG& G, int n = -1);
// compatibility with gfp

View File

@@ -88,6 +88,7 @@ class gfp_ : public ValueInterface
static DataFieldType field_type() { return DATA_INT; }
static char type_char() { return 'p'; }
static string type_short() { return "p"; }
static string type_string() { return "gfp"; }
static int size() { return t() * sizeof(mp_limb_t); }
@@ -100,8 +101,7 @@ class gfp_ : public ValueInterface
static void specification(octetStream& os);
static const bool invertible = true;
static const bool characteristic_two = false;
static const true_type invertible;
static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; }
@@ -184,6 +184,9 @@ class gfp_ : public ValueInterface
void mul(const gfp_& x)
{ a.template mul<L>(a,x.a,ZpD); }
gfp_ lazy_add(const gfp_& x) const { return *this + x; }
gfp_ lazy_mul(const gfp_& x) const { return *this * x; }
gfp_ operator+(const gfp_& x) const { gfp_ res; res.add(*this, x); return res; }
gfp_ operator-(const gfp_& x) const { gfp_ res; res.sub(*this, x); return res; }
gfp_ operator*(const gfp_& x) const { gfp_ res; res.mul(*this, x); return res; }
@@ -266,7 +269,7 @@ class gfp_ : public ValueInterface
// Convert representation to and from a bigint number
friend void to_bigint(bigint& ans,const gfp_& x,bool reduce=true)
{ to_bigint(ans,x.a,x.ZpD,reduce); }
{ x.a.template to_bigint<L>(ans, x.ZpD, reduce); }
friend void to_gfp(gfp_& ans,const bigint& x)
{ to_modp(ans.a,x,ans.ZpD); }
};

View File

@@ -9,6 +9,9 @@
#include "Math/bigint.hpp"
#include "Math/Setup.hpp"
template<int X, int L>
const true_type gfp_<X, L>::invertible;
template<int X, int L>
inline void gfp_<X, L>::read_or_generate_setup(string dir,
const OnlineOptions& opts)

View File

@@ -74,6 +74,8 @@ class modp_
// Convert representation to and from a modp number
void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const;
template<int M>
void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const;
template<int T>
void mul(const modp_& x, const modp_& y, const Zp_Data& ZpD);

View File

@@ -113,6 +113,31 @@ void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
}
template<int L>
template<int M>
void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
{
assert(M == ZpD.t);
auto& x = *this;
mpz_ptr a = ans.get_mpz_t();
if (a->_mp_alloc < M)
mpz_realloc(a, M);
if (ZpD.montgomery)
{
mp_limb_t one[M];
inline_mpn_zero(one,M);
one[0]=1;
ZpD.Mont_Mult_<M>(a->_mp_d,x.x,one);
}
else
{ inline_mpn_copyi(a->_mp_d,x.x,M); }
a->_mp_size=M;
if (reduce)
while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0)
{ a->_mp_size--; }
}
template<int L>
void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD)
{

View File

@@ -20,7 +20,9 @@ void ssl_error(string side, string pronoun, string other, string server)
<< " failed. Make sure " << pronoun
<< " have the necessary certificate (" << PREP_DIR << server
<< ".pem in the default configuration),"
<< " and run `c_rehash <directory>` on its location." << endl;
<< " and run `c_rehash <directory>` on its location." << endl
<< "Also make sure that it's still valid. Certificates generated "
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
}
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :

View File

@@ -90,7 +90,15 @@ struct CommStats
size_t data, rounds;
Timer timer;
CommStats() : data(0), rounds(0) {}
Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; }
Timer& add(const octetStream& os)
{
#ifdef VERBOSE_COMM
cout << "add " << os.get_length() << endl;
#endif
data += os.get_length();
rounds++;
return timer;
}
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
CommStats& operator+=(const CommStats& other);
CommStats& operator-=(const CommStats& other);

View File

@@ -7,6 +7,7 @@
#include <iostream>
#include <sodium.h>
#include <regex>
using namespace std;
BaseMachine* BaseMachine::singleton = 0;
@@ -28,13 +29,14 @@ BaseMachine::BaseMachine() : nthreads(0)
singleton = this;
}
void BaseMachine::load_schedule(string progname)
void BaseMachine::load_schedule(string progname, bool load_bytecode)
{
this->progname = progname;
string fname = "Programs/Schedules/" + progname + ".sch";
#ifdef DEBUG_FILES
cerr << "Opening file " << fname << endl;
#endif
ifstream inpf;
inpf.open(fname);
if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); }
@@ -54,25 +56,35 @@ void BaseMachine::load_schedule(string progname)
string threadname;
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
string filename = "Programs/Bytecode/" + threadname + ".bc";
if (load_bytecode)
{
string filename = "Programs/Bytecode/" + threadname + ".bc";
#ifdef DEBUG_FILES
cerr << "Loading program " << i << " from " << filename << endl;
cerr << "Loading program " << i << " from " << filename << endl;
#endif
load_program(threadname, filename);
load_program(threadname, filename);
}
}
for (auto i : {1, 0, 0})
{
int n;
inpf >> n;
if (n != i)
throw runtime_error("old schedule format not supported");
}
inpf.get();
getline(inpf, compiler);
inpf.close();
}
void BaseMachine::print_compiler()
{
char compiler[1000];
inpf.get();
inpf.getline(compiler, 1000);
#ifdef VERBOSE
if (compiler[0] != 0)
if (compiler.size() != 0)
cerr << "Compiler: " << compiler << endl;
#endif
inpf.close();
}
void BaseMachine::load_program(string threadname, string filename)
@@ -112,3 +124,20 @@ string BaseMachine::memory_filename(string type_short, int my_number)
{
return PREP_DIR "Memory-" + type_short + "-P" + to_string(my_number);
}
int BaseMachine::ring_size_from_schedule(string progname)
{
assert(not singleton);
BaseMachine machine;
singleton = 0;
machine.load_schedule(progname, false);
smatch m;
regex e("R ([0-9]+)");
regex_search(machine.compiler, m, e);
if (m.size() > 1)
{
return stoi(m[1]);
}
else
return 0;
}

View File

@@ -22,7 +22,7 @@ protected:
std::map<int,Timer> timer;
ifstream inpf;
string compiler;
void print_timers();
@@ -43,10 +43,12 @@ public:
static string memory_filename(string type_short, int my_number);
static int ring_size_from_schedule(string progname);
BaseMachine();
virtual ~BaseMachine() {}
void load_schedule(string progname);
void load_schedule(string progname, bool load_bytecode = true);
void print_compiler();
void time();

View File

@@ -14,13 +14,13 @@ const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 };
void DataPositions::set_num_players(int num_players)
{
files.resize(N_DATA_FIELD_TYPE, vector<long long>(N_DTYPE));
inputs.resize(num_players, vector<long long>(N_DATA_FIELD_TYPE));
files = {};
inputs.resize(num_players, {});
}
void DataPositions::increase(const DataPositions& delta)
{
inputs.resize(max(inputs.size(), delta.inputs.size()), vector<long long>(N_DATA_FIELD_TYPE));
inputs.resize(max(inputs.size(), delta.inputs.size()), {});
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++)
@@ -39,8 +39,7 @@ void DataPositions::increase(const DataPositions& delta)
DataPositions& DataPositions::operator-=(const DataPositions& delta)
{
inputs.resize(max(inputs.size(), delta.inputs.size()),
vector<long long>(N_DATA_FIELD_TYPE));
inputs.resize(max(inputs.size(), delta.inputs.size()), {});
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE;
field_type++)
{
@@ -144,3 +143,25 @@ void DataPositions::process_line(long long items_used, const char* name,
cerr << suffix << endl;
}
}
bool DataPositions::empty() const
{
for (auto& x : files)
for (auto& y : x)
if (y)
return false;
for (auto& x : inputs)
for (auto& y : x)
if (y)
return false;
for (auto& x : extended)
if (not x.empty())
return false;
if (not edabits.empty())
return false;
return true;
}

View File

@@ -10,11 +10,14 @@
#include "Processor/InputTuple.h"
#include "Tools/Lock.h"
#include "Networking/Player.h"
#include "Protocols/edabit.h"
#include <fstream>
#include <map>
using namespace std;
template<class T> class dabit;
class DataTag
{
int t[4];
@@ -50,9 +53,9 @@ public:
static const char* field_names[N_DATA_FIELD_TYPE];
static const int tuple_size[N_DTYPE];
vector< vector<long long> > files;
vector< vector<long long> > inputs;
map<DataTag, long long> extended[N_DATA_FIELD_TYPE];
array<array<long long, N_DTYPE>, N_DATA_FIELD_TYPE> files;
vector< array<long long, N_DATA_FIELD_TYPE> > inputs;
array<map<DataTag, long long>, N_DATA_FIELD_TYPE> extended;
map<pair<bool, int>, long long> edabits;
DataPositions(int num_players = 0) { set_num_players(num_players); }
@@ -63,6 +66,7 @@ public:
DataPositions& operator-=(const DataPositions& delta);
DataPositions operator-(const DataPositions& delta) const;
void print_cost() const;
bool empty() const;
};
template<class sint, class sgf2n> class Processor;
@@ -73,9 +77,12 @@ template<class T> class SubProcessor;
template<class T>
class Preprocessing
{
protected:
DataPositions& usage;
protected:
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
map<pair<bool, int>, edabitvec<T>> my_edabits;
void count(Dtype dtype) { usage.files[T::field_type()][dtype]++; }
void count(DataTag tag, int n = 1) { usage.extended[T::field_type()][tag] += n; }
void count_input(int player) { usage.inputs[player][T::field_type()]++; }
@@ -117,10 +124,12 @@ public:
virtual array<T, 3> get_triple(int n_bits);
virtual T get_bit();
virtual void get_dabit(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
virtual void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
const vector<int>&)
{ throw runtime_error("no edaBit"); }
virtual void get_dabit(T&, typename T::bit_type&);
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs);
virtual void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
virtual void push_triples(const vector<array<T, 3>>&)
{ throw runtime_error("no pushing"); }
@@ -145,10 +154,17 @@ class Sub_Data_Files : public Preprocessing<T>
vector<BufferOwner<T, T>> input_buffers;
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
map<int, ifstream*> edabit_buffers;
int my_num,num_players;
const string prep_data_dir;
int thread_num;
Sub_Data_Files<typename T::part_type>* part;
void buffer_edabits_with_queues(bool stric, int n_bits);
public:
static string get_suffix(int thread_num);
@@ -202,6 +218,9 @@ public:
void setup_extended(const DataTag& tag, int tuple_size = 0);
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_dabit_no_count(T& a, typename T::bit_type& b);
Preprocessing<typename T::part_type>& get_part();
};
template<class sint, class sgf2n>

View File

@@ -3,6 +3,8 @@
#include "Processor/Data_Files.h"
#include "Processor/Processor.h"
#include "Protocols/dabit.h"
#include "Math/Setup.h"
template<class T>
Lock Sub_Data_Files<T>::tuple_lengths_lock;
@@ -54,7 +56,8 @@ template<class T>
Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
const string& prep_data_dir, DataPositions& usage, int thread_num) :
Preprocessing<T>(usage),
my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir)
my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir),
thread_num(thread_num), part(0)
{
#ifdef DEBUG_FILES
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
@@ -72,6 +75,11 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
}
}
sprintf(filename, (prep_data_dir + "%s-%s-P%d%s").c_str(),
DataPositions::dtype_names[DATA_DABIT], (T::type_short()).c_str(), my_num,
suffix.c_str());
dabit_buffer.setup(filename, 1, DataPositions::dtype_names[DATA_DABIT]);
input_buffers.resize(num_players);
for (int i=0; i<num_players; i++)
{
@@ -127,6 +135,14 @@ Sub_Data_Files<T>::~Sub_Data_Files()
for (auto it =
extended.begin(); it != extended.end(); it++)
it->second.close();
dabit_buffer.close();
for (auto& x: edabit_buffers)
{
x.second->close();
delete x.second;
}
if (part != 0)
delete part;
}
template<class T>
@@ -236,4 +252,48 @@ void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int
extended[tag].input(S[regs[i] + j]);
}
template<class T>
void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
{
dabit<T> tmp;
dabit_buffer.input(tmp);
a = tmp.first;
b = tmp.second;
}
template<class T>
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits)
{
#ifndef INSECURE
throw runtime_error("no secure implementation of reading edaBits from files");
#endif
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
{
string filename = prep_data_dir + "edaBits-" + to_string(n_bits) + "-P"
+ to_string(my_num);
ifstream* f = new ifstream(filename);
if (f->fail())
throw runtime_error("cannot open " + filename);
edabit_buffers[n_bits] = f;
}
auto& buffer = *edabit_buffers[n_bits];
if (buffer.peek() == EOF)
buffer.seekg(0);
edabitvec<T> eb;
eb.input(n_bits, buffer);
this->edabits[{strict, n_bits}].push_back(eb);
if (buffer.fail())
throw runtime_error("error reading edaBits");
}
template<class T>
Preprocessing<typename T::part_type>& Sub_Data_Files<T>::get_part()
{
if (part == 0)
part = new Sub_Data_Files<typename T::part_type>(my_num, num_players,
get_prep_sub_dir<typename T::part_type>(num_players), this->usage,
thread_num);
return *part;
}
#endif

View File

@@ -45,6 +45,9 @@ public:
{
throw not_implemented();
}
void CheckFor(const typename T::open_type&, const vector<T>&, const Player&)
{
}
DummyMC<typename T::part_type>& get_part_MC()
{
@@ -56,6 +59,11 @@ public:
throw not_implemented();
return {};
}
int number()
{
return 0;
}
};
template<class T>
@@ -63,12 +71,17 @@ class DummyProtocol : public ProtocolBase<T>
{
public:
Player& P;
int counter;
static int get_n_relevant_players()
{
throw not_implemented();
}
static void multiply(vector<T>, vector<pair<T, T>>, int, int, SubProcessor<T>)
{
}
DummyProtocol(Player& P) :
P(P)
{
@@ -91,6 +104,9 @@ public:
throw not_implemented();
return {};
}
void check()
{
}
};
template<class T>
@@ -170,6 +186,10 @@ public:
{
(void) proc, (void) MC;
}
template<class T, class U, class W>
NotImplementedInput(const T&, const U&, const W&)
{
}
NotImplementedInput(Player& P)
{
(void) P;
@@ -200,6 +220,12 @@ public:
(void) proc, (void) regs;
throw not_implemented();
}
static void input_mixed(SubProcessor<V>, vector<int>, int, int)
{
}
static void raw_input(SubProcessor<V>, vector<int>, int)
{
}
void reset_all(Player& P)
{
(void) P;
@@ -248,7 +274,7 @@ public:
(void) player, (void) target, (void) source;
throw not_implemented();
}
void stop(int player, int source)
void stop(int player, int source, int)
{
(void) player, (void) source;
}

View File

@@ -20,9 +20,9 @@ class InputBase
{
typedef typename T::clear clear;
protected:
Player* P;
protected:
Buffer<typename T::clear, typename T::clear> buffer;
Timer timer;
@@ -42,6 +42,7 @@ public:
static void finalize(SubProcessor<T>& Proc, int player, const int* params, int size);
InputBase(ArithmeticProcessor* proc = 0);
InputBase(SubProcessor<T>* proc);
virtual ~InputBase();
virtual void reset(int player) = 0;
@@ -56,7 +57,7 @@ public:
virtual T finalize_mine() = 0;
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
T finalize(int player, int n_bits = -1);
virtual T finalize(int player, int n_bits = -1);
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);
};

View File

@@ -23,6 +23,12 @@ InputBase<T>::InputBase(ArithmeticProcessor* proc) :
buffer.setup(&proc->private_input, -1, proc->private_input_filename);
}
template<class T>
InputBase<T>::InputBase(SubProcessor<T>* proc) :
InputBase(proc ? proc->Proc : 0)
{
}
template<class T>
Input<T>::Input(SubProcessor<T>& proc) :
Input(proc, proc.MC)
@@ -92,7 +98,7 @@ void Input<T>::add_mine(const open_type& input, int n_bits)
prep.get_input(share, rr, player);
t = input - rr;
t.pack(this->os[player]);
share += T::constant(t, 0, MC.get_alphai());
share += T::constant(t, player, MC.get_alphai());
this->values_input++;
}
@@ -190,7 +196,7 @@ void Input<T>::finalize_other(int player, T& target,
(void) n_bits;
target = shares[player].next();
t.unpack(o);
target += T::constant(t, 1, MC.get_alphai());
target += T::constant(t, P.my_num(), MC.get_alphai());
}
template<class T>

View File

@@ -330,6 +330,7 @@ struct TempVars {
class BaseInstruction
{
friend class Program;
template<class T> friend class RepRingOnlyEdabitPrep;
protected:
int opcode; // The code

View File

@@ -1217,7 +1217,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]));
break;
case RANDOMS:
Procp.protocol.randoms_inst(Procp, *this);
Procp.protocol.randoms_inst(Procp.get_S(), *this);
return;
case INPUTMASKREG:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2]));
@@ -1314,14 +1314,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case SHLC:
to_bigint(Proc.temp.aa,Proc.read_Cp(r[2]));
if (Proc.temp.aa > 63)
throw runtime_error("too much left shift");
Proc.get_Cp_ref(r[0]).SHL(Proc.read_Cp(r[1]),Proc.temp.aa);
break;
case SHRC:
to_bigint(Proc.temp.aa,Proc.read_Cp(r[2]));
if (Proc.temp.aa > 63)
throw runtime_error("too much right shift");
Proc.get_Cp_ref(r[0]).SHR(Proc.read_Cp(r[1]),Proc.temp.aa);
break;
case SHLCI:
@@ -1337,8 +1333,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n);
break;
case SHRSI:
Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n;
break;
sint::shrsi(Procp, *this);
return;
case GBITDEC:
for (int j = 0; j < size; j++)
{

View File

@@ -39,35 +39,40 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
sint::clear::read_or_generate_setup(prep_dir_prefix<sint>(), opts);
sint::bit_type::mac_key_type::init_field();
// Initialize gf2n_short for CCD
sint::bit_type::part_type::open_type::init_field();
// make directory for outputs if necessary
mkdir_p(PREP_DIR);
Player* P;
if (use_encryption)
P = new CryptoPlayer(N, 0xF00);
else
P = new PlainPlayer(N, 0xF00);
if (opts.live_prep)
{
auto P = new PlainPlayer(N, 0xF00);
sint::LivePrep::basic_setup(*P);
delete P;
}
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), N, alphapi);
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), N, alpha2i);
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), *P, alphapi);
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), *P, alpha2i);
sint::bit_type::part_type::read_or_generate_mac_key(
prep_dir_prefix<typename sint::bit_type>(), N, alphabi);
prep_dir_prefix<typename sint::bit_type>(), *P, alphabi);
#ifdef DEBUG_MAC
cerr << "MAC Key p = " << alphapi << endl;
cerr << "MAC Key 2 = " << alpha2i << endl;
#endif
// deactivate output if necessary
sint::bit_type::out.activate(my_number == 0 or opts.interactive);
// for OT-based preprocessing
sint::clear::next::template init<typename sint::clear>(false);
// Initialize the global memory
if (memtype.compare("old")==0)
{
ifstream inpf;
inpf.open(memory_filename(), ios::in | ios::binary);
if (inpf.fail()) { throw file_error(memory_filename()); }
inpf >> M2 >> Mp >> Mi;
@@ -90,16 +95,12 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
if (live_prep
and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot))
{
Player* P;
if (use_encryption)
P = new CryptoPlayer(playerNames, 0xF000);
else
P = new PlainPlayer(playerNames, 0xF000);
for (int i = 0; i < nthreads; i++)
ot_setups.push_back({ *P, true });
delete P;
}
delete P;
/* Set up the threads */
tinfo.resize(nthreads);
threads.resize(nthreads);
@@ -190,7 +191,7 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
}
catch (bad_cast& e)
{
#ifdef VERBOSE
#ifdef VERBOSE_CENTRAL
cerr << "Problem with central preprocessing" << endl;
#endif
}
@@ -210,7 +211,7 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
}
catch (bad_cast& e)
{
#ifdef VERBOSE
#ifdef VERBOSE_CENTRAL
cerr << "Problem with central bit triple preprocessing: " << e.what() << endl;
#endif
}
@@ -231,12 +232,14 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
//printf("Running line %d\n",exec);
if (progs[tape_number].usage_unknown())
{
#ifndef INSECURE
if (not opts.live_prep)
{
cerr << "Internally called tape " << tape_number <<
" has unknown offline data usage" << endl;
throw invalid_program();
}
#endif
return DataPositions(N.num_players());
}
else
@@ -263,9 +266,6 @@ void Machine<sint, sgf2n>::run()
proc_timer.start();
timer[0].start();
// legacy
int _;
inpf >> _ >> _ >> _;
// run main tape
pos.increase(run_tape(0, 0, 0));
join_tape(0);

View File

@@ -16,6 +16,13 @@ template<class T>
class NoLivePrep : public Sub_Data_Files<T>
{
public:
static void basic_setup(Player&)
{
}
static void teardown()
{
}
NoLivePrep(SubProcessor<T>* proc, DataPositions& usage) : Sub_Data_Files<T>(0, 0, "", usage, 0)
{
(void) proc;

42
Processor/NoProtocol.h Normal file
View File

@@ -0,0 +1,42 @@
/*
* NoProtocol.h
*
*/
#ifndef PROCESSOR_NOPROTOCOL_H_
#define PROCESSOR_NOPROTOCOL_H_
#include "Protocols/Replicated.h"
template<class T>
class NoProtocol : public ProtocolBase<T>
{
public:
NoProtocol(Player&)
{
}
void init_mul(SubProcessor<T>*)
{
throw not_implemented();
}
typename T::clear prepare_mul(const T&, const T&, int n = -1)
{
(void) n;
throw not_implemented();
}
void exchange()
{
throw not_implemented();
}
T finalize_mul(int n = -1)
{
(void) n;
throw not_implemented();
}
};
#endif /* PROCESSOR_NOPROTOCOL_H_ */

View File

@@ -240,19 +240,6 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
job.pos.increase(Proc.DataF.get_usage());
}
//double elapsed = timeval_diff(&startv, &endv);
//printf("Thread time = %f seconds\n",elapsed/1000000);
//printf("\texec = %d\n",exec); exec++;
//printf("\tMC2.number = %d\n",MC2.number());
//printf("\tMCp.number = %d\n",MCp.number());
// MACCheck
MC2->Check(P);
MCp->Check(P);
//printf("\tMAC checked\n");
P.Check_Broadcast();
//printf("\tBroadcast checked\n");
#ifdef DEBUG_THREADS
printf("\tSignalling I have finished\n");
#endif
@@ -269,6 +256,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// MACCheck
MC2->Check(P);
MCp->Check(P);
Proc.share_thread.MC->Check(P);
//cout << num << " : Checking broadcast" << endl;
P.Check_Broadcast();

View File

@@ -17,6 +17,7 @@ OnlineOptions::OnlineOptions() : playerno(-1)
live_prep = true;
batch_size = 10000;
memtype = "empty";
bits_from_squares = false;
direct = false;
bucket_size = 3;
cmd_private_input_file = "Player-Data/Input";
@@ -138,6 +139,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-m", // Flag token.
"--memory" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Compute random bits from squares", // Help description.
"-Q", // Flag token.
"--bits-from-squares" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
@@ -174,6 +184,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
live_prep = opt.get("-L")->isSet;
opt.get("-b")->getInt(batch_size);
opt.get("--memory")->getString(memtype);
bits_from_squares = opt.isSet("-Q");
opt.get("-IF")->getString(cmd_private_input_file);
opt.get("-OF")->getString(cmd_private_output_file);

View File

@@ -22,6 +22,7 @@ public:
std::string progname;
int batch_size;
std::string memtype;
bool bits_from_squares;
bool direct;
int bucket_size;
std::string cmd_private_input_file;

View File

@@ -52,6 +52,11 @@ SubProcessor<T>::~SubProcessor()
if (bit_prep.data_sent())
cerr << "Sent for global bit preprocessing threads: " <<
bit_prep.data_sent() * 1e-6 << " MB" << endl;
if (not bit_usage.empty())
{
cerr << "Mixed-circuit preprocessing cost:" << endl;
bit_usage.print_cost();
}
#endif
}
@@ -82,13 +87,16 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
secure_prng.ReSeed();
shared_prng.SeedGlobally(P);
out.activate(P.my_num() == 0 or machine.opts.interactive);
// only output on party 0 if not interactive
bool output = P.my_num() == 0 or machine.opts.interactive;
out.activate(output);
Procb.out.activate(output);
setup_redirection(P.my_num(), thread_num, opts);
if (!machine.opts.cmd_private_output_file.empty())
if (stdout_redirect_file.is_open())
{
const string stdout_filename = get_parameterized_filename(P.my_num(), thread_num, opts.cmd_private_output_file);
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
out.redirect_to_file(stdout_redirect_file);
Procb.out.redirect_to_file(stdout_redirect_file);
}
}

View File

@@ -0,0 +1,17 @@
/*
* ProcessorBase.cpp
*
*/
#include "ProcessorBase.hpp"
void ProcessorBase::setup_redirection(int my_num, int thread_num,
OnlineOptions& opts)
{
if (not opts.cmd_private_output_file.empty())
{
const string stdout_filename = get_parameterized_filename(my_num,
thread_num, opts.cmd_private_output_file);
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
}
}

View File

@@ -12,6 +12,7 @@
using namespace std;
#include "Tools/ExecutionStats.h"
#include "OnlineOptions.h"
class ProcessorBase
{
@@ -30,6 +31,8 @@ protected:
public:
ExecutionStats stats;
ofstream stdout_redirect_file;
void pushi(long x) { stacki.push(x); }
void popi(long& x) { x = stacki.top(); stacki.pop(); }
@@ -50,6 +53,8 @@ public:
T get_input(bool interactive, const int* params);
template<class T>
T get_input(istream& is, const string& input_filename, const int* params);
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts);
};
#endif /* PROCESSOR_PROCESSORBASE_H_ */

View File

@@ -4,11 +4,13 @@
*/
#include "RingOptions.h"
#include "BaseMachine.h"
#include <iostream>
using namespace std;
RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv,
bool security)
{
opt.add(
"64", // Default.
@@ -19,8 +21,37 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
"-R", // Flag token.
"--ring" // Flag token.
);
if (security)
opt.add(
"40", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Security parameter (default: 40)", // Help description.
"-S", // Flag token.
"--security" // Flag token.
);
opt.parse(argc, argv);
opt.get("-R")->getInt(R);
if (security)
opt.get("-S")->getInt(S);
else
S = -1;
R_is_set = opt.isSet("-R");
opt.resetArgs();
cerr << "Trying to run " << R << "-bit computation" << endl;
if (R_is_set)
cerr << "Trying to run " << R << "-bit computation" << endl;
if (security)
cerr << "Using security parameter " << S << endl;
}
int RingOptions::ring_size_from_opts_or_schedule(string progname)
{
if (R_is_set)
return R;
int r = BaseMachine::ring_size_from_schedule(progname);
if (r == 0)
r = R;
cerr << "Trying to run " << r << "-bit computation" << endl;
return r;
}

View File

@@ -7,13 +7,21 @@
#define PROCESSOR_RINGOPTIONS_H_
#include "Tools/ezOptionParser.h"
#include <string>
using namespace std;
class RingOptions
{
bool R_is_set;
public:
int R;
int S;
RingOptions(ez::ezOptionParser& opt, int argc, const char** argv);
RingOptions(ez::ezOptionParser& opt, int argc, const char** argv,
bool security = false);
int ring_size_from_opts_or_schedule(string progname);
};
#endif /* PROCESSOR_RINGOPTIONS_H_ */

76
Processor/TruncPrTuple.h Normal file
View File

@@ -0,0 +1,76 @@
/*
* TruncPrTuple.h
*
*/
#ifndef PROCESSOR_TRUNCPRTUPLE_H_
#define PROCESSOR_TRUNCPRTUPLE_H_
#include <vector>
#include <assert.h>
using namespace std;
template<class T>
class TruncPrTuple
{
public:
int dest_base;
int source_base;
int k;
int m;
int n_shift;
TruncPrTuple(const vector<int>& regs, size_t base)
{
dest_base = regs[base];
source_base = regs[base + 1];
k = regs[base + 2];
m = regs[base + 3];
n_shift = T::N_BITS - 1 - k;
assert(m < k);
assert(0 < k);
assert(m < T::N_BITS);
}
T upper(T mask)
{
return (mask << (n_shift + 1)) >> (n_shift + m + 1);
}
T msb(T mask)
{
return (mask << (n_shift)) >> (T::N_BITS - 1);
}
};
template<class T>
class TruncPrTupleWithGap : public TruncPrTuple<T>
{
public:
TruncPrTupleWithGap(const vector<int>& regs, size_t base) :
TruncPrTuple<T>(regs, base)
{
}
T upper(T mask)
{
if (big_gap())
return mask >> this->m;
else
return TruncPrTuple<T>::upper(mask);
}
T msb(T mask)
{
assert(not big_gap());
return TruncPrTuple<T>::msb(mask);
}
bool big_gap()
{
return this->k <= T::N_BITS - 40;
}
};
#endif /* PROCESSOR_TRUNCPRTUPLE_H_ */

View File

@@ -55,6 +55,9 @@
X(MULCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
typename sint::clear op2 = int(n), \
*dest++ = *op1++ * op2) \
X(MULSI, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
typename sint::clear op2 = int(n), \
*dest++ = *op1++ * op2) \
X(SHRCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]], \
*dest++ = *op1++ >> n) \
X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \

View File

@@ -21,7 +21,7 @@ def match(db_entry, sample):
from Compiler import util
if n_threads is None:
util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n)))
res = util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n)))
else:
tmp = sint.Array(n_threads)
@@ -31,4 +31,6 @@ else:
(match(db[base + i], sample)
for i in range(size)))
util.tree_reduce(lambda x, y: x.min(y), tmp)
res = util.tree_reduce(lambda x, y: x.min(y), tmp)
print_ln('result: %s', res.reveal())

View File

@@ -0,0 +1,69 @@
import ml
import math
import re
import util
program.options_from_args()
sfix.set_precision_from_args(program)
n_examples = 11791
n_test = 1991
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
N = n_examples
batch_size = 128
assert batch_size <= N
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
n_inner = 128
n_dense_layers = None
for arg in program.args:
m = re.match('(.*)dense', arg)
if m:
n_dense_layers = int(m.group(1))
if n_dense_layers == 1:
layers = [ml.Dense(N, n_features, 1, activation='id')]
elif n_dense_layers > 1:
layers = [ml.Dense(N, n_features, n_inner, activation='relu')]
for i in range(n_dense_layers - 2):
layers += [ml.Dense(N, n_inner, n_inner, activation='relu')]
layers += [ml.Dense(N, n_inner, 1, activation='id')]
else:
raise CompilerError('number of dense layers not specified')
layers += [ml.Output.from_args(N, program)]
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
if not ('no_acc' in program.args and 'no_loss' in program.args):
layers[-1].Y.input_from(0)
layers[0].X.input_from(0)
Y.input_from(0)
X.input_from(0)
sgd = ml.SGD(layers, 1)
if 'no_out' in program.args:
del sgd.layers[-1]
if 'forward' in program.args:
sgd.forward(batch=regint.Array(batch_size))
elif 'backward' in program.args:
sgd.backward(batch=regint.Array(batch_size))
elif 'update' in program.args:
sgd.update(0, batch=regint.Array(batch_size))
else:
sgd.run_by_args(program, n_epochs, batch_size, X, Y)

View File

@@ -0,0 +1,99 @@
import ml
import math
#ml.report_progress = True
program.options_from_args()
approx = 3
if 'profile' in program.args:
print('Compiling for profiling')
N = 1000
n_test = 100
elif 'debug' in program.args:
N = 10
n_test = 10
elif 'gisette' in program.args:
print('Compiling for 4/9')
N = 11791
n_test = 1991
else:
N = 12665
n_test = 2115
n_examples = N
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
try:
batch_size = int(program.args[2])
except:
batch_size = N
assert batch_size <= N
try:
ml.set_n_threads(int(program.args[3]))
except:
pass
if 'debug' in program.args:
n_inner = 10
n_features = 10
else:
n_inner = 128
if 'norelu' in program.args:
activation = 'id'
else:
activation = 'relu'
layers = [ml.Dense(N, n_features, n_inner, activation=activation),
ml.Dense(N, n_inner, n_inner, activation=activation),
ml.Dense(N, n_inner, 1),
ml.Output(N, approx=approx)]
if '2dense' in program.args:
del layers[1]
layers[-1].Y.input_from(0)
layers[0].X.input_from(0)
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
Y.input_from(0)
X.input_from(0)
sgd = ml.SGD(layers, 10, report_loss=True)
sgd.reset()
@for_range(int(math.ceil(n_epochs / 10)))
def _(i):
start_timer(1)
sgd.run(batch_size)
stop_timer(1)
def get_correct(Y, n):
n_correct = regint(0)
for i in range(n):
n_correct += (Y[i].reveal() > 0).bit_xor(
layers[-2].Y[i][0][0][0].reveal() < 0)
return n_correct
sgd.forward(N)
n_correct = get_correct(layers[-1].Y, N)
print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N)
training_address = layers[0].X.address
layers[0].X.address = X.address
sgd.forward(n_test)
layers[0].X.address = training_address
n_correct = get_correct(Y, n_test)
print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test)

View File

@@ -0,0 +1,109 @@
import ml
import math
import re
import util
#ml.report_progress = True
program.options_from_args()
if 'profile' in program.args:
print('Compiling for profiling')
N = 1000
n_test = 100
elif 'debug' in program.args:
N = 100
n_test = 100
else:
N = 60000
n_test = 10000
n_examples = N
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
try:
batch_size = int(program.args[2])
except:
batch_size = N
assert batch_size <= N
try:
ml.set_n_threads(int(program.args[3]))
except:
pass
n_inner = 128
if 'norelu' in program.args:
activation = 'id'
else:
activation = 'relu'
if 'nearest' in program.args:
sfix.round_nearest = True
if 'double' in program.args:
sfix.set_precision(32, 63)
cfix.set_precision(32, 63)
elif 'triple' in program.args:
sfix.set_precision(48, 91)
cfix.set_precision(48, 91)
elif 'quadruple' in program.args:
sfix.set_precision(64, 127)
cfix.set_precision(64, 127)
elif 'sextuple' in program.args:
sfix.set_precision(96, 191)
cfix.set_precision(96, 191)
elif 'octuple' in program.args:
sfix.set_precision(128, 255)
cfix.set_precision(128, 255)
assert sfix.f * 4 == int(program.options.ring)
debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2)
if '1dense' in program.args:
layers = [ml.Dense(N, n_features, 10, debug=debug_ml)]
else:
layers = [ml.Dense(N, n_features, n_inner, activation=activation, debug=debug_ml),
ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml),
ml.Dense(N, n_inner, 10, debug=debug_ml)]
layers += [ml.MultiOutput.from_args(program, N, 10)]
layers[-1].cheaper_loss = 'mse' in program.args
if '2dense' in program.args:
del layers[1]
layers[-1].Y.input_from(0)
layers[0].X.input_from(0)
Y = sint.Matrix(n_test, 10)
X = sfix.Matrix(n_test, n_features)
Y.input_from(0)
X.input_from(0)
if 'always_acc' in program.args:
n_part_epochs = 1
else:
n_part_epochs = 10
sgd = ml.SGD(layers, n_part_epochs, report_loss=True, debug=debug_ml)
#sgd.print_update_average = True
sgd.print_losses = 'print_losses' in program.args
if 'faster' in program.args:
sgd.gamma = MemValue(cfix(.1))
if 'slower' in program.args:
sgd.gamma = MemValue(cfix(.001))
sgd.run_by_args(program, int(math.ceil(n_epochs / n_part_epochs)), batch_size,
X, Y)

View File

@@ -0,0 +1,59 @@
import ml
program.options_from_args()
approx = 3
if 'gisette' in program.args:
print('Compiling for 4/9')
N = 11791
n_test = 1991
else:
N = 12665
n_test = 2115
n_examples = N
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
try:
batch_size = int(program.args[2])
except:
batch_size = N
try:
ml.set_n_threads(int(program.args[3]))
except:
pass
layers = [ml.Dense(N, n_features, 1),
ml.Output(N, approx=approx)]
layers[1].Y.input_from(0)
layers[0].X.input_from(0)
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
Y.input_from(0)
X.input_from(0)
sgd = ml.SGD(layers, n_epochs, report_loss=True)
sgd.reset()
start_timer(1)
sgd.run(batch_size)
stop_timer(1)
layers[0].X.assign(X)
sgd.forward(n_test)
n_correct = cfix(0)
for i in range(n_test):
n_correct += Y[i].reveal().bit_xor(layers[0].Y[i][0][0][0].reveal() < 0)
print_ln('acc: %s (%s/%s)', n_correct / n_test, n_correct, n_test)

View File

@@ -37,6 +37,12 @@ if len(program.args) > 2:
n_normal = 49
n_features = 17814
if 'mnist' in program.args:
print('Compiling for MNIST')
n_examples = 2115
n_normal = 980
n_features = 28 ** 2
n_pos = n_examples - n_normal
n_epochs = 1
if len(program.args) > 1:

View File

@@ -15,6 +15,10 @@ class FakeInput : public InputBase<T>
PointerVector<T> results;
public:
FakeInput()
{
}
FakeInput(SubProcessor<T>&, typename T::MAC_Check&)
{
}

View File

@@ -77,6 +77,12 @@ public:
a = bit;
b = bit;
}
void get_one_no_count(Dtype dtype, T& a)
{
assert(dtype == DATA_BIT);
a = G.get_uchar() & 1;
}
};
#endif /* PROTOCOLS_FAKEPREP_H_ */

View File

@@ -7,6 +7,7 @@
#define PROTOCOLS_FAKEPROTOCOL_H_
#include "Replicated.h"
#include "Math/Z2k.h"
template<class T>
class FakeProtocol : public ProtocolBase<T>
@@ -14,6 +15,10 @@ class FakeProtocol : public ProtocolBase<T>
PointerVector<T> results;
SeededPRNG G;
T dot_prod;
T trunc_max;
public:
Player& P;
@@ -21,6 +26,27 @@ public:
{
}
#ifdef VERBOSE
~FakeProtocol()
{
output_trunc_max<0>(T::invertible);
}
template<int>
void output_trunc_max(false_type)
{
if (trunc_max != T())
cerr << "Maximum bit length in truncation: "
<< (bigint(typename T::clear(trunc_max)).numBits() + 1)
<< " (" << trunc_max << ")" << endl;
}
template<int>
void output_trunc_max(true_type)
{
}
#endif
void init_mul(SubProcessor<T>*)
{
results.clear();
@@ -41,6 +67,28 @@ public:
return results.next();
}
void init_dotprod(SubProcessor<T>* proc)
{
init_mul(proc);
dot_prod = {};
}
void prepare_dotprod(const T& x, const T& y)
{
dot_prod += x * y;
}
void next_dotprod()
{
results.push_back(dot_prod);
dot_prod = 0;
}
T finalize_dotprod(int)
{
return finalize_mul();
}
void randoms(T& res, int n_bits)
{
res.randomize_part(G, n_bits);
@@ -52,11 +100,63 @@ public:
}
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc)
{
trunc_pr<0>(regs, size, proc, T::characteristic_two);
}
template<int>
void trunc_pr(const vector<int>&, int, SubProcessor<T>&, true_type)
{
throw not_implemented();
}
template<int>
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc, false_type)
{
for (size_t i = 0; i < regs.size(); i += 4)
for (int l = 0; l < size; l++)
proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l)
>> regs[i + 3];
{
auto& res = proc.get_S_ref(regs[i] + l);
auto& source = proc.get_S_ref(regs[i + 1] + l);
T tmp = source - (T(1) << regs[i + 2] - 1);
tmp = tmp < T() ? (T() - tmp) : tmp;
trunc_max = max(trunc_max, tmp);
#ifdef CHECK_BOUNDS_IN_TRUNC_PR_EMULATION
auto test = (source >> (regs[i + 2]));
if (test != 0)
{
cerr << typename T::clear(source) << " has more than "
<< regs[i + 2]
<< " bits in " << regs[i + 3]
<< "-bit truncation (test value "
<< typename T::clear(test) << ")" << endl;
throw runtime_error("trunc_pr overflow");
}
#endif
int n_shift = regs[i + 3];
#ifdef ROUND_NEAREST_IN_EMULATION
res = source >> n_shift;
if (n_shift > 0)
{
bool overflow = T(source >> (n_shift - 1)).get_bit(0);
res += overflow;
}
#else
#ifdef RISKY_TRUNCATION_IN_EMULATION
T r;
r.randomize(G);
if (source.negative())
res = -T(((-source + r) >> n_shift) - (r >> n_shift));
else
res = ((source + r) >> n_shift) - (r >> n_shift);
#else
T r;
r.randomize_part(G, n_shift - 1);
res = (source + r) >> n_shift;
#endif
#endif
}
}
};

View File

@@ -25,7 +25,8 @@ protected:
public:
int values_opened;
MAC_Check_Base() : values_opened(0) {}
MAC_Check_Base(const typename T::mac_key_type::Scalar& mac_key = { }) :
alphai(mac_key), values_opened(0) {}
virtual ~MAC_Check_Base() {}
virtual void Check(const Player& P) { (void)P; }

View File

@@ -36,19 +36,12 @@ public:
void buffer_bits();
};
// extra class to avoid recursion
template<class T>
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
public virtual MalRepRingPrep<T>,
class SimplerMalRepRingPrep : public virtual MalRepRingPrep<T>,
public virtual RingOnlyBitsFromSquaresPrep<T>
{
public:
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
void set_protocol(typename T::Protocol& protocol)
{
MaliciousRingPrep<T>::set_protocol(protocol);
}
SimplerMalRepRingPrep(SubProcessor<T>* proc, DataPositions& usage);
void buffer_triples()
{
@@ -72,4 +65,27 @@ public:
}
};
template<class T>
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
public virtual SimplerMalRepRingPrep<T>
{
public:
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
void set_protocol(typename T::Protocol& protocol)
{
MaliciousRingPrep<T>::set_protocol(protocol);
}
void buffer_squares()
{
MalRepRingPrep<T>::buffer_squares();
}
void buffer_bits()
{
RingOnlyBitsFromSquaresPrep<T>::buffer_bits();
};
};
#endif /* PROTOCOLS_MALREPRINGPREP_H_ */

View File

@@ -27,13 +27,22 @@ RingOnlyBitsFromSquaresPrep<T>::RingOnlyBitsFromSquaresPrep(SubProcessor<T>*,
{
}
template<class T>
SimplerMalRepRingPrep<T>::SimplerMalRepRingPrep(SubProcessor<T>* proc,
DataPositions& usage) :
BufferPrep<T>(usage), MalRepRingPrep<T>(proc, usage),
RingOnlyBitsFromSquaresPrep<T>(proc, usage)
{
}
template<class T>
MalRepRingPrepWithBits<T>::MalRepRingPrepWithBits(SubProcessor<T>* proc,
DataPositions& usage) :
BufferPrep<T>(usage), BitPrep<T>(proc, usage),
RingPrep<T>(proc, usage),
MaliciousRingPrep<T>(proc, usage), MalRepRingPrep<T>(proc, usage),
RingOnlyBitsFromSquaresPrep<T>(proc, usage)
RingOnlyBitsFromSquaresPrep<T>(proc, usage),
SimplerMalRepRingPrep<T>(proc, usage)
{
}
@@ -54,6 +63,7 @@ void MalRepRingPrep<T>::buffer_squares()
MaliciousRepPrep<prep_type> prep(_);
assert(this->proc != 0);
prep.init_honest(this->proc->P);
prep.buffer_size = this->buffer_size;
prep.buffer_squares();
for (auto& x : prep.squares)
this->squares.push_back({{x[0], x[1]}});
@@ -68,6 +78,7 @@ void MalRepRingPrep<T>::simple_buffer_triples()
MaliciousRepPrep<prep_type> prep(_);
assert(this->proc != 0);
prep.init_honest(this->proc->P);
prep.buffer_size = this->buffer_size;
prep.buffer_triples();
for (auto& x : prep.triples)
this->triples.push_back({{x[0], x[1], x[2]}});
@@ -222,7 +233,7 @@ void RingOnlyBitsFromSquaresPrep<T>::buffer_bits()
typename BitShare::SquarePrep prep(0, usage);
SubProcessor<BitShare> bit_proc(MC, prep, proc->P);
prep.set_proc(&bit_proc);
bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep);
bits_from_square_in_ring(this->bits, this->buffer_size, &prep);
}
template<class T>

View File

@@ -11,6 +11,7 @@
template<class T> class HashMaliciousRepMC;
template<class T> class Beaver;
template<class T> class MaliciousRepPrepWithBits;
template<class T> class MaliciousRepPO;
template<class T> class MaliciousRepPrep;
namespace GC
@@ -22,6 +23,7 @@ template<class T>
class MaliciousRep3Share : public Rep3Share<T>
{
typedef Rep3Share<T> super;
typedef MaliciousRep3Share This;
public:
typedef Beaver<MaliciousRep3Share<T>> Protocol;
@@ -29,11 +31,13 @@ public:
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<MaliciousRep3Share<T>> Input;
typedef ::PrivateOutput<MaliciousRep3Share<T>> PrivateOutput;
typedef MaliciousRepPO<MaliciousRep3Share> PO;
typedef Rep3Share<T> Honest;
typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep;
typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep;
typedef MaliciousRep3Share prep_type;
typedef T random_type;
typedef This Scalar;
typedef GC::MaliciousRepSecret bit_type;

View File

@@ -144,7 +144,7 @@ void HashMaliciousRepMC<T>::Check(const Player& P)
P.Broadcast_Receive(os);
for (int i = 0; i < P.num_players(); i++)
if (os[i] != os[P.my_num()])
throw mac_fail();
throw mac_fail("check hash mismatch");
}
}

View File

@@ -0,0 +1,27 @@
/*
* MaliciousRepPO.h
*
*/
#ifndef PROTOCOLS_MALICIOUSREPPO_H_
#define PROTOCOLS_MALICIOUSREPPO_H_
#include "Networking/Player.h"
template<class T>
class MaliciousRepPO
{
Player& P;
octetStream to_send;
octetStream to_receive[2];
public:
MaliciousRepPO(Player& P);
void prepare_sending(const T& secret, int player);
void send(int player);
void receive();
typename T::clear finalize(const T& secret);
};
#endif /* PROTOCOLS_MALICIOUSREPPO_H_ */

Some files were not shown because too many files have changed in this diff Show More