Convolutional neural network training.

This commit is contained in:
Marcel Keller
2021-07-02 15:49:23 +10:00
parent f35447df51
commit 99c0549e72
208 changed files with 3615 additions and 1865 deletions

View File

@@ -167,7 +167,7 @@ void Node::Broadcast2(SendBuffer& msg) {
} }
void Node::_identify() { void Node::_identify() {
char* msg = id_msg; char msg[strlen(ID_HDR)+sizeof(_id)];
memcpy(msg, ID_HDR, strlen(ID_HDR)); memcpy(msg, ID_HDR, strlen(ID_HDR));
memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id)); memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
//printf("Node:: identifying myself:\n"); //printf("Node:: identifying myself:\n");

View File

@@ -78,8 +78,6 @@ private:
std::map<struct sockaddr_in*,int> _clientsmap; std::map<struct sockaddr_in*,int> _clientsmap;
bool* _clients_connected; bool* _clients_connected;
NodeUpdatable* _updatable; NodeUpdatable* _updatable;
char id_msg[strlen(ID_HDR)+sizeof(_id)];
}; };
#endif /* NETWORK_NODE_H_ */ #endif /* NETWORK_NODE_H_ */

View File

@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.5 (Jul 2, 2021)
- Training of convolutional neural networks
- Bit decomposition using edaBits
- Ability to force MAC checks from high-level code
- Ability to close client connection from high-level code
- Binary operators for comparison results
- Faster compilation for emulation
- More documentation
- Fixed security bug: insufficient LowGear secret key randomness
- Fixed security bug: skewed random bit generation
## 0.2.4 (Apr 19, 2021) ## 0.2.4 (Apr 19, 2021)
- ARM support - ARM support

View File

@@ -117,7 +117,7 @@ class xorm(NonVectorInstruction):
code = opcodes['XORM'] code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb'] arg_format = ['int','sbw','sb','cb']
class xorcb(NonVectorInstruction): class xorcb(BinaryVectorInstruction):
""" Bitwise XOR of two single clear bit registers. """ Bitwise XOR of two single clear bit registers.
:param: result (cbit) :param: result (cbit)
@@ -125,7 +125,7 @@ class xorcb(NonVectorInstruction):
:param: operand (cbit) :param: operand (cbit)
""" """
code = opcodes['XORCB'] code = opcodes['XORCB']
arg_format = ['cbw','cb','cb'] arg_format = ['int','cbw','cb','cb']
class xorcbi(NonVectorInstruction): class xorcbi(NonVectorInstruction):
""" Bitwise XOR of single clear bit register and immediate. """ Bitwise XOR of single clear bit register and immediate.

View File

@@ -36,6 +36,7 @@ class bits(Tape.Register, _structure, _bit):
class bitsn(cls): class bitsn(cls):
n = length n = length
cls.types[length] = bitsn cls.types[length] = bitsn
bitsn.clear_type = cbits.get_type(length)
bitsn.__name__ = cls.__name__ + str(length) bitsn.__name__ = cls.__name__ + str(length)
return cls.types[length] return cls.types[length]
@classmethod @classmethod
@@ -115,7 +116,11 @@ class bits(Tape.Register, _structure, _bit):
return res return res
def store_in_mem(self, address): def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address) self.store_inst[isinstance(address, int)](self, address)
@classmethod
def new(cls, value=None, n=None):
return cls.get_type(n)(value)
def __init__(self, value=None, n=None, size=None): def __init__(self, value=None, n=None, size=None):
assert n == self.n or n is None
if size != 1 and size is not None: if size != 1 and size is not None:
raise Exception('invalid size for bit type: %s' % size) raise Exception('invalid size for bit type: %s' % size)
self.n = n or self.n self.n = n or self.n
@@ -125,7 +130,7 @@ class bits(Tape.Register, _structure, _bit):
if value is not None: if value is not None:
self.load_other(value) self.load_other(value)
def copy(self): def copy(self):
return type(self)(n=instructions_base.get_global_vector_size()) return type(self).new(n=instructions_base.get_global_vector_size())
def set_length(self, n): def set_length(self, n):
if n > self.n: if n > self.n:
raise Exception('too long: %d/%d' % (n, self.n)) raise Exception('too long: %d/%d' % (n, self.n))
@@ -154,6 +159,8 @@ class bits(Tape.Register, _structure, _bit):
bits = other.bit_decompose() bits = other.bit_decompose()
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits)) bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
other = self.bit_compose(bits) other = self.bit_compose(bits)
assert(isinstance(other, type(self)))
assert(other.n == self.n)
self.load_other(other) self.load_other(other)
except: except:
raise CompilerError('cannot convert %s/%s from %s to %s' % \ raise CompilerError('cannot convert %s/%s from %s to %s' % \
@@ -176,6 +183,16 @@ class bits(Tape.Register, _structure, _bit):
res.i = i res.i = i
res.program = self.program res.program = self.program
return res return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
class cbits(bits): class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """ """ Clear bits register. Helper type with limited functionality. """
@@ -202,14 +219,16 @@ class cbits(bits):
inst.stmsdci(self, cbits.conv(address)) inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op): def clear_op(self, other, c_inst, ci_inst, op):
if isinstance(other, cbits): if isinstance(other, cbits):
res = cbits(n=max(self.n, other.n)) res = cbits.get_type(max(self.n, other.n))()
c_inst(res, self, other) c_inst(res, self, other)
return res return res
elif isinstance(other, sbits):
return NotImplemented
else: else:
if util.is_constant(other): if util.is_constant(other):
if other >= 2**31 or other < -2**31: if other >= 2**31 or other < -2**31:
return op(self, cbits(other)) return op(self, cbits(other))
res = cbits(n=max(self.n, len(bin(other)) - 2)) res = cbits.get_type(max(self.n, len(bin(other)) - 2))()
ci_inst(res, self, other) ci_inst(res, self, other)
return res return res
else: else:
@@ -221,8 +240,14 @@ class cbits(bits):
def __xor__(self, other): def __xor__(self, other):
if isinstance(other, (sbits, sbitvec)): if isinstance(other, (sbits, sbitvec)):
return NotImplemented return NotImplemented
elif isinstance(other, cbits):
res = cbits.get_type(max(self.n, other.n))()
assert res.size == self.size
assert res.size == other.size
inst.xorcb(res.n, res, self, other)
return res
else: else:
self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor) return self.clear_op(other, None, inst.xorcbi, operator.xor)
__radd__ = __add__ __radd__ = __add__
__rxor__ = __xor__ __rxor__ = __xor__
def __mul__(self, other): def __mul__(self, other):
@@ -230,17 +255,18 @@ class cbits(bits):
return NotImplemented return NotImplemented
else: else:
try: try:
res = cbits(n=min(self.max_length, self.n+util.int_len(other))) res = cbits.get_type(min(self.max_length,
self.n+util.int_len(other)))()
inst.mulcbi(res, self, other) inst.mulcbi(res, self, other)
return res return res
except TypeError: except TypeError:
return NotImplemented return NotImplemented
def __rshift__(self, other): def __rshift__(self, other):
res = cbits(n=self.n-other) res = cbits.new(n=self.n-other)
inst.shrcbi(res, self, other) inst.shrcbi(res, self, other)
return res return res
def __lshift__(self, other): def __lshift__(self, other):
res = cbits(n=self.n+other) res = cbits.get_type(self.n+other)()
inst.shlcbi(res, self, other) inst.shlcbi(res, self, other)
return res return res
def print_reg(self, desc=''): def print_reg(self, desc=''):
@@ -504,16 +530,6 @@ class sbits(bits):
res = [cls.new(n=len(rows)) for i in range(n_columns)] res = [cls.new(n=len(rows)) for i in range(n_columns)]
inst.trans(len(res), *(res + rows)) inst.trans(len(res), *(res + rows))
return res return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
@staticmethod @staticmethod
def bit_adder(*args, **kwargs): def bit_adder(*args, **kwargs):
return sbitint.bit_adder(*args, **kwargs) return sbitint.bit_adder(*args, **kwargs)
@@ -610,7 +626,7 @@ class sbitvec(_vec):
elif isinstance(other, (list, tuple)): elif isinstance(other, (list, tuple)):
self.v = self.bit_extend(sbitvec(other).v, n) self.v = self.bit_extend(sbitvec(other).v, n)
else: else:
self.v = sbits(other, n=n).bit_decompose(n) self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n assert len(self.v) == n
@classmethod @classmethod
def load_mem(cls, address): def load_mem(cls, address):
@@ -630,6 +646,8 @@ class sbitvec(_vec):
for i in range(n): for i in range(n):
v[i].store_in_mem(address + i) v[i].store_in_mem(address + i)
def reveal(self): def reveal(self):
if len(self) > cbits.unit:
return self.elements()[0].reveal()
revealed = [cbit() for i in range(len(self))] revealed = [cbit() for i in range(len(self))]
for i in range(len(self)): for i in range(len(self)):
try: try:
@@ -784,15 +802,23 @@ class bit(object):
def result_conv(x, y): def result_conv(x, y):
try: try:
def f(res):
try:
return t.conv(res)
except:
return res
if util.is_constant(x): if util.is_constant(x):
if util.is_constant(y): if util.is_constant(y):
return lambda x: x return lambda x: x
else: else:
return type(y).conv t = type(y)
return f
if util.is_constant(y): if util.is_constant(y):
return type(x).conv t = type(x)
return f
if type(x) is type(y): if type(x) is type(y):
return type(x).conv t = type(x)
return f
except AttributeError: except AttributeError:
pass pass
return lambda x: x return lambda x: x
@@ -807,13 +833,19 @@ class sbit(bit, sbits):
This will output 5. This will output 5.
""" """
return result_conv(x, y)(self * (x ^ y) ^ y) assert self.n == 1
diff = x ^ y
if isinstance(diff, cbits):
return result_conv(x, y)(self & (diff) ^ y)
else:
return result_conv(x, y)(self * (diff) ^ y)
class cbit(bit, cbits): class cbit(bit, cbits):
pass pass
sbits.bit_type = sbit sbits.bit_type = sbit
cbits.bit_type = cbit cbits.bit_type = cbit
sbit.clear_type = cbit
class bitsBlock(oram.Block): class bitsBlock(oram.Block):
value_type = sbits value_type = sbits
@@ -881,7 +913,7 @@ class _sbitintbase:
return self.get_type(k - m).compose(res_bits) return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None): def int_div(self, other, bit_length=None):
k = bit_length or max(self.n, other.n) k = bit_length or max(self.n, other.n)
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k) return (library.IntDiv(self.cast(k), other.cast(k), k) >> k).cast(k)
def Norm(self, k, f, kappa=None, simplex_flag=False): def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self) absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding #next 2 lines actually compute the SufOR for little indian encoding
@@ -1100,7 +1132,8 @@ class cbitfix(object):
bits = self.v.bit_decompose(self.k) bits = self.v.bit_decompose(self.k)
sign = bits[-1] sign = bits[-1]
v += (sign << (self.k)) * -1 v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0)) inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
class sbitfix(_fix): class sbitfix(_fix):
""" Secret signed integer in one binary register. """ Secret signed integer in one binary register.

View File

@@ -123,7 +123,7 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates): for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x) to_check.add(x)
free[reg.reg_type, base.size].add(self.alloc[base]) free[reg.reg_type, base.size].append(self.alloc[base])
if inst.is_vec() and base.vector: if inst.is_vec() and base.vector:
self.defined[base] = inst self.defined[base] = inst
for i in base.vector: for i in base.vector:
@@ -604,4 +604,4 @@ class RegintOptimizer:
elif op == 1: elif op == 1:
instructions[i] = None instructions[i] = None
inst.args[0].link(inst.args[1]) inst.args[0].link(inst.args[1])
instructions[:] = filter(lambda x: x is not None, instructions) instructions[:] = list(filter(lambda x: x is not None, instructions))

View File

@@ -127,7 +127,7 @@ def sha3_256(x):
from circuit import sha3_256 from circuit import sha3_256
a = sbitvec.from_vec([]) a = sbitvec.from_vec([])
b = sbitvec(sint(0xcc), 8) b = sbitvec(sint(0xcc), 8, 8)
for x in a, b: for x in a, b:
sha3_256(x).elements()[0].reveal().print_reg() sha3_256(x).elements()[0].reveal().print_reg()

View File

@@ -73,6 +73,7 @@ def require_ring_size(k, op):
if int(program.options.ring) < k: if int(program.options.ring) < k:
raise CompilerError('ring size too small for %s, compile ' raise CompilerError('ring size too small for %s, compile '
'with \'-R %d\' or more' % (op, k)) 'with \'-R %d\' or more' % (op, k))
program.curr_tape.require_bit_length(k)
@instructions_base.cisc @instructions_base.cisc
def LTZ(s, a, k, kappa): def LTZ(s, a, k, kappa):

View File

@@ -55,7 +55,7 @@ def EQZ(a, k, kappa):
from GC.types import sbitvec from GC.types import sbitvec
v = sbitvec(a, k).v v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v)) bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sint.conv(bit) return types.sintbit.conv(bit)
prog.non_linear.check_security(kappa) prog.non_linear.check_security(kappa)
return prog.non_linear.eqz(a, k) return prog.non_linear.eqz(a, k)
@@ -263,16 +263,17 @@ def BitAdd(a, b, bits_to_compute=None):
def BitDec(a, k, m, kappa, bits_to_compute=None): def BitDec(a, k, m, kappa, bits_to_compute=None):
return program.Program.prog.non_linear.bit_dec(a, k, m) return program.Program.prog.non_linear.bit_dec(a, k, m)
def BitDecRing(a, k, m): def BitDecRingRaw(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m n_shift = int(program.Program.prog.options.ring) - m
assert(n_shift >= 0) assert(n_shift >= 0)
if program.Program.prog.use_split(): if program.Program.prog.use_split():
x = a.split_to_two_summands(m) x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
# reversing to reduce number of rounds return bits[:m]
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
else: else:
if program.Program.prog.use_dabit: if program.Program.prog.use_edabit():
r, r_bits = types.sint.get_edabit(m, strict=False)
elif program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r = types.sint.bit_compose(r) r = types.sint.bit_compose(r)
else: else:
@@ -281,7 +282,12 @@ def BitDecRing(a, k, m):
shifted = ((a - r) << n_shift).reveal() shifted = ((a - r) << n_shift).reveal()
masked = shifted >> n_shift masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return [types.sint.conv(bit) for bit in bits] return bits
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
r_dprime = types.sint() r_dprime = types.sint()
@@ -429,7 +435,7 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
s = (1 - overflow) * t + overflow * t / 2 s = (1 - overflow) * t + overflow * t / 2
return s, overflow return s, overflow
def Int2FL(a, gamma, l, kappa): def Int2FL(a, gamma, l, kappa=None):
lam = gamma - 1 lam = gamma - 1
s = a.less_than(0, gamma, security=kappa) s = a.less_than(0, gamma, security=kappa)
z = a.equal(0, gamma, security=kappa) z = a.equal(0, gamma, security=kappa)
@@ -598,13 +604,13 @@ def SDiv_mono(a, b, l, kappa):
# Unconditionally Secure Constant-Rounds Multi-party Computation # Unconditionally Secure Constant-Rounds Multi-party Computation
# for Equality, Comparison, Bits and Exponentiation # for Equality, Comparison, Bits and Exponentiation
def BITLT(a, b, bit_length): def BITLT(a, b, bit_length):
sint = types.sint from .types import sint, regint, longint, cint
e = [sint(0)]*bit_length e = [None]*bit_length
g = [sint(0)]*bit_length g = [None]*bit_length
h = [sint(0)]*bit_length h = [None]*bit_length
for i in range(bit_length): for i in range(bit_length):
# Compute the XOR (reverse order of e for PreOpL) # Compute the XOR (reverse order of e for PreOpL)
e[bit_length-i-1] = a[i].bit_xor(b[i]) e[bit_length-i-1] = util.bit_xor(a[i], b[i])
f = PreOpL(or_op, e) f = PreOpL(or_op, e)
g[bit_length-1] = f[0] g[bit_length-1] = f[0]
for i in range(bit_length-1): for i in range(bit_length-1):
@@ -612,7 +618,7 @@ def BITLT(a, b, bit_length):
g[i] = f[bit_length-i-1]-f[bit_length-i-2] g[i] = f[bit_length-i-1]-f[bit_length-i-2]
ans = 0 ans = 0
for i in range(bit_length): for i in range(bit_length):
h[i] = g[i]*b[i] h[i] = g[i].bit_and(b[i])
ans = ans + h[i] ans = ans + h[i]
return ans return ans
@@ -620,9 +626,9 @@ def BITLT(a, b, bit_length):
# - From the paper # - From the paper
# Multiparty Computation for Interval, Equality, and Comparison without # Multiparty Computation for Interval, Equality, and Comparison without
# Bit-Decomposition Protocol # Bit-Decomposition Protocol
def BitDecFull(a): def BitDecFull(a, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint from .types import sint, regint, longint, cint
p = get_program().prime p = get_program().prime
assert p assert p
bit_length = p.bit_length() bit_length = p.bit_length()
@@ -631,9 +637,16 @@ def BitDecFull(a):
# inspired by Rabbit (https://eprint.iacr.org/2021/119) # inspired by Rabbit (https://eprint.iacr.org/2021/119)
# no need for exact randomness generation # no need for exact randomness generation
# if modulo a power of two is close enough # if modulo a power of two is close enough
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)] if get_program().use_edabit():
if logp != bit_length: b, bbits = sint.get_edabit(logp, True, size=a.size)
bbits += [sint(0, size=a.size)] if logp != bit_length:
from .GC.types import sbits
bbits += [sbits.get_type(a.size)(0)]
else:
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
b = sint.bit_compose(bbits)
if logp != bit_length:
bbits += [sint(0, size=a.size)]
else: else:
bbits = [sint(size=a.size) for i in range(bit_length)] bbits = [sint(size=a.size) for i in range(bit_length)]
tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)] tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)]
@@ -653,15 +666,21 @@ def BitDecFull(a):
for j in range(a.size): for j in range(a.size):
for i in range(bit_length): for i in range(bit_length):
movs(bbits[i][j], tbits[j][i]) movs(bbits[i][j], tbits[j][i])
b = sint.bit_compose(bbits) b = sint.bit_compose(bbits)
c = (a-b).reveal() c = (a-b).reveal()
t = (p-c).bit_decompose(bit_length) cmodp = c
t = bbits[0].bit_decompose_clear(p - c, bit_length)
c = longint(c, bit_length) c = longint(c, bit_length)
czero = (c==0) czero = (c==0)
q = 1-BITLT( bbits, t, bit_length) q = bbits[0].long_one() - BITLT(bbits, t, bit_length)
fbar=((1<<bit_length)+c-p).bit_decompose(bit_length) fbar = [bbits[0].clear_type.conv(cint(x))
fbard = c.bit_decompose(bit_length) for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
g = [(fbar[i] - fbard[i]) * q + fbard[i] for i in range(bit_length)] fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
h = BitAdd(bbits, g) g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
abits = [(1 - czero) * h[i] + czero * bbits[i] for i in range(bit_length)] h = bbits[0].bit_adder(bbits, g)
return abits abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
for i in range(bit_length)]
if maybe_mixed:
return abits
else:
return [sint.conv(bit) for bit in abits]

View File

@@ -1,4 +1,5 @@
import heapq import heapq
import collections
from Compiler.exceptions import * from Compiler.exceptions import *
class GraphError(CompilerError): class GraphError(CompilerError):
@@ -23,7 +24,7 @@ class SparseDiGraph(object):
self.n = max_nodes self.n = max_nodes
# each node contains list of default attributes, followed by outoing edges # each node contains list of default attributes, followed by outoing edges
self.nodes = [list(self.default_attributes.values()) for i in range(self.n)] self.nodes = [list(self.default_attributes.values()) for i in range(self.n)]
self.succ = [set() for i in range(self.n)] self.succ = [collections.OrderedDict() for i in range(self.n)]
self.pred = [set() for i in range(self.n)] self.pred = [set() for i in range(self.n)]
self.weights = {} self.weights = {}
@@ -32,7 +33,7 @@ class SparseDiGraph(object):
def __getitem__(self, i): def __getitem__(self, i):
""" Get list of the neighbours of node i """ """ Get list of the neighbours of node i """
return self.succ[i] return self.succ[i].keys()
def __iter__(self): def __iter__(self):
pass #return iter(self.nodes) pass #return iter(self.nodes)
@@ -68,7 +69,7 @@ class SparseDiGraph(object):
self.pred[v].remove(i) self.pred[v].remove(i)
#del self.weights[(i,v)] #del self.weights[(i,v)]
for v in pred: for v in pred:
self.succ[v].remove(i) del self.succ[v][i]
#del self.weights[(v,i)] #del self.weights[(v,i)]
#self.nodes[v].remove(i) #self.nodes[v].remove(i)
self.pred[i] = [] self.pred[i] = []
@@ -77,7 +78,7 @@ class SparseDiGraph(object):
def add_edge(self, i, j, weight=1): def add_edge(self, i, j, weight=1):
if j not in self[i]: if j not in self[i]:
self.pred[j].add(i) self.pred[j].add(i)
self.succ[i].add(j) self.succ[i][j] = None
self.weights[(i,j)] = weight self.weights[(i,j)] = weight
def add_edges_from(self, tuples): def add_edges_from(self, tuples):
@@ -89,7 +90,7 @@ class SparseDiGraph(object):
self.add_edge(edge[0], edge[1]) self.add_edge(edge[0], edge[1])
def remove_edge(self, i, j): def remove_edge(self, i, j):
self.succ[i].remove(j) del self.succ[i][j]
self.pred[j].remove(i) self.pred[j].remove(i)
del self.weights[(i,j)] del self.weights[(i,j)]

View File

@@ -2219,22 +2219,23 @@ class conv2ds(base.DataInstruction):
:param: number of channels (int) :param: number of channels (int)
:param: padding height (int) :param: padding height (int)
:param: padding width (int) :param: padding width (int)
:param: batch size (int)
""" """
code = base.opcodes['CONV2DS'] code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int', arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
'int','int','int'] 'int','int','int','int']
data_type = 'triple' data_type = 'triple'
is_vec = lambda self: True is_vec = lambda self: True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(conv2ds, self).__init__(*args, **kwargs) super(conv2ds, self).__init__(*args, **kwargs)
assert args[0].size == args[3] * args[4] assert args[0].size == args[3] * args[4] * args[14]
assert args[1].size == args[5] * args[6] * args[11] assert args[1].size == args[5] * args[6] * args[11] * args[14]
assert args[2].size == args[7] * args[8] * args[11] assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self): def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \ return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11] self.args[11] * self.args[14]
@base.vectorize @base.vectorize
class trunc_pr(base.VarArgsInstruction): class trunc_pr(base.VarArgsInstruction):
@@ -2250,6 +2251,15 @@ class trunc_pr(base.VarArgsInstruction):
code = base.opcodes['TRUNC_PR'] code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int']) arg_format = tools.cycle(['sw','s','int','int'])
class check(base.Instruction):
"""
Force MAC check in current thread and all idle thread if current
thread is the main thread.
"""
__slots__ = []
code = base.opcodes['CHECK']
arg_format = []
### ###
### CISC-style instructions ### CISC-style instructions
### ###
@@ -2289,47 +2299,5 @@ class lts(base.CISC):
subs(a, self.args[1], self.args[2]) subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
@base.vectorize
class g2muls(base.CISC):
r""" Secret GF(2) multiplication """
__slots__ = []
arg_format = ['sgw','sg','sg']
def expand(self):
s = [program.curr_block.new_reg('sg') for i in range(9)]
c = [program.curr_block.new_reg('cg') for i in range(3)]
gbittriple(s[0], s[1], s[2])
gsubs(s[3], self.args[1], s[0])
gsubs(s[4], self.args[2], s[1])
gasm_open(c[0], s[3])
gasm_open(c[1], s[4])
gmulbitm(s[5], s[1], c[0])
gmulbitm(s[6], s[0], c[1])
gmulbitc(c[2], c[0], c[1])
gadds(s[7], s[2], s[5])
gadds(s[8], s[7], s[6])
gaddm(self.args[0], s[8], c[2])
#@base.vectorize
#class gmulbits(base.CISC):
# r""" Secret $GF(2^n) \times GF(2)$ multiplication """
# __slots__ = []
# arg_format = ['sgw','sg','sg']
#
# def expand(self):
# s = [program.curr_block.new_reg('s') for i in range(9)]
# c = [program.curr_block.new_reg('c') for i in range(3)]
# g2ntriple(s[0], s[1], s[2])
# subs(s[3], self.args[1], s[0])
# subs(s[4], self.args[2], s[1])
# startopen(s[3], s[4])
# stopopen(c[0], c[1])
# mulm(s[5], s[1], c[0])
# mulm(s[6], s[0], c[1])
# mulc(c[2], c[0], c[1])
# adds(s[7], s[2], s[5])
# adds(s[8], s[7], s[6])
# addm(self.args[0], s[8], c[2])
# hack for circular dependency # hack for circular dependency
from Compiler import comparison from Compiler import comparison

View File

@@ -18,6 +18,8 @@ from Compiler import program
### MUST also be changed. (+ the documentation) ### MUST also be changed. (+ the documentation)
### ###
opcodes = dict( opcodes = dict(
# Emulation
CISC = 0,
# Load/store # Load/store
LDI = 0x1, LDI = 0x1,
LDSI = 0x2, LDSI = 0x2,
@@ -98,6 +100,7 @@ opcodes = dict(
MATMULS = 0xAA, MATMULS = 0xAA,
MATMULSM = 0xAB, MATMULSM = 0xAB,
CONV2DS = 0xAC, CONV2DS = 0xAC,
CHECK = 0xAF,
# Data access # Data access
TRIPLE = 0x50, TRIPLE = 0x50,
BIT = 0x51, BIT = 0x51,
@@ -409,7 +412,7 @@ def cisc(function):
program.curr_block.instructions.append(self) program.curr_block.instructions.append(self)
def get_def(self): def get_def(self):
return [self.args[0]] return [call[0][0] for call in self.calls]
def get_used(self): def get_used(self):
return self.used return self.used
@@ -423,6 +426,7 @@ def cisc(function):
def merge(self, other): def merge(self, other):
self.calls += other.calls self.calls += other.calls
self.used += other.used
def get_size(self): def get_size(self):
return self.args[0].size return self.args[0].size
@@ -470,7 +474,9 @@ def cisc(function):
inst.copy(size, subs) inst.copy(size, subs)
reset_global_vector_size() reset_global_vector_size()
def expand_merged(self): def expand_merged(self, skip):
if function.__name__ in skip:
return [self], 0
tape = program.curr_tape tape = program.curr_tape
block = tape.BasicBlock(tape, None, None) block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block tape.active_basicblock = block
@@ -496,10 +502,38 @@ def cisc(function):
reg.mov(reg, new_regs[0].get_vector(base, reg.size)) reg.mov(reg, new_regs[0].get_vector(base, reg.size))
reset_global_vector_size() reset_global_vector_size()
base += reg.size base += reg.size
return block.instructions return block.instructions, self.n_rounds - 1
def expanded_rounds(self): def add_usage(self, *args):
return self.n_rounds - 1 pass
def get_bytes(self):
assert len(self.kwargs) < 2
res = int_to_bytes(opcodes['CISC'])
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
name = self.function.__name__
String.check(name)
res += String.encode(name)
for call in self.calls:
assert not call[1]
res += int_to_bytes(len(call[0]) + 2)
res += int_to_bytes(call[0][0].size)
for arg in call[0]:
res += self.arg_to_bytes(arg)
return bytearray(res)
@classmethod
def arg_to_bytes(self, arg):
if arg is None:
return int_to_bytes(0)
try:
return int_to_bytes(arg.i)
except:
return int_to_bytes(arg)
def __str__(self):
return self.function.__name__ + ' ' + ', '.join(
str(x) for x in itertools.chain(call[0] for call in self.calls))
MergeCISC.__name__ = function.__name__ MergeCISC.__name__ = function.__name__
@@ -804,11 +838,8 @@ class Instruction(object):
else: else:
return self.args return self.args
def expand_merged(self): def expand_merged(self, skip):
return [self] return [self], 0
def expanded_rounds(self):
return 0
def get_new_args(self, size, subs): def get_new_args(self, size, subs):
new_args = [] new_args = []

View File

@@ -170,7 +170,7 @@ def print_ln_to(player, ss, *args):
Example:: Example::
print_ln_to(player, 'output for %s: %s', x.reveal_to(player)) print_ln_to(player, 'output for %s: %s', player, x.reveal_to(player))
""" """
cond = player == get_player_id() cond = player == get_player_id()
new_args = [] new_args = []
@@ -293,7 +293,9 @@ class Function:
self.compile_args = compile_args self.compile_args = compile_args
def __call__(self, *args): def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
get_reg_type = lambda x: regint if isinstance(x, int) else type(x) from .types import _types
get_reg_type = lambda x: \
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
if len(args) not in self.type_args: if len(args) not in self.type_args:
# first call # first call
type_args = collections.defaultdict(list) type_args = collections.defaultdict(list)
@@ -324,7 +326,8 @@ class Function:
j = 0 j = 0
for i_arg in type_args[reg_type]: for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type: if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch') raise CompilerError('type mismatch: "%s" not of type "%s"' %
(args[i_arg], reg_type))
store_in_mem(args[i_arg], bases[reg_type] + j) store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type) j += util.mem_size(reg_type)
return self.on_call(base, bases) return self.on_call(base, bases)
@@ -371,7 +374,7 @@ class FunctionBlock(Function):
parent_node = get_tape().req_node parent_node = get_tape().req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock block = get_tape().active_basicblock
block.alloc_pool = defaultdict(set) block.alloc_pool = defaultdict(list)
del parent_node.children[-1] del parent_node.children[-1]
self.node = get_tape().req_node self.node = get_tape().req_node
if get_program().verbose: if get_program().verbose:
@@ -763,22 +766,34 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
@function_block @function_block
def step(l): def step(l):
l = MemValue(l) l = MemValue(l)
@for_range_opt_multithread(n_threads, len(a) // k) m = 2 ** int(math.ceil(math.log(len(a), 2)))
@for_range_opt_multithread(n_threads, m // k)
def _(i): def _(i):
n_inner = l // k n_inner = l // k
j = i % n_inner j = i % n_inner
i //= n_inner i //= n_inner
base = i*l + j base = i*l + j
step = l//k step = l//k
def swap(base, step):
if m == len(a):
a[base], a[base + step] = \
cond_swap(a[base], a[base + step])
else:
# ignore values outside range
go = base + step < len(a)
x = a.maybe_get(go, base)
y = a.maybe_get(go, base + step)
tmp = cond_swap(x, y)
for i, idx in enumerate((base, base + step)):
a.maybe_set(go, idx, tmp[i])
if k == 2: if k == 2:
a[base], a[base+step] = \ swap(base, step)
cond_swap(a[base], a[base+step])
else: else:
@for_range_opt(n_innermost) @for_range_opt(n_innermost)
def f(i): def f(i):
m1 = step + i * 2 * step m1 = step + i * 2 * step
m2 = m1 + base m2 = m1 + base
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step]) swap(m2, step)
steps[key] = step steps[key] = step
steps[key](l) steps[key](l)
@@ -870,6 +885,8 @@ def for_range_parallel(n_parallel, n_loops):
""" """
Decorator to execute a loop :py:obj:`n_loops` up to Decorator to execute a loop :py:obj:`n_loops` up to
:py:obj:`n_parallel` loop bodies in parallel. :py:obj:`n_parallel` loop bodies in parallel.
Using any other control flow instruction inside the loop breaks
the optimization.
:param n_parallel: compile-time (int) :param n_parallel: compile-time (int)
:param n_loops: regint/cint/int :param n_loops: regint/cint/int
@@ -887,9 +904,12 @@ def for_range_parallel(n_parallel, n_loops):
def for_range_opt(n_loops, budget=None): def for_range_opt(n_loops, budget=None):
""" Execute loop bodies in parallel up to an optimization budget. """ Execute loop bodies in parallel up to an optimization budget.
This prevents excessive loop unrolling. The budget is respected This prevents excessive loop unrolling. The budget is respected
even with nested loops. Note that optimization is rather even with nested loops. Note that the optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_parallel` in this case. using :py:func:`for_range_parallel` in this case.
Using further control flow constructions inside other than
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
optimization.
:param n_loops: int/regint/cint :param n_loops: int/regint/cint
:param budget: number of instructions after which to start optimization (default is 100,000) :param budget: number of instructions after which to start optimization (default is 100,000)
@@ -1082,18 +1102,19 @@ def multithread(n_threads, n_items=None, max_size=None):
""" """
if n_items is None: if n_items is None:
n_items = n_threads n_items = n_threads
if max_size is None: if max_size is None or n_items <= max_size:
return map_reduce(n_threads, None, n_items, initializer=lambda: [], return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False) reducer=None, looping=False)
else: else:
def wrapper(function): def wrapper(function):
@multithread(n_threads, n_items) @multithread(n_threads, n_items)
def new_function(base, size): def new_function(base, size):
for i in range(0, size, max_size): @for_range(size // max_size)
part_base = base + i def _(i):
part_size = min(max_size, size - i) function(base + i * max_size, max_size)
function(part_base, part_size) rem = size % max_size
break_point() if rem:
function(base + size - rem, rem)
return wrapper return wrapper
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
@@ -1200,6 +1221,23 @@ def map_sum(n_threads, n_parallel, n_loops, n_items, value_types):
return tuple(a + b for a,b in zip(x,y)) return tuple(a + b for a,b in zip(x,y))
return map_reduce(n_threads, n_parallel, n_loops, initializer, summer) return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector):
inputs = vector.Array(len(vector))
inputs.assign_vector(vector)
outputs = vector.Array(len(vector) // 2)
left = len(vector)
while left > 1:
@multithread(n_threads, left // 2)
def _(base, size):
outputs.assign_vector(
function(inputs.get_vector(2 * base, size),
inputs.get_vector(2 * base + size, size)), base)
inputs.assign_vector(outputs.get_vector(0, left // 2))
if left % 2 == 1:
inputs[left // 2] = inputs[left - 1]
left = (left + 1) // 2
return inputs[0]
def foreach_enumerate(a): def foreach_enumerate(a):
""" Run-time loop over public data. This uses """ Run-time loop over public data. This uses
``Player-Data/Public-Input/<progname>``. Example: ``Player-Data/Public-Input/<progname>``. Example:
@@ -1511,6 +1549,15 @@ def break_point(name=''):
""" """
get_tape().start_new_basicblock(name=name) get_tape().start_new_basicblock(name=name)
def check_point():
"""
Force MAC checks in current thread and all idle threads if the
current thread is the main thread. This implies a break point.
"""
break_point('pre-check')
check()
break_point('post-check')
# Fixed point ops # Fixed point ops
from math import ceil, log from math import ceil, log
@@ -1566,6 +1613,9 @@ def cint_cint_division(a, b, k, f):
# theta can be replaced with something smaller # theta can be replaced with something smaller
# for safety we assume that is the same theta from previous GS method # for safety we assume that is the same theta from previous GS method
if get_program().options.ring:
assert 2 * f < int(get_program().options.ring)
theta = int(ceil(log(k/3.5) / log(2))) theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f) two = cint(2) * two_power(f)
@@ -1579,9 +1629,11 @@ def cint_cint_division(a, b, k, f):
B = absolute_b B = absolute_b
W = w0 W = w0
for i in range(1, theta): corr = cint(1) << (f - 1)
A = (A * W) >> f
B = (B * W) >> f for i in range(theta):
A = (A * W + corr) >> f
B = (B * W + corr) >> f
W = two - B W = two - B
return (sign_a * sign_b) * A return (sign_a * sign_b) * A
@@ -1592,7 +1644,7 @@ def sint_cint_division(a, b, k, f, kappa):
""" """
theta = int(ceil(log(k/3.5) / log(2))) theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f) two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b < 0) sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa) sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
absolute_b = b * sign_b absolute_b = b * sign_b
absolute_a = a * sign_a absolute_a = a * sign_a
@@ -1652,7 +1704,8 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
y = y.extend(2 * k) * (alpha + x).extend(2 * k) y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
return y return y
def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
""" """
Approximate reciprocal of [b]: Approximate reciprocal of [b]:
Given [b], compute [1/b] Given [b], compute [1/b]
@@ -1662,7 +1715,7 @@ def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
#v should be 2**{k - m} where m is the length of the bitwise repr of [b] #v should be 2**{k - m} where m is the length of the bitwise repr of [b]
d = alpha - 2 * c d = alpha - 2 * c
w = d * v w = d * v
w = w.round(2 * k, 2 * (k - f), kappa, nearest, signed=True) w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
# now w * 2 ^ {-f} should be an initial approximation of 1/b # now w * 2 ^ {-f} should be an initial approximation of 1/b
return w return w
@@ -1674,7 +1727,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
# For simplex, we can get rid of computing abs(b) # For simplex, we can get rid of computing abs(b)
temp = None temp = None
if simplex_flag == False: if simplex_flag == False:
temp = comparison.LessThanZero(b, 2 * k, kappa) temp = comparison.LessThanZero(b, k, kappa)
elif simplex_flag == True: elif simplex_flag == True:
temp = cint(0) temp = cint(0)
@@ -1682,7 +1735,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
absolute_val = sign * b absolute_val = sign * b
#next 2 lines actually compute the SufOR for little indian encoding #next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa)[::-1] bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
suffixes = PreOR(bits, kappa)[::-1] suffixes = PreOR(bits, kappa)[::-1]
z = [0] * k z = [0] * k
@@ -1690,10 +1743,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
z[i] = suffixes[i] - suffixes[i+1] z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1] z[k - 1] = suffixes[k-1]
#doing complicated stuff to compute v = 2^{k-m} acc = sint.bit_compose(reversed(z))
acc = cint(0)
for i in range(k):
acc += two_power(k-i-1) * z[i]
part_reciprocal = absolute_val * acc part_reciprocal = absolute_val * acc
signed_acc = sign * acc signed_acc = sign * acc

File diff suppressed because it is too large Load Diff

View File

@@ -846,3 +846,54 @@ def acos(x):
""" """
y = asin(x) y = asin(x)
return pi_over_2 - y return pi_over_2 - y
def tanh(x):
"""
Hyperbolic tangent. For efficiency, accuracy is diminished
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
:math:`f` denote the fixed-point parameters.
"""
limit = math.log(2 ** (x.k - x.f - 2)) / 2
s = x < -limit
t = x > limit
y = pow_fx(math.e, 2 * x)
return s.if_else(-1, t.if_else(1, (y - 1) / (y + 1)))
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
def Sep(x):
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
t = x.v * (1 + x.v.bit_compose(b_i.bit_not() for b_i in b[-2 * x.f + 1:]))
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
b += [b[0].long_one()]
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
def SqrtComp(z, old=False):
f = types.sfix.f
k = len(z)
if isinstance(z[0], types.sint):
return types.sfix._new(sum(z[i] * types.cfix(
2 ** (-(i - f + 1) / 2)).v for i in range(k)))
k_prime = k // 2
f_prime = f // 2
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
c0 = types.sfix(2 ** (f / 2 + 1))
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
if old:
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
else:
b = util.tree_reduce(lambda x, y: x.bit_xor(y), z[::2])
return types.sint.conv(b).if_else(c1, c0) * tmp
@types.vectorize
def InvertSqrt(x, old=False):
"""
Reciprocal square root approximation by `Lu et al.
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
"""
u, z = Sep(x)
c = 3.14736 + u * (4.63887 * u - 5.77789)
return c * SqrtComp(z, old=old)

View File

@@ -44,7 +44,7 @@ class Masking(NonLinear):
d = [None]*k d = [None]*k
for i,b in enumerate(r[0].bit_decompose_clear(c, k)): for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
d[i] = r[i].bit_xor(b) d[i] = r[i].bit_xor(b)
return 1 - types.sint.conv(self.kor(d)) return 1 - types.sintbit.conv(self.kor(d))
class Prime(Masking): class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """ """ Non-linear functionality modulo a prime with statistical masking. """
@@ -71,8 +71,11 @@ class Prime(Masking):
def _trunc_pr(self, a, k, m, signed=None): def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m, self.kappa) return TruncPrField(a, k, m, self.kappa)
def bit_dec(self, a, k, m): def bit_dec(self, a, k, m, maybe_mixed=False):
return BitDecField(a, k, m, self.kappa) if maybe_mixed:
return BitDecFieldRaw(a, k, m, self.kappa)
else:
return BitDecField(a, k, m, self.kappa)
def kor(self, d): def kor(self, d):
return KOR(d, self.kappa) return KOR(d, self.kappa)
@@ -85,7 +88,7 @@ class KnownPrime(NonLinear):
def _mod2m(self, a, k, m, signed): def _mod2m(self, a, k, m, signed):
if signed: if signed:
a += cint(1) << (k - 1) a += cint(1) << (k - 1)
return sint.bit_compose(self.bit_dec(a, k, k)[:m]) return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
def _trunc_pr(self, a, k, m, signed): def _trunc_pr(self, a, k, m, signed):
# nearest truncation # nearest truncation
@@ -96,14 +99,14 @@ class KnownPrime(NonLinear):
if signed: if signed:
a += cint(1) << (k - 1) a += cint(1) << (k - 1)
k += 1 k += 1
res = sint.bit_compose(self.bit_dec(a, k, k)[m:]) res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
if signed: if signed:
res -= cint(1) << (k - m - 2) res -= cint(1) << (k - m - 2)
return res return res
def bit_dec(self, a, k, m): def bit_dec(self, a, k, m, maybe_mixed=False):
assert k < self.prime.bit_length() assert k < self.prime.bit_length()
bits = BitDecFull(a) bits = BitDecFull(a, maybe_mixed=maybe_mixed)
if len(bits) < m: if len(bits) < m:
raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
return bits[:m] return bits[:m]
@@ -111,7 +114,7 @@ class KnownPrime(NonLinear):
def eqz(self, a, k): def eqz(self, a, k):
# always signed # always signed
a += two_power(k) a += two_power(k)
return 1 - KORL(self.bit_dec(a, k, k)) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
class Ring(Masking): class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time. """ Non-linear functionality modulo a power of two known at compile time.
@@ -130,8 +133,11 @@ class Ring(Masking):
def _trunc_pr(self, a, k, m, signed): def _trunc_pr(self, a, k, m, signed):
return TruncPrRing(a, k, m, signed=signed) return TruncPrRing(a, k, m, signed=signed)
def bit_dec(self, a, k, m): def bit_dec(self, a, k, m, maybe_mixed=False):
return BitDecRing(a, k, m) if maybe_mixed:
return BitDecRingRaw(a, k, m)
else:
return BitDecRing(a, k, m)
def kor(self, d): def kor(self, d):
return KORL(d) return KORL(d)

View File

@@ -28,9 +28,7 @@ data_types = dict(
square = 1, square = 1,
bit = 2, bit = 2,
inverse = 3, inverse = 3,
bittriple = 4, dabit = 4,
bitgf2ntriple = 5,
dabit = 6,
) )
field_types = dict( field_types = dict(
@@ -62,6 +60,7 @@ class defaults:
asmoutfile = None asmoutfile = None
stop = False stop = False
insecure = False insecure = False
keep_cisc = False
class Program(object): class Program(object):
""" A program consists of a list of tapes representing the whole """ A program consists of a list of tapes representing the whole
@@ -80,14 +79,14 @@ class Program(object):
self.init_names(args) self.init_names(args)
self._security = 40 self._security = 40
self.prime = None self.prime = None
self.tapes = []
if sum(x != 0 for x in(options.ring, options.field, if sum(x != 0 for x in(options.ring, options.field,
options.binary)) > 1: options.binary)) > 1:
raise CompilerError('can only use one out of -B, -R, -F') raise CompilerError('can only use one out of -B, -R, -F')
if options.prime and (options.ring or options.binary): if options.prime and (options.ring or options.binary):
raise CompilerError('can only use one out of -B, -R, -p') raise CompilerError('can only use one out of -B, -R, -p')
if options.ring: if options.ring:
self.bit_length = int(options.ring) - 1 self.set_ring_size(int(options.ring))
self.non_linear = Ring(int(options.ring))
else: else:
self.bit_length = int(options.binary) or int(options.field) self.bit_length = int(options.binary) or int(options.field)
if options.prime: if options.prime:
@@ -108,7 +107,6 @@ class Program(object):
if self.verbose: if self.verbose:
print('Galois length:', self.galois_length) print('Galois length:', self.galois_length)
self.tape_counter = 0 self.tape_counter = 0
self.tapes = []
self._curr_tape = None self._curr_tape = None
self.DEBUG = options.debug self.DEBUG = options.debug
self.allocated_mem = RegType.create_dict(lambda: USER_MEM) self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
@@ -204,6 +202,16 @@ class Program(object):
for arg in args[1:]) for arg in args[1:])
self.progname = progname self.progname = progname
def set_ring_size(self, ring_size):
from .non_linear import Ring
for tape in self.tapes:
prev = tape.req_bit_length['p']
if prev and prev != ring_size:
raise CompilerError('cannot have different ring sizes')
self.bit_length = ring_size - 1
self.non_linear = Ring(ring_size)
self.options.ring = str(ring_size)
def new_tape(self, function, args=[], name=None, single_thread=False): def new_tape(self, function, args=[], name=None, single_thread=False):
""" """
Create a new tape from a function. See Create a new tape from a function. See
@@ -414,7 +422,7 @@ class Program(object):
self.curr_tape.start_new_basicblock(None, 'memory-usage') self.curr_tape.start_new_basicblock(None, 'memory-usage')
# reset register counter to 0 # reset register counter to 0
self.curr_tape.init_registers() self.curr_tape.init_registers()
for mem_type,size in list(self.allocated_mem.items()): for mem_type,size in sorted(self.allocated_mem.items()):
if size: if size:
#print "Memory of type '%s' of size %d" % (mem_type, size) #print "Memory of type '%s' of size %d" % (mem_type, size)
if mem_type in self.types: if mem_type in self.types:
@@ -488,7 +496,7 @@ class Program(object):
else: else:
if change and not self.options.ring: if change and not self.options.ring:
raise CompilerError('splitting only supported for rings') raise CompilerError('splitting only supported for rings')
assert change > 1 assert change > 1 or change == False
self._split = change self._split = change
def use_square(self, change=None): def use_square(self, change=None):
@@ -575,7 +583,7 @@ class Tape:
scope.children.append(self) scope.children.append(self)
self.alloc_pool = scope.alloc_pool self.alloc_pool = scope.alloc_pool
else: else:
self.alloc_pool = defaultdict(set) self.alloc_pool = defaultdict(list)
self.purged = False self.purged = False
self.n_rounds = 0 self.n_rounds = 0
self.n_to_merge = 0 self.n_to_merge = 0
@@ -647,9 +655,14 @@ class Tape:
def expand_cisc(self): def expand_cisc(self):
new_instructions = [] new_instructions = []
if self.parent.program.options.keep_cisc:
skip = ['LTZ', 'Trunc']
else:
skip = []
for inst in self.instructions: for inst in self.instructions:
new_instructions.extend(inst.expand_merged()) new_inst, n_rounds = inst.expand_merged(skip)
self.n_rounds += inst.expanded_rounds() new_instructions.extend(new_inst)
self.n_rounds += n_rounds
self.instructions = new_instructions self.instructions = new_instructions
def __str__(self): def __str__(self):
@@ -774,7 +787,10 @@ class Tape:
# allocate registers # allocate registers
reg_counts = self.count_regs() reg_counts = self.count_regs()
if not options.noreallocate: if options.noreallocate:
if self.program.verbose:
print('Tape register usage:', dict(reg_counts))
else:
if self.program.verbose: if self.program.verbose:
print('Tape register usage before re-allocation:', print('Tape register usage before re-allocation:',
dict(reg_counts)) dict(reg_counts))
@@ -1071,7 +1087,7 @@ class Tape:
if size is None: if size is None:
size = Compiler.instructions_base.get_global_vector_size() size = Compiler.instructions_base.get_global_vector_size()
if size is not None and size > self.maximum_size: if size is not None and size > self.maximum_size:
raise CompilerError('vector too large') raise CompilerError('vector too large: %d' % size)
self.size = size self.size = size
self.vectorbase = self self.vectorbase = self
self.relative_i = 0 self.relative_i = 0

View File

@@ -591,12 +591,12 @@ class _register(Tape.Register, _number, _structure):
def prep_res(cls, other): def prep_res(cls, other):
return cls() return cls()
@staticmethod @classmethod
def bit_compose(bits): def bit_compose(cls, bits):
""" Compose value from bits. """ Compose value from bits.
:param bits: iterable of any type implementing left shift """ :param bits: iterable of any type implementing left shift """
return sum(b << i for i,b in enumerate(bits)) return sum(cls.conv(b) << i for i,b in enumerate(bits))
@classmethod @classmethod
def malloc(cls, size, creator_tape=None): def malloc(cls, size, creator_tape=None):
@@ -840,6 +840,7 @@ class cint(_clear, _int):
def in_immediate_range(value): def in_immediate_range(value):
return value < 2**31 and value >= -2**31 return value < 2**31 and value >= -2**31
@vectorize_init
def __init__(self, val=None, size=None): def __init__(self, val=None, size=None):
""" """
:param val: initialization (cint/regint/int/cgf2n or list thereof) :param val: initialization (cint/regint/int/cgf2n or list thereof)
@@ -1119,12 +1120,6 @@ class cgf2n(_clear, _gf2n):
elif chunk: elif chunk:
sum += chunk sum += chunk
def __mul__(self, other):
""" Clear :math:`\mathrm{GF}(2^n)` multiplication.
:param other: cgf2n/regint/int """
return super(cgf2n, self).__mul__(other)
def __neg__(self): def __neg__(self):
""" Identity. """ """ Identity. """
return self return self
@@ -1209,7 +1204,9 @@ class regint(_register, _int):
def get_random(cls, bit_length): def get_random(cls, bit_length):
""" Public insecure randomness. """ Public insecure randomness.
:param bit_length: number of bits (int) """ :param bit_length: number of bits (int)
:param size: vector size (int, default 1)
"""
if isinstance(bit_length, int): if isinstance(bit_length, int):
bit_length = regint(bit_length) bit_length = regint(bit_length)
res = cls() res = cls()
@@ -1582,7 +1579,9 @@ class _secret(_register):
def get_input_from(cls, player): def get_input_from(cls, player):
""" Secret input from player. """ Secret input from player.
:param player: public (regint/cint/int) """ :param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
res = cls() res = cls()
asm_input(res, player) asm_input(res, player)
return res return res
@@ -1592,7 +1591,9 @@ class _secret(_register):
def get_random_triple(cls): def get_random_triple(cls):
""" Secret random triple according to security model. """ Secret random triple according to security model.
:return: :math:`(a, b, ab)` """ :return: :math:`(a, b, ab)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls(), cls()) res = (cls(), cls(), cls())
triple(*res) triple(*res)
return res return res
@@ -1602,7 +1603,9 @@ class _secret(_register):
def get_random_bit(cls): def get_random_bit(cls):
""" Secret random bit according to security model. """ Secret random bit according to security model.
:return: 0/1 50-50 """ :return: 0/1 50-50
:param size: vector size (int, default 1)
"""
res = cls() res = cls()
bit(res) bit(res)
return res return res
@@ -1612,7 +1615,9 @@ class _secret(_register):
def get_random_square(cls): def get_random_square(cls):
""" Secret random square according to security model. """ Secret random square according to security model.
:return: :math:`(a, a^2)` """ :return: :math:`(a, a^2)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls()) res = (cls(), cls())
square(*res) square(*res)
return res return res
@@ -1622,7 +1627,9 @@ class _secret(_register):
def get_random_inverse(cls): def get_random_inverse(cls):
""" Secret random inverse tuple according to security model. """ Secret random inverse tuple according to security model.
:return: :math:`(a, a^{-1})` """ :return: :math:`(a, a^{-1})`
:param size: vector size (int, default 1)
"""
res = (cls(), cls()) res = (cls(), cls())
inverse(*res) inverse(*res)
return res return res
@@ -1717,16 +1724,51 @@ class _secret(_register):
else: else:
self.load_clear(self.clear_type(val)) self.load_clear(self.clear_type(val))
@classmethod
def bit_compose(cls, bits):
""" Compose value from bits.
:param bits: iterable of any type convertible to sint """
from Compiler.GC.types import sbits, sbitintvec
bits = list(bits)
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
if program.use_edabit():
mask = cls.get_edabit(len(bits), strict=True, size=bits[0].n)
else:
tmp = sint(size=bits[0].n)
randoms(tmp, len(bits))
n_overflow_bits = min(program.use_split().bit_length(),
int(program.options.ring) - len(bits))
mask_bits = tmp.bit_decompose(len(bits) + n_overflow_bits,
maybe_mixed=True)
if n_overflow_bits:
overflow = sint.bit_compose(
sint.conv(x) for x in mask_bits[-n_overflow_bits:])
mask = tmp - (overflow << len(bits)), \
mask_bits[:-n_overflow_bits]
else:
mask = tmp, mask_bits
t = sbitintvec.get_type(len(bits) + 1)
masked = t.from_vec(mask[1] + [0]) + t.from_vec(bits + [0])
overflow = masked.v[-1]
masked = cls.bit_compose(x.reveal().to_regint_by_bit() for x in masked.v[:-1])
return masked - mask[0] + (cls(overflow) << len(bits))
else:
return super(_secret, cls).bit_compose(bits)
@set_instruction_type @set_instruction_type
@read_mem_value @read_mem_value
@vectorize @vectorize
def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False): def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False):
cls = self.__class__
res = self.prep_res(other) res = self.prep_res(other)
cls = type(res)
if isinstance(other, regint): if isinstance(other, regint):
other = res.clear_type(other) other = res.clear_type(other)
if isinstance(other, cls): if isinstance(other, cls):
s_inst(res, self, other) if reverse:
s_inst(res, other, self)
else:
s_inst(res, self, other)
elif isinstance(other, res.clear_type): elif isinstance(other, res.clear_type):
if reverse: if reverse:
m_inst(res, other, self) m_inst(res, other, self)
@@ -1861,10 +1903,12 @@ class sint(_secret, _int):
def get_random_int(cls, bits): def get_random_int(cls, bits):
""" Secret random n-bit number according to security model. """ Secret random n-bit number according to security model.
:param bits: compile-time integer (int) """ :param bits: compile-time integer (int)
:param size: vector size (int, default 1)
"""
if program.use_edabit(): if program.use_edabit():
return sint.get_edabit(bits, True)[0] return sint.get_edabit(bits, True)[0]
elif program.use_split() > 2: elif program.use_split() > 2 and program.use_split() < 5:
tmp = sint() tmp = sint()
randoms(tmp, bits) randoms(tmp, bits)
x = tmp.split_to_two_summands(bits, True) x = tmp.split_to_two_summands(bits, True)
@@ -1882,7 +1926,10 @@ class sint(_secret, _int):
@vectorized_classmethod @vectorized_classmethod
def get_random(cls): def get_random(cls):
""" Secret random ring element according to security model. """ """ Secret random ring element according to security model.
:param size: vector size (int, default 1)
"""
res = sint() res = sint()
randomfulls(res) randomfulls(res)
return res return res
@@ -1891,7 +1938,9 @@ class sint(_secret, _int):
def get_input_from(cls, player): def get_input_from(cls, player):
""" Secret input. """ Secret input.
:param player: public (regint/cint/int) """ :param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
res = cls() res = cls()
inputmixed('int', res, player) inputmixed('int', res, player)
return res return res
@@ -1915,7 +1964,7 @@ class sint(_secret, _int):
else: else:
a = [sint.get_random_bit() for i in range(n_bits)] a = [sint.get_random_bit() for i in range(n_bits)]
return sint.bit_compose(a), a return sint.bit_compose(a), a
program.curr_tape.require_bit_length(n_bits) program.curr_tape.require_bit_length(n_bits - 1)
whole = cls() whole = cls()
size = get_global_vector_size() size = get_global_vector_size()
from Compiler.GC.types import sbits, sbitvec from Compiler.GC.types import sbits, sbitvec
@@ -1931,6 +1980,7 @@ class sint(_secret, _int):
return 1 return 1
@staticmethod @staticmethod
@vectorize
def bit_decompose_clear(a, n_bits): def bit_decompose_clear(a, n_bits):
return floatingpoint.bits(a, n_bits) return floatingpoint.bits(a, n_bits)
@@ -2055,7 +2105,7 @@ class sint(_secret, _int):
:param other: sint/cint/regint/int :param other: sint/cint/regint/int
:return: 0/1 (sint) """ :return: 0/1 (sint) """
res = sint() res = sintbit()
comparison.LTZ(res, self - other, comparison.LTZ(res, self - other,
(bit_length or program.bit_length) + 1, (bit_length or program.bit_length) + 1,
security or program.security) security or program.security)
@@ -2064,7 +2114,7 @@ class sint(_secret, _int):
@read_mem_value @read_mem_value
@vectorize @vectorize
def __gt__(self, other, bit_length=None, security=None): def __gt__(self, other, bit_length=None, security=None):
res = sint() res = sintbit()
comparison.LTZ(res, other - self, comparison.LTZ(res, other - self,
(bit_length or program.bit_length) + 1, (bit_length or program.bit_length) + 1,
security or program.security) security or program.security)
@@ -2185,13 +2235,14 @@ class sint(_secret, _int):
return floatingpoint.Trunc(other, program.bit_length, self, program.security) return floatingpoint.Trunc(other, program.bit_length, self, program.security)
@vectorize @vectorize
def bit_decompose(self, bit_length=None, security=None): def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False):
""" Secret bit decomposition. """ """ Secret bit decomposition. """
if bit_length == 0: if bit_length == 0:
return [] return []
bit_length = bit_length or program.bit_length bit_length = bit_length or program.bit_length
security = security or program.security assert program.security == security or program.security
return floatingpoint.BitDec(self, bit_length, bit_length, security) return program.non_linear.bit_dec(self, bit_length, bit_length,
maybe_mixed)
def TruncMul(self, other, k, m, kappa=None, nearest=False): def TruncMul(self, other, k, m, kappa=None, nearest=False):
return (self * other).round(k, m, kappa, nearest, signed=True) return (self * other).round(k, m, kappa, nearest, signed=True)
@@ -2249,6 +2300,7 @@ class sint(_secret, _int):
return floatingpoint.two_power(n) return floatingpoint.two_power(n)
def split_to_n_summands(self, length, n): def split_to_n_summands(self, length, n):
comparison.require_ring_size(length, 'splitting')
from .GC.types import sbits from .GC.types import sbits
from .GC.instructions import split from .GC.instructions import split
columns = [[sbits.get_type(self.size)() columns = [[sbits.get_type(self.size)()
@@ -2274,7 +2326,9 @@ class sint(_secret, _int):
@vectorize @vectorize
def reveal_to(self, player): def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`. """ Reveal secret value to :py:obj:`player`.
Result potentially written to ``Player-Data/Private-Output-P<player>.`` Result potentially written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint): :param player: public integer (int/regint/cint):
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to` :returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
@@ -2288,6 +2342,65 @@ class sint(_secret, _int):
else: else:
return super(sint, self).reveal_to(player) return super(sint, self).reveal_to(player)
class sintbit(sint):
@classmethod
def prep_res(cls, other):
return sint()
def load_other(self, other):
if isinstance(other, sint):
movs(self, other)
else:
super(sintbit, self).load_other(other)
@vectorize
def __and__(self, other):
if isinstance(other, sintbit):
res = sintbit()
muls(res, self, other)
return res
elif util.is_zero(other):
return 0
elif util.is_one(other):
return self
else:
return NotImplemented
@vectorize
def __or__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
return 1
else:
return NotImplemented
@vectorize
def __xor__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - 2 * self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
return 1
else:
return NotImplemented
@vectorize
def __rsub__(self, other):
if util.is_one(other):
res = sintbit()
subsfi(res, self, 1)
return res
else:
return super(sintbit, self).__rsub__(other)
class sgf2n(_secret, _gf2n): class sgf2n(_secret, _gf2n):
""" Secret :math:`\mathrm{GF}(2^n)` value. """ """ Secret :math:`\mathrm{GF}(2^n)` value. """
__slots__ = [] __slots__ = []
@@ -2437,10 +2550,11 @@ class sgf2n(_secret, _gf2n):
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)] return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
for t in (sint, sgf2n): for t in (sint, sgf2n):
t.bit_type = t
t.basic_type = t t.basic_type = t
t.default_type = t t.default_type = t
sint.bit_type = sintbit
sgf2n.bit_type = sgf2n
class _bitint(object): class _bitint(object):
bits = None bits = None
@@ -3046,14 +3160,17 @@ class cfix(_number, _structure):
@staticmethod @staticmethod
def int_rep(v, f, k=None): def int_rep(v, f, k=None):
if isinstance(v, regint):
v = cint(v)
res = v * (2 ** f) res = v * (2 ** f)
try: try:
res = int(round(res)) res = int(round(res))
if k and abs(res) >= 2 ** k: if k and res >= 2 ** (k - 1) or res < -2 ** (k - 1):
limit = 2 ** (k - f - 1)
raise CompilerError( raise CompilerError(
'Value out of fixed-point range (maximum %d). ' 'Value out of fixed-point range [-%d, %d). '
'Use `sfix.set_precision(f, k)` with k being at least f+%d' 'Use `sfix.set_precision(f, k)` with k being at least f+%d'
% (2 ** (k - f), math.ceil(math.log(abs(v), 2)) + 1)) % (limit, limit, res.bit_length() - f + 1))
except TypeError: except TypeError:
pass pass
return res return res
@@ -3268,6 +3385,14 @@ class cfix(_number, _structure):
else: else:
raise TypeError('Incompatible fixed point types in division') raise TypeError('Incompatible fixed point types in division')
@vectorize
def __rtruediv__(self, other):
""" Fixed-point division.
:param other: sfix/sint/cfix/cint/regint/int """
other = parse_type(other, self.k, self.f)
return other / self
def print_plain(self): def print_plain(self):
""" Clear fixed-point output. """ """ Clear fixed-point output. """
print_float_plain(cint.conv(self.v), cint(-self.f), \ print_float_plain(cint.conv(self.v), cint(-self.f), \
@@ -3468,7 +3593,7 @@ class _fix(_single):
set_precision = classmethod(set_precision) set_precision = classmethod(set_precision)
@classmethod @classmethod
def set_precision_from_args(cls, program): def set_precision_from_args(cls, program, adapt_ring=False):
f = None f = None
k = None k = None
for arg in program.args: for arg in program.args:
@@ -3484,6 +3609,15 @@ class _fix(_single):
cfix.set_precision(f, k) cfix.set_precision(f, k)
elif k is not None: elif k is not None:
raise CompilerError('need to set fractional precision') raise CompilerError('need to set fractional precision')
if 'nearest' in program.args:
print('Nearest rounding instead of proabilistic '
'for fixed-point computation')
cls.round_nearest = True
if adapt_ring and program.options.ring:
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
if need != int(program.options.ring):
print('Changing computation modulus to 2^%d' % need)
program.set_ring_size(need)
@classmethod @classmethod
def coerce(cls, other): def coerce(cls, other):
@@ -3609,11 +3743,14 @@ class _fix(_single):
:param other: sfix/cfix/sint/cint/regint/int """ :param other: sfix/cfix/sint/cint/regint/int """
if util.is_constant_float(other): if util.is_constant_float(other):
assert other != 0 assert other != 0
other_length = self.f + math.ceil(math.log(abs(other), 2)) log = math.ceil(math.log(abs(other), 2))
if other_length >= self.k: other_length = self.f + log
factor = 2 ** (self.k - other_length - 1) if other_length >= self.k - 1:
factor = 2 ** (self.k - other_length - 2)
self *= factor self *= factor
other *= factor other *= factor
if 2 ** log == other:
return self * 2 ** -log
other = self.coerce(other) other = self.coerce(other)
assert self.k == other.k assert self.k == other.k
assert self.f == other.f assert self.f == other.f
@@ -3660,7 +3797,9 @@ class sfix(_fix):
def get_input_from(cls, player): def get_input_from(cls, player):
""" Secret fixed-point input. """ Secret fixed-point input.
:param player: public (regint/cint/int) """ :param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
cls.int_type.require_bit_length(cls.k) cls.int_type.require_bit_length(cls.k)
v = cls.int_type() v = cls.int_type()
inputmixed('fix', v, cls.f, player) inputmixed('fix', v, cls.f, player)
@@ -3677,6 +3816,7 @@ class sfix(_fix):
:param lower: float :param lower: float
:param upper: float :param upper: float
:param size: vector size (int, default 1)
""" """
log_range = int(math.log(upper - lower, 2)) log_range = int(math.log(upper - lower, 2))
n_bits = log_range + cls.f n_bits = log_range + cls.f
@@ -3732,7 +3872,8 @@ class sfix(_fix):
def reveal_to(self, player): def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`. """ Reveal secret value to :py:obj:`player`.
Raw representation possibly written to Raw representation possibly written to
``Player-Data/Private-Output-P<player>.`` ``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint) :param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to` :returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
@@ -4066,7 +4207,9 @@ class sfloat(_number, _structure):
def get_input_from(cls, player): def get_input_from(cls, player):
""" Secret floating-point input. """ Secret floating-point input.
:param player: public (regint/cint/int) """ :param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
v = sint() v = sint()
p = sint() p = sint()
z = sint() z = sint()
@@ -4444,6 +4587,7 @@ class Array(object):
self.address_cache = {} self.address_cache = {}
self.debug = debug self.debug = debug
self.creator_tape = program.curr_tape self.creator_tape = program.curr_tape
self.sink = None
if alloc: if alloc:
self.alloc() self.alloc()
@@ -4514,6 +4658,17 @@ class Array(object):
return return
self._store(value, self.get_address(index)) self._store(value, self.get_address(index))
def maybe_get(self, condition, index):
return condition * self[condition * index]
def maybe_set(self, condition, index, value):
if self.sink is None:
self.sink = self.value_type.Array(1)
addresses = (condition.if_else(x, y) for x, y in
zip(util.tuplify(self.get_address(index)),
util.tuplify(self.sink.get_address(0))))
self._store(value, util.untuplify(tuple(addresses)))
# the following two are useful for compile-time lengths # the following two are useful for compile-time lengths
# and thus differ from the usual Python syntax # and thus differ from the usual Python syntax
def get_range(self, start, size): def get_range(self, start, size):
@@ -4590,11 +4745,22 @@ class Array(object):
get_part_vector = get_vector get_part_vector = get_vector
def get_part(self, base, size):
return Array(size, self.value_type, self.get_address(base))
def get(self, indices): def get(self, indices):
return self.value_type.load_mem( return self.value_type.load_mem(
regint.inc(len(indices), self.address, 0) + indices, regint.inc(len(indices), self.address, 0) + indices,
size=len(indices)) size=len(indices))
def get_slice_vector(self, slice):
assert self.value_type.n_elements() == 1
assert len(slice) <= self.total_size()
base = regint.inc(len(slice), slice.address, 1, 1)
inc = regint.inc(len(slice), 0, 1, 1, 1)
addresses = slice.value_type.load_mem(base) + inc
return self.value_type.load_mem(self.address + addresses)
def expand_to_vector(self, index, size): def expand_to_vector(self, index, size):
assert self.value_type.n_elements() == 1 assert self.value_type.n_elements() == 1
address = regint(size=size) address = regint(size=size)
@@ -4641,6 +4807,12 @@ class Array(object):
:param other: vector or container of same length and type that supports operations with type of this array """ :param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() * value return self.get_vector() * value
def __truediv__(self, value):
""" Vector division.
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() / value
def __pow__(self, value): def __pow__(self, value):
""" Vector power-of computation. """ Vector power-of computation.
@@ -4674,6 +4846,16 @@ class Array(object):
reveal_nested = reveal_list reveal_nested = reveal_list
def sort(self, n_threads=None):
"""
Sort in place using Batchers' odd-even merge mergesort
with complexity :math:`O(n (\log n)^2)`.
:param n_threads: number of threads to use (single thread by
default)
"""
library.loopy_odd_even_merge_sort(self, n_threads=n_threads)
def __str__(self): def __str__(self):
return '%s array of length %s at %s' % (self.value_type, len(self), return '%s array of length %s at %s' % (self.value_type, len(self),
self.address) self.address)
@@ -4784,6 +4966,15 @@ class SubMultiArray(object):
assert vector.size <= self.total_size() assert vector.size <= self.total_size()
vector.store_in_mem(self.address + base * part_size) vector.store_in_mem(self.address + base * part_size)
def get_slice_vector(self, slice):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
assert len(slice) * part_size <= self.total_size()
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
addresses = slice.value_type.load_mem(base) * part_size + inc
return self.value_type.load_mem(self.address + addresses)
def get_addresses(self, *indices): def get_addresses(self, *indices):
assert self.value_type.n_elements() == 1 assert self.value_type.n_elements() == 1
assert len(indices) == len(self.sizes) assert len(indices) == len(self.sizes)
@@ -4816,6 +5007,10 @@ class SubMultiArray(object):
""" :return: new multidimensional array with same shape and basic type """ """ :return: new multidimensional array with same shape and basic type """
return MultiArray(self.sizes, self.value_type) return MultiArray(self.sizes, self.value_type)
def get_part(self, start, size):
return MultiArray([size] + list(self.sizes[1:]), self.value_type,
address=self[start].address)
def input_from(self, player, budget=None, raw=False): def input_from(self, player, budget=None, raw=False):
""" Fill with inputs from player if supported by type. """ Fill with inputs from player if supported by type.
@@ -4978,7 +5173,7 @@ class SubMultiArray(object):
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
assert len(indices[1]) == len(indices[2]) assert len(indices[1]) == len(indices[2])
indices = list(indices) indices = list(indices)
indices[3] *= other.sizes[0] indices[3] *= other.sizes[1]
return self.value_type.direct_matrix_mul( return self.value_type.direct_matrix_mul(
self.address, other.address, None, self.sizes[1], 1, self.address, other.address, None, self.sizes[1], 1,
reduce=reduce, indices=indices) reduce=reduce, indices=indices)

View File

@@ -195,7 +195,7 @@ def is_all_ones(x, n):
else: else:
return False return False
def max(x, y=None): def max(x, y=None, n_threads=None):
if y is None: if y is None:
return tree_reduce(max, x) return tree_reduce(max, x)
else: else:

View File

@@ -7,6 +7,7 @@
#include "Networking/CryptoPlayer.h" #include "Networking/CryptoPlayer.h"
#include "Math/gfp.h" #include "Math/gfp.h"
#include "ECDSA/P256Element.h" #include "ECDSA/P256Element.h"
#include "GC/VectorInput.h"
#include "ECDSA/preprocessing.hpp" #include "ECDSA/preprocessing.hpp"
#include "ECDSA/sign.hpp" #include "ECDSA/sign.hpp"
@@ -20,6 +21,8 @@
#include "Protocols/MascotPrep.hpp" #include "Protocols/MascotPrep.hpp"
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp" #include "GC/TinyPrep.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/CcdPrep.hpp"
#include "OT/NPartyTripleGenerator.hpp" #include "OT/NPartyTripleGenerator.hpp"
#include <assert.h> #include <assert.h>

View File

@@ -5,6 +5,7 @@
#include "GC/TinierSecret.h" #include "GC/TinierSecret.h"
#include "GC/TinyMC.h" #include "GC/TinyMC.h"
#include "GC/VectorInput.h"
#include "Protocols/Share.hpp" #include "Protocols/Share.hpp"
#include "Protocols/MAC_Check.hpp" #include "Protocols/MAC_Check.hpp"

View File

@@ -19,6 +19,8 @@
#include "Processor/Data_Files.hpp" #include "Processor/Data_Files.hpp"
#include "Processor/Input.hpp" #include "Processor/Input.hpp"
#include "GC/TinyPrep.hpp" #include "GC/TinyPrep.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/CcdPrep.hpp"
#include <assert.h> #include <assert.h>

View File

@@ -13,7 +13,6 @@
#include "Protocols/MaliciousShamirShare.h" #include "Protocols/MaliciousShamirShare.h"
#include "Protocols/Rep3Share.h" #include "Protocols/Rep3Share.h"
#include "GC/TinierSecret.h" #include "GC/TinierSecret.h"
#include "GC/TinierPrep.h"
#include "GC/MaliciousCcdSecret.h" #include "GC/MaliciousCcdSecret.h"
#include "GC/TinyMC.h" #include "GC/TinyMC.h"
@@ -128,16 +127,4 @@ void check(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
MC.Check(P); MC.Check(P);
} }
template<>
void ReplicatedPrep<Rep3Share<P256Element::Scalar>>::buffer_bits()
{
throw not_implemented();
}
template<>
void ReplicatedPrep<ShamirShare<P256Element::Scalar>>::buffer_bits()
{
throw not_implemented();
}
#endif /* ECDSA_PREPROCESSING_HPP_ */ #endif /* ECDSA_PREPROCESSING_HPP_ */

View File

@@ -149,11 +149,6 @@ public:
return res; return res;
} }
bool is_binary() const
{
throw not_implemented();
}
size_t report_size(ReportType type) size_t report_size(ReportType type)
{ {
size_t res = 4; size_t res = 4;

View File

@@ -6,6 +6,7 @@
#include "AddableVector.h" #include "AddableVector.h"
#include "Rq_Element.h" #include "Rq_Element.h"
#include "FHE_Keys.h" #include "FHE_Keys.h"
#include "P2Data.h"
template<class T> template<class T>
AddableVector<T> AddableVector<T>::mul_by_X_i(int j, AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
@@ -33,7 +34,3 @@ AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
} }
return res; return res;
} }
template
AddableVector<Int_Random_Coins::rand_type> AddableVector<
Int_Random_Coins::rand_type>::mul_by_X_i(int j, const FHE_PK& pk) const;

View File

@@ -23,8 +23,6 @@ class Ciphertext
word pk_id; word pk_id;
public: public:
static string type_string() { return "ciphertext"; }
static int t() { return 0; }
static int size() { return 0; } static int size() { return 0; }
const FHE_Params& get_params() const { return *params; } const FHE_Params& get_params() const { return *params; }
@@ -41,8 +39,6 @@ class Ciphertext
set(a0, a1, C.get_pk_id()); set(a0, a1, C.get_pk_id());
} }
~Ciphertext() { ; }
// Rely on default copy assignment/constructor // Rely on default copy assignment/constructor
word get_pk_id() const { return pk_id; } word get_pk_id() const { return pk_id; }

View File

@@ -32,52 +32,6 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const
void RandomVectors::set(int nn,int hh,double R)
{
n=nn;
h=hh;
DG.set(R);
}
void RandomVectors::set_n(int nn)
{
n = nn;
}
vector<bigint> RandomVectors::sample_Gauss(PRNG& G, int stretch) const
{
vector<bigint> ans(n);
for (int i=0; i<n; i++)
{ ans[i]=DG.sample(G, stretch); }
return ans;
}
vector<bigint> RandomVectors::sample_Hwt(PRNG& G) const
{
if (h > n/2 or h <= 0) { return sample_Gauss(G); }
vector<bigint> ans(n);
for (int i=0; i<n; i++) { ans[i]=0; }
int cnt=0,j=0;
unsigned char ch=0;
while (cnt<h)
{ unsigned int i=G.get_uint()%n;
if (ans[i]==0)
{ cnt++;
if (j==0)
{ j=8;
ch=G.get_uchar();
}
int v=ch&1; j--;
if (v==0) { ans[i]=-1; }
else { ans[i]=1; }
}
}
return ans;
}
int sample_half(PRNG& G) int sample_half(PRNG& G)
{ {
@@ -91,36 +45,6 @@ int sample_half(PRNG& G)
} }
vector<bigint> RandomVectors::sample_Half(PRNG& G) const
{
vector<bigint> ans(n);
for (int i=0; i<n; i++)
ans[i] = sample_half(G);
return ans;
}
vector<bigint> RandomVectors::sample_Uniform(PRNG& G,const bigint& B) const
{
vector<bigint> ans(n);
bigint v;
for (int i=0; i<n; i++)
{ G.get_bigint(v, numBits(B));
int bit=G.get_uint()&1;
if (bit==0) { ans[i]=v; }
else { ans[i]=-v; }
}
return ans;
}
bool RandomVectors::operator!=(const RandomVectors& other) const
{
if (n != other.n or h != other.h or DG != other.DG)
return true;
else
return false;
}
bool DiscreteGauss::operator!=(const DiscreteGauss& other) const bool DiscreteGauss::operator!=(const DiscreteGauss& other) const
{ {
if (other.NewHopeB != NewHopeB) if (other.NewHopeB != NewHopeB)

View File

@@ -25,7 +25,6 @@ class DiscreteGauss
void unpack(octetStream& o) { o.unserialize(NewHopeB); } void unpack(octetStream& o) { o.unserialize(NewHopeB); }
DiscreteGauss(double R) { set(R); } DiscreteGauss(double R) { set(R); }
~DiscreteGauss() { ; }
// Rely on default copy constructor/assignment // Rely on default copy constructor/assignment
@@ -36,50 +35,6 @@ class DiscreteGauss
bool operator!=(const DiscreteGauss& other) const; bool operator!=(const DiscreteGauss& other) const;
}; };
/* Sample from integer lattice of dimension n
* with standard deviation R
*/
class RandomVectors
{
int n,h;
DiscreteGauss DG; // This generates the main distribution
public:
void set(int nn,int hh,double R); // R is input STANDARD DEVIATION
void set_n(int nn);
void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
void unpack(octetStream& o)
{ o.get(n); o.get(h); DG.unpack(o); }
RandomVectors(int h, double R) : RandomVectors(0, h, R) {}
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
~RandomVectors() { ; }
// Rely on default copy constructor/assignment
double get_R() const { return DG.get_R(); }
DiscreteGauss get_DG() const { return DG; }
int get_h() const { return h; }
// Sample from Discrete Gauss distribution
vector<bigint> sample_Gauss(PRNG& G, int stretch = 1) const;
// Next samples from Hwt distribution unless hwt>n/2 in which
// case it uses Gauss
vector<bigint> sample_Hwt(PRNG& G) const;
// Sample from {-1,0,1} with Pr(-1)=Pr(1)=1/4 and Pr(0)=1/2
vector<bigint> sample_Half(PRNG& G) const;
// Sample from (-B,0,B) with uniform prob
vector<bigint> sample_Uniform(PRNG& G,const bigint& B) const;
bool operator!=(const RandomVectors& other) const;
};
template<class T> template<class T>
class RandomGenerator : public Generator<T> class RandomGenerator : public Generator<T>
{ {
@@ -103,7 +58,7 @@ public:
void get(T& x) const { this->G.get(x, n_bits, positive); } void get(T& x) const { this->G.get(x, n_bits, positive); }
}; };
template<class T> template<class T = bigint>
class GaussianGenerator : public RandomGenerator<T> class GaussianGenerator : public RandomGenerator<T>
{ {
DiscreteGauss DG; DiscreteGauss DG;

View File

@@ -1,6 +1,7 @@
#include "FHE/FFT.h" #include "FHE/FFT.h"
#include "Math/Zp_Data.h" #include "Math/Zp_Data.h"
#include "Processor/BaseMachine.h"
#include "Math/modp.hpp" #include "Math/modp.hpp"
@@ -115,17 +116,38 @@ void FFT_Iter(vector<T>& ioput, int n, const T& root, const P& PrD)
*/ */
void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD) void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
{ {
FFT_Iter(ioput, n, root, PrD, false);
}
void FFT_Iter2(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD)
{
FFT_Iter(ioput, n, roots, PrD, false);
}
void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
bool start_with_one)
{
vector<modp> roots(n + 1);
assignOne(roots[0], PrD);
for (int i = 1; i < n + 1; i++)
Mul(roots[i], roots[i - 1], root, PrD);
FFT_Iter(ioput, n, roots, PrD, start_with_one);
}
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD, bool start_with_one)
{
assert(roots.size() > size_t(n));
int i, j, m; int i, j, m;
modp t;
// Bit-reversal of input // Bit-reversal of input
for( i = j = 0; i < n; ++i ) for( i = j = 0; i < n; ++i )
{ {
if( j >= i ) if( j >= i )
{ {
t = ioput[i]; swap(ioput[i], ioput[j]);
ioput[i] = ioput[j];
ioput[j] = t;
} }
m = n / 2; m = n / 2;
@@ -136,27 +158,38 @@ void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
} }
j += m; j += m;
} }
modp u, alpha, alpha2;
m = 0; j = 0; i = 0; m = 0; j = 0; i = 0;
// Do the transform // Do the transform
vector<modp> alpha2;
alpha2.reserve(n / 2);
for (int s = 1; s < n; s = 2*s) for (int s = 1; s < n; s = 2*s)
{ {
m = 2*s; m = 2*s;
Power(alpha, root, n/m, PrD);
alpha2 = alpha; alpha2.clear();
Mul(alpha, alpha, alpha, PrD); if (start_with_one)
for (int j = 0; j < m/2; ++j)
{ {
//root = root_table[(2*j+1)*n/m]; for (int j = 0; j < m / 2; j++)
for (int k = j; k < n; k += m) alpha2.push_back(roots[j * n / m]);
{
Mul(t, alpha2, ioput[k + m/2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m/2], u, t, PrD);
}
Mul(alpha2, alpha2, alpha, PrD);
} }
else
{
for (int j = 0; j < m / 2; j++)
alpha2.push_back(roots.at((j * 2 + 1) * (n / m)));
}
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{
auto& queues = BaseMachine::s().queues;
FftJob job(ioput, alpha2, m, PrD);
int start = queues.distribute(job, n / 2);
for (int i = start; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
queues.wrap_up(job);
}
else
for (int i = 0; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
} }
} }

View File

@@ -30,8 +30,29 @@ void FFT2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
template <class T,class P> template <class T,class P>
void FFT_Iter(vector<T>& a,int N,const T& theta,const P& PrD); void FFT_Iter(vector<T>& a,int N,const T& theta,const P& PrD);
void FFT_Iter(vector<modp>& a, int N, const modp& theta, const Zp_Data& PrD,
bool start_with_one = true);
void FFT_Iter2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD); void FFT_Iter2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
// variants with precomputed roots
void FFT_Iter(vector<modp>& a, int N, const vector<modp>& theta,
const Zp_Data& PrD, bool start_with_one = true);
void FFT_Iter2(vector<modp>& a, int N, const vector<modp>& theta,
const Zp_Data& PrD);
inline void FFT_Iter2_body(vector<modp>& ioput, const vector<modp>& alpha2, int i,
int m, const Zp_Data& PrD)
{
int j = i % (m / 2);
int kk = i / (m / 2);
int k = j + kk * m;
modp t, u;
Mul(t, alpha2[j], ioput[k + m / 2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m / 2], u, t, PrD);
}
/* BFFT perform FFT and inverse FFT mod PrD for non power of two cyclotomics. /* BFFT perform FFT and inverse FFT mod PrD for non power of two cyclotomics.
* The modulus in PrD (contained in FFT_Data) must be set up * The modulus in PrD (contained in FFT_Data) must be set up

View File

@@ -6,24 +6,6 @@
#include "Math/modp.hpp" #include "Math/modp.hpp"
void FFT_Data::assign(const FFT_Data& FFTD)
{
prData=FFTD.prData;
R=FFTD.R;
root=FFTD.root;
twop=FFTD.twop;
two_root=FFTD.two_root;
powers=FFTD.powers;
powers_i=FFTD.powers_i;
b=FFTD.b;
iphi=FFTD.iphi;
}
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{ {
@@ -49,6 +31,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
Inv(root[1],root[0],PrD); Inv(root[1],root[0],PrD);
to_modp(iphi,Rg.phi_m(),PrD); to_modp(iphi,Rg.phi_m(),PrD);
Inv(iphi,iphi,PrD); Inv(iphi,iphi,PrD);
compute_roots(Rg.m());
} }
} }
else else
@@ -57,6 +40,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{ throw invalid_params(); } { throw invalid_params(); }
root[0]=Find_Primitive_Root_2m(Rg.m(),Rg.Phi(),PrD); root[0]=Find_Primitive_Root_2m(Rg.m(),Rg.Phi(),PrD);
Inv(root[1],root[0],PrD); Inv(root[1],root[0],PrD);
compute_roots(2 * Rg.m());
int ptwop=twop; if (twop<0) { ptwop=-twop; } int ptwop=twop; if (twop<0) { ptwop=-twop; }
@@ -97,6 +81,14 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
} }
} }
void FFT_Data::compute_roots(int n)
{
roots.resize(n + 1);
assignOne(roots[0], prData);
for (int i = 1; i < n + 1; i++)
Mul(roots[i], roots[i - 1], root[0], prData);
}
void FFT_Data::hash(octetStream& o) const void FFT_Data::hash(octetStream& o) const
{ {
@@ -111,6 +103,7 @@ void FFT_Data::pack(octetStream& o) const
R.pack(o); R.pack(o);
prData.pack(o); prData.pack(o);
o.store(root); o.store(root);
o.store(roots);
o.store(twop); o.store(twop);
o.store(two_root); o.store(two_root);
o.store(b); o.store(b);
@@ -125,6 +118,7 @@ void FFT_Data::unpack(octetStream& o)
R.unpack(o); R.unpack(o);
prData.unpack(o); prData.unpack(o);
o.get(root); o.get(root);
o.get(roots);
o.get(twop); o.get(twop);
o.get(two_root); o.get(two_root);
o.get(b); o.get(b);
@@ -133,7 +127,6 @@ void FFT_Data::unpack(octetStream& o)
o.get(powers_i); o.get(powers_i);
} }
bool FFT_Data::operator!=(const FFT_Data& other) const bool FFT_Data::operator!=(const FFT_Data& other) const
{ {
if (R != other.R or prData != other.prData or root != other.root if (R != other.R or prData != other.prData or root != other.root

View File

@@ -19,6 +19,7 @@ class FFT_Data
Zp_Data prData; Zp_Data prData;
vector<modp> root; // 2m'th Root of Unity mod pr and it's inverse vector<modp> root; // 2m'th Root of Unity mod pr and it's inverse
vector<modp> roots; // precomputed powers of root
// When twop is equal to zero, m is a power of two // When twop is equal to zero, m is a power of two
// When twop is positive it is equal to 2^e where 2^e>2*m and 2^e divides p-1 // When twop is positive it is equal to 2^e where 2^e>2*m and 2^e divides p-1
@@ -34,6 +35,8 @@ class FFT_Data
modp iphi; // 1/phi_m mod pr modp iphi; // 1/phi_m mod pr
vector< vector<modp> > powers,powers_i; vector< vector<modp> > powers,powers_i;
void compute_roots(int n);
public: public:
typedef gfp T; typedef gfp T;
typedef bigint S; typedef bigint S;
@@ -47,17 +50,9 @@ class FFT_Data
void pack(octetStream& o) const; void pack(octetStream& o) const;
void unpack(octetStream& o); void unpack(octetStream& o);
void assign(const FFT_Data& FFTD);
FFT_Data() { ; } FFT_Data() { ; }
FFT_Data(const FFT_Data& FFTD)
{ assign(FFTD); }
FFT_Data(const Ring& Rg,const Zp_Data& PrD) FFT_Data(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); } { init(Rg,PrD); }
FFT_Data& operator=(const FFT_Data& FFTD)
{ if (this!=&FFTD) { assign(FFTD); }
return *this;
}
const Zp_Data& get_prD() const { return prData; } const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; } const bigint& get_prime() const { return prData.pr; }
@@ -72,6 +67,7 @@ class FFT_Data
int get_twop() const { return twop; } int get_twop() const { return twop; }
modp get_root(int i) const { return root[i]; } modp get_root(int i) const { return root[i]; }
modp get_iphi() const { return iphi; } modp get_iphi() const { return iphi; }
const vector<modp>& get_roots() const { return roots; }
const Ring& get_R() const { return R; } const Ring& get_R() const { return R; }

View File

@@ -42,7 +42,7 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
{ {
Rq_Element sk = FHE_SK(*this).s(); Rq_Element sk = FHE_SK(*this).s();
// Generate the secret key // Generate the secret key
sk.from_vec((*params).sampleHwt(G)); sk.from(GaussianGenerator<bigint>(params->get_DG(), G));
return sk; return sk;
} }
@@ -55,7 +55,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
// b0=a0*s+p*e0 // b0=a0*s+p*e0
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
e0.from_vec((*PK.params).sampleGaussian(G, noise_boost)); e0.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
mul(PK.b0,PK.a0,sk); mul(PK.b0,PK.a0,sk);
mul(e0,e0,PK.pr); mul(e0,e0,PK.pr);
add(PK.b0,PK.b0,e0); add(PK.b0,PK.b0,e0);
@@ -72,7 +72,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
// bs=as*s+p*es // bs=as*s+p*es
Rq_Element es((*PK.params).FFTD(),evaluation,evaluation); Rq_Element es((*PK.params).FFTD(),evaluation,evaluation);
es.from_vec((*PK.params).sampleGaussian(G, noise_boost)); es.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
mul(PK.Sw_b,PK.Sw_a,sk); mul(PK.Sw_b,PK.Sw_a,sk);
mul(es,es,PK.pr); mul(es,es,PK.pr);
add(PK.Sw_b,PK.Sw_b,es); add(PK.Sw_b,PK.Sw_b,es);
@@ -120,13 +120,14 @@ void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) const
} }
template<> template<class T, class FD, class S>
void FHE_PK::encrypt(Ciphertext& c, void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gfp,FFT_Data,bigint>& mess,const Random_Coins& rc) const const Plaintext<T, FD, S>& mess,const Random_Coins& rc) const
{ {
if (&c.get_params()!=params) { throw params_mismatch(); } if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); } if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); } if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element mm((*params).FFTD(),polynomial,polynomial); Rq_Element mm((*params).FFTD(),polynomial,polynomial);
mm.from(mess.get_iterator()); mm.from(mess.get_iterator());
@@ -134,35 +135,6 @@ void FHE_PK::encrypt(Ciphertext& c,
quasi_encrypt(c,mm,rc); quasi_encrypt(c,mm,rc);
} }
template<>
void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gfp,PPData,bigint>& mess,const Random_Coins& rc) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
mess.to_poly();
encrypt(c, mess.get_poly(), rc);
}
template<>
void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gf2n_short,P2Data,int>& mess,const Random_Coins& rc) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr!=2) { throw pr_mismatch(); }
mess.to_poly();
encrypt(c, mess.get_poly(), rc);
}
void FHE_PK::quasi_encrypt(Ciphertext& c, void FHE_PK::quasi_encrypt(Ciphertext& c,
const Rq_Element& mess,const Random_Coins& rc) const const Rq_Element& mess,const Random_Coins& rc) const
{ {
@@ -212,42 +184,12 @@ Ciphertext FHE_PK::encrypt(
} }
template<> template<class T, class FD, class S>
void FHE_SK::decrypt(Plaintext<gfp,FFT_Data,bigint>& mess,const Ciphertext& c) const void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
{ {
if (&c.get_params()!=params) { throw params_mismatch(); } if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); } if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
ans.change_rep(polynomial);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
}
template<>
void FHE_SK::decrypt(Plaintext<gfp,PPData,bigint>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
mess.set_poly_mod(ans.to_vec_bigint(),ans.get_modulus());
}
template<>
void FHE_SK::decrypt(Plaintext<gf2n_short,P2Data,int>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr!=2) { throw pr_mismatch(); }
Rq_Element ans; Rq_Element ans;

View File

@@ -3,14 +3,6 @@
#include "FHE/Ring_Element.h" #include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h" #include "Tools/Exceptions.h"
void FHE_Params::set(const Ring& R,
const vector<bigint>& primes,double r,int hwt)
{
set(R, primes);
Chi.set(R.phi_m(),hwt,r);
}
void FHE_Params::set(const Ring& R, void FHE_Params::set(const Ring& R,
const vector<bigint>& primes) const vector<bigint>& primes)
{ {
@@ -20,7 +12,6 @@ void FHE_Params::set(const Ring& R,
for (size_t i = 0; i < FFTData.size(); i++) for (size_t i = 0; i < FFTData.size(); i++)
FFTData[i].init(R,primes[i]); FFTData[i].init(R,primes[i]);
Chi.set_n(R.phi_m());
set_sec(40); set_sec(40);
} }

View File

@@ -21,7 +21,7 @@ class FHE_Params
vector<FFT_Data> FFTData; vector<FFT_Data> FFTData;
// Random generator for Multivariate Gaussian Distribution etc // Random generator for Multivariate Gaussian Distribution etc
RandomVectors Chi; mutable DiscreteGauss Chi;
// Data for distributed decryption // Data for distributed decryption
int sec_p; int sec_p;
@@ -29,27 +29,17 @@ class FHE_Params
public: public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(-1, 0.7), sec_p(-1) {} FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
int n_mults() const { return FFTData.size() - 1; } int n_mults() const { return FFTData.size() - 1; }
// Rely on default copy assignment/constructor (not that they should // Rely on default copy assignment/constructor (not that they should
// ever be needed) // ever be needed)
void set(const Ring& R,const vector<bigint>& primes,double r,int hwt);
void set(const Ring& R,const vector<bigint>& primes); void set(const Ring& R,const vector<bigint>& primes);
void set(const vector<bigint>& primes); void set(const vector<bigint>& primes);
void set_sec(int sec); void set_sec(int sec);
vector<bigint> sampleGaussian(PRNG& G, int noise_boost = 1) const
{ return Chi.sample_Gauss(G, noise_boost); }
vector<bigint> sampleHwt(PRNG& G) const
{ return Chi.sample_Hwt(G); }
vector<bigint> sampleHalf(PRNG& G) const
{ return Chi.sample_Half(G); }
vector<bigint> sampleUniform(PRNG& G,const bigint& Bd) const
{ return Chi.sample_Uniform(G,Bd); }
const vector<FFT_Data>& FFTD() const { return FFTData; } const vector<FFT_Data>& FFTD() const { return FFTData; }
const bigint& p0() const { return FFTData[0].get_prime(); } const bigint& p0() const { return FFTData[0].get_prime(); }
@@ -59,9 +49,8 @@ class FHE_Params
int secp() const { return sec_p; } int secp() const { return sec_p; }
const bigint& B() const { return Bval; } const bigint& B() const { return Bval; }
double get_R() const { return Chi.get_R(); } double get_R() const { return Chi.get_R(); }
void set_R(double R) const { return Chi.get_DG().set(R); } void set_R(double R) const { return Chi.set(R); }
DiscreteGauss get_DG() const { return Chi.get_DG(); } DiscreteGauss get_DG() const { return Chi; }
int get_h() const { return Chi.get_h(); }
int phi_m() const { return FFTData[0].phi_m(); } int phi_m() const { return FFTData[0].phi_m(); }
const Ring& get_ring() { return FFTData[0].get_R(); } const Ring& get_ring() { return FFTData[0].get_R(); }

View File

@@ -52,10 +52,12 @@ int generate_semi_setup(int plaintext_length, int sec,
bigint p; bigint p;
generate_prime(p, lgp, m); generate_prime(p, lgp, m);
int lgp0, lgp1; int lgp0, lgp1;
FHE_Params tmp_params;
while (true) while (true)
{ {
tmp_params = params;
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
bigint p1 = 2 * p * m, p0 = p; bigint p1 = 2 * p * m, p0 = p;
while (nb.min_p0(params.n_mults() > 0, p1) > p0) while (nb.min_p0(params.n_mults() > 0, p1) > p0)
{ {
@@ -75,6 +77,7 @@ int generate_semi_setup(int plaintext_length, int sec,
} }
} }
params = tmp_params;
int extra_slack = common_semi_setup(params, m, p, lgp0, lgp1, round_up); int extra_slack = common_semi_setup(params, m, p, lgp0, lgp1, round_up);
FTD.init(params.get_ring(), p); FTD.init(params.get_ring(), p);

View File

@@ -13,29 +13,24 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
const FHE_Params& params) : const FHE_Params& params) :
p(p), phi_m(phi_m), n(n), sec(sec), p(p), phi_m(phi_m), n(n), sec(sec),
slack(numBits(Proof::slack(slack_param, sec, phi_m))), slack(numBits(Proof::slack(slack_param, sec, phi_m))),
sigma(params.get_R()), h(params.get_h()) sigma(params.get_R())
{ {
if (sigma <= 0) if (sigma <= 0)
this->sigma = sigma = FHE_Params().get_R(); this->sigma = sigma = FHE_Params().get_R();
#ifdef VERBOSE if (extra_h)
cerr << "Standard deviation: " << this->sigma << endl;
#endif
if (h > 0)
h += extra_h * sec;
else if (extra_h)
{ {
sigma *= 1.4; sigma *= 1.4;
params.set_R(params.get_R() * 1.4); params.set_R(params.get_R() * 1.4);
} }
#ifdef VERBOSE
cerr << "Standard deviation: " << this->sigma << endl;
#endif
produce_epsilon_constants(); produce_epsilon_constants();
// according to documentation of SCALE-MAMBA 1.7 // according to documentation of SCALE-MAMBA 1.7
// excluding a factor of n because we don't always add up n ciphertexts // excluding a factor of n because we don't always add up n ciphertexts
if (h > 0) V_s = sigma * sqrt(phi_m);
V_s = sqrt(h);
else
V_s = sigma * sqrt(phi_m);
B_clean = (bigint(phi_m) << (sec + 1)) * p B_clean = (bigint(phi_m) << (sec + 1)) * p
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s);
// unify parameters by taking maximum over TopGear or not // unify parameters by taking maximum over TopGear or not

View File

@@ -22,7 +22,6 @@ protected:
const int sec; const int sec;
int slack; int slack;
mpf_class sigma; mpf_class sigma;
int h;
bigint B_clean; bigint B_clean;
bigint B_scale; bigint B_scale;

View File

@@ -5,14 +5,6 @@
void PPData::assign(const PPData& PPD)
{
R=PPD.R;
prData=PPD.prData;
root=PPD.root;
}
void PPData::init(const Ring& Rg,const Zp_Data& PrD) void PPData::init(const Ring& Rg,const Zp_Data& PrD)
{ {
R=Rg; R=Rg;

View File

@@ -27,17 +27,9 @@ class PPData
void init(const Ring& Rg,const Zp_Data& PrD); void init(const Ring& Rg,const Zp_Data& PrD);
void assign(const PPData& PPD);
PPData() { ; } PPData() { ; }
PPData(const PPData& PPD)
{ assign(PPD); }
PPData(const Ring& Rg,const Zp_Data& PrD) PPData(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); } { init(Rg,PrD); }
PPData& operator=(const PPData& PPD)
{ if (this!=&PPD) { assign(PPD); }
return *this;
}
const Zp_Data& get_prD() const { return prData; } const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; } const bigint& get_prime() const { return prData.pr; }

View File

@@ -5,6 +5,7 @@
#include "FHE/P2Data.h" #include "FHE/P2Data.h"
#include "FHE/Rq_Element.h" #include "FHE/Rq_Element.h"
#include "FHE_Keys.h" #include "FHE_Keys.h"
#include "FHE/AddableVector.hpp"
#include "Math/Z2k.hpp" #include "Math/Z2k.hpp"
#include "Math/modp.hpp" #include "Math/modp.hpp"
@@ -258,37 +259,9 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
} }
template<>
void Plaintext_<FFT_Data>::randomize(PRNG& G, bigint B, bool Diag, bool binary, PT_Type t)
{
if (Diag or binary)
throw not_implemented();
if (B == 0)
throw runtime_error("cannot randomize modulo 0");
allocate(t);
switch (t)
{
case Polynomial:
rand_poly(b, G, B, false);
break;
case Evaluation:
for (int i = 0; i < n_slots; i++)
a[i] = G.randomBnd(B);
break;
default:
throw runtime_error("wrong type for randomization with bound");
break;
}
}
template<class T,class FD,class S> template<class T,class FD,class S>
void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, bool binary, PT_Type t) void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
{ {
if (binary)
throw not_implemented();
allocate(t); allocate(t);
switch(t) switch(t)
{ {
@@ -614,10 +587,11 @@ void Plaintext<gf2n_short,P2Data,int>::negate()
template<class T, class FD, class S> template<class T, class FD, class _>
Rq_Element Plaintext<T, FD, S>::mul_by_X_i(int i, const FHE_PK& pk) const AddableVector<typename FD::poly_type> Plaintext<T, FD, _>::mul_by_X_i(int i,
const FHE_PK& pk) const
{ {
return Rq_Element(pk.get_params(), *this).mul_by_X_i(i); return AddableVector<S>(get_poly()).mul_by_X_i(i, pk);
} }

View File

@@ -25,6 +25,7 @@ using namespace std;
class FHE_PK; class FHE_PK;
class Rq_Element; class Rq_Element;
template<class T> class AddableVector;
// Forward declaration as apparently this is needed for friends in templates // Forward declaration as apparently this is needed for friends in templates
template<class T,class FD,class S> class Plaintext; template<class T,class FD,class S> class Plaintext;
@@ -64,13 +65,6 @@ class Plaintext
const FD& get_field() const { return *Field_Data; } const FD& get_field() const { return *Field_Data; }
unsigned int num_slots() const { return n_slots; } unsigned int num_slots() const { return n_slots; }
void assign(const Plaintext& p)
{ Field_Data=p.Field_Data;
a=p.a; b=p.b; type=p.type;
n_slots = p.n_slots;
degree = p.degree;
}
Plaintext(const FD& FieldD, PT_Type type = Polynomial) Plaintext(const FD& FieldD, PT_Type type = Polynomial)
{ Field_Data=&FieldD; set_sizes(); allocate(type); } { Field_Data=&FieldD; set_sizes(); allocate(type); }
@@ -142,8 +136,7 @@ class Plaintext
void to_poly() const; void to_poly() const;
void randomize(PRNG& G,condition cond=Full); void randomize(PRNG& G,condition cond=Full);
void randomize(PRNG& G, bigint B, bool Diag=false, bool binary=false, PT_Type type=Polynomial); void randomize(PRNG& G, int n_bits, bool Diag=false, PT_Type type=Polynomial);
void randomize(PRNG& G, int n_bits, bool Diag=false, bool binary=false, PT_Type type=Polynomial);
void assign_zero(PT_Type t = Evaluation); void assign_zero(PT_Type t = Evaluation);
void assign_one(PT_Type t = Evaluation); void assign_one(PT_Type t = Evaluation);
@@ -171,13 +164,12 @@ class Plaintext
void negate(); void negate();
Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const; AddableVector<S> mul_by_X_i(int i, const FHE_PK& pk) const;
bool equals(const Plaintext& x) const; bool equals(const Plaintext& x) const;
bool operator!=(const Plaintext& x) { return !equals(x); } bool operator!=(const Plaintext& x) { return !equals(x); }
bool is_diagonal() const; bool is_diagonal() const;
bool is_binary() const { throw not_implemented(); }
/* Pack and unpack into an octetStream /* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already * For unpack we assume the FFTD has been assigned correctly already

View File

@@ -52,8 +52,6 @@ class Random_Coins
{ params=&p; } { params=&p; }
Random_Coins(const FHE_PK& pk); Random_Coins(const FHE_PK& pk);
~Random_Coins() { ; }
// Rely on default copy assignment/constructor // Rely on default copy assignment/constructor

View File

@@ -33,17 +33,36 @@ Ring_Element::Ring_Element(const FFT_Data& fftd,RepType r)
} }
void Ring_Element::prepare(const Ring_Element& other)
{
assert(this != &other);
FFTD = other.FFTD;
rep = other.rep;
prepare_push();
}
void Ring_Element::prepare_push()
{
element.clear();
element.reserve(FFTD->phi_m());
}
void Ring_Element::allocate()
{
element.resize(FFTD->phi_m());
}
void Ring_Element::assign_zero() void Ring_Element::assign_zero()
{ {
element.resize((*FFTD).phi_m()); element.clear();
for (int i=0; i<(*FFTD).phi_m(); i++)
{ assignZero(element[i],(*FFTD).get_prD()); }
} }
void Ring_Element::assign_one() void Ring_Element::assign_one()
{ {
element.resize((*FFTD).phi_m()); allocate();
modp fill; modp fill;
if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); } if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); }
else { assignOne(fill,(*FFTD).get_prD()); } else { assignOne(fill,(*FFTD).get_prD()); }
@@ -56,6 +75,9 @@ void Ring_Element::assign_one()
void Ring_Element::negate() void Ring_Element::negate()
{ {
if (element.empty())
return;
for (int i=0; i<(*FFTD).phi_m(); i++) for (int i=0; i<(*FFTD).phi_m(); i++)
{ Negate(element[i],element[i],(*FFTD).get_prD()); } { Negate(element[i],element[i],(*FFTD).get_prD()); }
} }
@@ -66,20 +88,58 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{ {
if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a); if (a.element.empty())
{
ans = b;
return;
}
else if (b.element.empty())
{
ans = a;
return;
}
if (&ans == &a)
{
ans += b;
return;
}
else if (&ans == &b)
{
ans += a;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++) for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Add(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); } ans.element.push_back(a.element[i].add(b.element[i], a.FFTD->get_prD()));
} }
void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{ {
if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a); if (a.element.empty())
{
ans = b;
ans.negate();
return;
}
else if (b.element.empty())
{
ans = a;
return;
}
if (&ans == &a)
{
ans -= b;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++) for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Sub(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); } ans.element.push_back(a.element[i].sub(b.element[i], a.FFTD->get_prD()));
} }
@@ -88,13 +148,29 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{ {
if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a); if (a.element.empty() or b.element.empty())
if (ans.rep==evaluation) {
{ // In evaluation representation, so we can just multiply componentwise ans = Ring_Element(*a.FFTD, a.rep);
for (int i=0; i<(*ans.FFTD).phi_m(); i++) return;
{ Mul(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
} }
else if ((*ans.FFTD).get_twop()!=0)
if (a.rep==evaluation)
{ // In evaluation representation, so we can just multiply componentwise
if (&ans == &a)
{
ans *= b;
return;
}
else if (&ans == &b)
{
ans *= a;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
ans.element.push_back(a.element[i].mul(b.element[i], a.FFTD->get_prD()));
}
else if ((*a.FFTD).get_twop()!=0)
{ // This is the case where m is not a power of two { // This is the case where m is not a power of two
// Here we have to do a poly mult followed by a reduction // Here we have to do a poly mult followed by a reduction
@@ -116,11 +192,13 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
// Now apply reduction, assumes Ring.poly is monic // Now apply reduction, assumes Ring.poly is monic
reduce(aa, 2*(*a.FFTD).phi_m(), (*a.FFTD).phi_m(), *a.FFTD); reduce(aa, 2*(*a.FFTD).phi_m(), (*a.FFTD).phi_m(), *a.FFTD);
// Now stick into answer // Now stick into answer
ans.partial_assign(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++) for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ ans.element[i]=aa[i]; } { ans.element[i]=aa[i]; }
} }
else if ((*ans.FFTD).get_twop()==0) else if ((*a.FFTD).get_twop()==0)
{ // m a power of two case { // m a power of two case
ans.partial_assign(a);
Ring_Element aa(*ans.FFTD,ans.rep); Ring_Element aa(*ans.FFTD,ans.rep);
modp temp; modp temp;
for (int i=0; i<(*ans.FFTD).phi_m(); i++) for (int i=0; i<(*ans.FFTD).phi_m(); i++)
@@ -143,31 +221,89 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
void mul(Ring_Element& ans,const Ring_Element& a,const modp& b) void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
{ {
ans.partial_assign(a); if (&ans == &a)
{
ans *= b;
return;
}
ans.prepare(a);
if (a.element.empty())
return;
for (int i=0; i<(*ans.FFTD).phi_m(); i++) for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Mul(ans.element[i],a.element[i],b,(*a.FFTD).get_prD()); } ans.element.push_back(a.element[i].mul(b, a.FFTD->get_prD()));
}
Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].add(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].sub(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
assert(rep == evaluation);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].mul(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator *=(const modp& other)
{
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].mul(other, FFTD->get_prD());
return *this;
} }
Ring_Element Ring_Element::mul_by_X_i(int j) const Ring_Element Ring_Element::mul_by_X_i(int j) const
{ {
Ring_Element ans; Ring_Element ans;
ans.prepare(*this);
if (element.empty())
return ans;
auto& a = *this; auto& a = *this;
ans.partial_assign(a);
if (ans.rep == evaluation) if (ans.rep == evaluation)
{ {
modp xj, xj2; modp xj, xj2;
Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD()); Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD());
Sqr(xj2, xj, (*a.FFTD).get_prD()); Sqr(xj2, xj, (*a.FFTD).get_prD());
ans.prepare_push();
modp tmp;
for (int i= 0; i < (*ans.FFTD).phi_m(); i++) for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{ {
Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD()); Mul(tmp, a.element[i], xj, (*a.FFTD).get_prD());
ans.element.push_back(tmp);
Mul(xj, xj, xj2, (*a.FFTD).get_prD()); Mul(xj, xj, xj2, (*a.FFTD).get_prD());
} }
} }
else else
{ {
Ring_Element aa(*ans.FFTD, ans.rep); Ring_Element aa(*ans.FFTD, ans.rep);
aa.allocate();
for (int i= 0; i < (*ans.FFTD).phi_m(); i++) for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{ {
int k= j + i, s= 1; int k= j + i, s= 1;
@@ -193,6 +329,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const
void Ring_Element::randomize(PRNG& G,bool Diag) void Ring_Element::randomize(PRNG& G,bool Diag)
{ {
allocate();
if (Diag==false) if (Diag==false)
{ for (int i=0; i<(*FFTD).phi_m(); i++) { for (int i=0; i<(*FFTD).phi_m(); i++)
{ element[i].randomize(G,(*FFTD).get_prD()); } { element[i].randomize(G,(*FFTD).get_prD()); }
@@ -213,12 +350,18 @@ void Ring_Element::randomize(PRNG& G,bool Diag)
void Ring_Element::change_rep(RepType r) void Ring_Element::change_rep(RepType r)
{ {
if (element.empty())
{
rep = r;
return;
}
if (rep==r) { return; } if (rep==r) { return; }
if (r==evaluation) if (r==evaluation)
{ rep=evaluation; { rep=evaluation;
if ((*FFTD).get_twop()==0) if ((*FFTD).get_twop()==0)
{ // m a power of two variant { // m a power of two variant
FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_root(0),(*FFTD).get_prD()); FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_roots(),(*FFTD).get_prD());
} }
else else
{ // Non m power of two variant and FFT enabled { // Non m power of two variant and FFT enabled
@@ -258,6 +401,11 @@ void Ring_Element::change_rep(RepType r)
bool Ring_Element::equals(const Ring_Element& a) const bool Ring_Element::equals(const Ring_Element& a) const
{ {
if (element.empty() and a.element.empty())
return true;
else if (element.empty() or a.element.empty())
throw not_implemented();
if (rep!=a.rep) { throw rep_mismatch(); } if (rep!=a.rep) { throw rep_mismatch(); }
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
for (int i=0; i<(*FFTD).phi_m(); i++) for (int i=0; i<(*FFTD).phi_m(); i++)
@@ -266,34 +414,11 @@ bool Ring_Element::equals(const Ring_Element& a) const
} }
void Ring_Element::from_vec(const vector<bigint>& v)
{
RepType t=rep;
rep=polynomial;
bigint tmp;
for (int i=0; i<(*FFTD).phi_m(); i++)
{
tmp = v[i];
element[i].convert_destroy(tmp, FFTD->get_prD());
}
change_rep(t);
// cout << "RE:from_vec<bigint>:: " << *this << endl;
}
void Ring_Element::from_vec(const vector<int>& v)
{
RepType t=rep;
rep=polynomial;
for (int i=0; i<(*FFTD).phi_m(); i++)
{ to_modp(element[i],v[i],(*FFTD).get_prD()); }
change_rep(t);
// cout << "RE:from_vec<int>:: " << *this << endl;
}
ConversionIterator Ring_Element::get_iterator() const ConversionIterator Ring_Element::get_iterator() const
{ {
if (rep != polynomial) if (rep != polynomial)
throw runtime_error("simple iterator only available in polynomial represention"); throw runtime_error("simple iterator only available in polynomial represention");
assert(not element.empty());
return {element, (*FFTD).get_prD()}; return {element, (*FFTD).get_prD()};
} }
@@ -318,6 +443,9 @@ vector<bigint> Ring_Element::to_vec_bigint() const
void Ring_Element::to_vec_bigint(vector<bigint>& v) const void Ring_Element::to_vec_bigint(vector<bigint>& v) const
{ {
v.resize(FFTD->phi_m()); v.resize(FFTD->phi_m());
if (element.empty())
return;
if (rep==polynomial) if (rep==polynomial)
{ for (int i=0; i<(*FFTD).phi_m(); i++) { for (int i=0; i<(*FFTD).phi_m(); i++)
{ to_bigint(v[i],element[i],(*FFTD).get_prD()); } { to_bigint(v[i],element[i],(*FFTD).get_prD()); }
@@ -336,11 +464,10 @@ void Ring_Element::to_vec_bigint(vector<bigint>& v) const
modp Ring_Element::get_constant() const modp Ring_Element::get_constant() const
{ {
if (rep==polynomial) if (element.empty())
{ return element[0]; } return {};
Ring_Element a=*this; else
a.change_rep(polynomial); return element[0];
return a.element[0];
} }
@@ -364,9 +491,14 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
+ to_string(ZpD.pr_bit_length)); + to_string(ZpD.pr_bit_length));
unsigned int length; unsigned int length;
o.get(length); o.get(length);
v.resize(length); v.clear();
v.reserve(length);
modp tmp;
for (unsigned int i=0; i<length; i++) for (unsigned int i=0; i<length; i++)
{ v[i].unpack(o,ZpD); } {
tmp.unpack(o,ZpD);
v.push_back(tmp);
}
} }
@@ -398,7 +530,7 @@ void Ring_Element::check_rep()
void Ring_Element::check_size() const void Ring_Element::check_size() const
{ {
if ((int)element.size() != FFTD->phi_m()) if (not element.empty() and (int)element.size() != FFTD->phi_m())
throw runtime_error("invalid element size"); throw runtime_error("invalid element size");
} }

View File

@@ -41,12 +41,6 @@ class Ring_Element
vector<modp> element; vector<modp> element;
// Define a copy
void assign(const Ring_Element& e)
{ rep=e.rep; FFTD=e.FFTD;
element=e.element;
}
public: public:
// Used to basically make sure *this is able to cope // Used to basically make sure *this is able to cope
@@ -57,6 +51,10 @@ class Ring_Element
element.resize((*FFTD).phi_m()); element.resize((*FFTD).phi_m());
} }
void prepare(const Ring_Element& e);
void prepare_push();
void allocate();
void set_data(const FFT_Data& prd) { FFTD=&prd; } void set_data(const FFT_Data& prd) { FFTD=&prd; }
const FFT_Data& get_FFTD() const { return *FFTD; } const FFT_Data& get_FFTD() const { return *FFTD; }
const Zp_Data& get_prD() const { return (*FFTD).get_prD(); } const Zp_Data& get_prD() const { return (*FFTD).get_prD(); }
@@ -80,19 +78,6 @@ class Ring_Element
element.push_back(x); element.push_back(x);
} }
// Copy Constructor
Ring_Element(const Ring_Element& e)
{ assign(e); }
// Destructor
~Ring_Element() { ; }
// Copy Assignment
Ring_Element& operator=(const Ring_Element& e)
{ if (this!=&e) { assign(e); }
return *this;
}
/* Functional Operators */ /* Functional Operators */
void negate(); void negate();
friend void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b); friend void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b);
@@ -102,6 +87,11 @@ class Ring_Element
Ring_Element mul_by_X_i(int i) const; Ring_Element mul_by_X_i(int i) const;
Ring_Element& operator+=(const Ring_Element& other);
Ring_Element& operator-=(const Ring_Element& other);
Ring_Element& operator*=(const Ring_Element& other);
Ring_Element& operator*=(const modp& other);
void randomize(PRNG& G,bool Diag=false); void randomize(PRNG& G,bool Diag=false);
bool equals(const Ring_Element& a) const; bool equals(const Ring_Element& a) const;
@@ -112,8 +102,6 @@ class Ring_Element
// Converting to and from a vector of bigint/int's // Converting to and from a vector of bigint/int's
// I/O is assumed to be in poly rep, so from_vec it internally alters // I/O is assumed to be in poly rep, so from_vec it internally alters
// the representation to the current representation // the representation to the current representation
void from_vec(const vector<bigint>& v);
void from_vec(const vector<int>& v);
vector<bigint> to_vec_bigint() const; vector<bigint> to_vec_bigint() const;
void to_vec_bigint(vector<bigint>& v) const; void to_vec_bigint(vector<bigint>& v) const;
@@ -136,8 +124,18 @@ class Ring_Element
// This gets the constant term of the poly rep as a modp element // This gets the constant term of the poly rep as a modp element
modp get_constant() const; modp get_constant() const;
modp get_element(int i) const { return element[i]; } modp get_element(int i) const
void set_element(int i,const modp& a) { element[i]=a; } {
if (element.empty())
return {};
else
return element[i];
}
void set_element(int i,const modp& a)
{
allocate();
element[i] = a;
}
/* Pack and unpack into an octetStream /* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already * For unpack we assume the FFTD has been assigned correctly already
@@ -164,7 +162,11 @@ class RingWriteIterator : public WriteConversionIterator
public: public:
RingWriteIterator(Ring_Element& element) : RingWriteIterator(Ring_Element& element) :
WriteConversionIterator(element.element, element.FFTD->get_prD()), WriteConversionIterator(element.element, element.FFTD->get_prD()),
element(element), rep(element.rep) { element.rep = polynomial; } element(element), rep(element.rep)
{
element.rep = polynomial;
element.allocate();
}
~RingWriteIterator() { element.change_rep(rep); } ~RingWriteIterator() { element.change_rep(rep); }
}; };
@@ -175,7 +177,11 @@ class RingReadIterator : public ConversionIterator
public: public:
RingReadIterator(const Ring_Element& element) : RingReadIterator(const Ring_Element& element) :
ConversionIterator(this->element.element, element.FFTD->get_prD()), ConversionIterator(this->element.element, element.FFTD->get_prD()),
element(element) { this->element.change_rep(polynomial); } element(element)
{
this->element.change_rep(polynomial);
this->element.allocate();
}
}; };
@@ -189,10 +195,13 @@ void Ring_Element::from(const Generator<T>& generator)
RepType t=rep; RepType t=rep;
rep=polynomial; rep=polynomial;
T tmp; T tmp;
modp tmp2;
prepare_push();
for (int i=0; i<(*FFTD).phi_m(); i++) for (int i=0; i<(*FFTD).phi_m(); i++)
{ {
generator.get(tmp); generator.get(tmp);
element[i].convert_destroy(tmp, (*FFTD).get_prD()); tmp2.convert_destroy(tmp, (*FFTD).get_prD());
element.push_back(tmp2);
} }
change_rep(t); change_rep(t);
} }

View File

@@ -48,15 +48,6 @@ void Rq_Element::partial_assign(const Rq_Element& other)
{ {
lev=other.lev; lev=other.lev;
a.resize(other.a.size()); a.resize(other.a.size());
for (size_t i = 0; i < a.size(); i++)
a[i].partial_assign(other.a[i]);
}
void Rq_Element::assign(const Rq_Element& other)
{
partial_assign(other);
for (int i=0; i<=lev; ++i)
a[i] = other.a[i];
} }
void Rq_Element::negate() void Rq_Element::negate()
@@ -134,20 +125,6 @@ bool Rq_Element::equals(const Rq_Element& other) const
} }
void Rq_Element::from_vec(const vector<bigint>& v,int level)
{
set_level(level);
for (int i=0;i<=lev;++i)
a[i].from_vec(v);
}
void Rq_Element::from_vec(const vector<int>& v,int level)
{
set_level(level);
for (int i=0;i<=lev;++i)
a[i].from_vec(v);
}
vector<bigint> Rq_Element::to_vec_bigint() const vector<bigint> Rq_Element::to_vec_bigint() const
{ {
vector<bigint> v; vector<bigint> v;

View File

@@ -44,7 +44,6 @@ protected:
void assign_zero(const vector<FFT_Data>& prd); void assign_zero(const vector<FFT_Data>& prd);
void assign_zero(); void assign_zero();
void assign_one(); void assign_one();
void assign(const Rq_Element& e);
void partial_assign(const Rq_Element& e); void partial_assign(const Rq_Element& e);
// Must be careful not to call by mistake // Must be careful not to call by mistake
@@ -85,10 +84,6 @@ protected:
a[1] = Ring_Element(prd[1], r, b1); a[1] = Ring_Element(prd[1], r, b1);
} }
// Destructor
~Rq_Element()
{ ; }
const Ring_Element& get(int i) const { return a[i]; } const Ring_Element& get(int i) const { return a[i]; }
/* Functional Operators */ /* Functional Operators */
@@ -131,8 +126,6 @@ protected:
void partial_assign(const Rq_Element& a, const Rq_Element& b); void partial_assign(const Rq_Element& a, const Rq_Element& b);
// Converting to and from a vector of bigint's Again I/O is in poly rep // Converting to and from a vector of bigint's Again I/O is in poly rep
void from_vec(const vector<bigint>& v,int level=-1);
void from_vec(const vector<int>& v,int level=-1);
vector<bigint> to_vec_bigint() const; vector<bigint> to_vec_bigint() const;
void to_vec_bigint(vector<bigint>& v) const; void to_vec_bigint(vector<bigint>& v) const;

View File

@@ -49,7 +49,8 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
if (not error.empty()) if (not error.empty())
{ {
cerr << "Running secrets generation because " << error << endl; cerr << "Running secrets generation because no suitable material "
"from a previous run was found (" << error << ")" << endl;
setup.key_and_mac_generation(P, machine, num_runs, V()); setup.key_and_mac_generation(P, machine, num_runs, V());
ofstream output(filename); ofstream output(filename);

View File

@@ -109,11 +109,11 @@ DistKeyGen::DistKeyGen(const FHE_Params& params, const bigint& p) :
*/ */
void DistKeyGen::Gen_Random_Data(PRNG& G) void DistKeyGen::Gen_Random_Data(PRNG& G)
{ {
secret.from_vec(params.sampleHwt(G)); secret.from(GaussianGenerator<bigint>(params.get_DG(), G));
rc1.generate(G); rc1.generate(G);
rc2.generate(G); rc2.generate(G);
a.randomize(G); a.randomize(G);
e.from_vec(params.sampleGaussian(G)); e.from(GaussianGenerator<bigint>(params.get_DG(), G));
} }
DistKeyGen& DistKeyGen::operator+=(const DistKeyGen& other) DistKeyGen& DistKeyGen::operator+=(const DistKeyGen& other)

View File

@@ -45,7 +45,7 @@ template <class FD>
void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res, void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role) const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role)
{ {
octetStream o; o.reset_write_head();
if (role & SENDER) if (role & SENDER)
{ {

View File

@@ -36,6 +36,8 @@ class Multiplier
size_t volatile_capacity; size_t volatile_capacity;
MemoryUsage memory_usage; MemoryUsage memory_usage;
octetStream o;
public: public:
Multiplier(int offset, PairwiseGenerator<FD>& generator); Multiplier(int offset, PairwiseGenerator<FD>& generator);
Multiplier(int offset, PairwiseMachine& machine, Player& P, Multiplier(int offset, PairwiseMachine& machine, Player& P,

View File

@@ -1,32 +0,0 @@
/*
* Player-Offline.h
*
*/
#ifndef FHEOFFLINE_PLAYER_OFFLINE_H_
#define FHEOFFLINE_PLAYER_OFFLINE_H_
class thread_info
{
public:
int thread_num;
int covert;
Names* Nms;
FHE_PK* pk_p;
FHE_PK* pk_2;
FHE_SK* sk_p;
FHE_SK* sk_2;
Ciphertext *calphap;
Ciphertext *calpha2;
gfp *alphapi;
gf2n_short *alpha2i;
FFT_Data *FTD;
P2Data *P2D;
int nm2,nmp,nb2,nbp,ni2,nip,ns2,nsp,nvp;
bool skip_2() { return nm2 + ni2 + nb2 + ns2 == 0; }
};
#endif /* FHEOFFLINE_PLAYER_OFFLINE_H_ */

View File

@@ -589,7 +589,7 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
P.receive_player(j, cleartexts); P.receive_player(j, cleartexts);
C.resize(personal_EC.machine->sec, pk.get_params()); C.resize(personal_EC.machine->sec, pk.get_params());
Verifier<FD>(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts, Verifier<FD>(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts,
cleartexts, pk, false); cleartexts, pk);
} }
inputs[j].clear(); inputs[j].clear();

View File

@@ -88,6 +88,7 @@ public:
bool Proof::check_bounds(T& z, X& t, int i) const bool Proof::check_bounds(T& z, X& t, int i) const
{ {
(void)i;
unsigned int j,k; unsigned int j,k;
// Check Bound 1 and Bound 2 // Check Bound 1 and Bound 2
@@ -99,9 +100,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
auto& te = z[j]; auto& te = z[j];
if (plain_checker.outside(te, dist)) if (plain_checker.outside(te, dist))
{ {
#ifdef VERBOSE
cout << "Fail on Check 1 " << i << " " << j << endl; cout << "Fail on Check 1 " << i << " " << j << endl;
cout << te << " " << plain_check << endl; cout << te << " " << plain_check << endl;
cout << tau << " " << sec << " " << n_proofs << endl; cout << tau << " " << sec << " " << n_proofs << endl;
#endif
return false; return false;
} }
} }
@@ -113,9 +116,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
auto& te = coeffs.at(j); auto& te = coeffs.at(j);
if (rand_checker.outside(te, dist)) if (rand_checker.outside(te, dist))
{ {
#ifdef VERBOSE
cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl; cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl;
cout << te << " " << rand_check << endl; cout << te << " " << rand_check << endl;
cout << rho << " " << sec << " " << n_proofs << endl; cout << rho << " " << sec << " " << n_proofs << endl;
#endif
return false; return false;
} }
} }

View File

@@ -6,6 +6,7 @@
#include "Tools/random.h" #include "Tools/random.h"
#include "Math/Z2k.hpp" #include "Math/Z2k.hpp"
#include "Math/modp.hpp" #include "Math/modp.hpp"
#include "FHE/AddableVector.hpp"
template <class FD, class U> template <class FD, class U>
@@ -28,7 +29,7 @@ Prover<FD,U>::Prover(Proof& proof, const FD& FieldD) :
template <class FD, class U> template <class FD, class U>
void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts, void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
const AddableVector<Ciphertext>& c, const AddableVector<Ciphertext>& c,
const FHE_PK& pk, bool binary) const FHE_PK& pk)
{ {
size_t allocate = 3 * c.size() * c[0].report_size(USED); size_t allocate = 3 * c.size() * c[0].report_size(USED);
ciphertexts.resize_precise(allocate); ciphertexts.resize_precise(allocate);
@@ -51,7 +52,7 @@ void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
// AE.randomize(Diag,binary); // AE.randomize(Diag,binary);
// rd=RandPoly(phim,bd<<1); // rd=RandPoly(phim,bd<<1);
// y[i]=AE.plaintext()+pr*rd; // y[i]=AE.plaintext()+pr*rd;
y[i].randomize(G, P.B_plain_length, P.get_diagonal(), binary); y[i].randomize(G, P.B_plain_length, P.get_diagonal());
if (P.get_diagonal()) if (P.get_diagonal())
assert(y[i].is_diagonal()); assert(y[i].is_diagonal());
s[i].resize(3, P.phim); s[i].resize(3, P.phim);
@@ -114,8 +115,7 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
const FHE_PK& pk, const FHE_PK& pk,
const AddableVector<Ciphertext>& c, const AddableVector<Ciphertext>& c,
const vector<U>& x, const vector<U>& x,
const Proof::Randomness& r, const Proof::Randomness& r)
bool binary)
{ {
// AElement<T> AE; // AElement<T> AE;
// for (i=0; i<P.sec; i++) // for (i=0; i<P.sec; i++)
@@ -130,13 +130,15 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
int cnt=0; int cnt=0;
while (!ok) while (!ok)
{ cnt++; { cnt++;
Stage_1(P,ciphertexts,c,pk,binary); Stage_1(P,ciphertexts,c,pk);
P.set_challenge(ciphertexts); P.set_challenge(ciphertexts);
// Check check whether we are OK, or whether we should abort // Check check whether we are OK, or whether we should abort
ok = Stage_2(P,cleartexts,x,r,pk); ok = Stage_2(P,cleartexts,x,r,pk);
} }
#ifdef VERBOSE
if (cnt > 1) if (cnt > 1)
cout << "\t\tNumber iterations of prover = " << cnt << endl; cout << "\t\tNumber iterations of prover = " << cnt << endl;
#endif
return report_size(CAPACITY) + volatile_memory; return report_size(CAPACITY) + volatile_memory;
} }

View File

@@ -24,8 +24,7 @@ public:
Prover(Proof& proof, const FD& FieldD); Prover(Proof& proof, const FD& FieldD);
void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector<Ciphertext>& c, void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector<Ciphertext>& c,
const FHE_PK& pk, const FHE_PK& pk);
bool binary = false);
bool Stage_2(Proof& P, octetStream& cleartexts, bool Stage_2(Proof& P, octetStream& cleartexts,
const vector<U>& x, const vector<U>& x,
@@ -40,8 +39,7 @@ public:
const FHE_PK& pk, const FHE_PK& pk,
const AddableVector<Ciphertext>& c, const AddableVector<Ciphertext>& c,
const vector<U>& x, const vector<U>& x,
const Proof::Randomness& r, const Proof::Randomness& r);
bool binary=false);
size_t report_size(ReportType type); size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res); void report_size(ReportType type, MemoryUsage& res);

View File

@@ -11,6 +11,7 @@
#include "Protocols/MAC_Check.h" #include "Protocols/MAC_Check.h"
#include "Protocols/MAC_Check.hpp" #include "Protocols/MAC_Check.hpp"
#include "Math/modp.hpp"
template<class T, class FD, class S> template<class T, class FD, class S>
SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) : SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) :
@@ -63,7 +64,10 @@ SimpleEncCommitFactory<FD>::SimpleEncCommitFactory(const FHE_PK& pk,
template <class FD> template <class FD>
SimpleEncCommitFactory<FD>::~SimpleEncCommitFactory() SimpleEncCommitFactory<FD>::~SimpleEncCommitFactory()
{ {
cout << "EncCommit called " << n_calls << " times" << endl; #ifdef VERBOSE_HE
if (n_calls > 0)
cout << "EncCommit called " << n_calls << " times" << endl;
#endif
} }
template<class FD> template<class FD>
@@ -131,7 +135,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
Prover<FD, Plaintext_<FD> > prover(proof, FTD); Prover<FD, Plaintext_<FD> > prover(proof, FTD);
#endif #endif
size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts, size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts,
pk, c, m, r, false); pk, c, m, r);
timers["Proving"].stop(); timers["Proving"].stop();
if (proof.top_gear) if (proof.top_gear)
@@ -192,7 +196,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
#endif #endif
timers["Verifying"].start(); timers["Verifying"].start();
verifier.NIZKPoK(others_ciphertexts, ciphertexts, verifier.NIZKPoK(others_ciphertexts, ciphertexts,
cleartexts, get_pk_for_verification(i), false); cleartexts, get_pk_for_verification(i));
timers["Verifying"].stop(); timers["Verifying"].stop();
add_ciphertexts(others_ciphertexts, i); add_ciphertexts(others_ciphertexts, i);
this->memory_usage.update("verifier", verifier.report_size(CAPACITY)); this->memory_usage.update("verifier", verifier.report_size(CAPACITY));
@@ -251,7 +255,7 @@ void SummingEncCommit<FD>::create_more()
#endif #endif
this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof); this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof);
this->timers["Stage 1 of proof"].start(); this->timers["Stage 1 of proof"].start();
prover.Stage_1(proof, ciphertexts, this->c, this->pk, false); prover.Stage_1(proof, ciphertexts, this->c, this->pk);
this->timers["Stage 1 of proof"].stop(); this->timers["Stage 1 of proof"].stop();
this->c.unpack(ciphertexts, this->pk); this->c.unpack(ciphertexts, this->pk);
@@ -291,8 +295,10 @@ void SummingEncCommit<FD>::create_more()
for (int i = 1; i < P.num_players(); i++) for (int i = 1; i < P.num_players(); i++)
{ {
#ifdef VERBOSE_HE
cout << "Sending cleartexts with " << 1e-9 * cleartexts.get_length() cout << "Sending cleartexts with " << 1e-9 * cleartexts.get_length()
<< " GB in round " << i << endl; << " GB in round " << i << endl;
#endif
TimeScope(this->timers["Exchanging cleartexts"]); TimeScope(this->timers["Exchanging cleartexts"]);
P.pass_around(cleartexts); P.pass_around(cleartexts);
preimages.add(cleartexts); preimages.add(cleartexts);
@@ -312,7 +318,7 @@ void SummingEncCommit<FD>::create_more()
Verifier<FD> verifier(proof); Verifier<FD> verifier(proof);
#endif #endif
verifier.Stage_2(this->c, ciphertexts, cleartexts, verifier.Stage_2(this->c, ciphertexts, cleartexts,
this->pk, false); this->pk);
this->timers["Verifying"].stop(); this->timers["Verifying"].stop();
this->cnt = proof.U - 1; this->cnt = proof.U - 1;

View File

@@ -25,7 +25,10 @@ bool Check_Decoding(const Plaintext<T,FD,S>& AE,bool Diag)
// return false; // return false;
// } // }
if (Diag && !AE.is_diagonal()) if (Diag && !AE.is_diagonal())
{ cout << "Fail Check 5 " << endl; {
#ifdef VERBOSE
cout << "Fail Check 5 " << endl;
#endif
return false; return false;
} }
return true; return true;
@@ -62,7 +65,7 @@ template <class FD>
void Verifier<FD>::Stage_2( void Verifier<FD>::Stage_2(
AddableVector<Ciphertext>& c,octetStream& ciphertexts, AddableVector<Ciphertext>& c,octetStream& ciphertexts,
octetStream& cleartexts, octetStream& cleartexts,
const FHE_PK& pk,bool binary) const FHE_PK& pk)
{ {
unsigned int i, V; unsigned int i, V;
@@ -90,18 +93,19 @@ void Verifier<FD>::Stage_2(
rc.assign(t[0], t[1], t[2]); rc.assign(t[0], t[1], t[2]);
pk.encrypt(d2,z,rc); pk.encrypt(d2,z,rc);
if (!(d1 == d2)) if (!(d1 == d2))
{ cout << "Fail Check 6 " << i << endl; {
#ifdef VERBOSE
cout << "Fail Check 6 " << i << endl;
#endif
throw runtime_error("ciphertexts don't match"); throw runtime_error("ciphertexts don't match");
} }
if (!Check_Decoding(z,P.get_diagonal(),FieldD)) if (!Check_Decoding(z,P.get_diagonal(),FieldD))
{ cout << "\tCheck : " << i << endl; {
#ifdef VERBOSE
cout << "\tCheck : " << i << endl;
#endif
throw runtime_error("cleartext isn't diagonal"); throw runtime_error("cleartext isn't diagonal");
} }
if (binary && !z.is_binary())
{
cout << "Not binary " << i << endl;
throw runtime_error("cleartext isn't binary");
}
} }
} }
@@ -112,17 +116,15 @@ void Verifier<FD>::Stage_2(
template <class FD> template <class FD>
void Verifier<FD>::NIZKPoK(AddableVector<Ciphertext>& c, void Verifier<FD>::NIZKPoK(AddableVector<Ciphertext>& c,
octetStream& ciphertexts, octetStream& cleartexts, octetStream& ciphertexts, octetStream& cleartexts,
const FHE_PK& pk, const FHE_PK& pk)
bool binary)
{ {
P.set_challenge(ciphertexts); P.set_challenge(ciphertexts);
Stage_2(c,ciphertexts,cleartexts,pk,binary); Stage_2(c,ciphertexts,cleartexts,pk);
if (P.top_gear) if (P.top_gear)
{ {
assert(not P.get_diagonal()); assert(not P.get_diagonal());
assert(not binary);
c += c; c += c;
} }
} }

View File

@@ -21,14 +21,14 @@ public:
void Stage_2( void Stage_2(
AddableVector<Ciphertext>& c, octetStream& ciphertexts, AddableVector<Ciphertext>& c, octetStream& ciphertexts,
octetStream& cleartexts,const FHE_PK& pk,bool binary=false); octetStream& cleartexts,const FHE_PK& pk);
/* This is the non-interactive version using the ROM /* This is the non-interactive version using the ROM
- Creates space for all output values - Creates space for all output values
- Diag flag mirrors that in Prover - Diag flag mirrors that in Prover
*/ */
void NIZKPoK(AddableVector<Ciphertext>& c,octetStream& ciphertexts,octetStream& cleartexts, void NIZKPoK(AddableVector<Ciphertext>& c,octetStream& ciphertexts,octetStream& cleartexts,
const FHE_PK& pk,bool binary=false); const FHE_PK& pk);
size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); } size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); }
}; };

View File

@@ -19,13 +19,13 @@ template<class T>
class CcdPrep : public BufferPrep<T> class CcdPrep : public BufferPrep<T>
{ {
typename T::part_type::LivePrep part_prep; typename T::part_type::LivePrep part_prep;
typename T::part_type::MAC_Check part_MC;
SubProcessor<typename T::part_type>* part_proc; SubProcessor<typename T::part_type>* part_proc;
ShareThread<T>& thread; ShareThread<T>& thread;
public: public:
CcdPrep(DataPositions& usage, ShareThread<T>& thread) : CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
BufferPrep<T>(usage), part_prep(usage), part_proc(0), thread(thread) BufferPrep<T>(usage), part_prep(usage, thread), part_proc(0),
thread(thread)
{ {
} }
@@ -34,17 +34,9 @@ public:
{ {
} }
~CcdPrep() ~CcdPrep();
{
if (part_proc)
delete part_proc;
}
void set_protocol(typename T::Protocol& protocol) void set_protocol(typename T::Protocol& protocol);
{
part_proc = new SubProcessor<typename T::part_type>(part_MC,
part_prep, protocol.get_part().P);
}
Preprocessing<typename T::part_type>& get_part() Preprocessing<typename T::part_type>& get_part()
{ {
@@ -53,7 +45,16 @@ public:
void buffer_triples() void buffer_triples()
{ {
throw not_implemented(); assert(part_proc);
this->triples.push_back({});
for (auto& x : this->triples.back())
x.resize_regs(T::default_length);
for (int i = 0; i < T::default_length; i++)
{
auto triple = part_prep.get_triple(1);
for (int j = 0; j < 3; j++)
this->triples.back()[j].get_bit(j) = triple[j];
}
} }
void buffer_bits() void buffer_bits()
@@ -72,6 +73,25 @@ public:
{ {
throw not_implemented(); throw not_implemented();
} }
void buffer_inputs(int player)
{
this->inputs[player].push_back({});
this->inputs[player].back().share.resize_regs(T::default_length);
for (int i = 0; i < T::default_length; i++)
{
typename T::part_type::open_type tmp;
part_prep.get_input(this->inputs[player].back().share.get_reg(i),
tmp, player);
this->inputs[player].back().value ^=
(typename T::clear(tmp.get_bit(0)) << i);
}
}
size_t data_sent()
{
return part_prep.data_sent();
}
}; };
} /* namespace GC */ } /* namespace GC */

33
GC/CcdPrep.hpp Normal file
View File

@@ -0,0 +1,33 @@
/*
* CcdPrep.hpp
*
*/
#ifndef GC_CCDPREP_HPP_
#define GC_CCDPREP_HPP_
#include "CcdPrep.h"
#include "Processor/Processor.hpp"
namespace GC
{
template<class T>
CcdPrep<T>::~CcdPrep()
{
if (part_proc)
delete part_proc;
}
template<class T>
void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
{
assert(thread.MC);
part_proc = new SubProcessor<typename T::part_type>(
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
}
}
#endif /* GC_CCDPREP_HPP_ */

View File

@@ -34,11 +34,6 @@ public:
static const int default_length = 1; static const int default_length = 1;
static DataFieldType field_type()
{
return DATA_GF2;
}
static string name() static string name()
{ {
return "CCD"; return "CCD";

View File

@@ -70,8 +70,6 @@ public:
static const true_type invertible; static const true_type invertible;
static const true_type characteristic_two; static const true_type characteristic_two;
static DataFieldType field_type() { return DATA_GF2; }
static MC* new_mc(mac_key_type key) { return new MC(key); } static MC* new_mc(mac_key_type key) { return new MC(key); }
static void store_clear_in_dynamic(Memory<DynamicType>& mem, static void store_clear_in_dynamic(Memory<DynamicType>& mem,

View File

@@ -39,11 +39,6 @@ public:
static const int default_length = 1; static const int default_length = 1;
static DataFieldType field_type()
{
return DATA_GF2;
}
static string name() static string name()
{ {
return "Malicious CCD"; return "Malicious CCD";

View File

@@ -54,6 +54,11 @@ public:
return "no"; return "no";
} }
static DataFieldType field_type()
{
throw not_implemented();
}
static void fail() static void fail()
{ {
throw runtime_error("VM does not support binary circuits"); throw runtime_error("VM does not support binary circuits");
@@ -101,16 +106,10 @@ public:
typedef NoValue clear; typedef NoValue clear;
typedef NoValue mac_key_type; typedef NoValue mac_key_type;
typedef NoShare bit_type;
typedef NoShare part_type;
typedef NoShare small_type; typedef NoShare small_type;
typedef BlackHole out_type; typedef BlackHole out_type;
static const int default_length = 1;
static const bool needs_ot = false;
static const bool expensive_triples = false;
static const bool is_real = false; static const bool is_real = false;
static MC* new_mc(mac_key_type) static MC* new_mc(mac_key_type)
@@ -118,21 +117,6 @@ public:
return new MC; return new MC;
} }
template<class T>
static void generate_mac_key(mac_key_type, T)
{
}
static DataFieldType field_type()
{
throw not_implemented();
}
static string type_short()
{
return "";
}
static string type_string() static string type_string()
{ {
return "no"; return "no";
@@ -155,7 +139,6 @@ public:
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); } static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); } static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); } static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; } static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
@@ -166,11 +149,8 @@ public:
void load_clear(Integer, Integer) { fail(); } void load_clear(Integer, Integer) { fail(); }
void random_bit() { fail(); } void random_bit() { fail(); }
void and_(int, NoShare&, NoShare&, bool) { fail(); }
void xor_(int, NoShare&, NoShare&) { fail(); }
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); } void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); } void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
void reveal(Integer, Integer) { fail(); }
void assign(const char*) { fail(); } void assign(const char*) { fail(); }
@@ -183,13 +163,11 @@ public:
NoShare operator-(const NoShare&) const { fail(); return {}; } NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; } NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; } NoShare operator&(int) const { fail(); return {}; }
NoShare operator>>(int) const { fail(); return {}; } NoShare operator>>(int) const { fail(); return {}; }
NoShare& operator+=(const NoShare&) { fail(); return *this; } NoShare& operator+=(const NoShare&) { fail(); return *this; }
NoShare lsb() const { fail(); return {}; }
NoShare get_bit(int) const { fail(); return {}; } NoShare get_bit(int) const { fail(); return {}; }
void invert(int, NoShare) { fail(); } void invert(int, NoShare) { fail(); }

View File

@@ -88,7 +88,7 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
input.reset_all(P); input.reset_all(P);
for (size_t i = begin; i < end; i++) for (size_t i = begin; i < end; i++)
{ {
typename T::clear x[2]; typename T::open_type x[2];
for (int j = 0; j < 2; j++) for (int j = 0; j < 2; j++)
this->get_input(triples[i][j], x[j], input_player); this->get_input(triples[i][j], x[j], input_player);
if (P.my_num() == input_player) if (P.my_num() == input_player)

View File

@@ -84,6 +84,7 @@ public:
void xors(const vector<int>& args); void xors(const vector<int>& args);
void xors(const vector<int>& args, size_t start, size_t end); void xors(const vector<int>& args, size_t start, size_t end);
void xorc(const ::BaseInstruction& instruction);
void nots(const ::BaseInstruction& instruction); void nots(const ::BaseInstruction& instruction);
void andm(const ::BaseInstruction& instruction); void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat); void and_(const vector<int>& args, bool repeat);

View File

@@ -18,6 +18,7 @@ using namespace std;
#include "GC/Machine.hpp" #include "GC/Machine.hpp"
#include "Processor/ProcessorBase.hpp" #include "Processor/ProcessorBase.hpp"
#include "Processor/IntInput.hpp"
#include "Math/bigint.hpp" #include "Math/bigint.hpp"
namespace GC namespace GC
@@ -82,8 +83,12 @@ U GC::Processor<T>::get_long_input(const int* params,
{ {
if (not T::actual_inputs) if (not T::actual_inputs)
return {}; return {};
U res = input_proc.get_input<FixInput_<U>>(interactive, U res;
&params[1]).items[0]; if (params[1] == 0)
res = input_proc.get_input<IntInput<U>>(interactive, 0).items[0];
else
res = input_proc.get_input<FixInput_<U>>(interactive,
&params[1]).items[0];
int n_bits = *params; int n_bits = *params;
check_input(res, n_bits); check_input(res, n_bits);
return res; return res;
@@ -229,6 +234,18 @@ void Processor<T>::xors(const vector<int>& args, size_t start, size_t end)
} }
} }
template<class T>
void Processor<T>::xorc(const ::BaseInstruction& instruction)
{
int total = instruction.get_n();
for (int i = 0; i < DIV_CEIL(total, T::default_length); i++)
{
int n = min(T::default_length, total - i * T::default_length);
C[instruction.get_r(0) + i] = BitVec(C[instruction.get_r(1) + i]).mask(n)
^ BitVec(C[instruction.get_r(2) + i]).mask(n);
}
}
template<class T> template<class T>
void Processor<T>::nots(const ::BaseInstruction& instruction) void Processor<T>::nots(const ::BaseInstruction& instruction)
{ {

View File

@@ -7,7 +7,6 @@
#define GC_REP4SECRET_H_ #define GC_REP4SECRET_H_
#include "ShareSecret.h" #include "ShareSecret.h"
#include "Processor/NoLivePrep.h"
#include "Protocols/Rep4MC.h" #include "Protocols/Rep4MC.h"
#include "Protocols/Rep4Share.h" #include "Protocols/Rep4Share.h"

View File

@@ -1,11 +0,0 @@
/*
* ReplicatedPrep.cpp
*
*/
#include <GC/SemiHonestRepPrep.h>
namespace GC
{
} /* namespace GC */

View File

@@ -119,7 +119,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
try try
{ {
read_mac_key(get_prep_sub_dir<T>(PREP_DIR, network_opts.nplayers), this->N, read_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
this->N,
this->mac_key); this->mac_key);
} }
catch (exception& e) catch (exception& e)

View File

@@ -38,6 +38,8 @@ template<class U>
class ShareSecret class ShareSecret
{ {
public: public:
typedef U whole_type;
typedef Memory<U> DynamicMemory; typedef Memory<U> DynamicMemory;
typedef SwitchableOutput out_type; typedef SwitchableOutput out_type;

View File

@@ -21,6 +21,7 @@
#include "ShareParty.h" #include "ShareParty.h"
#include "ShareThread.hpp" #include "ShareThread.hpp"
#include "Thread.hpp" #include "Thread.hpp"
#include "VectorProtocol.hpp"
namespace GC namespace GC
{ {

View File

@@ -29,7 +29,8 @@ ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, DataPositions&
*static_cast<Preprocessing<T>*>(new typename T::LivePrep( *static_cast<Preprocessing<T>*>(new typename T::LivePrep(
usage, *this)) : usage, *this)) :
*static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N, *static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N,
get_prep_sub_dir<T>(PREP_DIR, N.num_players()), usage))) get_prep_sub_dir<T>(PREP_DIR, N.num_players()),
usage, BaseMachine::thread_num)))
{ {
} }

View File

@@ -1,37 +0,0 @@
/*
* TinierPrep.h
*
*/
#ifndef GC_TINIERPREP_H_
#define GC_TINIERPREP_H_
#include "TinyPrep.h"
namespace GC
{
template<class T>
class TinierPrep : public TinyPrep<T>
{
public:
TinierPrep(DataPositions& usage, ShareThread<T>& thread,
bool amplify = true) :
TinyPrep<T>(usage, thread, amplify)
{
}
TinierPrep(SubProcessor<T>*, DataPositions& usage) :
TinierPrep(usage, ShareThread<T>::s())
{
}
void buffer_inputs(int player)
{
this->buffer_inputs_(player, this->triple_generator);
}
};
}
#endif /* GC_TINIERPREP_H_ */

View File

@@ -15,6 +15,9 @@ namespace GC
{ {
template<class T> class TinierPrep; template<class T> class TinierPrep;
template<class T> class VectorProtocol;
template<class T> class CcdPrep;
template<class T> class VectorInput;
template<class T> template<class T>
class TinierSecret : public VectorSecret<TinierShare<T>> class TinierSecret : public VectorSecret<TinierShare<T>>
@@ -25,9 +28,9 @@ class TinierSecret : public VectorSecret<TinierShare<T>>
public: public:
typedef TinyMC<This> MC; typedef TinyMC<This> MC;
typedef MC MAC_Check; typedef MC MAC_Check;
typedef Beaver<This> Protocol; typedef VectorProtocol<This> Protocol;
typedef ::Input<This> Input; typedef VectorInput<This> Input;
typedef TinierPrep<This> LivePrep; typedef CcdPrep<This> LivePrep;
typedef Memory<This> DynamicMemory; typedef Memory<This> DynamicMemory;
typedef NPartyTripleGenerator<This> TripleGenerator; typedef NPartyTripleGenerator<This> TripleGenerator;

View File

@@ -9,12 +9,12 @@
#include "Processor/DummyProtocol.h" #include "Processor/DummyProtocol.h"
#include "Protocols/Share.h" #include "Protocols/Share.h"
#include "Math/Bit.h" #include "Math/Bit.h"
#include "TinierSharePrep.h"
namespace GC namespace GC
{ {
template<class T> class TinierSecret; template<class T> class TinierSecret;
template<class T> class TinierSharePrep;
template<class T> template<class T>
class TinierShare: public Share_<SemiShare<Bit>, SemiShare<T>>, class TinierShare: public Share_<SemiShare<Bit>, SemiShare<T>>,
@@ -55,6 +55,11 @@ public:
return "Tinier"; return "Tinier";
} }
static string type_short()
{
return "TT";
}
static ShareThread<TinierSecret<T>>& get_party() static ShareThread<TinierSecret<T>>& get_party()
{ {
return ShareThread<TinierSecret<T>>::s(); return ShareThread<TinierSecret<T>>::s();
@@ -103,9 +108,7 @@ public:
void random() void random()
{ {
TinierSecret<T> tmp; *this = get_party().DataF.get_part().get_bit();
get_party().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
} }
This lsb() const This lsb() const

View File

@@ -21,18 +21,26 @@ template<class T>
class TinierSharePrep : public PersonalPrep<T> class TinierSharePrep : public PersonalPrep<T>
{ {
typename T::TripleGenerator* triple_generator; typename T::TripleGenerator* triple_generator;
typename T::whole_type::TripleGenerator* real_triple_generator;
MascotParams params; MascotParams params;
TinierPrep<TinierSecret<typename T::mac_key_type>> whole_prep; typedef typename T::whole_type secret_type;
ShareThread<secret_type>& thread;
void buffer_triples(); void buffer_triples();
void buffer_squares() { throw not_implemented(); } void buffer_squares() { throw not_implemented(); }
void buffer_bits() { throw not_implemented(); } void buffer_bits();
void buffer_inverses() { throw not_implemented(); } void buffer_inverses() { throw not_implemented(); }
void buffer_inputs(int player); void buffer_inputs(int player);
void buffer_secret_triples();
void init_real(Player& P);
public: public:
TinierSharePrep(DataPositions& usage, ShareThread<secret_type>& thread,
int input_player = PersonalPrep<T>::SECURE);
TinierSharePrep(DataPositions& usage, int input_player = TinierSharePrep(DataPositions& usage, int input_player =
PersonalPrep<T>::SECURE); PersonalPrep<T>::SECURE);
TinierSharePrep(SubProcessor<T>*, DataPositions& usage); TinierSharePrep(SubProcessor<T>*, DataPositions& usage);

View File

@@ -15,10 +15,16 @@ namespace GC
template<class T> template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) : TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) :
TinierSharePrep<T>(usage, ShareThread<secret_type>::s(), input_player)
{
}
template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage,
ShareThread<secret_type>& thread, int input_player) :
PersonalPrep<T>(usage, input_player), triple_generator(0), PersonalPrep<T>(usage, input_player), triple_generator(0),
whole_prep(usage, real_triple_generator(0),
ShareThread<TinierSecret<typename T::mac_key_type>>::s(), thread(thread)
input_player == PersonalPrep<T>::SECURE)
{ {
} }
@@ -33,6 +39,8 @@ TinierSharePrep<T>::~TinierSharePrep()
{ {
if (triple_generator) if (triple_generator)
delete triple_generator; delete triple_generator;
if (real_triple_generator)
delete real_triple_generator;
} }
template<class T> template<class T>
@@ -44,15 +52,14 @@ void TinierSharePrep<T>::set_protocol(typename T::Protocol& protocol)
params.generateMACs = true; params.generateMACs = true;
params.amplify = false; params.amplify = false;
params.check = false; params.check = false;
auto& thread = ShareThread<TinierSecret<typename T::mac_key_type>>::s(); auto& thread = ShareThread<typename T::whole_type>::s();
triple_generator = new typename T::TripleGenerator( triple_generator = new typename T::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
OnlineOptions::singleton.batch_size OnlineOptions::singleton.batch_size, 1,
* TinierSecret<typename T::mac_key_type>::default_length, 1,
params, thread.MC->get_alphai(), &protocol.P); params, thread.MC->get_alphai(), &protocol.P);
triple_generator->multi_threaded = false; triple_generator->multi_threaded = false;
this->inputs.resize(thread.P->num_players()); this->inputs.resize(thread.P->num_players());
whole_prep.init(*thread.P); init_real(protocol.P);
} }
template<class T> template<class T>
@@ -63,12 +70,8 @@ void TinierSharePrep<T>::buffer_triples()
this->buffer_personal_triples(); this->buffer_personal_triples();
return; return;
} }
else
array<TinierSecret<typename T::mac_key_type>, 3> whole; buffer_secret_triples();
whole_prep.get(DATA_TRIPLE, whole.data());
for (size_t i = 0; i < whole[0].get_regs().size(); i++)
this->triples.push_back(
{{ whole[0].get_reg(i), whole[1].get_reg(i), whole[2].get_reg(i) }});
} }
template<class T> template<class T>
@@ -81,12 +84,21 @@ void TinierSharePrep<T>::buffer_inputs(int player)
inputs.at(player).push_back(x); inputs.at(player).push_back(x);
} }
template<class T>
void GC::TinierSharePrep<T>::buffer_bits()
{
this->bits.push_back(
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
}
template<class T> template<class T>
size_t TinierSharePrep<T>::data_sent() size_t TinierSharePrep<T>::data_sent()
{ {
size_t res = whole_prep.data_sent(); size_t res = 0;
if (triple_generator) if (triple_generator)
res += triple_generator->data_sent(); res += triple_generator->data_sent();
if (real_triple_generator)
res += real_triple_generator->data_sent();
return res; return res;
} }

View File

@@ -1,11 +0,0 @@
/*
* TinyMC.cpp
*
*/
#include "TinyMC.h"
namespace GC
{
} /* namespace GC */

View File

@@ -14,7 +14,7 @@ namespace GC
template<class T> template<class T>
class TinyMC : public MAC_Check_Base<T> class TinyMC : public MAC_Check_Base<T>
{ {
typename T::check_type::MAC_Check part_MC; typename T::part_type::MAC_Check part_MC;
PointerVector<int> sizes; PointerVector<int> sizes;
public: public:

View File

@@ -1,71 +0,0 @@
/*
* TinyPrep.h
*
*/
#ifndef GC_TINYPREP_H_
#define GC_TINYPREP_H_
#include "Thread.h"
#include "OT/MascotParams.h"
#include "Protocols/Beaver.h"
#include "Protocols/ReplicatedPrep.h"
namespace GC
{
template<class T>
class TinyPrep : public BufferPrep<T>
{
protected:
ShareThread<T>& thread;
typename T::TripleGenerator* triple_generator;
MascotParams params;
vector<array<typename T::part_type, 3>> triple_buffer;
const bool amplify;
public:
TinyPrep(DataPositions& usage, ShareThread<T>& thread, bool amplify = true);
~TinyPrep();
void set_protocol(Beaver<T>& protocol);
void init(Player& P);
void buffer_triples();
void buffer_bits();
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
void buffer_inputs_(int player, typename T::InputGenerator* input_generator);
array<T, 3> get_triple_no_count(int n_bits);
size_t data_sent();
};
template<class T>
class TinyOnlyPrep : public TinyPrep<T>
{
typename T::part_type::TripleGenerator* input_generator;
public:
TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread);
~TinyOnlyPrep();
void set_protocol(Beaver<T>& protocol);
void buffer_inputs(int player)
{
this->buffer_inputs_(player, input_generator);
}
size_t data_sent();
};
} /* namespace GC */
#endif /* GC_TINYPREP_H_ */

View File

@@ -3,7 +3,7 @@
* *
*/ */
#include "TinyPrep.h" #include "TinierSharePrep.h"
#include "Protocols/MascotPrep.hpp" #include "Protocols/MascotPrep.hpp"
@@ -11,78 +11,26 @@ namespace GC
{ {
template<class T> template<class T>
TinyPrep<T>::TinyPrep(DataPositions& usage, ShareThread<T>& thread, void TinierSharePrep<T>::init_real(Player& P)
bool amplify) :
BufferPrep<T>(usage), thread(thread), triple_generator(0),
amplify(amplify)
{ {
assert(real_triple_generator == 0);
} real_triple_generator = new typename T::whole_type::TripleGenerator(
template<class T>
TinyOnlyPrep<T>::TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread) :
TinyPrep<T>(usage, thread), input_generator(0)
{
}
template<class T>
TinyPrep<T>::~TinyPrep()
{
if (triple_generator)
delete triple_generator;
}
template<class T>
TinyOnlyPrep<T>::~TinyOnlyPrep()
{
if (input_generator)
delete input_generator;
}
template<class T>
void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
{
init(protocol.P);
}
template<class T>
void TinyPrep<T>::init(Player& P)
{
params.generateMACs = true;
params.amplify = false;
params.check = false;
auto& thread = ShareThread<T>::s();
triple_generator = new typename T::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), P.N, -1, BaseMachine::s().fresh_ot_setup(), P.N, -1,
OnlineOptions::singleton.batch_size, 1, params, OnlineOptions::singleton.batch_size, 1, params,
thread.MC->get_alphai(), &P); thread.MC->get_alphai(), &P);
triple_generator->multi_threaded = false; real_triple_generator->multi_threaded = false;
} }
template<class T> template<class T>
void TinyOnlyPrep<T>::set_protocol(Beaver<T>& protocol) void TinierSharePrep<T>::buffer_secret_triples()
{ {
TinyPrep<T>::set_protocol(protocol); auto& triple_generator = real_triple_generator;
input_generator = new typename T::part_type::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
OnlineOptions::singleton.batch_size, 1, this->params,
this->thread.MC->get_alphai(), &protocol.P);
input_generator->multi_threaded = false;
}
template<class T>
void TinyPrep<T>::buffer_triples()
{
auto& triple_generator = this->triple_generator;
assert(triple_generator != 0); assert(triple_generator != 0);
params.generateBits = false; params.generateBits = false;
vector<array<typename T::check_type, 3>> triples; vector<array<T, 3>> triples;
TripleShuffleSacrifice<typename T::check_type> sacrifice; TripleShuffleSacrifice<T> sacrifice;
size_t required; size_t required;
if (amplify) required = sacrifice.minimum_n_inputs_with_combining();
required = sacrifice.minimum_n_inputs_with_combining();
else
required = sacrifice.minimum_n_inputs();
while (triples.size() < required) while (triples.size() < required)
{ {
triple_generator->generatePlainTriples(); triple_generator->generatePlainTriples();
@@ -92,9 +40,11 @@ void TinyPrep<T>::buffer_triples()
triple_generator->valueBits[2].set_portion(i, triple_generator->valueBits[2].set_portion(i,
triple_generator->plainTriples[i][2]); triple_generator->plainTriples[i][2]);
triple_generator->run_multipliers({}); triple_generator->run_multipliers({});
assert(triple_generator->plainTriples.size() != 0);
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++) for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
{ {
for (int j = 0; j < T::default_length; j++) int dl = secret_type::default_length;
for (int j = 0; j < dl; j++)
{ {
triples.push_back({}); triples.push_back({});
for (int k = 0; k < 3; k++) for (int k = 0; k < 3; k++)
@@ -103,10 +53,10 @@ void TinyPrep<T>::buffer_triples()
share.set_share( share.set_share(
triple_generator->plainTriples.at(i).at(k).get_bit( triple_generator->plainTriples.at(i).at(k).get_bit(
j)); j));
typename T::part_type::mac_type mac; typename T::mac_type mac;
mac = thread.MC->get_alphai() * share.get_share(); mac = thread.MC->get_alphai() * share.get_share();
for (auto& multiplier : triple_generator->ot_multipliers) for (auto& multiplier : triple_generator->ot_multipliers)
mac += multiplier->macs.at(k).at(i * T::default_length + j); mac += multiplier->macs.at(k).at(i * dl + j);
share.set_mac(mac); share.set_mac(mac);
} }
} }
@@ -114,104 +64,10 @@ void TinyPrep<T>::buffer_triples()
} }
sacrifice.triple_sacrifice(triples, triples, sacrifice.triple_sacrifice(triples, triples,
*thread.P, thread.MC->get_part_MC()); *thread.P, thread.MC->get_part_MC());
if (amplify) sacrifice.triple_combine(triples, triples, *thread.P,
sacrifice.triple_combine(triples, triples, *thread.P, thread.MC->get_part_MC());
thread.MC->get_part_MC()); for (auto& triple : triples)
for (size_t i = 0; i < triples.size() / T::default_length; i++) this->triples.push_back(triple);
{
this->triples.push_back({});
auto& triple = this->triples.back();
for (auto& x : triple)
x.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_triple = triples[j + i * T::default_length];
for (int k = 0; k < 3; k++)
triple[k].get_reg(j) = source_triple[k];
}
}
}
template<class T>
void TinyPrep<T>::buffer_bits()
{
auto tmp = BufferPrep<T>::get_random_from_inputs(thread.P->num_players());
for (auto& bit : tmp.get_regs())
{
this->bits.push_back({});
this->bits.back().resize_regs(1);
this->bits.back().get_reg(0) = bit;
}
}
template<class T>
void TinyPrep<T>::buffer_inputs_(int player, typename T::InputGenerator* input_generator)
{
auto& inputs = this->inputs;
inputs.resize(this->thread.P->num_players());
assert(input_generator);
input_generator->generateInputs(player);
assert(input_generator->inputs.size() >= T::default_length);
for (size_t i = 0; i < input_generator->inputs.size() / T::default_length; i++)
{
inputs[player].push_back({});
inputs[player].back().share.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_input = input_generator->inputs[j
+ i * T::default_length];
inputs[player].back().share.get_reg(j) = source_input.share;
inputs[player].back().value ^= typename T::open_type(
source_input.value.get_bit(0)) << j;
}
}
}
template<class T>
array<T, 3> TinyPrep<T>::get_triple_no_count(int n_bits)
{
assert(n_bits > 0);
while ((unsigned)n_bits > triple_buffer.size())
{
array<T, 3> tmp;
this->get(DATA_TRIPLE, tmp.data());
for (size_t i = 0; i < tmp[0].get_regs().size(); i++)
{
triple_buffer.push_back(
{ {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} });
}
}
array<T, 3> res;
for (int j = 0; j < 3; j++)
res[j].resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
{
for (int j = 0; j < 3; j++)
res[j].get_reg(i) = triple_buffer.back()[j];
triple_buffer.pop_back();
}
return res;
}
template<class T>
size_t TinyPrep<T>::data_sent()
{
size_t res = 0;
if (triple_generator)
res += triple_generator->data_sent();
return res;
}
template<class T>
size_t TinyOnlyPrep<T>::data_sent()
{
auto res = TinyPrep<T>::data_sent();
if (input_generator)
res += input_generator->data_sent();
return res;
} }
} /* namespace GC */ } /* namespace GC */

View File

@@ -1,11 +0,0 @@
/*
* TinySecret.cpp
*
*/
#include "TinySecret.h"
namespace GC
{
} /* namespace GC */

View File

@@ -21,6 +21,9 @@ namespace GC
template<class T> class TinyOnlyPrep; template<class T> class TinyOnlyPrep;
template<class T> class TinyMC; template<class T> class TinyMC;
template<class T> class VectorProtocol;
template<class T> class VectorInput;
template<class T> class CcdPrep;
template<class T> template<class T>
class VectorSecret : public Secret<T> class VectorSecret : public Secret<T>
@@ -50,11 +53,6 @@ public:
static const int default_length = 64; static const int default_length = 64;
static DataFieldType field_type()
{
return BitVec::field_type();
}
static int size() static int size()
{ {
return part_type::size() * default_length; return part_type::size() * default_length;
@@ -166,9 +164,9 @@ public:
} }
template <class U> template <class U>
void other_input(U& inputter, int from, int) void other_input(U& inputter, int from, int n_bits)
{ {
inputter.add_other(from); inputter.add_other(from, n_bits);
} }
template <class U> template <class U>
@@ -187,9 +185,9 @@ class TinySecret : public VectorSecret<TinyShare<S>>
public: public:
typedef TinyMC<This> MC; typedef TinyMC<This> MC;
typedef MC MAC_Check; typedef MC MAC_Check;
typedef Beaver<This> Protocol; typedef VectorProtocol<This> Protocol;
typedef ::Input<This> Input; typedef VectorInput<This> Input;
typedef TinyOnlyPrep<This> LivePrep; typedef CcdPrep<This> LivePrep;
typedef Memory<This> DynamicMemory; typedef Memory<This> DynamicMemory;
typedef OTTripleGenerator<This> TripleGenerator; typedef OTTripleGenerator<This> TripleGenerator;

View File

@@ -1,11 +0,0 @@
/*
* TinyShare.cpp
*
*/
#include "TinyShare.h"
namespace GC
{
} /* namespace GC */

View File

@@ -10,13 +10,14 @@
#include "ShareParty.h" #include "ShareParty.h"
#include "Secret.h" #include "Secret.h"
#include "Protocols/Spdz2kShare.h" #include "Protocols/Spdz2kShare.h"
#include "Processor/NoLivePrep.h"
namespace GC namespace GC
{ {
template<int S> class TinySecret; template<int S> class TinySecret;
template<class T> class ShareThread; template<class T> class ShareThread;
template<class T> class TinierSharePrep;
template<int S> template<int S>
class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>> class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>>
@@ -28,12 +29,18 @@ public:
typedef void DynamicMemory; typedef void DynamicMemory;
typedef NoLivePrep<This> LivePrep; typedef Beaver<This> Protocol;
typedef MAC_Check_Z2k_<This> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ::Input<This> Input;
typedef TinierSharePrep<This> LivePrep;
typedef SwitchableOutput out_type; typedef SwitchableOutput out_type;
typedef This small_type; typedef This small_type;
typedef NoShare bit_type;
static string name() static string name()
{ {
return "tiny share"; return "tiny share";

View File

@@ -18,14 +18,14 @@ class VectorInput : public InputBase<T>
deque<int> input_lengths; deque<int> input_lengths;
public: public:
VectorInput(typename T::MAC_Check&, Preprocessing<T>&, Player& P) : VectorInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
part_input(0, P) part_input(MC.get_part_MC(), prep.get_part(), P)
{ {
part_input.reset_all(P); part_input.reset_all(P);
} }
VectorInput(SubProcessor<T>& proc, typename T::MAC_Check&) : VectorInput(SubProcessor<T>& proc, typename T::MAC_Check&) :
VectorInput(proc.MC, proc.DataF, proc.P) part_input(proc.MC, proc.DataF, proc.P)
{ {
} }
@@ -41,8 +41,10 @@ public:
input_lengths.push_back(n_bits); input_lengths.push_back(n_bits);
} }
void add_other(int) void add_other(int player, int n_bits)
{ {
for (int i = 0; i < n_bits; i++)
part_input.add_other(player);
} }
void send_mine() void send_mine()

View File

@@ -17,6 +17,8 @@ class VectorProtocol : public ProtocolBase<T>
typename T::part_type::Protocol part_protocol; typename T::part_type::Protocol part_protocol;
public: public:
Player& P;
VectorProtocol(Player& P); VectorProtocol(Player& P);
void init_mul(SubProcessor<T>* proc); void init_mul(SubProcessor<T>* proc);

View File

@@ -3,6 +3,9 @@
* *
*/ */
#ifndef GC_VECTORPROTOCOL_HPP_
#define GC_VECTORPROTOCOL_HPP_
#include "VectorProtocol.h" #include "VectorProtocol.h"
namespace GC namespace GC
@@ -10,7 +13,7 @@ namespace GC
template<class T> template<class T>
VectorProtocol<T>::VectorProtocol(Player& P) : VectorProtocol<T>::VectorProtocol(Player& P) :
part_protocol(P) part_protocol(P), P(P)
{ {
} }
@@ -54,3 +57,5 @@ T VectorProtocol<T>::finalize_mul(int n)
} }
} /* namespace GC */ } /* namespace GC */
#endif

View File

@@ -40,7 +40,7 @@
#define BIT_INSTRUCTIONS \ #define BIT_INSTRUCTIONS \
X(XORS, T::xors(PROC, EXTRA)) \ X(XORS, T::xors(PROC, EXTRA)) \
X(XORCB, C0.xor_(PC1, PC2)) \ X(XORCB, processor.xorc(instruction)) \
X(XORCBI, C0.xor_(PC1, IMM)) \ X(XORCBI, C0.xor_(PC1, IMM)) \
X(NOTS, processor.nots(INST)) \ X(NOTS, processor.nots(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \

View File

@@ -20,13 +20,14 @@
#include "GC/TinierSecret.h" #include "GC/TinierSecret.h"
#include "GC/TinyMC.h" #include "GC/TinyMC.h"
#include "GC/TinierPrep.h" #include "GC/VectorInput.h"
#include "GC/ShareParty.hpp" #include "GC/ShareParty.hpp"
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp" #include "GC/TinyPrep.hpp"
#include "GC/ShareSecret.hpp" #include "GC/ShareSecret.hpp"
#include "GC/TinierSharePrep.hpp" #include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "Math/gfp.hpp" #include "Math/gfp.hpp"

View File

@@ -8,9 +8,8 @@
#include "GC/TinySecret.h" #include "GC/TinySecret.h"
#include "GC/TinyMC.h" #include "GC/TinyMC.h"
#include "GC/TinyPrep.h"
#include "GC/TinierPrep.h"
#include "GC/TinierSecret.h" #include "GC/TinierSecret.h"
#include "GC/VectorInput.h"
#include "Processor/Data_Files.hpp" #include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp" #include "Processor/Instruction.hpp"
@@ -29,3 +28,4 @@
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp" #include "GC/TinyPrep.hpp"
#include "GC/TinierSharePrep.hpp" #include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"

View File

@@ -30,6 +30,7 @@
#include "GC/ShareSecret.hpp" #include "GC/ShareSecret.hpp"
#include "GC/VectorProtocol.hpp" #include "GC/VectorProtocol.hpp"
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Math/gfp.hpp" #include "Math/gfp.hpp"
ShamirOptions ShamirOptions::singleton; ShamirOptions ShamirOptions::singleton;

View File

@@ -12,6 +12,7 @@
#include "GC/ShareSecret.hpp" #include "GC/ShareSecret.hpp"
#include "GC/ThreadMaster.hpp" #include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Machines/ShamirMachine.hpp" #include "Machines/ShamirMachine.hpp"
int main(int argc, const char** argv) int main(int argc, const char** argv)

View File

@@ -12,7 +12,6 @@
#include "FHE/NTL-Subs.h" #include "FHE/NTL-Subs.h"
#include "GC/TinierSecret.h" #include "GC/TinierSecret.h"
#include "GC/TinierPrep.h"
#include "GC/TinyMC.h" #include "GC/TinyMC.h"
#include "SPDZ.hpp" #include "SPDZ.hpp"

View File

@@ -18,12 +18,14 @@
int main(int argc, const char** argv) int main(int argc, const char** argv)
{ {
OnlineOptions online_opts; OnlineOptions& online_opts = OnlineOptions::singleton;
Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector<string>({"localhost"})); Names N;
ez::ezOptionParser opt; ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv); RingOptions ring_opts(opt, argc, argv);
online_opts = {opt, argc, argv};
opt.parse(argc, argv); opt.parse(argc, argv);
opt.syntax = string(argv[0]) + " <progname>"; opt.syntax = string(argv[0]) + " <progname>";
string progname; string progname;
if (opt.firstArgs.size() > 1) if (opt.firstArgs.size() > 1)
progname = *opt.firstArgs.at(1); progname = *opt.firstArgs.at(1);
@@ -50,36 +52,14 @@ int main(int argc, const char** argv)
int R = ring_opts.ring_size_from_opts_or_schedule(progname); int R = ring_opts.ring_size_from_opts_or_schedule(progname);
switch (R) switch (R)
{ {
case 64: #define X(L) \
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname, case L: \
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, Machine<FakeShare<SignedZ2<L>>, FakeShare<gf2n>>(0, N, progname, \
online_opts.live_prep, online_opts).run(); online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \
break; online_opts.live_prep, online_opts).run(); \
case 128:
Machine<FakeShare<SignedZ2<128>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 256:
Machine<FakeShare<SignedZ2<256>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 192:
Machine<FakeShare<SignedZ2<192>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 384:
Machine<FakeShare<SignedZ2<384>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 512:
Machine<FakeShare<SignedZ2<512>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break; break;
X(64) X(128) X(256) X(192) X(384) X(512)
#undef X
default: default:
cerr << "Not compiled for " << R << "-bit rings" << endl; cerr << "Not compiled for " << R << "-bit rings" << endl;
} }

View File

@@ -12,6 +12,7 @@
#include "GC/ShareSecret.hpp" #include "GC/ShareSecret.hpp"
#include "GC/ThreadMaster.hpp" #include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp" #include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Machines/ShamirMachine.hpp" #include "Machines/ShamirMachine.hpp"
#include "Machines/MalRep.hpp" #include "Machines/MalRep.hpp"

23
Machines/no-party.cpp Normal file
View File

@@ -0,0 +1,23 @@
/*
* no-party.cpp
*
*/
#include "Protocols/NoShare.h"
#include "Processor/OnlineMachine.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/Replicated.hpp"
#include "Math/gfp.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
OnlineOptions::singleton = {opt, argc, argv};
OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton);
OnlineOptions::singleton.finalize(opt, argc, argv);
machine.start_networking();
// use primes of length 65 to 128 for arithmetic computation
machine.run<NoShare<gfp_<0, 2>>, NoShare<gf2n>>();
}

View File

@@ -12,7 +12,6 @@
#include "Math/gfp.h" #include "Math/gfp.h"
#include "Math/gf2n.h" #include "Math/gf2n.h"
#include "Tools/ezOptionParser.h" #include "Tools/ezOptionParser.h"
#include "Processor/NoLivePrep.h"
#include "GC/MaliciousCcdSecret.h" #include "GC/MaliciousCcdSecret.h"
#include "Processor/FieldMachine.hpp" #include "Processor/FieldMachine.hpp"

Some files were not shown because too many files have changed in this diff Show More