Convolutional neural network training.

This commit is contained in:
Marcel Keller
2021-07-02 15:49:23 +10:00
parent f35447df51
commit 99c0549e72
208 changed files with 3615 additions and 1865 deletions

View File

@@ -167,7 +167,7 @@ void Node::Broadcast2(SendBuffer& msg) {
}
void Node::_identify() {
char* msg = id_msg;
char msg[strlen(ID_HDR)+sizeof(_id)];
memcpy(msg, ID_HDR, strlen(ID_HDR));
memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
//printf("Node:: identifying myself:\n");

View File

@@ -78,8 +78,6 @@ private:
std::map<struct sockaddr_in*,int> _clientsmap;
bool* _clients_connected;
NodeUpdatable* _updatable;
char id_msg[strlen(ID_HDR)+sizeof(_id)];
};
#endif /* NETWORK_NODE_H_ */

View File

@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.5 (Jul 2, 2021)
- Training of convolutional neural networks
- Bit decomposition using edaBits
- Ability to force MAC checks from high-level code
- Ability to close client connection from high-level code
- Binary operators for comparison results
- Faster compilation for emulation
- More documentation
- Fixed security bug: insufficient LowGear secret key randomness
- Fixed security bug: skewed random bit generation
## 0.2.4 (Apr 19, 2021)
- ARM support

View File

@@ -117,7 +117,7 @@ class xorm(NonVectorInstruction):
code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb']
class xorcb(NonVectorInstruction):
class xorcb(BinaryVectorInstruction):
""" Bitwise XOR of two single clear bit registers.
:param: result (cbit)
@@ -125,7 +125,7 @@ class xorcb(NonVectorInstruction):
:param: operand (cbit)
"""
code = opcodes['XORCB']
arg_format = ['cbw','cb','cb']
arg_format = ['int','cbw','cb','cb']
class xorcbi(NonVectorInstruction):
""" Bitwise XOR of single clear bit register and immediate.

View File

@@ -36,6 +36,7 @@ class bits(Tape.Register, _structure, _bit):
class bitsn(cls):
n = length
cls.types[length] = bitsn
bitsn.clear_type = cbits.get_type(length)
bitsn.__name__ = cls.__name__ + str(length)
return cls.types[length]
@classmethod
@@ -115,7 +116,11 @@ class bits(Tape.Register, _structure, _bit):
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
@classmethod
def new(cls, value=None, n=None):
return cls.get_type(n)(value)
def __init__(self, value=None, n=None, size=None):
assert n == self.n or n is None
if size != 1 and size is not None:
raise Exception('invalid size for bit type: %s' % size)
self.n = n or self.n
@@ -125,7 +130,7 @@ class bits(Tape.Register, _structure, _bit):
if value is not None:
self.load_other(value)
def copy(self):
return type(self)(n=instructions_base.get_global_vector_size())
return type(self).new(n=instructions_base.get_global_vector_size())
def set_length(self, n):
if n > self.n:
raise Exception('too long: %d/%d' % (n, self.n))
@@ -154,6 +159,8 @@ class bits(Tape.Register, _structure, _bit):
bits = other.bit_decompose()
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
other = self.bit_compose(bits)
assert(isinstance(other, type(self)))
assert(other.n == self.n)
self.load_other(other)
except:
raise CompilerError('cannot convert %s/%s from %s to %s' % \
@@ -176,6 +183,16 @@ class bits(Tape.Register, _structure, _bit):
res.i = i
res.program = self.program
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
@@ -202,14 +219,16 @@ class cbits(bits):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
if isinstance(other, cbits):
res = cbits(n=max(self.n, other.n))
res = cbits.get_type(max(self.n, other.n))()
c_inst(res, self, other)
return res
elif isinstance(other, sbits):
return NotImplemented
else:
if util.is_constant(other):
if other >= 2**31 or other < -2**31:
return op(self, cbits(other))
res = cbits(n=max(self.n, len(bin(other)) - 2))
res = cbits.get_type(max(self.n, len(bin(other)) - 2))()
ci_inst(res, self, other)
return res
else:
@@ -221,8 +240,14 @@ class cbits(bits):
def __xor__(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
elif isinstance(other, cbits):
res = cbits.get_type(max(self.n, other.n))()
assert res.size == self.size
assert res.size == other.size
inst.xorcb(res.n, res, self, other)
return res
else:
self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor)
return self.clear_op(other, None, inst.xorcbi, operator.xor)
__radd__ = __add__
__rxor__ = __xor__
def __mul__(self, other):
@@ -230,17 +255,18 @@ class cbits(bits):
return NotImplemented
else:
try:
res = cbits(n=min(self.max_length, self.n+util.int_len(other)))
res = cbits.get_type(min(self.max_length,
self.n+util.int_len(other)))()
inst.mulcbi(res, self, other)
return res
except TypeError:
return NotImplemented
def __rshift__(self, other):
res = cbits(n=self.n-other)
res = cbits.new(n=self.n-other)
inst.shrcbi(res, self, other)
return res
def __lshift__(self, other):
res = cbits(n=self.n+other)
res = cbits.get_type(self.n+other)()
inst.shlcbi(res, self, other)
return res
def print_reg(self, desc=''):
@@ -504,16 +530,6 @@ class sbits(bits):
res = [cls.new(n=len(rows)) for i in range(n_columns)]
inst.trans(len(res), *(res + rows))
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
@staticmethod
def bit_adder(*args, **kwargs):
return sbitint.bit_adder(*args, **kwargs)
@@ -610,7 +626,7 @@ class sbitvec(_vec):
elif isinstance(other, (list, tuple)):
self.v = self.bit_extend(sbitvec(other).v, n)
else:
self.v = sbits(other, n=n).bit_decompose(n)
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
@classmethod
def load_mem(cls, address):
@@ -630,6 +646,8 @@ class sbitvec(_vec):
for i in range(n):
v[i].store_in_mem(address + i)
def reveal(self):
if len(self) > cbits.unit:
return self.elements()[0].reveal()
revealed = [cbit() for i in range(len(self))]
for i in range(len(self)):
try:
@@ -784,15 +802,23 @@ class bit(object):
def result_conv(x, y):
try:
def f(res):
try:
return t.conv(res)
except:
return res
if util.is_constant(x):
if util.is_constant(y):
return lambda x: x
else:
return type(y).conv
t = type(y)
return f
if util.is_constant(y):
return type(x).conv
t = type(x)
return f
if type(x) is type(y):
return type(x).conv
t = type(x)
return f
except AttributeError:
pass
return lambda x: x
@@ -807,13 +833,19 @@ class sbit(bit, sbits):
This will output 5.
"""
return result_conv(x, y)(self * (x ^ y) ^ y)
assert self.n == 1
diff = x ^ y
if isinstance(diff, cbits):
return result_conv(x, y)(self & (diff) ^ y)
else:
return result_conv(x, y)(self * (diff) ^ y)
class cbit(bit, cbits):
pass
sbits.bit_type = sbit
cbits.bit_type = cbit
sbit.clear_type = cbit
class bitsBlock(oram.Block):
value_type = sbits
@@ -881,7 +913,7 @@ class _sbitintbase:
return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None):
k = bit_length or max(self.n, other.n)
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
return (library.IntDiv(self.cast(k), other.cast(k), k) >> k).cast(k)
def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding
@@ -1100,7 +1132,8 @@ class cbitfix(object):
bits = self.v.bit_decompose(self.k)
sign = bits[-1]
v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0))
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))
class sbitfix(_fix):
""" Secret signed integer in one binary register.

View File

@@ -123,7 +123,7 @@ class StraightlineAllocator:
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)
free[reg.reg_type, base.size].add(self.alloc[base])
free[reg.reg_type, base.size].append(self.alloc[base])
if inst.is_vec() and base.vector:
self.defined[base] = inst
for i in base.vector:
@@ -604,4 +604,4 @@ class RegintOptimizer:
elif op == 1:
instructions[i] = None
inst.args[0].link(inst.args[1])
instructions[:] = filter(lambda x: x is not None, instructions)
instructions[:] = list(filter(lambda x: x is not None, instructions))

View File

@@ -127,7 +127,7 @@ def sha3_256(x):
from circuit import sha3_256
a = sbitvec.from_vec([])
b = sbitvec(sint(0xcc), 8)
b = sbitvec(sint(0xcc), 8, 8)
for x in a, b:
sha3_256(x).elements()[0].reveal().print_reg()

View File

@@ -73,6 +73,7 @@ def require_ring_size(k, op):
if int(program.options.ring) < k:
raise CompilerError('ring size too small for %s, compile '
'with \'-R %d\' or more' % (op, k))
program.curr_tape.require_bit_length(k)
@instructions_base.cisc
def LTZ(s, a, k, kappa):

View File

@@ -55,7 +55,7 @@ def EQZ(a, k, kappa):
from GC.types import sbitvec
v = sbitvec(a, k).v
bit = util.tree_reduce(operator.and_, (~b for b in v))
return types.sint.conv(bit)
return types.sintbit.conv(bit)
prog.non_linear.check_security(kappa)
return prog.non_linear.eqz(a, k)
@@ -263,16 +263,17 @@ def BitAdd(a, b, bits_to_compute=None):
def BitDec(a, k, m, kappa, bits_to_compute=None):
return program.Program.prog.non_linear.bit_dec(a, k, m)
def BitDecRing(a, k, m):
def BitDecRingRaw(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m
assert(n_shift >= 0)
if program.Program.prog.use_split():
x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
return bits[:m]
else:
if program.Program.prog.use_dabit:
if program.Program.prog.use_edabit():
r, r_bits = types.sint.get_edabit(m, strict=False)
elif program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r = types.sint.bit_compose(r)
else:
@@ -281,7 +282,12 @@ def BitDecRing(a, k, m):
shifted = ((a - r) << n_shift).reveal()
masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return [types.sint.conv(bit) for bit in bits]
return bits
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
r_dprime = types.sint()
@@ -429,7 +435,7 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
s = (1 - overflow) * t + overflow * t / 2
return s, overflow
def Int2FL(a, gamma, l, kappa):
def Int2FL(a, gamma, l, kappa=None):
lam = gamma - 1
s = a.less_than(0, gamma, security=kappa)
z = a.equal(0, gamma, security=kappa)
@@ -598,13 +604,13 @@ def SDiv_mono(a, b, l, kappa):
# Unconditionally Secure Constant-Rounds Multi-party Computation
# for Equality, Comparison, Bits and Exponentiation
def BITLT(a, b, bit_length):
sint = types.sint
e = [sint(0)]*bit_length
g = [sint(0)]*bit_length
h = [sint(0)]*bit_length
from .types import sint, regint, longint, cint
e = [None]*bit_length
g = [None]*bit_length
h = [None]*bit_length
for i in range(bit_length):
# Compute the XOR (reverse order of e for PreOpL)
e[bit_length-i-1] = a[i].bit_xor(b[i])
e[bit_length-i-1] = util.bit_xor(a[i], b[i])
f = PreOpL(or_op, e)
g[bit_length-1] = f[0]
for i in range(bit_length-1):
@@ -612,7 +618,7 @@ def BITLT(a, b, bit_length):
g[i] = f[bit_length-i-1]-f[bit_length-i-2]
ans = 0
for i in range(bit_length):
h[i] = g[i]*b[i]
h[i] = g[i].bit_and(b[i])
ans = ans + h[i]
return ans
@@ -620,9 +626,9 @@ def BITLT(a, b, bit_length):
# - From the paper
# Multiparty Computation for Interval, Equality, and Comparison without
# Bit-Decomposition Protocol
def BitDecFull(a):
def BitDecFull(a, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint
from .types import sint, regint, longint, cint
p = get_program().prime
assert p
bit_length = p.bit_length()
@@ -631,9 +637,16 @@ def BitDecFull(a):
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
# no need for exact randomness generation
# if modulo a power of two is close enough
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
if logp != bit_length:
bbits += [sint(0, size=a.size)]
if get_program().use_edabit():
b, bbits = sint.get_edabit(logp, True, size=a.size)
if logp != bit_length:
from .GC.types import sbits
bbits += [sbits.get_type(a.size)(0)]
else:
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
b = sint.bit_compose(bbits)
if logp != bit_length:
bbits += [sint(0, size=a.size)]
else:
bbits = [sint(size=a.size) for i in range(bit_length)]
tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)]
@@ -653,15 +666,21 @@ def BitDecFull(a):
for j in range(a.size):
for i in range(bit_length):
movs(bbits[i][j], tbits[j][i])
b = sint.bit_compose(bbits)
b = sint.bit_compose(bbits)
c = (a-b).reveal()
t = (p-c).bit_decompose(bit_length)
cmodp = c
t = bbits[0].bit_decompose_clear(p - c, bit_length)
c = longint(c, bit_length)
czero = (c==0)
q = 1-BITLT( bbits, t, bit_length)
fbar=((1<<bit_length)+c-p).bit_decompose(bit_length)
fbard = c.bit_decompose(bit_length)
g = [(fbar[i] - fbard[i]) * q + fbard[i] for i in range(bit_length)]
h = BitAdd(bbits, g)
abits = [(1 - czero) * h[i] + czero * bbits[i] for i in range(bit_length)]
return abits
q = bbits[0].long_one() - BITLT(bbits, t, bit_length)
fbar = [bbits[0].clear_type.conv(cint(x))
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
h = bbits[0].bit_adder(bbits, g)
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
for i in range(bit_length)]
if maybe_mixed:
return abits
else:
return [sint.conv(bit) for bit in abits]

View File

@@ -1,4 +1,5 @@
import heapq
import collections
from Compiler.exceptions import *
class GraphError(CompilerError):
@@ -23,7 +24,7 @@ class SparseDiGraph(object):
self.n = max_nodes
# each node contains list of default attributes, followed by outoing edges
self.nodes = [list(self.default_attributes.values()) for i in range(self.n)]
self.succ = [set() for i in range(self.n)]
self.succ = [collections.OrderedDict() for i in range(self.n)]
self.pred = [set() for i in range(self.n)]
self.weights = {}
@@ -32,7 +33,7 @@ class SparseDiGraph(object):
def __getitem__(self, i):
""" Get list of the neighbours of node i """
return self.succ[i]
return self.succ[i].keys()
def __iter__(self):
pass #return iter(self.nodes)
@@ -68,7 +69,7 @@ class SparseDiGraph(object):
self.pred[v].remove(i)
#del self.weights[(i,v)]
for v in pred:
self.succ[v].remove(i)
del self.succ[v][i]
#del self.weights[(v,i)]
#self.nodes[v].remove(i)
self.pred[i] = []
@@ -77,7 +78,7 @@ class SparseDiGraph(object):
def add_edge(self, i, j, weight=1):
if j not in self[i]:
self.pred[j].add(i)
self.succ[i].add(j)
self.succ[i][j] = None
self.weights[(i,j)] = weight
def add_edges_from(self, tuples):
@@ -89,7 +90,7 @@ class SparseDiGraph(object):
self.add_edge(edge[0], edge[1])
def remove_edge(self, i, j):
self.succ[i].remove(j)
del self.succ[i][j]
self.pred[j].remove(i)
del self.weights[(i,j)]

View File

@@ -2219,22 +2219,23 @@ class conv2ds(base.DataInstruction):
:param: number of channels (int)
:param: padding height (int)
:param: padding width (int)
:param: batch size (int)
"""
code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
'int','int','int']
'int','int','int','int']
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, *args, **kwargs):
super(conv2ds, self).__init__(*args, **kwargs)
assert args[0].size == args[3] * args[4]
assert args[1].size == args[5] * args[6] * args[11]
assert args[0].size == args[3] * args[4] * args[14]
assert args[1].size == args[5] * args[6] * args[11] * args[14]
assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11]
self.args[11] * self.args[14]
@base.vectorize
class trunc_pr(base.VarArgsInstruction):
@@ -2250,6 +2251,15 @@ class trunc_pr(base.VarArgsInstruction):
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
class check(base.Instruction):
"""
Force MAC check in current thread and all idle thread if current
thread is the main thread.
"""
__slots__ = []
code = base.opcodes['CHECK']
arg_format = []
###
### CISC-style instructions
###
@@ -2289,47 +2299,5 @@ class lts(base.CISC):
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
@base.vectorize
class g2muls(base.CISC):
r""" Secret GF(2) multiplication """
__slots__ = []
arg_format = ['sgw','sg','sg']
def expand(self):
s = [program.curr_block.new_reg('sg') for i in range(9)]
c = [program.curr_block.new_reg('cg') for i in range(3)]
gbittriple(s[0], s[1], s[2])
gsubs(s[3], self.args[1], s[0])
gsubs(s[4], self.args[2], s[1])
gasm_open(c[0], s[3])
gasm_open(c[1], s[4])
gmulbitm(s[5], s[1], c[0])
gmulbitm(s[6], s[0], c[1])
gmulbitc(c[2], c[0], c[1])
gadds(s[7], s[2], s[5])
gadds(s[8], s[7], s[6])
gaddm(self.args[0], s[8], c[2])
#@base.vectorize
#class gmulbits(base.CISC):
# r""" Secret $GF(2^n) \times GF(2)$ multiplication """
# __slots__ = []
# arg_format = ['sgw','sg','sg']
#
# def expand(self):
# s = [program.curr_block.new_reg('s') for i in range(9)]
# c = [program.curr_block.new_reg('c') for i in range(3)]
# g2ntriple(s[0], s[1], s[2])
# subs(s[3], self.args[1], s[0])
# subs(s[4], self.args[2], s[1])
# startopen(s[3], s[4])
# stopopen(c[0], c[1])
# mulm(s[5], s[1], c[0])
# mulm(s[6], s[0], c[1])
# mulc(c[2], c[0], c[1])
# adds(s[7], s[2], s[5])
# adds(s[8], s[7], s[6])
# addm(self.args[0], s[8], c[2])
# hack for circular dependency
from Compiler import comparison

View File

@@ -18,6 +18,8 @@ from Compiler import program
### MUST also be changed. (+ the documentation)
###
opcodes = dict(
# Emulation
CISC = 0,
# Load/store
LDI = 0x1,
LDSI = 0x2,
@@ -98,6 +100,7 @@ opcodes = dict(
MATMULS = 0xAA,
MATMULSM = 0xAB,
CONV2DS = 0xAC,
CHECK = 0xAF,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -409,7 +412,7 @@ def cisc(function):
program.curr_block.instructions.append(self)
def get_def(self):
return [self.args[0]]
return [call[0][0] for call in self.calls]
def get_used(self):
return self.used
@@ -423,6 +426,7 @@ def cisc(function):
def merge(self, other):
self.calls += other.calls
self.used += other.used
def get_size(self):
return self.args[0].size
@@ -470,7 +474,9 @@ def cisc(function):
inst.copy(size, subs)
reset_global_vector_size()
def expand_merged(self):
def expand_merged(self, skip):
if function.__name__ in skip:
return [self], 0
tape = program.curr_tape
block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block
@@ -496,10 +502,38 @@ def cisc(function):
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
reset_global_vector_size()
base += reg.size
return block.instructions
return block.instructions, self.n_rounds - 1
def expanded_rounds(self):
return self.n_rounds - 1
def add_usage(self, *args):
pass
def get_bytes(self):
assert len(self.kwargs) < 2
res = int_to_bytes(opcodes['CISC'])
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
name = self.function.__name__
String.check(name)
res += String.encode(name)
for call in self.calls:
assert not call[1]
res += int_to_bytes(len(call[0]) + 2)
res += int_to_bytes(call[0][0].size)
for arg in call[0]:
res += self.arg_to_bytes(arg)
return bytearray(res)
@classmethod
def arg_to_bytes(self, arg):
if arg is None:
return int_to_bytes(0)
try:
return int_to_bytes(arg.i)
except:
return int_to_bytes(arg)
def __str__(self):
return self.function.__name__ + ' ' + ', '.join(
str(x) for x in itertools.chain(call[0] for call in self.calls))
MergeCISC.__name__ = function.__name__
@@ -804,11 +838,8 @@ class Instruction(object):
else:
return self.args
def expand_merged(self):
return [self]
def expanded_rounds(self):
return 0
def expand_merged(self, skip):
return [self], 0
def get_new_args(self, size, subs):
new_args = []

View File

@@ -170,7 +170,7 @@ def print_ln_to(player, ss, *args):
Example::
print_ln_to(player, 'output for %s: %s', x.reveal_to(player))
print_ln_to(player, 'output for %s: %s', player, x.reveal_to(player))
"""
cond = player == get_player_id()
new_args = []
@@ -293,7 +293,9 @@ class Function:
self.compile_args = compile_args
def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
get_reg_type = lambda x: regint if isinstance(x, int) else type(x)
from .types import _types
get_reg_type = lambda x: \
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
if len(args) not in self.type_args:
# first call
type_args = collections.defaultdict(list)
@@ -324,7 +326,8 @@ class Function:
j = 0
for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch')
raise CompilerError('type mismatch: "%s" not of type "%s"' %
(args[i_arg], reg_type))
store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type)
return self.on_call(base, bases)
@@ -371,7 +374,7 @@ class FunctionBlock(Function):
parent_node = get_tape().req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock
block.alloc_pool = defaultdict(set)
block.alloc_pool = defaultdict(list)
del parent_node.children[-1]
self.node = get_tape().req_node
if get_program().verbose:
@@ -763,22 +766,34 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
@function_block
def step(l):
l = MemValue(l)
@for_range_opt_multithread(n_threads, len(a) // k)
m = 2 ** int(math.ceil(math.log(len(a), 2)))
@for_range_opt_multithread(n_threads, m // k)
def _(i):
n_inner = l // k
j = i % n_inner
i //= n_inner
base = i*l + j
step = l//k
def swap(base, step):
if m == len(a):
a[base], a[base + step] = \
cond_swap(a[base], a[base + step])
else:
# ignore values outside range
go = base + step < len(a)
x = a.maybe_get(go, base)
y = a.maybe_get(go, base + step)
tmp = cond_swap(x, y)
for i, idx in enumerate((base, base + step)):
a.maybe_set(go, idx, tmp[i])
if k == 2:
a[base], a[base+step] = \
cond_swap(a[base], a[base+step])
swap(base, step)
else:
@for_range_opt(n_innermost)
def f(i):
m1 = step + i * 2 * step
m2 = m1 + base
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
swap(m2, step)
steps[key] = step
steps[key](l)
@@ -870,6 +885,8 @@ def for_range_parallel(n_parallel, n_loops):
"""
Decorator to execute a loop :py:obj:`n_loops` up to
:py:obj:`n_parallel` loop bodies in parallel.
Using any other control flow instruction inside the loop breaks
the optimization.
:param n_parallel: compile-time (int)
:param n_loops: regint/cint/int
@@ -887,9 +904,12 @@ def for_range_parallel(n_parallel, n_loops):
def for_range_opt(n_loops, budget=None):
""" Execute loop bodies in parallel up to an optimization budget.
This prevents excessive loop unrolling. The budget is respected
even with nested loops. Note that optimization is rather
even with nested loops. Note that the optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_parallel` in this case.
Using further control flow constructions inside other than
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
optimization.
:param n_loops: int/regint/cint
:param budget: number of instructions after which to start optimization (default is 100,000)
@@ -1082,18 +1102,19 @@ def multithread(n_threads, n_items=None, max_size=None):
"""
if n_items is None:
n_items = n_threads
if max_size is None:
if max_size is None or n_items <= max_size:
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
else:
def wrapper(function):
@multithread(n_threads, n_items)
def new_function(base, size):
for i in range(0, size, max_size):
part_base = base + i
part_size = min(max_size, size - i)
function(part_base, part_size)
break_point()
@for_range(size // max_size)
def _(i):
function(base + i * max_size, max_size)
rem = size % max_size
if rem:
function(base + size - rem, rem)
return wrapper
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
@@ -1200,6 +1221,23 @@ def map_sum(n_threads, n_parallel, n_loops, n_items, value_types):
return tuple(a + b for a,b in zip(x,y))
return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector):
inputs = vector.Array(len(vector))
inputs.assign_vector(vector)
outputs = vector.Array(len(vector) // 2)
left = len(vector)
while left > 1:
@multithread(n_threads, left // 2)
def _(base, size):
outputs.assign_vector(
function(inputs.get_vector(2 * base, size),
inputs.get_vector(2 * base + size, size)), base)
inputs.assign_vector(outputs.get_vector(0, left // 2))
if left % 2 == 1:
inputs[left // 2] = inputs[left - 1]
left = (left + 1) // 2
return inputs[0]
def foreach_enumerate(a):
""" Run-time loop over public data. This uses
``Player-Data/Public-Input/<progname>``. Example:
@@ -1511,6 +1549,15 @@ def break_point(name=''):
"""
get_tape().start_new_basicblock(name=name)
def check_point():
"""
Force MAC checks in current thread and all idle threads if the
current thread is the main thread. This implies a break point.
"""
break_point('pre-check')
check()
break_point('post-check')
# Fixed point ops
from math import ceil, log
@@ -1566,6 +1613,9 @@ def cint_cint_division(a, b, k, f):
# theta can be replaced with something smaller
# for safety we assume that is the same theta from previous GS method
if get_program().options.ring:
assert 2 * f < int(get_program().options.ring)
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
@@ -1579,9 +1629,11 @@ def cint_cint_division(a, b, k, f):
B = absolute_b
W = w0
for i in range(1, theta):
A = (A * W) >> f
B = (B * W) >> f
corr = cint(1) << (f - 1)
for i in range(theta):
A = (A * W + corr) >> f
B = (B * W + corr) >> f
W = two - B
return (sign_a * sign_b) * A
@@ -1592,7 +1644,7 @@ def sint_cint_division(a, b, k, f, kappa):
"""
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b < 0)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
absolute_b = b * sign_b
absolute_a = a * sign_a
@@ -1652,7 +1704,8 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
"""
Approximate reciprocal of [b]:
Given [b], compute [1/b]
@@ -1662,7 +1715,7 @@ def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
d = alpha - 2 * c
w = d * v
w = w.round(2 * k, 2 * (k - f), kappa, nearest, signed=True)
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
# now w * 2 ^ {-f} should be an initial approximation of 1/b
return w
@@ -1674,7 +1727,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
# For simplex, we can get rid of computing abs(b)
temp = None
if simplex_flag == False:
temp = comparison.LessThanZero(b, 2 * k, kappa)
temp = comparison.LessThanZero(b, k, kappa)
elif simplex_flag == True:
temp = cint(0)
@@ -1682,7 +1735,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
absolute_val = sign * b
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa)[::-1]
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
suffixes = PreOR(bits, kappa)[::-1]
z = [0] * k
@@ -1690,10 +1743,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1]
#doing complicated stuff to compute v = 2^{k-m}
acc = cint(0)
for i in range(k):
acc += two_power(k-i-1) * z[i]
acc = sint.bit_compose(reversed(z))
part_reciprocal = absolute_val * acc
signed_acc = sign * acc

File diff suppressed because it is too large Load Diff

View File

@@ -846,3 +846,54 @@ def acos(x):
"""
y = asin(x)
return pi_over_2 - y
def tanh(x):
"""
Hyperbolic tangent. For efficiency, accuracy is diminished
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
:math:`f` denote the fixed-point parameters.
"""
limit = math.log(2 ** (x.k - x.f - 2)) / 2
s = x < -limit
t = x > limit
y = pow_fx(math.e, 2 * x)
return s.if_else(-1, t.if_else(1, (y - 1) / (y + 1)))
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
def Sep(x):
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
t = x.v * (1 + x.v.bit_compose(b_i.bit_not() for b_i in b[-2 * x.f + 1:]))
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
b += [b[0].long_one()]
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
def SqrtComp(z, old=False):
f = types.sfix.f
k = len(z)
if isinstance(z[0], types.sint):
return types.sfix._new(sum(z[i] * types.cfix(
2 ** (-(i - f + 1) / 2)).v for i in range(k)))
k_prime = k // 2
f_prime = f // 2
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
c0 = types.sfix(2 ** (f / 2 + 1))
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
if old:
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
else:
b = util.tree_reduce(lambda x, y: x.bit_xor(y), z[::2])
return types.sint.conv(b).if_else(c1, c0) * tmp
@types.vectorize
def InvertSqrt(x, old=False):
"""
Reciprocal square root approximation by `Lu et al.
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
"""
u, z = Sep(x)
c = 3.14736 + u * (4.63887 * u - 5.77789)
return c * SqrtComp(z, old=old)

View File

@@ -44,7 +44,7 @@ class Masking(NonLinear):
d = [None]*k
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
d[i] = r[i].bit_xor(b)
return 1 - types.sint.conv(self.kor(d))
return 1 - types.sintbit.conv(self.kor(d))
class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """
@@ -71,8 +71,11 @@ class Prime(Masking):
def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m, self.kappa)
def bit_dec(self, a, k, m):
return BitDecField(a, k, m, self.kappa)
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecFieldRaw(a, k, m, self.kappa)
else:
return BitDecField(a, k, m, self.kappa)
def kor(self, d):
return KOR(d, self.kappa)
@@ -85,7 +88,7 @@ class KnownPrime(NonLinear):
def _mod2m(self, a, k, m, signed):
if signed:
a += cint(1) << (k - 1)
return sint.bit_compose(self.bit_dec(a, k, k)[:m])
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
def _trunc_pr(self, a, k, m, signed):
# nearest truncation
@@ -96,14 +99,14 @@ class KnownPrime(NonLinear):
if signed:
a += cint(1) << (k - 1)
k += 1
res = sint.bit_compose(self.bit_dec(a, k, k)[m:])
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
if signed:
res -= cint(1) << (k - m - 2)
return res
def bit_dec(self, a, k, m):
def bit_dec(self, a, k, m, maybe_mixed=False):
assert k < self.prime.bit_length()
bits = BitDecFull(a)
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
if len(bits) < m:
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
return bits[:m]
@@ -111,7 +114,7 @@ class KnownPrime(NonLinear):
def eqz(self, a, k):
# always signed
a += two_power(k)
return 1 - KORL(self.bit_dec(a, k, k))
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
@@ -130,8 +133,11 @@ class Ring(Masking):
def _trunc_pr(self, a, k, m, signed):
return TruncPrRing(a, k, m, signed=signed)
def bit_dec(self, a, k, m):
return BitDecRing(a, k, m)
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecRingRaw(a, k, m)
else:
return BitDecRing(a, k, m)
def kor(self, d):
return KORL(d)

View File

@@ -28,9 +28,7 @@ data_types = dict(
square = 1,
bit = 2,
inverse = 3,
bittriple = 4,
bitgf2ntriple = 5,
dabit = 6,
dabit = 4,
)
field_types = dict(
@@ -62,6 +60,7 @@ class defaults:
asmoutfile = None
stop = False
insecure = False
keep_cisc = False
class Program(object):
""" A program consists of a list of tapes representing the whole
@@ -80,14 +79,14 @@ class Program(object):
self.init_names(args)
self._security = 40
self.prime = None
self.tapes = []
if sum(x != 0 for x in(options.ring, options.field,
options.binary)) > 1:
raise CompilerError('can only use one out of -B, -R, -F')
if options.prime and (options.ring or options.binary):
raise CompilerError('can only use one out of -B, -R, -p')
if options.ring:
self.bit_length = int(options.ring) - 1
self.non_linear = Ring(int(options.ring))
self.set_ring_size(int(options.ring))
else:
self.bit_length = int(options.binary) or int(options.field)
if options.prime:
@@ -108,7 +107,6 @@ class Program(object):
if self.verbose:
print('Galois length:', self.galois_length)
self.tape_counter = 0
self.tapes = []
self._curr_tape = None
self.DEBUG = options.debug
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
@@ -204,6 +202,16 @@ class Program(object):
for arg in args[1:])
self.progname = progname
def set_ring_size(self, ring_size):
from .non_linear import Ring
for tape in self.tapes:
prev = tape.req_bit_length['p']
if prev and prev != ring_size:
raise CompilerError('cannot have different ring sizes')
self.bit_length = ring_size - 1
self.non_linear = Ring(ring_size)
self.options.ring = str(ring_size)
def new_tape(self, function, args=[], name=None, single_thread=False):
"""
Create a new tape from a function. See
@@ -414,7 +422,7 @@ class Program(object):
self.curr_tape.start_new_basicblock(None, 'memory-usage')
# reset register counter to 0
self.curr_tape.init_registers()
for mem_type,size in list(self.allocated_mem.items()):
for mem_type,size in sorted(self.allocated_mem.items()):
if size:
#print "Memory of type '%s' of size %d" % (mem_type, size)
if mem_type in self.types:
@@ -488,7 +496,7 @@ class Program(object):
else:
if change and not self.options.ring:
raise CompilerError('splitting only supported for rings')
assert change > 1
assert change > 1 or change == False
self._split = change
def use_square(self, change=None):
@@ -575,7 +583,7 @@ class Tape:
scope.children.append(self)
self.alloc_pool = scope.alloc_pool
else:
self.alloc_pool = defaultdict(set)
self.alloc_pool = defaultdict(list)
self.purged = False
self.n_rounds = 0
self.n_to_merge = 0
@@ -647,9 +655,14 @@ class Tape:
def expand_cisc(self):
new_instructions = []
if self.parent.program.options.keep_cisc:
skip = ['LTZ', 'Trunc']
else:
skip = []
for inst in self.instructions:
new_instructions.extend(inst.expand_merged())
self.n_rounds += inst.expanded_rounds()
new_inst, n_rounds = inst.expand_merged(skip)
new_instructions.extend(new_inst)
self.n_rounds += n_rounds
self.instructions = new_instructions
def __str__(self):
@@ -774,7 +787,10 @@ class Tape:
# allocate registers
reg_counts = self.count_regs()
if not options.noreallocate:
if options.noreallocate:
if self.program.verbose:
print('Tape register usage:', dict(reg_counts))
else:
if self.program.verbose:
print('Tape register usage before re-allocation:',
dict(reg_counts))
@@ -1071,7 +1087,7 @@ class Tape:
if size is None:
size = Compiler.instructions_base.get_global_vector_size()
if size is not None and size > self.maximum_size:
raise CompilerError('vector too large')
raise CompilerError('vector too large: %d' % size)
self.size = size
self.vectorbase = self
self.relative_i = 0

View File

@@ -591,12 +591,12 @@ class _register(Tape.Register, _number, _structure):
def prep_res(cls, other):
return cls()
@staticmethod
def bit_compose(bits):
@classmethod
def bit_compose(cls, bits):
""" Compose value from bits.
:param bits: iterable of any type implementing left shift """
return sum(b << i for i,b in enumerate(bits))
return sum(cls.conv(b) << i for i,b in enumerate(bits))
@classmethod
def malloc(cls, size, creator_tape=None):
@@ -840,6 +840,7 @@ class cint(_clear, _int):
def in_immediate_range(value):
return value < 2**31 and value >= -2**31
@vectorize_init
def __init__(self, val=None, size=None):
"""
:param val: initialization (cint/regint/int/cgf2n or list thereof)
@@ -1119,12 +1120,6 @@ class cgf2n(_clear, _gf2n):
elif chunk:
sum += chunk
def __mul__(self, other):
""" Clear :math:`\mathrm{GF}(2^n)` multiplication.
:param other: cgf2n/regint/int """
return super(cgf2n, self).__mul__(other)
def __neg__(self):
""" Identity. """
return self
@@ -1209,7 +1204,9 @@ class regint(_register, _int):
def get_random(cls, bit_length):
""" Public insecure randomness.
:param bit_length: number of bits (int) """
:param bit_length: number of bits (int)
:param size: vector size (int, default 1)
"""
if isinstance(bit_length, int):
bit_length = regint(bit_length)
res = cls()
@@ -1582,7 +1579,9 @@ class _secret(_register):
def get_input_from(cls, player):
""" Secret input from player.
:param player: public (regint/cint/int) """
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
res = cls()
asm_input(res, player)
return res
@@ -1592,7 +1591,9 @@ class _secret(_register):
def get_random_triple(cls):
""" Secret random triple according to security model.
:return: :math:`(a, b, ab)` """
:return: :math:`(a, b, ab)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls(), cls())
triple(*res)
return res
@@ -1602,7 +1603,9 @@ class _secret(_register):
def get_random_bit(cls):
""" Secret random bit according to security model.
:return: 0/1 50-50 """
:return: 0/1 50-50
:param size: vector size (int, default 1)
"""
res = cls()
bit(res)
return res
@@ -1612,7 +1615,9 @@ class _secret(_register):
def get_random_square(cls):
""" Secret random square according to security model.
:return: :math:`(a, a^2)` """
:return: :math:`(a, a^2)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls())
square(*res)
return res
@@ -1622,7 +1627,9 @@ class _secret(_register):
def get_random_inverse(cls):
""" Secret random inverse tuple according to security model.
:return: :math:`(a, a^{-1})` """
:return: :math:`(a, a^{-1})`
:param size: vector size (int, default 1)
"""
res = (cls(), cls())
inverse(*res)
return res
@@ -1717,16 +1724,51 @@ class _secret(_register):
else:
self.load_clear(self.clear_type(val))
@classmethod
def bit_compose(cls, bits):
""" Compose value from bits.
:param bits: iterable of any type convertible to sint """
from Compiler.GC.types import sbits, sbitintvec
bits = list(bits)
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
if program.use_edabit():
mask = cls.get_edabit(len(bits), strict=True, size=bits[0].n)
else:
tmp = sint(size=bits[0].n)
randoms(tmp, len(bits))
n_overflow_bits = min(program.use_split().bit_length(),
int(program.options.ring) - len(bits))
mask_bits = tmp.bit_decompose(len(bits) + n_overflow_bits,
maybe_mixed=True)
if n_overflow_bits:
overflow = sint.bit_compose(
sint.conv(x) for x in mask_bits[-n_overflow_bits:])
mask = tmp - (overflow << len(bits)), \
mask_bits[:-n_overflow_bits]
else:
mask = tmp, mask_bits
t = sbitintvec.get_type(len(bits) + 1)
masked = t.from_vec(mask[1] + [0]) + t.from_vec(bits + [0])
overflow = masked.v[-1]
masked = cls.bit_compose(x.reveal().to_regint_by_bit() for x in masked.v[:-1])
return masked - mask[0] + (cls(overflow) << len(bits))
else:
return super(_secret, cls).bit_compose(bits)
@set_instruction_type
@read_mem_value
@vectorize
def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False):
cls = self.__class__
res = self.prep_res(other)
cls = type(res)
if isinstance(other, regint):
other = res.clear_type(other)
if isinstance(other, cls):
s_inst(res, self, other)
if reverse:
s_inst(res, other, self)
else:
s_inst(res, self, other)
elif isinstance(other, res.clear_type):
if reverse:
m_inst(res, other, self)
@@ -1861,10 +1903,12 @@ class sint(_secret, _int):
def get_random_int(cls, bits):
""" Secret random n-bit number according to security model.
:param bits: compile-time integer (int) """
:param bits: compile-time integer (int)
:param size: vector size (int, default 1)
"""
if program.use_edabit():
return sint.get_edabit(bits, True)[0]
elif program.use_split() > 2:
elif program.use_split() > 2 and program.use_split() < 5:
tmp = sint()
randoms(tmp, bits)
x = tmp.split_to_two_summands(bits, True)
@@ -1882,7 +1926,10 @@ class sint(_secret, _int):
@vectorized_classmethod
def get_random(cls):
""" Secret random ring element according to security model. """
""" Secret random ring element according to security model.
:param size: vector size (int, default 1)
"""
res = sint()
randomfulls(res)
return res
@@ -1891,7 +1938,9 @@ class sint(_secret, _int):
def get_input_from(cls, player):
""" Secret input.
:param player: public (regint/cint/int) """
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
res = cls()
inputmixed('int', res, player)
return res
@@ -1915,7 +1964,7 @@ class sint(_secret, _int):
else:
a = [sint.get_random_bit() for i in range(n_bits)]
return sint.bit_compose(a), a
program.curr_tape.require_bit_length(n_bits)
program.curr_tape.require_bit_length(n_bits - 1)
whole = cls()
size = get_global_vector_size()
from Compiler.GC.types import sbits, sbitvec
@@ -1931,6 +1980,7 @@ class sint(_secret, _int):
return 1
@staticmethod
@vectorize
def bit_decompose_clear(a, n_bits):
return floatingpoint.bits(a, n_bits)
@@ -2055,7 +2105,7 @@ class sint(_secret, _int):
:param other: sint/cint/regint/int
:return: 0/1 (sint) """
res = sint()
res = sintbit()
comparison.LTZ(res, self - other,
(bit_length or program.bit_length) + 1,
security or program.security)
@@ -2064,7 +2114,7 @@ class sint(_secret, _int):
@read_mem_value
@vectorize
def __gt__(self, other, bit_length=None, security=None):
res = sint()
res = sintbit()
comparison.LTZ(res, other - self,
(bit_length or program.bit_length) + 1,
security or program.security)
@@ -2185,13 +2235,14 @@ class sint(_secret, _int):
return floatingpoint.Trunc(other, program.bit_length, self, program.security)
@vectorize
def bit_decompose(self, bit_length=None, security=None):
def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False):
""" Secret bit decomposition. """
if bit_length == 0:
return []
bit_length = bit_length or program.bit_length
security = security or program.security
return floatingpoint.BitDec(self, bit_length, bit_length, security)
assert program.security == security or program.security
return program.non_linear.bit_dec(self, bit_length, bit_length,
maybe_mixed)
def TruncMul(self, other, k, m, kappa=None, nearest=False):
return (self * other).round(k, m, kappa, nearest, signed=True)
@@ -2249,6 +2300,7 @@ class sint(_secret, _int):
return floatingpoint.two_power(n)
def split_to_n_summands(self, length, n):
comparison.require_ring_size(length, 'splitting')
from .GC.types import sbits
from .GC.instructions import split
columns = [[sbits.get_type(self.size)()
@@ -2274,7 +2326,9 @@ class sint(_secret, _int):
@vectorize
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Result potentially written to ``Player-Data/Private-Output-P<player>.``
Result potentially written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint):
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
@@ -2288,6 +2342,65 @@ class sint(_secret, _int):
else:
return super(sint, self).reveal_to(player)
class sintbit(sint):
@classmethod
def prep_res(cls, other):
return sint()
def load_other(self, other):
if isinstance(other, sint):
movs(self, other)
else:
super(sintbit, self).load_other(other)
@vectorize
def __and__(self, other):
if isinstance(other, sintbit):
res = sintbit()
muls(res, self, other)
return res
elif util.is_zero(other):
return 0
elif util.is_one(other):
return self
else:
return NotImplemented
@vectorize
def __or__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
return 1
else:
return NotImplemented
@vectorize
def __xor__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - 2 * self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
return 1
else:
return NotImplemented
@vectorize
def __rsub__(self, other):
if util.is_one(other):
res = sintbit()
subsfi(res, self, 1)
return res
else:
return super(sintbit, self).__rsub__(other)
class sgf2n(_secret, _gf2n):
""" Secret :math:`\mathrm{GF}(2^n)` value. """
__slots__ = []
@@ -2437,10 +2550,11 @@ class sgf2n(_secret, _gf2n):
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
for t in (sint, sgf2n):
t.bit_type = t
t.basic_type = t
t.default_type = t
sint.bit_type = sintbit
sgf2n.bit_type = sgf2n
class _bitint(object):
bits = None
@@ -3046,14 +3160,17 @@ class cfix(_number, _structure):
@staticmethod
def int_rep(v, f, k=None):
if isinstance(v, regint):
v = cint(v)
res = v * (2 ** f)
try:
res = int(round(res))
if k and abs(res) >= 2 ** k:
if k and res >= 2 ** (k - 1) or res < -2 ** (k - 1):
limit = 2 ** (k - f - 1)
raise CompilerError(
'Value out of fixed-point range (maximum %d). '
'Value out of fixed-point range [-%d, %d). '
'Use `sfix.set_precision(f, k)` with k being at least f+%d'
% (2 ** (k - f), math.ceil(math.log(abs(v), 2)) + 1))
% (limit, limit, res.bit_length() - f + 1))
except TypeError:
pass
return res
@@ -3268,6 +3385,14 @@ class cfix(_number, _structure):
else:
raise TypeError('Incompatible fixed point types in division')
@vectorize
def __rtruediv__(self, other):
""" Fixed-point division.
:param other: sfix/sint/cfix/cint/regint/int """
other = parse_type(other, self.k, self.f)
return other / self
def print_plain(self):
""" Clear fixed-point output. """
print_float_plain(cint.conv(self.v), cint(-self.f), \
@@ -3468,7 +3593,7 @@ class _fix(_single):
set_precision = classmethod(set_precision)
@classmethod
def set_precision_from_args(cls, program):
def set_precision_from_args(cls, program, adapt_ring=False):
f = None
k = None
for arg in program.args:
@@ -3484,6 +3609,15 @@ class _fix(_single):
cfix.set_precision(f, k)
elif k is not None:
raise CompilerError('need to set fractional precision')
if 'nearest' in program.args:
print('Nearest rounding instead of proabilistic '
'for fixed-point computation')
cls.round_nearest = True
if adapt_ring and program.options.ring:
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
if need != int(program.options.ring):
print('Changing computation modulus to 2^%d' % need)
program.set_ring_size(need)
@classmethod
def coerce(cls, other):
@@ -3609,11 +3743,14 @@ class _fix(_single):
:param other: sfix/cfix/sint/cint/regint/int """
if util.is_constant_float(other):
assert other != 0
other_length = self.f + math.ceil(math.log(abs(other), 2))
if other_length >= self.k:
factor = 2 ** (self.k - other_length - 1)
log = math.ceil(math.log(abs(other), 2))
other_length = self.f + log
if other_length >= self.k - 1:
factor = 2 ** (self.k - other_length - 2)
self *= factor
other *= factor
if 2 ** log == other:
return self * 2 ** -log
other = self.coerce(other)
assert self.k == other.k
assert self.f == other.f
@@ -3660,7 +3797,9 @@ class sfix(_fix):
def get_input_from(cls, player):
""" Secret fixed-point input.
:param player: public (regint/cint/int) """
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
cls.int_type.require_bit_length(cls.k)
v = cls.int_type()
inputmixed('fix', v, cls.f, player)
@@ -3677,6 +3816,7 @@ class sfix(_fix):
:param lower: float
:param upper: float
:param size: vector size (int, default 1)
"""
log_range = int(math.log(upper - lower, 2))
n_bits = log_range + cls.f
@@ -3732,7 +3872,8 @@ class sfix(_fix):
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Raw representation possibly written to
``Player-Data/Private-Output-P<player>.``
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
@@ -4066,7 +4207,9 @@ class sfloat(_number, _structure):
def get_input_from(cls, player):
""" Secret floating-point input.
:param player: public (regint/cint/int) """
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
v = sint()
p = sint()
z = sint()
@@ -4444,6 +4587,7 @@ class Array(object):
self.address_cache = {}
self.debug = debug
self.creator_tape = program.curr_tape
self.sink = None
if alloc:
self.alloc()
@@ -4514,6 +4658,17 @@ class Array(object):
return
self._store(value, self.get_address(index))
def maybe_get(self, condition, index):
return condition * self[condition * index]
def maybe_set(self, condition, index, value):
if self.sink is None:
self.sink = self.value_type.Array(1)
addresses = (condition.if_else(x, y) for x, y in
zip(util.tuplify(self.get_address(index)),
util.tuplify(self.sink.get_address(0))))
self._store(value, util.untuplify(tuple(addresses)))
# the following two are useful for compile-time lengths
# and thus differ from the usual Python syntax
def get_range(self, start, size):
@@ -4590,11 +4745,22 @@ class Array(object):
get_part_vector = get_vector
def get_part(self, base, size):
return Array(size, self.value_type, self.get_address(base))
def get(self, indices):
return self.value_type.load_mem(
regint.inc(len(indices), self.address, 0) + indices,
size=len(indices))
def get_slice_vector(self, slice):
assert self.value_type.n_elements() == 1
assert len(slice) <= self.total_size()
base = regint.inc(len(slice), slice.address, 1, 1)
inc = regint.inc(len(slice), 0, 1, 1, 1)
addresses = slice.value_type.load_mem(base) + inc
return self.value_type.load_mem(self.address + addresses)
def expand_to_vector(self, index, size):
assert self.value_type.n_elements() == 1
address = regint(size=size)
@@ -4641,6 +4807,12 @@ class Array(object):
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() * value
def __truediv__(self, value):
""" Vector division.
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() / value
def __pow__(self, value):
""" Vector power-of computation.
@@ -4674,6 +4846,16 @@ class Array(object):
reveal_nested = reveal_list
def sort(self, n_threads=None):
"""
Sort in place using Batchers' odd-even merge mergesort
with complexity :math:`O(n (\log n)^2)`.
:param n_threads: number of threads to use (single thread by
default)
"""
library.loopy_odd_even_merge_sort(self, n_threads=n_threads)
def __str__(self):
return '%s array of length %s at %s' % (self.value_type, len(self),
self.address)
@@ -4784,6 +4966,15 @@ class SubMultiArray(object):
assert vector.size <= self.total_size()
vector.store_in_mem(self.address + base * part_size)
def get_slice_vector(self, slice):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
assert len(slice) * part_size <= self.total_size()
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
addresses = slice.value_type.load_mem(base) * part_size + inc
return self.value_type.load_mem(self.address + addresses)
def get_addresses(self, *indices):
assert self.value_type.n_elements() == 1
assert len(indices) == len(self.sizes)
@@ -4816,6 +5007,10 @@ class SubMultiArray(object):
""" :return: new multidimensional array with same shape and basic type """
return MultiArray(self.sizes, self.value_type)
def get_part(self, start, size):
return MultiArray([size] + list(self.sizes[1:]), self.value_type,
address=self[start].address)
def input_from(self, player, budget=None, raw=False):
""" Fill with inputs from player if supported by type.
@@ -4978,7 +5173,7 @@ class SubMultiArray(object):
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
assert len(indices[1]) == len(indices[2])
indices = list(indices)
indices[3] *= other.sizes[0]
indices[3] *= other.sizes[1]
return self.value_type.direct_matrix_mul(
self.address, other.address, None, self.sizes[1], 1,
reduce=reduce, indices=indices)

View File

@@ -195,7 +195,7 @@ def is_all_ones(x, n):
else:
return False
def max(x, y=None):
def max(x, y=None, n_threads=None):
if y is None:
return tree_reduce(max, x)
else:

View File

@@ -7,6 +7,7 @@
#include "Networking/CryptoPlayer.h"
#include "Math/gfp.h"
#include "ECDSA/P256Element.h"
#include "GC/VectorInput.h"
#include "ECDSA/preprocessing.hpp"
#include "ECDSA/sign.hpp"
@@ -20,6 +21,8 @@
#include "Protocols/MascotPrep.hpp"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/CcdPrep.hpp"
#include "OT/NPartyTripleGenerator.hpp"
#include <assert.h>

View File

@@ -5,6 +5,7 @@
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/VectorInput.h"
#include "Protocols/Share.hpp"
#include "Protocols/MAC_Check.hpp"

View File

@@ -19,6 +19,8 @@
#include "Processor/Data_Files.hpp"
#include "Processor/Input.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/CcdPrep.hpp"
#include <assert.h>

View File

@@ -13,7 +13,6 @@
#include "Protocols/MaliciousShamirShare.h"
#include "Protocols/Rep3Share.h"
#include "GC/TinierSecret.h"
#include "GC/TinierPrep.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/TinyMC.h"
@@ -128,16 +127,4 @@ void check(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
MC.Check(P);
}
template<>
void ReplicatedPrep<Rep3Share<P256Element::Scalar>>::buffer_bits()
{
throw not_implemented();
}
template<>
void ReplicatedPrep<ShamirShare<P256Element::Scalar>>::buffer_bits()
{
throw not_implemented();
}
#endif /* ECDSA_PREPROCESSING_HPP_ */

View File

@@ -149,11 +149,6 @@ public:
return res;
}
bool is_binary() const
{
throw not_implemented();
}
size_t report_size(ReportType type)
{
size_t res = 4;

View File

@@ -6,6 +6,7 @@
#include "AddableVector.h"
#include "Rq_Element.h"
#include "FHE_Keys.h"
#include "P2Data.h"
template<class T>
AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
@@ -33,7 +34,3 @@ AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
}
return res;
}
template
AddableVector<Int_Random_Coins::rand_type> AddableVector<
Int_Random_Coins::rand_type>::mul_by_X_i(int j, const FHE_PK& pk) const;

View File

@@ -23,8 +23,6 @@ class Ciphertext
word pk_id;
public:
static string type_string() { return "ciphertext"; }
static int t() { return 0; }
static int size() { return 0; }
const FHE_Params& get_params() const { return *params; }
@@ -41,8 +39,6 @@ class Ciphertext
set(a0, a1, C.get_pk_id());
}
~Ciphertext() { ; }
// Rely on default copy assignment/constructor
word get_pk_id() const { return pk_id; }

View File

@@ -32,52 +32,6 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const
void RandomVectors::set(int nn,int hh,double R)
{
n=nn;
h=hh;
DG.set(R);
}
void RandomVectors::set_n(int nn)
{
n = nn;
}
vector<bigint> RandomVectors::sample_Gauss(PRNG& G, int stretch) const
{
vector<bigint> ans(n);
for (int i=0; i<n; i++)
{ ans[i]=DG.sample(G, stretch); }
return ans;
}
vector<bigint> RandomVectors::sample_Hwt(PRNG& G) const
{
if (h > n/2 or h <= 0) { return sample_Gauss(G); }
vector<bigint> ans(n);
for (int i=0; i<n; i++) { ans[i]=0; }
int cnt=0,j=0;
unsigned char ch=0;
while (cnt<h)
{ unsigned int i=G.get_uint()%n;
if (ans[i]==0)
{ cnt++;
if (j==0)
{ j=8;
ch=G.get_uchar();
}
int v=ch&1; j--;
if (v==0) { ans[i]=-1; }
else { ans[i]=1; }
}
}
return ans;
}
int sample_half(PRNG& G)
{
@@ -91,36 +45,6 @@ int sample_half(PRNG& G)
}
vector<bigint> RandomVectors::sample_Half(PRNG& G) const
{
vector<bigint> ans(n);
for (int i=0; i<n; i++)
ans[i] = sample_half(G);
return ans;
}
vector<bigint> RandomVectors::sample_Uniform(PRNG& G,const bigint& B) const
{
vector<bigint> ans(n);
bigint v;
for (int i=0; i<n; i++)
{ G.get_bigint(v, numBits(B));
int bit=G.get_uint()&1;
if (bit==0) { ans[i]=v; }
else { ans[i]=-v; }
}
return ans;
}
bool RandomVectors::operator!=(const RandomVectors& other) const
{
if (n != other.n or h != other.h or DG != other.DG)
return true;
else
return false;
}
bool DiscreteGauss::operator!=(const DiscreteGauss& other) const
{
if (other.NewHopeB != NewHopeB)

View File

@@ -25,7 +25,6 @@ class DiscreteGauss
void unpack(octetStream& o) { o.unserialize(NewHopeB); }
DiscreteGauss(double R) { set(R); }
~DiscreteGauss() { ; }
// Rely on default copy constructor/assignment
@@ -36,50 +35,6 @@ class DiscreteGauss
bool operator!=(const DiscreteGauss& other) const;
};
/* Sample from integer lattice of dimension n
* with standard deviation R
*/
class RandomVectors
{
int n,h;
DiscreteGauss DG; // This generates the main distribution
public:
void set(int nn,int hh,double R); // R is input STANDARD DEVIATION
void set_n(int nn);
void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
void unpack(octetStream& o)
{ o.get(n); o.get(h); DG.unpack(o); }
RandomVectors(int h, double R) : RandomVectors(0, h, R) {}
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
~RandomVectors() { ; }
// Rely on default copy constructor/assignment
double get_R() const { return DG.get_R(); }
DiscreteGauss get_DG() const { return DG; }
int get_h() const { return h; }
// Sample from Discrete Gauss distribution
vector<bigint> sample_Gauss(PRNG& G, int stretch = 1) const;
// Next samples from Hwt distribution unless hwt>n/2 in which
// case it uses Gauss
vector<bigint> sample_Hwt(PRNG& G) const;
// Sample from {-1,0,1} with Pr(-1)=Pr(1)=1/4 and Pr(0)=1/2
vector<bigint> sample_Half(PRNG& G) const;
// Sample from (-B,0,B) with uniform prob
vector<bigint> sample_Uniform(PRNG& G,const bigint& B) const;
bool operator!=(const RandomVectors& other) const;
};
template<class T>
class RandomGenerator : public Generator<T>
{
@@ -103,7 +58,7 @@ public:
void get(T& x) const { this->G.get(x, n_bits, positive); }
};
template<class T>
template<class T = bigint>
class GaussianGenerator : public RandomGenerator<T>
{
DiscreteGauss DG;

View File

@@ -1,6 +1,7 @@
#include "FHE/FFT.h"
#include "Math/Zp_Data.h"
#include "Processor/BaseMachine.h"
#include "Math/modp.hpp"
@@ -115,17 +116,38 @@ void FFT_Iter(vector<T>& ioput, int n, const T& root, const P& PrD)
*/
void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
{
FFT_Iter(ioput, n, root, PrD, false);
}
void FFT_Iter2(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD)
{
FFT_Iter(ioput, n, roots, PrD, false);
}
void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
bool start_with_one)
{
vector<modp> roots(n + 1);
assignOne(roots[0], PrD);
for (int i = 1; i < n + 1; i++)
Mul(roots[i], roots[i - 1], root, PrD);
FFT_Iter(ioput, n, roots, PrD, start_with_one);
}
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD, bool start_with_one)
{
assert(roots.size() > size_t(n));
int i, j, m;
modp t;
// Bit-reversal of input
for( i = j = 0; i < n; ++i )
{
if( j >= i )
{
t = ioput[i];
ioput[i] = ioput[j];
ioput[j] = t;
swap(ioput[i], ioput[j]);
}
m = n / 2;
@@ -136,27 +158,38 @@ void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
}
j += m;
}
modp u, alpha, alpha2;
m = 0; j = 0; i = 0;
// Do the transform
vector<modp> alpha2;
alpha2.reserve(n / 2);
for (int s = 1; s < n; s = 2*s)
{
m = 2*s;
Power(alpha, root, n/m, PrD);
alpha2 = alpha;
Mul(alpha, alpha, alpha, PrD);
for (int j = 0; j < m/2; ++j)
alpha2.clear();
if (start_with_one)
{
//root = root_table[(2*j+1)*n/m];
for (int k = j; k < n; k += m)
{
Mul(t, alpha2, ioput[k + m/2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m/2], u, t, PrD);
}
Mul(alpha2, alpha2, alpha, PrD);
for (int j = 0; j < m / 2; j++)
alpha2.push_back(roots[j * n / m]);
}
else
{
for (int j = 0; j < m / 2; j++)
alpha2.push_back(roots.at((j * 2 + 1) * (n / m)));
}
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{
auto& queues = BaseMachine::s().queues;
FftJob job(ioput, alpha2, m, PrD);
int start = queues.distribute(job, n / 2);
for (int i = start; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
queues.wrap_up(job);
}
else
for (int i = 0; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
}
}

View File

@@ -30,8 +30,29 @@ void FFT2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
template <class T,class P>
void FFT_Iter(vector<T>& a,int N,const T& theta,const P& PrD);
void FFT_Iter(vector<modp>& a, int N, const modp& theta, const Zp_Data& PrD,
bool start_with_one = true);
void FFT_Iter2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
// variants with precomputed roots
void FFT_Iter(vector<modp>& a, int N, const vector<modp>& theta,
const Zp_Data& PrD, bool start_with_one = true);
void FFT_Iter2(vector<modp>& a, int N, const vector<modp>& theta,
const Zp_Data& PrD);
inline void FFT_Iter2_body(vector<modp>& ioput, const vector<modp>& alpha2, int i,
int m, const Zp_Data& PrD)
{
int j = i % (m / 2);
int kk = i / (m / 2);
int k = j + kk * m;
modp t, u;
Mul(t, alpha2[j], ioput[k + m / 2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m / 2], u, t, PrD);
}
/* BFFT perform FFT and inverse FFT mod PrD for non power of two cyclotomics.
* The modulus in PrD (contained in FFT_Data) must be set up

View File

@@ -6,24 +6,6 @@
#include "Math/modp.hpp"
void FFT_Data::assign(const FFT_Data& FFTD)
{
prData=FFTD.prData;
R=FFTD.R;
root=FFTD.root;
twop=FFTD.twop;
two_root=FFTD.two_root;
powers=FFTD.powers;
powers_i=FFTD.powers_i;
b=FFTD.b;
iphi=FFTD.iphi;
}
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{
@@ -49,6 +31,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
Inv(root[1],root[0],PrD);
to_modp(iphi,Rg.phi_m(),PrD);
Inv(iphi,iphi,PrD);
compute_roots(Rg.m());
}
}
else
@@ -57,6 +40,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
{ throw invalid_params(); }
root[0]=Find_Primitive_Root_2m(Rg.m(),Rg.Phi(),PrD);
Inv(root[1],root[0],PrD);
compute_roots(2 * Rg.m());
int ptwop=twop; if (twop<0) { ptwop=-twop; }
@@ -97,6 +81,14 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
}
}
void FFT_Data::compute_roots(int n)
{
roots.resize(n + 1);
assignOne(roots[0], prData);
for (int i = 1; i < n + 1; i++)
Mul(roots[i], roots[i - 1], root[0], prData);
}
void FFT_Data::hash(octetStream& o) const
{
@@ -111,6 +103,7 @@ void FFT_Data::pack(octetStream& o) const
R.pack(o);
prData.pack(o);
o.store(root);
o.store(roots);
o.store(twop);
o.store(two_root);
o.store(b);
@@ -125,6 +118,7 @@ void FFT_Data::unpack(octetStream& o)
R.unpack(o);
prData.unpack(o);
o.get(root);
o.get(roots);
o.get(twop);
o.get(two_root);
o.get(b);
@@ -133,7 +127,6 @@ void FFT_Data::unpack(octetStream& o)
o.get(powers_i);
}
bool FFT_Data::operator!=(const FFT_Data& other) const
{
if (R != other.R or prData != other.prData or root != other.root

View File

@@ -19,6 +19,7 @@ class FFT_Data
Zp_Data prData;
vector<modp> root; // 2m'th Root of Unity mod pr and it's inverse
vector<modp> roots; // precomputed powers of root
// When twop is equal to zero, m is a power of two
// When twop is positive it is equal to 2^e where 2^e>2*m and 2^e divides p-1
@@ -34,6 +35,8 @@ class FFT_Data
modp iphi; // 1/phi_m mod pr
vector< vector<modp> > powers,powers_i;
void compute_roots(int n);
public:
typedef gfp T;
typedef bigint S;
@@ -47,17 +50,9 @@ class FFT_Data
void pack(octetStream& o) const;
void unpack(octetStream& o);
void assign(const FFT_Data& FFTD);
FFT_Data() { ; }
FFT_Data(const FFT_Data& FFTD)
{ assign(FFTD); }
FFT_Data(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }
FFT_Data& operator=(const FFT_Data& FFTD)
{ if (this!=&FFTD) { assign(FFTD); }
return *this;
}
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }
@@ -72,6 +67,7 @@ class FFT_Data
int get_twop() const { return twop; }
modp get_root(int i) const { return root[i]; }
modp get_iphi() const { return iphi; }
const vector<modp>& get_roots() const { return roots; }
const Ring& get_R() const { return R; }

View File

@@ -42,7 +42,7 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
{
Rq_Element sk = FHE_SK(*this).s();
// Generate the secret key
sk.from_vec((*params).sampleHwt(G));
sk.from(GaussianGenerator<bigint>(params->get_DG(), G));
return sk;
}
@@ -55,7 +55,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
// b0=a0*s+p*e0
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
e0.from_vec((*PK.params).sampleGaussian(G, noise_boost));
e0.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
mul(PK.b0,PK.a0,sk);
mul(e0,e0,PK.pr);
add(PK.b0,PK.b0,e0);
@@ -72,7 +72,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
// bs=as*s+p*es
Rq_Element es((*PK.params).FFTD(),evaluation,evaluation);
es.from_vec((*PK.params).sampleGaussian(G, noise_boost));
es.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
mul(PK.Sw_b,PK.Sw_a,sk);
mul(es,es,PK.pr);
add(PK.Sw_b,PK.Sw_b,es);
@@ -120,13 +120,14 @@ void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) const
}
template<>
template<class T, class FD, class S>
void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gfp,FFT_Data,bigint>& mess,const Random_Coins& rc) const
const Plaintext<T, FD, S>& mess,const Random_Coins& rc) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element mm((*params).FFTD(),polynomial,polynomial);
mm.from(mess.get_iterator());
@@ -134,35 +135,6 @@ void FHE_PK::encrypt(Ciphertext& c,
quasi_encrypt(c,mm,rc);
}
template<>
void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gfp,PPData,bigint>& mess,const Random_Coins& rc) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
mess.to_poly();
encrypt(c, mess.get_poly(), rc);
}
template<>
void FHE_PK::encrypt(Ciphertext& c,
const Plaintext<gf2n_short,P2Data,int>& mess,const Random_Coins& rc) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (&rc.get_params()!=params) { throw params_mismatch(); }
if (pr!=2) { throw pr_mismatch(); }
mess.to_poly();
encrypt(c, mess.get_poly(), rc);
}
void FHE_PK::quasi_encrypt(Ciphertext& c,
const Rq_Element& mess,const Random_Coins& rc) const
{
@@ -212,42 +184,12 @@ Ciphertext FHE_PK::encrypt(
}
template<>
void FHE_SK::decrypt(Plaintext<gfp,FFT_Data,bigint>& mess,const Ciphertext& c) const
template<class T, class FD, class S>
void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
ans.change_rep(polynomial);
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
}
template<>
void FHE_SK::decrypt(Plaintext<gfp,PPData,bigint>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr==2) { throw pr_mismatch(); }
Rq_Element ans;
mul(ans,c.c1(),sk);
sub(ans,c.c0(),ans);
mess.set_poly_mod(ans.to_vec_bigint(),ans.get_modulus());
}
template<>
void FHE_SK::decrypt(Plaintext<gf2n_short,P2Data,int>& mess,const Ciphertext& c) const
{
if (&c.get_params()!=params) { throw params_mismatch(); }
if (pr!=2) { throw pr_mismatch(); }
if (T::characteristic_two ^ (pr == 2))
throw pr_mismatch();
Rq_Element ans;

View File

@@ -3,14 +3,6 @@
#include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h"
void FHE_Params::set(const Ring& R,
const vector<bigint>& primes,double r,int hwt)
{
set(R, primes);
Chi.set(R.phi_m(),hwt,r);
}
void FHE_Params::set(const Ring& R,
const vector<bigint>& primes)
{
@@ -20,7 +12,6 @@ void FHE_Params::set(const Ring& R,
for (size_t i = 0; i < FFTData.size(); i++)
FFTData[i].init(R,primes[i]);
Chi.set_n(R.phi_m());
set_sec(40);
}

View File

@@ -21,7 +21,7 @@ class FHE_Params
vector<FFT_Data> FFTData;
// Random generator for Multivariate Gaussian Distribution etc
RandomVectors Chi;
mutable DiscreteGauss Chi;
// Data for distributed decryption
int sec_p;
@@ -29,27 +29,17 @@ class FHE_Params
public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(-1, 0.7), sec_p(-1) {}
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
int n_mults() const { return FFTData.size() - 1; }
// Rely on default copy assignment/constructor (not that they should
// ever be needed)
void set(const Ring& R,const vector<bigint>& primes,double r,int hwt);
void set(const Ring& R,const vector<bigint>& primes);
void set(const vector<bigint>& primes);
void set_sec(int sec);
vector<bigint> sampleGaussian(PRNG& G, int noise_boost = 1) const
{ return Chi.sample_Gauss(G, noise_boost); }
vector<bigint> sampleHwt(PRNG& G) const
{ return Chi.sample_Hwt(G); }
vector<bigint> sampleHalf(PRNG& G) const
{ return Chi.sample_Half(G); }
vector<bigint> sampleUniform(PRNG& G,const bigint& Bd) const
{ return Chi.sample_Uniform(G,Bd); }
const vector<FFT_Data>& FFTD() const { return FFTData; }
const bigint& p0() const { return FFTData[0].get_prime(); }
@@ -59,9 +49,8 @@ class FHE_Params
int secp() const { return sec_p; }
const bigint& B() const { return Bval; }
double get_R() const { return Chi.get_R(); }
void set_R(double R) const { return Chi.get_DG().set(R); }
DiscreteGauss get_DG() const { return Chi.get_DG(); }
int get_h() const { return Chi.get_h(); }
void set_R(double R) const { return Chi.set(R); }
DiscreteGauss get_DG() const { return Chi; }
int phi_m() const { return FFTData[0].phi_m(); }
const Ring& get_ring() { return FFTData[0].get_R(); }

View File

@@ -52,10 +52,12 @@ int generate_semi_setup(int plaintext_length, int sec,
bigint p;
generate_prime(p, lgp, m);
int lgp0, lgp1;
FHE_Params tmp_params;
while (true)
{
tmp_params = params;
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
bigint p1 = 2 * p * m, p0 = p;
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
{
@@ -75,6 +77,7 @@ int generate_semi_setup(int plaintext_length, int sec,
}
}
params = tmp_params;
int extra_slack = common_semi_setup(params, m, p, lgp0, lgp1, round_up);
FTD.init(params.get_ring(), p);

View File

@@ -13,29 +13,24 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
const FHE_Params& params) :
p(p), phi_m(phi_m), n(n), sec(sec),
slack(numBits(Proof::slack(slack_param, sec, phi_m))),
sigma(params.get_R()), h(params.get_h())
sigma(params.get_R())
{
if (sigma <= 0)
this->sigma = sigma = FHE_Params().get_R();
#ifdef VERBOSE
cerr << "Standard deviation: " << this->sigma << endl;
#endif
if (h > 0)
h += extra_h * sec;
else if (extra_h)
if (extra_h)
{
sigma *= 1.4;
params.set_R(params.get_R() * 1.4);
}
#ifdef VERBOSE
cerr << "Standard deviation: " << this->sigma << endl;
#endif
produce_epsilon_constants();
// according to documentation of SCALE-MAMBA 1.7
// excluding a factor of n because we don't always add up n ciphertexts
if (h > 0)
V_s = sqrt(h);
else
V_s = sigma * sqrt(phi_m);
V_s = sigma * sqrt(phi_m);
B_clean = (bigint(phi_m) << (sec + 1)) * p
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s);
// unify parameters by taking maximum over TopGear or not

View File

@@ -22,7 +22,6 @@ protected:
const int sec;
int slack;
mpf_class sigma;
int h;
bigint B_clean;
bigint B_scale;

View File

@@ -5,14 +5,6 @@
void PPData::assign(const PPData& PPD)
{
R=PPD.R;
prData=PPD.prData;
root=PPD.root;
}
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;

View File

@@ -27,17 +27,9 @@ class PPData
void init(const Ring& Rg,const Zp_Data& PrD);
void assign(const PPData& PPD);
PPData() { ; }
PPData(const PPData& PPD)
{ assign(PPD); }
PPData(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }
PPData& operator=(const PPData& PPD)
{ if (this!=&PPD) { assign(PPD); }
return *this;
}
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }

View File

@@ -5,6 +5,7 @@
#include "FHE/P2Data.h"
#include "FHE/Rq_Element.h"
#include "FHE_Keys.h"
#include "FHE/AddableVector.hpp"
#include "Math/Z2k.hpp"
#include "Math/modp.hpp"
@@ -258,37 +259,9 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
}
template<>
void Plaintext_<FFT_Data>::randomize(PRNG& G, bigint B, bool Diag, bool binary, PT_Type t)
{
if (Diag or binary)
throw not_implemented();
if (B == 0)
throw runtime_error("cannot randomize modulo 0");
allocate(t);
switch (t)
{
case Polynomial:
rand_poly(b, G, B, false);
break;
case Evaluation:
for (int i = 0; i < n_slots; i++)
a[i] = G.randomBnd(B);
break;
default:
throw runtime_error("wrong type for randomization with bound");
break;
}
}
template<class T,class FD,class S>
void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, bool binary, PT_Type t)
void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
{
if (binary)
throw not_implemented();
allocate(t);
switch(t)
{
@@ -614,10 +587,11 @@ void Plaintext<gf2n_short,P2Data,int>::negate()
template<class T, class FD, class S>
Rq_Element Plaintext<T, FD, S>::mul_by_X_i(int i, const FHE_PK& pk) const
template<class T, class FD, class _>
AddableVector<typename FD::poly_type> Plaintext<T, FD, _>::mul_by_X_i(int i,
const FHE_PK& pk) const
{
return Rq_Element(pk.get_params(), *this).mul_by_X_i(i);
return AddableVector<S>(get_poly()).mul_by_X_i(i, pk);
}

View File

@@ -25,6 +25,7 @@ using namespace std;
class FHE_PK;
class Rq_Element;
template<class T> class AddableVector;
// Forward declaration as apparently this is needed for friends in templates
template<class T,class FD,class S> class Plaintext;
@@ -64,13 +65,6 @@ class Plaintext
const FD& get_field() const { return *Field_Data; }
unsigned int num_slots() const { return n_slots; }
void assign(const Plaintext& p)
{ Field_Data=p.Field_Data;
a=p.a; b=p.b; type=p.type;
n_slots = p.n_slots;
degree = p.degree;
}
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
@@ -142,8 +136,7 @@ class Plaintext
void to_poly() const;
void randomize(PRNG& G,condition cond=Full);
void randomize(PRNG& G, bigint B, bool Diag=false, bool binary=false, PT_Type type=Polynomial);
void randomize(PRNG& G, int n_bits, bool Diag=false, bool binary=false, PT_Type type=Polynomial);
void randomize(PRNG& G, int n_bits, bool Diag=false, PT_Type type=Polynomial);
void assign_zero(PT_Type t = Evaluation);
void assign_one(PT_Type t = Evaluation);
@@ -171,13 +164,12 @@ class Plaintext
void negate();
Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const;
AddableVector<S> mul_by_X_i(int i, const FHE_PK& pk) const;
bool equals(const Plaintext& x) const;
bool operator!=(const Plaintext& x) { return !equals(x); }
bool is_diagonal() const;
bool is_binary() const { throw not_implemented(); }
/* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already

View File

@@ -52,8 +52,6 @@ class Random_Coins
{ params=&p; }
Random_Coins(const FHE_PK& pk);
~Random_Coins() { ; }
// Rely on default copy assignment/constructor

View File

@@ -33,17 +33,36 @@ Ring_Element::Ring_Element(const FFT_Data& fftd,RepType r)
}
void Ring_Element::prepare(const Ring_Element& other)
{
assert(this != &other);
FFTD = other.FFTD;
rep = other.rep;
prepare_push();
}
void Ring_Element::prepare_push()
{
element.clear();
element.reserve(FFTD->phi_m());
}
void Ring_Element::allocate()
{
element.resize(FFTD->phi_m());
}
void Ring_Element::assign_zero()
{
element.resize((*FFTD).phi_m());
for (int i=0; i<(*FFTD).phi_m(); i++)
{ assignZero(element[i],(*FFTD).get_prD()); }
element.clear();
}
void Ring_Element::assign_one()
{
element.resize((*FFTD).phi_m());
allocate();
modp fill;
if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); }
else { assignOne(fill,(*FFTD).get_prD()); }
@@ -56,6 +75,9 @@ void Ring_Element::assign_one()
void Ring_Element::negate()
{
if (element.empty())
return;
for (int i=0; i<(*FFTD).phi_m(); i++)
{ Negate(element[i],element[i],(*FFTD).get_prD()); }
}
@@ -66,20 +88,58 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a);
if (a.element.empty())
{
ans = b;
return;
}
else if (b.element.empty())
{
ans = a;
return;
}
if (&ans == &a)
{
ans += b;
return;
}
else if (&ans == &b)
{
ans += a;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Add(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
ans.element.push_back(a.element[i].add(b.element[i], a.FFTD->get_prD()));
}
void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a);
if (a.element.empty())
{
ans = b;
ans.negate();
return;
}
else if (b.element.empty())
{
ans = a;
return;
}
if (&ans == &a)
{
ans -= b;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Sub(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
ans.element.push_back(a.element[i].sub(b.element[i], a.FFTD->get_prD()));
}
@@ -88,13 +148,29 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
{
if (a.rep!=b.rep) { throw rep_mismatch(); }
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
ans.partial_assign(a);
if (ans.rep==evaluation)
{ // In evaluation representation, so we can just multiply componentwise
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Mul(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
if (a.element.empty() or b.element.empty())
{
ans = Ring_Element(*a.FFTD, a.rep);
return;
}
else if ((*ans.FFTD).get_twop()!=0)
if (a.rep==evaluation)
{ // In evaluation representation, so we can just multiply componentwise
if (&ans == &a)
{
ans *= b;
return;
}
else if (&ans == &b)
{
ans *= a;
return;
}
ans.prepare(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
ans.element.push_back(a.element[i].mul(b.element[i], a.FFTD->get_prD()));
}
else if ((*a.FFTD).get_twop()!=0)
{ // This is the case where m is not a power of two
// Here we have to do a poly mult followed by a reduction
@@ -116,11 +192,13 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
// Now apply reduction, assumes Ring.poly is monic
reduce(aa, 2*(*a.FFTD).phi_m(), (*a.FFTD).phi_m(), *a.FFTD);
// Now stick into answer
ans.partial_assign(a);
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ ans.element[i]=aa[i]; }
}
else if ((*ans.FFTD).get_twop()==0)
else if ((*a.FFTD).get_twop()==0)
{ // m a power of two case
ans.partial_assign(a);
Ring_Element aa(*ans.FFTD,ans.rep);
modp temp;
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
@@ -143,31 +221,89 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
{
ans.partial_assign(a);
if (&ans == &a)
{
ans *= b;
return;
}
ans.prepare(a);
if (a.element.empty())
return;
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
{ Mul(ans.element[i],a.element[i],b,(*a.FFTD).get_prD()); }
ans.element.push_back(a.element[i].mul(b, a.FFTD->get_prD()));
}
Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].add(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].sub(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
{
assert(element.size() == other.element.size());
assert(FFTD == other.FFTD);
assert(rep == other.rep);
assert(rep == evaluation);
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].mul(other.element[i], FFTD->get_prD());
return *this;
}
Ring_Element& Ring_Element::operator *=(const modp& other)
{
for (size_t i = 0; i < element.size(); i++)
element[i] = element[i].mul(other, FFTD->get_prD());
return *this;
}
Ring_Element Ring_Element::mul_by_X_i(int j) const
{
Ring_Element ans;
ans.prepare(*this);
if (element.empty())
return ans;
auto& a = *this;
ans.partial_assign(a);
if (ans.rep == evaluation)
{
modp xj, xj2;
Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD());
Sqr(xj2, xj, (*a.FFTD).get_prD());
ans.prepare_push();
modp tmp;
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{
Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD());
Mul(tmp, a.element[i], xj, (*a.FFTD).get_prD());
ans.element.push_back(tmp);
Mul(xj, xj, xj2, (*a.FFTD).get_prD());
}
}
else
{
Ring_Element aa(*ans.FFTD, ans.rep);
aa.allocate();
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{
int k= j + i, s= 1;
@@ -193,6 +329,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const
void Ring_Element::randomize(PRNG& G,bool Diag)
{
allocate();
if (Diag==false)
{ for (int i=0; i<(*FFTD).phi_m(); i++)
{ element[i].randomize(G,(*FFTD).get_prD()); }
@@ -213,12 +350,18 @@ void Ring_Element::randomize(PRNG& G,bool Diag)
void Ring_Element::change_rep(RepType r)
{
if (element.empty())
{
rep = r;
return;
}
if (rep==r) { return; }
if (r==evaluation)
{ rep=evaluation;
if ((*FFTD).get_twop()==0)
{ // m a power of two variant
FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_root(0),(*FFTD).get_prD());
FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_roots(),(*FFTD).get_prD());
}
else
{ // Non m power of two variant and FFT enabled
@@ -258,6 +401,11 @@ void Ring_Element::change_rep(RepType r)
bool Ring_Element::equals(const Ring_Element& a) const
{
if (element.empty() and a.element.empty())
return true;
else if (element.empty() or a.element.empty())
throw not_implemented();
if (rep!=a.rep) { throw rep_mismatch(); }
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
for (int i=0; i<(*FFTD).phi_m(); i++)
@@ -266,34 +414,11 @@ bool Ring_Element::equals(const Ring_Element& a) const
}
void Ring_Element::from_vec(const vector<bigint>& v)
{
RepType t=rep;
rep=polynomial;
bigint tmp;
for (int i=0; i<(*FFTD).phi_m(); i++)
{
tmp = v[i];
element[i].convert_destroy(tmp, FFTD->get_prD());
}
change_rep(t);
// cout << "RE:from_vec<bigint>:: " << *this << endl;
}
void Ring_Element::from_vec(const vector<int>& v)
{
RepType t=rep;
rep=polynomial;
for (int i=0; i<(*FFTD).phi_m(); i++)
{ to_modp(element[i],v[i],(*FFTD).get_prD()); }
change_rep(t);
// cout << "RE:from_vec<int>:: " << *this << endl;
}
ConversionIterator Ring_Element::get_iterator() const
{
if (rep != polynomial)
throw runtime_error("simple iterator only available in polynomial represention");
assert(not element.empty());
return {element, (*FFTD).get_prD()};
}
@@ -318,6 +443,9 @@ vector<bigint> Ring_Element::to_vec_bigint() const
void Ring_Element::to_vec_bigint(vector<bigint>& v) const
{
v.resize(FFTD->phi_m());
if (element.empty())
return;
if (rep==polynomial)
{ for (int i=0; i<(*FFTD).phi_m(); i++)
{ to_bigint(v[i],element[i],(*FFTD).get_prD()); }
@@ -336,11 +464,10 @@ void Ring_Element::to_vec_bigint(vector<bigint>& v) const
modp Ring_Element::get_constant() const
{
if (rep==polynomial)
{ return element[0]; }
Ring_Element a=*this;
a.change_rep(polynomial);
return a.element[0];
if (element.empty())
return {};
else
return element[0];
}
@@ -364,9 +491,14 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
+ to_string(ZpD.pr_bit_length));
unsigned int length;
o.get(length);
v.resize(length);
v.clear();
v.reserve(length);
modp tmp;
for (unsigned int i=0; i<length; i++)
{ v[i].unpack(o,ZpD); }
{
tmp.unpack(o,ZpD);
v.push_back(tmp);
}
}
@@ -398,7 +530,7 @@ void Ring_Element::check_rep()
void Ring_Element::check_size() const
{
if ((int)element.size() != FFTD->phi_m())
if (not element.empty() and (int)element.size() != FFTD->phi_m())
throw runtime_error("invalid element size");
}

View File

@@ -41,12 +41,6 @@ class Ring_Element
vector<modp> element;
// Define a copy
void assign(const Ring_Element& e)
{ rep=e.rep; FFTD=e.FFTD;
element=e.element;
}
public:
// Used to basically make sure *this is able to cope
@@ -57,6 +51,10 @@ class Ring_Element
element.resize((*FFTD).phi_m());
}
void prepare(const Ring_Element& e);
void prepare_push();
void allocate();
void set_data(const FFT_Data& prd) { FFTD=&prd; }
const FFT_Data& get_FFTD() const { return *FFTD; }
const Zp_Data& get_prD() const { return (*FFTD).get_prD(); }
@@ -80,19 +78,6 @@ class Ring_Element
element.push_back(x);
}
// Copy Constructor
Ring_Element(const Ring_Element& e)
{ assign(e); }
// Destructor
~Ring_Element() { ; }
// Copy Assignment
Ring_Element& operator=(const Ring_Element& e)
{ if (this!=&e) { assign(e); }
return *this;
}
/* Functional Operators */
void negate();
friend void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b);
@@ -102,6 +87,11 @@ class Ring_Element
Ring_Element mul_by_X_i(int i) const;
Ring_Element& operator+=(const Ring_Element& other);
Ring_Element& operator-=(const Ring_Element& other);
Ring_Element& operator*=(const Ring_Element& other);
Ring_Element& operator*=(const modp& other);
void randomize(PRNG& G,bool Diag=false);
bool equals(const Ring_Element& a) const;
@@ -112,8 +102,6 @@ class Ring_Element
// Converting to and from a vector of bigint/int's
// I/O is assumed to be in poly rep, so from_vec it internally alters
// the representation to the current representation
void from_vec(const vector<bigint>& v);
void from_vec(const vector<int>& v);
vector<bigint> to_vec_bigint() const;
void to_vec_bigint(vector<bigint>& v) const;
@@ -136,8 +124,18 @@ class Ring_Element
// This gets the constant term of the poly rep as a modp element
modp get_constant() const;
modp get_element(int i) const { return element[i]; }
void set_element(int i,const modp& a) { element[i]=a; }
modp get_element(int i) const
{
if (element.empty())
return {};
else
return element[i];
}
void set_element(int i,const modp& a)
{
allocate();
element[i] = a;
}
/* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already
@@ -164,7 +162,11 @@ class RingWriteIterator : public WriteConversionIterator
public:
RingWriteIterator(Ring_Element& element) :
WriteConversionIterator(element.element, element.FFTD->get_prD()),
element(element), rep(element.rep) { element.rep = polynomial; }
element(element), rep(element.rep)
{
element.rep = polynomial;
element.allocate();
}
~RingWriteIterator() { element.change_rep(rep); }
};
@@ -175,7 +177,11 @@ class RingReadIterator : public ConversionIterator
public:
RingReadIterator(const Ring_Element& element) :
ConversionIterator(this->element.element, element.FFTD->get_prD()),
element(element) { this->element.change_rep(polynomial); }
element(element)
{
this->element.change_rep(polynomial);
this->element.allocate();
}
};
@@ -189,10 +195,13 @@ void Ring_Element::from(const Generator<T>& generator)
RepType t=rep;
rep=polynomial;
T tmp;
modp tmp2;
prepare_push();
for (int i=0; i<(*FFTD).phi_m(); i++)
{
generator.get(tmp);
element[i].convert_destroy(tmp, (*FFTD).get_prD());
tmp2.convert_destroy(tmp, (*FFTD).get_prD());
element.push_back(tmp2);
}
change_rep(t);
}

View File

@@ -48,15 +48,6 @@ void Rq_Element::partial_assign(const Rq_Element& other)
{
lev=other.lev;
a.resize(other.a.size());
for (size_t i = 0; i < a.size(); i++)
a[i].partial_assign(other.a[i]);
}
void Rq_Element::assign(const Rq_Element& other)
{
partial_assign(other);
for (int i=0; i<=lev; ++i)
a[i] = other.a[i];
}
void Rq_Element::negate()
@@ -134,20 +125,6 @@ bool Rq_Element::equals(const Rq_Element& other) const
}
void Rq_Element::from_vec(const vector<bigint>& v,int level)
{
set_level(level);
for (int i=0;i<=lev;++i)
a[i].from_vec(v);
}
void Rq_Element::from_vec(const vector<int>& v,int level)
{
set_level(level);
for (int i=0;i<=lev;++i)
a[i].from_vec(v);
}
vector<bigint> Rq_Element::to_vec_bigint() const
{
vector<bigint> v;

View File

@@ -44,7 +44,6 @@ protected:
void assign_zero(const vector<FFT_Data>& prd);
void assign_zero();
void assign_one();
void assign(const Rq_Element& e);
void partial_assign(const Rq_Element& e);
// Must be careful not to call by mistake
@@ -85,10 +84,6 @@ protected:
a[1] = Ring_Element(prd[1], r, b1);
}
// Destructor
~Rq_Element()
{ ; }
const Ring_Element& get(int i) const { return a[i]; }
/* Functional Operators */
@@ -131,8 +126,6 @@ protected:
void partial_assign(const Rq_Element& a, const Rq_Element& b);
// Converting to and from a vector of bigint's Again I/O is in poly rep
void from_vec(const vector<bigint>& v,int level=-1);
void from_vec(const vector<int>& v,int level=-1);
vector<bigint> to_vec_bigint() const;
void to_vec_bigint(vector<bigint>& v) const;

View File

@@ -49,7 +49,8 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
if (not error.empty())
{
cerr << "Running secrets generation because " << error << endl;
cerr << "Running secrets generation because no suitable material "
"from a previous run was found (" << error << ")" << endl;
setup.key_and_mac_generation(P, machine, num_runs, V());
ofstream output(filename);

View File

@@ -109,11 +109,11 @@ DistKeyGen::DistKeyGen(const FHE_Params& params, const bigint& p) :
*/
void DistKeyGen::Gen_Random_Data(PRNG& G)
{
secret.from_vec(params.sampleHwt(G));
secret.from(GaussianGenerator<bigint>(params.get_DG(), G));
rc1.generate(G);
rc2.generate(G);
a.randomize(G);
e.from_vec(params.sampleGaussian(G));
e.from(GaussianGenerator<bigint>(params.get_DG(), G));
}
DistKeyGen& DistKeyGen::operator+=(const DistKeyGen& other)

View File

@@ -45,7 +45,7 @@ template <class FD>
void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role)
{
octetStream o;
o.reset_write_head();
if (role & SENDER)
{

View File

@@ -36,6 +36,8 @@ class Multiplier
size_t volatile_capacity;
MemoryUsage memory_usage;
octetStream o;
public:
Multiplier(int offset, PairwiseGenerator<FD>& generator);
Multiplier(int offset, PairwiseMachine& machine, Player& P,

View File

@@ -1,32 +0,0 @@
/*
* Player-Offline.h
*
*/
#ifndef FHEOFFLINE_PLAYER_OFFLINE_H_
#define FHEOFFLINE_PLAYER_OFFLINE_H_
class thread_info
{
public:
int thread_num;
int covert;
Names* Nms;
FHE_PK* pk_p;
FHE_PK* pk_2;
FHE_SK* sk_p;
FHE_SK* sk_2;
Ciphertext *calphap;
Ciphertext *calpha2;
gfp *alphapi;
gf2n_short *alpha2i;
FFT_Data *FTD;
P2Data *P2D;
int nm2,nmp,nb2,nbp,ni2,nip,ns2,nsp,nvp;
bool skip_2() { return nm2 + ni2 + nb2 + ns2 == 0; }
};
#endif /* FHEOFFLINE_PLAYER_OFFLINE_H_ */

View File

@@ -589,7 +589,7 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
P.receive_player(j, cleartexts);
C.resize(personal_EC.machine->sec, pk.get_params());
Verifier<FD>(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts,
cleartexts, pk, false);
cleartexts, pk);
}
inputs[j].clear();

View File

@@ -88,6 +88,7 @@ public:
bool Proof::check_bounds(T& z, X& t, int i) const
{
(void)i;
unsigned int j,k;
// Check Bound 1 and Bound 2
@@ -99,9 +100,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
auto& te = z[j];
if (plain_checker.outside(te, dist))
{
#ifdef VERBOSE
cout << "Fail on Check 1 " << i << " " << j << endl;
cout << te << " " << plain_check << endl;
cout << tau << " " << sec << " " << n_proofs << endl;
#endif
return false;
}
}
@@ -113,9 +116,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
auto& te = coeffs.at(j);
if (rand_checker.outside(te, dist))
{
#ifdef VERBOSE
cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl;
cout << te << " " << rand_check << endl;
cout << rho << " " << sec << " " << n_proofs << endl;
#endif
return false;
}
}

View File

@@ -6,6 +6,7 @@
#include "Tools/random.h"
#include "Math/Z2k.hpp"
#include "Math/modp.hpp"
#include "FHE/AddableVector.hpp"
template <class FD, class U>
@@ -28,7 +29,7 @@ Prover<FD,U>::Prover(Proof& proof, const FD& FieldD) :
template <class FD, class U>
void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
const AddableVector<Ciphertext>& c,
const FHE_PK& pk, bool binary)
const FHE_PK& pk)
{
size_t allocate = 3 * c.size() * c[0].report_size(USED);
ciphertexts.resize_precise(allocate);
@@ -51,7 +52,7 @@ void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
// AE.randomize(Diag,binary);
// rd=RandPoly(phim,bd<<1);
// y[i]=AE.plaintext()+pr*rd;
y[i].randomize(G, P.B_plain_length, P.get_diagonal(), binary);
y[i].randomize(G, P.B_plain_length, P.get_diagonal());
if (P.get_diagonal())
assert(y[i].is_diagonal());
s[i].resize(3, P.phim);
@@ -114,8 +115,7 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
const FHE_PK& pk,
const AddableVector<Ciphertext>& c,
const vector<U>& x,
const Proof::Randomness& r,
bool binary)
const Proof::Randomness& r)
{
// AElement<T> AE;
// for (i=0; i<P.sec; i++)
@@ -130,13 +130,15 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
int cnt=0;
while (!ok)
{ cnt++;
Stage_1(P,ciphertexts,c,pk,binary);
Stage_1(P,ciphertexts,c,pk);
P.set_challenge(ciphertexts);
// Check check whether we are OK, or whether we should abort
ok = Stage_2(P,cleartexts,x,r,pk);
}
#ifdef VERBOSE
if (cnt > 1)
cout << "\t\tNumber iterations of prover = " << cnt << endl;
#endif
return report_size(CAPACITY) + volatile_memory;
}

View File

@@ -24,8 +24,7 @@ public:
Prover(Proof& proof, const FD& FieldD);
void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector<Ciphertext>& c,
const FHE_PK& pk,
bool binary = false);
const FHE_PK& pk);
bool Stage_2(Proof& P, octetStream& cleartexts,
const vector<U>& x,
@@ -40,8 +39,7 @@ public:
const FHE_PK& pk,
const AddableVector<Ciphertext>& c,
const vector<U>& x,
const Proof::Randomness& r,
bool binary=false);
const Proof::Randomness& r);
size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res);

View File

@@ -11,6 +11,7 @@
#include "Protocols/MAC_Check.h"
#include "Protocols/MAC_Check.hpp"
#include "Math/modp.hpp"
template<class T, class FD, class S>
SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) :
@@ -63,7 +64,10 @@ SimpleEncCommitFactory<FD>::SimpleEncCommitFactory(const FHE_PK& pk,
template <class FD>
SimpleEncCommitFactory<FD>::~SimpleEncCommitFactory()
{
cout << "EncCommit called " << n_calls << " times" << endl;
#ifdef VERBOSE_HE
if (n_calls > 0)
cout << "EncCommit called " << n_calls << " times" << endl;
#endif
}
template<class FD>
@@ -131,7 +135,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
Prover<FD, Plaintext_<FD> > prover(proof, FTD);
#endif
size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts,
pk, c, m, r, false);
pk, c, m, r);
timers["Proving"].stop();
if (proof.top_gear)
@@ -192,7 +196,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
#endif
timers["Verifying"].start();
verifier.NIZKPoK(others_ciphertexts, ciphertexts,
cleartexts, get_pk_for_verification(i), false);
cleartexts, get_pk_for_verification(i));
timers["Verifying"].stop();
add_ciphertexts(others_ciphertexts, i);
this->memory_usage.update("verifier", verifier.report_size(CAPACITY));
@@ -251,7 +255,7 @@ void SummingEncCommit<FD>::create_more()
#endif
this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof);
this->timers["Stage 1 of proof"].start();
prover.Stage_1(proof, ciphertexts, this->c, this->pk, false);
prover.Stage_1(proof, ciphertexts, this->c, this->pk);
this->timers["Stage 1 of proof"].stop();
this->c.unpack(ciphertexts, this->pk);
@@ -291,8 +295,10 @@ void SummingEncCommit<FD>::create_more()
for (int i = 1; i < P.num_players(); i++)
{
#ifdef VERBOSE_HE
cout << "Sending cleartexts with " << 1e-9 * cleartexts.get_length()
<< " GB in round " << i << endl;
#endif
TimeScope(this->timers["Exchanging cleartexts"]);
P.pass_around(cleartexts);
preimages.add(cleartexts);
@@ -312,7 +318,7 @@ void SummingEncCommit<FD>::create_more()
Verifier<FD> verifier(proof);
#endif
verifier.Stage_2(this->c, ciphertexts, cleartexts,
this->pk, false);
this->pk);
this->timers["Verifying"].stop();
this->cnt = proof.U - 1;

View File

@@ -25,7 +25,10 @@ bool Check_Decoding(const Plaintext<T,FD,S>& AE,bool Diag)
// return false;
// }
if (Diag && !AE.is_diagonal())
{ cout << "Fail Check 5 " << endl;
{
#ifdef VERBOSE
cout << "Fail Check 5 " << endl;
#endif
return false;
}
return true;
@@ -62,7 +65,7 @@ template <class FD>
void Verifier<FD>::Stage_2(
AddableVector<Ciphertext>& c,octetStream& ciphertexts,
octetStream& cleartexts,
const FHE_PK& pk,bool binary)
const FHE_PK& pk)
{
unsigned int i, V;
@@ -90,18 +93,19 @@ void Verifier<FD>::Stage_2(
rc.assign(t[0], t[1], t[2]);
pk.encrypt(d2,z,rc);
if (!(d1 == d2))
{ cout << "Fail Check 6 " << i << endl;
{
#ifdef VERBOSE
cout << "Fail Check 6 " << i << endl;
#endif
throw runtime_error("ciphertexts don't match");
}
if (!Check_Decoding(z,P.get_diagonal(),FieldD))
{ cout << "\tCheck : " << i << endl;
{
#ifdef VERBOSE
cout << "\tCheck : " << i << endl;
#endif
throw runtime_error("cleartext isn't diagonal");
}
if (binary && !z.is_binary())
{
cout << "Not binary " << i << endl;
throw runtime_error("cleartext isn't binary");
}
}
}
@@ -112,17 +116,15 @@ void Verifier<FD>::Stage_2(
template <class FD>
void Verifier<FD>::NIZKPoK(AddableVector<Ciphertext>& c,
octetStream& ciphertexts, octetStream& cleartexts,
const FHE_PK& pk,
bool binary)
const FHE_PK& pk)
{
P.set_challenge(ciphertexts);
Stage_2(c,ciphertexts,cleartexts,pk,binary);
Stage_2(c,ciphertexts,cleartexts,pk);
if (P.top_gear)
{
assert(not P.get_diagonal());
assert(not binary);
c += c;
}
}

View File

@@ -21,14 +21,14 @@ public:
void Stage_2(
AddableVector<Ciphertext>& c, octetStream& ciphertexts,
octetStream& cleartexts,const FHE_PK& pk,bool binary=false);
octetStream& cleartexts,const FHE_PK& pk);
/* This is the non-interactive version using the ROM
- Creates space for all output values
- Diag flag mirrors that in Prover
*/
void NIZKPoK(AddableVector<Ciphertext>& c,octetStream& ciphertexts,octetStream& cleartexts,
const FHE_PK& pk,bool binary=false);
const FHE_PK& pk);
size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); }
};

View File

@@ -19,13 +19,13 @@ template<class T>
class CcdPrep : public BufferPrep<T>
{
typename T::part_type::LivePrep part_prep;
typename T::part_type::MAC_Check part_MC;
SubProcessor<typename T::part_type>* part_proc;
ShareThread<T>& thread;
public:
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
BufferPrep<T>(usage), part_prep(usage), part_proc(0), thread(thread)
BufferPrep<T>(usage), part_prep(usage, thread), part_proc(0),
thread(thread)
{
}
@@ -34,17 +34,9 @@ public:
{
}
~CcdPrep()
{
if (part_proc)
delete part_proc;
}
~CcdPrep();
void set_protocol(typename T::Protocol& protocol)
{
part_proc = new SubProcessor<typename T::part_type>(part_MC,
part_prep, protocol.get_part().P);
}
void set_protocol(typename T::Protocol& protocol);
Preprocessing<typename T::part_type>& get_part()
{
@@ -53,7 +45,16 @@ public:
void buffer_triples()
{
throw not_implemented();
assert(part_proc);
this->triples.push_back({});
for (auto& x : this->triples.back())
x.resize_regs(T::default_length);
for (int i = 0; i < T::default_length; i++)
{
auto triple = part_prep.get_triple(1);
for (int j = 0; j < 3; j++)
this->triples.back()[j].get_bit(j) = triple[j];
}
}
void buffer_bits()
@@ -72,6 +73,25 @@ public:
{
throw not_implemented();
}
void buffer_inputs(int player)
{
this->inputs[player].push_back({});
this->inputs[player].back().share.resize_regs(T::default_length);
for (int i = 0; i < T::default_length; i++)
{
typename T::part_type::open_type tmp;
part_prep.get_input(this->inputs[player].back().share.get_reg(i),
tmp, player);
this->inputs[player].back().value ^=
(typename T::clear(tmp.get_bit(0)) << i);
}
}
size_t data_sent()
{
return part_prep.data_sent();
}
};
} /* namespace GC */

33
GC/CcdPrep.hpp Normal file
View File

@@ -0,0 +1,33 @@
/*
* CcdPrep.hpp
*
*/
#ifndef GC_CCDPREP_HPP_
#define GC_CCDPREP_HPP_
#include "CcdPrep.h"
#include "Processor/Processor.hpp"
namespace GC
{
template<class T>
CcdPrep<T>::~CcdPrep()
{
if (part_proc)
delete part_proc;
}
template<class T>
void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
{
assert(thread.MC);
part_proc = new SubProcessor<typename T::part_type>(
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
}
}
#endif /* GC_CCDPREP_HPP_ */

View File

@@ -34,11 +34,6 @@ public:
static const int default_length = 1;
static DataFieldType field_type()
{
return DATA_GF2;
}
static string name()
{
return "CCD";

View File

@@ -70,8 +70,6 @@ public:
static const true_type invertible;
static const true_type characteristic_two;
static DataFieldType field_type() { return DATA_GF2; }
static MC* new_mc(mac_key_type key) { return new MC(key); }
static void store_clear_in_dynamic(Memory<DynamicType>& mem,

View File

@@ -39,11 +39,6 @@ public:
static const int default_length = 1;
static DataFieldType field_type()
{
return DATA_GF2;
}
static string name()
{
return "Malicious CCD";

View File

@@ -54,6 +54,11 @@ public:
return "no";
}
static DataFieldType field_type()
{
throw not_implemented();
}
static void fail()
{
throw runtime_error("VM does not support binary circuits");
@@ -101,16 +106,10 @@ public:
typedef NoValue clear;
typedef NoValue mac_key_type;
typedef NoShare bit_type;
typedef NoShare part_type;
typedef NoShare small_type;
typedef BlackHole out_type;
static const int default_length = 1;
static const bool needs_ot = false;
static const bool expensive_triples = false;
static const bool is_real = false;
static MC* new_mc(mac_key_type)
@@ -118,21 +117,6 @@ public:
return new MC;
}
template<class T>
static void generate_mac_key(mac_key_type, T)
{
}
static DataFieldType field_type()
{
throw not_implemented();
}
static string type_short()
{
return "";
}
static string type_string()
{
return "no";
@@ -155,7 +139,6 @@ public:
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
@@ -166,11 +149,8 @@ public:
void load_clear(Integer, Integer) { fail(); }
void random_bit() { fail(); }
void and_(int, NoShare&, NoShare&, bool) { fail(); }
void xor_(int, NoShare&, NoShare&) { fail(); }
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
void reveal(Integer, Integer) { fail(); }
void assign(const char*) { fail(); }
@@ -183,13 +163,11 @@ public:
NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
NoShare operator>>(int) const { fail(); return {}; }
NoShare& operator+=(const NoShare&) { fail(); return *this; }
NoShare lsb() const { fail(); return {}; }
NoShare get_bit(int) const { fail(); return {}; }
void invert(int, NoShare) { fail(); }

View File

@@ -88,7 +88,7 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
input.reset_all(P);
for (size_t i = begin; i < end; i++)
{
typename T::clear x[2];
typename T::open_type x[2];
for (int j = 0; j < 2; j++)
this->get_input(triples[i][j], x[j], input_player);
if (P.my_num() == input_player)

View File

@@ -84,6 +84,7 @@ public:
void xors(const vector<int>& args);
void xors(const vector<int>& args, size_t start, size_t end);
void xorc(const ::BaseInstruction& instruction);
void nots(const ::BaseInstruction& instruction);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);

View File

@@ -18,6 +18,7 @@ using namespace std;
#include "GC/Machine.hpp"
#include "Processor/ProcessorBase.hpp"
#include "Processor/IntInput.hpp"
#include "Math/bigint.hpp"
namespace GC
@@ -82,8 +83,12 @@ U GC::Processor<T>::get_long_input(const int* params,
{
if (not T::actual_inputs)
return {};
U res = input_proc.get_input<FixInput_<U>>(interactive,
&params[1]).items[0];
U res;
if (params[1] == 0)
res = input_proc.get_input<IntInput<U>>(interactive, 0).items[0];
else
res = input_proc.get_input<FixInput_<U>>(interactive,
&params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits);
return res;
@@ -229,6 +234,18 @@ void Processor<T>::xors(const vector<int>& args, size_t start, size_t end)
}
}
template<class T>
void Processor<T>::xorc(const ::BaseInstruction& instruction)
{
int total = instruction.get_n();
for (int i = 0; i < DIV_CEIL(total, T::default_length); i++)
{
int n = min(T::default_length, total - i * T::default_length);
C[instruction.get_r(0) + i] = BitVec(C[instruction.get_r(1) + i]).mask(n)
^ BitVec(C[instruction.get_r(2) + i]).mask(n);
}
}
template<class T>
void Processor<T>::nots(const ::BaseInstruction& instruction)
{

View File

@@ -7,7 +7,6 @@
#define GC_REP4SECRET_H_
#include "ShareSecret.h"
#include "Processor/NoLivePrep.h"
#include "Protocols/Rep4MC.h"
#include "Protocols/Rep4Share.h"

View File

@@ -1,11 +0,0 @@
/*
* ReplicatedPrep.cpp
*
*/
#include <GC/SemiHonestRepPrep.h>
namespace GC
{
} /* namespace GC */

View File

@@ -119,7 +119,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
try
{
read_mac_key(get_prep_sub_dir<T>(PREP_DIR, network_opts.nplayers), this->N,
read_mac_key(
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
this->N,
this->mac_key);
}
catch (exception& e)

View File

@@ -38,6 +38,8 @@ template<class U>
class ShareSecret
{
public:
typedef U whole_type;
typedef Memory<U> DynamicMemory;
typedef SwitchableOutput out_type;

View File

@@ -21,6 +21,7 @@
#include "ShareParty.h"
#include "ShareThread.hpp"
#include "Thread.hpp"
#include "VectorProtocol.hpp"
namespace GC
{

View File

@@ -29,7 +29,8 @@ ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, DataPositions&
*static_cast<Preprocessing<T>*>(new typename T::LivePrep(
usage, *this)) :
*static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N,
get_prep_sub_dir<T>(PREP_DIR, N.num_players()), usage)))
get_prep_sub_dir<T>(PREP_DIR, N.num_players()),
usage, BaseMachine::thread_num)))
{
}

View File

@@ -1,37 +0,0 @@
/*
* TinierPrep.h
*
*/
#ifndef GC_TINIERPREP_H_
#define GC_TINIERPREP_H_
#include "TinyPrep.h"
namespace GC
{
template<class T>
class TinierPrep : public TinyPrep<T>
{
public:
TinierPrep(DataPositions& usage, ShareThread<T>& thread,
bool amplify = true) :
TinyPrep<T>(usage, thread, amplify)
{
}
TinierPrep(SubProcessor<T>*, DataPositions& usage) :
TinierPrep(usage, ShareThread<T>::s())
{
}
void buffer_inputs(int player)
{
this->buffer_inputs_(player, this->triple_generator);
}
};
}
#endif /* GC_TINIERPREP_H_ */

View File

@@ -15,6 +15,9 @@ namespace GC
{
template<class T> class TinierPrep;
template<class T> class VectorProtocol;
template<class T> class CcdPrep;
template<class T> class VectorInput;
template<class T>
class TinierSecret : public VectorSecret<TinierShare<T>>
@@ -25,9 +28,9 @@ class TinierSecret : public VectorSecret<TinierShare<T>>
public:
typedef TinyMC<This> MC;
typedef MC MAC_Check;
typedef Beaver<This> Protocol;
typedef ::Input<This> Input;
typedef TinierPrep<This> LivePrep;
typedef VectorProtocol<This> Protocol;
typedef VectorInput<This> Input;
typedef CcdPrep<This> LivePrep;
typedef Memory<This> DynamicMemory;
typedef NPartyTripleGenerator<This> TripleGenerator;

View File

@@ -9,12 +9,12 @@
#include "Processor/DummyProtocol.h"
#include "Protocols/Share.h"
#include "Math/Bit.h"
#include "TinierSharePrep.h"
namespace GC
{
template<class T> class TinierSecret;
template<class T> class TinierSharePrep;
template<class T>
class TinierShare: public Share_<SemiShare<Bit>, SemiShare<T>>,
@@ -55,6 +55,11 @@ public:
return "Tinier";
}
static string type_short()
{
return "TT";
}
static ShareThread<TinierSecret<T>>& get_party()
{
return ShareThread<TinierSecret<T>>::s();
@@ -103,9 +108,7 @@ public:
void random()
{
TinierSecret<T> tmp;
get_party().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
*this = get_party().DataF.get_part().get_bit();
}
This lsb() const

View File

@@ -21,18 +21,26 @@ template<class T>
class TinierSharePrep : public PersonalPrep<T>
{
typename T::TripleGenerator* triple_generator;
typename T::whole_type::TripleGenerator* real_triple_generator;
MascotParams params;
TinierPrep<TinierSecret<typename T::mac_key_type>> whole_prep;
typedef typename T::whole_type secret_type;
ShareThread<secret_type>& thread;
void buffer_triples();
void buffer_squares() { throw not_implemented(); }
void buffer_bits() { throw not_implemented(); }
void buffer_bits();
void buffer_inverses() { throw not_implemented(); }
void buffer_inputs(int player);
void buffer_secret_triples();
void init_real(Player& P);
public:
TinierSharePrep(DataPositions& usage, ShareThread<secret_type>& thread,
int input_player = PersonalPrep<T>::SECURE);
TinierSharePrep(DataPositions& usage, int input_player =
PersonalPrep<T>::SECURE);
TinierSharePrep(SubProcessor<T>*, DataPositions& usage);

View File

@@ -15,10 +15,16 @@ namespace GC
template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) :
TinierSharePrep<T>(usage, ShareThread<secret_type>::s(), input_player)
{
}
template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage,
ShareThread<secret_type>& thread, int input_player) :
PersonalPrep<T>(usage, input_player), triple_generator(0),
whole_prep(usage,
ShareThread<TinierSecret<typename T::mac_key_type>>::s(),
input_player == PersonalPrep<T>::SECURE)
real_triple_generator(0),
thread(thread)
{
}
@@ -33,6 +39,8 @@ TinierSharePrep<T>::~TinierSharePrep()
{
if (triple_generator)
delete triple_generator;
if (real_triple_generator)
delete real_triple_generator;
}
template<class T>
@@ -44,15 +52,14 @@ void TinierSharePrep<T>::set_protocol(typename T::Protocol& protocol)
params.generateMACs = true;
params.amplify = false;
params.check = false;
auto& thread = ShareThread<TinierSecret<typename T::mac_key_type>>::s();
auto& thread = ShareThread<typename T::whole_type>::s();
triple_generator = new typename T::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
OnlineOptions::singleton.batch_size
* TinierSecret<typename T::mac_key_type>::default_length, 1,
OnlineOptions::singleton.batch_size, 1,
params, thread.MC->get_alphai(), &protocol.P);
triple_generator->multi_threaded = false;
this->inputs.resize(thread.P->num_players());
whole_prep.init(*thread.P);
init_real(protocol.P);
}
template<class T>
@@ -63,12 +70,8 @@ void TinierSharePrep<T>::buffer_triples()
this->buffer_personal_triples();
return;
}
array<TinierSecret<typename T::mac_key_type>, 3> whole;
whole_prep.get(DATA_TRIPLE, whole.data());
for (size_t i = 0; i < whole[0].get_regs().size(); i++)
this->triples.push_back(
{{ whole[0].get_reg(i), whole[1].get_reg(i), whole[2].get_reg(i) }});
else
buffer_secret_triples();
}
template<class T>
@@ -81,12 +84,21 @@ void TinierSharePrep<T>::buffer_inputs(int player)
inputs.at(player).push_back(x);
}
template<class T>
void GC::TinierSharePrep<T>::buffer_bits()
{
this->bits.push_back(
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
}
template<class T>
size_t TinierSharePrep<T>::data_sent()
{
size_t res = whole_prep.data_sent();
size_t res = 0;
if (triple_generator)
res += triple_generator->data_sent();
if (real_triple_generator)
res += real_triple_generator->data_sent();
return res;
}

View File

@@ -1,11 +0,0 @@
/*
* TinyMC.cpp
*
*/
#include "TinyMC.h"
namespace GC
{
} /* namespace GC */

View File

@@ -14,7 +14,7 @@ namespace GC
template<class T>
class TinyMC : public MAC_Check_Base<T>
{
typename T::check_type::MAC_Check part_MC;
typename T::part_type::MAC_Check part_MC;
PointerVector<int> sizes;
public:

View File

@@ -1,71 +0,0 @@
/*
* TinyPrep.h
*
*/
#ifndef GC_TINYPREP_H_
#define GC_TINYPREP_H_
#include "Thread.h"
#include "OT/MascotParams.h"
#include "Protocols/Beaver.h"
#include "Protocols/ReplicatedPrep.h"
namespace GC
{
template<class T>
class TinyPrep : public BufferPrep<T>
{
protected:
ShareThread<T>& thread;
typename T::TripleGenerator* triple_generator;
MascotParams params;
vector<array<typename T::part_type, 3>> triple_buffer;
const bool amplify;
public:
TinyPrep(DataPositions& usage, ShareThread<T>& thread, bool amplify = true);
~TinyPrep();
void set_protocol(Beaver<T>& protocol);
void init(Player& P);
void buffer_triples();
void buffer_bits();
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
void buffer_inputs_(int player, typename T::InputGenerator* input_generator);
array<T, 3> get_triple_no_count(int n_bits);
size_t data_sent();
};
template<class T>
class TinyOnlyPrep : public TinyPrep<T>
{
typename T::part_type::TripleGenerator* input_generator;
public:
TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread);
~TinyOnlyPrep();
void set_protocol(Beaver<T>& protocol);
void buffer_inputs(int player)
{
this->buffer_inputs_(player, input_generator);
}
size_t data_sent();
};
} /* namespace GC */
#endif /* GC_TINYPREP_H_ */

View File

@@ -3,7 +3,7 @@
*
*/
#include "TinyPrep.h"
#include "TinierSharePrep.h"
#include "Protocols/MascotPrep.hpp"
@@ -11,78 +11,26 @@ namespace GC
{
template<class T>
TinyPrep<T>::TinyPrep(DataPositions& usage, ShareThread<T>& thread,
bool amplify) :
BufferPrep<T>(usage), thread(thread), triple_generator(0),
amplify(amplify)
void TinierSharePrep<T>::init_real(Player& P)
{
}
template<class T>
TinyOnlyPrep<T>::TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread) :
TinyPrep<T>(usage, thread), input_generator(0)
{
}
template<class T>
TinyPrep<T>::~TinyPrep()
{
if (triple_generator)
delete triple_generator;
}
template<class T>
TinyOnlyPrep<T>::~TinyOnlyPrep()
{
if (input_generator)
delete input_generator;
}
template<class T>
void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
{
init(protocol.P);
}
template<class T>
void TinyPrep<T>::init(Player& P)
{
params.generateMACs = true;
params.amplify = false;
params.check = false;
auto& thread = ShareThread<T>::s();
triple_generator = new typename T::TripleGenerator(
assert(real_triple_generator == 0);
real_triple_generator = new typename T::whole_type::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), P.N, -1,
OnlineOptions::singleton.batch_size, 1, params,
thread.MC->get_alphai(), &P);
triple_generator->multi_threaded = false;
real_triple_generator->multi_threaded = false;
}
template<class T>
void TinyOnlyPrep<T>::set_protocol(Beaver<T>& protocol)
void TinierSharePrep<T>::buffer_secret_triples()
{
TinyPrep<T>::set_protocol(protocol);
input_generator = new typename T::part_type::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
OnlineOptions::singleton.batch_size, 1, this->params,
this->thread.MC->get_alphai(), &protocol.P);
input_generator->multi_threaded = false;
}
template<class T>
void TinyPrep<T>::buffer_triples()
{
auto& triple_generator = this->triple_generator;
auto& triple_generator = real_triple_generator;
assert(triple_generator != 0);
params.generateBits = false;
vector<array<typename T::check_type, 3>> triples;
TripleShuffleSacrifice<typename T::check_type> sacrifice;
vector<array<T, 3>> triples;
TripleShuffleSacrifice<T> sacrifice;
size_t required;
if (amplify)
required = sacrifice.minimum_n_inputs_with_combining();
else
required = sacrifice.minimum_n_inputs();
required = sacrifice.minimum_n_inputs_with_combining();
while (triples.size() < required)
{
triple_generator->generatePlainTriples();
@@ -92,9 +40,11 @@ void TinyPrep<T>::buffer_triples()
triple_generator->valueBits[2].set_portion(i,
triple_generator->plainTriples[i][2]);
triple_generator->run_multipliers({});
assert(triple_generator->plainTriples.size() != 0);
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
{
for (int j = 0; j < T::default_length; j++)
int dl = secret_type::default_length;
for (int j = 0; j < dl; j++)
{
triples.push_back({});
for (int k = 0; k < 3; k++)
@@ -103,10 +53,10 @@ void TinyPrep<T>::buffer_triples()
share.set_share(
triple_generator->plainTriples.at(i).at(k).get_bit(
j));
typename T::part_type::mac_type mac;
typename T::mac_type mac;
mac = thread.MC->get_alphai() * share.get_share();
for (auto& multiplier : triple_generator->ot_multipliers)
mac += multiplier->macs.at(k).at(i * T::default_length + j);
mac += multiplier->macs.at(k).at(i * dl + j);
share.set_mac(mac);
}
}
@@ -114,104 +64,10 @@ void TinyPrep<T>::buffer_triples()
}
sacrifice.triple_sacrifice(triples, triples,
*thread.P, thread.MC->get_part_MC());
if (amplify)
sacrifice.triple_combine(triples, triples, *thread.P,
thread.MC->get_part_MC());
for (size_t i = 0; i < triples.size() / T::default_length; i++)
{
this->triples.push_back({});
auto& triple = this->triples.back();
for (auto& x : triple)
x.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_triple = triples[j + i * T::default_length];
for (int k = 0; k < 3; k++)
triple[k].get_reg(j) = source_triple[k];
}
}
}
template<class T>
void TinyPrep<T>::buffer_bits()
{
auto tmp = BufferPrep<T>::get_random_from_inputs(thread.P->num_players());
for (auto& bit : tmp.get_regs())
{
this->bits.push_back({});
this->bits.back().resize_regs(1);
this->bits.back().get_reg(0) = bit;
}
}
template<class T>
void TinyPrep<T>::buffer_inputs_(int player, typename T::InputGenerator* input_generator)
{
auto& inputs = this->inputs;
inputs.resize(this->thread.P->num_players());
assert(input_generator);
input_generator->generateInputs(player);
assert(input_generator->inputs.size() >= T::default_length);
for (size_t i = 0; i < input_generator->inputs.size() / T::default_length; i++)
{
inputs[player].push_back({});
inputs[player].back().share.resize_regs(T::default_length);
for (int j = 0; j < T::default_length; j++)
{
auto& source_input = input_generator->inputs[j
+ i * T::default_length];
inputs[player].back().share.get_reg(j) = source_input.share;
inputs[player].back().value ^= typename T::open_type(
source_input.value.get_bit(0)) << j;
}
}
}
template<class T>
array<T, 3> TinyPrep<T>::get_triple_no_count(int n_bits)
{
assert(n_bits > 0);
while ((unsigned)n_bits > triple_buffer.size())
{
array<T, 3> tmp;
this->get(DATA_TRIPLE, tmp.data());
for (size_t i = 0; i < tmp[0].get_regs().size(); i++)
{
triple_buffer.push_back(
{ {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} });
}
}
array<T, 3> res;
for (int j = 0; j < 3; j++)
res[j].resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
{
for (int j = 0; j < 3; j++)
res[j].get_reg(i) = triple_buffer.back()[j];
triple_buffer.pop_back();
}
return res;
}
template<class T>
size_t TinyPrep<T>::data_sent()
{
size_t res = 0;
if (triple_generator)
res += triple_generator->data_sent();
return res;
}
template<class T>
size_t TinyOnlyPrep<T>::data_sent()
{
auto res = TinyPrep<T>::data_sent();
if (input_generator)
res += input_generator->data_sent();
return res;
sacrifice.triple_combine(triples, triples, *thread.P,
thread.MC->get_part_MC());
for (auto& triple : triples)
this->triples.push_back(triple);
}
} /* namespace GC */

View File

@@ -1,11 +0,0 @@
/*
* TinySecret.cpp
*
*/
#include "TinySecret.h"
namespace GC
{
} /* namespace GC */

View File

@@ -21,6 +21,9 @@ namespace GC
template<class T> class TinyOnlyPrep;
template<class T> class TinyMC;
template<class T> class VectorProtocol;
template<class T> class VectorInput;
template<class T> class CcdPrep;
template<class T>
class VectorSecret : public Secret<T>
@@ -50,11 +53,6 @@ public:
static const int default_length = 64;
static DataFieldType field_type()
{
return BitVec::field_type();
}
static int size()
{
return part_type::size() * default_length;
@@ -166,9 +164,9 @@ public:
}
template <class U>
void other_input(U& inputter, int from, int)
void other_input(U& inputter, int from, int n_bits)
{
inputter.add_other(from);
inputter.add_other(from, n_bits);
}
template <class U>
@@ -187,9 +185,9 @@ class TinySecret : public VectorSecret<TinyShare<S>>
public:
typedef TinyMC<This> MC;
typedef MC MAC_Check;
typedef Beaver<This> Protocol;
typedef ::Input<This> Input;
typedef TinyOnlyPrep<This> LivePrep;
typedef VectorProtocol<This> Protocol;
typedef VectorInput<This> Input;
typedef CcdPrep<This> LivePrep;
typedef Memory<This> DynamicMemory;
typedef OTTripleGenerator<This> TripleGenerator;

View File

@@ -1,11 +0,0 @@
/*
* TinyShare.cpp
*
*/
#include "TinyShare.h"
namespace GC
{
} /* namespace GC */

View File

@@ -10,13 +10,14 @@
#include "ShareParty.h"
#include "Secret.h"
#include "Protocols/Spdz2kShare.h"
#include "Processor/NoLivePrep.h"
namespace GC
{
template<int S> class TinySecret;
template<class T> class ShareThread;
template<class T> class TinierSharePrep;
template<int S>
class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>>
@@ -28,12 +29,18 @@ public:
typedef void DynamicMemory;
typedef NoLivePrep<This> LivePrep;
typedef Beaver<This> Protocol;
typedef MAC_Check_Z2k_<This> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ::Input<This> Input;
typedef TinierSharePrep<This> LivePrep;
typedef SwitchableOutput out_type;
typedef This small_type;
typedef NoShare bit_type;
static string name()
{
return "tiny share";

View File

@@ -18,14 +18,14 @@ class VectorInput : public InputBase<T>
deque<int> input_lengths;
public:
VectorInput(typename T::MAC_Check&, Preprocessing<T>&, Player& P) :
part_input(0, P)
VectorInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
part_input(MC.get_part_MC(), prep.get_part(), P)
{
part_input.reset_all(P);
}
VectorInput(SubProcessor<T>& proc, typename T::MAC_Check&) :
VectorInput(proc.MC, proc.DataF, proc.P)
part_input(proc.MC, proc.DataF, proc.P)
{
}
@@ -41,8 +41,10 @@ public:
input_lengths.push_back(n_bits);
}
void add_other(int)
void add_other(int player, int n_bits)
{
for (int i = 0; i < n_bits; i++)
part_input.add_other(player);
}
void send_mine()

View File

@@ -17,6 +17,8 @@ class VectorProtocol : public ProtocolBase<T>
typename T::part_type::Protocol part_protocol;
public:
Player& P;
VectorProtocol(Player& P);
void init_mul(SubProcessor<T>* proc);

View File

@@ -3,6 +3,9 @@
*
*/
#ifndef GC_VECTORPROTOCOL_HPP_
#define GC_VECTORPROTOCOL_HPP_
#include "VectorProtocol.h"
namespace GC
@@ -10,7 +13,7 @@ namespace GC
template<class T>
VectorProtocol<T>::VectorProtocol(Player& P) :
part_protocol(P)
part_protocol(P), P(P)
{
}
@@ -54,3 +57,5 @@ T VectorProtocol<T>::finalize_mul(int n)
}
} /* namespace GC */
#endif

View File

@@ -40,7 +40,7 @@
#define BIT_INSTRUCTIONS \
X(XORS, T::xors(PROC, EXTRA)) \
X(XORCB, C0.xor_(PC1, PC2)) \
X(XORCB, processor.xorc(instruction)) \
X(XORCBI, C0.xor_(PC1, IMM)) \
X(NOTS, processor.nots(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \

View File

@@ -20,13 +20,14 @@
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/TinierPrep.h"
#include "GC/VectorInput.h"
#include "GC/ShareParty.hpp"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "Math/gfp.hpp"

View File

@@ -8,9 +8,8 @@
#include "GC/TinySecret.h"
#include "GC/TinyMC.h"
#include "GC/TinyPrep.h"
#include "GC/TinierPrep.h"
#include "GC/TinierSecret.h"
#include "GC/VectorInput.h"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
@@ -29,3 +28,4 @@
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"

View File

@@ -30,6 +30,7 @@
#include "GC/ShareSecret.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Math/gfp.hpp"
ShamirOptions ShamirOptions::singleton;

View File

@@ -12,6 +12,7 @@
#include "GC/ShareSecret.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Machines/ShamirMachine.hpp"
int main(int argc, const char** argv)

View File

@@ -12,7 +12,6 @@
#include "FHE/NTL-Subs.h"
#include "GC/TinierSecret.h"
#include "GC/TinierPrep.h"
#include "GC/TinyMC.h"
#include "SPDZ.hpp"

View File

@@ -18,12 +18,14 @@
int main(int argc, const char** argv)
{
OnlineOptions online_opts;
Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector<string>({"localhost"}));
OnlineOptions& online_opts = OnlineOptions::singleton;
Names N;
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
online_opts = {opt, argc, argv};
opt.parse(argc, argv);
opt.syntax = string(argv[0]) + " <progname>";
string progname;
if (opt.firstArgs.size() > 1)
progname = *opt.firstArgs.at(1);
@@ -50,36 +52,14 @@ int main(int argc, const char** argv)
int R = ring_opts.ring_size_from_opts_or_schedule(progname);
switch (R)
{
case 64:
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 128:
Machine<FakeShare<SignedZ2<128>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 256:
Machine<FakeShare<SignedZ2<256>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 192:
Machine<FakeShare<SignedZ2<192>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 384:
Machine<FakeShare<SignedZ2<384>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
break;
case 512:
Machine<FakeShare<SignedZ2<512>>, FakeShare<gf2n>>(0, N, progname,
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
online_opts.live_prep, online_opts).run();
#define X(L) \
case L: \
Machine<FakeShare<SignedZ2<L>>, FakeShare<gf2n>>(0, N, progname, \
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \
online_opts.live_prep, online_opts).run(); \
break;
X(64) X(128) X(256) X(192) X(384) X(512)
#undef X
default:
cerr << "Not compiled for " << R << "-bit rings" << endl;
}

View File

@@ -12,6 +12,7 @@
#include "GC/ShareSecret.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "GC/CcdPrep.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/MalRep.hpp"

23
Machines/no-party.cpp Normal file
View File

@@ -0,0 +1,23 @@
/*
* no-party.cpp
*
*/
#include "Protocols/NoShare.h"
#include "Processor/OnlineMachine.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/Replicated.hpp"
#include "Math/gfp.hpp"
#include "Math/Z2k.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
OnlineOptions::singleton = {opt, argc, argv};
OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton);
OnlineOptions::singleton.finalize(opt, argc, argv);
machine.start_networking();
// use primes of length 65 to 128 for arithmetic computation
machine.run<NoShare<gfp_<0, 2>>, NoShare<gf2n>>();
}

View File

@@ -12,7 +12,6 @@
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Tools/ezOptionParser.h"
#include "Processor/NoLivePrep.h"
#include "GC/MaliciousCcdSecret.h"
#include "Processor/FieldMachine.hpp"

Some files were not shown because too many files have changed in this diff Show More