mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
Rep4, SPDZ-wise, MNIST training.
This commit is contained in:
@@ -104,7 +104,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
}
|
||||
else
|
||||
{
|
||||
T::read_or_generate_mac_key(prep_dir, N, mac_key);
|
||||
T::read_or_generate_mac_key(prep_dir, *P, mac_key);
|
||||
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
ostream& EvalRegister::out = cout;
|
||||
|
||||
int Register::counter = 0;
|
||||
|
||||
void Register::init(int n_parties)
|
||||
|
||||
@@ -22,6 +22,7 @@ using namespace std;
|
||||
#include "Tools/FlexBuffer.h"
|
||||
#include "Tools/PointerVector.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
|
||||
//#define PAD_TO_8(n) (n+8-n%8)
|
||||
#define PAD_TO_8(n) (n)
|
||||
@@ -199,6 +200,7 @@ public:
|
||||
BlackHole& operator<<(T) { return *this; }
|
||||
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
|
||||
void activate(bool) {}
|
||||
void redirect_to_file(ostream&) {}
|
||||
};
|
||||
inline BlackHole& endl(BlackHole& b) { return b; }
|
||||
inline BlackHole& flush(BlackHole& b) { return b; }
|
||||
@@ -211,7 +213,6 @@ public:
|
||||
typedef NoMemory DynamicMemory;
|
||||
|
||||
typedef BlackHole out_type;
|
||||
static BlackHole out;
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
@@ -353,8 +354,7 @@ public:
|
||||
|
||||
typedef EvalInputter Input;
|
||||
|
||||
typedef ostream& out_type;
|
||||
static ostream& out;
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
|
||||
21
CHANGELOG.md
21
CHANGELOG.md
@@ -1,5 +1,24 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.0 (Oct 28, 2020)
|
||||
|
||||
- Rep4: honest-majority four-party computation with malicious security
|
||||
- SY/SPDZ-wise: honest-majority computation with malicious security based on replicated or Shamir secret sharing
|
||||
- Training with a sequence of dense layers
|
||||
- Training and inference for multi-class classification
|
||||
- Local share conversion for semi-honest protocols based on additive secret sharing modulo a power of two
|
||||
- edaBit generation based on local share conversion
|
||||
- Optimize exponentation with local share conversion
|
||||
- Optimize Shamir pseudo-random secret sharing using a hyper-invertible matrix
|
||||
- Mathematical functions (exponentation, logarithm, square root, and trigonometric functions) with binary circuits
|
||||
- Direct construction of fixed-point values from any type, breaking `sfix(x)` where `x` is the integer representation of a fixed-point number. Use `sfix._new(x)` instead.
|
||||
- Optimized dot product for `sfix`
|
||||
- Matrix multiplication via operator overloading uses VM-optimized multiplication.
|
||||
- Fake preprocessing for daBits and edaBits
|
||||
- Fixed security bug: insufficient randomness in SemiBin random bit generation.
|
||||
- Fixed security bug: insufficient randomization of FKOS15 inputs.
|
||||
- Fixed security bug in binary computation with SPDZ(2k).
|
||||
|
||||
## 0.1.9 (Aug 24, 2020)
|
||||
|
||||
- Streamline inputs to binary circuits
|
||||
@@ -7,7 +26,7 @@ The changelog explains changes pulled through from the private development repos
|
||||
- Emulator for arithmetic circuits
|
||||
- Efficient dot product with Shamir's secret sharing
|
||||
- Lower memory usage for TensorFlow inference
|
||||
- This version breaks bytecode compatibilty.
|
||||
- This version breaks bytecode compatibility.
|
||||
|
||||
## 0.1.8 (June 15, 2020)
|
||||
|
||||
|
||||
4
CONFIG
4
CONFIG
@@ -24,7 +24,9 @@ USE_GF2N_LONG = 1
|
||||
# AVX/AVX2 is required for replicated binary secret sharing
|
||||
# BMI2 is used to optimize multiplication modulo a prime
|
||||
# ADX is used to optimize big integer additions
|
||||
# delete the second line to compile for a platform that supports everything
|
||||
ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
|
||||
ARCH = -march=native
|
||||
|
||||
# allow to set compiler in CONFIG.mine
|
||||
CXX = g++
|
||||
@@ -60,7 +62,7 @@ else
|
||||
BOOST = -lboost_thread $(MY_BOOST)
|
||||
endif
|
||||
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], []))
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
bits = list(bits)
|
||||
if len(bits) == 1:
|
||||
return bits[0]
|
||||
bits = list(bits)
|
||||
@@ -72,7 +73,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
res = [self.bit_type() for i in range(n)]
|
||||
self.bitdec(self, *res)
|
||||
else:
|
||||
res = self.trans([self])
|
||||
res = self.bit_type.trans([self])
|
||||
self.decomposed = res
|
||||
return res + suffix
|
||||
else:
|
||||
@@ -83,8 +84,8 @@ class bits(Tape.Register, _structure, _bit):
|
||||
cbits.conv_cint_vec(a, *res)
|
||||
return res
|
||||
@classmethod
|
||||
def malloc(cls, size):
|
||||
return Program.prog.malloc(size, cls)
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@@ -430,6 +431,8 @@ class sbits(bits):
|
||||
def equal(self, other, n=None):
|
||||
bits = (~(self + other)).bit_decompose()
|
||||
return reduce(operator.mul, bits)
|
||||
def right_shift(self, m, k, security=None, signed=True):
|
||||
return self.TruncPr(k, m)
|
||||
def TruncPr(self, k, m, kappa=None):
|
||||
if k > self.n:
|
||||
raise Exception('TruncPr overflow: %d > %d' % (k, self.n))
|
||||
@@ -481,8 +484,8 @@ class sbitvec(_vec):
|
||||
def get_type(cls, n):
|
||||
class sbitvecn(cls, _structure):
|
||||
@staticmethod
|
||||
def malloc(size):
|
||||
return sbit.malloc(size * n)
|
||||
def malloc(size, creator_tape=None):
|
||||
return sbit.malloc(size * n, creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return n
|
||||
@@ -566,7 +569,8 @@ class sbitvec(_vec):
|
||||
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
|
||||
v = x.v
|
||||
self.v = v[:length]
|
||||
elif elements is not None:
|
||||
elif elements is not None and not (util.is_constant(elements) and \
|
||||
elements == 0):
|
||||
self.v = sbits.trans(elements)
|
||||
def popcnt(self):
|
||||
res = sbitint.wallace_tree([[b] for b in self.v])
|
||||
@@ -606,7 +610,10 @@ class sbitvec(_vec):
|
||||
return cls.from_vec(other.v)
|
||||
@property
|
||||
def size(self):
|
||||
return self.v[0].n
|
||||
if not self.v or util.is_constant(self.v[0]):
|
||||
return 1
|
||||
else:
|
||||
return self.v[0].n
|
||||
@property
|
||||
def n_bits(self):
|
||||
return len(self.v)
|
||||
@@ -725,6 +732,8 @@ class _sbitintbase:
|
||||
return self.get_type(n).bit_compose(bits)
|
||||
def round(self, k, m, kappa=None, nearest=None, signed=None):
|
||||
bits = self.bit_decompose()
|
||||
if signed:
|
||||
bits += [bits[-1]] * (k - len(bits))
|
||||
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
|
||||
return self.get_type(k - m).compose(res_bits)
|
||||
def int_div(self, other, bit_length=None):
|
||||
@@ -781,7 +790,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
# truncate and extend bits
|
||||
bits = bits[:cls.n]
|
||||
bits = list(bits)[:cls.n]
|
||||
bits += [0] * (cls.n - len(bits))
|
||||
return super(sbitint, cls).bit_compose(bits)
|
||||
def force_bit_decompose(self, n_bits=None):
|
||||
@@ -801,6 +810,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits)))
|
||||
product = a * b
|
||||
res_bits = product.bit_decompose()[m:k]
|
||||
res_bits += [res_bits[-1]] * (self.n - len(res_bits))
|
||||
t = self.combo_type(other)
|
||||
return t.bit_compose(res_bits)
|
||||
def __mul__(self, other):
|
||||
@@ -824,6 +834,15 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
else:
|
||||
res.append([(x & bit) for x in other.bit_decompose(n - i)])
|
||||
return res
|
||||
@classmethod
|
||||
def popcnt_bits(cls, bits):
|
||||
res = sbitvec.from_vec(bits).popcnt().elements()[0]
|
||||
res = cls.conv(res)
|
||||
return res
|
||||
def pow2(self, k):
|
||||
l = int(math.ceil(math.log(k, 2)))
|
||||
bits = [self.equal(i, l) for i in range(k)]
|
||||
return self.bit_compose(bits)
|
||||
|
||||
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
|
||||
def __add__(self, other):
|
||||
@@ -867,8 +886,11 @@ class cbitfix(object):
|
||||
conv = staticmethod(lambda x: x)
|
||||
load_mem = classmethod(lambda cls, *args: cls(cbits.load_mem(*args)))
|
||||
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
|
||||
def __init__(self, value):
|
||||
self.v = value
|
||||
@classmethod
|
||||
def _new(cls, value):
|
||||
res = cls()
|
||||
res.v = value
|
||||
return res
|
||||
def output(self):
|
||||
v = self.v
|
||||
if self.k < v.unit:
|
||||
@@ -897,10 +919,10 @@ class sbitfix(_fix):
|
||||
inst.inputb(player, cls.k, cls.f, v)
|
||||
return cls._new(v)
|
||||
def __xor__(self, other):
|
||||
return type(self)(self.v ^ other.v)
|
||||
return type(self)._new(self.v ^ other.v)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbit):
|
||||
return type(self)(self.int_type(other * self.v))
|
||||
return type(self)._new(self.int_type(other * self.v))
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return other * self
|
||||
else:
|
||||
@@ -911,10 +933,11 @@ class sbitfix(_fix):
|
||||
def multipliable(other, k, f, size):
|
||||
class cls(_fix):
|
||||
int_type = sbitint.get_type(k)
|
||||
clear_type = cbitfix
|
||||
cls.set_precision(f, k)
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
sbitfix.set_precision(20, 41)
|
||||
sbitfix.set_precision(16, 31)
|
||||
|
||||
class sbitfixvec(_fix):
|
||||
int_type = sbitintvec
|
||||
|
||||
@@ -220,6 +220,7 @@ class Merger:
|
||||
else:
|
||||
self.max_parallel_open = float('inf')
|
||||
self.counter = defaultdict(lambda: 0)
|
||||
self.rounds = defaultdict(lambda: 0)
|
||||
self.dependency_graph(merge_classes)
|
||||
|
||||
def do_merge(self, merges_iter):
|
||||
@@ -271,6 +272,7 @@ class Merger:
|
||||
merge = merges[i]
|
||||
t = type(self.instructions[merge[0]])
|
||||
self.counter[t] += len(merge)
|
||||
self.rounds[t] += 1
|
||||
if len(merge) > 10000:
|
||||
print('Merging %d %s in round %d/%d' % \
|
||||
(len(merge), t.__name__, i, len(merges)))
|
||||
|
||||
@@ -135,27 +135,37 @@ def Trunc(d, a, k, m, kappa, signed):
|
||||
mulm(d, t, c[2])
|
||||
|
||||
def TruncRing(d, a, k, m, signed):
|
||||
if program.use_split() == 3:
|
||||
program.curr_tape.require_bit_length(1)
|
||||
if program.use_split() in (2, 3):
|
||||
if signed:
|
||||
a += (1 << (k - 1))
|
||||
from Compiler.types import sint
|
||||
from .GC.types import sbitint
|
||||
length = int(program.options.ring)
|
||||
summands = a.split_to_n_summands(length, 3)
|
||||
summands = a.split_to_n_summands(length, program.use_split())
|
||||
x = sbitint.wallace_tree_without_finish(summands, True)
|
||||
if m == 1:
|
||||
low = x[1][1]
|
||||
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
|
||||
sint.conv(x[0][-1])
|
||||
if program.use_split() == 2:
|
||||
carries = sbitint.get_carries(*x)
|
||||
low = carries[m]
|
||||
high = sint.conv(carries[length])
|
||||
else:
|
||||
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
|
||||
low = sint.conv(mid_carry) + sint.conv(x[0][m])
|
||||
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
|
||||
for xx, yy in zip(x[1][m:-1],
|
||||
x[0][m:-1])))
|
||||
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
|
||||
high = top_carry + sint.conv(x[0][-1])
|
||||
if m == 1:
|
||||
low = x[1][1]
|
||||
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
|
||||
sint.conv(x[0][-1])
|
||||
else:
|
||||
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
|
||||
low = sint.conv(mid_carry) + sint.conv(x[0][m])
|
||||
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
|
||||
for xx, yy in zip(x[1][m:-1],
|
||||
x[0][m:-1])))
|
||||
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
|
||||
high = top_carry + sint.conv(x[0][-1])
|
||||
shifted = sint()
|
||||
shrsi(shifted, a, m)
|
||||
res = shifted + sint.conv(low) - (high << (length - m))
|
||||
if signed:
|
||||
res -= (1 << (k - m - 1))
|
||||
else:
|
||||
a_prime = Mod2mRing(None, a, k, m, signed)
|
||||
a -= a_prime
|
||||
|
||||
@@ -53,6 +53,12 @@ def maskField(a, k, kappa):
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def EQZ(a, k, kappa):
|
||||
prog = program.Program.prog
|
||||
if prog.use_split():
|
||||
from GC.types import sbitvec
|
||||
v = sbitvec(a, k).v
|
||||
bit = util.tree_reduce(operator.and_, (~b for b in v))
|
||||
return types.sint.conv(bit)
|
||||
if program.Program.prog.options.ring:
|
||||
c, r = maskRing(a, k)
|
||||
else:
|
||||
@@ -307,16 +313,22 @@ def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
def BitDecRing(a, k, m):
|
||||
n_shift = int(program.Program.prog.options.ring) - m
|
||||
assert(n_shift >= 0)
|
||||
if program.Program.prog.use_dabit:
|
||||
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
|
||||
r = types.sint.bit_compose(r)
|
||||
if program.Program.prog.use_split():
|
||||
x = a.split_to_two_summands(m)
|
||||
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
else:
|
||||
r_bits = [types.sint.get_random_bit() for i in range(m)]
|
||||
r = types.sint.bit_compose(r_bits)
|
||||
shifted = ((a - r) << n_shift).reveal()
|
||||
masked = shifted >> n_shift
|
||||
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
|
||||
return [types.sint.conv(bit) for bit in bits]
|
||||
if program.Program.prog.use_dabit:
|
||||
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
|
||||
r = types.sint.bit_compose(r)
|
||||
else:
|
||||
r_bits = [types.sint.get_random_bit() for i in range(m)]
|
||||
r = types.sint.bit_compose(r_bits)
|
||||
shifted = ((a - r) << n_shift).reveal()
|
||||
masked = shifted >> n_shift
|
||||
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
|
||||
return [types.sint.conv(bit) for bit in bits]
|
||||
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
r_dprime = types.sint()
|
||||
@@ -476,22 +488,20 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
|
||||
def Int2FL(a, gamma, l, kappa):
|
||||
lam = gamma - 1
|
||||
s = types.sint()
|
||||
comparison.LTZ(s, a, gamma, kappa)
|
||||
z = EQZ(a, gamma, kappa)
|
||||
a = (1 - 2 * s) * a
|
||||
a_bits = BitDec(a, lam, lam, kappa)
|
||||
s = a.less_than(0, gamma, security=kappa)
|
||||
z = a.equal(0, gamma, security=kappa)
|
||||
a = s.if_else(-a, a)
|
||||
a_bits = a.bit_decompose(lam, security=kappa)
|
||||
a_bits.reverse()
|
||||
b = PreOR(a_bits, kappa)
|
||||
t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b)))
|
||||
p = - (lam - sum(b))
|
||||
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
|
||||
p = a.popcnt_bits(b) - lam
|
||||
if gamma - 1 > l:
|
||||
if types.sfloat.round_nearest:
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
|
||||
p = p + overflow
|
||||
else:
|
||||
v = types.sint()
|
||||
comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
|
||||
else:
|
||||
v = 2**(l-gamma+1) * t
|
||||
p = (p + gamma - 1 - l) * (1 -z)
|
||||
@@ -539,6 +549,7 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
n_ring = int(program.Program.prog.options.ring)
|
||||
assert n_ring >= k, '%d too large' % k
|
||||
if k == n_ring:
|
||||
program.Program.prog.curr_tape.require_bit_length(1)
|
||||
if program.Program.prog.use_edabit():
|
||||
a += types.sint.get_edabit(m, True)[0]
|
||||
else:
|
||||
@@ -555,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
else:
|
||||
# extra bit to mask overflow
|
||||
prog = program.Program.prog
|
||||
if prog.use_edabit() or prog.use_split() == 3:
|
||||
prog.curr_tape.require_bit_length(1)
|
||||
if prog.use_edabit() or prog.use_split() > 2:
|
||||
lower = sint.get_random_int(m)
|
||||
upper = sint.get_random_int(k - m)
|
||||
msb = sint.get_random_bit()
|
||||
|
||||
@@ -441,6 +441,8 @@ def cisc(function):
|
||||
program.options.cisc = True
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
for x, bl in tape.req_bit_length.items():
|
||||
old_tape.require_bit_length(bl, x)
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
@@ -523,25 +525,26 @@ def ret_cisc(function):
|
||||
|
||||
def sfix_cisc(function):
|
||||
from Compiler.types import sfix, sint, cfix, copy_doc
|
||||
def instruction(res, arg, k, f):
|
||||
def instruction(res, arg, k, f, *args):
|
||||
assert k is not None
|
||||
assert f is not None
|
||||
old = sfix.k, sfix.f, cfix.k, cfix.f
|
||||
sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4
|
||||
res.mov(res, function(sfix._new(arg, k=k, f=f)).v)
|
||||
res.mov(res, function(sfix._new(arg, k=k, f=f), *args).v)
|
||||
sfix.k, sfix.f, cfix.k, cfix.f = old
|
||||
instruction.__name__ = function.__name__
|
||||
instruction = cisc(instruction)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if isinstance(args[0], sfix):
|
||||
assert len(args) == 1
|
||||
for arg in args[1:]:
|
||||
assert util.is_constant(arg)
|
||||
assert not kwargs
|
||||
assert args[0].size == args[0].v.size
|
||||
k = args[0].k
|
||||
f = args[0].f
|
||||
res = sfix._new(sint(size=args[0].size), k=k, f=f)
|
||||
instruction(res.v, args[0].v, k, f)
|
||||
instruction(res.v, args[0].v, k, f, *args[1:])
|
||||
return res
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
@@ -134,6 +134,7 @@ def print_ln_if(cond, ss, *args):
|
||||
print_str_if(cond, ss + '\n', *args)
|
||||
|
||||
def print_str_if(cond, ss, *args):
|
||||
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
|
||||
if util.is_constant(cond):
|
||||
if cond:
|
||||
print_ln(ss, *args)
|
||||
@@ -160,7 +161,8 @@ def print_str_if(cond, ss, *args):
|
||||
|
||||
def print_ln_to(player, ss, *args):
|
||||
""" Print line at :py:obj:`player` only. Note that printing is
|
||||
disabled by default except at player 0.
|
||||
disabled by default except at player 0. Activate interactive mode
|
||||
with `-I` to enable it for all players.
|
||||
|
||||
:param player: int
|
||||
:param ss: Python string
|
||||
@@ -814,8 +816,8 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
if step is None:
|
||||
step = 1
|
||||
def loop_fn(i):
|
||||
loop_body(i)
|
||||
return i + step
|
||||
res = loop_body(i)
|
||||
return util.if_else(res == 0, stop, i + step)
|
||||
if isinstance(step, int):
|
||||
if step > 0:
|
||||
condition = lambda x: x < stop
|
||||
@@ -840,7 +842,9 @@ def for_range(start, stop=None, step=None):
|
||||
in Python :py:func:`range`, but they can by any public
|
||||
integer. Information has to be passed out via container types such
|
||||
as :py:class:`Compiler.types.Array` or declaring registers as
|
||||
:py:obj:`global`.
|
||||
:py:obj:`global`. Note that changing Python data structures such
|
||||
as lists within the loop is not possible, but the compiler cannot
|
||||
warn about this.
|
||||
|
||||
:param start/stop/step: regint/cint/int
|
||||
|
||||
@@ -1057,7 +1061,7 @@ def for_range_opt_multithread(n_threads, n_loops):
|
||||
"""
|
||||
return for_range_multithread(n_threads, None, n_loops)
|
||||
|
||||
def multithread(n_threads, n_items):
|
||||
def multithread(n_threads, n_items, max_size=None):
|
||||
"""
|
||||
Distribute the computation of :py:obj:`n_items` to
|
||||
:py:obj:`n_threads` threads, but leave the in-thread repetition up
|
||||
@@ -1075,8 +1079,19 @@ def multithread(n_threads, n_items):
|
||||
def f(base, size):
|
||||
...
|
||||
"""
|
||||
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
|
||||
reducer=None, looping=False)
|
||||
if max_size is None:
|
||||
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
|
||||
reducer=None, looping=False)
|
||||
else:
|
||||
def wrapper(function):
|
||||
@multithread(n_threads, n_items)
|
||||
def new_function(base, size):
|
||||
for i in range(0, size, max_size):
|
||||
part_base = base + i
|
||||
part_size = min(max_size, size - i)
|
||||
function(part_base, part_size)
|
||||
break_point()
|
||||
return wrapper
|
||||
|
||||
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
thread_mem_req={}, looping=True):
|
||||
@@ -1563,8 +1578,8 @@ def cint_cint_division(a, b, k, f):
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
|
||||
sign_b = cint(1) - 2 * cint(b < 0)
|
||||
sign_a = cint(1) - 2 * cint(a < 0)
|
||||
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
|
||||
sign_a = cint(1) - 2 * cint(a.less_than(0, k))
|
||||
absolute_b = b * sign_b
|
||||
absolute_a = a * sign_a
|
||||
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
||||
@@ -1632,9 +1647,12 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
f = max((k - nearest) // 2 + 1, f)
|
||||
assert 2 * f > k - nearest
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
|
||||
base.set_global_vector_size(b.size)
|
||||
alpha = b.get_type(2 * k).two_power(2*f)
|
||||
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
|
||||
x = alpha - b.extend(2 * k) * w
|
||||
base.reset_global_vector_size()
|
||||
|
||||
y = a.extend(2 *k) * w
|
||||
y = y.round(2*k, f, kappa, nearest, signed=True)
|
||||
|
||||
618
Compiler/ml.py
618
Compiler/ml.py
@@ -42,6 +42,7 @@ an example of how to run MP-SPDZ on TensorFlow graphs.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from Compiler import mpc_math, util
|
||||
from Compiler.types import *
|
||||
@@ -58,13 +59,13 @@ def log_e(x):
|
||||
def exp(x):
|
||||
return mpc_math.pow_fx(math.e, x)
|
||||
|
||||
def sanitize(x, raw, lower, upper):
|
||||
def get_limit(x):
|
||||
exp_limit = 2 ** (x.k - x.f - 1)
|
||||
limit = math.log(exp_limit)
|
||||
if get_program().options.ring:
|
||||
res = raw
|
||||
else:
|
||||
res = (x > limit).if_else(upper, raw)
|
||||
return math.log(exp_limit)
|
||||
|
||||
def sanitize(x, raw, lower, upper):
|
||||
limit = get_limit(x)
|
||||
res = (x > limit).if_else(upper, raw)
|
||||
return (x < -limit).if_else(lower, res)
|
||||
|
||||
def sigmoid(x):
|
||||
@@ -137,10 +138,12 @@ def argmax(x):
|
||||
return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1])
|
||||
return tree_reduce(op, enumerate(x))[0]
|
||||
|
||||
report_progress = False
|
||||
|
||||
def progress(x):
|
||||
return
|
||||
print_ln(x)
|
||||
time()
|
||||
if report_progress:
|
||||
print_ln(x)
|
||||
time()
|
||||
|
||||
def set_n_threads(n_threads):
|
||||
Layer.n_threads = n_threads
|
||||
@@ -159,6 +162,10 @@ class Tensor(MultiArray):
|
||||
self.alloc()
|
||||
return super(Tensor, self).__getitem__(*args)
|
||||
|
||||
def assign_vector(self, *args):
|
||||
self.alloc()
|
||||
return super(Tensor, self).assign_vector(*args)
|
||||
|
||||
class Layer:
|
||||
n_threads = 1
|
||||
inputs = []
|
||||
@@ -190,12 +197,23 @@ class Layer:
|
||||
class NoVariableLayer(Layer):
|
||||
input_from = lambda *args, **kwargs: None
|
||||
|
||||
class Output(Layer):
|
||||
nablas = lambda self: ()
|
||||
reset = lambda self: None
|
||||
|
||||
class Output(NoVariableLayer):
|
||||
""" Fixed-point logistic regression output layer.
|
||||
|
||||
:param N: number of examples
|
||||
:param approx: :py:obj:`False` (default) or parameter for :py:obj:`approx_sigmoid`
|
||||
"""
|
||||
n_outputs = 2
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, N, program):
|
||||
res = cls(N, approx='approx' in program.args)
|
||||
res.compute_loss = not 'no_loss' in program.args
|
||||
return res
|
||||
|
||||
def __init__(self, N, debug=False, approx=False):
|
||||
self.N = N
|
||||
self.X = sfix.Array(N)
|
||||
@@ -206,9 +224,7 @@ class Output(Layer):
|
||||
self.debug = debug
|
||||
self.weights = None
|
||||
self.approx = approx
|
||||
|
||||
nablas = lambda self: ()
|
||||
reset = lambda self: None
|
||||
self.compute_loss = True
|
||||
|
||||
def divisor(self, divisor, size):
|
||||
return cfix(1.0 / divisor, size=size)
|
||||
@@ -224,11 +240,13 @@ class Output(Layer):
|
||||
x = self.X.get_vector(base, size)
|
||||
y = self.Y.get(batch.get_vector(base, size))
|
||||
if self.approx:
|
||||
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
|
||||
if self.compute_loss:
|
||||
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
|
||||
return
|
||||
e_x = exp(-x)
|
||||
self.e_x.assign(e_x, base)
|
||||
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
||||
if self.compute_loss:
|
||||
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
||||
self.l.write(sum(lse) * \
|
||||
self.divisor(N, 1))
|
||||
|
||||
@@ -246,13 +264,10 @@ class Output(Layer):
|
||||
diff = self.eval(size, base) - \
|
||||
self.Y.get(batch.get_vector(base, size))
|
||||
assert sfix.f == cfix.f
|
||||
if self.weights is None:
|
||||
diff *= self.divisor(N, size)
|
||||
else:
|
||||
if self.weights is not None:
|
||||
assert N == len(self.weights)
|
||||
diff *= self.weights.get_vector(base, size)
|
||||
if self.weight_total != 1:
|
||||
diff *= self.divisor(self.weight_total, size)
|
||||
assert self.weight_total == N
|
||||
self.nabla_X.assign(diff, base)
|
||||
# @for_range_opt(len(diff))
|
||||
# def _(i):
|
||||
@@ -271,6 +286,244 @@ class Output(Layer):
|
||||
self.weights.assign(weights)
|
||||
self.weight_total = sum(weights)
|
||||
|
||||
def average_loss(self, N):
|
||||
return self.l.reveal()
|
||||
|
||||
def reveal_correctness(self, n=None, Y=None, debug=False):
|
||||
if n is None:
|
||||
n = self.X.sizes[0]
|
||||
if Y is None:
|
||||
Y = self.Y
|
||||
n_correct = MemValue(0)
|
||||
n_printed = MemValue(0)
|
||||
@for_range_opt(n)
|
||||
def _(i):
|
||||
truth = Y[i].reveal()
|
||||
b = self.X[i].reveal()
|
||||
if debug:
|
||||
nabla = self.nabla_X[i].reveal()
|
||||
guess = b > 0
|
||||
correct = truth == guess
|
||||
n_correct.iadd(correct)
|
||||
if debug:
|
||||
to_print = (1 - correct) * (n_printed < 10)
|
||||
n_printed.iadd(to_print)
|
||||
print_ln_if(to_print, '%s: %s %s %s %s',
|
||||
i, truth, guess, b, nabla)
|
||||
return n_correct
|
||||
|
||||
class MultiOutputBase(NoVariableLayer):
|
||||
def __init__(self, N, d_out, approx=False, debug=False):
|
||||
self.X = sfix.Matrix(N, d_out)
|
||||
self.Y = sint.Matrix(N, d_out)
|
||||
self.nabla_X = sfix.Matrix(N, d_out)
|
||||
self.l = MemValue(sfix(-1))
|
||||
self.losses = sfix.Array(N)
|
||||
self.approx = None
|
||||
self.N = N
|
||||
self.d_out = d_out
|
||||
self.compute_loss = True
|
||||
|
||||
def eval(self, N):
|
||||
d_out = self.X.sizes[1]
|
||||
res = sfix.Matrix(N, d_out)
|
||||
res.assign_vector(self.X.get_part_vector(0, N))
|
||||
return res
|
||||
|
||||
def average_loss(self, N):
|
||||
return sum(self.losses.get_vector(0, N)).reveal() / N
|
||||
|
||||
def reveal_correctness(self, n=None, Y=None, debug=False):
|
||||
if n is None:
|
||||
n = self.X.sizes[0]
|
||||
if Y is None:
|
||||
Y = self.Y
|
||||
n_correct = MemValue(0)
|
||||
n_printed = MemValue(0)
|
||||
@for_range_opt(n)
|
||||
def _(i):
|
||||
a = Y[i].reveal_list()
|
||||
b = self.X[i].reveal_list()
|
||||
if debug:
|
||||
loss = self.losses[i].reveal()
|
||||
exp = self.get_extra_debugging(i)
|
||||
nabla = self.nabla_X[i].reveal_list()
|
||||
truth = argmax(a)
|
||||
guess = argmax(b)
|
||||
correct = truth == guess
|
||||
n_correct.iadd(correct)
|
||||
if debug:
|
||||
to_print = (1 - correct) * (n_printed < 10)
|
||||
n_printed.iadd(to_print)
|
||||
print_ln_if(to_print, '%s: %s %s %s %s %s %s',
|
||||
i, truth, guess, loss, b, exp, nabla)
|
||||
return n_correct
|
||||
|
||||
@property
|
||||
def n_outputs(self):
|
||||
return self.d_out
|
||||
|
||||
def get_extra_debugging(self, i):
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def from_args(program, N, n_output):
|
||||
if 'relu_out' in program.args:
|
||||
res = ReluMultiOutput(N, n_output)
|
||||
else:
|
||||
res = MultiOutput(N, n_output, approx='approx' in program.args)
|
||||
res.cheaper_loss = 'mse' in program.args
|
||||
res.compute_loss = not 'no_loss' in program.args
|
||||
return res
|
||||
|
||||
class MultiOutput(MultiOutputBase):
|
||||
"""
|
||||
Output layer for multi-class classification with softmax and cross entropy.
|
||||
|
||||
:param N: number of examples
|
||||
:param d_out: number of classes
|
||||
:param approx: use ReLU division instead of softmax for the loss
|
||||
"""
|
||||
def __init__(self, N, d_out, approx=False, debug=False):
|
||||
MultiOutputBase.__init__(self, N, d_out)
|
||||
self.exp = sfix.Matrix(N, d_out)
|
||||
self.approx = approx
|
||||
self.positives = sint.Matrix(N, d_out)
|
||||
self.relus = sfix.Matrix(N, d_out)
|
||||
self.cheaper_loss = False
|
||||
self.debug = debug
|
||||
self.true_X = sfix.Array(N)
|
||||
|
||||
def forward(self, batch):
|
||||
N = len(batch)
|
||||
d_out = self.X.sizes[1]
|
||||
tmp = self.losses
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
def _(i):
|
||||
if self.approx:
|
||||
positives = self.X[i].get_vector() > (0 if self.cheaper_loss else 0.1)
|
||||
relus = positives.if_else(self.X[i].get_vector(), 0)
|
||||
self.positives[i].assign_vector(positives)
|
||||
self.relus[i].assign_vector(relus)
|
||||
if self.compute_loss:
|
||||
if self.cheaper_loss:
|
||||
s = sum(relus)
|
||||
tmp[i] = sum((self.Y[batch[i]][j] * s - relus[j]) ** 2
|
||||
for j in range(d_out)) / s ** 2 * 0.5
|
||||
else:
|
||||
div = relus / sum(relus).expand_to_vector(d_out)
|
||||
self.losses[i] = -sfix.dot_product(
|
||||
self.Y[batch[i]].get_vector(), log_e(div))
|
||||
else:
|
||||
m = util.max(self.X[i])
|
||||
mv = m.expand_to_vector(d_out)
|
||||
x = self.X[i].get_vector()
|
||||
e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0)
|
||||
self.exp[i].assign_vector(e)
|
||||
if self.compute_loss:
|
||||
true_X = sfix.dot_product(self.Y[batch[i]], self.X[i])
|
||||
tmp[i] = m + log_e(sum(e)) - true_X
|
||||
self.true_X[i] = true_X
|
||||
self.l.write(sum(tmp.get_vector(0, N)) / N)
|
||||
|
||||
def eval(self, N):
|
||||
d_out = self.X.sizes[1]
|
||||
res = sfix.Matrix(N, d_out)
|
||||
if self.approx:
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
def _(i):
|
||||
relus = (self.X[i].get_vector() > 0).if_else(
|
||||
self.X[i].get_vector(), 0)
|
||||
res[i].assign_vector(relus / sum(relus).expand_to_vector(d_out))
|
||||
return res
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
def _(i):
|
||||
e = exp(self.X[i].get_vector())
|
||||
res[i].assign_vector(e / sum(e).expand_to_vector(d_out))
|
||||
return res
|
||||
|
||||
def backward(self, batch):
|
||||
d_out = self.X.sizes[1]
|
||||
if self.approx:
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
if self.cheaper_loss:
|
||||
s = sum(self.relus[i])
|
||||
ss = s * s * s
|
||||
inv = 1 / ss
|
||||
@for_range_opt(d_out)
|
||||
def _(j):
|
||||
res = 0
|
||||
for k in range(d_out):
|
||||
relu = self.relus[i][k]
|
||||
summand = relu - self.Y[batch[i]][k] * s
|
||||
summand *= (sfix.from_sint(j == k) - relu)
|
||||
res += summand
|
||||
fallback = -self.Y[batch[i]][j]
|
||||
res *= inv
|
||||
self.nabla_X[i][j] = self.positives[i][j].if_else(res, fallback)
|
||||
return
|
||||
relus = self.relus[i].get_vector()
|
||||
positives = self.positives[i].get_vector()
|
||||
inv = (1 / sum(relus)).expand_to_vector(d_out)
|
||||
truths = self.Y[batch[i]].get_vector()
|
||||
raw = truths / relus - inv
|
||||
self.nabla_X[i] = -positives.if_else(raw, truths)
|
||||
self.maybe_debug_backward(batch)
|
||||
return
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
for j in range(d_out):
|
||||
dividend = self.exp[i][j]
|
||||
divisor = sum(self.exp[i])
|
||||
div = (divisor > 0.1).if_else(dividend / divisor, 0)
|
||||
self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div)
|
||||
self.maybe_debug_backward(batch)
|
||||
|
||||
def maybe_debug_backward(self, batch):
|
||||
if self.debug:
|
||||
@for_range(len(batch))
|
||||
def _(i):
|
||||
check = 0
|
||||
for j in range(self.X.sizes[1]):
|
||||
to_check = self.nabla_X[i][j].reveal()
|
||||
check += (to_check > len(batch)) + (to_check < -len(batch))
|
||||
print_ln_if(check, 'X %s', self.X[i].reveal_nested())
|
||||
print_ln_if(check, 'exp %s', self.exp[i].reveal_nested())
|
||||
print_ln_if(check, 'nabla X %s',
|
||||
self.nabla_X[i].reveal_nested())
|
||||
|
||||
def get_extra_debugging(self, i):
|
||||
if self.approx:
|
||||
return self.relus[i].reveal_list()
|
||||
else:
|
||||
return self.exp[i].reveal_list()
|
||||
|
||||
class ReluMultiOutput(MultiOutputBase):
|
||||
"""
|
||||
Output layer for multi-class classification with back-propagation
|
||||
based on ReLU division.
|
||||
|
||||
:param N: number of examples
|
||||
:param d_out: number of classes
|
||||
"""
|
||||
def forward(self, batch):
|
||||
self.l.write(999)
|
||||
|
||||
def backward(self, batch):
|
||||
N = len(batch)
|
||||
d_out = self.X.sizes[1]
|
||||
relus = sfix.Matrix(N, d_out)
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
positives = self.X[i].get_vector() > 0
|
||||
relus = positives.if_else(self.X[i].get_vector(), 0)
|
||||
s = sum(relus)
|
||||
inv = 1 / s
|
||||
prod = relus * inv
|
||||
res = prod - self.Y[batch[i]].get_vector()
|
||||
self.nabla_X[i].assign_vector(res)
|
||||
|
||||
class DenseBase(Layer):
|
||||
thetas = lambda self: (self.W, self.b)
|
||||
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
||||
@@ -279,26 +532,20 @@ class DenseBase(Layer):
|
||||
N = len(batch)
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
|
||||
assert self.d == 1
|
||||
if self.d_out == 1:
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
A = sfix.Matrix(1, self.N, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
mp = A.direct_mul(B, reduce=False,
|
||||
indices=(regint(0, size=1),
|
||||
regint.inc(N),
|
||||
batch.get_vector(),
|
||||
regint.inc(size, base)))
|
||||
tmp.assign_vector(mp, base)
|
||||
else:
|
||||
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
|
||||
def _(j, k):
|
||||
a = [f_schur_Y[i][0][k] for i in range(N)]
|
||||
b = [self.X[i][0][j] for i in batch]
|
||||
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
mp = B.direct_trans_mul(A, reduce=False,
|
||||
indices=(regint.inc(size, base),
|
||||
batch.get_vector(),
|
||||
regint.inc(N),
|
||||
regint.inc(self.d_out)))
|
||||
tmp.assign_part_vector(mp, base)
|
||||
|
||||
if self.d_in * self.d_out < 100000:
|
||||
progress('nabla W (matmul)')
|
||||
|
||||
if self.d_in * self.d_out < 200000:
|
||||
print('reduce at once')
|
||||
@multithread(self.n_threads, self.d_in * self.d_out)
|
||||
def _(base, size):
|
||||
@@ -309,10 +556,46 @@ class DenseBase(Layer):
|
||||
def _(i):
|
||||
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
|
||||
|
||||
self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N))
|
||||
for j in range(self.d)) for i in range(self.d_out))
|
||||
progress('nabla W')
|
||||
|
||||
progress('nabla W/b')
|
||||
self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector()
|
||||
for k in range(N))
|
||||
for j in range(self.d)))
|
||||
|
||||
progress('nabla b')
|
||||
|
||||
if self.debug:
|
||||
limit = N * self.debug
|
||||
@for_range_opt(self.d_in)
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.nabla_W[i][j].reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check)
|
||||
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
||||
for k in range(N)])
|
||||
print_ln('X %s', [self.X[k][0][i].reveal()
|
||||
for k in range(N)])
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.nabla_b[j].reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('nabla b %s %s: %s', j, len(self.b), to_check)
|
||||
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
||||
for k in range(N)])
|
||||
@for_range_opt(len(batch))
|
||||
def _(i):
|
||||
to_check = self.nabla_X[i].get_vector().reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('X %s %s', i, self.X[i].reveal_nested())
|
||||
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
|
||||
|
||||
class Dense(DenseBase):
|
||||
""" Fixed-point dense (matrix multiplication) layer.
|
||||
@@ -321,7 +604,7 @@ class Dense(DenseBase):
|
||||
:param d_in: input dimension
|
||||
:param d_out: output dimension
|
||||
"""
|
||||
def __init__(self, N, d_in, d_out, d=1, activation='id'):
|
||||
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
|
||||
self.activation = activation
|
||||
if activation == 'id':
|
||||
self.f = lambda x: x
|
||||
@@ -349,15 +632,13 @@ class Dense(DenseBase):
|
||||
|
||||
self.f_input = MultiArray([N, d, d_out], sfix)
|
||||
|
||||
self.debug = debug
|
||||
|
||||
def reset(self):
|
||||
d_in = self.d_in
|
||||
d_out = self.d_out
|
||||
r = math.sqrt(6.0 / (d_in + d_out))
|
||||
@for_range(d_in)
|
||||
def _(i):
|
||||
@for_range(d_out)
|
||||
def _(j):
|
||||
self.W[i][j] = sfix.get_random(-r, r)
|
||||
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size()))
|
||||
self.b.assign_all(0)
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
@@ -372,15 +653,14 @@ class Dense(DenseBase):
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
else:
|
||||
prod = self.f_input
|
||||
@multithread(self.n_threads, N)
|
||||
max_size = program.Program.prog.budget // self.d_out
|
||||
@multithread(self.n_threads, N, max_size)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
prod.assign_vector(
|
||||
X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size),
|
||||
regint.inc(self.d_in),
|
||||
regint.inc(self.d_in),
|
||||
regint.inc(self.d_out))),
|
||||
base)
|
||||
prod.assign_part_vector(
|
||||
X_sub.direct_mul(self.W, indices=(
|
||||
batch.get_vector(base, size), regint.inc(self.d_in),
|
||||
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
||||
|
||||
if self.input_bias:
|
||||
if self.d_out == 1:
|
||||
@@ -389,7 +669,7 @@ class Dense(DenseBase):
|
||||
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
||||
self.f_input.assign_vector(v, base)
|
||||
else:
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
@for_range_multithread(self.n_threads, 100, N)
|
||||
def _(i):
|
||||
v = prod[i].get_vector() + self.b.get_vector()
|
||||
self.f_input[i].assign_vector(v)
|
||||
@@ -397,8 +677,24 @@ class Dense(DenseBase):
|
||||
|
||||
def forward(self, batch=None):
|
||||
self.compute_f_input(batch=batch)
|
||||
self.Y.assign_vector(self.f(
|
||||
self.f_input.get_part_vector(0, len(batch))))
|
||||
@multithread(self.n_threads, len(batch), 128)
|
||||
def _(base, size):
|
||||
self.Y.assign_part_vector(self.f(
|
||||
self.f_input.get_part_vector(base, size)), base)
|
||||
if self.debug:
|
||||
limit = self.debug
|
||||
@for_range_opt(len(batch))
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.Y[i][0][j].reveal()
|
||||
check = to_check > limit
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
|
||||
print_ln('X %s', self.X[i].reveal_nested())
|
||||
print_ln('W %s',
|
||||
[self.W[k][j].reveal() for k in range(self.d_in)])
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
N = len(batch)
|
||||
@@ -419,26 +715,31 @@ class Dense(DenseBase):
|
||||
f_prime_bit = MultiArray([N, d, d_out], sint)
|
||||
f_schur_Y = MultiArray([N, d, d_out], sfix)
|
||||
|
||||
self.compute_f_input()
|
||||
f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector()))
|
||||
@multithread(self.n_threads, f_prime_bit.total_size())
|
||||
def _(base, size):
|
||||
f_prime_bit.assign_vector(
|
||||
self.f_prime(self.f_input.get_vector(base, size)), base)
|
||||
|
||||
progress('f prime')
|
||||
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
i = batch[i]
|
||||
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
|
||||
@multithread(self.n_threads, f_prime_bit.total_size())
|
||||
def _(base, size):
|
||||
f_schur_Y.assign_vector(nabla_Y.get_vector(base, size) *
|
||||
f_prime_bit.get_vector(base, size),
|
||||
base)
|
||||
|
||||
progress('f prime schur Y')
|
||||
|
||||
if compute_nabla_X:
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
i = batch[i]
|
||||
if self.activation == 'id':
|
||||
nabla_X[i] = nabla_Y[i].mul_trans(W)
|
||||
else:
|
||||
nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W)
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
|
||||
nabla_X.assign_part_vector(
|
||||
B.direct_mul_trans(W, indices=(regint.inc(size, base),
|
||||
regint.inc(self.d_out),
|
||||
regint.inc(self.d_out),
|
||||
regint.inc(self.d_in))),
|
||||
base)
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
@@ -1151,6 +1452,7 @@ class QuantSoftmax(QuantBase, BaseLayer):
|
||||
class Optimizer:
|
||||
""" Base class for graphs of layers. """
|
||||
n_threads = Layer.n_threads
|
||||
always_shuffle = True
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@@ -1175,6 +1477,14 @@ class Optimizer:
|
||||
layer.last_used = list(filter(lambda x: x not in used, layer.inputs))
|
||||
used.update(layer.inputs)
|
||||
|
||||
def batch_for(self, layer, batch):
|
||||
if layer in (self.layers[0], self.layers[-1]):
|
||||
return batch
|
||||
else:
|
||||
batch = regint.Array(len(batch))
|
||||
batch.assign(regint.inc(len(batch)))
|
||||
return batch
|
||||
|
||||
def forward(self, N=None, batch=None, keep_intermediate=True,
|
||||
model_from=None):
|
||||
""" Compute graph.
|
||||
@@ -1193,7 +1503,7 @@ class Optimizer:
|
||||
if model_from is not None:
|
||||
layer.input_from(model_from)
|
||||
break_point()
|
||||
layer.forward(batch=batch)
|
||||
layer.forward(batch=self.batch_for(layer, batch))
|
||||
break_point()
|
||||
if not keep_intermediate:
|
||||
for l in layer.last_used:
|
||||
@@ -1212,26 +1522,30 @@ class Optimizer:
|
||||
""" Compute backward propagation. """
|
||||
for layer in reversed(self.layers):
|
||||
if len(layer.inputs) == 0:
|
||||
layer.backward(compute_nabla_X=False, batch=batch)
|
||||
layer.backward(compute_nabla_X=False,
|
||||
batch=self.batch_for(layer, batch))
|
||||
else:
|
||||
layer.backward(batch=batch)
|
||||
layer.backward(batch=self.batch_for(layer, batch))
|
||||
if len(layer.inputs) == 1:
|
||||
layer.inputs[0].nabla_Y.alloc()
|
||||
layer.inputs[0].nabla_Y.assign_vector(
|
||||
layer.nabla_X.get_part_vector(0, len(batch)))
|
||||
|
||||
def run(self, batch_size=None):
|
||||
def run(self, batch_size=None, stop_on_loss=0):
|
||||
""" Run training.
|
||||
|
||||
:param batch_size: batch size (defaults to example size of first layer)
|
||||
"""
|
||||
if self.n_epochs == 0:
|
||||
return
|
||||
if batch_size is not None:
|
||||
N = batch_size
|
||||
else:
|
||||
N = self.layers[0].N
|
||||
i = MemValue(0)
|
||||
i = self.i_epoch
|
||||
n_iterations = MemValue(0)
|
||||
@do_while
|
||||
def _():
|
||||
@for_range(self.n_epochs)
|
||||
def _(_):
|
||||
if self.X_by_label is None:
|
||||
self.X_by_label = [[None] * self.layers[0].N]
|
||||
assert len(self.X_by_label) in (1, 2)
|
||||
@@ -1239,16 +1553,18 @@ class Optimizer:
|
||||
n = N // len(self.X_by_label)
|
||||
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
||||
self.X_by_label) / n))
|
||||
n_iterations.iadd(n_per_epoch)
|
||||
print('%d runs per epoch' % n_per_epoch)
|
||||
indices_by_label = []
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = regint.Array(n * n_per_epoch)
|
||||
indices_by_label.append(indices)
|
||||
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
|
||||
indices.shuffle()
|
||||
if self.always_shuffle or n_per_epoch > 1:
|
||||
indices.shuffle()
|
||||
loss_sum = MemValue(sfix(0))
|
||||
@for_range(n_per_epoch)
|
||||
def _(j):
|
||||
n_iterations.iadd(1)
|
||||
batch = regint.Array(N)
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = indices_by_label[label]
|
||||
@@ -1257,20 +1573,84 @@ class Optimizer:
|
||||
label * n)
|
||||
self.forward(batch=batch)
|
||||
self.backward(batch=batch)
|
||||
self.update(i)
|
||||
loss = self.layers[-1].l
|
||||
self.update(i, batch=batch)
|
||||
loss_sum.iadd(self.layers[-1].l)
|
||||
if self.print_loss_reduction:
|
||||
before = self.layers[-1].average_loss(N)
|
||||
self.forward(batch=batch)
|
||||
after = self.layers[-1].average_loss(N)
|
||||
print_ln('loss reduction in batch %s: %s (%s - %s)', j,
|
||||
before - after, before, after)
|
||||
elif self.print_losses:
|
||||
print_ln('loss in batch %s: %s', j, self.layers[-1].average_loss(N))
|
||||
if stop_on_loss:
|
||||
loss = self.layers[-1].average_loss(N)
|
||||
res = (loss < stop_on_loss) * (loss >= 0)
|
||||
self.stopped_on_loss.write(1 - res)
|
||||
return res
|
||||
if self.report_loss and self.layers[-1].approx != 5:
|
||||
print_ln('loss after epoch %s: %s', i, loss.reveal())
|
||||
print_ln('loss in epoch %s: %s', i,
|
||||
(loss_sum.reveal() * cfix(1 / n_per_epoch)))
|
||||
else:
|
||||
print_ln('done with epoch %s', i)
|
||||
time()
|
||||
i.iadd(1)
|
||||
res = (i < self.n_epochs)
|
||||
res = True
|
||||
if self.tol > 0:
|
||||
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
|
||||
return res
|
||||
print_ln('finished after %s epochs and %s iterations', i, n_iterations)
|
||||
|
||||
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y):
|
||||
for arg in program.args:
|
||||
m = re.match('rate(.*)', arg)
|
||||
if m:
|
||||
self.gamma = MemValue(cfix(float(m.group(1))))
|
||||
if 'nomom' in program.args:
|
||||
self.momentum = 0
|
||||
model_input = 'model_input' in program.args
|
||||
if model_input:
|
||||
for layer in self.layers:
|
||||
layer.input_from(0)
|
||||
else:
|
||||
self.reset()
|
||||
@for_range(n_runs)
|
||||
def _(i):
|
||||
if not model_input:
|
||||
start_timer(1)
|
||||
self.run(batch_size, stop_on_loss=100)
|
||||
stop_timer(1)
|
||||
if 'no_acc' in program.args:
|
||||
return
|
||||
N = self.layers[0].X.sizes[0]
|
||||
self.forward(N)
|
||||
batch = regint.Array(N)
|
||||
batch.assign_vector(regint.inc(N))
|
||||
self.layers[-1].backward(batch)
|
||||
n_correct = self.layers[-1].reveal_correctness(N, debug=True)
|
||||
print_ln('train_acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / N,
|
||||
n_correct, N)
|
||||
training_address = self.layers[0].X.address
|
||||
self.layers[0].X.address = test_X.address
|
||||
n_test = len(test_Y)
|
||||
self.forward(n_test)
|
||||
self.layers[0].X.address = training_address
|
||||
n_correct = self.layers[-1].reveal_correctness(n_test, test_Y)
|
||||
print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / n_test,
|
||||
n_correct, n_test)
|
||||
if model_input:
|
||||
start_timer(1)
|
||||
self.run(batch_size)
|
||||
stop_timer(1)
|
||||
else:
|
||||
@if_(util.or_op(self.stopped_on_loss, n_correct <
|
||||
int(n_test // self.layers[-1].n_outputs * 1.1)))
|
||||
def _():
|
||||
self.gamma.imul(.5)
|
||||
self.reset()
|
||||
print_ln('reset after reducing learning rate to %s',
|
||||
self.gamma)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, layers, n_epochs):
|
||||
self.alpha = .001
|
||||
@@ -1318,7 +1698,7 @@ class SGD(Optimizer):
|
||||
:param n_epochs: number of epochs for training
|
||||
:param report_loss: disclose and print loss
|
||||
"""
|
||||
def __init__(self, layers, n_epochs, debug=False, report_loss=False):
|
||||
def __init__(self, layers, n_epochs, debug=False, report_loss=None):
|
||||
self.momentum = 0.9
|
||||
self.layers = layers
|
||||
self.n_epochs = n_epochs
|
||||
@@ -1330,11 +1710,19 @@ class SGD(Optimizer):
|
||||
self.thetas.extend(layer.thetas())
|
||||
for theta in layer.thetas():
|
||||
self.delta_thetas.append(theta.same_shape())
|
||||
self.gamma = MemValue(sfix(0.01))
|
||||
self.gamma = MemValue(cfix(0.01))
|
||||
self.debug = debug
|
||||
self.report_loss = report_loss
|
||||
if report_loss is None:
|
||||
self.report_loss = layers[-1].compute_loss
|
||||
else:
|
||||
self.report_loss = report_loss
|
||||
self.tol = 0.000
|
||||
self.X_by_label = None
|
||||
self.print_update_average = False
|
||||
self.print_losses = False
|
||||
self.print_loss_reduction = False
|
||||
self.i_epoch = MemValue(0)
|
||||
self.stopped_on_loss = MemValue(0)
|
||||
|
||||
def reset(self, X_by_label=None):
|
||||
""" Reset layer parameters.
|
||||
@@ -1353,40 +1741,64 @@ class SGD(Optimizer):
|
||||
y.assign_all(0)
|
||||
for layer in self.layers:
|
||||
layer.reset()
|
||||
self.i_epoch.write(0)
|
||||
self.stopped_on_loss.write(0)
|
||||
|
||||
def update(self, i_epoch):
|
||||
def update(self, i_epoch, batch):
|
||||
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
||||
self.delta_thetas):
|
||||
@multithread(self.n_threads, len(nabla))
|
||||
@multithread(self.n_threads, nabla.total_size())
|
||||
def _(base, size):
|
||||
old = delta_theta.get_vector(base, size)
|
||||
red_old = self.momentum * old
|
||||
new = self.gamma * nabla.get_vector(base, size)
|
||||
rate = self.gamma.expand_to_vector(size)
|
||||
nabla_vector = nabla.get_vector(base, size)
|
||||
log_batch_size = math.log(len(batch), 2)
|
||||
# divide by len(batch) by truncation
|
||||
# increased rate if len(batch) is not a power of two
|
||||
pre_trunc = nabla_vector.v * rate.v
|
||||
k = nabla_vector.k + rate.k
|
||||
m = rate.f + int(log_batch_size)
|
||||
v = pre_trunc.round(k, m, signed=True,
|
||||
nearest=sfix.round_nearest)
|
||||
new = nabla_vector._new(v)
|
||||
diff = red_old - new
|
||||
delta_theta.assign_vector(diff, base)
|
||||
theta.assign_vector(theta.get_vector(base, size) +
|
||||
delta_theta.get_vector(base, size), base)
|
||||
if self.debug:
|
||||
for x, name in (old, 'old'), (red_old, 'red_old'), \
|
||||
(new, 'new'), (diff, 'diff'):
|
||||
x = x.reveal()
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
name + ': %s %s %s %s',
|
||||
*[y.v.reveal() for y in (old, red_old, \
|
||||
new, diff)])
|
||||
if self.print_update_average:
|
||||
vec = abs(delta_theta.get_vector().reveal())
|
||||
print_ln('update average: %s (%s)',
|
||||
sum(vec) * cfix(1 / len(vec)), len(vec))
|
||||
if self.debug:
|
||||
limit = int(self.debug)
|
||||
d = delta_theta.get_vector().reveal()
|
||||
a = cfix.Array(len(d.v))
|
||||
aa = [cfix.Array(len(d.v)) for i in range(3)]
|
||||
a = aa[0]
|
||||
a.assign(d)
|
||||
@for_range(len(a))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
'update len=%d' % len(nabla))
|
||||
print_ln_if((x > limit) + (x < -limit),
|
||||
'update epoch=%s %s index=%s %s',
|
||||
i_epoch.read(), str(delta_theta), i, x)
|
||||
a = aa[1]
|
||||
a.assign(nabla.get_vector().reveal())
|
||||
@for_range(len(a))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
'nabla len=%d' % len(nabla))
|
||||
print_ln_if((x > len(batch) * limit) + (x < -len(batch) * limit),
|
||||
'nabla epoch=%s %s index=%s %s',
|
||||
i_epoch.read(), str(nabla), i, x)
|
||||
a = aa[2]
|
||||
a.assign(theta.get_vector().reveal())
|
||||
@for_range(len(a))
|
||||
def _(i):
|
||||
x = a[i]
|
||||
print_ln_if((x > limit) + (x < -limit),
|
||||
'theta epoch=%s %s index=%s %s',
|
||||
i_epoch.read(), str(theta), i, x)
|
||||
index = regint.get_random(64) % len(a)
|
||||
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index,
|
||||
aa[1][index], aa[0][index], aa[2][index])
|
||||
self.gamma.imul(1 - 10 ** - 6)
|
||||
|
||||
@@ -13,6 +13,7 @@ from Compiler import types
|
||||
from Compiler import comparison
|
||||
from Compiler import program
|
||||
from Compiler import instructions_base
|
||||
from Compiler import library, util
|
||||
|
||||
# polynomials as enumerated on Hart's book
|
||||
##
|
||||
@@ -33,11 +34,8 @@ p_3508 = [1.00000000000000000000, -0.50000000000000000000,
|
||||
0.00000000000000000040]
|
||||
##
|
||||
# @private
|
||||
p_1045 = [1.000000077443021686, 0.693147180426163827795756,
|
||||
0.224022651071017064605384, 0.055504068620466379157744,
|
||||
0.009618341225880462374977, 0.001332730359281437819329,
|
||||
0.000155107460590052573978, 0.000014197847399765606711,
|
||||
0.000001863347724137967076]
|
||||
p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(12)]
|
||||
|
||||
##
|
||||
# @private
|
||||
p_2524 = [-2.05466671951, -8.8626599391,
|
||||
@@ -92,8 +90,8 @@ pi_over_2 = math.radians(90)
|
||||
#
|
||||
# @return truncated sint value of x
|
||||
def trunc(x):
|
||||
if type(x) is types.sfix:
|
||||
return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True)
|
||||
if isinstance(x, types._fix):
|
||||
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
|
||||
elif type(x) is types.sfloat:
|
||||
v, p, z, s = floatingpoint.FLRound(x, 0)
|
||||
#return types.sfloat(v, p, z, s, x.err)
|
||||
@@ -125,7 +123,7 @@ def load_sint(x, l_type):
|
||||
# @return the evaluation of the polynomial. return type depends on inputs.
|
||||
def p_eval(p_c, x):
|
||||
degree = len(p_c) - 1
|
||||
if type(x) is types.sfix:
|
||||
if isinstance(x, types._fix):
|
||||
# ignore coefficients smaller than precision
|
||||
for c in reversed(p_c):
|
||||
if c < 2 ** -(x.f + 1):
|
||||
@@ -160,10 +158,10 @@ def sTrigSub(x):
|
||||
y = x - (f) * x.coerce(2 * pi)
|
||||
# reduction to \pi
|
||||
b1 = y > pi
|
||||
w = b1 * ((2 * pi - y) - y) + y
|
||||
w = b1.if_else(2 * pi - y, y)
|
||||
# reduction to \pi/2
|
||||
b2 = w > pi_over_2
|
||||
w = b2 * ((pi - w) - w) + w
|
||||
w = b2.if_else(pi - w, w)
|
||||
# returns scaled angle and boolean flags
|
||||
return w, b1, b2
|
||||
|
||||
@@ -182,9 +180,8 @@ def ssin(w, s):
|
||||
v = w * (1.0 / pi_over_2)
|
||||
v_2 = v ** 2
|
||||
# adjust sign according to the movement in the reduction
|
||||
b = s * (-2) + 1
|
||||
# calculate the sin using polynomial evaluation
|
||||
local_sin = b * v * p_eval(p_3307, v_2)
|
||||
local_sin = s.if_else(-v, v) * p_eval(p_3307, v_2)
|
||||
return local_sin
|
||||
|
||||
|
||||
@@ -203,10 +200,10 @@ def scos(w, s):
|
||||
# calculates the v of the w.
|
||||
v = w
|
||||
v_2 = v ** 2
|
||||
# adjust sign according to the movement in the reduction
|
||||
b = s * (-2) + 1
|
||||
# calculate the cos using polynomial evaluation
|
||||
local_cos = b * p_eval(p_3508, v_2)
|
||||
tmp = p_eval(p_3508, v_2)
|
||||
# adjust sign according to the movement in the reduction
|
||||
local_cos = s.if_else(-tmp, tmp)
|
||||
return local_cos
|
||||
|
||||
|
||||
@@ -264,11 +261,12 @@ def tan(x):
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def exp2_fx(a):
|
||||
def exp2_fx(a, zero_output=False):
|
||||
"""
|
||||
Power of two for fixed-point numbers.
|
||||
|
||||
:param a: exponent for :math:`2^a` (sfix)
|
||||
:param zero_output: whether to output zero for very small values. If not, the result will be undefined.
|
||||
|
||||
:return: :math:`2^a` if it is within the range. Undefined otherwise
|
||||
"""
|
||||
@@ -279,54 +277,95 @@ def exp2_fx(a):
|
||||
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
||||
n_bits = a.f + n_int_bits
|
||||
n_shift = int(types.program.options.ring) - a.k
|
||||
if types.program.use_edabit():
|
||||
l = sint.get_edabit(a.f, True)
|
||||
u = sint.get_edabit(a.k - a.f, True)
|
||||
r_bits = l[1] + u[1]
|
||||
r = l[0] + (u[0] << a.f)
|
||||
lower_r = l[0]
|
||||
if types.program.use_split():
|
||||
assert not zero_output
|
||||
from Compiler.GC.types import sbitvec
|
||||
if types.program.use_split() == 3:
|
||||
x = a.v.split_to_two_summands(a.k)
|
||||
bits = types._bitint.carry_lookahead_adder(x[0], x[1],
|
||||
fewer_inv=False)
|
||||
# converting MSB first reduces the number of rounds
|
||||
s = sint.conv(bits[-1])
|
||||
lower_overflow = sint.conv(x[0][a.f]) + \
|
||||
sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f])
|
||||
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
|
||||
elif types.program.use_split() == 4:
|
||||
x = list(zip(*a.v.split_to_n_summands(a.k, 4)))
|
||||
bi = types._bitint
|
||||
red = bi.wallace_reduction
|
||||
sums1, carries1 = red(*x[:3], get_carry=False)
|
||||
sums2, carries2 = red(x[3], sums1, carries1, False)
|
||||
bits = bi.carry_lookahead_adder(sums2, carries2,
|
||||
fewer_inv=False)
|
||||
overflows = bi.full_adder(carries1[a.f], carries2[a.f],
|
||||
bits[a.f] ^ sums2[a.f] ^ carries2[a.f])
|
||||
overflows = reversed(list((sint.conv(x)
|
||||
for x in reversed(overflows))))
|
||||
lower_overflow = sint.bit_compose(sint.conv(x)
|
||||
for x in overflows)
|
||||
s = sint.conv(bits[-1])
|
||||
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
|
||||
else:
|
||||
bits = sbitvec(a.v, a.k)
|
||||
s = sint.conv(bits[-1])
|
||||
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
|
||||
higher_bits = bits[a.f:n_bits]
|
||||
else:
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
shifted = ((a.v - r) << n_shift).reveal()
|
||||
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
||||
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
|
||||
r_bits[a.f-1::-1])
|
||||
lower_masked = sint.bit_compose(masked_bits[:a.f])
|
||||
lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f))
|
||||
if types.program.use_edabit():
|
||||
l = sint.get_edabit(a.f, True)
|
||||
u = sint.get_edabit(a.k - a.f, True)
|
||||
r_bits = l[1] + u[1]
|
||||
r = l[0] + (u[0] << a.f)
|
||||
lower_r = l[0]
|
||||
else:
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
shifted = ((a.v - r) << n_shift).reveal()
|
||||
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
||||
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
|
||||
r_bits[a.f-1::-1])
|
||||
lower_masked = sint.bit_compose(masked_bits[:a.f])
|
||||
lower = lower_r + lower_masked - \
|
||||
(sint.conv(lower_overflow) << (a.f))
|
||||
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
|
||||
masked_bits[a.f:n_bits],
|
||||
carry_in=lower_overflow,
|
||||
get_carry=True)
|
||||
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
|
||||
r_bits[n_bits:-1],
|
||||
higher_bits[-1])
|
||||
if zero_output:
|
||||
# should be for free
|
||||
highest_bits = r_bits[0].ripple_carry_adder(
|
||||
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
|
||||
carry_in=higher_bits[-1])
|
||||
bits_to_check = [x.bit_xor(y)
|
||||
for x, y in zip(highest_bits[:-1],
|
||||
r_bits[n_bits:-1])]
|
||||
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
||||
bits_to_check))
|
||||
# sign
|
||||
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
|
||||
del higher_bits[-1]
|
||||
c = types.sfix._new(lower, k=a.k, f=a.f)
|
||||
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
|
||||
masked_bits[a.f:n_bits],
|
||||
carry_in=lower_overflow,
|
||||
get_carry=True)
|
||||
assert(len(higher_bits) == n_bits - a.f + 1)
|
||||
assert(len(higher_bits) == n_bits - a.f)
|
||||
pow2_bits = [sint.conv(x) for x in higher_bits]
|
||||
d = floatingpoint.Pow2_from_bits(pow2_bits[:-1])
|
||||
d = floatingpoint.Pow2_from_bits(pow2_bits)
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
|
||||
2 ** n_int_bits, signed=False,
|
||||
nearest=types.sfix.round_nearest),
|
||||
k=a.k, f=a.f)
|
||||
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
|
||||
r_bits[n_bits:-1],
|
||||
higher_bits[-1])
|
||||
# should be for free
|
||||
highest_bits = r_bits[0].ripple_carry_adder(
|
||||
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
|
||||
carry_in=higher_bits[-1])
|
||||
bits_to_check = [x.bit_xor(y)
|
||||
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
|
||||
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
||||
bits_to_check))
|
||||
# sign
|
||||
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
|
||||
return s.if_else(t.if_else(small_result, 0), g)
|
||||
if zero_output:
|
||||
small_result = t.if_else(small_result, 0)
|
||||
return s.if_else(small_result, g)
|
||||
else:
|
||||
assert not zero_output
|
||||
# obtain absolute value of a
|
||||
s = a < 0
|
||||
a = (s * (-2) + 1) * a
|
||||
a = s.if_else(-a, a)
|
||||
# isolates fractional part of number
|
||||
b = trunc(a)
|
||||
c = a - b
|
||||
@@ -335,7 +374,7 @@ def exp2_fx(a):
|
||||
# evaluates fractional part of a in p_1045
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
return (1 - s) * g + s / g
|
||||
return s.if_else(1 / g, g)
|
||||
|
||||
|
||||
@types.vectorize
|
||||
@@ -353,19 +392,20 @@ def log2_fx(x):
|
||||
:return: (sfix) the value of :math:`\log_2(x)`
|
||||
|
||||
"""
|
||||
if type(x) is types.sfix:
|
||||
if isinstance(x, types._fix):
|
||||
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
||||
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
|
||||
p -= x.f
|
||||
vlen = x.f
|
||||
v = x._new(v, k=x.k, f=x.f)
|
||||
else:
|
||||
d = types.sfloat(x)
|
||||
v, p, vlen = d.v, d.p, d.vlen
|
||||
w = x.coerce(1.0 / (2 ** (vlen)))
|
||||
v *= w
|
||||
# isolates mantisa of d, now the n can be also substituted by the
|
||||
# secret shared p from d in the expresion above.
|
||||
w = x.coerce(1.0 / (2 ** (vlen)))
|
||||
v = v * w
|
||||
# polynomials for the log_2 evaluation of f are calculated
|
||||
P = p_eval(p_2524, v)
|
||||
Q = p_eval(q_2524, v)
|
||||
@@ -384,7 +424,7 @@ def pow_fx(x, y):
|
||||
|
||||
:param y: (sfix, clear types) secret shared exponent.
|
||||
|
||||
:return: :math:`x^y` (sfix)
|
||||
:return: :math:`x^y` (sfix) if positive and in range
|
||||
"""
|
||||
log2_x =0
|
||||
# obtains log2(x)
|
||||
@@ -456,9 +496,6 @@ def floor_fx(x):
|
||||
def MSB(b, k):
|
||||
# calculation of z
|
||||
# x in order 0 - k
|
||||
if (k > types.program.bit_length):
|
||||
raise OverflowError("The supported bit \
|
||||
lenght of the application is smaller than k")
|
||||
|
||||
x_order = b.bit_decompose(k)
|
||||
x = [0] * k
|
||||
@@ -511,9 +548,7 @@ def norm_simplified_SQ(b, k):
|
||||
w_array[i] = z[2 * i - 1] + z[2 * i]
|
||||
|
||||
# w aggregation
|
||||
w = types.sint(0)
|
||||
for i in range(k_over_2):
|
||||
w += (2 ** i) * w_array[i]
|
||||
w = b.bit_compose(w_array)
|
||||
|
||||
# return computed values
|
||||
#return m_odd, m, w
|
||||
@@ -538,9 +573,9 @@ def sqrt_simplified_fx(x):
|
||||
# process to set up the precision and allocate correct 2**f
|
||||
if x.f % 2 == 1:
|
||||
m_odd = (1 - 2 * m_odd) + m_odd
|
||||
w = (w * 2 - w) * (1-m_odd) + w
|
||||
w = m_odd.if_else(w, 2 * w)
|
||||
# map number to use sfix format and instantiate the number
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
|
||||
w = x._new(w << ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
|
||||
# obtains correct 2 ** (m/2)
|
||||
w = (w * (2 ** (1/2.0)) - w) * m_odd + w
|
||||
# produce x/ 2^(m/2)
|
||||
@@ -739,15 +774,15 @@ def atan(x):
|
||||
"""
|
||||
# obtain absolute value of x
|
||||
s = x < 0
|
||||
x_abs = (s * (-2) + 1) * x
|
||||
x_abs = s.if_else(-x, x)
|
||||
# angle isolation
|
||||
b = x_abs > 1
|
||||
v = 1 / x_abs
|
||||
v = (1 - b) * (x_abs - v) + v
|
||||
v = b.if_else(v, x_abs)
|
||||
v_2 =v*v
|
||||
|
||||
# range of polynomial coefficients
|
||||
assert x.k - x.f >= 15
|
||||
assert x.k - x.f >= 19
|
||||
P = p_eval(p_5102, v_2)
|
||||
Q = p_eval(q_5102, v_2)
|
||||
|
||||
@@ -756,8 +791,8 @@ def atan(x):
|
||||
y_pi_over_two = pi_over_2 - y
|
||||
|
||||
# sign correction
|
||||
y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two
|
||||
y = (1 - s) * (y - (-y)) + (-y)
|
||||
y = b.if_else(y_pi_over_two, y)
|
||||
y = s.if_else(-y, y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ class Program(object):
|
||||
self.use_dabit = options.mixed
|
||||
self._edabit = options.edabit
|
||||
self._split = False
|
||||
if options.split:
|
||||
self.use_split(int(options.split))
|
||||
self._square = False
|
||||
self._always_raw = False
|
||||
Program.prog = self
|
||||
@@ -243,10 +245,12 @@ class Program(object):
|
||||
""" The basic block that is currently being created. """
|
||||
return self.curr_tape.active_basicblock
|
||||
|
||||
def malloc(self, size, mem_type, reg_type=None):
|
||||
def malloc(self, size, mem_type, reg_type=None, creator_tape=None):
|
||||
""" Allocate memory from the top """
|
||||
if not isinstance(size, int):
|
||||
raise CompilerError('size must be known at compile time')
|
||||
if (creator_tape or self.curr_tape) != self.tapes[0]:
|
||||
raise CompilerError('cannot allocate memory outside main thread')
|
||||
if size == 0:
|
||||
return
|
||||
if isinstance(mem_type, type):
|
||||
@@ -330,7 +334,9 @@ class Program(object):
|
||||
if change is None:
|
||||
return self._split
|
||||
else:
|
||||
assert change in (2, 3)
|
||||
if change and not self.options.ring:
|
||||
raise CompilerError('splitting only supported for rings')
|
||||
assert change > 1
|
||||
self._split = change
|
||||
|
||||
def use_square(self, change=None):
|
||||
@@ -350,8 +356,12 @@ class Program(object):
|
||||
self.use_trunc_pr = True
|
||||
if 'split' in self.args or 'split3' in self.args:
|
||||
self.use_split(3)
|
||||
if 'split4' in self.args:
|
||||
self.use_split(4)
|
||||
if 'raw' in self.args:
|
||||
self.always_raw(True)
|
||||
if 'edabit' in self.args:
|
||||
self.use_edabit(True)
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
@@ -559,12 +569,14 @@ class Tape:
|
||||
numrounds = merger.longest_paths_merge()
|
||||
block.n_rounds = numrounds
|
||||
block.n_to_merge = len(merger.open_nodes)
|
||||
if numrounds > 0 and self.program.verbose:
|
||||
print('Program requires %d rounds of communication' % numrounds)
|
||||
if merger.counter and self.program.verbose:
|
||||
print('Block requires', \
|
||||
', '.join('%d %s' % (y, x.__name__) \
|
||||
for x, y in list(merger.counter.items())))
|
||||
if merger.counter and self.program.verbose:
|
||||
print('Block requires %s rounds' % \
|
||||
', '.join('%d %s' % (y, x.__name__) \
|
||||
for x, y in list(merger.rounds.items())))
|
||||
# free memory
|
||||
merger = None
|
||||
if options.dead_code_elimination:
|
||||
|
||||
@@ -10,6 +10,8 @@ Registers are used for computation, allocated on an ongoing basis,
|
||||
and thread-specific. The memory is allocated statically and shared
|
||||
between threads. This means that memory-based types such as
|
||||
:py:class:`Array` can be used to transfer information between threads.
|
||||
Note that creating memory-based types outside the main thread is not
|
||||
supported.
|
||||
|
||||
If viewing this documentation in processed form, many function signatures
|
||||
appear generic because of the use of decorators. See the source code for the
|
||||
@@ -74,6 +76,7 @@ from . import instructions
|
||||
from .util import is_zero, is_one
|
||||
import operator
|
||||
from functools import reduce
|
||||
import re
|
||||
|
||||
|
||||
class ClientMessageType:
|
||||
@@ -118,17 +121,6 @@ class MPCThread(object):
|
||||
program.join_tape(self.run_handles.pop(0))
|
||||
|
||||
|
||||
def copy_doc(a, b):
|
||||
try:
|
||||
a.__doc__ = b.__doc__
|
||||
except:
|
||||
pass
|
||||
|
||||
def no_doc(operation):
|
||||
def wrapper(*args, **kwargs):
|
||||
return operation(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def copy_doc(a, b):
|
||||
try:
|
||||
a.__doc__ = b.__doc__
|
||||
@@ -316,6 +308,14 @@ class _number(object):
|
||||
res.iadd(res.value_type.conv(aa[i] * bb[i]))
|
||||
return res.read()
|
||||
|
||||
def __abs__(self):
|
||||
""" Absolute value. """
|
||||
return (self < 0).if_else(-self, self)
|
||||
|
||||
@staticmethod
|
||||
def popcnt_bits(bits):
|
||||
return sum(bits)
|
||||
|
||||
class _int(object):
|
||||
""" Integer functionality. """
|
||||
|
||||
@@ -534,11 +534,11 @@ class _register(Tape.Register, _number, _structure):
|
||||
return sum(b << i for i,b in enumerate(bits))
|
||||
|
||||
@classmethod
|
||||
def malloc(cls, size):
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
""" Allocate memory (statically).
|
||||
|
||||
:param size: compile-time (int) """
|
||||
return program.malloc(size, cls)
|
||||
return program.malloc(size, cls, creator_tape=creator_tape)
|
||||
|
||||
@classmethod
|
||||
def free(cls, addr):
|
||||
@@ -833,6 +833,20 @@ class cint(_clear, _int):
|
||||
:param other: cint/regint/int """
|
||||
return self.coerce_op(other, modc, True)
|
||||
|
||||
def less_than(self, other, bit_length):
|
||||
""" Clear comparison for particular bit length.
|
||||
|
||||
:param other: cint/regint/int
|
||||
:param bit_length: signed bit length of inputs
|
||||
:return: 0/1 (regint), undefined if inputs outside range """
|
||||
if bit_length <= 64:
|
||||
return self < other
|
||||
else:
|
||||
diff = self - other
|
||||
shifted = diff >> (bit_length - 1)
|
||||
res = regint(shifted & 1)
|
||||
return res
|
||||
|
||||
def __lt__(self, other):
|
||||
""" Clear 64-bit comparison.
|
||||
|
||||
@@ -1732,15 +1746,20 @@ class sint(_secret, _int):
|
||||
""" Secret random n-bit number according to security model.
|
||||
|
||||
:param bits: compile-time integer (int) """
|
||||
if program.use_split() == 3:
|
||||
if program.use_edabit():
|
||||
return sint.get_edabit(bits, True)[0]
|
||||
elif program.use_split() > 2:
|
||||
tmp = sint()
|
||||
randoms(tmp, bits)
|
||||
x = tmp.split_to_two_summands(bits, True)
|
||||
overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \
|
||||
sint.conv(x[0][-1])
|
||||
carry = comparison.CarryOutRawLE(x[1][:bits], x[0][:bits])
|
||||
if program.use_split() > 3:
|
||||
from .GC.types import sbitint
|
||||
x = sbitint.full_adder(carry, x[0][bits], x[1][bits])
|
||||
overflow = sint.conv(x[1]) * 2 + sint.conv(x[0])
|
||||
else:
|
||||
overflow = sint.conv(carry) + sint.conv(x[0][bits])
|
||||
return tmp - (overflow << bits)
|
||||
elif program.use_edabit():
|
||||
return sint.get_edabit(bits, True)[0]
|
||||
res = sint()
|
||||
comparison.PRandInt(res, bits)
|
||||
return res
|
||||
@@ -1791,7 +1810,7 @@ class sint(_secret, _int):
|
||||
def bit_decompose_clear(a, n_bits):
|
||||
return floatingpoint.bits(a, n_bits)
|
||||
|
||||
@classmethod
|
||||
@vectorized_classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
rawinput(player, res)
|
||||
@@ -1980,7 +1999,7 @@ class sint(_secret, _int):
|
||||
|
||||
@vectorize
|
||||
@read_mem_value
|
||||
def __rshift__(self, other, bit_length=None, security=None):
|
||||
def __rshift__(self, other, bit_length=None, security=None, signed=True):
|
||||
""" Secret right shift.
|
||||
|
||||
:param other: secret or public integer (sint/cint/regint/int) """
|
||||
@@ -1990,7 +2009,7 @@ class sint(_secret, _int):
|
||||
if other == 0:
|
||||
return self
|
||||
res = sint()
|
||||
comparison.Trunc(res, self, bit_length, other, security, True)
|
||||
comparison.Trunc(res, self, bit_length, other, security, signed)
|
||||
return res
|
||||
elif isinstance(other, sint):
|
||||
return floatingpoint.Trunc(self, bit_length, other, security)
|
||||
@@ -2092,6 +2111,15 @@ class sint(_secret, _int):
|
||||
columns = self.split_to_n_summands(length, n)
|
||||
return _bitint.wallace_tree_without_finish(columns, get_carry)
|
||||
|
||||
@vectorize
|
||||
def raw_right_shift(self, length):
|
||||
res = sint()
|
||||
shrsi(res, self, length)
|
||||
return res
|
||||
|
||||
def raw_mod2m(self, m):
|
||||
return self - (self.raw_right_shift(m) << m)
|
||||
|
||||
@vectorize
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
@@ -2304,6 +2332,14 @@ class _bitint(object):
|
||||
b.pop(0)
|
||||
else:
|
||||
break
|
||||
carries = cls.get_carries(a, b, fewer_inv=fewer_inv, carry_in=carry_in)
|
||||
res = lower + cls.sum_from_carries(a, b, carries)
|
||||
if get_carry:
|
||||
res += [carries[-1]]
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_carries(cls, a, b, fewer_inv=False, carry_in=0):
|
||||
d = [cls.half_adder(ai, bi) for (ai,bi) in zip(a,b)]
|
||||
carry = floatingpoint.carry
|
||||
if fewer_inv:
|
||||
@@ -2314,10 +2350,7 @@ class _bitint(object):
|
||||
carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1]
|
||||
else:
|
||||
carries = []
|
||||
res = lower + cls.sum_from_carries(a, b, carries)
|
||||
if get_carry:
|
||||
res += [carries[-1]]
|
||||
return res
|
||||
return carries
|
||||
|
||||
@staticmethod
|
||||
def sum_from_carries(a, b, carries):
|
||||
@@ -2469,6 +2502,18 @@ class _bitint(object):
|
||||
def wallace_tree(cls, rows):
|
||||
return cls.wallace_tree_from_columns([list(x) for x in zip(*rows)])
|
||||
|
||||
@classmethod
|
||||
def wallace_reduction(cls, a, b, c, get_carry=True):
|
||||
assert len(a) == len(b) == len(c)
|
||||
tmp = zip(*(cls.full_adder(*x) for x in zip(a, b, c)))
|
||||
sums, carries = (list(x) for x in tmp)
|
||||
carries = [0] + carries
|
||||
if get_carry:
|
||||
sums += [0]
|
||||
else:
|
||||
del carries[-1]
|
||||
return sums, carries
|
||||
|
||||
def __sub__(self, other):
|
||||
if type(other) == sgf2n:
|
||||
raise CompilerError('Unclear subtraction')
|
||||
@@ -2497,7 +2542,7 @@ class _bitint(object):
|
||||
def __rshift__(self, other):
|
||||
return self.compose(self.bit_decompose()[other:])
|
||||
|
||||
def bit_decompose(self, n_bits=None, *args):
|
||||
def bit_decompose(self, n_bits=None, security=None):
|
||||
if self.bits is None:
|
||||
self.bits = self.force_bit_decompose(self.n_bits)
|
||||
if n_bits is None:
|
||||
@@ -2541,14 +2586,16 @@ class _bitint(object):
|
||||
def __gt__(self, other):
|
||||
return 1 - (self <= other)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other, bit_length=None, security=None):
|
||||
diff = self ^ other
|
||||
diff_bits = [1 - x for x in diff.bit_decompose()]
|
||||
diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]]
|
||||
return floatingpoint.KMul(diff_bits)
|
||||
|
||||
def __ne__(self, other):
|
||||
return 1 - (self == other)
|
||||
|
||||
equal = __eq__
|
||||
|
||||
def __neg__(self):
|
||||
return 1 + self.compose(1 ^ b for b in self.bit_decompose())
|
||||
|
||||
@@ -2752,9 +2799,9 @@ def parse_type(other, k=None, f=None):
|
||||
|
||||
class cfix(_number, _structure):
|
||||
""" Clear fixed-point number represented as clear integer. """
|
||||
__slots__ = ['value', 'f', 'k', 'size']
|
||||
__slots__ = ['value', 'f', 'k']
|
||||
reg_type = 'c'
|
||||
scalars = (int, float, regint)
|
||||
scalars = (int, float, regint, cint)
|
||||
@classmethod
|
||||
def set_precision(cls, f, k = None):
|
||||
""" Set the precision of the integer representation. Note that some
|
||||
@@ -2779,7 +2826,7 @@ class cfix(_number, _structure):
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address, mem_type=None):
|
||||
""" Load from memory by public address. """
|
||||
return cls(cint.load_mem(address))
|
||||
return cls._new(cint.load_mem(address))
|
||||
|
||||
@vectorized_classmethod
|
||||
def read_from_socket(cls, client_id, n=1):
|
||||
@@ -2787,7 +2834,7 @@ class cfix(_number, _structure):
|
||||
Sender will have already bit shifted and sent as cints."""
|
||||
cint_input = cint.read_from_socket(client_id, n)
|
||||
if n == 1:
|
||||
return cfix(cint_inputs)
|
||||
return cfix._new(cint_inputs)
|
||||
else:
|
||||
return list(map(cfix, cint_inputs))
|
||||
|
||||
@@ -2805,34 +2852,47 @@ class cfix(_number, _structure):
|
||||
writesocketc(client_id, message_type, *cint_values)
|
||||
|
||||
@staticmethod
|
||||
def malloc(size):
|
||||
return program.malloc(size, cint)
|
||||
def malloc(size, creator_tape=None):
|
||||
return program.malloc(size, cint, creator_tape=creator_tape)
|
||||
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, other):
|
||||
res = cls()
|
||||
res.load_int(other)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _new(cls, other, k=None, f=None):
|
||||
res = cls(k=k, f=f)
|
||||
res.v = cint.conv(other)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def int_rep(v, f):
|
||||
v = v * (2 ** f)
|
||||
try:
|
||||
v = int(round(v))
|
||||
except TypeError:
|
||||
pass
|
||||
return v
|
||||
|
||||
@vectorize_init
|
||||
@read_mem_value
|
||||
def __init__(self, v=None, k=None, f=None, size=None):
|
||||
""" :param v: cfix/float/int """
|
||||
f = self.f if f is None else f
|
||||
k = self.k if k is None else k
|
||||
self.f = f
|
||||
self.k = k
|
||||
self.size = get_global_vector_size()
|
||||
if isinstance(v, cint):
|
||||
self.v = cint(v,size=self.size)
|
||||
elif isinstance(v, cfix.scalars):
|
||||
v = v * (2 ** f)
|
||||
try:
|
||||
v = int(round(v))
|
||||
except TypeError:
|
||||
pass
|
||||
self.v = cint(v, size=self.size)
|
||||
if isinstance(v, cfix.scalars):
|
||||
v = self.int_rep(v, f)
|
||||
self.v = cint(v, size=size)
|
||||
elif isinstance(v, cfix):
|
||||
self.v = v.v
|
||||
elif isinstance(v, MemValue):
|
||||
self.v = v
|
||||
elif v is None:
|
||||
self.v = cint(0)
|
||||
else:
|
||||
@@ -2840,7 +2900,10 @@ class cfix(_number, _structure):
|
||||
|
||||
def __iter__(self):
|
||||
for x in self.v:
|
||||
yield type(self)(x, self.k, self.f)
|
||||
yield self._new(x, self.k, self.f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.v)
|
||||
|
||||
@vectorize
|
||||
def load_int(self, v):
|
||||
@@ -2863,6 +2926,10 @@ class cfix(_number, _structure):
|
||||
""" Store in memory by public address. """
|
||||
self.v.store_in_mem(address)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self.v.size
|
||||
|
||||
def sizeof(self):
|
||||
return self.size * 4
|
||||
|
||||
@@ -2873,7 +2940,7 @@ class cfix(_number, _structure):
|
||||
:param other: cfix/cint/regint/int """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return cfix(self.v + other.v)
|
||||
return cfix._new(self.v + other.v)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@@ -2884,18 +2951,26 @@ class cfix(_number, _structure):
|
||||
:param other: cfix/cint/regint/int/sint """
|
||||
if isinstance(other, sint):
|
||||
return sfix._new(self.v * other, k=self.k, f=self.f)
|
||||
if isinstance(other, (int, regint, cint)):
|
||||
return cfix._new(self.v * cint(other), k=self.k, f=self.f)
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
assert self.f == other.f
|
||||
sgn = cint(1 - 2 * (self.v * other.v < 0))
|
||||
sgn = cint(1 - 2 * ((self < 0) ^ (other < 0)))
|
||||
absolute = self.v * other.v * sgn
|
||||
val = sgn * (absolute >> self.f)
|
||||
return cfix(val)
|
||||
return cfix._new(val)
|
||||
elif isinstance(other, sfix):
|
||||
return NotImplemented
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for cfix.__mul__' % type(other))
|
||||
|
||||
|
||||
def positive_mul(self, other):
|
||||
assert isinstance(other, float)
|
||||
assert other >= 0
|
||||
v = self.v * int(round(other * 2 ** self.f))
|
||||
return self._new(v >> self.f, k=self.k, f=self.f)
|
||||
|
||||
@vectorize
|
||||
def __sub__(self, other):
|
||||
""" Clear fixed-point subtraction.
|
||||
@@ -2903,9 +2978,9 @@ class cfix(_number, _structure):
|
||||
:param other: cfix/cint/regint/int """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return cfix(self.v - other.v)
|
||||
return cfix._new(self.v - other.v)
|
||||
elif isinstance(other, sfix):
|
||||
return sfix(self.v - other.v)
|
||||
return sfix._new(self.v - other.v)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2913,7 +2988,7 @@ class cfix(_number, _structure):
|
||||
def __neg__(self):
|
||||
""" Clear fixed-point negation. """
|
||||
# cfix type always has .v
|
||||
return cfix(-self.v)
|
||||
return cfix._new(-self.v)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return -self + other
|
||||
@@ -2939,7 +3014,8 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed-point comparison. """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return self.v < other.v
|
||||
assert self.k == other.k
|
||||
return self.v.less_than(other.v, self.k)
|
||||
elif isinstance(other, sfix):
|
||||
if(self.k != other.k or self.f != other.f):
|
||||
raise TypeError('Incompatible fixed point types in comparison')
|
||||
@@ -2952,7 +3028,7 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed-point comparison. """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return self.v <= other.v
|
||||
return 1 - (self > other)
|
||||
elif isinstance(other, sfix):
|
||||
return other.v.greater_equal(self.v, self.k, other.kappa)
|
||||
else:
|
||||
@@ -2963,7 +3039,7 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed-point comparison. """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return self.v > other.v
|
||||
return other.__lt__(self)
|
||||
elif isinstance(other, sfix):
|
||||
return other.v.less_than(self.v, self.k, other.kappa)
|
||||
else:
|
||||
@@ -2974,7 +3050,7 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed-point comparison. """
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return self.v >= other.v
|
||||
return 1 - (self < other)
|
||||
elif isinstance(other, sfix):
|
||||
return other.v.less_equal(self.v, self.k, other.kappa)
|
||||
else:
|
||||
@@ -3000,9 +3076,10 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed-point division.
|
||||
|
||||
:param other: cfix/cint/regint/int """
|
||||
other = parse_type(other)
|
||||
other = parse_type(other, self.k, self.f)
|
||||
if isinstance(other, cfix):
|
||||
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
|
||||
return cfix._new(library.cint_cint_division(
|
||||
self.v, other.v, self.k, self.f), k=self.k, f=self.f)
|
||||
elif isinstance(other, sfix):
|
||||
assert self.k == other.k
|
||||
assert self.f == other.f
|
||||
@@ -3016,11 +3093,11 @@ class cfix(_number, _structure):
|
||||
def print_plain(self):
|
||||
""" Clear fixed-point output. """
|
||||
if self.k > 64:
|
||||
raise CompilerError('Printing of fixed-point numbers not ' +
|
||||
'implemented for more than 64-bit precision')
|
||||
tmp = regint()
|
||||
convmodp(tmp, self.v, bitlength=self.k)
|
||||
sign = cint(tmp < 0)
|
||||
sign = (((self.v + (1 << (self.k - 1))) >> self.k) & 1)
|
||||
else:
|
||||
tmp = regint()
|
||||
convmodp(tmp, self.v, bitlength=self.k)
|
||||
sign = cint(tmp < 0)
|
||||
abs_v = sign.if_else(-self.v, self.v)
|
||||
print_float_plain(cint(abs_v), cint(-self.f), \
|
||||
cint(0), cint(sign), cint(0))
|
||||
@@ -3064,8 +3141,8 @@ class _single(_number, _structure):
|
||||
return cls.conv(other)
|
||||
|
||||
@classmethod
|
||||
def malloc(cls, size):
|
||||
return program.malloc(size, cls.int_type)
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return program.malloc(size, cls.int_type, creator_tape=creator_tape)
|
||||
|
||||
@classmethod
|
||||
def free(cls, addr):
|
||||
@@ -3192,7 +3269,7 @@ class _single(_number, _structure):
|
||||
|
||||
class _fix(_single):
|
||||
""" Secret fixed point type. """
|
||||
__slots__ = ['v', 'f', 'k', 'size']
|
||||
__slots__ = ['v', 'f', 'k']
|
||||
|
||||
def set_precision(cls, f, k = None):
|
||||
cls.f = f
|
||||
@@ -3200,12 +3277,28 @@ class _fix(_single):
|
||||
if k is None:
|
||||
cls.k = 2 * f
|
||||
else:
|
||||
if k < f:
|
||||
raise CompilerError('bit length cannot be less than precision')
|
||||
cls.k = k
|
||||
set_precision.__doc__ = cfix.set_precision.__doc__
|
||||
set_precision = classmethod(set_precision)
|
||||
|
||||
@classmethod
|
||||
def set_precision_from_args(cls, program):
|
||||
f = None
|
||||
k = None
|
||||
for arg in program.args:
|
||||
m = re.match('f([0-9]+)$', arg)
|
||||
if m:
|
||||
f = int(m.group(1))
|
||||
m = re.match('k([0-9]+)$', arg)
|
||||
if m:
|
||||
k = int(m.group(1))
|
||||
if f is not None:
|
||||
print ('Setting fixed-point precision to %d/%s' % (f, k))
|
||||
cls.set_precision(f, k)
|
||||
cfix.set_precision(f, k)
|
||||
elif k is not None:
|
||||
raise CompilerError('need to set fractional precision')
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
if isinstance(other, (_fix, cls.clear_type)):
|
||||
@@ -3224,13 +3317,13 @@ class _fix(_single):
|
||||
|
||||
@classmethod
|
||||
def _new(cls, other, k=None, f=None):
|
||||
res = cls(other, k=k, f=f)
|
||||
res = cls(k=k, f=f)
|
||||
res.v = cls.int_type.conv(other)
|
||||
return res
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, _v=None, k=None, f=None, size=None):
|
||||
""" :params _v: compile-time value (int/float) """
|
||||
self.size = get_global_vector_size()
|
||||
""" :params _v: int/float/regint/cint/sint/sfloat """
|
||||
if k is None:
|
||||
k = self.k
|
||||
else:
|
||||
@@ -3241,15 +3334,12 @@ class _fix(_single):
|
||||
self.f = f
|
||||
assert k is not None
|
||||
assert f is not None
|
||||
# warning: don't initialize a sfix from a sint, this is only used in internal methods;
|
||||
# for external initialization use load_int.
|
||||
if _v is None:
|
||||
self.v = self.int_type(0)
|
||||
elif isinstance(_v, self.int_type):
|
||||
self.v = _v
|
||||
self.size = _v.size
|
||||
self.load_int(_v)
|
||||
elif isinstance(_v, cfix.scalars):
|
||||
self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size)
|
||||
self.v = self.int_type(cfix.int_rep(_v, f=f), size=size)
|
||||
elif isinstance(_v, self.float_type):
|
||||
p = (f + _v.p)
|
||||
b = (p.greater_equal(0, _v.vlen))
|
||||
@@ -3265,7 +3355,6 @@ class _fix(_single):
|
||||
if not isinstance(self.v, self.int_type):
|
||||
raise CompilerError('sfix conversion failure: %s/%s' % (_v, self.v))
|
||||
|
||||
@vectorize
|
||||
def load_int(self, v):
|
||||
self.v = self.int_type(v) << self.f
|
||||
|
||||
@@ -3365,7 +3454,7 @@ class _fix(_single):
|
||||
class revealed_fix(self.clear_type):
|
||||
f = self.f
|
||||
k = self.k
|
||||
return revealed_fix(val)
|
||||
return revealed_fix._new(val)
|
||||
|
||||
class sfix(_fix):
|
||||
""" Secret fixed-point number represented as secret integer.
|
||||
@@ -3410,6 +3499,24 @@ class sfix(_fix):
|
||||
res = res.reduce_after_mul()
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def dot_product(cls, x, y, res_params=None):
|
||||
""" Secret dot product.
|
||||
|
||||
:param x: iterable of appropriate secret type
|
||||
:param y: iterable of appropriate secret type and same length """
|
||||
x, y = list(x), list(y)
|
||||
if res_params is None:
|
||||
if isinstance(x[0], cls.int_type):
|
||||
x, y = y, x
|
||||
if isinstance(y[0], cls.int_type):
|
||||
return cls._new(cls.int_type.dot_product((xx.v for xx in x), y),
|
||||
k=x[0].k, f=x[0].f)
|
||||
return super().dot_product(x, y, res_params)
|
||||
|
||||
def expand_to_vector(self, size):
|
||||
return self._new(self.v.expand_to_vector(size), k=self.k, f=self.f)
|
||||
|
||||
def coerce(self, other):
|
||||
return parse_type(other, k=self.k, f=self.f)
|
||||
|
||||
@@ -3426,7 +3533,7 @@ class sfix(_fix):
|
||||
|
||||
@staticmethod
|
||||
def multipliable(v, k, f, size):
|
||||
return cfix(cint.conv(v, size=size), k, f)
|
||||
return cfix._new(cint.conv(v, size=size), k, f)
|
||||
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
@@ -3436,8 +3543,8 @@ class sfix(_fix):
|
||||
:param player: public integer (int/regint/cint)
|
||||
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
|
||||
"""
|
||||
return personal(player, cfix(self.v.reveal_to(player)._v,
|
||||
self.k, self.f))
|
||||
return personal(player, cfix._new(self.v.reveal_to(player)._v,
|
||||
self.k, self.f))
|
||||
|
||||
class unreduced_sfix(_single):
|
||||
int_type = sint
|
||||
@@ -3466,10 +3573,9 @@ class unreduced_sfix(_single):
|
||||
|
||||
@vectorize
|
||||
def reduce_after_mul(self):
|
||||
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
|
||||
nearest=sfix.round_nearest,
|
||||
signed=True),
|
||||
k=self.k // 2, f=self.m)
|
||||
v = sfix.int_type.round(self.v, self.k, self.m, self.kappa,
|
||||
nearest=sfix.round_nearest, signed=True)
|
||||
return sfix._new(v, k=self.k // 2, f=self.m)
|
||||
|
||||
sfix.unreduced_type = unreduced_sfix
|
||||
|
||||
@@ -3696,8 +3802,9 @@ class sfloat(_number, _structure):
|
||||
return 4
|
||||
|
||||
@classmethod
|
||||
def malloc(cls, size):
|
||||
return program.malloc(size * cls.n_elements(), sint)
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return program.malloc(size * cls.n_elements(), sint,
|
||||
creator_tape=creator_tape)
|
||||
|
||||
@classmethod
|
||||
def is_address_tuple(cls, address):
|
||||
@@ -4142,12 +4249,14 @@ class Array(object):
|
||||
self.address = address
|
||||
self.address_cache = {}
|
||||
self.debug = debug
|
||||
self.creator_tape = program.curr_tape
|
||||
if alloc:
|
||||
self.alloc()
|
||||
|
||||
def alloc(self):
|
||||
if self.address is None:
|
||||
self.address = self.value_type.malloc(self.length)
|
||||
self.address = self.value_type.malloc(self.length,
|
||||
self.creator_tape)
|
||||
|
||||
def delete(self):
|
||||
self.value_type.free(self.address)
|
||||
@@ -4371,6 +4480,10 @@ class Array(object):
|
||||
|
||||
reveal_nested = reveal_list
|
||||
|
||||
def __str__(self):
|
||||
return '%s array of length %s at %s' % (self.value_type, len(self),
|
||||
self.address)
|
||||
|
||||
sint.dynamic_array = Array
|
||||
sgf2n.dynamic_array = Array
|
||||
|
||||
@@ -4471,6 +4584,12 @@ class SubMultiArray(object):
|
||||
return self.value_type.load_mem(self.address + base * part_size,
|
||||
size=size)
|
||||
|
||||
def assign_part_vector(self, vector, base=0):
|
||||
assert self.value_type.n_elements() == 1
|
||||
part_size = reduce(operator.mul, self.sizes[1:])
|
||||
assert vector.size <= self.total_size()
|
||||
vector.store_in_mem(self.address + base * part_size)
|
||||
|
||||
def get_addresses(self, *indices):
|
||||
assert self.value_type.n_elements() == 1
|
||||
assert len(indices) == len(self.sizes)
|
||||
@@ -4592,13 +4711,16 @@ class SubMultiArray(object):
|
||||
t = self.value_type
|
||||
res_matrix = Matrix(self.sizes[0], other.sizes[1], t)
|
||||
try:
|
||||
if max(res_matrix.sizes) > 1000:
|
||||
raise AttributeError()
|
||||
A = self.get_vector()
|
||||
B = other.get_vector()
|
||||
res_matrix.assign_vector(
|
||||
self.value_type.matrix_mul(A, B, self.sizes[1],
|
||||
res_params))
|
||||
try:
|
||||
res_matrix.assign_vector(self.direct_mul(other))
|
||||
except AttributeError:
|
||||
if max(res_matrix.sizes) > 1000:
|
||||
raise AttributeError()
|
||||
A = self.get_vector()
|
||||
B = other.get_vector()
|
||||
res_matrix.assign_vector(
|
||||
self.value_type.matrix_mul(A, B, self.sizes[1],
|
||||
res_params))
|
||||
except (AttributeError, AssertionError):
|
||||
# fallback for sfloat etc.
|
||||
@library.for_range_opt(self.sizes[0])
|
||||
@@ -4644,7 +4766,60 @@ class SubMultiArray(object):
|
||||
self.sizes[0], *other.sizes,
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def direct_mul_trans(self, other, reduce=True, indices=None):
|
||||
"""
|
||||
Matrix multiplication with the transpose of :py:obj:`other`
|
||||
in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
|
||||
:return: Matrix as vector of relevant type (row-major)
|
||||
|
||||
"""
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
if indices is None:
|
||||
assert self.sizes[1] == other.sizes[1]
|
||||
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
|
||||
assert len(indices[1]) == len(indices[2])
|
||||
indices = list(indices)
|
||||
indices[3] *= other.sizes[0]
|
||||
return self.value_type.direct_matrix_mul(
|
||||
self.address, other.address, None, self.sizes[1], 1,
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def direct_trans_mul(self, other, reduce=True, indices=None):
|
||||
"""
|
||||
Matrix multiplication with the transpose of :py:obj:`self`
|
||||
in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
|
||||
:return: Matrix as vector of relevant type (row-major)
|
||||
|
||||
"""
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
if indices is None:
|
||||
assert self.sizes[0] == other.sizes[0]
|
||||
indices = [regint.inc(i) for i in self.sizes[::-1] + other.sizes]
|
||||
assert len(indices[1]) == len(indices[2])
|
||||
indices = list(indices)
|
||||
indices[1] *= self.sizes[1]
|
||||
return self.value_type.direct_matrix_mul(
|
||||
self.address, other.address, None, 1, other.sizes[1],
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def direct_mul_to_matrix(self, other):
|
||||
""" Matrix multiplication in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:returns: :py:obj:`Matrix`
|
||||
|
||||
"""
|
||||
res = self.value_type.Matrix(self.sizes[0], other.sizes[1])
|
||||
res.assign_vector(self.direct_mul(other))
|
||||
return res
|
||||
@@ -4756,6 +4931,10 @@ class SubMultiArray(object):
|
||||
return [f(sizes[1:]) for i in range(sizes[0])]
|
||||
return f(self.sizes)
|
||||
|
||||
def __str__(self):
|
||||
return '%s multi-array of lengths %s at %s' % (self.value_type,
|
||||
self.sizes, self.address)
|
||||
|
||||
class MultiArray(SubMultiArray):
|
||||
""" Multidimensional array. """
|
||||
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
|
||||
@@ -4990,25 +5169,15 @@ class MemValue(_mem):
|
||||
return 'MemValue(%s,%d)' % (self.value_type, self.address)
|
||||
|
||||
|
||||
class MemFloat(_mem):
|
||||
class MemFloat(MemValue):
|
||||
def __init__(self, *args):
|
||||
value = sfloat(*args)
|
||||
self.v = MemValue(value.v)
|
||||
self.p = MemValue(value.p)
|
||||
self.z = MemValue(value.z)
|
||||
self.s = MemValue(value.s)
|
||||
super().__init__(sfloat(*args))
|
||||
|
||||
def write(self, *args):
|
||||
value = sfloat(*args)
|
||||
self.v.write(value.v)
|
||||
self.p.write(value.p)
|
||||
self.z.write(value.z)
|
||||
self.s.write(value.s)
|
||||
super().write(value)
|
||||
|
||||
def read(self):
|
||||
return sfloat(self.v, self.p, self.z, self.s)
|
||||
|
||||
class MemFix(_mem):
|
||||
class MemFix(MemValue):
|
||||
def __init__(self, *args):
|
||||
arg_type = type(*args)
|
||||
if arg_type == sfix:
|
||||
@@ -5017,22 +5186,10 @@ class MemFix(_mem):
|
||||
value = cfix(*args)
|
||||
else:
|
||||
raise CompilerError('MemFix init argument error')
|
||||
self.reg_type = value.v.reg_type
|
||||
self.v = MemValue(value.v)
|
||||
super().__init__(value)
|
||||
|
||||
def write(self, *args):
|
||||
value = sfix(*args)
|
||||
self.v.write(value.v)
|
||||
|
||||
def reveal(self):
|
||||
return cfix(self.v.reveal())
|
||||
|
||||
def read(self):
|
||||
val = self.v.read()
|
||||
if isinstance(val, sint):
|
||||
return sfix(val)
|
||||
else:
|
||||
return cfix(val)
|
||||
super().write(self.value_type(*args))
|
||||
|
||||
def getNamedTupleType(*names):
|
||||
class NamedTuple(object):
|
||||
|
||||
@@ -174,8 +174,17 @@ def is_all_ones(x, n):
|
||||
else:
|
||||
return False
|
||||
|
||||
def max(x, y):
|
||||
return if_else(x > y, x, y)
|
||||
def max(x, y=None):
|
||||
if y is None:
|
||||
return tree_reduce(max, x)
|
||||
else:
|
||||
return if_else(x > y, x, y)
|
||||
|
||||
def min(x, y=None):
|
||||
if y is None:
|
||||
return tree_reduce(min, x)
|
||||
else:
|
||||
return if_else(x < y, x, y)
|
||||
|
||||
def long_one(x):
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ECDSA/P256Element.h"
|
||||
#include "Tools/mkpath.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
@@ -90,8 +90,9 @@ class Offline_Check_Error: public runtime_error
|
||||
runtime_error("Offline-Check-Error : " + m) {}
|
||||
};
|
||||
class mac_fail: public bad_value
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "MacCheck Failure"; }
|
||||
{
|
||||
public:
|
||||
mac_fail(string msg = "MacCheck Failure") : bad_value(msg) {}
|
||||
};
|
||||
class consistency_check_fail: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
|
||||
@@ -64,7 +64,7 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
|
||||
int n_bits = summands.size();
|
||||
for (size_t i = begin; i < end; i++)
|
||||
res[i].resize(n_bits + 1);
|
||||
res.at(i).resize(n_bits + 1);
|
||||
|
||||
size_t n_items = end - begin;
|
||||
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
#include "GC/square64.h"
|
||||
|
||||
#include "GC/Processor.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
SwitchableOutput FakeSecret::out;
|
||||
const int FakeSecret::default_length;
|
||||
|
||||
void FakeSecret::load_clear(int n, const Integer& x)
|
||||
@@ -87,6 +88,14 @@ FakeSecret FakeSecret::input(int from, word input, int n_bits)
|
||||
return input;
|
||||
}
|
||||
|
||||
void FakeSecret::inputbvec(Processor<FakeSecret>& processor,
|
||||
ProcessorBase& input_processor, const vector<int>& args)
|
||||
{
|
||||
Input input;
|
||||
input.reset_all(*ShareThread<FakeSecret>::s().P);
|
||||
processor.inputbvec(input, input_processor, args, 0);
|
||||
}
|
||||
|
||||
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
|
||||
bool repeat)
|
||||
{
|
||||
@@ -96,4 +105,19 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
|
||||
*this = BitVec(x & y).mask(n);
|
||||
}
|
||||
|
||||
void FakeSecret::my_input(Input& inputter, BitVec value, int n_bits)
|
||||
{
|
||||
inputter.add_mine(value, n_bits);
|
||||
}
|
||||
|
||||
void FakeSecret::other_input(Input&, int, int)
|
||||
{
|
||||
throw runtime_error("emulation is supposed to be lonely");
|
||||
}
|
||||
|
||||
void FakeSecret::finalize_input(Input& inputter, int from, int n_bits)
|
||||
{
|
||||
*this = inputter.finalize(from, n_bits);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#include <random>
|
||||
#include <fstream>
|
||||
|
||||
class ProcessorBase;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -53,7 +55,10 @@ public:
|
||||
typedef FakeProtocol<FakeSecret> Protocol;
|
||||
typedef FakeInput<FakeSecret> Input;
|
||||
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
static string type_string() { return "fake secret"; }
|
||||
static string type_short() { return "emulB"; }
|
||||
static string phase_name() { return "Faking"; }
|
||||
|
||||
static const int default_length = 64;
|
||||
@@ -62,7 +67,8 @@ public:
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
static SwitchableOutput out;
|
||||
static const true_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2; }
|
||||
|
||||
@@ -87,8 +93,8 @@ public:
|
||||
template <class T>
|
||||
static void inputb(T& processor, ArithmeticProcessor&, const vector<int>& args)
|
||||
{ processor.input(args); }
|
||||
template <class T, class U>
|
||||
static void inputbvec(T&, U&, const vector<int>&) { throw not_implemented(); }
|
||||
static void inputbvec(Processor<FakeSecret>& processor,
|
||||
ProcessorBase& input_processor, const vector<int>& args);
|
||||
template <class T>
|
||||
static void reveal_inst(T& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
@@ -136,6 +142,14 @@ public:
|
||||
void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; }
|
||||
|
||||
void invert(FakeSecret) { throw not_implemented(); }
|
||||
|
||||
void input(istream&, bool) { throw not_implemented(); }
|
||||
|
||||
bool operator<(FakeSecret) const { return false; }
|
||||
|
||||
void my_input(Input& inputter, BitVec value, int n_bits);
|
||||
void other_input(Input& inputter, int from, int n_bits = 1);
|
||||
void finalize_input(Input& inputter, int from, int n_bits);
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -47,13 +47,6 @@ template<class T>
|
||||
void Machine<T>::load_schedule(string progname)
|
||||
{
|
||||
BaseMachine::load_schedule(progname);
|
||||
for (auto i : {1, 0, 0})
|
||||
{
|
||||
int n;
|
||||
inpf >> n;
|
||||
if (n != i)
|
||||
throw runtime_error("old schedule format not supported");
|
||||
}
|
||||
print_compiler();
|
||||
}
|
||||
|
||||
|
||||
20
GC/NoShare.h
20
GC/NoShare.h
@@ -7,6 +7,7 @@
|
||||
#define GC_NOSHARE_H_
|
||||
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "BMR/Register.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
|
||||
class InputArgs;
|
||||
@@ -41,6 +42,11 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
static string type_string()
|
||||
{
|
||||
return "no";
|
||||
}
|
||||
|
||||
static void fail()
|
||||
{
|
||||
throw runtime_error("VM does not support binary circuits");
|
||||
@@ -93,8 +99,6 @@ public:
|
||||
static const bool expensive_triples = false;
|
||||
static const bool is_real = false;
|
||||
|
||||
static SwitchableOutput out;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
return new MC;
|
||||
@@ -130,7 +134,7 @@ public:
|
||||
NoValue::fail();
|
||||
}
|
||||
|
||||
static void inputb(Processor<NoShare>&, ArithmeticProcessor&, const vector<int>&) { fail(); }
|
||||
static void inputb(Processor<NoShare>&, const ArithmeticProcessor&, const vector<int>&) { fail(); }
|
||||
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
@@ -139,6 +143,10 @@ public:
|
||||
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static void xors(Processor<NoShare>&, vector<int>) { fail(); }
|
||||
static void ands(Processor<NoShare>&, vector<int>) { fail(); }
|
||||
static void andrs(Processor<NoShare>&, vector<int>) { fail(); }
|
||||
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
@@ -161,8 +169,8 @@ public:
|
||||
void operator^=(NoShare) { fail(); }
|
||||
|
||||
NoShare operator+(const NoShare&) const { fail(); return {}; }
|
||||
NoShare operator-(NoShare) const { fail(); return 0; }
|
||||
NoShare operator*(NoValue) const { fail(); return 0; }
|
||||
NoShare operator-(const NoShare&) const { fail(); return {}; }
|
||||
NoShare operator*(const NoValue&) const { fail(); return {}; }
|
||||
|
||||
NoShare operator+(int) const { fail(); return {}; }
|
||||
NoShare operator&(int) const { fail(); return {}; }
|
||||
@@ -172,6 +180,8 @@ public:
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void invert(int, NoShare) { fail(); }
|
||||
|
||||
void input(istream&, bool) { fail(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -44,6 +44,8 @@ public:
|
||||
|
||||
Timer xor_timer;
|
||||
|
||||
typename T::out_type out;
|
||||
|
||||
Processor(Machine<T>& machine);
|
||||
Processor(Memories<T>& memories, Machine<T>* machine = 0);
|
||||
~Processor();
|
||||
|
||||
@@ -301,15 +301,15 @@ void Processor<T>::print_reg(int reg, int n, int size)
|
||||
bigint output;
|
||||
for (int i = 0; i < size; i++)
|
||||
output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i);
|
||||
T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
|
||||
out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
|
||||
print_str(n);
|
||||
T::out << endl << flush;
|
||||
out << endl << flush;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg_plain(Clear& value)
|
||||
{
|
||||
T::out << hex << showbase << value << dec << flush;
|
||||
out << hex << showbase << value << dec << flush;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -323,7 +323,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
if (n_shift > 63)
|
||||
n_shift = 0;
|
||||
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
|
||||
out << dec << (value.get() << n_shift >> n_shift) << flush;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -334,26 +334,26 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
|
||||
}
|
||||
if (tmp >= bigint(1) << (n_bits - 1))
|
||||
tmp -= bigint(1) << n_bits;
|
||||
T::out << dec << tmp << flush;
|
||||
out << dec << tmp << flush;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_chr(int n)
|
||||
{
|
||||
T::out << (char)n << flush;
|
||||
out << (char)n << flush;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_str(int n)
|
||||
{
|
||||
T::out << string((char*)&n,sizeof(n)) << flush;
|
||||
out << string((char*)&n,sizeof(n)) << flush;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_float(const vector<int>& args)
|
||||
{
|
||||
bigint::output_float(T::out,
|
||||
bigint::output_float(out,
|
||||
bigint::get_float(C[args[0]], C[args[1]], C[args[2]], C[args[3]]),
|
||||
C[args[4]]);
|
||||
}
|
||||
@@ -361,7 +361,7 @@ void Processor<T>::print_float(const vector<int>& args)
|
||||
template <class T>
|
||||
void Processor<T>::print_float_prec(int n)
|
||||
{
|
||||
T::out << setprecision(n);
|
||||
out << setprecision(n);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
26
GC/Rep4Secret.cpp
Normal file
26
GC/Rep4Secret.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Rep4Secret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REP4SECRET_CPP_
|
||||
#define GC_REP4SECRET_CPP_
|
||||
|
||||
#include "Rep4Secret.h"
|
||||
|
||||
#include "ShareSecret.hpp"
|
||||
#include "ShareThread.hpp"
|
||||
#include "Protocols/Rep4MC.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
void Rep4Secret::load_clear(int n, const Integer& x)
|
||||
{
|
||||
this->check_length(n, x);
|
||||
*this = constant(x, ShareThread<This>::s().P->my_num());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_REP4SECRET_CPP_ */
|
||||
53
GC/Rep4Secret.h
Normal file
53
GC/Rep4Secret.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Rep4Secret.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REP4SECRET_H_
|
||||
#define GC_REP4SECRET_H_
|
||||
|
||||
#include "ShareSecret.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
#include "Protocols/Rep4MC.h"
|
||||
#include "Protocols/Rep4Share.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class Rep4Secret : public RepSecretBase<Rep4Secret, 3>
|
||||
{
|
||||
typedef RepSecretBase<Rep4Secret, 3> super;
|
||||
typedef Rep4Secret This;
|
||||
|
||||
public:
|
||||
typedef DummyLivePrep<This> LivePrep;
|
||||
typedef Rep4<This> Protocol;
|
||||
typedef Rep4MC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef Rep4Input<This> Input;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
|
||||
static MC* new_mc(typename super::mac_key_type) { return new MC; }
|
||||
|
||||
static This constant(const typename super::clear& constant, int my_num,
|
||||
typename super::mac_key_type = {})
|
||||
{
|
||||
return Rep4Share<typename super::clear>::constant(constant, my_num);
|
||||
}
|
||||
|
||||
Rep4Secret()
|
||||
{
|
||||
}
|
||||
template <class T>
|
||||
Rep4Secret(const T& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void load_clear(int n, const Integer& x);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_REP4SECRET_H_ */
|
||||
@@ -62,13 +62,13 @@ public:
|
||||
|
||||
typedef typename T::Input Input;
|
||||
|
||||
typedef typename T::out_type out_type;
|
||||
|
||||
static string type_string() { return "evaluation secret"; }
|
||||
static string phase_name() { return T::name(); }
|
||||
|
||||
static int default_length;
|
||||
|
||||
static typename T::out_type out;
|
||||
|
||||
static const bool needs_ot = false;
|
||||
|
||||
static const bool is_real = true;
|
||||
@@ -170,9 +170,6 @@ public:
|
||||
template <class T>
|
||||
int Secret<T>::default_length = 64;
|
||||
|
||||
template <class T>
|
||||
typename T::out_type Secret<T>::out = T::out;
|
||||
|
||||
template <class T>
|
||||
inline ostream& operator<<(ostream& o, Secret<T>& secret)
|
||||
{
|
||||
|
||||
@@ -58,10 +58,11 @@ SemiPrep::~SemiPrep()
|
||||
|
||||
void SemiPrep::buffer_bits()
|
||||
{
|
||||
auto& thread = Thread<SemiSecret>::s();
|
||||
word r = thread.secure_prng.get_word();
|
||||
word r = secure_prng.get_word();
|
||||
for (size_t i = 0; i < sizeof(word) * 8; i++)
|
||||
{
|
||||
this->bits.push_back((r >> i) & 1);
|
||||
}
|
||||
}
|
||||
|
||||
size_t SemiPrep::data_sent()
|
||||
|
||||
@@ -23,6 +23,8 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
|
||||
SemiSecret::TripleGenerator* triple_generator;
|
||||
MascotParams params;
|
||||
|
||||
SeededPRNG secure_prng;
|
||||
|
||||
public:
|
||||
SemiPrep(DataPositions& usage, ShareThread<SemiSecret>& thread);
|
||||
SemiPrep(DataPositions& usage, bool = true);
|
||||
|
||||
@@ -80,8 +80,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
|
||||
|
||||
this->machine.more_comm_less_comp = opt.get("-c")->isSet;
|
||||
|
||||
T::out.activate(my_num == 0 or online_opts.interactive);
|
||||
|
||||
if (not this->machine.use_encryption and not T::dishonest_majority)
|
||||
insecure("unencrypted communication");
|
||||
|
||||
|
||||
107
GC/ShareSecret.h
107
GC/ShareSecret.h
@@ -44,8 +44,6 @@ public:
|
||||
static const bool is_real = true;
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
static SwitchableOutput out;
|
||||
|
||||
static void store_clear_in_dynamic(Memory<U>& mem,
|
||||
const vector<ClearWriteAccess>& accesses);
|
||||
|
||||
@@ -83,21 +81,26 @@ public:
|
||||
void other_input(T& inputter, int from, int n_bits = 1);
|
||||
template<class T>
|
||||
void finalize_input(T& inputter, int from, int n_bits);
|
||||
|
||||
U& operator=(const U&);
|
||||
};
|
||||
|
||||
template<class U>
|
||||
class ReplicatedSecret : public FixedVec<BitVec, 2>, public ShareSecret<U>
|
||||
template<class U, int L>
|
||||
class RepSecretBase : public FixedVec<BitVec, L>, public ShareSecret<U>
|
||||
{
|
||||
typedef FixedVec<BitVec, 2> super;
|
||||
typedef FixedVec<BitVec, L> super;
|
||||
typedef RepSecretBase This;
|
||||
|
||||
public:
|
||||
typedef U part_type;
|
||||
typedef U small_type;
|
||||
typedef U whole_type;
|
||||
|
||||
typedef BitVec clear;
|
||||
typedef BitVec open_type;
|
||||
typedef BitVec mac_type;
|
||||
typedef BitVec mac_key_type;
|
||||
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
static const int N_BITS = clear::N_BITS;
|
||||
@@ -109,7 +112,7 @@ public:
|
||||
static string type_string() { return "replicated secret"; }
|
||||
static string phase_name() { return "Replicated computation"; }
|
||||
|
||||
static const int default_length = 8 * sizeof(typename ReplicatedSecret<U>::value_type);
|
||||
static const int default_length = N_BITS;
|
||||
|
||||
static int threshold(int)
|
||||
{
|
||||
@@ -124,9 +127,45 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
static void read_or_generate_mac_key(string, const Names&, mac_key_type) {}
|
||||
static void read_or_generate_mac_key(string, const Player&, mac_key_type)
|
||||
{
|
||||
}
|
||||
|
||||
static ReplicatedSecret constant(const clear& value, int my_num, mac_key_type)
|
||||
RepSecretBase()
|
||||
{
|
||||
}
|
||||
template <class T>
|
||||
RepSecretBase(const T& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const This& x, const This& y)
|
||||
{ *this = x ^ y; (void)n; }
|
||||
|
||||
This operator&(const Clear& other)
|
||||
{ return super::operator&(BitVec(other)); }
|
||||
|
||||
This lsb()
|
||||
{ return *this & 1; }
|
||||
|
||||
This get_bit(int i)
|
||||
{ return (*this >> i) & 1; }
|
||||
};
|
||||
|
||||
template<class U>
|
||||
class ReplicatedSecret : public RepSecretBase<U, 2>
|
||||
{
|
||||
typedef RepSecretBase<U, 2> super;
|
||||
|
||||
public:
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
|
||||
typename super::mac_key_type)
|
||||
{
|
||||
ReplicatedSecret res;
|
||||
if (my_num < 2)
|
||||
@@ -140,28 +179,44 @@ public:
|
||||
|
||||
void load_clear(int n, const Integer& x);
|
||||
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
|
||||
BitVec local_mul(const ReplicatedSecret& other) const;
|
||||
|
||||
void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y)
|
||||
{ *this = x ^ y; (void)n; }
|
||||
|
||||
void reveal(size_t n_bits, Clear& x);
|
||||
|
||||
ReplicatedSecret operator&(const Clear& other)
|
||||
{ return super::operator&(BitVec(other)); }
|
||||
|
||||
ReplicatedSecret lsb()
|
||||
{ return *this & 1; }
|
||||
|
||||
ReplicatedSecret get_bit(int i)
|
||||
{ return (*this >> i) & 1; }
|
||||
};
|
||||
|
||||
class SemiHonestRepPrep;
|
||||
|
||||
class SmallRepSecret : public FixedVec<BitVec_<unsigned char>, 2>
|
||||
{
|
||||
typedef FixedVec<BitVec_<unsigned char>, 2> super;
|
||||
typedef SmallRepSecret This;
|
||||
|
||||
public:
|
||||
typedef ReplicatedMC<This> MC;
|
||||
typedef BitVec_<unsigned char> open_type;
|
||||
typedef open_type clear;
|
||||
typedef BitVec mac_key_type;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
return new MC;
|
||||
}
|
||||
|
||||
SmallRepSecret()
|
||||
{
|
||||
}
|
||||
template<class T>
|
||||
SmallRepSecret(const T& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
This lsb() const
|
||||
{
|
||||
return *this & 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SemiHonestRepSecret : public ReplicatedSecret<SemiHonestRepSecret>
|
||||
{
|
||||
typedef ReplicatedSecret<SemiHonestRepSecret> super;
|
||||
@@ -176,7 +231,7 @@ public:
|
||||
typedef ReplicatedInput<SemiHonestRepSecret> Input;
|
||||
|
||||
typedef SemiHonestRepSecret part_type;
|
||||
typedef SemiHonestRepSecret small_type;
|
||||
typedef SmallRepSecret small_type;
|
||||
typedef SemiHonestRepSecret whole_type;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
|
||||
@@ -25,14 +25,11 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class U>
|
||||
const int ReplicatedSecret<U>::N_BITS;
|
||||
template<class U, int L>
|
||||
const int RepSecretBase<U, L>::N_BITS;
|
||||
|
||||
template<class U>
|
||||
const int ReplicatedSecret<U>::default_length;
|
||||
|
||||
template<class U>
|
||||
SwitchableOutput ShareSecret<U>::out;
|
||||
template<class U, int L>
|
||||
const int RepSecretBase<U, L>::default_length;
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::check_length(int n, const Integer& x)
|
||||
@@ -59,16 +56,16 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
|
||||
*this = x;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
@@ -285,12 +282,11 @@ void ShareSecret<U>::xors(Processor<U>& processor, const vector<int>& args)
|
||||
ShareThread<U>::s().xors(processor, args);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ReplicatedSecret<U>::trans(Processor<U>& processor,
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::trans(Processor<U>& processor,
|
||||
int n_outputs, const vector<int>& args)
|
||||
{
|
||||
assert(length == 2);
|
||||
for (int k = 0; k < 2; k++)
|
||||
for (int k = 0; k < L; k++)
|
||||
{
|
||||
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
|
||||
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
|
||||
@@ -330,6 +326,14 @@ void ShareSecret<U>::random_bit()
|
||||
*this = res;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
U& GC::ShareSecret<U>::operator=(const U& other)
|
||||
{
|
||||
U& real_this = static_cast<U&>(*this);
|
||||
real_this = other;
|
||||
return real_this;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -54,7 +54,13 @@ void Thread<T>::run()
|
||||
P = new CryptoPlayer(N, thread_num << 16);
|
||||
else
|
||||
P = new PlainPlayer(N, thread_num << 16);
|
||||
processor.open_input_file(N.my_num(), thread_num);
|
||||
processor.open_input_file(N.my_num(), thread_num,
|
||||
master.opts.cmd_private_input_file);
|
||||
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
|
||||
processor.setup_redirection(P->my_num(), thread_num, master.opts);
|
||||
if (processor.stdout_redirect_file.is_open())
|
||||
processor.out.redirect_to_file(processor.stdout_redirect_file);
|
||||
|
||||
done.push(0);
|
||||
pre_run();
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ public:
|
||||
|
||||
typedef typename part_type::sacri_type sacri_type;
|
||||
typedef typename part_type::mac_type mac_type;
|
||||
typedef typename part_type::mac_share_type mac_share_type;
|
||||
typedef BitDiagonal Rectangle;
|
||||
|
||||
typedef typename T::super check_type;
|
||||
@@ -152,6 +153,11 @@ public:
|
||||
reg.output(s, human);
|
||||
}
|
||||
|
||||
void input(istream&, bool)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void my_input(U& inputter, BitVec value, int n_bits)
|
||||
{
|
||||
|
||||
@@ -129,7 +129,7 @@
|
||||
X(GLDMC, ) \
|
||||
X(LDMS, ) \
|
||||
X(LDMC, ) \
|
||||
X(PRINTINT, S0.out << I0) \
|
||||
X(PRINTINT, PROC.out << I0) \
|
||||
X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \
|
||||
X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \
|
||||
X(RUN_TAPE, MACH->run_tapes(EXTRA)) \
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "Processor/config.h"
|
||||
#include "Protocols/Share.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
#include "Player-Online.hpp"
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
#include "Rep.hpp"
|
||||
#include "Protocols/Spdz2kPrep.hpp"
|
||||
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
|
||||
|
||||
@@ -16,16 +16,14 @@
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/FakeShare.hpp"
|
||||
|
||||
SwitchableOutput GC::NoShare::out;
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
assert(argc > 1);
|
||||
OnlineOptions online_opts;
|
||||
Names N(0, 9999, vector<string>({"localhost"}));
|
||||
Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector<string>({"localhost"}));
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
opt.parse(argc, argv);
|
||||
opt.syntax = string(argv[0]) + " <progname>";
|
||||
string progname;
|
||||
if (opt.firstArgs.size() > 1)
|
||||
progname = *opt.firstArgs.at(1);
|
||||
@@ -41,7 +39,16 @@ int main(int argc, const char** argv)
|
||||
exit(1);
|
||||
}
|
||||
|
||||
switch (ring_opts.R)
|
||||
#ifdef ROUND_NEAREST_IN_EMULATION
|
||||
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;
|
||||
#else
|
||||
#ifdef RISKY_TRUNCATION_IN_EMULATION
|
||||
cerr << "Using risky truncation" << endl;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
int R = ring_opts.ring_size_from_opts_or_schedule(progname);
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
|
||||
@@ -53,7 +60,27 @@ int main(int argc, const char** argv)
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 256:
|
||||
Machine<FakeShare<SignedZ2<256>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 192:
|
||||
Machine<FakeShare<SignedZ2<192>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 384:
|
||||
Machine<FakeShare<SignedZ2<384>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 512:
|
||||
Machine<FakeShare<SignedZ2<512>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
default:
|
||||
cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl;
|
||||
cerr << "Not compiled for " << R << "-bit rings" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
#include "Player-Online.hpp"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gfp.hpp"
|
||||
#include "GC/TinierSecret.h"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -13,14 +13,27 @@
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions opts(opt, argc, argv);
|
||||
RingOptions opts(opt, argc, argv, true);
|
||||
switch (opts.R)
|
||||
{
|
||||
case 64:
|
||||
ReplicatedMachine<PostSacriRepRingShare<64, 40>, PostSacriRepFieldShare<gf2n>>(
|
||||
argc, argv, opt);
|
||||
break;
|
||||
case 72:
|
||||
switch (opts.S)
|
||||
{
|
||||
case 40:
|
||||
ReplicatedMachine<PostSacriRepRingShare<64, 40>,
|
||||
PostSacriRepFieldShare<gf2n>>(argc, argv, opt);
|
||||
break;
|
||||
case 64:
|
||||
ReplicatedMachine<PostSacriRepRingShare<64, 64>,
|
||||
PostSacriRepFieldShare<gf2n>>(argc, argv, opt);
|
||||
break;
|
||||
default:
|
||||
cerr << "Security parameter " << opts.S << " not implemented"
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
break;
|
||||
case 72:
|
||||
ReplicatedMachine<PostSacriRepRingShare<72, 40>, PostSacriRepFieldShare<gf2n>>(
|
||||
argc, argv, opt);
|
||||
break;
|
||||
|
||||
38
Machines/rep4-ring-party.cpp
Normal file
38
Machines/rep4-ring-party.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* rep4-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/Rep4Share2k.h"
|
||||
#include "Protocols/Rep4Share.h"
|
||||
#include "Protocols/Rep4MC.h"
|
||||
#include "Protocols/ReplicatedMachine.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "GC/Rep4Secret.h"
|
||||
#include "Processor/RingOptions.h"
|
||||
|
||||
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
|
||||
#include "Protocols/ReplicatedMachine.hpp"
|
||||
#include "Protocols/Rep4Input.hpp"
|
||||
#include "Protocols/Rep4Prep.hpp"
|
||||
#include "Protocols/Rep4MC.hpp"
|
||||
#include "Protocols/Rep4.hpp"
|
||||
#include "GC/BitAdder.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Rep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
switch (ring_opts.R)
|
||||
{
|
||||
#define X(R) case R: ReplicatedMachine<Rep4Share2<R>, Rep4Share<gf2n>>(argc, argv, opt, 4); break;
|
||||
X(64) X(80) X(88)
|
||||
default:
|
||||
cerr << ring_opts.R << "-bit computation not implemented" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,10 @@ int main(int argc, const char** argv)
|
||||
ReplicatedMachine<Rep3Share2<72>, Rep3Share<gf2n>>(argc, argv,
|
||||
"replicated-ring", opt);
|
||||
break;
|
||||
case 128:
|
||||
ReplicatedMachine<Rep3Share2<128>, Rep3Share<gf2n>>(argc, argv,
|
||||
"replicated-ring", opt);
|
||||
break;
|
||||
default:
|
||||
throw runtime_error(to_string(opts.R) + "-bit computation not implemented");
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "Player-Online.hpp"
|
||||
#include "Semi.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "Networking/Server.h"
|
||||
|
||||
#include "Player-Online.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
41
Machines/sy-rep-field-party.cpp
Normal file
41
Machines/sy-rep-field-party.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* sy-rep-field-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/SpdzWiseShare.h"
|
||||
#include "Protocols/MaliciousRep3Share.h"
|
||||
#include "Protocols/ReplicatedMachine.h"
|
||||
#include "Protocols/MAC_Check.h"
|
||||
#include "Protocols/SpdzWiseMC.h"
|
||||
#include "Protocols/SpdzWisePrep.h"
|
||||
#include "Protocols/SpdzWiseInput.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
|
||||
#include "Protocols/ReplicatedMachine.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/SpdzWise.hpp"
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/RepPrep.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opts;
|
||||
ReplicatedMachine<SpdzWiseShare<MaliciousRep3Share<gfp>>,
|
||||
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opts);
|
||||
}
|
||||
68
Machines/sy-rep-ring-party.cpp
Normal file
68
Machines/sy-rep-ring-party.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* sy-rep-ring-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/ReplicatedMachine.h"
|
||||
#include "Protocols/SpdzWiseRingShare.h"
|
||||
#include "Protocols/MaliciousRep3Share.h"
|
||||
#include "Protocols/SpdzWiseMC.h"
|
||||
#include "Protocols/SpdzWiseRingPrep.h"
|
||||
#include "Protocols/SpdzWiseInput.h"
|
||||
#include "Protocols/MalRepRingPrep.h"
|
||||
#include "Processor/RingOptions.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
|
||||
#include "Protocols/ReplicatedMachine.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/SpdzWise.hpp"
|
||||
#include "Protocols/SpdzWiseRing.hpp"
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Protocols/PostSacrifice.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Protocols/MaliciousRepPrep.hpp"
|
||||
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/RepPrep.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions opts(opt, argc, argv, true);
|
||||
switch (opts.R)
|
||||
{
|
||||
case 64:
|
||||
switch (opts.S)
|
||||
{
|
||||
case 40:
|
||||
ReplicatedMachine<SpdzWiseRingShare<64, 40>,
|
||||
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
|
||||
break;
|
||||
case 64:
|
||||
ReplicatedMachine<SpdzWiseRingShare<64, 64>,
|
||||
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
|
||||
break;
|
||||
default:
|
||||
cerr << "Security parameter " << opts.S << " not implemented"
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
break;
|
||||
case 72:
|
||||
ReplicatedMachine<SpdzWiseRingShare<72, 40>,
|
||||
SpdzWiseShare<MaliciousRep3Share<gf2n>>>(argc, argv, opt);
|
||||
break;
|
||||
default:
|
||||
throw runtime_error(
|
||||
to_string(opts.R) + "-bit computation not implemented");
|
||||
}
|
||||
}
|
||||
33
Machines/sy-shamir-party.cpp
Normal file
33
Machines/sy-shamir-party.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* sy-shamir-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ShamirMachine.h"
|
||||
#include "Protocols/ReplicatedMachine.h"
|
||||
#include "Protocols/SpdzWiseShare.h"
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
#include "Protocols/SpdzWiseMC.h"
|
||||
#include "Protocols/SpdzWiseInput.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "GC/CcdSecret.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/SpdzWise.hpp"
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
auto& opts = ShamirOptions::singleton;
|
||||
ez::ezOptionParser opt;
|
||||
opts = {opt, argc, argv};
|
||||
ReplicatedMachine<SpdzWiseShare<MaliciousShamirShare<gfp>>,
|
||||
SpdzWiseShare<MaliciousShamirShare<gf2n>>>(
|
||||
argc, argv,
|
||||
{ }, opt, opts.nparties);
|
||||
}
|
||||
14
Makefile
14
Makefile
@@ -42,7 +42,7 @@ all: arithmetic binary gen_input online offline externalIO bmr doc
|
||||
doc:
|
||||
cd doc; $(MAKE) html
|
||||
|
||||
arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot
|
||||
arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy
|
||||
binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
@@ -77,7 +77,7 @@ overdrive: simple-offline.x pairwise-offline.x cnc-offline.x
|
||||
|
||||
rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x
|
||||
|
||||
rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x Fake-Offline.x
|
||||
rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x rep4-ring-party.x
|
||||
|
||||
rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x
|
||||
|
||||
@@ -98,10 +98,12 @@ endif
|
||||
|
||||
shamir: shamir-party.x malicious-shamir-party.x galois-degree.x
|
||||
|
||||
sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
|
||||
|
||||
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
|
||||
$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC)
|
||||
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMON) $(BMR) $(GC)
|
||||
$(AR) -csr $@ $^
|
||||
|
||||
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
|
||||
@@ -184,12 +186,18 @@ hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
|
||||
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
|
||||
static/hemi-party.x: $(FHEOFFLINE)
|
||||
static/soho-party.x: $(FHEOFFLINE)
|
||||
static/cowgear-party.x: $(FHEOFFLINE)
|
||||
static/chaigear-party.x: $(FHEOFFLINE)
|
||||
mascot-party.x: Machines/SPDZ.o $(OT)
|
||||
static/mascot-party.x: Machines/SPDZ.o
|
||||
Player-Online.x: Machines/SPDZ.o $(OT)
|
||||
mama-party.x: $(OT)
|
||||
ps-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
sy-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
rep4-ring-party.x: GC/Rep4Secret.o
|
||||
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o
|
||||
mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
|
||||
fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
|
||||
|
||||
@@ -25,6 +25,7 @@ public:
|
||||
static const int n_bits = sizeof(T) * 8;
|
||||
|
||||
static char type_char() { return 'B'; }
|
||||
static string type_short() { return "B"; }
|
||||
static DataFieldType field_type() { return DATA_GF2; }
|
||||
|
||||
static bool allows(Dtype dtype) { return dtype == DATA_TRIPLE or dtype == DATA_BIT; }
|
||||
@@ -59,7 +60,7 @@ public:
|
||||
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
|
||||
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
|
||||
|
||||
void pack(octetStream& os, int n) const { os.store_int(this->a, DIV_CEIL(n, 8)); }
|
||||
void pack(octetStream& os, int n) const { os.store_int(mask(n).a, DIV_CEIL(n, 8)); }
|
||||
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
|
||||
|
||||
static BitVec_ unpack_new(octetStream& os, int n = n_bits)
|
||||
|
||||
@@ -156,11 +156,6 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator!=(const FixedVec& other) const
|
||||
{
|
||||
return not equal(other);
|
||||
}
|
||||
|
||||
bool is_zero()
|
||||
{
|
||||
return equal(0);
|
||||
@@ -170,6 +165,11 @@ public:
|
||||
return equal(1);
|
||||
}
|
||||
|
||||
bool operator!=(const FixedVec<T, L>& other) const
|
||||
{
|
||||
return not equal(other);
|
||||
}
|
||||
|
||||
FixedVec<T, L>operator+(const FixedVec<T, L>& other) const
|
||||
{
|
||||
FixedVec<T, L> res;
|
||||
@@ -291,6 +291,15 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
T lazy_sum() const
|
||||
{
|
||||
assert(L > 1);
|
||||
T res = v[0].lazy_add(v[1]);
|
||||
for (int i = 2; i < L; i++)
|
||||
res = res.lazy_add(v[i]);
|
||||
return res;
|
||||
}
|
||||
|
||||
FixedVec<T, L> extend_bit() const
|
||||
{
|
||||
FixedVec<T, L> res;
|
||||
@@ -343,13 +352,21 @@ public:
|
||||
|
||||
void output(ostream& s, bool human) const
|
||||
{
|
||||
for (auto& x : v)
|
||||
x.output(s, human);
|
||||
if (human)
|
||||
s << *this;
|
||||
else
|
||||
for (auto& x : v)
|
||||
x.output(s, human);
|
||||
}
|
||||
void input(istream& s, bool human)
|
||||
{
|
||||
for (auto& x : v)
|
||||
x.input(s, human);
|
||||
for (int i = 0; i < L; i++)
|
||||
{
|
||||
if (human and i != 0)
|
||||
if (s.get() != ',')
|
||||
throw runtime_error("cannot read vector");
|
||||
(*this)[i].input(s, human);
|
||||
}
|
||||
}
|
||||
|
||||
void pack(octetStream& os) const
|
||||
|
||||
@@ -17,6 +17,8 @@ class ValueInterface
|
||||
public:
|
||||
static const int MAX_EDABITS = 0;
|
||||
|
||||
static const false_type characteristic_two;
|
||||
|
||||
template<class T>
|
||||
static void init(bool mont = true) { (void) mont; }
|
||||
static void init_default(int l) { (void) l; }
|
||||
|
||||
28
Math/Z2k.h
28
Math/Z2k.h
@@ -62,13 +62,14 @@ public:
|
||||
static int t() { return 0; }
|
||||
|
||||
static char type_char() { return 'R'; }
|
||||
static string type_short() { return "R"; }
|
||||
static string type_string() { return "Z2^" + to_string(int(N_BITS)); }
|
||||
|
||||
static DataFieldType field_type() { return DATA_INT; }
|
||||
|
||||
static const bool invertible = false;
|
||||
static const false_type invertible;
|
||||
|
||||
template <int L, int M>
|
||||
template <int L, int M, bool LAZY = false>
|
||||
static Z2<K> Mul(const Z2<L>& x, const Z2<M>& y);
|
||||
|
||||
static void reqbl(int n);
|
||||
@@ -151,6 +152,9 @@ public:
|
||||
|
||||
void add(octetStream& os) { add(os.consume(size())); }
|
||||
|
||||
Z2 lazy_add(const Z2& x) const;
|
||||
Z2 lazy_mul(const Z2& x) const;
|
||||
|
||||
Z2& invert();
|
||||
void invert(const Z2& a) { *this = a; invert(); }
|
||||
|
||||
@@ -279,10 +283,17 @@ public:
|
||||
|
||||
template<int K>
|
||||
inline Z2<K> Z2<K>::operator+(const Z2<K>& other) const
|
||||
{
|
||||
auto res = lazy_add(other);
|
||||
res.normalize();
|
||||
return res;
|
||||
}
|
||||
|
||||
template<int K>
|
||||
Z2<K> Z2<K>::lazy_add(const Z2<K>& other) const
|
||||
{
|
||||
Z2<K> res;
|
||||
mpn_add_fixed_n<N_WORDS>(res.a, a, other.a);
|
||||
res.a[N_WORDS - 1] &= UPPER_MASK;
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -332,12 +343,13 @@ Z2<K>& Z2<K>::operator>>=(int other)
|
||||
}
|
||||
|
||||
template <int K>
|
||||
template <int L, int M>
|
||||
template <int L, int M, bool LAZY>
|
||||
inline Z2<K> Z2<K>::Mul(const Z2<L>& x, const Z2<M>& y)
|
||||
{
|
||||
Z2<K> res;
|
||||
mpn_mul_fixed_<N_WORDS, Z2<L>::N_WORDS, Z2<M>::N_WORDS>(res.a, x.a, y.a);
|
||||
res.a[N_WORDS - 1] &= UPPER_MASK;
|
||||
if (not LAZY)
|
||||
res.normalize();
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -348,6 +360,12 @@ inline Z2<(K > L) ? K : L> Z2<K>::operator*(const Z2<L>& other) const
|
||||
return Z2<(K > L) ? K : L>::Mul(*this, other);
|
||||
}
|
||||
|
||||
template <int K>
|
||||
inline Z2<K> Z2<K>::lazy_mul(const Z2<K>& other) const
|
||||
{
|
||||
return Z2<K>::Mul<K, K, true>(*this, other);
|
||||
}
|
||||
|
||||
template <int K>
|
||||
Z2<K> Z2<K>::operator<<(int i) const
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@ template<int K>
|
||||
const int Z2<K>::N_BITS;
|
||||
template<int K>
|
||||
const int Z2<K>::N_BYTES;
|
||||
template<int K>
|
||||
const false_type Z2<K>::invertible;
|
||||
|
||||
template<int K>
|
||||
void Z2<K>::reqbl(int n)
|
||||
|
||||
@@ -83,6 +83,7 @@ class gf2n_short : public ValueInterface
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2N; }
|
||||
static char type_char() { return '2'; }
|
||||
static string type_short() { return "2"; }
|
||||
static string type_string() { return "gf2n"; }
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
@@ -94,8 +95,8 @@ class gf2n_short : public ValueInterface
|
||||
|
||||
static bool allows(Dtype type) { (void) type; return true; }
|
||||
|
||||
static const bool invertible = true;
|
||||
static const bool characteristic_two = true;
|
||||
static const true_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
|
||||
static gf2n_short cut(int128 x) { return x.get_lower(); }
|
||||
|
||||
@@ -163,6 +164,9 @@ class gf2n_short : public ValueInterface
|
||||
// x * y when one of x,y is a bit
|
||||
void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; }
|
||||
|
||||
gf2n_short lazy_add(const gf2n_short& x) const { return *this + x; }
|
||||
gf2n_short lazy_mul(const gf2n_short& x) const { return *this * x; }
|
||||
|
||||
gf2n_short operator+(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; }
|
||||
gf2n_short operator*(const gf2n_short& x) const { gf2n_short res; res.mul(*this, x); return res; }
|
||||
gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; }
|
||||
|
||||
@@ -134,6 +134,7 @@ class gf2n_long : public ValueInterface
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2N; }
|
||||
static char type_char() { return '2'; }
|
||||
static string type_short() { return "2"; }
|
||||
static string type_string() { return "gf2n_long"; }
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
@@ -144,8 +145,8 @@ class gf2n_long : public ValueInterface
|
||||
|
||||
static bool allows(Dtype type) { (void) type; return true; }
|
||||
|
||||
static const bool invertible = true;
|
||||
static const bool characteristic_two = true;
|
||||
static const true_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
|
||||
static gf2n_long cut(int128 x) { return x; }
|
||||
|
||||
@@ -216,6 +217,9 @@ class gf2n_long : public ValueInterface
|
||||
// x * y when one of x,y is a bit
|
||||
void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; }
|
||||
|
||||
gf2n_long lazy_add(const gf2n_long& x) const { return *this + x; }
|
||||
gf2n_long lazy_mul(const gf2n_long& x) const { return *this * x; }
|
||||
|
||||
gf2n_long operator+(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; }
|
||||
gf2n_long operator*(const gf2n_long& x) const { gf2n_long res; res.mul(*this, x); return res; }
|
||||
gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; }
|
||||
@@ -251,6 +255,8 @@ class gf2n_long : public ValueInterface
|
||||
gf2n_long& operator>>=(int i) { SHR(*this, i); return *this; }
|
||||
gf2n_long& operator<<=(int i) { SHL(*this, i); return *this; }
|
||||
|
||||
bool operator<(gf2n_long) const { return false; }
|
||||
|
||||
/* Crap RNG */
|
||||
void randomize(PRNG& G, int n = -1);
|
||||
// compatibility with gfp
|
||||
|
||||
@@ -88,6 +88,7 @@ class gfp_ : public ValueInterface
|
||||
|
||||
static DataFieldType field_type() { return DATA_INT; }
|
||||
static char type_char() { return 'p'; }
|
||||
static string type_short() { return "p"; }
|
||||
static string type_string() { return "gfp"; }
|
||||
|
||||
static int size() { return t() * sizeof(mp_limb_t); }
|
||||
@@ -100,8 +101,7 @@ class gfp_ : public ValueInterface
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
static const bool invertible = true;
|
||||
static const bool characteristic_two = false;
|
||||
static const true_type invertible;
|
||||
|
||||
static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; }
|
||||
|
||||
@@ -184,6 +184,9 @@ class gfp_ : public ValueInterface
|
||||
void mul(const gfp_& x)
|
||||
{ a.template mul<L>(a,x.a,ZpD); }
|
||||
|
||||
gfp_ lazy_add(const gfp_& x) const { return *this + x; }
|
||||
gfp_ lazy_mul(const gfp_& x) const { return *this * x; }
|
||||
|
||||
gfp_ operator+(const gfp_& x) const { gfp_ res; res.add(*this, x); return res; }
|
||||
gfp_ operator-(const gfp_& x) const { gfp_ res; res.sub(*this, x); return res; }
|
||||
gfp_ operator*(const gfp_& x) const { gfp_ res; res.mul(*this, x); return res; }
|
||||
@@ -266,7 +269,7 @@ class gfp_ : public ValueInterface
|
||||
|
||||
// Convert representation to and from a bigint number
|
||||
friend void to_bigint(bigint& ans,const gfp_& x,bool reduce=true)
|
||||
{ to_bigint(ans,x.a,x.ZpD,reduce); }
|
||||
{ x.a.template to_bigint<L>(ans, x.ZpD, reduce); }
|
||||
friend void to_gfp(gfp_& ans,const bigint& x)
|
||||
{ to_modp(ans.a,x,ans.ZpD); }
|
||||
};
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
#include "Math/bigint.hpp"
|
||||
#include "Math/Setup.hpp"
|
||||
|
||||
template<int X, int L>
|
||||
const true_type gfp_<X, L>::invertible;
|
||||
|
||||
template<int X, int L>
|
||||
inline void gfp_<X, L>::read_or_generate_setup(string dir,
|
||||
const OnlineOptions& opts)
|
||||
|
||||
@@ -74,6 +74,8 @@ class modp_
|
||||
|
||||
// Convert representation to and from a modp number
|
||||
void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const;
|
||||
template<int M>
|
||||
void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const;
|
||||
|
||||
template<int T>
|
||||
void mul(const modp_& x, const modp_& y, const Zp_Data& ZpD);
|
||||
|
||||
@@ -113,6 +113,31 @@ void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
|
||||
}
|
||||
|
||||
|
||||
template<int L>
|
||||
template<int M>
|
||||
void modp_<L>::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const
|
||||
{
|
||||
assert(M == ZpD.t);
|
||||
auto& x = *this;
|
||||
mpz_ptr a = ans.get_mpz_t();
|
||||
if (a->_mp_alloc < M)
|
||||
mpz_realloc(a, M);
|
||||
if (ZpD.montgomery)
|
||||
{
|
||||
mp_limb_t one[M];
|
||||
inline_mpn_zero(one,M);
|
||||
one[0]=1;
|
||||
ZpD.Mont_Mult_<M>(a->_mp_d,x.x,one);
|
||||
}
|
||||
else
|
||||
{ inline_mpn_copyi(a->_mp_d,x.x,M); }
|
||||
a->_mp_size=M;
|
||||
if (reduce)
|
||||
while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0)
|
||||
{ a->_mp_size--; }
|
||||
}
|
||||
|
||||
|
||||
template<int L>
|
||||
void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD)
|
||||
{
|
||||
|
||||
@@ -20,7 +20,9 @@ void ssl_error(string side, string pronoun, string other, string server)
|
||||
<< " failed. Make sure " << pronoun
|
||||
<< " have the necessary certificate (" << PREP_DIR << server
|
||||
<< ".pem in the default configuration),"
|
||||
<< " and run `c_rehash <directory>` on its location." << endl;
|
||||
<< " and run `c_rehash <directory>` on its location." << endl
|
||||
<< "Also make sure that it's still valid. Certificates generated "
|
||||
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
|
||||
@@ -90,7 +90,15 @@ struct CommStats
|
||||
size_t data, rounds;
|
||||
Timer timer;
|
||||
CommStats() : data(0), rounds(0) {}
|
||||
Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; }
|
||||
Timer& add(const octetStream& os)
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "add " << os.get_length() << endl;
|
||||
#endif
|
||||
data += os.get_length();
|
||||
rounds++;
|
||||
return timer;
|
||||
}
|
||||
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
|
||||
CommStats& operator+=(const CommStats& other);
|
||||
CommStats& operator-=(const CommStats& other);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sodium.h>
|
||||
#include <regex>
|
||||
using namespace std;
|
||||
|
||||
BaseMachine* BaseMachine::singleton = 0;
|
||||
@@ -28,13 +29,14 @@ BaseMachine::BaseMachine() : nthreads(0)
|
||||
singleton = this;
|
||||
}
|
||||
|
||||
void BaseMachine::load_schedule(string progname)
|
||||
void BaseMachine::load_schedule(string progname, bool load_bytecode)
|
||||
{
|
||||
this->progname = progname;
|
||||
string fname = "Programs/Schedules/" + progname + ".sch";
|
||||
#ifdef DEBUG_FILES
|
||||
cerr << "Opening file " << fname << endl;
|
||||
#endif
|
||||
ifstream inpf;
|
||||
inpf.open(fname);
|
||||
if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); }
|
||||
|
||||
@@ -54,25 +56,35 @@ void BaseMachine::load_schedule(string progname)
|
||||
string threadname;
|
||||
for (int i=0; i<nprogs; i++)
|
||||
{ inpf >> threadname;
|
||||
string filename = "Programs/Bytecode/" + threadname + ".bc";
|
||||
if (load_bytecode)
|
||||
{
|
||||
string filename = "Programs/Bytecode/" + threadname + ".bc";
|
||||
#ifdef DEBUG_FILES
|
||||
cerr << "Loading program " << i << " from " << filename << endl;
|
||||
cerr << "Loading program " << i << " from " << filename << endl;
|
||||
#endif
|
||||
load_program(threadname, filename);
|
||||
load_program(threadname, filename);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto i : {1, 0, 0})
|
||||
{
|
||||
int n;
|
||||
inpf >> n;
|
||||
if (n != i)
|
||||
throw runtime_error("old schedule format not supported");
|
||||
}
|
||||
|
||||
inpf.get();
|
||||
getline(inpf, compiler);
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
void BaseMachine::print_compiler()
|
||||
{
|
||||
|
||||
char compiler[1000];
|
||||
inpf.get();
|
||||
inpf.getline(compiler, 1000);
|
||||
#ifdef VERBOSE
|
||||
if (compiler[0] != 0)
|
||||
if (compiler.size() != 0)
|
||||
cerr << "Compiler: " << compiler << endl;
|
||||
#endif
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
void BaseMachine::load_program(string threadname, string filename)
|
||||
@@ -112,3 +124,20 @@ string BaseMachine::memory_filename(string type_short, int my_number)
|
||||
{
|
||||
return PREP_DIR "Memory-" + type_short + "-P" + to_string(my_number);
|
||||
}
|
||||
|
||||
int BaseMachine::ring_size_from_schedule(string progname)
|
||||
{
|
||||
assert(not singleton);
|
||||
BaseMachine machine;
|
||||
singleton = 0;
|
||||
machine.load_schedule(progname, false);
|
||||
smatch m;
|
||||
regex e("R ([0-9]+)");
|
||||
regex_search(machine.compiler, m, e);
|
||||
if (m.size() > 1)
|
||||
{
|
||||
return stoi(m[1]);
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ protected:
|
||||
|
||||
std::map<int,Timer> timer;
|
||||
|
||||
ifstream inpf;
|
||||
string compiler;
|
||||
|
||||
void print_timers();
|
||||
|
||||
@@ -43,10 +43,12 @@ public:
|
||||
|
||||
static string memory_filename(string type_short, int my_number);
|
||||
|
||||
static int ring_size_from_schedule(string progname);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
|
||||
void load_schedule(string progname);
|
||||
void load_schedule(string progname, bool load_bytecode = true);
|
||||
void print_compiler();
|
||||
|
||||
void time();
|
||||
|
||||
@@ -14,13 +14,13 @@ const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 };
|
||||
|
||||
void DataPositions::set_num_players(int num_players)
|
||||
{
|
||||
files.resize(N_DATA_FIELD_TYPE, vector<long long>(N_DTYPE));
|
||||
inputs.resize(num_players, vector<long long>(N_DATA_FIELD_TYPE));
|
||||
files = {};
|
||||
inputs.resize(num_players, {});
|
||||
}
|
||||
|
||||
void DataPositions::increase(const DataPositions& delta)
|
||||
{
|
||||
inputs.resize(max(inputs.size(), delta.inputs.size()), vector<long long>(N_DATA_FIELD_TYPE));
|
||||
inputs.resize(max(inputs.size(), delta.inputs.size()), {});
|
||||
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
@@ -39,8 +39,7 @@ void DataPositions::increase(const DataPositions& delta)
|
||||
|
||||
DataPositions& DataPositions::operator-=(const DataPositions& delta)
|
||||
{
|
||||
inputs.resize(max(inputs.size(), delta.inputs.size()),
|
||||
vector<long long>(N_DATA_FIELD_TYPE));
|
||||
inputs.resize(max(inputs.size(), delta.inputs.size()), {});
|
||||
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE;
|
||||
field_type++)
|
||||
{
|
||||
@@ -144,3 +143,25 @@ void DataPositions::process_line(long long items_used, const char* name,
|
||||
cerr << suffix << endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool DataPositions::empty() const
|
||||
{
|
||||
for (auto& x : files)
|
||||
for (auto& y : x)
|
||||
if (y)
|
||||
return false;
|
||||
|
||||
for (auto& x : inputs)
|
||||
for (auto& y : x)
|
||||
if (y)
|
||||
return false;
|
||||
|
||||
for (auto& x : extended)
|
||||
if (not x.empty())
|
||||
return false;
|
||||
|
||||
if (not edabits.empty())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -10,11 +10,14 @@
|
||||
#include "Processor/InputTuple.h"
|
||||
#include "Tools/Lock.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Protocols/edabit.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
using namespace std;
|
||||
|
||||
template<class T> class dabit;
|
||||
|
||||
class DataTag
|
||||
{
|
||||
int t[4];
|
||||
@@ -50,9 +53,9 @@ public:
|
||||
static const char* field_names[N_DATA_FIELD_TYPE];
|
||||
static const int tuple_size[N_DTYPE];
|
||||
|
||||
vector< vector<long long> > files;
|
||||
vector< vector<long long> > inputs;
|
||||
map<DataTag, long long> extended[N_DATA_FIELD_TYPE];
|
||||
array<array<long long, N_DTYPE>, N_DATA_FIELD_TYPE> files;
|
||||
vector< array<long long, N_DATA_FIELD_TYPE> > inputs;
|
||||
array<map<DataTag, long long>, N_DATA_FIELD_TYPE> extended;
|
||||
map<pair<bool, int>, long long> edabits;
|
||||
|
||||
DataPositions(int num_players = 0) { set_num_players(num_players); }
|
||||
@@ -63,6 +66,7 @@ public:
|
||||
DataPositions& operator-=(const DataPositions& delta);
|
||||
DataPositions operator-(const DataPositions& delta) const;
|
||||
void print_cost() const;
|
||||
bool empty() const;
|
||||
};
|
||||
|
||||
template<class sint, class sgf2n> class Processor;
|
||||
@@ -73,9 +77,12 @@ template<class T> class SubProcessor;
|
||||
template<class T>
|
||||
class Preprocessing
|
||||
{
|
||||
protected:
|
||||
DataPositions& usage;
|
||||
|
||||
protected:
|
||||
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
|
||||
map<pair<bool, int>, edabitvec<T>> my_edabits;
|
||||
|
||||
void count(Dtype dtype) { usage.files[T::field_type()][dtype]++; }
|
||||
void count(DataTag tag, int n = 1) { usage.extended[T::field_type()][tag] += n; }
|
||||
void count_input(int player) { usage.inputs[player][T::field_type()]++; }
|
||||
@@ -117,10 +124,12 @@ public:
|
||||
|
||||
virtual array<T, 3> get_triple(int n_bits);
|
||||
virtual T get_bit();
|
||||
virtual void get_dabit(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
|
||||
virtual void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
|
||||
const vector<int>&)
|
||||
{ throw runtime_error("no edaBit"); }
|
||||
virtual void get_dabit(T&, typename T::bit_type&);
|
||||
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs);
|
||||
virtual void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
|
||||
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
|
||||
|
||||
virtual void push_triples(const vector<array<T, 3>>&)
|
||||
{ throw runtime_error("no pushing"); }
|
||||
@@ -145,10 +154,17 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
vector<BufferOwner<T, T>> input_buffers;
|
||||
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
|
||||
map<DataTag, BufferOwner<T, T> > extended;
|
||||
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
|
||||
map<int, ifstream*> edabit_buffers;
|
||||
|
||||
int my_num,num_players;
|
||||
|
||||
const string prep_data_dir;
|
||||
int thread_num;
|
||||
|
||||
Sub_Data_Files<typename T::part_type>* part;
|
||||
|
||||
void buffer_edabits_with_queues(bool stric, int n_bits);
|
||||
|
||||
public:
|
||||
static string get_suffix(int thread_num);
|
||||
@@ -202,6 +218,9 @@ public:
|
||||
|
||||
void setup_extended(const DataTag& tag, int tuple_size = 0);
|
||||
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b);
|
||||
|
||||
Preprocessing<typename T::part_type>& get_part();
|
||||
};
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Processor/Processor.h"
|
||||
#include "Protocols/dabit.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
template<class T>
|
||||
Lock Sub_Data_Files<T>::tuple_lengths_lock;
|
||||
@@ -54,7 +56,8 @@ template<class T>
|
||||
Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
const string& prep_data_dir, DataPositions& usage, int thread_num) :
|
||||
Preprocessing<T>(usage),
|
||||
my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir)
|
||||
my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir),
|
||||
thread_num(thread_num), part(0)
|
||||
{
|
||||
#ifdef DEBUG_FILES
|
||||
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
|
||||
@@ -72,6 +75,11 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
}
|
||||
}
|
||||
|
||||
sprintf(filename, (prep_data_dir + "%s-%s-P%d%s").c_str(),
|
||||
DataPositions::dtype_names[DATA_DABIT], (T::type_short()).c_str(), my_num,
|
||||
suffix.c_str());
|
||||
dabit_buffer.setup(filename, 1, DataPositions::dtype_names[DATA_DABIT]);
|
||||
|
||||
input_buffers.resize(num_players);
|
||||
for (int i=0; i<num_players; i++)
|
||||
{
|
||||
@@ -127,6 +135,14 @@ Sub_Data_Files<T>::~Sub_Data_Files()
|
||||
for (auto it =
|
||||
extended.begin(); it != extended.end(); it++)
|
||||
it->second.close();
|
||||
dabit_buffer.close();
|
||||
for (auto& x: edabit_buffers)
|
||||
{
|
||||
x.second->close();
|
||||
delete x.second;
|
||||
}
|
||||
if (part != 0)
|
||||
delete part;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -236,4 +252,48 @@ void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int
|
||||
extended[tag].input(S[regs[i] + j]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
|
||||
{
|
||||
dabit<T> tmp;
|
||||
dabit_buffer.input(tmp);
|
||||
a = tmp.first;
|
||||
b = tmp.second;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits)
|
||||
{
|
||||
#ifndef INSECURE
|
||||
throw runtime_error("no secure implementation of reading edaBits from files");
|
||||
#endif
|
||||
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
|
||||
{
|
||||
string filename = prep_data_dir + "edaBits-" + to_string(n_bits) + "-P"
|
||||
+ to_string(my_num);
|
||||
ifstream* f = new ifstream(filename);
|
||||
if (f->fail())
|
||||
throw runtime_error("cannot open " + filename);
|
||||
edabit_buffers[n_bits] = f;
|
||||
}
|
||||
auto& buffer = *edabit_buffers[n_bits];
|
||||
if (buffer.peek() == EOF)
|
||||
buffer.seekg(0);
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
this->edabits[{strict, n_bits}].push_back(eb);
|
||||
if (buffer.fail())
|
||||
throw runtime_error("error reading edaBits");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Preprocessing<typename T::part_type>& Sub_Data_Files<T>::get_part()
|
||||
{
|
||||
if (part == 0)
|
||||
part = new Sub_Data_Files<typename T::part_type>(my_num, num_players,
|
||||
get_prep_sub_dir<typename T::part_type>(num_players), this->usage,
|
||||
thread_num);
|
||||
return *part;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -45,6 +45,9 @@ public:
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
void CheckFor(const typename T::open_type&, const vector<T>&, const Player&)
|
||||
{
|
||||
}
|
||||
|
||||
DummyMC<typename T::part_type>& get_part_MC()
|
||||
{
|
||||
@@ -56,6 +59,11 @@ public:
|
||||
throw not_implemented();
|
||||
return {};
|
||||
}
|
||||
|
||||
int number()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -63,12 +71,17 @@ class DummyProtocol : public ProtocolBase<T>
|
||||
{
|
||||
public:
|
||||
Player& P;
|
||||
int counter;
|
||||
|
||||
static int get_n_relevant_players()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
static void multiply(vector<T>, vector<pair<T, T>>, int, int, SubProcessor<T>)
|
||||
{
|
||||
}
|
||||
|
||||
DummyProtocol(Player& P) :
|
||||
P(P)
|
||||
{
|
||||
@@ -91,6 +104,9 @@ public:
|
||||
throw not_implemented();
|
||||
return {};
|
||||
}
|
||||
void check()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -170,6 +186,10 @@ public:
|
||||
{
|
||||
(void) proc, (void) MC;
|
||||
}
|
||||
template<class T, class U, class W>
|
||||
NotImplementedInput(const T&, const U&, const W&)
|
||||
{
|
||||
}
|
||||
NotImplementedInput(Player& P)
|
||||
{
|
||||
(void) P;
|
||||
@@ -200,6 +220,12 @@ public:
|
||||
(void) proc, (void) regs;
|
||||
throw not_implemented();
|
||||
}
|
||||
static void input_mixed(SubProcessor<V>, vector<int>, int, int)
|
||||
{
|
||||
}
|
||||
static void raw_input(SubProcessor<V>, vector<int>, int)
|
||||
{
|
||||
}
|
||||
void reset_all(Player& P)
|
||||
{
|
||||
(void) P;
|
||||
@@ -248,7 +274,7 @@ public:
|
||||
(void) player, (void) target, (void) source;
|
||||
throw not_implemented();
|
||||
}
|
||||
void stop(int player, int source)
|
||||
void stop(int player, int source, int)
|
||||
{
|
||||
(void) player, (void) source;
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ class InputBase
|
||||
{
|
||||
typedef typename T::clear clear;
|
||||
|
||||
protected:
|
||||
Player* P;
|
||||
|
||||
protected:
|
||||
Buffer<typename T::clear, typename T::clear> buffer;
|
||||
Timer timer;
|
||||
|
||||
@@ -42,6 +42,7 @@ public:
|
||||
static void finalize(SubProcessor<T>& Proc, int player, const int* params, int size);
|
||||
|
||||
InputBase(ArithmeticProcessor* proc = 0);
|
||||
InputBase(SubProcessor<T>* proc);
|
||||
virtual ~InputBase();
|
||||
|
||||
virtual void reset(int player) = 0;
|
||||
@@ -56,7 +57,7 @@ public:
|
||||
|
||||
virtual T finalize_mine() = 0;
|
||||
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
|
||||
T finalize(int player, int n_bits = -1);
|
||||
virtual T finalize(int player, int n_bits = -1);
|
||||
|
||||
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);
|
||||
};
|
||||
|
||||
@@ -23,6 +23,12 @@ InputBase<T>::InputBase(ArithmeticProcessor* proc) :
|
||||
buffer.setup(&proc->private_input, -1, proc->private_input_filename);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
InputBase<T>::InputBase(SubProcessor<T>* proc) :
|
||||
InputBase(proc ? proc->Proc : 0)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Input<T>::Input(SubProcessor<T>& proc) :
|
||||
Input(proc, proc.MC)
|
||||
@@ -92,7 +98,7 @@ void Input<T>::add_mine(const open_type& input, int n_bits)
|
||||
prep.get_input(share, rr, player);
|
||||
t = input - rr;
|
||||
t.pack(this->os[player]);
|
||||
share += T::constant(t, 0, MC.get_alphai());
|
||||
share += T::constant(t, player, MC.get_alphai());
|
||||
this->values_input++;
|
||||
}
|
||||
|
||||
@@ -190,7 +196,7 @@ void Input<T>::finalize_other(int player, T& target,
|
||||
(void) n_bits;
|
||||
target = shares[player].next();
|
||||
t.unpack(o);
|
||||
target += T::constant(t, 1, MC.get_alphai());
|
||||
target += T::constant(t, P.my_num(), MC.get_alphai());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -330,6 +330,7 @@ struct TempVars {
|
||||
class BaseInstruction
|
||||
{
|
||||
friend class Program;
|
||||
template<class T> friend class RepRingOnlyEdabitPrep;
|
||||
|
||||
protected:
|
||||
int opcode; // The code
|
||||
|
||||
@@ -1217,7 +1217,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]));
|
||||
break;
|
||||
case RANDOMS:
|
||||
Procp.protocol.randoms_inst(Procp, *this);
|
||||
Procp.protocol.randoms_inst(Procp.get_S(), *this);
|
||||
return;
|
||||
case INPUTMASKREG:
|
||||
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2]));
|
||||
@@ -1314,14 +1314,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
break;
|
||||
case SHLC:
|
||||
to_bigint(Proc.temp.aa,Proc.read_Cp(r[2]));
|
||||
if (Proc.temp.aa > 63)
|
||||
throw runtime_error("too much left shift");
|
||||
Proc.get_Cp_ref(r[0]).SHL(Proc.read_Cp(r[1]),Proc.temp.aa);
|
||||
break;
|
||||
case SHRC:
|
||||
to_bigint(Proc.temp.aa,Proc.read_Cp(r[2]));
|
||||
if (Proc.temp.aa > 63)
|
||||
throw runtime_error("too much right shift");
|
||||
Proc.get_Cp_ref(r[0]).SHR(Proc.read_Cp(r[1]),Proc.temp.aa);
|
||||
break;
|
||||
case SHLCI:
|
||||
@@ -1337,8 +1333,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n);
|
||||
break;
|
||||
case SHRSI:
|
||||
Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n;
|
||||
break;
|
||||
sint::shrsi(Procp, *this);
|
||||
return;
|
||||
case GBITDEC:
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
|
||||
@@ -39,35 +39,40 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
sint::clear::read_or_generate_setup(prep_dir_prefix<sint>(), opts);
|
||||
sint::bit_type::mac_key_type::init_field();
|
||||
|
||||
// Initialize gf2n_short for CCD
|
||||
sint::bit_type::part_type::open_type::init_field();
|
||||
|
||||
// make directory for outputs if necessary
|
||||
mkdir_p(PREP_DIR);
|
||||
|
||||
Player* P;
|
||||
if (use_encryption)
|
||||
P = new CryptoPlayer(N, 0xF00);
|
||||
else
|
||||
P = new PlainPlayer(N, 0xF00);
|
||||
|
||||
if (opts.live_prep)
|
||||
{
|
||||
auto P = new PlainPlayer(N, 0xF00);
|
||||
sint::LivePrep::basic_setup(*P);
|
||||
delete P;
|
||||
}
|
||||
|
||||
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), N, alphapi);
|
||||
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), N, alpha2i);
|
||||
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), *P, alphapi);
|
||||
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), *P, alpha2i);
|
||||
sint::bit_type::part_type::read_or_generate_mac_key(
|
||||
prep_dir_prefix<typename sint::bit_type>(), N, alphabi);
|
||||
prep_dir_prefix<typename sint::bit_type>(), *P, alphabi);
|
||||
|
||||
#ifdef DEBUG_MAC
|
||||
cerr << "MAC Key p = " << alphapi << endl;
|
||||
cerr << "MAC Key 2 = " << alpha2i << endl;
|
||||
#endif
|
||||
|
||||
// deactivate output if necessary
|
||||
sint::bit_type::out.activate(my_number == 0 or opts.interactive);
|
||||
|
||||
// for OT-based preprocessing
|
||||
sint::clear::next::template init<typename sint::clear>(false);
|
||||
|
||||
// Initialize the global memory
|
||||
if (memtype.compare("old")==0)
|
||||
{
|
||||
ifstream inpf;
|
||||
inpf.open(memory_filename(), ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(memory_filename()); }
|
||||
inpf >> M2 >> Mp >> Mi;
|
||||
@@ -90,16 +95,12 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
if (live_prep
|
||||
and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot))
|
||||
{
|
||||
Player* P;
|
||||
if (use_encryption)
|
||||
P = new CryptoPlayer(playerNames, 0xF000);
|
||||
else
|
||||
P = new PlainPlayer(playerNames, 0xF000);
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
ot_setups.push_back({ *P, true });
|
||||
delete P;
|
||||
}
|
||||
|
||||
delete P;
|
||||
|
||||
/* Set up the threads */
|
||||
tinfo.resize(nthreads);
|
||||
threads.resize(nthreads);
|
||||
@@ -190,7 +191,7 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
|
||||
}
|
||||
catch (bad_cast& e)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_CENTRAL
|
||||
cerr << "Problem with central preprocessing" << endl;
|
||||
#endif
|
||||
}
|
||||
@@ -210,7 +211,7 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
|
||||
}
|
||||
catch (bad_cast& e)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_CENTRAL
|
||||
cerr << "Problem with central bit triple preprocessing: " << e.what() << endl;
|
||||
#endif
|
||||
}
|
||||
@@ -231,12 +232,14 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
//printf("Running line %d\n",exec);
|
||||
if (progs[tape_number].usage_unknown())
|
||||
{
|
||||
#ifndef INSECURE
|
||||
if (not opts.live_prep)
|
||||
{
|
||||
cerr << "Internally called tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
throw invalid_program();
|
||||
}
|
||||
#endif
|
||||
return DataPositions(N.num_players());
|
||||
}
|
||||
else
|
||||
@@ -263,9 +266,6 @@ void Machine<sint, sgf2n>::run()
|
||||
proc_timer.start();
|
||||
timer[0].start();
|
||||
|
||||
// legacy
|
||||
int _;
|
||||
inpf >> _ >> _ >> _;
|
||||
// run main tape
|
||||
pos.increase(run_tape(0, 0, 0));
|
||||
join_tape(0);
|
||||
|
||||
@@ -16,6 +16,13 @@ template<class T>
|
||||
class NoLivePrep : public Sub_Data_Files<T>
|
||||
{
|
||||
public:
|
||||
static void basic_setup(Player&)
|
||||
{
|
||||
}
|
||||
static void teardown()
|
||||
{
|
||||
}
|
||||
|
||||
NoLivePrep(SubProcessor<T>* proc, DataPositions& usage) : Sub_Data_Files<T>(0, 0, "", usage, 0)
|
||||
{
|
||||
(void) proc;
|
||||
|
||||
42
Processor/NoProtocol.h
Normal file
42
Processor/NoProtocol.h
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* NoProtocol.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_NOPROTOCOL_H_
|
||||
#define PROCESSOR_NOPROTOCOL_H_
|
||||
|
||||
#include "Protocols/Replicated.h"
|
||||
|
||||
template<class T>
|
||||
class NoProtocol : public ProtocolBase<T>
|
||||
{
|
||||
public:
|
||||
NoProtocol(Player&)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void init_mul(SubProcessor<T>*)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
typename T::clear prepare_mul(const T&, const T&, int n = -1)
|
||||
{
|
||||
(void) n;
|
||||
throw not_implemented();
|
||||
}
|
||||
void exchange()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
T finalize_mul(int n = -1)
|
||||
{
|
||||
(void) n;
|
||||
throw not_implemented();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
#endif /* PROCESSOR_NOPROTOCOL_H_ */
|
||||
@@ -240,19 +240,6 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
job.pos.increase(Proc.DataF.get_usage());
|
||||
}
|
||||
|
||||
//double elapsed = timeval_diff(&startv, &endv);
|
||||
//printf("Thread time = %f seconds\n",elapsed/1000000);
|
||||
//printf("\texec = %d\n",exec); exec++;
|
||||
//printf("\tMC2.number = %d\n",MC2.number());
|
||||
//printf("\tMCp.number = %d\n",MCp.number());
|
||||
|
||||
// MACCheck
|
||||
MC2->Check(P);
|
||||
MCp->Check(P);
|
||||
//printf("\tMAC checked\n");
|
||||
P.Check_Broadcast();
|
||||
//printf("\tBroadcast checked\n");
|
||||
|
||||
#ifdef DEBUG_THREADS
|
||||
printf("\tSignalling I have finished\n");
|
||||
#endif
|
||||
@@ -269,6 +256,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
// MACCheck
|
||||
MC2->Check(P);
|
||||
MCp->Check(P);
|
||||
Proc.share_thread.MC->Check(P);
|
||||
|
||||
//cout << num << " : Checking broadcast" << endl;
|
||||
P.Check_Broadcast();
|
||||
|
||||
@@ -17,6 +17,7 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
live_prep = true;
|
||||
batch_size = 10000;
|
||||
memtype = "empty";
|
||||
bits_from_squares = false;
|
||||
direct = false;
|
||||
bucket_size = 3;
|
||||
cmd_private_input_file = "Player-Data/Input";
|
||||
@@ -138,6 +139,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-m", // Flag token.
|
||||
"--memory" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Compute random bits from squares", // Help description.
|
||||
"-Q", // Flag token.
|
||||
"--bits-from-squares" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
@@ -174,6 +184,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
live_prep = opt.get("-L")->isSet;
|
||||
opt.get("-b")->getInt(batch_size);
|
||||
opt.get("--memory")->getString(memtype);
|
||||
bits_from_squares = opt.isSet("-Q");
|
||||
|
||||
opt.get("-IF")->getString(cmd_private_input_file);
|
||||
opt.get("-OF")->getString(cmd_private_output_file);
|
||||
|
||||
@@ -22,6 +22,7 @@ public:
|
||||
std::string progname;
|
||||
int batch_size;
|
||||
std::string memtype;
|
||||
bool bits_from_squares;
|
||||
bool direct;
|
||||
int bucket_size;
|
||||
std::string cmd_private_input_file;
|
||||
|
||||
@@ -52,6 +52,11 @@ SubProcessor<T>::~SubProcessor()
|
||||
if (bit_prep.data_sent())
|
||||
cerr << "Sent for global bit preprocessing threads: " <<
|
||||
bit_prep.data_sent() * 1e-6 << " MB" << endl;
|
||||
if (not bit_usage.empty())
|
||||
{
|
||||
cerr << "Mixed-circuit preprocessing cost:" << endl;
|
||||
bit_usage.print_cost();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -82,13 +87,16 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
secure_prng.ReSeed();
|
||||
shared_prng.SeedGlobally(P);
|
||||
|
||||
out.activate(P.my_num() == 0 or machine.opts.interactive);
|
||||
// only output on party 0 if not interactive
|
||||
bool output = P.my_num() == 0 or machine.opts.interactive;
|
||||
out.activate(output);
|
||||
Procb.out.activate(output);
|
||||
setup_redirection(P.my_num(), thread_num, opts);
|
||||
|
||||
if (!machine.opts.cmd_private_output_file.empty())
|
||||
if (stdout_redirect_file.is_open())
|
||||
{
|
||||
const string stdout_filename = get_parameterized_filename(P.my_num(), thread_num, opts.cmd_private_output_file);
|
||||
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
|
||||
out.redirect_to_file(stdout_redirect_file);
|
||||
Procb.out.redirect_to_file(stdout_redirect_file);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
17
Processor/ProcessorBase.cpp
Normal file
17
Processor/ProcessorBase.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
/*
|
||||
* ProcessorBase.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ProcessorBase.hpp"
|
||||
|
||||
void ProcessorBase::setup_redirection(int my_num, int thread_num,
|
||||
OnlineOptions& opts)
|
||||
{
|
||||
if (not opts.cmd_private_output_file.empty())
|
||||
{
|
||||
const string stdout_filename = get_parameterized_filename(my_num,
|
||||
thread_num, opts.cmd_private_output_file);
|
||||
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/ExecutionStats.h"
|
||||
#include "OnlineOptions.h"
|
||||
|
||||
class ProcessorBase
|
||||
{
|
||||
@@ -30,6 +31,8 @@ protected:
|
||||
public:
|
||||
ExecutionStats stats;
|
||||
|
||||
ofstream stdout_redirect_file;
|
||||
|
||||
void pushi(long x) { stacki.push(x); }
|
||||
void popi(long& x) { x = stacki.top(); stacki.pop(); }
|
||||
|
||||
@@ -50,6 +53,8 @@ public:
|
||||
T get_input(bool interactive, const int* params);
|
||||
template<class T>
|
||||
T get_input(istream& is, const string& input_filename, const int* params);
|
||||
|
||||
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PROCESSORBASE_H_ */
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
*/
|
||||
|
||||
#include "RingOptions.h"
|
||||
#include "BaseMachine.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
|
||||
RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
bool security)
|
||||
{
|
||||
opt.add(
|
||||
"64", // Default.
|
||||
@@ -19,8 +21,37 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
|
||||
"-R", // Flag token.
|
||||
"--ring" // Flag token.
|
||||
);
|
||||
if (security)
|
||||
opt.add(
|
||||
"40", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Security parameter (default: 40)", // Help description.
|
||||
"-S", // Flag token.
|
||||
"--security" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
opt.get("-R")->getInt(R);
|
||||
if (security)
|
||||
opt.get("-S")->getInt(S);
|
||||
else
|
||||
S = -1;
|
||||
R_is_set = opt.isSet("-R");
|
||||
opt.resetArgs();
|
||||
cerr << "Trying to run " << R << "-bit computation" << endl;
|
||||
if (R_is_set)
|
||||
cerr << "Trying to run " << R << "-bit computation" << endl;
|
||||
if (security)
|
||||
cerr << "Using security parameter " << S << endl;
|
||||
}
|
||||
|
||||
int RingOptions::ring_size_from_opts_or_schedule(string progname)
|
||||
{
|
||||
if (R_is_set)
|
||||
return R;
|
||||
int r = BaseMachine::ring_size_from_schedule(progname);
|
||||
if (r == 0)
|
||||
r = R;
|
||||
cerr << "Trying to run " << r << "-bit computation" << endl;
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -7,13 +7,21 @@
|
||||
#define PROCESSOR_RINGOPTIONS_H_
|
||||
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
class RingOptions
|
||||
{
|
||||
bool R_is_set;
|
||||
|
||||
public:
|
||||
int R;
|
||||
int S;
|
||||
|
||||
RingOptions(ez::ezOptionParser& opt, int argc, const char** argv);
|
||||
RingOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
bool security = false);
|
||||
|
||||
int ring_size_from_opts_or_schedule(string progname);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_RINGOPTIONS_H_ */
|
||||
|
||||
76
Processor/TruncPrTuple.h
Normal file
76
Processor/TruncPrTuple.h
Normal file
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* TruncPrTuple.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_TRUNCPRTUPLE_H_
|
||||
#define PROCESSOR_TRUNCPRTUPLE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
using namespace std;
|
||||
|
||||
template<class T>
|
||||
class TruncPrTuple
|
||||
{
|
||||
public:
|
||||
int dest_base;
|
||||
int source_base;
|
||||
int k;
|
||||
int m;
|
||||
int n_shift;
|
||||
|
||||
TruncPrTuple(const vector<int>& regs, size_t base)
|
||||
{
|
||||
dest_base = regs[base];
|
||||
source_base = regs[base + 1];
|
||||
k = regs[base + 2];
|
||||
m = regs[base + 3];
|
||||
n_shift = T::N_BITS - 1 - k;
|
||||
assert(m < k);
|
||||
assert(0 < k);
|
||||
assert(m < T::N_BITS);
|
||||
}
|
||||
|
||||
T upper(T mask)
|
||||
{
|
||||
return (mask << (n_shift + 1)) >> (n_shift + m + 1);
|
||||
}
|
||||
|
||||
T msb(T mask)
|
||||
{
|
||||
return (mask << (n_shift)) >> (T::N_BITS - 1);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TruncPrTupleWithGap : public TruncPrTuple<T>
|
||||
{
|
||||
public:
|
||||
TruncPrTupleWithGap(const vector<int>& regs, size_t base) :
|
||||
TruncPrTuple<T>(regs, base)
|
||||
{
|
||||
}
|
||||
|
||||
T upper(T mask)
|
||||
{
|
||||
if (big_gap())
|
||||
return mask >> this->m;
|
||||
else
|
||||
return TruncPrTuple<T>::upper(mask);
|
||||
}
|
||||
|
||||
T msb(T mask)
|
||||
{
|
||||
assert(not big_gap());
|
||||
return TruncPrTuple<T>::msb(mask);
|
||||
}
|
||||
|
||||
bool big_gap()
|
||||
{
|
||||
return this->k <= T::N_BITS - 40;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_TRUNCPRTUPLE_H_ */
|
||||
@@ -55,6 +55,9 @@
|
||||
X(MULCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
typename sint::clear op2 = int(n), \
|
||||
*dest++ = *op1++ * op2) \
|
||||
X(MULSI, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
typename sint::clear op2 = int(n), \
|
||||
*dest++ = *op1++ * op2) \
|
||||
X(SHRCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]], \
|
||||
*dest++ = *op1++ >> n) \
|
||||
X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \
|
||||
|
||||
@@ -21,7 +21,7 @@ def match(db_entry, sample):
|
||||
from Compiler import util
|
||||
|
||||
if n_threads is None:
|
||||
util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n)))
|
||||
res = util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n)))
|
||||
else:
|
||||
tmp = sint.Array(n_threads)
|
||||
|
||||
@@ -31,4 +31,6 @@ else:
|
||||
(match(db[base + i], sample)
|
||||
for i in range(size)))
|
||||
|
||||
util.tree_reduce(lambda x, y: x.min(y), tmp)
|
||||
res = util.tree_reduce(lambda x, y: x.min(y), tmp)
|
||||
|
||||
print_ln('result: %s', res.reveal())
|
||||
|
||||
69
Programs/Source/mnist_49.mpc
Normal file
69
Programs/Source/mnist_49.mpc
Normal file
@@ -0,0 +1,69 @@
|
||||
import ml
|
||||
import math
|
||||
import re
|
||||
import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program)
|
||||
|
||||
n_examples = 11791
|
||||
n_test = 1991
|
||||
n_features = 28 ** 2
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except:
|
||||
n_epochs = 100
|
||||
|
||||
N = n_examples
|
||||
batch_size = 128
|
||||
|
||||
assert batch_size <= N
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[2]))
|
||||
except:
|
||||
pass
|
||||
|
||||
n_inner = 128
|
||||
|
||||
n_dense_layers = None
|
||||
for arg in program.args:
|
||||
m = re.match('(.*)dense', arg)
|
||||
if m:
|
||||
n_dense_layers = int(m.group(1))
|
||||
|
||||
if n_dense_layers == 1:
|
||||
layers = [ml.Dense(N, n_features, 1, activation='id')]
|
||||
elif n_dense_layers > 1:
|
||||
layers = [ml.Dense(N, n_features, n_inner, activation='relu')]
|
||||
for i in range(n_dense_layers - 2):
|
||||
layers += [ml.Dense(N, n_inner, n_inner, activation='relu')]
|
||||
layers += [ml.Dense(N, n_inner, 1, activation='id')]
|
||||
else:
|
||||
raise CompilerError('number of dense layers not specified')
|
||||
|
||||
layers += [ml.Output.from_args(N, program)]
|
||||
|
||||
Y = sint.Array(n_test)
|
||||
X = sfix.Matrix(n_test, n_features)
|
||||
|
||||
if not ('no_acc' in program.args and 'no_loss' in program.args):
|
||||
layers[-1].Y.input_from(0)
|
||||
layers[0].X.input_from(0)
|
||||
Y.input_from(0)
|
||||
X.input_from(0)
|
||||
|
||||
sgd = ml.SGD(layers, 1)
|
||||
|
||||
if 'no_out' in program.args:
|
||||
del sgd.layers[-1]
|
||||
|
||||
if 'forward' in program.args:
|
||||
sgd.forward(batch=regint.Array(batch_size))
|
||||
elif 'backward' in program.args:
|
||||
sgd.backward(batch=regint.Array(batch_size))
|
||||
elif 'update' in program.args:
|
||||
sgd.update(0, batch=regint.Array(batch_size))
|
||||
else:
|
||||
sgd.run_by_args(program, n_epochs, batch_size, X, Y)
|
||||
99
Programs/Source/mnist_A.mpc
Normal file
99
Programs/Source/mnist_A.mpc
Normal file
@@ -0,0 +1,99 @@
|
||||
import ml
|
||||
import math
|
||||
|
||||
#ml.report_progress = True
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
approx = 3
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
N = 1000
|
||||
n_test = 100
|
||||
elif 'debug' in program.args:
|
||||
N = 10
|
||||
n_test = 10
|
||||
elif 'gisette' in program.args:
|
||||
print('Compiling for 4/9')
|
||||
N = 11791
|
||||
n_test = 1991
|
||||
else:
|
||||
N = 12665
|
||||
n_test = 2115
|
||||
|
||||
n_examples = N
|
||||
n_features = 28 ** 2
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except:
|
||||
n_epochs = 100
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
|
||||
assert batch_size <= N
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[3]))
|
||||
except:
|
||||
pass
|
||||
|
||||
if 'debug' in program.args:
|
||||
n_inner = 10
|
||||
n_features = 10
|
||||
else:
|
||||
n_inner = 128
|
||||
|
||||
if 'norelu' in program.args:
|
||||
activation = 'id'
|
||||
else:
|
||||
activation = 'relu'
|
||||
|
||||
layers = [ml.Dense(N, n_features, n_inner, activation=activation),
|
||||
ml.Dense(N, n_inner, n_inner, activation=activation),
|
||||
ml.Dense(N, n_inner, 1),
|
||||
ml.Output(N, approx=approx)]
|
||||
|
||||
if '2dense' in program.args:
|
||||
del layers[1]
|
||||
|
||||
layers[-1].Y.input_from(0)
|
||||
layers[0].X.input_from(0)
|
||||
|
||||
Y = sint.Array(n_test)
|
||||
X = sfix.Matrix(n_test, n_features)
|
||||
Y.input_from(0)
|
||||
X.input_from(0)
|
||||
|
||||
sgd = ml.SGD(layers, 10, report_loss=True)
|
||||
sgd.reset()
|
||||
|
||||
@for_range(int(math.ceil(n_epochs / 10)))
|
||||
def _(i):
|
||||
start_timer(1)
|
||||
sgd.run(batch_size)
|
||||
stop_timer(1)
|
||||
|
||||
def get_correct(Y, n):
|
||||
n_correct = regint(0)
|
||||
for i in range(n):
|
||||
n_correct += (Y[i].reveal() > 0).bit_xor(
|
||||
layers[-2].Y[i][0][0][0].reveal() < 0)
|
||||
return n_correct
|
||||
|
||||
sgd.forward(N)
|
||||
|
||||
n_correct = get_correct(layers[-1].Y, N)
|
||||
print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N)
|
||||
|
||||
training_address = layers[0].X.address
|
||||
layers[0].X.address = X.address
|
||||
sgd.forward(n_test)
|
||||
layers[0].X.address = training_address
|
||||
|
||||
n_correct = get_correct(Y, n_test)
|
||||
print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test)
|
||||
109
Programs/Source/mnist_full_A.mpc
Normal file
109
Programs/Source/mnist_full_A.mpc
Normal file
@@ -0,0 +1,109 @@
|
||||
import ml
|
||||
import math
|
||||
import re
|
||||
import util
|
||||
|
||||
#ml.report_progress = True
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
N = 1000
|
||||
n_test = 100
|
||||
elif 'debug' in program.args:
|
||||
N = 100
|
||||
n_test = 100
|
||||
else:
|
||||
N = 60000
|
||||
n_test = 10000
|
||||
|
||||
n_examples = N
|
||||
n_features = 28 ** 2
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except:
|
||||
n_epochs = 100
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
|
||||
assert batch_size <= N
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[3]))
|
||||
except:
|
||||
pass
|
||||
|
||||
n_inner = 128
|
||||
|
||||
if 'norelu' in program.args:
|
||||
activation = 'id'
|
||||
else:
|
||||
activation = 'relu'
|
||||
|
||||
if 'nearest' in program.args:
|
||||
sfix.round_nearest = True
|
||||
|
||||
if 'double' in program.args:
|
||||
sfix.set_precision(32, 63)
|
||||
cfix.set_precision(32, 63)
|
||||
elif 'triple' in program.args:
|
||||
sfix.set_precision(48, 91)
|
||||
cfix.set_precision(48, 91)
|
||||
elif 'quadruple' in program.args:
|
||||
sfix.set_precision(64, 127)
|
||||
cfix.set_precision(64, 127)
|
||||
elif 'sextuple' in program.args:
|
||||
sfix.set_precision(96, 191)
|
||||
cfix.set_precision(96, 191)
|
||||
elif 'octuple' in program.args:
|
||||
sfix.set_precision(128, 255)
|
||||
cfix.set_precision(128, 255)
|
||||
|
||||
assert sfix.f * 4 == int(program.options.ring)
|
||||
|
||||
debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2)
|
||||
|
||||
if '1dense' in program.args:
|
||||
layers = [ml.Dense(N, n_features, 10, debug=debug_ml)]
|
||||
else:
|
||||
layers = [ml.Dense(N, n_features, n_inner, activation=activation, debug=debug_ml),
|
||||
ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml),
|
||||
ml.Dense(N, n_inner, 10, debug=debug_ml)]
|
||||
|
||||
layers += [ml.MultiOutput.from_args(program, N, 10)]
|
||||
|
||||
layers[-1].cheaper_loss = 'mse' in program.args
|
||||
|
||||
if '2dense' in program.args:
|
||||
del layers[1]
|
||||
|
||||
layers[-1].Y.input_from(0)
|
||||
layers[0].X.input_from(0)
|
||||
|
||||
Y = sint.Matrix(n_test, 10)
|
||||
X = sfix.Matrix(n_test, n_features)
|
||||
Y.input_from(0)
|
||||
X.input_from(0)
|
||||
|
||||
if 'always_acc' in program.args:
|
||||
n_part_epochs = 1
|
||||
else:
|
||||
n_part_epochs = 10
|
||||
|
||||
sgd = ml.SGD(layers, n_part_epochs, report_loss=True, debug=debug_ml)
|
||||
#sgd.print_update_average = True
|
||||
sgd.print_losses = 'print_losses' in program.args
|
||||
|
||||
if 'faster' in program.args:
|
||||
sgd.gamma = MemValue(cfix(.1))
|
||||
|
||||
if 'slower' in program.args:
|
||||
sgd.gamma = MemValue(cfix(.001))
|
||||
|
||||
sgd.run_by_args(program, int(math.ceil(n_epochs / n_part_epochs)), batch_size,
|
||||
X, Y)
|
||||
59
Programs/Source/mnist_logreg.mpc
Normal file
59
Programs/Source/mnist_logreg.mpc
Normal file
@@ -0,0 +1,59 @@
|
||||
import ml
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
approx = 3
|
||||
|
||||
if 'gisette' in program.args:
|
||||
print('Compiling for 4/9')
|
||||
N = 11791
|
||||
n_test = 1991
|
||||
else:
|
||||
N = 12665
|
||||
n_test = 2115
|
||||
|
||||
n_examples = N
|
||||
n_features = 28 ** 2
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except:
|
||||
n_epochs = 100
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except:
|
||||
batch_size = N
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[3]))
|
||||
except:
|
||||
pass
|
||||
|
||||
layers = [ml.Dense(N, n_features, 1),
|
||||
ml.Output(N, approx=approx)]
|
||||
|
||||
layers[1].Y.input_from(0)
|
||||
layers[0].X.input_from(0)
|
||||
|
||||
Y = sint.Array(n_test)
|
||||
X = sfix.Matrix(n_test, n_features)
|
||||
Y.input_from(0)
|
||||
X.input_from(0)
|
||||
|
||||
sgd = ml.SGD(layers, n_epochs, report_loss=True)
|
||||
sgd.reset()
|
||||
|
||||
start_timer(1)
|
||||
sgd.run(batch_size)
|
||||
stop_timer(1)
|
||||
|
||||
layers[0].X.assign(X)
|
||||
sgd.forward(n_test)
|
||||
|
||||
n_correct = cfix(0)
|
||||
|
||||
for i in range(n_test):
|
||||
n_correct += Y[i].reveal().bit_xor(layers[0].Y[i][0][0][0].reveal() < 0)
|
||||
|
||||
print_ln('acc: %s (%s/%s)', n_correct / n_test, n_correct, n_test)
|
||||
@@ -37,6 +37,12 @@ if len(program.args) > 2:
|
||||
n_normal = 49
|
||||
n_features = 17814
|
||||
|
||||
if 'mnist' in program.args:
|
||||
print('Compiling for MNIST')
|
||||
n_examples = 2115
|
||||
n_normal = 980
|
||||
n_features = 28 ** 2
|
||||
|
||||
n_pos = n_examples - n_normal
|
||||
n_epochs = 1
|
||||
if len(program.args) > 1:
|
||||
|
||||
@@ -15,6 +15,10 @@ class FakeInput : public InputBase<T>
|
||||
PointerVector<T> results;
|
||||
|
||||
public:
|
||||
FakeInput()
|
||||
{
|
||||
}
|
||||
|
||||
FakeInput(SubProcessor<T>&, typename T::MAC_Check&)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -77,6 +77,12 @@ public:
|
||||
a = bit;
|
||||
b = bit;
|
||||
}
|
||||
|
||||
void get_one_no_count(Dtype dtype, T& a)
|
||||
{
|
||||
assert(dtype == DATA_BIT);
|
||||
a = G.get_uchar() & 1;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKEPREP_H_ */
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define PROTOCOLS_FAKEPROTOCOL_H_
|
||||
|
||||
#include "Replicated.h"
|
||||
#include "Math/Z2k.h"
|
||||
|
||||
template<class T>
|
||||
class FakeProtocol : public ProtocolBase<T>
|
||||
@@ -14,6 +15,10 @@ class FakeProtocol : public ProtocolBase<T>
|
||||
PointerVector<T> results;
|
||||
SeededPRNG G;
|
||||
|
||||
T dot_prod;
|
||||
|
||||
T trunc_max;
|
||||
|
||||
public:
|
||||
Player& P;
|
||||
|
||||
@@ -21,6 +26,27 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
~FakeProtocol()
|
||||
{
|
||||
output_trunc_max<0>(T::invertible);
|
||||
}
|
||||
|
||||
template<int>
|
||||
void output_trunc_max(false_type)
|
||||
{
|
||||
if (trunc_max != T())
|
||||
cerr << "Maximum bit length in truncation: "
|
||||
<< (bigint(typename T::clear(trunc_max)).numBits() + 1)
|
||||
<< " (" << trunc_max << ")" << endl;
|
||||
}
|
||||
|
||||
template<int>
|
||||
void output_trunc_max(true_type)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
void init_mul(SubProcessor<T>*)
|
||||
{
|
||||
results.clear();
|
||||
@@ -41,6 +67,28 @@ public:
|
||||
return results.next();
|
||||
}
|
||||
|
||||
void init_dotprod(SubProcessor<T>* proc)
|
||||
{
|
||||
init_mul(proc);
|
||||
dot_prod = {};
|
||||
}
|
||||
|
||||
void prepare_dotprod(const T& x, const T& y)
|
||||
{
|
||||
dot_prod += x * y;
|
||||
}
|
||||
|
||||
void next_dotprod()
|
||||
{
|
||||
results.push_back(dot_prod);
|
||||
dot_prod = 0;
|
||||
}
|
||||
|
||||
T finalize_dotprod(int)
|
||||
{
|
||||
return finalize_mul();
|
||||
}
|
||||
|
||||
void randoms(T& res, int n_bits)
|
||||
{
|
||||
res.randomize_part(G, n_bits);
|
||||
@@ -52,11 +100,63 @@ public:
|
||||
}
|
||||
|
||||
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc)
|
||||
{
|
||||
trunc_pr<0>(regs, size, proc, T::characteristic_two);
|
||||
}
|
||||
|
||||
template<int>
|
||||
void trunc_pr(const vector<int>&, int, SubProcessor<T>&, true_type)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template<int>
|
||||
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc, false_type)
|
||||
{
|
||||
for (size_t i = 0; i < regs.size(); i += 4)
|
||||
for (int l = 0; l < size; l++)
|
||||
proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l)
|
||||
>> regs[i + 3];
|
||||
{
|
||||
auto& res = proc.get_S_ref(regs[i] + l);
|
||||
auto& source = proc.get_S_ref(regs[i + 1] + l);
|
||||
T tmp = source - (T(1) << regs[i + 2] - 1);
|
||||
tmp = tmp < T() ? (T() - tmp) : tmp;
|
||||
trunc_max = max(trunc_max, tmp);
|
||||
#ifdef CHECK_BOUNDS_IN_TRUNC_PR_EMULATION
|
||||
auto test = (source >> (regs[i + 2]));
|
||||
if (test != 0)
|
||||
{
|
||||
cerr << typename T::clear(source) << " has more than "
|
||||
<< regs[i + 2]
|
||||
<< " bits in " << regs[i + 3]
|
||||
<< "-bit truncation (test value "
|
||||
<< typename T::clear(test) << ")" << endl;
|
||||
throw runtime_error("trunc_pr overflow");
|
||||
}
|
||||
#endif
|
||||
int n_shift = regs[i + 3];
|
||||
#ifdef ROUND_NEAREST_IN_EMULATION
|
||||
res = source >> n_shift;
|
||||
if (n_shift > 0)
|
||||
{
|
||||
bool overflow = T(source >> (n_shift - 1)).get_bit(0);
|
||||
res += overflow;
|
||||
}
|
||||
#else
|
||||
#ifdef RISKY_TRUNCATION_IN_EMULATION
|
||||
T r;
|
||||
r.randomize(G);
|
||||
|
||||
if (source.negative())
|
||||
res = -T(((-source + r) >> n_shift) - (r >> n_shift));
|
||||
else
|
||||
res = ((source + r) >> n_shift) - (r >> n_shift);
|
||||
#else
|
||||
T r;
|
||||
r.randomize_part(G, n_shift - 1);
|
||||
res = (source + r) >> n_shift;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ protected:
|
||||
public:
|
||||
int values_opened;
|
||||
|
||||
MAC_Check_Base() : values_opened(0) {}
|
||||
MAC_Check_Base(const typename T::mac_key_type::Scalar& mac_key = { }) :
|
||||
alphai(mac_key), values_opened(0) {}
|
||||
virtual ~MAC_Check_Base() {}
|
||||
|
||||
virtual void Check(const Player& P) { (void)P; }
|
||||
|
||||
@@ -36,19 +36,12 @@ public:
|
||||
void buffer_bits();
|
||||
};
|
||||
|
||||
// extra class to avoid recursion
|
||||
template<class T>
|
||||
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
|
||||
public virtual MalRepRingPrep<T>,
|
||||
class SimplerMalRepRingPrep : public virtual MalRepRingPrep<T>,
|
||||
public virtual RingOnlyBitsFromSquaresPrep<T>
|
||||
{
|
||||
public:
|
||||
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
MaliciousRingPrep<T>::set_protocol(protocol);
|
||||
}
|
||||
SimplerMalRepRingPrep(SubProcessor<T>* proc, DataPositions& usage);
|
||||
|
||||
void buffer_triples()
|
||||
{
|
||||
@@ -72,4 +65,27 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
|
||||
public virtual SimplerMalRepRingPrep<T>
|
||||
{
|
||||
public:
|
||||
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
MaliciousRingPrep<T>::set_protocol(protocol);
|
||||
}
|
||||
|
||||
void buffer_squares()
|
||||
{
|
||||
MalRepRingPrep<T>::buffer_squares();
|
||||
}
|
||||
|
||||
void buffer_bits()
|
||||
{
|
||||
RingOnlyBitsFromSquaresPrep<T>::buffer_bits();
|
||||
};
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_MALREPRINGPREP_H_ */
|
||||
|
||||
@@ -27,13 +27,22 @@ RingOnlyBitsFromSquaresPrep<T>::RingOnlyBitsFromSquaresPrep(SubProcessor<T>*,
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
SimplerMalRepRingPrep<T>::SimplerMalRepRingPrep(SubProcessor<T>* proc,
|
||||
DataPositions& usage) :
|
||||
BufferPrep<T>(usage), MalRepRingPrep<T>(proc, usage),
|
||||
RingOnlyBitsFromSquaresPrep<T>(proc, usage)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
MalRepRingPrepWithBits<T>::MalRepRingPrepWithBits(SubProcessor<T>* proc,
|
||||
DataPositions& usage) :
|
||||
BufferPrep<T>(usage), BitPrep<T>(proc, usage),
|
||||
RingPrep<T>(proc, usage),
|
||||
MaliciousRingPrep<T>(proc, usage), MalRepRingPrep<T>(proc, usage),
|
||||
RingOnlyBitsFromSquaresPrep<T>(proc, usage)
|
||||
RingOnlyBitsFromSquaresPrep<T>(proc, usage),
|
||||
SimplerMalRepRingPrep<T>(proc, usage)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -54,6 +63,7 @@ void MalRepRingPrep<T>::buffer_squares()
|
||||
MaliciousRepPrep<prep_type> prep(_);
|
||||
assert(this->proc != 0);
|
||||
prep.init_honest(this->proc->P);
|
||||
prep.buffer_size = this->buffer_size;
|
||||
prep.buffer_squares();
|
||||
for (auto& x : prep.squares)
|
||||
this->squares.push_back({{x[0], x[1]}});
|
||||
@@ -68,6 +78,7 @@ void MalRepRingPrep<T>::simple_buffer_triples()
|
||||
MaliciousRepPrep<prep_type> prep(_);
|
||||
assert(this->proc != 0);
|
||||
prep.init_honest(this->proc->P);
|
||||
prep.buffer_size = this->buffer_size;
|
||||
prep.buffer_triples();
|
||||
for (auto& x : prep.triples)
|
||||
this->triples.push_back({{x[0], x[1], x[2]}});
|
||||
@@ -222,7 +233,7 @@ void RingOnlyBitsFromSquaresPrep<T>::buffer_bits()
|
||||
typename BitShare::SquarePrep prep(0, usage);
|
||||
SubProcessor<BitShare> bit_proc(MC, prep, proc->P);
|
||||
prep.set_proc(&bit_proc);
|
||||
bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep);
|
||||
bits_from_square_in_ring(this->bits, this->buffer_size, &prep);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
template<class T> class HashMaliciousRepMC;
|
||||
template<class T> class Beaver;
|
||||
template<class T> class MaliciousRepPrepWithBits;
|
||||
template<class T> class MaliciousRepPO;
|
||||
template<class T> class MaliciousRepPrep;
|
||||
|
||||
namespace GC
|
||||
@@ -22,6 +23,7 @@ template<class T>
|
||||
class MaliciousRep3Share : public Rep3Share<T>
|
||||
{
|
||||
typedef Rep3Share<T> super;
|
||||
typedef MaliciousRep3Share This;
|
||||
|
||||
public:
|
||||
typedef Beaver<MaliciousRep3Share<T>> Protocol;
|
||||
@@ -29,11 +31,13 @@ public:
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<MaliciousRep3Share<T>> Input;
|
||||
typedef ::PrivateOutput<MaliciousRep3Share<T>> PrivateOutput;
|
||||
typedef MaliciousRepPO<MaliciousRep3Share> PO;
|
||||
typedef Rep3Share<T> Honest;
|
||||
typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep;
|
||||
typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep;
|
||||
typedef MaliciousRep3Share prep_type;
|
||||
typedef T random_type;
|
||||
typedef This Scalar;
|
||||
|
||||
typedef GC::MaliciousRepSecret bit_type;
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ void HashMaliciousRepMC<T>::Check(const Player& P)
|
||||
P.Broadcast_Receive(os);
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
if (os[i] != os[P.my_num()])
|
||||
throw mac_fail();
|
||||
throw mac_fail("check hash mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
27
Protocols/MaliciousRepPO.h
Normal file
27
Protocols/MaliciousRepPO.h
Normal file
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* MaliciousRepPO.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_MALICIOUSREPPO_H_
|
||||
#define PROTOCOLS_MALICIOUSREPPO_H_
|
||||
|
||||
#include "Networking/Player.h"
|
||||
|
||||
template<class T>
|
||||
class MaliciousRepPO
|
||||
{
|
||||
Player& P;
|
||||
octetStream to_send;
|
||||
octetStream to_receive[2];
|
||||
|
||||
public:
|
||||
MaliciousRepPO(Player& P);
|
||||
|
||||
void prepare_sending(const T& secret, int player);
|
||||
void send(int player);
|
||||
void receive();
|
||||
typename T::clear finalize(const T& secret);
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_MALICIOUSREPPO_H_ */
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user