mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-07 20:53:55 -05:00
Convolutional neural network training.
This commit is contained in:
@@ -167,7 +167,7 @@ void Node::Broadcast2(SendBuffer& msg) {
|
||||
}
|
||||
|
||||
void Node::_identify() {
|
||||
char* msg = id_msg;
|
||||
char msg[strlen(ID_HDR)+sizeof(_id)];
|
||||
memcpy(msg, ID_HDR, strlen(ID_HDR));
|
||||
memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
|
||||
//printf("Node:: identifying myself:\n");
|
||||
|
||||
@@ -78,8 +78,6 @@ private:
|
||||
std::map<struct sockaddr_in*,int> _clientsmap;
|
||||
bool* _clients_connected;
|
||||
NodeUpdatable* _updatable;
|
||||
|
||||
char id_msg[strlen(ID_HDR)+sizeof(_id)];
|
||||
};
|
||||
|
||||
#endif /* NETWORK_NODE_H_ */
|
||||
|
||||
12
CHANGELOG.md
12
CHANGELOG.md
@@ -1,5 +1,17 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.5 (Jul 2, 2021)
|
||||
|
||||
- Training of convolutional neural networks
|
||||
- Bit decomposition using edaBits
|
||||
- Ability to force MAC checks from high-level code
|
||||
- Ability to close client connection from high-level code
|
||||
- Binary operators for comparison results
|
||||
- Faster compilation for emulation
|
||||
- More documentation
|
||||
- Fixed security bug: insufficient LowGear secret key randomness
|
||||
- Fixed security bug: skewed random bit generation
|
||||
|
||||
## 0.2.4 (Apr 19, 2021)
|
||||
|
||||
- ARM support
|
||||
|
||||
@@ -117,7 +117,7 @@ class xorm(NonVectorInstruction):
|
||||
code = opcodes['XORM']
|
||||
arg_format = ['int','sbw','sb','cb']
|
||||
|
||||
class xorcb(NonVectorInstruction):
|
||||
class xorcb(BinaryVectorInstruction):
|
||||
""" Bitwise XOR of two single clear bit registers.
|
||||
|
||||
:param: result (cbit)
|
||||
@@ -125,7 +125,7 @@ class xorcb(NonVectorInstruction):
|
||||
:param: operand (cbit)
|
||||
"""
|
||||
code = opcodes['XORCB']
|
||||
arg_format = ['cbw','cb','cb']
|
||||
arg_format = ['int','cbw','cb','cb']
|
||||
|
||||
class xorcbi(NonVectorInstruction):
|
||||
""" Bitwise XOR of single clear bit register and immediate.
|
||||
|
||||
@@ -36,6 +36,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
class bitsn(cls):
|
||||
n = length
|
||||
cls.types[length] = bitsn
|
||||
bitsn.clear_type = cbits.get_type(length)
|
||||
bitsn.__name__ = cls.__name__ + str(length)
|
||||
return cls.types[length]
|
||||
@classmethod
|
||||
@@ -115,7 +116,11 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, int)](self, address)
|
||||
@classmethod
|
||||
def new(cls, value=None, n=None):
|
||||
return cls.get_type(n)(value)
|
||||
def __init__(self, value=None, n=None, size=None):
|
||||
assert n == self.n or n is None
|
||||
if size != 1 and size is not None:
|
||||
raise Exception('invalid size for bit type: %s' % size)
|
||||
self.n = n or self.n
|
||||
@@ -125,7 +130,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
if value is not None:
|
||||
self.load_other(value)
|
||||
def copy(self):
|
||||
return type(self)(n=instructions_base.get_global_vector_size())
|
||||
return type(self).new(n=instructions_base.get_global_vector_size())
|
||||
def set_length(self, n):
|
||||
if n > self.n:
|
||||
raise Exception('too long: %d/%d' % (n, self.n))
|
||||
@@ -154,6 +159,8 @@ class bits(Tape.Register, _structure, _bit):
|
||||
bits = other.bit_decompose()
|
||||
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
|
||||
other = self.bit_compose(bits)
|
||||
assert(isinstance(other, type(self)))
|
||||
assert(other.n == self.n)
|
||||
self.load_other(other)
|
||||
except:
|
||||
raise CompilerError('cannot convert %s/%s from %s to %s' % \
|
||||
@@ -176,6 +183,16 @@ class bits(Tape.Register, _structure, _bit):
|
||||
res.i = i
|
||||
res.program = self.program
|
||||
return res
|
||||
def if_else(self, x, y):
|
||||
"""
|
||||
Vectorized oblivious selection::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
|
||||
|
||||
This will output 1.
|
||||
"""
|
||||
return result_conv(x, y)(self & (x ^ y) ^ y)
|
||||
|
||||
class cbits(bits):
|
||||
""" Clear bits register. Helper type with limited functionality. """
|
||||
@@ -202,14 +219,16 @@ class cbits(bits):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
def clear_op(self, other, c_inst, ci_inst, op):
|
||||
if isinstance(other, cbits):
|
||||
res = cbits(n=max(self.n, other.n))
|
||||
res = cbits.get_type(max(self.n, other.n))()
|
||||
c_inst(res, self, other)
|
||||
return res
|
||||
elif isinstance(other, sbits):
|
||||
return NotImplemented
|
||||
else:
|
||||
if util.is_constant(other):
|
||||
if other >= 2**31 or other < -2**31:
|
||||
return op(self, cbits(other))
|
||||
res = cbits(n=max(self.n, len(bin(other)) - 2))
|
||||
res = cbits.get_type(max(self.n, len(bin(other)) - 2))()
|
||||
ci_inst(res, self, other)
|
||||
return res
|
||||
else:
|
||||
@@ -221,8 +240,14 @@ class cbits(bits):
|
||||
def __xor__(self, other):
|
||||
if isinstance(other, (sbits, sbitvec)):
|
||||
return NotImplemented
|
||||
elif isinstance(other, cbits):
|
||||
res = cbits.get_type(max(self.n, other.n))()
|
||||
assert res.size == self.size
|
||||
assert res.size == other.size
|
||||
inst.xorcb(res.n, res, self, other)
|
||||
return res
|
||||
else:
|
||||
self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor)
|
||||
return self.clear_op(other, None, inst.xorcbi, operator.xor)
|
||||
__radd__ = __add__
|
||||
__rxor__ = __xor__
|
||||
def __mul__(self, other):
|
||||
@@ -230,17 +255,18 @@ class cbits(bits):
|
||||
return NotImplemented
|
||||
else:
|
||||
try:
|
||||
res = cbits(n=min(self.max_length, self.n+util.int_len(other)))
|
||||
res = cbits.get_type(min(self.max_length,
|
||||
self.n+util.int_len(other)))()
|
||||
inst.mulcbi(res, self, other)
|
||||
return res
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
def __rshift__(self, other):
|
||||
res = cbits(n=self.n-other)
|
||||
res = cbits.new(n=self.n-other)
|
||||
inst.shrcbi(res, self, other)
|
||||
return res
|
||||
def __lshift__(self, other):
|
||||
res = cbits(n=self.n+other)
|
||||
res = cbits.get_type(self.n+other)()
|
||||
inst.shlcbi(res, self, other)
|
||||
return res
|
||||
def print_reg(self, desc=''):
|
||||
@@ -504,16 +530,6 @@ class sbits(bits):
|
||||
res = [cls.new(n=len(rows)) for i in range(n_columns)]
|
||||
inst.trans(len(res), *(res + rows))
|
||||
return res
|
||||
def if_else(self, x, y):
|
||||
"""
|
||||
Vectorized oblivious selection::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
|
||||
|
||||
This will output 1.
|
||||
"""
|
||||
return result_conv(x, y)(self & (x ^ y) ^ y)
|
||||
@staticmethod
|
||||
def bit_adder(*args, **kwargs):
|
||||
return sbitint.bit_adder(*args, **kwargs)
|
||||
@@ -610,7 +626,7 @@ class sbitvec(_vec):
|
||||
elif isinstance(other, (list, tuple)):
|
||||
self.v = self.bit_extend(sbitvec(other).v, n)
|
||||
else:
|
||||
self.v = sbits(other, n=n).bit_decompose(n)
|
||||
self.v = sbits.get_type(n)(other).bit_decompose()
|
||||
assert len(self.v) == n
|
||||
@classmethod
|
||||
def load_mem(cls, address):
|
||||
@@ -630,6 +646,8 @@ class sbitvec(_vec):
|
||||
for i in range(n):
|
||||
v[i].store_in_mem(address + i)
|
||||
def reveal(self):
|
||||
if len(self) > cbits.unit:
|
||||
return self.elements()[0].reveal()
|
||||
revealed = [cbit() for i in range(len(self))]
|
||||
for i in range(len(self)):
|
||||
try:
|
||||
@@ -784,15 +802,23 @@ class bit(object):
|
||||
|
||||
def result_conv(x, y):
|
||||
try:
|
||||
def f(res):
|
||||
try:
|
||||
return t.conv(res)
|
||||
except:
|
||||
return res
|
||||
if util.is_constant(x):
|
||||
if util.is_constant(y):
|
||||
return lambda x: x
|
||||
else:
|
||||
return type(y).conv
|
||||
t = type(y)
|
||||
return f
|
||||
if util.is_constant(y):
|
||||
return type(x).conv
|
||||
t = type(x)
|
||||
return f
|
||||
if type(x) is type(y):
|
||||
return type(x).conv
|
||||
t = type(x)
|
||||
return f
|
||||
except AttributeError:
|
||||
pass
|
||||
return lambda x: x
|
||||
@@ -807,13 +833,19 @@ class sbit(bit, sbits):
|
||||
|
||||
This will output 5.
|
||||
"""
|
||||
return result_conv(x, y)(self * (x ^ y) ^ y)
|
||||
assert self.n == 1
|
||||
diff = x ^ y
|
||||
if isinstance(diff, cbits):
|
||||
return result_conv(x, y)(self & (diff) ^ y)
|
||||
else:
|
||||
return result_conv(x, y)(self * (diff) ^ y)
|
||||
|
||||
class cbit(bit, cbits):
|
||||
pass
|
||||
|
||||
sbits.bit_type = sbit
|
||||
cbits.bit_type = cbit
|
||||
sbit.clear_type = cbit
|
||||
|
||||
class bitsBlock(oram.Block):
|
||||
value_type = sbits
|
||||
@@ -881,7 +913,7 @@ class _sbitintbase:
|
||||
return self.get_type(k - m).compose(res_bits)
|
||||
def int_div(self, other, bit_length=None):
|
||||
k = bit_length or max(self.n, other.n)
|
||||
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
|
||||
return (library.IntDiv(self.cast(k), other.cast(k), k) >> k).cast(k)
|
||||
def Norm(self, k, f, kappa=None, simplex_flag=False):
|
||||
absolute_val = abs(self)
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
@@ -1100,7 +1132,8 @@ class cbitfix(object):
|
||||
bits = self.v.bit_decompose(self.k)
|
||||
sign = bits[-1]
|
||||
v += (sign << (self.k)) * -1
|
||||
inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0))
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
|
||||
class sbitfix(_fix):
|
||||
""" Secret signed integer in one binary register.
|
||||
|
||||
@@ -123,7 +123,7 @@ class StraightlineAllocator:
|
||||
for x in itertools.chain(dup.duplicates, base.duplicates):
|
||||
to_check.add(x)
|
||||
|
||||
free[reg.reg_type, base.size].add(self.alloc[base])
|
||||
free[reg.reg_type, base.size].append(self.alloc[base])
|
||||
if inst.is_vec() and base.vector:
|
||||
self.defined[base] = inst
|
||||
for i in base.vector:
|
||||
@@ -604,4 +604,4 @@ class RegintOptimizer:
|
||||
elif op == 1:
|
||||
instructions[i] = None
|
||||
inst.args[0].link(inst.args[1])
|
||||
instructions[:] = filter(lambda x: x is not None, instructions)
|
||||
instructions[:] = list(filter(lambda x: x is not None, instructions))
|
||||
|
||||
@@ -127,7 +127,7 @@ def sha3_256(x):
|
||||
|
||||
from circuit import sha3_256
|
||||
a = sbitvec.from_vec([])
|
||||
b = sbitvec(sint(0xcc), 8)
|
||||
b = sbitvec(sint(0xcc), 8, 8)
|
||||
for x in a, b:
|
||||
sha3_256(x).elements()[0].reveal().print_reg()
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ def require_ring_size(k, op):
|
||||
if int(program.options.ring) < k:
|
||||
raise CompilerError('ring size too small for %s, compile '
|
||||
'with \'-R %d\' or more' % (op, k))
|
||||
program.curr_tape.require_bit_length(k)
|
||||
|
||||
@instructions_base.cisc
|
||||
def LTZ(s, a, k, kappa):
|
||||
|
||||
@@ -55,7 +55,7 @@ def EQZ(a, k, kappa):
|
||||
from GC.types import sbitvec
|
||||
v = sbitvec(a, k).v
|
||||
bit = util.tree_reduce(operator.and_, (~b for b in v))
|
||||
return types.sint.conv(bit)
|
||||
return types.sintbit.conv(bit)
|
||||
prog.non_linear.check_security(kappa)
|
||||
return prog.non_linear.eqz(a, k)
|
||||
|
||||
@@ -263,16 +263,17 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
return program.Program.prog.non_linear.bit_dec(a, k, m)
|
||||
|
||||
def BitDecRing(a, k, m):
|
||||
def BitDecRingRaw(a, k, m):
|
||||
n_shift = int(program.Program.prog.options.ring) - m
|
||||
assert(n_shift >= 0)
|
||||
if program.Program.prog.use_split():
|
||||
x = a.split_to_two_summands(m)
|
||||
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
return bits[:m]
|
||||
else:
|
||||
if program.Program.prog.use_dabit:
|
||||
if program.Program.prog.use_edabit():
|
||||
r, r_bits = types.sint.get_edabit(m, strict=False)
|
||||
elif program.Program.prog.use_dabit:
|
||||
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
|
||||
r = types.sint.bit_compose(r)
|
||||
else:
|
||||
@@ -281,7 +282,12 @@ def BitDecRing(a, k, m):
|
||||
shifted = ((a - r) << n_shift).reveal()
|
||||
masked = shifted >> n_shift
|
||||
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
|
||||
return [types.sint.conv(bit) for bit in bits]
|
||||
return bits
|
||||
|
||||
def BitDecRing(a, k, m):
|
||||
bits = BitDecRingRaw(a, k, m)
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
r_dprime = types.sint()
|
||||
@@ -429,7 +435,7 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
s = (1 - overflow) * t + overflow * t / 2
|
||||
return s, overflow
|
||||
|
||||
def Int2FL(a, gamma, l, kappa):
|
||||
def Int2FL(a, gamma, l, kappa=None):
|
||||
lam = gamma - 1
|
||||
s = a.less_than(0, gamma, security=kappa)
|
||||
z = a.equal(0, gamma, security=kappa)
|
||||
@@ -598,13 +604,13 @@ def SDiv_mono(a, b, l, kappa):
|
||||
# Unconditionally Secure Constant-Rounds Multi-party Computation
|
||||
# for Equality, Comparison, Bits and Exponentiation
|
||||
def BITLT(a, b, bit_length):
|
||||
sint = types.sint
|
||||
e = [sint(0)]*bit_length
|
||||
g = [sint(0)]*bit_length
|
||||
h = [sint(0)]*bit_length
|
||||
from .types import sint, regint, longint, cint
|
||||
e = [None]*bit_length
|
||||
g = [None]*bit_length
|
||||
h = [None]*bit_length
|
||||
for i in range(bit_length):
|
||||
# Compute the XOR (reverse order of e for PreOpL)
|
||||
e[bit_length-i-1] = a[i].bit_xor(b[i])
|
||||
e[bit_length-i-1] = util.bit_xor(a[i], b[i])
|
||||
f = PreOpL(or_op, e)
|
||||
g[bit_length-1] = f[0]
|
||||
for i in range(bit_length-1):
|
||||
@@ -612,7 +618,7 @@ def BITLT(a, b, bit_length):
|
||||
g[i] = f[bit_length-i-1]-f[bit_length-i-2]
|
||||
ans = 0
|
||||
for i in range(bit_length):
|
||||
h[i] = g[i]*b[i]
|
||||
h[i] = g[i].bit_and(b[i])
|
||||
ans = ans + h[i]
|
||||
return ans
|
||||
|
||||
@@ -620,9 +626,9 @@ def BITLT(a, b, bit_length):
|
||||
# - From the paper
|
||||
# Multiparty Computation for Interval, Equality, and Comparison without
|
||||
# Bit-Decomposition Protocol
|
||||
def BitDecFull(a):
|
||||
def BitDecFull(a, maybe_mixed=False):
|
||||
from .library import get_program, do_while, if_, break_point
|
||||
from .types import sint, regint, longint
|
||||
from .types import sint, regint, longint, cint
|
||||
p = get_program().prime
|
||||
assert p
|
||||
bit_length = p.bit_length()
|
||||
@@ -631,9 +637,16 @@ def BitDecFull(a):
|
||||
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
|
||||
# no need for exact randomness generation
|
||||
# if modulo a power of two is close enough
|
||||
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
|
||||
if logp != bit_length:
|
||||
bbits += [sint(0, size=a.size)]
|
||||
if get_program().use_edabit():
|
||||
b, bbits = sint.get_edabit(logp, True, size=a.size)
|
||||
if logp != bit_length:
|
||||
from .GC.types import sbits
|
||||
bbits += [sbits.get_type(a.size)(0)]
|
||||
else:
|
||||
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
|
||||
b = sint.bit_compose(bbits)
|
||||
if logp != bit_length:
|
||||
bbits += [sint(0, size=a.size)]
|
||||
else:
|
||||
bbits = [sint(size=a.size) for i in range(bit_length)]
|
||||
tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)]
|
||||
@@ -653,15 +666,21 @@ def BitDecFull(a):
|
||||
for j in range(a.size):
|
||||
for i in range(bit_length):
|
||||
movs(bbits[i][j], tbits[j][i])
|
||||
b = sint.bit_compose(bbits)
|
||||
b = sint.bit_compose(bbits)
|
||||
c = (a-b).reveal()
|
||||
t = (p-c).bit_decompose(bit_length)
|
||||
cmodp = c
|
||||
t = bbits[0].bit_decompose_clear(p - c, bit_length)
|
||||
c = longint(c, bit_length)
|
||||
czero = (c==0)
|
||||
q = 1-BITLT( bbits, t, bit_length)
|
||||
fbar=((1<<bit_length)+c-p).bit_decompose(bit_length)
|
||||
fbard = c.bit_decompose(bit_length)
|
||||
g = [(fbar[i] - fbard[i]) * q + fbard[i] for i in range(bit_length)]
|
||||
h = BitAdd(bbits, g)
|
||||
abits = [(1 - czero) * h[i] + czero * bbits[i] for i in range(bit_length)]
|
||||
return abits
|
||||
q = bbits[0].long_one() - BITLT(bbits, t, bit_length)
|
||||
fbar = [bbits[0].clear_type.conv(cint(x))
|
||||
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
|
||||
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
|
||||
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
|
||||
h = bbits[0].bit_adder(bbits, g)
|
||||
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
|
||||
for i in range(bit_length)]
|
||||
if maybe_mixed:
|
||||
return abits
|
||||
else:
|
||||
return [sint.conv(bit) for bit in abits]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import heapq
|
||||
import collections
|
||||
from Compiler.exceptions import *
|
||||
|
||||
class GraphError(CompilerError):
|
||||
@@ -23,7 +24,7 @@ class SparseDiGraph(object):
|
||||
self.n = max_nodes
|
||||
# each node contains list of default attributes, followed by outoing edges
|
||||
self.nodes = [list(self.default_attributes.values()) for i in range(self.n)]
|
||||
self.succ = [set() for i in range(self.n)]
|
||||
self.succ = [collections.OrderedDict() for i in range(self.n)]
|
||||
self.pred = [set() for i in range(self.n)]
|
||||
self.weights = {}
|
||||
|
||||
@@ -32,7 +33,7 @@ class SparseDiGraph(object):
|
||||
|
||||
def __getitem__(self, i):
|
||||
""" Get list of the neighbours of node i """
|
||||
return self.succ[i]
|
||||
return self.succ[i].keys()
|
||||
|
||||
def __iter__(self):
|
||||
pass #return iter(self.nodes)
|
||||
@@ -68,7 +69,7 @@ class SparseDiGraph(object):
|
||||
self.pred[v].remove(i)
|
||||
#del self.weights[(i,v)]
|
||||
for v in pred:
|
||||
self.succ[v].remove(i)
|
||||
del self.succ[v][i]
|
||||
#del self.weights[(v,i)]
|
||||
#self.nodes[v].remove(i)
|
||||
self.pred[i] = []
|
||||
@@ -77,7 +78,7 @@ class SparseDiGraph(object):
|
||||
def add_edge(self, i, j, weight=1):
|
||||
if j not in self[i]:
|
||||
self.pred[j].add(i)
|
||||
self.succ[i].add(j)
|
||||
self.succ[i][j] = None
|
||||
self.weights[(i,j)] = weight
|
||||
|
||||
def add_edges_from(self, tuples):
|
||||
@@ -89,7 +90,7 @@ class SparseDiGraph(object):
|
||||
self.add_edge(edge[0], edge[1])
|
||||
|
||||
def remove_edge(self, i, j):
|
||||
self.succ[i].remove(j)
|
||||
del self.succ[i][j]
|
||||
self.pred[j].remove(i)
|
||||
del self.weights[(i,j)]
|
||||
|
||||
|
||||
@@ -2219,22 +2219,23 @@ class conv2ds(base.DataInstruction):
|
||||
:param: number of channels (int)
|
||||
:param: padding height (int)
|
||||
:param: padding width (int)
|
||||
:param: batch size (int)
|
||||
"""
|
||||
code = base.opcodes['CONV2DS']
|
||||
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
|
||||
'int','int','int']
|
||||
'int','int','int','int']
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(conv2ds, self).__init__(*args, **kwargs)
|
||||
assert args[0].size == args[3] * args[4]
|
||||
assert args[1].size == args[5] * args[6] * args[11]
|
||||
assert args[0].size == args[3] * args[4] * args[14]
|
||||
assert args[1].size == args[5] * args[6] * args[11] * args[14]
|
||||
assert args[2].size == args[7] * args[8] * args[11]
|
||||
|
||||
def get_repeat(self):
|
||||
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
|
||||
self.args[11]
|
||||
self.args[11] * self.args[14]
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
@@ -2250,6 +2251,15 @@ class trunc_pr(base.VarArgsInstruction):
|
||||
code = base.opcodes['TRUNC_PR']
|
||||
arg_format = tools.cycle(['sw','s','int','int'])
|
||||
|
||||
class check(base.Instruction):
|
||||
"""
|
||||
Force MAC check in current thread and all idle thread if current
|
||||
thread is the main thread.
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['CHECK']
|
||||
arg_format = []
|
||||
|
||||
###
|
||||
### CISC-style instructions
|
||||
###
|
||||
@@ -2289,47 +2299,5 @@ class lts(base.CISC):
|
||||
subs(a, self.args[1], self.args[2])
|
||||
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
|
||||
|
||||
@base.vectorize
|
||||
class g2muls(base.CISC):
|
||||
r""" Secret GF(2) multiplication """
|
||||
__slots__ = []
|
||||
arg_format = ['sgw','sg','sg']
|
||||
|
||||
def expand(self):
|
||||
s = [program.curr_block.new_reg('sg') for i in range(9)]
|
||||
c = [program.curr_block.new_reg('cg') for i in range(3)]
|
||||
gbittriple(s[0], s[1], s[2])
|
||||
gsubs(s[3], self.args[1], s[0])
|
||||
gsubs(s[4], self.args[2], s[1])
|
||||
gasm_open(c[0], s[3])
|
||||
gasm_open(c[1], s[4])
|
||||
gmulbitm(s[5], s[1], c[0])
|
||||
gmulbitm(s[6], s[0], c[1])
|
||||
gmulbitc(c[2], c[0], c[1])
|
||||
gadds(s[7], s[2], s[5])
|
||||
gadds(s[8], s[7], s[6])
|
||||
gaddm(self.args[0], s[8], c[2])
|
||||
|
||||
#@base.vectorize
|
||||
#class gmulbits(base.CISC):
|
||||
# r""" Secret $GF(2^n) \times GF(2)$ multiplication """
|
||||
# __slots__ = []
|
||||
# arg_format = ['sgw','sg','sg']
|
||||
#
|
||||
# def expand(self):
|
||||
# s = [program.curr_block.new_reg('s') for i in range(9)]
|
||||
# c = [program.curr_block.new_reg('c') for i in range(3)]
|
||||
# g2ntriple(s[0], s[1], s[2])
|
||||
# subs(s[3], self.args[1], s[0])
|
||||
# subs(s[4], self.args[2], s[1])
|
||||
# startopen(s[3], s[4])
|
||||
# stopopen(c[0], c[1])
|
||||
# mulm(s[5], s[1], c[0])
|
||||
# mulm(s[6], s[0], c[1])
|
||||
# mulc(c[2], c[0], c[1])
|
||||
# adds(s[7], s[2], s[5])
|
||||
# adds(s[8], s[7], s[6])
|
||||
# addm(self.args[0], s[8], c[2])
|
||||
|
||||
# hack for circular dependency
|
||||
from Compiler import comparison
|
||||
|
||||
@@ -18,6 +18,8 @@ from Compiler import program
|
||||
### MUST also be changed. (+ the documentation)
|
||||
###
|
||||
opcodes = dict(
|
||||
# Emulation
|
||||
CISC = 0,
|
||||
# Load/store
|
||||
LDI = 0x1,
|
||||
LDSI = 0x2,
|
||||
@@ -98,6 +100,7 @@ opcodes = dict(
|
||||
MATMULS = 0xAA,
|
||||
MATMULSM = 0xAB,
|
||||
CONV2DS = 0xAC,
|
||||
CHECK = 0xAF,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -409,7 +412,7 @@ def cisc(function):
|
||||
program.curr_block.instructions.append(self)
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[0]]
|
||||
return [call[0][0] for call in self.calls]
|
||||
|
||||
def get_used(self):
|
||||
return self.used
|
||||
@@ -423,6 +426,7 @@ def cisc(function):
|
||||
|
||||
def merge(self, other):
|
||||
self.calls += other.calls
|
||||
self.used += other.used
|
||||
|
||||
def get_size(self):
|
||||
return self.args[0].size
|
||||
@@ -470,7 +474,9 @@ def cisc(function):
|
||||
inst.copy(size, subs)
|
||||
reset_global_vector_size()
|
||||
|
||||
def expand_merged(self):
|
||||
def expand_merged(self, skip):
|
||||
if function.__name__ in skip:
|
||||
return [self], 0
|
||||
tape = program.curr_tape
|
||||
block = tape.BasicBlock(tape, None, None)
|
||||
tape.active_basicblock = block
|
||||
@@ -496,10 +502,38 @@ def cisc(function):
|
||||
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
|
||||
reset_global_vector_size()
|
||||
base += reg.size
|
||||
return block.instructions
|
||||
return block.instructions, self.n_rounds - 1
|
||||
|
||||
def expanded_rounds(self):
|
||||
return self.n_rounds - 1
|
||||
def add_usage(self, *args):
|
||||
pass
|
||||
|
||||
def get_bytes(self):
|
||||
assert len(self.kwargs) < 2
|
||||
res = int_to_bytes(opcodes['CISC'])
|
||||
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
|
||||
name = self.function.__name__
|
||||
String.check(name)
|
||||
res += String.encode(name)
|
||||
for call in self.calls:
|
||||
assert not call[1]
|
||||
res += int_to_bytes(len(call[0]) + 2)
|
||||
res += int_to_bytes(call[0][0].size)
|
||||
for arg in call[0]:
|
||||
res += self.arg_to_bytes(arg)
|
||||
return bytearray(res)
|
||||
|
||||
@classmethod
|
||||
def arg_to_bytes(self, arg):
|
||||
if arg is None:
|
||||
return int_to_bytes(0)
|
||||
try:
|
||||
return int_to_bytes(arg.i)
|
||||
except:
|
||||
return int_to_bytes(arg)
|
||||
|
||||
def __str__(self):
|
||||
return self.function.__name__ + ' ' + ', '.join(
|
||||
str(x) for x in itertools.chain(call[0] for call in self.calls))
|
||||
|
||||
MergeCISC.__name__ = function.__name__
|
||||
|
||||
@@ -804,11 +838,8 @@ class Instruction(object):
|
||||
else:
|
||||
return self.args
|
||||
|
||||
def expand_merged(self):
|
||||
return [self]
|
||||
|
||||
def expanded_rounds(self):
|
||||
return 0
|
||||
def expand_merged(self, skip):
|
||||
return [self], 0
|
||||
|
||||
def get_new_args(self, size, subs):
|
||||
new_args = []
|
||||
|
||||
@@ -170,7 +170,7 @@ def print_ln_to(player, ss, *args):
|
||||
|
||||
Example::
|
||||
|
||||
print_ln_to(player, 'output for %s: %s', x.reveal_to(player))
|
||||
print_ln_to(player, 'output for %s: %s', player, x.reveal_to(player))
|
||||
"""
|
||||
cond = player == get_player_id()
|
||||
new_args = []
|
||||
@@ -293,7 +293,9 @@ class Function:
|
||||
self.compile_args = compile_args
|
||||
def __call__(self, *args):
|
||||
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
|
||||
get_reg_type = lambda x: regint if isinstance(x, int) else type(x)
|
||||
from .types import _types
|
||||
get_reg_type = lambda x: \
|
||||
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
|
||||
if len(args) not in self.type_args:
|
||||
# first call
|
||||
type_args = collections.defaultdict(list)
|
||||
@@ -324,7 +326,8 @@ class Function:
|
||||
j = 0
|
||||
for i_arg in type_args[reg_type]:
|
||||
if get_reg_type(args[i_arg]) != reg_type:
|
||||
raise CompilerError('type mismatch')
|
||||
raise CompilerError('type mismatch: "%s" not of type "%s"' %
|
||||
(args[i_arg], reg_type))
|
||||
store_in_mem(args[i_arg], bases[reg_type] + j)
|
||||
j += util.mem_size(reg_type)
|
||||
return self.on_call(base, bases)
|
||||
@@ -371,7 +374,7 @@ class FunctionBlock(Function):
|
||||
parent_node = get_tape().req_node
|
||||
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
|
||||
block = get_tape().active_basicblock
|
||||
block.alloc_pool = defaultdict(set)
|
||||
block.alloc_pool = defaultdict(list)
|
||||
del parent_node.children[-1]
|
||||
self.node = get_tape().req_node
|
||||
if get_program().verbose:
|
||||
@@ -763,22 +766,34 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
|
||||
@function_block
|
||||
def step(l):
|
||||
l = MemValue(l)
|
||||
@for_range_opt_multithread(n_threads, len(a) // k)
|
||||
m = 2 ** int(math.ceil(math.log(len(a), 2)))
|
||||
@for_range_opt_multithread(n_threads, m // k)
|
||||
def _(i):
|
||||
n_inner = l // k
|
||||
j = i % n_inner
|
||||
i //= n_inner
|
||||
base = i*l + j
|
||||
step = l//k
|
||||
def swap(base, step):
|
||||
if m == len(a):
|
||||
a[base], a[base + step] = \
|
||||
cond_swap(a[base], a[base + step])
|
||||
else:
|
||||
# ignore values outside range
|
||||
go = base + step < len(a)
|
||||
x = a.maybe_get(go, base)
|
||||
y = a.maybe_get(go, base + step)
|
||||
tmp = cond_swap(x, y)
|
||||
for i, idx in enumerate((base, base + step)):
|
||||
a.maybe_set(go, idx, tmp[i])
|
||||
if k == 2:
|
||||
a[base], a[base+step] = \
|
||||
cond_swap(a[base], a[base+step])
|
||||
swap(base, step)
|
||||
else:
|
||||
@for_range_opt(n_innermost)
|
||||
def f(i):
|
||||
m1 = step + i * 2 * step
|
||||
m2 = m1 + base
|
||||
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
|
||||
swap(m2, step)
|
||||
steps[key] = step
|
||||
steps[key](l)
|
||||
|
||||
@@ -870,6 +885,8 @@ def for_range_parallel(n_parallel, n_loops):
|
||||
"""
|
||||
Decorator to execute a loop :py:obj:`n_loops` up to
|
||||
:py:obj:`n_parallel` loop bodies in parallel.
|
||||
Using any other control flow instruction inside the loop breaks
|
||||
the optimization.
|
||||
|
||||
:param n_parallel: compile-time (int)
|
||||
:param n_loops: regint/cint/int
|
||||
@@ -887,9 +904,12 @@ def for_range_parallel(n_parallel, n_loops):
|
||||
def for_range_opt(n_loops, budget=None):
|
||||
""" Execute loop bodies in parallel up to an optimization budget.
|
||||
This prevents excessive loop unrolling. The budget is respected
|
||||
even with nested loops. Note that optimization is rather
|
||||
even with nested loops. Note that the optimization is rather
|
||||
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
|
||||
using :py:func:`for_range_parallel` in this case.
|
||||
Using further control flow constructions inside other than
|
||||
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
|
||||
optimization.
|
||||
|
||||
:param n_loops: int/regint/cint
|
||||
:param budget: number of instructions after which to start optimization (default is 100,000)
|
||||
@@ -1082,18 +1102,19 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
"""
|
||||
if n_items is None:
|
||||
n_items = n_threads
|
||||
if max_size is None:
|
||||
if max_size is None or n_items <= max_size:
|
||||
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
|
||||
reducer=None, looping=False)
|
||||
else:
|
||||
def wrapper(function):
|
||||
@multithread(n_threads, n_items)
|
||||
def new_function(base, size):
|
||||
for i in range(0, size, max_size):
|
||||
part_base = base + i
|
||||
part_size = min(max_size, size - i)
|
||||
function(part_base, part_size)
|
||||
break_point()
|
||||
@for_range(size // max_size)
|
||||
def _(i):
|
||||
function(base + i * max_size, max_size)
|
||||
rem = size % max_size
|
||||
if rem:
|
||||
function(base + size - rem, rem)
|
||||
return wrapper
|
||||
|
||||
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
@@ -1200,6 +1221,23 @@ def map_sum(n_threads, n_parallel, n_loops, n_items, value_types):
|
||||
return tuple(a + b for a,b in zip(x,y))
|
||||
return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
|
||||
|
||||
def tree_reduce_multithread(n_threads, function, vector):
|
||||
inputs = vector.Array(len(vector))
|
||||
inputs.assign_vector(vector)
|
||||
outputs = vector.Array(len(vector) // 2)
|
||||
left = len(vector)
|
||||
while left > 1:
|
||||
@multithread(n_threads, left // 2)
|
||||
def _(base, size):
|
||||
outputs.assign_vector(
|
||||
function(inputs.get_vector(2 * base, size),
|
||||
inputs.get_vector(2 * base + size, size)), base)
|
||||
inputs.assign_vector(outputs.get_vector(0, left // 2))
|
||||
if left % 2 == 1:
|
||||
inputs[left // 2] = inputs[left - 1]
|
||||
left = (left + 1) // 2
|
||||
return inputs[0]
|
||||
|
||||
def foreach_enumerate(a):
|
||||
""" Run-time loop over public data. This uses
|
||||
``Player-Data/Public-Input/<progname>``. Example:
|
||||
@@ -1511,6 +1549,15 @@ def break_point(name=''):
|
||||
"""
|
||||
get_tape().start_new_basicblock(name=name)
|
||||
|
||||
def check_point():
|
||||
"""
|
||||
Force MAC checks in current thread and all idle threads if the
|
||||
current thread is the main thread. This implies a break point.
|
||||
"""
|
||||
break_point('pre-check')
|
||||
check()
|
||||
break_point('post-check')
|
||||
|
||||
# Fixed point ops
|
||||
|
||||
from math import ceil, log
|
||||
@@ -1566,6 +1613,9 @@ def cint_cint_division(a, b, k, f):
|
||||
# theta can be replaced with something smaller
|
||||
# for safety we assume that is the same theta from previous GS method
|
||||
|
||||
if get_program().options.ring:
|
||||
assert 2 * f < int(get_program().options.ring)
|
||||
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
|
||||
@@ -1579,9 +1629,11 @@ def cint_cint_division(a, b, k, f):
|
||||
B = absolute_b
|
||||
W = w0
|
||||
|
||||
for i in range(1, theta):
|
||||
A = (A * W) >> f
|
||||
B = (B * W) >> f
|
||||
corr = cint(1) << (f - 1)
|
||||
|
||||
for i in range(theta):
|
||||
A = (A * W + corr) >> f
|
||||
B = (B * W + corr) >> f
|
||||
W = two - B
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
@@ -1592,7 +1644,7 @@ def sint_cint_division(a, b, k, f, kappa):
|
||||
"""
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
sign_b = cint(1) - 2 * cint(b < 0)
|
||||
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
|
||||
absolute_b = b * sign_b
|
||||
absolute_a = a * sign_a
|
||||
@@ -1652,7 +1704,8 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
return y
|
||||
def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
|
||||
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Approximate reciprocal of [b]:
|
||||
Given [b], compute [1/b]
|
||||
@@ -1662,7 +1715,7 @@ def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
|
||||
d = alpha - 2 * c
|
||||
w = d * v
|
||||
w = w.round(2 * k, 2 * (k - f), kappa, nearest, signed=True)
|
||||
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
|
||||
# now w * 2 ^ {-f} should be an initial approximation of 1/b
|
||||
return w
|
||||
|
||||
@@ -1674,7 +1727,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
# For simplex, we can get rid of computing abs(b)
|
||||
temp = None
|
||||
if simplex_flag == False:
|
||||
temp = comparison.LessThanZero(b, 2 * k, kappa)
|
||||
temp = comparison.LessThanZero(b, k, kappa)
|
||||
elif simplex_flag == True:
|
||||
temp = cint(0)
|
||||
|
||||
@@ -1682,7 +1735,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
absolute_val = sign * b
|
||||
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k, kappa)[::-1]
|
||||
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
|
||||
suffixes = PreOR(bits, kappa)[::-1]
|
||||
|
||||
z = [0] * k
|
||||
@@ -1690,10 +1743,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
z[i] = suffixes[i] - suffixes[i+1]
|
||||
z[k - 1] = suffixes[k-1]
|
||||
|
||||
#doing complicated stuff to compute v = 2^{k-m}
|
||||
acc = cint(0)
|
||||
for i in range(k):
|
||||
acc += two_power(k-i-1) * z[i]
|
||||
acc = sint.bit_compose(reversed(z))
|
||||
|
||||
part_reciprocal = absolute_val * acc
|
||||
signed_acc = sign * acc
|
||||
|
||||
749
Compiler/ml.py
749
Compiler/ml.py
File diff suppressed because it is too large
Load Diff
@@ -846,3 +846,54 @@ def acos(x):
|
||||
"""
|
||||
y = asin(x)
|
||||
return pi_over_2 - y
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""
|
||||
Hyperbolic tangent. For efficiency, accuracy is diminished
|
||||
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
|
||||
:math:`f` denote the fixed-point parameters.
|
||||
"""
|
||||
limit = math.log(2 ** (x.k - x.f - 2)) / 2
|
||||
s = x < -limit
|
||||
t = x > limit
|
||||
y = pow_fx(math.e, 2 * x)
|
||||
return s.if_else(-1, t.if_else(1, (y - 1) / (y + 1)))
|
||||
|
||||
|
||||
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
|
||||
|
||||
def Sep(x):
|
||||
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
|
||||
t = x.v * (1 + x.v.bit_compose(b_i.bit_not() for b_i in b[-2 * x.f + 1:]))
|
||||
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
||||
b += [b[0].long_one()]
|
||||
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
|
||||
|
||||
def SqrtComp(z, old=False):
|
||||
f = types.sfix.f
|
||||
k = len(z)
|
||||
if isinstance(z[0], types.sint):
|
||||
return types.sfix._new(sum(z[i] * types.cfix(
|
||||
2 ** (-(i - f + 1) / 2)).v for i in range(k)))
|
||||
k_prime = k // 2
|
||||
f_prime = f // 2
|
||||
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
|
||||
c0 = types.sfix(2 ** (f / 2 + 1))
|
||||
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
|
||||
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
||||
if old:
|
||||
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
|
||||
else:
|
||||
b = util.tree_reduce(lambda x, y: x.bit_xor(y), z[::2])
|
||||
return types.sint.conv(b).if_else(c1, c0) * tmp
|
||||
|
||||
@types.vectorize
|
||||
def InvertSqrt(x, old=False):
|
||||
"""
|
||||
Reciprocal square root approximation by `Lu et al.
|
||||
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
|
||||
"""
|
||||
u, z = Sep(x)
|
||||
c = 3.14736 + u * (4.63887 * u - 5.77789)
|
||||
return c * SqrtComp(z, old=old)
|
||||
|
||||
@@ -44,7 +44,7 @@ class Masking(NonLinear):
|
||||
d = [None]*k
|
||||
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
|
||||
d[i] = r[i].bit_xor(b)
|
||||
return 1 - types.sint.conv(self.kor(d))
|
||||
return 1 - types.sintbit.conv(self.kor(d))
|
||||
|
||||
class Prime(Masking):
|
||||
""" Non-linear functionality modulo a prime with statistical masking. """
|
||||
@@ -71,8 +71,11 @@ class Prime(Masking):
|
||||
def _trunc_pr(self, a, k, m, signed=None):
|
||||
return TruncPrField(a, k, m, self.kappa)
|
||||
|
||||
def bit_dec(self, a, k, m):
|
||||
return BitDecField(a, k, m, self.kappa)
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
if maybe_mixed:
|
||||
return BitDecFieldRaw(a, k, m, self.kappa)
|
||||
else:
|
||||
return BitDecField(a, k, m, self.kappa)
|
||||
|
||||
def kor(self, d):
|
||||
return KOR(d, self.kappa)
|
||||
@@ -85,7 +88,7 @@ class KnownPrime(NonLinear):
|
||||
def _mod2m(self, a, k, m, signed):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
return sint.bit_compose(self.bit_dec(a, k, k)[:m])
|
||||
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
|
||||
|
||||
def _trunc_pr(self, a, k, m, signed):
|
||||
# nearest truncation
|
||||
@@ -96,14 +99,14 @@ class KnownPrime(NonLinear):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
k += 1
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k)[m:])
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
|
||||
if signed:
|
||||
res -= cint(1) << (k - m - 2)
|
||||
return res
|
||||
|
||||
def bit_dec(self, a, k, m):
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
assert k < self.prime.bit_length()
|
||||
bits = BitDecFull(a)
|
||||
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
|
||||
if len(bits) < m:
|
||||
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
|
||||
return bits[:m]
|
||||
@@ -111,7 +114,7 @@ class KnownPrime(NonLinear):
|
||||
def eqz(self, a, k):
|
||||
# always signed
|
||||
a += two_power(k)
|
||||
return 1 - KORL(self.bit_dec(a, k, k))
|
||||
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
|
||||
|
||||
class Ring(Masking):
|
||||
""" Non-linear functionality modulo a power of two known at compile time.
|
||||
@@ -130,8 +133,11 @@ class Ring(Masking):
|
||||
def _trunc_pr(self, a, k, m, signed):
|
||||
return TruncPrRing(a, k, m, signed=signed)
|
||||
|
||||
def bit_dec(self, a, k, m):
|
||||
return BitDecRing(a, k, m)
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
if maybe_mixed:
|
||||
return BitDecRingRaw(a, k, m)
|
||||
else:
|
||||
return BitDecRing(a, k, m)
|
||||
|
||||
def kor(self, d):
|
||||
return KORL(d)
|
||||
|
||||
@@ -28,9 +28,7 @@ data_types = dict(
|
||||
square = 1,
|
||||
bit = 2,
|
||||
inverse = 3,
|
||||
bittriple = 4,
|
||||
bitgf2ntriple = 5,
|
||||
dabit = 6,
|
||||
dabit = 4,
|
||||
)
|
||||
|
||||
field_types = dict(
|
||||
@@ -62,6 +60,7 @@ class defaults:
|
||||
asmoutfile = None
|
||||
stop = False
|
||||
insecure = False
|
||||
keep_cisc = False
|
||||
|
||||
class Program(object):
|
||||
""" A program consists of a list of tapes representing the whole
|
||||
@@ -80,14 +79,14 @@ class Program(object):
|
||||
self.init_names(args)
|
||||
self._security = 40
|
||||
self.prime = None
|
||||
self.tapes = []
|
||||
if sum(x != 0 for x in(options.ring, options.field,
|
||||
options.binary)) > 1:
|
||||
raise CompilerError('can only use one out of -B, -R, -F')
|
||||
if options.prime and (options.ring or options.binary):
|
||||
raise CompilerError('can only use one out of -B, -R, -p')
|
||||
if options.ring:
|
||||
self.bit_length = int(options.ring) - 1
|
||||
self.non_linear = Ring(int(options.ring))
|
||||
self.set_ring_size(int(options.ring))
|
||||
else:
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if options.prime:
|
||||
@@ -108,7 +107,6 @@ class Program(object):
|
||||
if self.verbose:
|
||||
print('Galois length:', self.galois_length)
|
||||
self.tape_counter = 0
|
||||
self.tapes = []
|
||||
self._curr_tape = None
|
||||
self.DEBUG = options.debug
|
||||
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
|
||||
@@ -204,6 +202,16 @@ class Program(object):
|
||||
for arg in args[1:])
|
||||
self.progname = progname
|
||||
|
||||
def set_ring_size(self, ring_size):
|
||||
from .non_linear import Ring
|
||||
for tape in self.tapes:
|
||||
prev = tape.req_bit_length['p']
|
||||
if prev and prev != ring_size:
|
||||
raise CompilerError('cannot have different ring sizes')
|
||||
self.bit_length = ring_size - 1
|
||||
self.non_linear = Ring(ring_size)
|
||||
self.options.ring = str(ring_size)
|
||||
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False):
|
||||
"""
|
||||
Create a new tape from a function. See
|
||||
@@ -414,7 +422,7 @@ class Program(object):
|
||||
self.curr_tape.start_new_basicblock(None, 'memory-usage')
|
||||
# reset register counter to 0
|
||||
self.curr_tape.init_registers()
|
||||
for mem_type,size in list(self.allocated_mem.items()):
|
||||
for mem_type,size in sorted(self.allocated_mem.items()):
|
||||
if size:
|
||||
#print "Memory of type '%s' of size %d" % (mem_type, size)
|
||||
if mem_type in self.types:
|
||||
@@ -488,7 +496,7 @@ class Program(object):
|
||||
else:
|
||||
if change and not self.options.ring:
|
||||
raise CompilerError('splitting only supported for rings')
|
||||
assert change > 1
|
||||
assert change > 1 or change == False
|
||||
self._split = change
|
||||
|
||||
def use_square(self, change=None):
|
||||
@@ -575,7 +583,7 @@ class Tape:
|
||||
scope.children.append(self)
|
||||
self.alloc_pool = scope.alloc_pool
|
||||
else:
|
||||
self.alloc_pool = defaultdict(set)
|
||||
self.alloc_pool = defaultdict(list)
|
||||
self.purged = False
|
||||
self.n_rounds = 0
|
||||
self.n_to_merge = 0
|
||||
@@ -647,9 +655,14 @@ class Tape:
|
||||
|
||||
def expand_cisc(self):
|
||||
new_instructions = []
|
||||
if self.parent.program.options.keep_cisc:
|
||||
skip = ['LTZ', 'Trunc']
|
||||
else:
|
||||
skip = []
|
||||
for inst in self.instructions:
|
||||
new_instructions.extend(inst.expand_merged())
|
||||
self.n_rounds += inst.expanded_rounds()
|
||||
new_inst, n_rounds = inst.expand_merged(skip)
|
||||
new_instructions.extend(new_inst)
|
||||
self.n_rounds += n_rounds
|
||||
self.instructions = new_instructions
|
||||
|
||||
def __str__(self):
|
||||
@@ -774,7 +787,10 @@ class Tape:
|
||||
|
||||
# allocate registers
|
||||
reg_counts = self.count_regs()
|
||||
if not options.noreallocate:
|
||||
if options.noreallocate:
|
||||
if self.program.verbose:
|
||||
print('Tape register usage:', dict(reg_counts))
|
||||
else:
|
||||
if self.program.verbose:
|
||||
print('Tape register usage before re-allocation:',
|
||||
dict(reg_counts))
|
||||
@@ -1071,7 +1087,7 @@ class Tape:
|
||||
if size is None:
|
||||
size = Compiler.instructions_base.get_global_vector_size()
|
||||
if size is not None and size > self.maximum_size:
|
||||
raise CompilerError('vector too large')
|
||||
raise CompilerError('vector too large: %d' % size)
|
||||
self.size = size
|
||||
self.vectorbase = self
|
||||
self.relative_i = 0
|
||||
|
||||
@@ -591,12 +591,12 @@ class _register(Tape.Register, _number, _structure):
|
||||
def prep_res(cls, other):
|
||||
return cls()
|
||||
|
||||
@staticmethod
|
||||
def bit_compose(bits):
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
""" Compose value from bits.
|
||||
|
||||
:param bits: iterable of any type implementing left shift """
|
||||
return sum(b << i for i,b in enumerate(bits))
|
||||
return sum(cls.conv(b) << i for i,b in enumerate(bits))
|
||||
|
||||
@classmethod
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
@@ -840,6 +840,7 @@ class cint(_clear, _int):
|
||||
def in_immediate_range(value):
|
||||
return value < 2**31 and value >= -2**31
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, val=None, size=None):
|
||||
"""
|
||||
:param val: initialization (cint/regint/int/cgf2n or list thereof)
|
||||
@@ -1119,12 +1120,6 @@ class cgf2n(_clear, _gf2n):
|
||||
elif chunk:
|
||||
sum += chunk
|
||||
|
||||
def __mul__(self, other):
|
||||
""" Clear :math:`\mathrm{GF}(2^n)` multiplication.
|
||||
|
||||
:param other: cgf2n/regint/int """
|
||||
return super(cgf2n, self).__mul__(other)
|
||||
|
||||
def __neg__(self):
|
||||
""" Identity. """
|
||||
return self
|
||||
@@ -1209,7 +1204,9 @@ class regint(_register, _int):
|
||||
def get_random(cls, bit_length):
|
||||
""" Public insecure randomness.
|
||||
|
||||
:param bit_length: number of bits (int) """
|
||||
:param bit_length: number of bits (int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
if isinstance(bit_length, int):
|
||||
bit_length = regint(bit_length)
|
||||
res = cls()
|
||||
@@ -1582,7 +1579,9 @@ class _secret(_register):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret input from player.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
:param player: public (regint/cint/int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = cls()
|
||||
asm_input(res, player)
|
||||
return res
|
||||
@@ -1592,7 +1591,9 @@ class _secret(_register):
|
||||
def get_random_triple(cls):
|
||||
""" Secret random triple according to security model.
|
||||
|
||||
:return: :math:`(a, b, ab)` """
|
||||
:return: :math:`(a, b, ab)`
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = (cls(), cls(), cls())
|
||||
triple(*res)
|
||||
return res
|
||||
@@ -1602,7 +1603,9 @@ class _secret(_register):
|
||||
def get_random_bit(cls):
|
||||
""" Secret random bit according to security model.
|
||||
|
||||
:return: 0/1 50-50 """
|
||||
:return: 0/1 50-50
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = cls()
|
||||
bit(res)
|
||||
return res
|
||||
@@ -1612,7 +1615,9 @@ class _secret(_register):
|
||||
def get_random_square(cls):
|
||||
""" Secret random square according to security model.
|
||||
|
||||
:return: :math:`(a, a^2)` """
|
||||
:return: :math:`(a, a^2)`
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = (cls(), cls())
|
||||
square(*res)
|
||||
return res
|
||||
@@ -1622,7 +1627,9 @@ class _secret(_register):
|
||||
def get_random_inverse(cls):
|
||||
""" Secret random inverse tuple according to security model.
|
||||
|
||||
:return: :math:`(a, a^{-1})` """
|
||||
:return: :math:`(a, a^{-1})`
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = (cls(), cls())
|
||||
inverse(*res)
|
||||
return res
|
||||
@@ -1717,16 +1724,51 @@ class _secret(_register):
|
||||
else:
|
||||
self.load_clear(self.clear_type(val))
|
||||
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
""" Compose value from bits.
|
||||
|
||||
:param bits: iterable of any type convertible to sint """
|
||||
from Compiler.GC.types import sbits, sbitintvec
|
||||
bits = list(bits)
|
||||
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
|
||||
if program.use_edabit():
|
||||
mask = cls.get_edabit(len(bits), strict=True, size=bits[0].n)
|
||||
else:
|
||||
tmp = sint(size=bits[0].n)
|
||||
randoms(tmp, len(bits))
|
||||
n_overflow_bits = min(program.use_split().bit_length(),
|
||||
int(program.options.ring) - len(bits))
|
||||
mask_bits = tmp.bit_decompose(len(bits) + n_overflow_bits,
|
||||
maybe_mixed=True)
|
||||
if n_overflow_bits:
|
||||
overflow = sint.bit_compose(
|
||||
sint.conv(x) for x in mask_bits[-n_overflow_bits:])
|
||||
mask = tmp - (overflow << len(bits)), \
|
||||
mask_bits[:-n_overflow_bits]
|
||||
else:
|
||||
mask = tmp, mask_bits
|
||||
t = sbitintvec.get_type(len(bits) + 1)
|
||||
masked = t.from_vec(mask[1] + [0]) + t.from_vec(bits + [0])
|
||||
overflow = masked.v[-1]
|
||||
masked = cls.bit_compose(x.reveal().to_regint_by_bit() for x in masked.v[:-1])
|
||||
return masked - mask[0] + (cls(overflow) << len(bits))
|
||||
else:
|
||||
return super(_secret, cls).bit_compose(bits)
|
||||
|
||||
@set_instruction_type
|
||||
@read_mem_value
|
||||
@vectorize
|
||||
def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False):
|
||||
cls = self.__class__
|
||||
res = self.prep_res(other)
|
||||
cls = type(res)
|
||||
if isinstance(other, regint):
|
||||
other = res.clear_type(other)
|
||||
if isinstance(other, cls):
|
||||
s_inst(res, self, other)
|
||||
if reverse:
|
||||
s_inst(res, other, self)
|
||||
else:
|
||||
s_inst(res, self, other)
|
||||
elif isinstance(other, res.clear_type):
|
||||
if reverse:
|
||||
m_inst(res, other, self)
|
||||
@@ -1861,10 +1903,12 @@ class sint(_secret, _int):
|
||||
def get_random_int(cls, bits):
|
||||
""" Secret random n-bit number according to security model.
|
||||
|
||||
:param bits: compile-time integer (int) """
|
||||
:param bits: compile-time integer (int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
if program.use_edabit():
|
||||
return sint.get_edabit(bits, True)[0]
|
||||
elif program.use_split() > 2:
|
||||
elif program.use_split() > 2 and program.use_split() < 5:
|
||||
tmp = sint()
|
||||
randoms(tmp, bits)
|
||||
x = tmp.split_to_two_summands(bits, True)
|
||||
@@ -1882,7 +1926,10 @@ class sint(_secret, _int):
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_random(cls):
|
||||
""" Secret random ring element according to security model. """
|
||||
""" Secret random ring element according to security model.
|
||||
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = sint()
|
||||
randomfulls(res)
|
||||
return res
|
||||
@@ -1891,7 +1938,9 @@ class sint(_secret, _int):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret input.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
:param player: public (regint/cint/int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = cls()
|
||||
inputmixed('int', res, player)
|
||||
return res
|
||||
@@ -1915,7 +1964,7 @@ class sint(_secret, _int):
|
||||
else:
|
||||
a = [sint.get_random_bit() for i in range(n_bits)]
|
||||
return sint.bit_compose(a), a
|
||||
program.curr_tape.require_bit_length(n_bits)
|
||||
program.curr_tape.require_bit_length(n_bits - 1)
|
||||
whole = cls()
|
||||
size = get_global_vector_size()
|
||||
from Compiler.GC.types import sbits, sbitvec
|
||||
@@ -1931,6 +1980,7 @@ class sint(_secret, _int):
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
@vectorize
|
||||
def bit_decompose_clear(a, n_bits):
|
||||
return floatingpoint.bits(a, n_bits)
|
||||
|
||||
@@ -2055,7 +2105,7 @@ class sint(_secret, _int):
|
||||
|
||||
:param other: sint/cint/regint/int
|
||||
:return: 0/1 (sint) """
|
||||
res = sint()
|
||||
res = sintbit()
|
||||
comparison.LTZ(res, self - other,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
@@ -2064,7 +2114,7 @@ class sint(_secret, _int):
|
||||
@read_mem_value
|
||||
@vectorize
|
||||
def __gt__(self, other, bit_length=None, security=None):
|
||||
res = sint()
|
||||
res = sintbit()
|
||||
comparison.LTZ(res, other - self,
|
||||
(bit_length or program.bit_length) + 1,
|
||||
security or program.security)
|
||||
@@ -2185,13 +2235,14 @@ class sint(_secret, _int):
|
||||
return floatingpoint.Trunc(other, program.bit_length, self, program.security)
|
||||
|
||||
@vectorize
|
||||
def bit_decompose(self, bit_length=None, security=None):
|
||||
def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False):
|
||||
""" Secret bit decomposition. """
|
||||
if bit_length == 0:
|
||||
return []
|
||||
bit_length = bit_length or program.bit_length
|
||||
security = security or program.security
|
||||
return floatingpoint.BitDec(self, bit_length, bit_length, security)
|
||||
assert program.security == security or program.security
|
||||
return program.non_linear.bit_dec(self, bit_length, bit_length,
|
||||
maybe_mixed)
|
||||
|
||||
def TruncMul(self, other, k, m, kappa=None, nearest=False):
|
||||
return (self * other).round(k, m, kappa, nearest, signed=True)
|
||||
@@ -2249,6 +2300,7 @@ class sint(_secret, _int):
|
||||
return floatingpoint.two_power(n)
|
||||
|
||||
def split_to_n_summands(self, length, n):
|
||||
comparison.require_ring_size(length, 'splitting')
|
||||
from .GC.types import sbits
|
||||
from .GC.instructions import split
|
||||
columns = [[sbits.get_type(self.size)()
|
||||
@@ -2274,7 +2326,9 @@ class sint(_secret, _int):
|
||||
@vectorize
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Result potentially written to ``Player-Data/Private-Output-P<player>.``
|
||||
Result potentially written to
|
||||
``Player-Data/Private-Output-P<player>``, but not if
|
||||
:py:obj:`player` is a :py:class:`regint`.
|
||||
|
||||
:param player: public integer (int/regint/cint):
|
||||
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
|
||||
@@ -2288,6 +2342,65 @@ class sint(_secret, _int):
|
||||
else:
|
||||
return super(sint, self).reveal_to(player)
|
||||
|
||||
class sintbit(sint):
|
||||
@classmethod
|
||||
def prep_res(cls, other):
|
||||
return sint()
|
||||
|
||||
def load_other(self, other):
|
||||
if isinstance(other, sint):
|
||||
movs(self, other)
|
||||
else:
|
||||
super(sintbit, self).load_other(other)
|
||||
|
||||
@vectorize
|
||||
def __and__(self, other):
|
||||
if isinstance(other, sintbit):
|
||||
res = sintbit()
|
||||
muls(res, self, other)
|
||||
return res
|
||||
elif util.is_zero(other):
|
||||
return 0
|
||||
elif util.is_one(other):
|
||||
return self
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def __or__(self, other):
|
||||
if isinstance(other, sintbit):
|
||||
res = sintbit()
|
||||
adds(res, self, other - self * other)
|
||||
return res
|
||||
elif util.is_zero(other):
|
||||
return self
|
||||
elif util.is_one(other):
|
||||
return 1
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def __xor__(self, other):
|
||||
if isinstance(other, sintbit):
|
||||
res = sintbit()
|
||||
adds(res, self, other - 2 * self * other)
|
||||
return res
|
||||
elif util.is_zero(other):
|
||||
return self
|
||||
elif util.is_one(other):
|
||||
return 1
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def __rsub__(self, other):
|
||||
if util.is_one(other):
|
||||
res = sintbit()
|
||||
subsfi(res, self, 1)
|
||||
return res
|
||||
else:
|
||||
return super(sintbit, self).__rsub__(other)
|
||||
|
||||
class sgf2n(_secret, _gf2n):
|
||||
""" Secret :math:`\mathrm{GF}(2^n)` value. """
|
||||
__slots__ = []
|
||||
@@ -2437,10 +2550,11 @@ class sgf2n(_secret, _gf2n):
|
||||
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
|
||||
|
||||
for t in (sint, sgf2n):
|
||||
t.bit_type = t
|
||||
t.basic_type = t
|
||||
t.default_type = t
|
||||
|
||||
sint.bit_type = sintbit
|
||||
sgf2n.bit_type = sgf2n
|
||||
|
||||
class _bitint(object):
|
||||
bits = None
|
||||
@@ -3046,14 +3160,17 @@ class cfix(_number, _structure):
|
||||
|
||||
@staticmethod
|
||||
def int_rep(v, f, k=None):
|
||||
if isinstance(v, regint):
|
||||
v = cint(v)
|
||||
res = v * (2 ** f)
|
||||
try:
|
||||
res = int(round(res))
|
||||
if k and abs(res) >= 2 ** k:
|
||||
if k and res >= 2 ** (k - 1) or res < -2 ** (k - 1):
|
||||
limit = 2 ** (k - f - 1)
|
||||
raise CompilerError(
|
||||
'Value out of fixed-point range (maximum %d). '
|
||||
'Value out of fixed-point range [-%d, %d). '
|
||||
'Use `sfix.set_precision(f, k)` with k being at least f+%d'
|
||||
% (2 ** (k - f), math.ceil(math.log(abs(v), 2)) + 1))
|
||||
% (limit, limit, res.bit_length() - f + 1))
|
||||
except TypeError:
|
||||
pass
|
||||
return res
|
||||
@@ -3268,6 +3385,14 @@ class cfix(_number, _structure):
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
|
||||
@vectorize
|
||||
def __rtruediv__(self, other):
|
||||
""" Fixed-point division.
|
||||
|
||||
:param other: sfix/sint/cfix/cint/regint/int """
|
||||
other = parse_type(other, self.k, self.f)
|
||||
return other / self
|
||||
|
||||
def print_plain(self):
|
||||
""" Clear fixed-point output. """
|
||||
print_float_plain(cint.conv(self.v), cint(-self.f), \
|
||||
@@ -3468,7 +3593,7 @@ class _fix(_single):
|
||||
set_precision = classmethod(set_precision)
|
||||
|
||||
@classmethod
|
||||
def set_precision_from_args(cls, program):
|
||||
def set_precision_from_args(cls, program, adapt_ring=False):
|
||||
f = None
|
||||
k = None
|
||||
for arg in program.args:
|
||||
@@ -3484,6 +3609,15 @@ class _fix(_single):
|
||||
cfix.set_precision(f, k)
|
||||
elif k is not None:
|
||||
raise CompilerError('need to set fractional precision')
|
||||
if 'nearest' in program.args:
|
||||
print('Nearest rounding instead of proabilistic '
|
||||
'for fixed-point computation')
|
||||
cls.round_nearest = True
|
||||
if adapt_ring and program.options.ring:
|
||||
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
|
||||
if need != int(program.options.ring):
|
||||
print('Changing computation modulus to 2^%d' % need)
|
||||
program.set_ring_size(need)
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
@@ -3609,11 +3743,14 @@ class _fix(_single):
|
||||
:param other: sfix/cfix/sint/cint/regint/int """
|
||||
if util.is_constant_float(other):
|
||||
assert other != 0
|
||||
other_length = self.f + math.ceil(math.log(abs(other), 2))
|
||||
if other_length >= self.k:
|
||||
factor = 2 ** (self.k - other_length - 1)
|
||||
log = math.ceil(math.log(abs(other), 2))
|
||||
other_length = self.f + log
|
||||
if other_length >= self.k - 1:
|
||||
factor = 2 ** (self.k - other_length - 2)
|
||||
self *= factor
|
||||
other *= factor
|
||||
if 2 ** log == other:
|
||||
return self * 2 ** -log
|
||||
other = self.coerce(other)
|
||||
assert self.k == other.k
|
||||
assert self.f == other.f
|
||||
@@ -3660,7 +3797,9 @@ class sfix(_fix):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret fixed-point input.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
:param player: public (regint/cint/int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
cls.int_type.require_bit_length(cls.k)
|
||||
v = cls.int_type()
|
||||
inputmixed('fix', v, cls.f, player)
|
||||
@@ -3677,6 +3816,7 @@ class sfix(_fix):
|
||||
|
||||
:param lower: float
|
||||
:param upper: float
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
log_range = int(math.log(upper - lower, 2))
|
||||
n_bits = log_range + cls.f
|
||||
@@ -3732,7 +3872,8 @@ class sfix(_fix):
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Raw representation possibly written to
|
||||
``Player-Data/Private-Output-P<player>.``
|
||||
``Player-Data/Private-Output-P<player>``, but not if
|
||||
:py:obj:`player` is a :py:class:`regint`.
|
||||
|
||||
:param player: public integer (int/regint/cint)
|
||||
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
|
||||
@@ -4066,7 +4207,9 @@ class sfloat(_number, _structure):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret floating-point input.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
:param player: public (regint/cint/int)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
v = sint()
|
||||
p = sint()
|
||||
z = sint()
|
||||
@@ -4444,6 +4587,7 @@ class Array(object):
|
||||
self.address_cache = {}
|
||||
self.debug = debug
|
||||
self.creator_tape = program.curr_tape
|
||||
self.sink = None
|
||||
if alloc:
|
||||
self.alloc()
|
||||
|
||||
@@ -4514,6 +4658,17 @@ class Array(object):
|
||||
return
|
||||
self._store(value, self.get_address(index))
|
||||
|
||||
def maybe_get(self, condition, index):
|
||||
return condition * self[condition * index]
|
||||
|
||||
def maybe_set(self, condition, index, value):
|
||||
if self.sink is None:
|
||||
self.sink = self.value_type.Array(1)
|
||||
addresses = (condition.if_else(x, y) for x, y in
|
||||
zip(util.tuplify(self.get_address(index)),
|
||||
util.tuplify(self.sink.get_address(0))))
|
||||
self._store(value, util.untuplify(tuple(addresses)))
|
||||
|
||||
# the following two are useful for compile-time lengths
|
||||
# and thus differ from the usual Python syntax
|
||||
def get_range(self, start, size):
|
||||
@@ -4590,11 +4745,22 @@ class Array(object):
|
||||
|
||||
get_part_vector = get_vector
|
||||
|
||||
def get_part(self, base, size):
|
||||
return Array(size, self.value_type, self.get_address(base))
|
||||
|
||||
def get(self, indices):
|
||||
return self.value_type.load_mem(
|
||||
regint.inc(len(indices), self.address, 0) + indices,
|
||||
size=len(indices))
|
||||
|
||||
def get_slice_vector(self, slice):
|
||||
assert self.value_type.n_elements() == 1
|
||||
assert len(slice) <= self.total_size()
|
||||
base = regint.inc(len(slice), slice.address, 1, 1)
|
||||
inc = regint.inc(len(slice), 0, 1, 1, 1)
|
||||
addresses = slice.value_type.load_mem(base) + inc
|
||||
return self.value_type.load_mem(self.address + addresses)
|
||||
|
||||
def expand_to_vector(self, index, size):
|
||||
assert self.value_type.n_elements() == 1
|
||||
address = regint(size=size)
|
||||
@@ -4641,6 +4807,12 @@ class Array(object):
|
||||
:param other: vector or container of same length and type that supports operations with type of this array """
|
||||
return self.get_vector() * value
|
||||
|
||||
def __truediv__(self, value):
|
||||
""" Vector division.
|
||||
|
||||
:param other: vector or container of same length and type that supports operations with type of this array """
|
||||
return self.get_vector() / value
|
||||
|
||||
def __pow__(self, value):
|
||||
""" Vector power-of computation.
|
||||
|
||||
@@ -4674,6 +4846,16 @@ class Array(object):
|
||||
|
||||
reveal_nested = reveal_list
|
||||
|
||||
def sort(self, n_threads=None):
|
||||
"""
|
||||
Sort in place using Batchers' odd-even merge mergesort
|
||||
with complexity :math:`O(n (\log n)^2)`.
|
||||
|
||||
:param n_threads: number of threads to use (single thread by
|
||||
default)
|
||||
"""
|
||||
library.loopy_odd_even_merge_sort(self, n_threads=n_threads)
|
||||
|
||||
def __str__(self):
|
||||
return '%s array of length %s at %s' % (self.value_type, len(self),
|
||||
self.address)
|
||||
@@ -4784,6 +4966,15 @@ class SubMultiArray(object):
|
||||
assert vector.size <= self.total_size()
|
||||
vector.store_in_mem(self.address + base * part_size)
|
||||
|
||||
def get_slice_vector(self, slice):
|
||||
assert self.value_type.n_elements() == 1
|
||||
part_size = reduce(operator.mul, self.sizes[1:])
|
||||
assert len(slice) * part_size <= self.total_size()
|
||||
base = regint.inc(len(slice) * part_size, slice.address, 1, part_size)
|
||||
inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size)
|
||||
addresses = slice.value_type.load_mem(base) * part_size + inc
|
||||
return self.value_type.load_mem(self.address + addresses)
|
||||
|
||||
def get_addresses(self, *indices):
|
||||
assert self.value_type.n_elements() == 1
|
||||
assert len(indices) == len(self.sizes)
|
||||
@@ -4816,6 +5007,10 @@ class SubMultiArray(object):
|
||||
""" :return: new multidimensional array with same shape and basic type """
|
||||
return MultiArray(self.sizes, self.value_type)
|
||||
|
||||
def get_part(self, start, size):
|
||||
return MultiArray([size] + list(self.sizes[1:]), self.value_type,
|
||||
address=self[start].address)
|
||||
|
||||
def input_from(self, player, budget=None, raw=False):
|
||||
""" Fill with inputs from player if supported by type.
|
||||
|
||||
@@ -4978,7 +5173,7 @@ class SubMultiArray(object):
|
||||
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
|
||||
assert len(indices[1]) == len(indices[2])
|
||||
indices = list(indices)
|
||||
indices[3] *= other.sizes[0]
|
||||
indices[3] *= other.sizes[1]
|
||||
return self.value_type.direct_matrix_mul(
|
||||
self.address, other.address, None, self.sizes[1], 1,
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
@@ -195,7 +195,7 @@ def is_all_ones(x, n):
|
||||
else:
|
||||
return False
|
||||
|
||||
def max(x, y=None):
|
||||
def max(x, y=None, n_threads=None):
|
||||
if y is None:
|
||||
return tree_reduce(max, x)
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "Networking/CryptoPlayer.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "ECDSA/P256Element.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "ECDSA/preprocessing.hpp"
|
||||
#include "ECDSA/sign.hpp"
|
||||
@@ -20,6 +21,8 @@
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
#include "Protocols/Rep3Share.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinierPrep.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
|
||||
@@ -128,16 +127,4 @@ void check(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
MC.Check(P);
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedPrep<Rep3Share<P256Element::Scalar>>::buffer_bits()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template<>
|
||||
void ReplicatedPrep<ShamirShare<P256Element::Scalar>>::buffer_bits()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
#endif /* ECDSA_PREPROCESSING_HPP_ */
|
||||
|
||||
@@ -149,11 +149,6 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
bool is_binary() const
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
size_t report_size(ReportType type)
|
||||
{
|
||||
size_t res = 4;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "AddableVector.h"
|
||||
#include "Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
#include "P2Data.h"
|
||||
|
||||
template<class T>
|
||||
AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
|
||||
@@ -33,7 +34,3 @@ AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template
|
||||
AddableVector<Int_Random_Coins::rand_type> AddableVector<
|
||||
Int_Random_Coins::rand_type>::mul_by_X_i(int j, const FHE_PK& pk) const;
|
||||
@@ -23,8 +23,6 @@ class Ciphertext
|
||||
word pk_id;
|
||||
|
||||
public:
|
||||
static string type_string() { return "ciphertext"; }
|
||||
static int t() { return 0; }
|
||||
static int size() { return 0; }
|
||||
|
||||
const FHE_Params& get_params() const { return *params; }
|
||||
@@ -41,8 +39,6 @@ class Ciphertext
|
||||
set(a0, a1, C.get_pk_id());
|
||||
}
|
||||
|
||||
~Ciphertext() { ; }
|
||||
|
||||
// Rely on default copy assignment/constructor
|
||||
|
||||
word get_pk_id() const { return pk_id; }
|
||||
|
||||
@@ -32,52 +32,6 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const
|
||||
|
||||
|
||||
|
||||
void RandomVectors::set(int nn,int hh,double R)
|
||||
{
|
||||
n=nn;
|
||||
h=hh;
|
||||
DG.set(R);
|
||||
}
|
||||
|
||||
void RandomVectors::set_n(int nn)
|
||||
{
|
||||
n = nn;
|
||||
}
|
||||
|
||||
vector<bigint> RandomVectors::sample_Gauss(PRNG& G, int stretch) const
|
||||
{
|
||||
vector<bigint> ans(n);
|
||||
for (int i=0; i<n; i++)
|
||||
{ ans[i]=DG.sample(G, stretch); }
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
vector<bigint> RandomVectors::sample_Hwt(PRNG& G) const
|
||||
{
|
||||
if (h > n/2 or h <= 0) { return sample_Gauss(G); }
|
||||
vector<bigint> ans(n);
|
||||
for (int i=0; i<n; i++) { ans[i]=0; }
|
||||
int cnt=0,j=0;
|
||||
unsigned char ch=0;
|
||||
while (cnt<h)
|
||||
{ unsigned int i=G.get_uint()%n;
|
||||
if (ans[i]==0)
|
||||
{ cnt++;
|
||||
if (j==0)
|
||||
{ j=8;
|
||||
ch=G.get_uchar();
|
||||
}
|
||||
int v=ch&1; j--;
|
||||
if (v==0) { ans[i]=-1; }
|
||||
else { ans[i]=1; }
|
||||
}
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
int sample_half(PRNG& G)
|
||||
{
|
||||
@@ -91,36 +45,6 @@ int sample_half(PRNG& G)
|
||||
}
|
||||
|
||||
|
||||
vector<bigint> RandomVectors::sample_Half(PRNG& G) const
|
||||
{
|
||||
vector<bigint> ans(n);
|
||||
for (int i=0; i<n; i++)
|
||||
ans[i] = sample_half(G);
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
vector<bigint> RandomVectors::sample_Uniform(PRNG& G,const bigint& B) const
|
||||
{
|
||||
vector<bigint> ans(n);
|
||||
bigint v;
|
||||
for (int i=0; i<n; i++)
|
||||
{ G.get_bigint(v, numBits(B));
|
||||
int bit=G.get_uint()&1;
|
||||
if (bit==0) { ans[i]=v; }
|
||||
else { ans[i]=-v; }
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
bool RandomVectors::operator!=(const RandomVectors& other) const
|
||||
{
|
||||
if (n != other.n or h != other.h or DG != other.DG)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DiscreteGauss::operator!=(const DiscreteGauss& other) const
|
||||
{
|
||||
if (other.NewHopeB != NewHopeB)
|
||||
|
||||
@@ -25,7 +25,6 @@ class DiscreteGauss
|
||||
void unpack(octetStream& o) { o.unserialize(NewHopeB); }
|
||||
|
||||
DiscreteGauss(double R) { set(R); }
|
||||
~DiscreteGauss() { ; }
|
||||
|
||||
// Rely on default copy constructor/assignment
|
||||
|
||||
@@ -36,50 +35,6 @@ class DiscreteGauss
|
||||
bool operator!=(const DiscreteGauss& other) const;
|
||||
};
|
||||
|
||||
|
||||
/* Sample from integer lattice of dimension n
|
||||
* with standard deviation R
|
||||
*/
|
||||
class RandomVectors
|
||||
{
|
||||
int n,h;
|
||||
DiscreteGauss DG; // This generates the main distribution
|
||||
|
||||
public:
|
||||
|
||||
void set(int nn,int hh,double R); // R is input STANDARD DEVIATION
|
||||
void set_n(int nn);
|
||||
|
||||
void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
|
||||
void unpack(octetStream& o)
|
||||
{ o.get(n); o.get(h); DG.unpack(o); }
|
||||
|
||||
RandomVectors(int h, double R) : RandomVectors(0, h, R) {}
|
||||
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
|
||||
~RandomVectors() { ; }
|
||||
|
||||
// Rely on default copy constructor/assignment
|
||||
|
||||
double get_R() const { return DG.get_R(); }
|
||||
DiscreteGauss get_DG() const { return DG; }
|
||||
int get_h() const { return h; }
|
||||
|
||||
// Sample from Discrete Gauss distribution
|
||||
vector<bigint> sample_Gauss(PRNG& G, int stretch = 1) const;
|
||||
|
||||
// Next samples from Hwt distribution unless hwt>n/2 in which
|
||||
// case it uses Gauss
|
||||
vector<bigint> sample_Hwt(PRNG& G) const;
|
||||
|
||||
// Sample from {-1,0,1} with Pr(-1)=Pr(1)=1/4 and Pr(0)=1/2
|
||||
vector<bigint> sample_Half(PRNG& G) const;
|
||||
|
||||
// Sample from (-B,0,B) with uniform prob
|
||||
vector<bigint> sample_Uniform(PRNG& G,const bigint& B) const;
|
||||
|
||||
bool operator!=(const RandomVectors& other) const;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class RandomGenerator : public Generator<T>
|
||||
{
|
||||
@@ -103,7 +58,7 @@ public:
|
||||
void get(T& x) const { this->G.get(x, n_bits, positive); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
template<class T = bigint>
|
||||
class GaussianGenerator : public RandomGenerator<T>
|
||||
{
|
||||
DiscreteGauss DG;
|
||||
|
||||
69
FHE/FFT.cpp
69
FHE/FFT.cpp
@@ -1,6 +1,7 @@
|
||||
|
||||
#include "FHE/FFT.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Processor/BaseMachine.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
@@ -115,17 +116,38 @@ void FFT_Iter(vector<T>& ioput, int n, const T& root, const P& PrD)
|
||||
*/
|
||||
void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
|
||||
{
|
||||
FFT_Iter(ioput, n, root, PrD, false);
|
||||
}
|
||||
|
||||
void FFT_Iter2(vector<modp>& ioput, int n, const vector<modp>& roots,
|
||||
const Zp_Data& PrD)
|
||||
{
|
||||
FFT_Iter(ioput, n, roots, PrD, false);
|
||||
}
|
||||
|
||||
void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
|
||||
bool start_with_one)
|
||||
{
|
||||
vector<modp> roots(n + 1);
|
||||
assignOne(roots[0], PrD);
|
||||
for (int i = 1; i < n + 1; i++)
|
||||
Mul(roots[i], roots[i - 1], root, PrD);
|
||||
FFT_Iter(ioput, n, roots, PrD, start_with_one);
|
||||
}
|
||||
|
||||
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
|
||||
const Zp_Data& PrD, bool start_with_one)
|
||||
{
|
||||
assert(roots.size() > size_t(n));
|
||||
|
||||
int i, j, m;
|
||||
modp t;
|
||||
|
||||
// Bit-reversal of input
|
||||
for( i = j = 0; i < n; ++i )
|
||||
{
|
||||
if( j >= i )
|
||||
{
|
||||
t = ioput[i];
|
||||
ioput[i] = ioput[j];
|
||||
ioput[j] = t;
|
||||
swap(ioput[i], ioput[j]);
|
||||
}
|
||||
m = n / 2;
|
||||
|
||||
@@ -136,27 +158,38 @@ void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
|
||||
}
|
||||
j += m;
|
||||
}
|
||||
modp u, alpha, alpha2;
|
||||
m = 0; j = 0; i = 0;
|
||||
// Do the transform
|
||||
vector<modp> alpha2;
|
||||
alpha2.reserve(n / 2);
|
||||
for (int s = 1; s < n; s = 2*s)
|
||||
{
|
||||
m = 2*s;
|
||||
Power(alpha, root, n/m, PrD);
|
||||
alpha2 = alpha;
|
||||
Mul(alpha, alpha, alpha, PrD);
|
||||
for (int j = 0; j < m/2; ++j)
|
||||
|
||||
alpha2.clear();
|
||||
if (start_with_one)
|
||||
{
|
||||
//root = root_table[(2*j+1)*n/m];
|
||||
for (int k = j; k < n; k += m)
|
||||
{
|
||||
Mul(t, alpha2, ioput[k + m/2], PrD);
|
||||
u = ioput[k];
|
||||
Add(ioput[k], u, t, PrD);
|
||||
Sub(ioput[k + m/2], u, t, PrD);
|
||||
}
|
||||
Mul(alpha2, alpha2, alpha, PrD);
|
||||
for (int j = 0; j < m / 2; j++)
|
||||
alpha2.push_back(roots[j * n / m]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < m / 2; j++)
|
||||
alpha2.push_back(roots.at((j * 2 + 1) * (n / m)));
|
||||
}
|
||||
|
||||
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
|
||||
{
|
||||
auto& queues = BaseMachine::s().queues;
|
||||
FftJob job(ioput, alpha2, m, PrD);
|
||||
int start = queues.distribute(job, n / 2);
|
||||
for (int i = start; i < n / 2; i++)
|
||||
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
|
||||
queues.wrap_up(job);
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < n / 2; i++)
|
||||
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
21
FHE/FFT.h
21
FHE/FFT.h
@@ -30,8 +30,29 @@ void FFT2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
|
||||
template <class T,class P>
|
||||
void FFT_Iter(vector<T>& a,int N,const T& theta,const P& PrD);
|
||||
|
||||
void FFT_Iter(vector<modp>& a, int N, const modp& theta, const Zp_Data& PrD,
|
||||
bool start_with_one = true);
|
||||
void FFT_Iter2(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD);
|
||||
|
||||
// variants with precomputed roots
|
||||
|
||||
void FFT_Iter(vector<modp>& a, int N, const vector<modp>& theta,
|
||||
const Zp_Data& PrD, bool start_with_one = true);
|
||||
void FFT_Iter2(vector<modp>& a, int N, const vector<modp>& theta,
|
||||
const Zp_Data& PrD);
|
||||
|
||||
inline void FFT_Iter2_body(vector<modp>& ioput, const vector<modp>& alpha2, int i,
|
||||
int m, const Zp_Data& PrD)
|
||||
{
|
||||
int j = i % (m / 2);
|
||||
int kk = i / (m / 2);
|
||||
int k = j + kk * m;
|
||||
modp t, u;
|
||||
Mul(t, alpha2[j], ioput[k + m / 2], PrD);
|
||||
u = ioput[k];
|
||||
Add(ioput[k], u, t, PrD);
|
||||
Sub(ioput[k + m / 2], u, t, PrD);
|
||||
}
|
||||
|
||||
/* BFFT perform FFT and inverse FFT mod PrD for non power of two cyclotomics.
|
||||
* The modulus in PrD (contained in FFT_Data) must be set up
|
||||
|
||||
@@ -6,24 +6,6 @@
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
|
||||
void FFT_Data::assign(const FFT_Data& FFTD)
|
||||
{
|
||||
prData=FFTD.prData;
|
||||
R=FFTD.R;
|
||||
|
||||
root=FFTD.root;
|
||||
twop=FFTD.twop;
|
||||
|
||||
two_root=FFTD.two_root;
|
||||
powers=FFTD.powers;
|
||||
powers_i=FFTD.powers_i;
|
||||
b=FFTD.b;
|
||||
|
||||
iphi=FFTD.iphi;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{
|
||||
@@ -49,6 +31,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
Inv(root[1],root[0],PrD);
|
||||
to_modp(iphi,Rg.phi_m(),PrD);
|
||||
Inv(iphi,iphi,PrD);
|
||||
compute_roots(Rg.m());
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -57,6 +40,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ throw invalid_params(); }
|
||||
root[0]=Find_Primitive_Root_2m(Rg.m(),Rg.Phi(),PrD);
|
||||
Inv(root[1],root[0],PrD);
|
||||
compute_roots(2 * Rg.m());
|
||||
|
||||
int ptwop=twop; if (twop<0) { ptwop=-twop; }
|
||||
|
||||
@@ -97,6 +81,14 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
}
|
||||
}
|
||||
|
||||
void FFT_Data::compute_roots(int n)
|
||||
{
|
||||
roots.resize(n + 1);
|
||||
assignOne(roots[0], prData);
|
||||
for (int i = 1; i < n + 1; i++)
|
||||
Mul(roots[i], roots[i - 1], root[0], prData);
|
||||
}
|
||||
|
||||
|
||||
void FFT_Data::hash(octetStream& o) const
|
||||
{
|
||||
@@ -111,6 +103,7 @@ void FFT_Data::pack(octetStream& o) const
|
||||
R.pack(o);
|
||||
prData.pack(o);
|
||||
o.store(root);
|
||||
o.store(roots);
|
||||
o.store(twop);
|
||||
o.store(two_root);
|
||||
o.store(b);
|
||||
@@ -125,6 +118,7 @@ void FFT_Data::unpack(octetStream& o)
|
||||
R.unpack(o);
|
||||
prData.unpack(o);
|
||||
o.get(root);
|
||||
o.get(roots);
|
||||
o.get(twop);
|
||||
o.get(two_root);
|
||||
o.get(b);
|
||||
@@ -133,7 +127,6 @@ void FFT_Data::unpack(octetStream& o)
|
||||
o.get(powers_i);
|
||||
}
|
||||
|
||||
|
||||
bool FFT_Data::operator!=(const FFT_Data& other) const
|
||||
{
|
||||
if (R != other.R or prData != other.prData or root != other.root
|
||||
|
||||
@@ -19,6 +19,7 @@ class FFT_Data
|
||||
Zp_Data prData;
|
||||
|
||||
vector<modp> root; // 2m'th Root of Unity mod pr and it's inverse
|
||||
vector<modp> roots; // precomputed powers of root
|
||||
|
||||
// When twop is equal to zero, m is a power of two
|
||||
// When twop is positive it is equal to 2^e where 2^e>2*m and 2^e divides p-1
|
||||
@@ -34,6 +35,8 @@ class FFT_Data
|
||||
modp iphi; // 1/phi_m mod pr
|
||||
vector< vector<modp> > powers,powers_i;
|
||||
|
||||
void compute_roots(int n);
|
||||
|
||||
public:
|
||||
typedef gfp T;
|
||||
typedef bigint S;
|
||||
@@ -47,17 +50,9 @@ class FFT_Data
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
void assign(const FFT_Data& FFTD);
|
||||
|
||||
FFT_Data() { ; }
|
||||
FFT_Data(const FFT_Data& FFTD)
|
||||
{ assign(FFTD); }
|
||||
FFT_Data(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ init(Rg,PrD); }
|
||||
FFT_Data& operator=(const FFT_Data& FFTD)
|
||||
{ if (this!=&FFTD) { assign(FFTD); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Zp_Data& get_prD() const { return prData; }
|
||||
const bigint& get_prime() const { return prData.pr; }
|
||||
@@ -72,6 +67,7 @@ class FFT_Data
|
||||
int get_twop() const { return twop; }
|
||||
modp get_root(int i) const { return root[i]; }
|
||||
modp get_iphi() const { return iphi; }
|
||||
const vector<modp>& get_roots() const { return roots; }
|
||||
|
||||
const Ring& get_R() const { return R; }
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
|
||||
{
|
||||
Rq_Element sk = FHE_SK(*this).s();
|
||||
// Generate the secret key
|
||||
sk.from_vec((*params).sampleHwt(G));
|
||||
sk.from(GaussianGenerator<bigint>(params->get_DG(), G));
|
||||
return sk;
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
|
||||
// b0=a0*s+p*e0
|
||||
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
|
||||
e0.from_vec((*PK.params).sampleGaussian(G, noise_boost));
|
||||
e0.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
|
||||
mul(PK.b0,PK.a0,sk);
|
||||
mul(e0,e0,PK.pr);
|
||||
add(PK.b0,PK.b0,e0);
|
||||
@@ -72,7 +72,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
|
||||
// bs=as*s+p*es
|
||||
Rq_Element es((*PK.params).FFTD(),evaluation,evaluation);
|
||||
es.from_vec((*PK.params).sampleGaussian(G, noise_boost));
|
||||
es.from(GaussianGenerator<bigint>(params->get_DG(), G, noise_boost));
|
||||
mul(PK.Sw_b,PK.Sw_a,sk);
|
||||
mul(es,es,PK.pr);
|
||||
add(PK.Sw_b,PK.Sw_b,es);
|
||||
@@ -120,13 +120,14 @@ void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) const
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
template<class T, class FD, class S>
|
||||
void FHE_PK::encrypt(Ciphertext& c,
|
||||
const Plaintext<gfp,FFT_Data,bigint>& mess,const Random_Coins& rc) const
|
||||
const Plaintext<T, FD, S>& mess,const Random_Coins& rc) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (&rc.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr==2) { throw pr_mismatch(); }
|
||||
if (T::characteristic_two ^ (pr == 2))
|
||||
throw pr_mismatch();
|
||||
|
||||
Rq_Element mm((*params).FFTD(),polynomial,polynomial);
|
||||
mm.from(mess.get_iterator());
|
||||
@@ -134,35 +135,6 @@ void FHE_PK::encrypt(Ciphertext& c,
|
||||
quasi_encrypt(c,mm,rc);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void FHE_PK::encrypt(Ciphertext& c,
|
||||
const Plaintext<gfp,PPData,bigint>& mess,const Random_Coins& rc) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (&rc.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr==2) { throw pr_mismatch(); }
|
||||
|
||||
mess.to_poly();
|
||||
encrypt(c, mess.get_poly(), rc);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void FHE_PK::encrypt(Ciphertext& c,
|
||||
const Plaintext<gf2n_short,P2Data,int>& mess,const Random_Coins& rc) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (&rc.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr!=2) { throw pr_mismatch(); }
|
||||
|
||||
mess.to_poly();
|
||||
encrypt(c, mess.get_poly(), rc);
|
||||
}
|
||||
|
||||
void FHE_PK::quasi_encrypt(Ciphertext& c,
|
||||
const Rq_Element& mess,const Random_Coins& rc) const
|
||||
{
|
||||
@@ -212,42 +184,12 @@ Ciphertext FHE_PK::encrypt(
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void FHE_SK::decrypt(Plaintext<gfp,FFT_Data,bigint>& mess,const Ciphertext& c) const
|
||||
template<class T, class FD, class S>
|
||||
void FHE_SK::decrypt(Plaintext<T,FD,S>& mess,const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr==2) { throw pr_mismatch(); }
|
||||
|
||||
Rq_Element ans;
|
||||
|
||||
mul(ans,c.c1(),sk);
|
||||
sub(ans,c.c0(),ans);
|
||||
ans.change_rep(polynomial);
|
||||
mess.set_poly_mod(ans.get_iterator(), ans.get_modulus());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void FHE_SK::decrypt(Plaintext<gfp,PPData,bigint>& mess,const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr==2) { throw pr_mismatch(); }
|
||||
|
||||
Rq_Element ans;
|
||||
|
||||
mul(ans,c.c1(),sk);
|
||||
sub(ans,c.c0(),ans);
|
||||
mess.set_poly_mod(ans.to_vec_bigint(),ans.get_modulus());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void FHE_SK::decrypt(Plaintext<gf2n_short,P2Data,int>& mess,const Ciphertext& c) const
|
||||
{
|
||||
if (&c.get_params()!=params) { throw params_mismatch(); }
|
||||
if (pr!=2) { throw pr_mismatch(); }
|
||||
if (T::characteristic_two ^ (pr == 2))
|
||||
throw pr_mismatch();
|
||||
|
||||
Rq_Element ans;
|
||||
|
||||
|
||||
@@ -3,14 +3,6 @@
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
void FHE_Params::set(const Ring& R,
|
||||
const vector<bigint>& primes,double r,int hwt)
|
||||
{
|
||||
set(R, primes);
|
||||
|
||||
Chi.set(R.phi_m(),hwt,r);
|
||||
}
|
||||
|
||||
void FHE_Params::set(const Ring& R,
|
||||
const vector<bigint>& primes)
|
||||
{
|
||||
@@ -20,7 +12,6 @@ void FHE_Params::set(const Ring& R,
|
||||
for (size_t i = 0; i < FFTData.size(); i++)
|
||||
FFTData[i].init(R,primes[i]);
|
||||
|
||||
Chi.set_n(R.phi_m());
|
||||
set_sec(40);
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class FHE_Params
|
||||
vector<FFT_Data> FFTData;
|
||||
|
||||
// Random generator for Multivariate Gaussian Distribution etc
|
||||
RandomVectors Chi;
|
||||
mutable DiscreteGauss Chi;
|
||||
|
||||
// Data for distributed decryption
|
||||
int sec_p;
|
||||
@@ -29,27 +29,17 @@ class FHE_Params
|
||||
|
||||
public:
|
||||
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(-1, 0.7), sec_p(-1) {}
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
|
||||
|
||||
int n_mults() const { return FFTData.size() - 1; }
|
||||
|
||||
// Rely on default copy assignment/constructor (not that they should
|
||||
// ever be needed)
|
||||
|
||||
void set(const Ring& R,const vector<bigint>& primes,double r,int hwt);
|
||||
void set(const Ring& R,const vector<bigint>& primes);
|
||||
void set(const vector<bigint>& primes);
|
||||
void set_sec(int sec);
|
||||
|
||||
vector<bigint> sampleGaussian(PRNG& G, int noise_boost = 1) const
|
||||
{ return Chi.sample_Gauss(G, noise_boost); }
|
||||
vector<bigint> sampleHwt(PRNG& G) const
|
||||
{ return Chi.sample_Hwt(G); }
|
||||
vector<bigint> sampleHalf(PRNG& G) const
|
||||
{ return Chi.sample_Half(G); }
|
||||
vector<bigint> sampleUniform(PRNG& G,const bigint& Bd) const
|
||||
{ return Chi.sample_Uniform(G,Bd); }
|
||||
|
||||
const vector<FFT_Data>& FFTD() const { return FFTData; }
|
||||
|
||||
const bigint& p0() const { return FFTData[0].get_prime(); }
|
||||
@@ -59,9 +49,8 @@ class FHE_Params
|
||||
int secp() const { return sec_p; }
|
||||
const bigint& B() const { return Bval; }
|
||||
double get_R() const { return Chi.get_R(); }
|
||||
void set_R(double R) const { return Chi.get_DG().set(R); }
|
||||
DiscreteGauss get_DG() const { return Chi.get_DG(); }
|
||||
int get_h() const { return Chi.get_h(); }
|
||||
void set_R(double R) const { return Chi.set(R); }
|
||||
DiscreteGauss get_DG() const { return Chi; }
|
||||
|
||||
int phi_m() const { return FFTData[0].phi_m(); }
|
||||
const Ring& get_ring() { return FFTData[0].get_R(); }
|
||||
|
||||
@@ -52,10 +52,12 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
bigint p;
|
||||
generate_prime(p, lgp, m);
|
||||
int lgp0, lgp1;
|
||||
FHE_Params tmp_params;
|
||||
while (true)
|
||||
{
|
||||
tmp_params = params;
|
||||
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
|
||||
bigint p1 = 2 * p * m, p0 = p;
|
||||
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
|
||||
{
|
||||
@@ -75,6 +77,7 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
}
|
||||
}
|
||||
|
||||
params = tmp_params;
|
||||
int extra_slack = common_semi_setup(params, m, p, lgp0, lgp1, round_up);
|
||||
|
||||
FTD.init(params.get_ring(), p);
|
||||
|
||||
@@ -13,29 +13,24 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
const FHE_Params& params) :
|
||||
p(p), phi_m(phi_m), n(n), sec(sec),
|
||||
slack(numBits(Proof::slack(slack_param, sec, phi_m))),
|
||||
sigma(params.get_R()), h(params.get_h())
|
||||
sigma(params.get_R())
|
||||
{
|
||||
if (sigma <= 0)
|
||||
this->sigma = sigma = FHE_Params().get_R();
|
||||
#ifdef VERBOSE
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
#endif
|
||||
if (h > 0)
|
||||
h += extra_h * sec;
|
||||
else if (extra_h)
|
||||
if (extra_h)
|
||||
{
|
||||
sigma *= 1.4;
|
||||
params.set_R(params.get_R() * 1.4);
|
||||
}
|
||||
#ifdef VERBOSE
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
#endif
|
||||
|
||||
produce_epsilon_constants();
|
||||
|
||||
// according to documentation of SCALE-MAMBA 1.7
|
||||
// excluding a factor of n because we don't always add up n ciphertexts
|
||||
if (h > 0)
|
||||
V_s = sqrt(h);
|
||||
else
|
||||
V_s = sigma * sqrt(phi_m);
|
||||
V_s = sigma * sqrt(phi_m);
|
||||
B_clean = (bigint(phi_m) << (sec + 1)) * p
|
||||
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s);
|
||||
// unify parameters by taking maximum over TopGear or not
|
||||
|
||||
@@ -22,7 +22,6 @@ protected:
|
||||
const int sec;
|
||||
int slack;
|
||||
mpf_class sigma;
|
||||
int h;
|
||||
|
||||
bigint B_clean;
|
||||
bigint B_scale;
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
|
||||
|
||||
void PPData::assign(const PPData& PPD)
|
||||
{
|
||||
R=PPD.R;
|
||||
prData=PPD.prData;
|
||||
root=PPD.root;
|
||||
}
|
||||
|
||||
|
||||
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{
|
||||
R=Rg;
|
||||
|
||||
@@ -27,17 +27,9 @@ class PPData
|
||||
|
||||
void init(const Ring& Rg,const Zp_Data& PrD);
|
||||
|
||||
void assign(const PPData& PPD);
|
||||
|
||||
PPData() { ; }
|
||||
PPData(const PPData& PPD)
|
||||
{ assign(PPD); }
|
||||
PPData(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ init(Rg,PrD); }
|
||||
PPData& operator=(const PPData& PPD)
|
||||
{ if (this!=&PPD) { assign(PPD); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Zp_Data& get_prD() const { return prData; }
|
||||
const bigint& get_prime() const { return prData.pr; }
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
#include "FHE/AddableVector.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
@@ -258,37 +259,9 @@ void Plaintext<T,FD,S>::randomize(PRNG& G,condition cond)
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext_<FFT_Data>::randomize(PRNG& G, bigint B, bool Diag, bool binary, PT_Type t)
|
||||
{
|
||||
if (Diag or binary)
|
||||
throw not_implemented();
|
||||
if (B == 0)
|
||||
throw runtime_error("cannot randomize modulo 0");
|
||||
|
||||
allocate(t);
|
||||
switch (t)
|
||||
{
|
||||
case Polynomial:
|
||||
rand_poly(b, G, B, false);
|
||||
break;
|
||||
case Evaluation:
|
||||
for (int i = 0; i < n_slots; i++)
|
||||
a[i] = G.randomBnd(B);
|
||||
break;
|
||||
default:
|
||||
throw runtime_error("wrong type for randomization with bound");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class T,class FD,class S>
|
||||
void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, bool binary, PT_Type t)
|
||||
void Plaintext<T,FD,S>::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t)
|
||||
{
|
||||
if (binary)
|
||||
throw not_implemented();
|
||||
|
||||
allocate(t);
|
||||
switch(t)
|
||||
{
|
||||
@@ -614,10 +587,11 @@ void Plaintext<gf2n_short,P2Data,int>::negate()
|
||||
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
Rq_Element Plaintext<T, FD, S>::mul_by_X_i(int i, const FHE_PK& pk) const
|
||||
template<class T, class FD, class _>
|
||||
AddableVector<typename FD::poly_type> Plaintext<T, FD, _>::mul_by_X_i(int i,
|
||||
const FHE_PK& pk) const
|
||||
{
|
||||
return Rq_Element(pk.get_params(), *this).mul_by_X_i(i);
|
||||
return AddableVector<S>(get_poly()).mul_by_X_i(i, pk);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ using namespace std;
|
||||
|
||||
class FHE_PK;
|
||||
class Rq_Element;
|
||||
template<class T> class AddableVector;
|
||||
|
||||
// Forward declaration as apparently this is needed for friends in templates
|
||||
template<class T,class FD,class S> class Plaintext;
|
||||
@@ -64,13 +65,6 @@ class Plaintext
|
||||
const FD& get_field() const { return *Field_Data; }
|
||||
unsigned int num_slots() const { return n_slots; }
|
||||
|
||||
void assign(const Plaintext& p)
|
||||
{ Field_Data=p.Field_Data;
|
||||
a=p.a; b=p.b; type=p.type;
|
||||
n_slots = p.n_slots;
|
||||
degree = p.degree;
|
||||
}
|
||||
|
||||
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
|
||||
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
|
||||
|
||||
@@ -142,8 +136,7 @@ class Plaintext
|
||||
void to_poly() const;
|
||||
|
||||
void randomize(PRNG& G,condition cond=Full);
|
||||
void randomize(PRNG& G, bigint B, bool Diag=false, bool binary=false, PT_Type type=Polynomial);
|
||||
void randomize(PRNG& G, int n_bits, bool Diag=false, bool binary=false, PT_Type type=Polynomial);
|
||||
void randomize(PRNG& G, int n_bits, bool Diag=false, PT_Type type=Polynomial);
|
||||
|
||||
void assign_zero(PT_Type t = Evaluation);
|
||||
void assign_one(PT_Type t = Evaluation);
|
||||
@@ -171,13 +164,12 @@ class Plaintext
|
||||
|
||||
void negate();
|
||||
|
||||
Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const;
|
||||
AddableVector<S> mul_by_X_i(int i, const FHE_PK& pk) const;
|
||||
|
||||
bool equals(const Plaintext& x) const;
|
||||
bool operator!=(const Plaintext& x) { return !equals(x); }
|
||||
|
||||
bool is_diagonal() const;
|
||||
bool is_binary() const { throw not_implemented(); }
|
||||
|
||||
/* Pack and unpack into an octetStream
|
||||
* For unpack we assume the FFTD has been assigned correctly already
|
||||
|
||||
@@ -52,8 +52,6 @@ class Random_Coins
|
||||
{ params=&p; }
|
||||
|
||||
Random_Coins(const FHE_PK& pk);
|
||||
|
||||
~Random_Coins() { ; }
|
||||
|
||||
// Rely on default copy assignment/constructor
|
||||
|
||||
|
||||
@@ -33,17 +33,36 @@ Ring_Element::Ring_Element(const FFT_Data& fftd,RepType r)
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::prepare(const Ring_Element& other)
|
||||
{
|
||||
assert(this != &other);
|
||||
FFTD = other.FFTD;
|
||||
rep = other.rep;
|
||||
prepare_push();
|
||||
}
|
||||
|
||||
void Ring_Element::prepare_push()
|
||||
{
|
||||
element.clear();
|
||||
element.reserve(FFTD->phi_m());
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::allocate()
|
||||
{
|
||||
element.resize(FFTD->phi_m());
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::assign_zero()
|
||||
{
|
||||
element.resize((*FFTD).phi_m());
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ assignZero(element[i],(*FFTD).get_prD()); }
|
||||
element.clear();
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::assign_one()
|
||||
{
|
||||
element.resize((*FFTD).phi_m());
|
||||
allocate();
|
||||
modp fill;
|
||||
if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); }
|
||||
else { assignOne(fill,(*FFTD).get_prD()); }
|
||||
@@ -56,6 +75,9 @@ void Ring_Element::assign_one()
|
||||
|
||||
void Ring_Element::negate()
|
||||
{
|
||||
if (element.empty())
|
||||
return;
|
||||
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ Negate(element[i],element[i],(*FFTD).get_prD()); }
|
||||
}
|
||||
@@ -66,20 +88,58 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
ans.partial_assign(a);
|
||||
if (a.element.empty())
|
||||
{
|
||||
ans = b;
|
||||
return;
|
||||
}
|
||||
else if (b.element.empty())
|
||||
{
|
||||
ans = a;
|
||||
return;
|
||||
}
|
||||
|
||||
if (&ans == &a)
|
||||
{
|
||||
ans += b;
|
||||
return;
|
||||
}
|
||||
else if (&ans == &b)
|
||||
{
|
||||
ans += a;
|
||||
return;
|
||||
}
|
||||
|
||||
ans.prepare(a);
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ Add(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
|
||||
ans.element.push_back(a.element[i].add(b.element[i], a.FFTD->get_prD()));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
ans.partial_assign(a);
|
||||
if (a.element.empty())
|
||||
{
|
||||
ans = b;
|
||||
ans.negate();
|
||||
return;
|
||||
}
|
||||
else if (b.element.empty())
|
||||
{
|
||||
ans = a;
|
||||
return;
|
||||
}
|
||||
|
||||
if (&ans == &a)
|
||||
{
|
||||
ans -= b;
|
||||
return;
|
||||
}
|
||||
|
||||
ans.prepare(a);
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ Sub(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
|
||||
ans.element.push_back(a.element[i].sub(b.element[i], a.FFTD->get_prD()));
|
||||
}
|
||||
|
||||
|
||||
@@ -88,13 +148,29 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
{
|
||||
if (a.rep!=b.rep) { throw rep_mismatch(); }
|
||||
if (a.FFTD!=b.FFTD) { throw pr_mismatch(); }
|
||||
ans.partial_assign(a);
|
||||
if (ans.rep==evaluation)
|
||||
{ // In evaluation representation, so we can just multiply componentwise
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ Mul(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); }
|
||||
if (a.element.empty() or b.element.empty())
|
||||
{
|
||||
ans = Ring_Element(*a.FFTD, a.rep);
|
||||
return;
|
||||
}
|
||||
else if ((*ans.FFTD).get_twop()!=0)
|
||||
|
||||
if (a.rep==evaluation)
|
||||
{ // In evaluation representation, so we can just multiply componentwise
|
||||
if (&ans == &a)
|
||||
{
|
||||
ans *= b;
|
||||
return;
|
||||
}
|
||||
else if (&ans == &b)
|
||||
{
|
||||
ans *= a;
|
||||
return;
|
||||
}
|
||||
ans.prepare(a);
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
ans.element.push_back(a.element[i].mul(b.element[i], a.FFTD->get_prD()));
|
||||
}
|
||||
else if ((*a.FFTD).get_twop()!=0)
|
||||
{ // This is the case where m is not a power of two
|
||||
|
||||
// Here we have to do a poly mult followed by a reduction
|
||||
@@ -116,11 +192,13 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
// Now apply reduction, assumes Ring.poly is monic
|
||||
reduce(aa, 2*(*a.FFTD).phi_m(), (*a.FFTD).phi_m(), *a.FFTD);
|
||||
// Now stick into answer
|
||||
ans.partial_assign(a);
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ ans.element[i]=aa[i]; }
|
||||
}
|
||||
else if ((*ans.FFTD).get_twop()==0)
|
||||
else if ((*a.FFTD).get_twop()==0)
|
||||
{ // m a power of two case
|
||||
ans.partial_assign(a);
|
||||
Ring_Element aa(*ans.FFTD,ans.rep);
|
||||
modp temp;
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
@@ -143,31 +221,89 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b)
|
||||
|
||||
void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
|
||||
{
|
||||
ans.partial_assign(a);
|
||||
if (&ans == &a)
|
||||
{
|
||||
ans *= b;
|
||||
return;
|
||||
}
|
||||
|
||||
ans.prepare(a);
|
||||
if (a.element.empty())
|
||||
return;
|
||||
|
||||
for (int i=0; i<(*ans.FFTD).phi_m(); i++)
|
||||
{ Mul(ans.element[i],a.element[i],b,(*a.FFTD).get_prD()); }
|
||||
ans.element.push_back(a.element[i].mul(b, a.FFTD->get_prD()));
|
||||
}
|
||||
|
||||
|
||||
Ring_Element& Ring_Element::operator +=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
element[i] = element[i].add(other.element[i], FFTD->get_prD());
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
Ring_Element& Ring_Element::operator -=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
element[i] = element[i].sub(other.element[i], FFTD->get_prD());
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
Ring_Element& Ring_Element::operator *=(const Ring_Element& other)
|
||||
{
|
||||
assert(element.size() == other.element.size());
|
||||
assert(FFTD == other.FFTD);
|
||||
assert(rep == other.rep);
|
||||
assert(rep == evaluation);
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
element[i] = element[i].mul(other.element[i], FFTD->get_prD());
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
Ring_Element& Ring_Element::operator *=(const modp& other)
|
||||
{
|
||||
for (size_t i = 0; i < element.size(); i++)
|
||||
element[i] = element[i].mul(other, FFTD->get_prD());
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
{
|
||||
Ring_Element ans;
|
||||
ans.prepare(*this);
|
||||
if (element.empty())
|
||||
return ans;
|
||||
|
||||
auto& a = *this;
|
||||
ans.partial_assign(a);
|
||||
if (ans.rep == evaluation)
|
||||
{
|
||||
modp xj, xj2;
|
||||
Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD());
|
||||
Sqr(xj2, xj, (*a.FFTD).get_prD());
|
||||
ans.prepare_push();
|
||||
modp tmp;
|
||||
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
|
||||
{
|
||||
Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD());
|
||||
Mul(tmp, a.element[i], xj, (*a.FFTD).get_prD());
|
||||
ans.element.push_back(tmp);
|
||||
Mul(xj, xj, xj2, (*a.FFTD).get_prD());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Ring_Element aa(*ans.FFTD, ans.rep);
|
||||
aa.allocate();
|
||||
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
|
||||
{
|
||||
int k= j + i, s= 1;
|
||||
@@ -193,6 +329,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
|
||||
void Ring_Element::randomize(PRNG& G,bool Diag)
|
||||
{
|
||||
allocate();
|
||||
if (Diag==false)
|
||||
{ for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ element[i].randomize(G,(*FFTD).get_prD()); }
|
||||
@@ -213,12 +350,18 @@ void Ring_Element::randomize(PRNG& G,bool Diag)
|
||||
|
||||
void Ring_Element::change_rep(RepType r)
|
||||
{
|
||||
if (element.empty())
|
||||
{
|
||||
rep = r;
|
||||
return;
|
||||
}
|
||||
|
||||
if (rep==r) { return; }
|
||||
if (r==evaluation)
|
||||
{ rep=evaluation;
|
||||
if ((*FFTD).get_twop()==0)
|
||||
{ // m a power of two variant
|
||||
FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_root(0),(*FFTD).get_prD());
|
||||
FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_roots(),(*FFTD).get_prD());
|
||||
}
|
||||
else
|
||||
{ // Non m power of two variant and FFT enabled
|
||||
@@ -258,6 +401,11 @@ void Ring_Element::change_rep(RepType r)
|
||||
|
||||
bool Ring_Element::equals(const Ring_Element& a) const
|
||||
{
|
||||
if (element.empty() and a.element.empty())
|
||||
return true;
|
||||
else if (element.empty() or a.element.empty())
|
||||
throw not_implemented();
|
||||
|
||||
if (rep!=a.rep) { throw rep_mismatch(); }
|
||||
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
@@ -266,34 +414,11 @@ bool Ring_Element::equals(const Ring_Element& a) const
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::from_vec(const vector<bigint>& v)
|
||||
{
|
||||
RepType t=rep;
|
||||
rep=polynomial;
|
||||
bigint tmp;
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{
|
||||
tmp = v[i];
|
||||
element[i].convert_destroy(tmp, FFTD->get_prD());
|
||||
}
|
||||
change_rep(t);
|
||||
// cout << "RE:from_vec<bigint>:: " << *this << endl;
|
||||
}
|
||||
|
||||
void Ring_Element::from_vec(const vector<int>& v)
|
||||
{
|
||||
RepType t=rep;
|
||||
rep=polynomial;
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ to_modp(element[i],v[i],(*FFTD).get_prD()); }
|
||||
change_rep(t);
|
||||
// cout << "RE:from_vec<int>:: " << *this << endl;
|
||||
}
|
||||
|
||||
ConversionIterator Ring_Element::get_iterator() const
|
||||
{
|
||||
if (rep != polynomial)
|
||||
throw runtime_error("simple iterator only available in polynomial represention");
|
||||
assert(not element.empty());
|
||||
return {element, (*FFTD).get_prD()};
|
||||
}
|
||||
|
||||
@@ -318,6 +443,9 @@ vector<bigint> Ring_Element::to_vec_bigint() const
|
||||
void Ring_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
{
|
||||
v.resize(FFTD->phi_m());
|
||||
if (element.empty())
|
||||
return;
|
||||
|
||||
if (rep==polynomial)
|
||||
{ for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ to_bigint(v[i],element[i],(*FFTD).get_prD()); }
|
||||
@@ -336,11 +464,10 @@ void Ring_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
|
||||
modp Ring_Element::get_constant() const
|
||||
{
|
||||
if (rep==polynomial)
|
||||
{ return element[0]; }
|
||||
Ring_Element a=*this;
|
||||
a.change_rep(polynomial);
|
||||
return a.element[0];
|
||||
if (element.empty())
|
||||
return {};
|
||||
else
|
||||
return element[0];
|
||||
}
|
||||
|
||||
|
||||
@@ -364,9 +491,14 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
+ to_string(ZpD.pr_bit_length));
|
||||
unsigned int length;
|
||||
o.get(length);
|
||||
v.resize(length);
|
||||
v.clear();
|
||||
v.reserve(length);
|
||||
modp tmp;
|
||||
for (unsigned int i=0; i<length; i++)
|
||||
{ v[i].unpack(o,ZpD); }
|
||||
{
|
||||
tmp.unpack(o,ZpD);
|
||||
v.push_back(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -398,7 +530,7 @@ void Ring_Element::check_rep()
|
||||
|
||||
void Ring_Element::check_size() const
|
||||
{
|
||||
if ((int)element.size() != FFTD->phi_m())
|
||||
if (not element.empty() and (int)element.size() != FFTD->phi_m())
|
||||
throw runtime_error("invalid element size");
|
||||
}
|
||||
|
||||
|
||||
@@ -41,12 +41,6 @@ class Ring_Element
|
||||
|
||||
vector<modp> element;
|
||||
|
||||
// Define a copy
|
||||
void assign(const Ring_Element& e)
|
||||
{ rep=e.rep; FFTD=e.FFTD;
|
||||
element=e.element;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
// Used to basically make sure *this is able to cope
|
||||
@@ -57,6 +51,10 @@ class Ring_Element
|
||||
element.resize((*FFTD).phi_m());
|
||||
}
|
||||
|
||||
void prepare(const Ring_Element& e);
|
||||
void prepare_push();
|
||||
void allocate();
|
||||
|
||||
void set_data(const FFT_Data& prd) { FFTD=&prd; }
|
||||
const FFT_Data& get_FFTD() const { return *FFTD; }
|
||||
const Zp_Data& get_prD() const { return (*FFTD).get_prD(); }
|
||||
@@ -80,19 +78,6 @@ class Ring_Element
|
||||
element.push_back(x);
|
||||
}
|
||||
|
||||
// Copy Constructor
|
||||
Ring_Element(const Ring_Element& e)
|
||||
{ assign(e); }
|
||||
|
||||
// Destructor
|
||||
~Ring_Element() { ; }
|
||||
|
||||
// Copy Assignment
|
||||
Ring_Element& operator=(const Ring_Element& e)
|
||||
{ if (this!=&e) { assign(e); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
/* Functional Operators */
|
||||
void negate();
|
||||
friend void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b);
|
||||
@@ -102,6 +87,11 @@ class Ring_Element
|
||||
|
||||
Ring_Element mul_by_X_i(int i) const;
|
||||
|
||||
Ring_Element& operator+=(const Ring_Element& other);
|
||||
Ring_Element& operator-=(const Ring_Element& other);
|
||||
Ring_Element& operator*=(const Ring_Element& other);
|
||||
Ring_Element& operator*=(const modp& other);
|
||||
|
||||
void randomize(PRNG& G,bool Diag=false);
|
||||
|
||||
bool equals(const Ring_Element& a) const;
|
||||
@@ -112,8 +102,6 @@ class Ring_Element
|
||||
// Converting to and from a vector of bigint/int's
|
||||
// I/O is assumed to be in poly rep, so from_vec it internally alters
|
||||
// the representation to the current representation
|
||||
void from_vec(const vector<bigint>& v);
|
||||
void from_vec(const vector<int>& v);
|
||||
vector<bigint> to_vec_bigint() const;
|
||||
void to_vec_bigint(vector<bigint>& v) const;
|
||||
|
||||
@@ -136,8 +124,18 @@ class Ring_Element
|
||||
|
||||
// This gets the constant term of the poly rep as a modp element
|
||||
modp get_constant() const;
|
||||
modp get_element(int i) const { return element[i]; }
|
||||
void set_element(int i,const modp& a) { element[i]=a; }
|
||||
modp get_element(int i) const
|
||||
{
|
||||
if (element.empty())
|
||||
return {};
|
||||
else
|
||||
return element[i];
|
||||
}
|
||||
void set_element(int i,const modp& a)
|
||||
{
|
||||
allocate();
|
||||
element[i] = a;
|
||||
}
|
||||
|
||||
/* Pack and unpack into an octetStream
|
||||
* For unpack we assume the FFTD has been assigned correctly already
|
||||
@@ -164,7 +162,11 @@ class RingWriteIterator : public WriteConversionIterator
|
||||
public:
|
||||
RingWriteIterator(Ring_Element& element) :
|
||||
WriteConversionIterator(element.element, element.FFTD->get_prD()),
|
||||
element(element), rep(element.rep) { element.rep = polynomial; }
|
||||
element(element), rep(element.rep)
|
||||
{
|
||||
element.rep = polynomial;
|
||||
element.allocate();
|
||||
}
|
||||
~RingWriteIterator() { element.change_rep(rep); }
|
||||
};
|
||||
|
||||
@@ -175,7 +177,11 @@ class RingReadIterator : public ConversionIterator
|
||||
public:
|
||||
RingReadIterator(const Ring_Element& element) :
|
||||
ConversionIterator(this->element.element, element.FFTD->get_prD()),
|
||||
element(element) { this->element.change_rep(polynomial); }
|
||||
element(element)
|
||||
{
|
||||
this->element.change_rep(polynomial);
|
||||
this->element.allocate();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -189,10 +195,13 @@ void Ring_Element::from(const Generator<T>& generator)
|
||||
RepType t=rep;
|
||||
rep=polynomial;
|
||||
T tmp;
|
||||
modp tmp2;
|
||||
prepare_push();
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{
|
||||
generator.get(tmp);
|
||||
element[i].convert_destroy(tmp, (*FFTD).get_prD());
|
||||
tmp2.convert_destroy(tmp, (*FFTD).get_prD());
|
||||
element.push_back(tmp2);
|
||||
}
|
||||
change_rep(t);
|
||||
}
|
||||
|
||||
@@ -48,15 +48,6 @@ void Rq_Element::partial_assign(const Rq_Element& other)
|
||||
{
|
||||
lev=other.lev;
|
||||
a.resize(other.a.size());
|
||||
for (size_t i = 0; i < a.size(); i++)
|
||||
a[i].partial_assign(other.a[i]);
|
||||
}
|
||||
|
||||
void Rq_Element::assign(const Rq_Element& other)
|
||||
{
|
||||
partial_assign(other);
|
||||
for (int i=0; i<=lev; ++i)
|
||||
a[i] = other.a[i];
|
||||
}
|
||||
|
||||
void Rq_Element::negate()
|
||||
@@ -134,20 +125,6 @@ bool Rq_Element::equals(const Rq_Element& other) const
|
||||
}
|
||||
|
||||
|
||||
void Rq_Element::from_vec(const vector<bigint>& v,int level)
|
||||
{
|
||||
set_level(level);
|
||||
for (int i=0;i<=lev;++i)
|
||||
a[i].from_vec(v);
|
||||
}
|
||||
|
||||
void Rq_Element::from_vec(const vector<int>& v,int level)
|
||||
{
|
||||
set_level(level);
|
||||
for (int i=0;i<=lev;++i)
|
||||
a[i].from_vec(v);
|
||||
}
|
||||
|
||||
vector<bigint> Rq_Element::to_vec_bigint() const
|
||||
{
|
||||
vector<bigint> v;
|
||||
|
||||
@@ -44,7 +44,6 @@ protected:
|
||||
void assign_zero(const vector<FFT_Data>& prd);
|
||||
void assign_zero();
|
||||
void assign_one();
|
||||
void assign(const Rq_Element& e);
|
||||
void partial_assign(const Rq_Element& e);
|
||||
|
||||
// Must be careful not to call by mistake
|
||||
@@ -85,10 +84,6 @@ protected:
|
||||
a[1] = Ring_Element(prd[1], r, b1);
|
||||
}
|
||||
|
||||
// Destructor
|
||||
~Rq_Element()
|
||||
{ ; }
|
||||
|
||||
const Ring_Element& get(int i) const { return a[i]; }
|
||||
|
||||
/* Functional Operators */
|
||||
@@ -131,8 +126,6 @@ protected:
|
||||
void partial_assign(const Rq_Element& a, const Rq_Element& b);
|
||||
|
||||
// Converting to and from a vector of bigint's Again I/O is in poly rep
|
||||
void from_vec(const vector<bigint>& v,int level=-1);
|
||||
void from_vec(const vector<int>& v,int level=-1);
|
||||
vector<bigint> to_vec_bigint() const;
|
||||
void to_vec_bigint(vector<bigint>& v) const;
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine,
|
||||
|
||||
if (not error.empty())
|
||||
{
|
||||
cerr << "Running secrets generation because " << error << endl;
|
||||
cerr << "Running secrets generation because no suitable material "
|
||||
"from a previous run was found (" << error << ")" << endl;
|
||||
setup.key_and_mac_generation(P, machine, num_runs, V());
|
||||
|
||||
ofstream output(filename);
|
||||
|
||||
@@ -109,11 +109,11 @@ DistKeyGen::DistKeyGen(const FHE_Params& params, const bigint& p) :
|
||||
*/
|
||||
void DistKeyGen::Gen_Random_Data(PRNG& G)
|
||||
{
|
||||
secret.from_vec(params.sampleHwt(G));
|
||||
secret.from(GaussianGenerator<bigint>(params.get_DG(), G));
|
||||
rc1.generate(G);
|
||||
rc2.generate(G);
|
||||
a.randomize(G);
|
||||
e.from_vec(params.sampleGaussian(G));
|
||||
e.from(GaussianGenerator<bigint>(params.get_DG(), G));
|
||||
}
|
||||
|
||||
DistKeyGen& DistKeyGen::operator+=(const DistKeyGen& other)
|
||||
|
||||
@@ -45,7 +45,7 @@ template <class FD>
|
||||
void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role)
|
||||
{
|
||||
octetStream o;
|
||||
o.reset_write_head();
|
||||
|
||||
if (role & SENDER)
|
||||
{
|
||||
|
||||
@@ -36,6 +36,8 @@ class Multiplier
|
||||
size_t volatile_capacity;
|
||||
MemoryUsage memory_usage;
|
||||
|
||||
octetStream o;
|
||||
|
||||
public:
|
||||
Multiplier(int offset, PairwiseGenerator<FD>& generator);
|
||||
Multiplier(int offset, PairwiseMachine& machine, Player& P,
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
/*
|
||||
* Player-Offline.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef FHEOFFLINE_PLAYER_OFFLINE_H_
|
||||
#define FHEOFFLINE_PLAYER_OFFLINE_H_
|
||||
|
||||
class thread_info
|
||||
{
|
||||
public:
|
||||
|
||||
int thread_num;
|
||||
int covert;
|
||||
Names* Nms;
|
||||
FHE_PK* pk_p;
|
||||
FHE_PK* pk_2;
|
||||
FHE_SK* sk_p;
|
||||
FHE_SK* sk_2;
|
||||
Ciphertext *calphap;
|
||||
Ciphertext *calpha2;
|
||||
gfp *alphapi;
|
||||
gf2n_short *alpha2i;
|
||||
|
||||
FFT_Data *FTD;
|
||||
P2Data *P2D;
|
||||
|
||||
int nm2,nmp,nb2,nbp,ni2,nip,ns2,nsp,nvp;
|
||||
bool skip_2() { return nm2 + ni2 + nb2 + ns2 == 0; }
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_PLAYER_OFFLINE_H_ */
|
||||
@@ -589,7 +589,7 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
|
||||
P.receive_player(j, cleartexts);
|
||||
C.resize(personal_EC.machine->sec, pk.get_params());
|
||||
Verifier<FD>(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts,
|
||||
cleartexts, pk, false);
|
||||
cleartexts, pk);
|
||||
}
|
||||
|
||||
inputs[j].clear();
|
||||
|
||||
@@ -88,6 +88,7 @@ public:
|
||||
|
||||
bool Proof::check_bounds(T& z, X& t, int i) const
|
||||
{
|
||||
(void)i;
|
||||
unsigned int j,k;
|
||||
|
||||
// Check Bound 1 and Bound 2
|
||||
@@ -99,9 +100,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
|
||||
auto& te = z[j];
|
||||
if (plain_checker.outside(te, dist))
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Fail on Check 1 " << i << " " << j << endl;
|
||||
cout << te << " " << plain_check << endl;
|
||||
cout << tau << " " << sec << " " << n_proofs << endl;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -113,9 +116,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const
|
||||
auto& te = coeffs.at(j);
|
||||
if (rand_checker.outside(te, dist))
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl;
|
||||
cout << te << " " << rand_check << endl;
|
||||
cout << rho << " " << sec << " " << n_proofs << endl;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "Tools/random.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Math/modp.hpp"
|
||||
#include "FHE/AddableVector.hpp"
|
||||
|
||||
|
||||
template <class FD, class U>
|
||||
@@ -28,7 +29,7 @@ Prover<FD,U>::Prover(Proof& proof, const FD& FieldD) :
|
||||
template <class FD, class U>
|
||||
void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
|
||||
const AddableVector<Ciphertext>& c,
|
||||
const FHE_PK& pk, bool binary)
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
size_t allocate = 3 * c.size() * c[0].report_size(USED);
|
||||
ciphertexts.resize_precise(allocate);
|
||||
@@ -51,7 +52,7 @@ void Prover<FD,U>::Stage_1(const Proof& P, octetStream& ciphertexts,
|
||||
// AE.randomize(Diag,binary);
|
||||
// rd=RandPoly(phim,bd<<1);
|
||||
// y[i]=AE.plaintext()+pr*rd;
|
||||
y[i].randomize(G, P.B_plain_length, P.get_diagonal(), binary);
|
||||
y[i].randomize(G, P.B_plain_length, P.get_diagonal());
|
||||
if (P.get_diagonal())
|
||||
assert(y[i].is_diagonal());
|
||||
s[i].resize(3, P.phim);
|
||||
@@ -114,8 +115,7 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
|
||||
const FHE_PK& pk,
|
||||
const AddableVector<Ciphertext>& c,
|
||||
const vector<U>& x,
|
||||
const Proof::Randomness& r,
|
||||
bool binary)
|
||||
const Proof::Randomness& r)
|
||||
{
|
||||
// AElement<T> AE;
|
||||
// for (i=0; i<P.sec; i++)
|
||||
@@ -130,13 +130,15 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
|
||||
int cnt=0;
|
||||
while (!ok)
|
||||
{ cnt++;
|
||||
Stage_1(P,ciphertexts,c,pk,binary);
|
||||
Stage_1(P,ciphertexts,c,pk);
|
||||
P.set_challenge(ciphertexts);
|
||||
// Check check whether we are OK, or whether we should abort
|
||||
ok = Stage_2(P,cleartexts,x,r,pk);
|
||||
}
|
||||
#ifdef VERBOSE
|
||||
if (cnt > 1)
|
||||
cout << "\t\tNumber iterations of prover = " << cnt << endl;
|
||||
#endif
|
||||
return report_size(CAPACITY) + volatile_memory;
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,7 @@ public:
|
||||
Prover(Proof& proof, const FD& FieldD);
|
||||
|
||||
void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector<Ciphertext>& c,
|
||||
const FHE_PK& pk,
|
||||
bool binary = false);
|
||||
const FHE_PK& pk);
|
||||
|
||||
bool Stage_2(Proof& P, octetStream& cleartexts,
|
||||
const vector<U>& x,
|
||||
@@ -40,8 +39,7 @@ public:
|
||||
const FHE_PK& pk,
|
||||
const AddableVector<Ciphertext>& c,
|
||||
const vector<U>& x,
|
||||
const Proof::Randomness& r,
|
||||
bool binary=false);
|
||||
const Proof::Randomness& r);
|
||||
|
||||
size_t report_size(ReportType type);
|
||||
void report_size(ReportType type, MemoryUsage& res);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "Protocols/MAC_Check.h"
|
||||
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
template<class T, class FD, class S>
|
||||
SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) :
|
||||
@@ -63,7 +64,10 @@ SimpleEncCommitFactory<FD>::SimpleEncCommitFactory(const FHE_PK& pk,
|
||||
template <class FD>
|
||||
SimpleEncCommitFactory<FD>::~SimpleEncCommitFactory()
|
||||
{
|
||||
cout << "EncCommit called " << n_calls << " times" << endl;
|
||||
#ifdef VERBOSE_HE
|
||||
if (n_calls > 0)
|
||||
cout << "EncCommit called " << n_calls << " times" << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
@@ -131,7 +135,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
|
||||
Prover<FD, Plaintext_<FD> > prover(proof, FTD);
|
||||
#endif
|
||||
size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts,
|
||||
pk, c, m, r, false);
|
||||
pk, c, m, r);
|
||||
timers["Proving"].stop();
|
||||
|
||||
if (proof.top_gear)
|
||||
@@ -192,7 +196,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
|
||||
#endif
|
||||
timers["Verifying"].start();
|
||||
verifier.NIZKPoK(others_ciphertexts, ciphertexts,
|
||||
cleartexts, get_pk_for_verification(i), false);
|
||||
cleartexts, get_pk_for_verification(i));
|
||||
timers["Verifying"].stop();
|
||||
add_ciphertexts(others_ciphertexts, i);
|
||||
this->memory_usage.update("verifier", verifier.report_size(CAPACITY));
|
||||
@@ -251,7 +255,7 @@ void SummingEncCommit<FD>::create_more()
|
||||
#endif
|
||||
this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof);
|
||||
this->timers["Stage 1 of proof"].start();
|
||||
prover.Stage_1(proof, ciphertexts, this->c, this->pk, false);
|
||||
prover.Stage_1(proof, ciphertexts, this->c, this->pk);
|
||||
this->timers["Stage 1 of proof"].stop();
|
||||
|
||||
this->c.unpack(ciphertexts, this->pk);
|
||||
@@ -291,8 +295,10 @@ void SummingEncCommit<FD>::create_more()
|
||||
|
||||
for (int i = 1; i < P.num_players(); i++)
|
||||
{
|
||||
#ifdef VERBOSE_HE
|
||||
cout << "Sending cleartexts with " << 1e-9 * cleartexts.get_length()
|
||||
<< " GB in round " << i << endl;
|
||||
#endif
|
||||
TimeScope(this->timers["Exchanging cleartexts"]);
|
||||
P.pass_around(cleartexts);
|
||||
preimages.add(cleartexts);
|
||||
@@ -312,7 +318,7 @@ void SummingEncCommit<FD>::create_more()
|
||||
Verifier<FD> verifier(proof);
|
||||
#endif
|
||||
verifier.Stage_2(this->c, ciphertexts, cleartexts,
|
||||
this->pk, false);
|
||||
this->pk);
|
||||
this->timers["Verifying"].stop();
|
||||
this->cnt = proof.U - 1;
|
||||
|
||||
|
||||
@@ -25,7 +25,10 @@ bool Check_Decoding(const Plaintext<T,FD,S>& AE,bool Diag)
|
||||
// return false;
|
||||
// }
|
||||
if (Diag && !AE.is_diagonal())
|
||||
{ cout << "Fail Check 5 " << endl;
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Fail Check 5 " << endl;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -62,7 +65,7 @@ template <class FD>
|
||||
void Verifier<FD>::Stage_2(
|
||||
AddableVector<Ciphertext>& c,octetStream& ciphertexts,
|
||||
octetStream& cleartexts,
|
||||
const FHE_PK& pk,bool binary)
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
unsigned int i, V;
|
||||
|
||||
@@ -90,18 +93,19 @@ void Verifier<FD>::Stage_2(
|
||||
rc.assign(t[0], t[1], t[2]);
|
||||
pk.encrypt(d2,z,rc);
|
||||
if (!(d1 == d2))
|
||||
{ cout << "Fail Check 6 " << i << endl;
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "Fail Check 6 " << i << endl;
|
||||
#endif
|
||||
throw runtime_error("ciphertexts don't match");
|
||||
}
|
||||
if (!Check_Decoding(z,P.get_diagonal(),FieldD))
|
||||
{ cout << "\tCheck : " << i << endl;
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
cout << "\tCheck : " << i << endl;
|
||||
#endif
|
||||
throw runtime_error("cleartext isn't diagonal");
|
||||
}
|
||||
if (binary && !z.is_binary())
|
||||
{
|
||||
cout << "Not binary " << i << endl;
|
||||
throw runtime_error("cleartext isn't binary");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,17 +116,15 @@ void Verifier<FD>::Stage_2(
|
||||
template <class FD>
|
||||
void Verifier<FD>::NIZKPoK(AddableVector<Ciphertext>& c,
|
||||
octetStream& ciphertexts, octetStream& cleartexts,
|
||||
const FHE_PK& pk,
|
||||
bool binary)
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
P.set_challenge(ciphertexts);
|
||||
|
||||
Stage_2(c,ciphertexts,cleartexts,pk,binary);
|
||||
Stage_2(c,ciphertexts,cleartexts,pk);
|
||||
|
||||
if (P.top_gear)
|
||||
{
|
||||
assert(not P.get_diagonal());
|
||||
assert(not binary);
|
||||
c += c;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ public:
|
||||
|
||||
void Stage_2(
|
||||
AddableVector<Ciphertext>& c, octetStream& ciphertexts,
|
||||
octetStream& cleartexts,const FHE_PK& pk,bool binary=false);
|
||||
octetStream& cleartexts,const FHE_PK& pk);
|
||||
|
||||
/* This is the non-interactive version using the ROM
|
||||
- Creates space for all output values
|
||||
- Diag flag mirrors that in Prover
|
||||
*/
|
||||
void NIZKPoK(AddableVector<Ciphertext>& c,octetStream& ciphertexts,octetStream& cleartexts,
|
||||
const FHE_PK& pk,bool binary=false);
|
||||
const FHE_PK& pk);
|
||||
|
||||
size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); }
|
||||
};
|
||||
|
||||
46
GC/CcdPrep.h
46
GC/CcdPrep.h
@@ -19,13 +19,13 @@ template<class T>
|
||||
class CcdPrep : public BufferPrep<T>
|
||||
{
|
||||
typename T::part_type::LivePrep part_prep;
|
||||
typename T::part_type::MAC_Check part_MC;
|
||||
SubProcessor<typename T::part_type>* part_proc;
|
||||
ShareThread<T>& thread;
|
||||
|
||||
public:
|
||||
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
|
||||
BufferPrep<T>(usage), part_prep(usage), part_proc(0), thread(thread)
|
||||
BufferPrep<T>(usage), part_prep(usage, thread), part_proc(0),
|
||||
thread(thread)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -34,17 +34,9 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
~CcdPrep()
|
||||
{
|
||||
if (part_proc)
|
||||
delete part_proc;
|
||||
}
|
||||
~CcdPrep();
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
part_proc = new SubProcessor<typename T::part_type>(part_MC,
|
||||
part_prep, protocol.get_part().P);
|
||||
}
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
|
||||
Preprocessing<typename T::part_type>& get_part()
|
||||
{
|
||||
@@ -53,7 +45,16 @@ public:
|
||||
|
||||
void buffer_triples()
|
||||
{
|
||||
throw not_implemented();
|
||||
assert(part_proc);
|
||||
this->triples.push_back({});
|
||||
for (auto& x : this->triples.back())
|
||||
x.resize_regs(T::default_length);
|
||||
for (int i = 0; i < T::default_length; i++)
|
||||
{
|
||||
auto triple = part_prep.get_triple(1);
|
||||
for (int j = 0; j < 3; j++)
|
||||
this->triples.back()[j].get_bit(j) = triple[j];
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_bits()
|
||||
@@ -72,6 +73,25 @@ public:
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void buffer_inputs(int player)
|
||||
{
|
||||
this->inputs[player].push_back({});
|
||||
this->inputs[player].back().share.resize_regs(T::default_length);
|
||||
for (int i = 0; i < T::default_length; i++)
|
||||
{
|
||||
typename T::part_type::open_type tmp;
|
||||
part_prep.get_input(this->inputs[player].back().share.get_reg(i),
|
||||
tmp, player);
|
||||
this->inputs[player].back().value ^=
|
||||
(typename T::clear(tmp.get_bit(0)) << i);
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_sent()
|
||||
{
|
||||
return part_prep.data_sent();
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
33
GC/CcdPrep.hpp
Normal file
33
GC/CcdPrep.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* CcdPrep.hpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_CCDPREP_HPP_
|
||||
#define GC_CCDPREP_HPP_
|
||||
|
||||
#include "CcdPrep.h"
|
||||
|
||||
#include "Processor/Processor.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
CcdPrep<T>::~CcdPrep()
|
||||
{
|
||||
if (part_proc)
|
||||
delete part_proc;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
assert(thread.MC);
|
||||
part_proc = new SubProcessor<typename T::part_type>(
|
||||
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_CCDPREP_HPP_ */
|
||||
@@ -34,11 +34,6 @@ public:
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
return DATA_GF2;
|
||||
}
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "CCD";
|
||||
|
||||
@@ -70,8 +70,6 @@ public:
|
||||
static const true_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2; }
|
||||
|
||||
static MC* new_mc(mac_key_type key) { return new MC(key); }
|
||||
|
||||
static void store_clear_in_dynamic(Memory<DynamicType>& mem,
|
||||
|
||||
@@ -39,11 +39,6 @@ public:
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
return DATA_GF2;
|
||||
}
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "Malicious CCD";
|
||||
|
||||
32
GC/NoShare.h
32
GC/NoShare.h
@@ -54,6 +54,11 @@ public:
|
||||
return "no";
|
||||
}
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
static void fail()
|
||||
{
|
||||
throw runtime_error("VM does not support binary circuits");
|
||||
@@ -101,16 +106,10 @@ public:
|
||||
typedef NoValue clear;
|
||||
typedef NoValue mac_key_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
typedef NoShare part_type;
|
||||
typedef NoShare small_type;
|
||||
|
||||
typedef BlackHole out_type;
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static const bool needs_ot = false;
|
||||
static const bool expensive_triples = false;
|
||||
static const bool is_real = false;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
@@ -118,21 +117,6 @@ public:
|
||||
return new MC;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static void generate_mac_key(mac_key_type, T)
|
||||
{
|
||||
}
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "";
|
||||
}
|
||||
|
||||
static string type_string()
|
||||
{
|
||||
return "no";
|
||||
@@ -155,7 +139,6 @@ public:
|
||||
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
|
||||
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
|
||||
@@ -166,11 +149,8 @@ public:
|
||||
|
||||
void load_clear(Integer, Integer) { fail(); }
|
||||
void random_bit() { fail(); }
|
||||
void and_(int, NoShare&, NoShare&, bool) { fail(); }
|
||||
void xor_(int, NoShare&, NoShare&) { fail(); }
|
||||
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void reveal(Integer, Integer) { fail(); }
|
||||
|
||||
void assign(const char*) { fail(); }
|
||||
|
||||
@@ -183,13 +163,11 @@ public:
|
||||
NoShare operator-(const NoShare&) const { fail(); return {}; }
|
||||
NoShare operator*(const NoValue&) const { fail(); return {}; }
|
||||
|
||||
NoShare operator+(int) const { fail(); return {}; }
|
||||
NoShare operator&(int) const { fail(); return {}; }
|
||||
NoShare operator>>(int) const { fail(); return {}; }
|
||||
|
||||
NoShare& operator+=(const NoShare&) { fail(); return *this; }
|
||||
|
||||
NoShare lsb() const { fail(); return {}; }
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void invert(int, NoShare) { fail(); }
|
||||
|
||||
@@ -88,7 +88,7 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
input.reset_all(P);
|
||||
for (size_t i = begin; i < end; i++)
|
||||
{
|
||||
typename T::clear x[2];
|
||||
typename T::open_type x[2];
|
||||
for (int j = 0; j < 2; j++)
|
||||
this->get_input(triples[i][j], x[j], input_player);
|
||||
if (P.my_num() == input_player)
|
||||
|
||||
@@ -84,6 +84,7 @@ public:
|
||||
|
||||
void xors(const vector<int>& args);
|
||||
void xors(const vector<int>& args, size_t start, size_t end);
|
||||
void xorc(const ::BaseInstruction& instruction);
|
||||
void nots(const ::BaseInstruction& instruction);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
|
||||
@@ -18,6 +18,7 @@ using namespace std;
|
||||
|
||||
#include "GC/Machine.hpp"
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "Processor/IntInput.hpp"
|
||||
#include "Math/bigint.hpp"
|
||||
|
||||
namespace GC
|
||||
@@ -82,8 +83,12 @@ U GC::Processor<T>::get_long_input(const int* params,
|
||||
{
|
||||
if (not T::actual_inputs)
|
||||
return {};
|
||||
U res = input_proc.get_input<FixInput_<U>>(interactive,
|
||||
¶ms[1]).items[0];
|
||||
U res;
|
||||
if (params[1] == 0)
|
||||
res = input_proc.get_input<IntInput<U>>(interactive, 0).items[0];
|
||||
else
|
||||
res = input_proc.get_input<FixInput_<U>>(interactive,
|
||||
¶ms[1]).items[0];
|
||||
int n_bits = *params;
|
||||
check_input(res, n_bits);
|
||||
return res;
|
||||
@@ -229,6 +234,18 @@ void Processor<T>::xors(const vector<int>& args, size_t start, size_t end)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::xorc(const ::BaseInstruction& instruction)
|
||||
{
|
||||
int total = instruction.get_n();
|
||||
for (int i = 0; i < DIV_CEIL(total, T::default_length); i++)
|
||||
{
|
||||
int n = min(T::default_length, total - i * T::default_length);
|
||||
C[instruction.get_r(0) + i] = BitVec(C[instruction.get_r(1) + i]).mask(n)
|
||||
^ BitVec(C[instruction.get_r(2) + i]).mask(n);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::nots(const ::BaseInstruction& instruction)
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#define GC_REP4SECRET_H_
|
||||
|
||||
#include "ShareSecret.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
#include "Protocols/Rep4MC.h"
|
||||
#include "Protocols/Rep4Share.h"
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
/*
|
||||
* ReplicatedPrep.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <GC/SemiHonestRepPrep.h>
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
@@ -119,7 +119,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
|
||||
try
|
||||
{
|
||||
read_mac_key(get_prep_sub_dir<T>(PREP_DIR, network_opts.nplayers), this->N,
|
||||
read_mac_key(
|
||||
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
|
||||
this->N,
|
||||
this->mac_key);
|
||||
}
|
||||
catch (exception& e)
|
||||
|
||||
@@ -38,6 +38,8 @@ template<class U>
|
||||
class ShareSecret
|
||||
{
|
||||
public:
|
||||
typedef U whole_type;
|
||||
|
||||
typedef Memory<U> DynamicMemory;
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ShareParty.h"
|
||||
#include "ShareThread.hpp"
|
||||
#include "Thread.hpp"
|
||||
#include "VectorProtocol.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -29,7 +29,8 @@ ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, DataPositions&
|
||||
*static_cast<Preprocessing<T>*>(new typename T::LivePrep(
|
||||
usage, *this)) :
|
||||
*static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N,
|
||||
get_prep_sub_dir<T>(PREP_DIR, N.num_players()), usage)))
|
||||
get_prep_sub_dir<T>(PREP_DIR, N.num_players()),
|
||||
usage, BaseMachine::thread_num)))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
/*
|
||||
* TinierPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINIERPREP_H_
|
||||
#define GC_TINIERPREP_H_
|
||||
|
||||
#include "TinyPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class TinierPrep : public TinyPrep<T>
|
||||
{
|
||||
public:
|
||||
TinierPrep(DataPositions& usage, ShareThread<T>& thread,
|
||||
bool amplify = true) :
|
||||
TinyPrep<T>(usage, thread, amplify)
|
||||
{
|
||||
}
|
||||
|
||||
TinierPrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
TinierPrep(usage, ShareThread<T>::s())
|
||||
{
|
||||
}
|
||||
|
||||
void buffer_inputs(int player)
|
||||
{
|
||||
this->buffer_inputs_(player, this->triple_generator);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_TINIERPREP_H_ */
|
||||
@@ -15,6 +15,9 @@ namespace GC
|
||||
{
|
||||
|
||||
template<class T> class TinierPrep;
|
||||
template<class T> class VectorProtocol;
|
||||
template<class T> class CcdPrep;
|
||||
template<class T> class VectorInput;
|
||||
|
||||
template<class T>
|
||||
class TinierSecret : public VectorSecret<TinierShare<T>>
|
||||
@@ -25,9 +28,9 @@ class TinierSecret : public VectorSecret<TinierShare<T>>
|
||||
public:
|
||||
typedef TinyMC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef Beaver<This> Protocol;
|
||||
typedef ::Input<This> Input;
|
||||
typedef TinierPrep<This> LivePrep;
|
||||
typedef VectorProtocol<This> Protocol;
|
||||
typedef VectorInput<This> Input;
|
||||
typedef CcdPrep<This> LivePrep;
|
||||
typedef Memory<This> DynamicMemory;
|
||||
|
||||
typedef NPartyTripleGenerator<This> TripleGenerator;
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "Protocols/Share.h"
|
||||
#include "Math/Bit.h"
|
||||
#include "TinierSharePrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class TinierSecret;
|
||||
template<class T> class TinierSharePrep;
|
||||
|
||||
template<class T>
|
||||
class TinierShare: public Share_<SemiShare<Bit>, SemiShare<T>>,
|
||||
@@ -55,6 +55,11 @@ public:
|
||||
return "Tinier";
|
||||
}
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "TT";
|
||||
}
|
||||
|
||||
static ShareThread<TinierSecret<T>>& get_party()
|
||||
{
|
||||
return ShareThread<TinierSecret<T>>::s();
|
||||
@@ -103,9 +108,7 @@ public:
|
||||
|
||||
void random()
|
||||
{
|
||||
TinierSecret<T> tmp;
|
||||
get_party().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
*this = get_party().DataF.get_part().get_bit();
|
||||
}
|
||||
|
||||
This lsb() const
|
||||
|
||||
@@ -21,18 +21,26 @@ template<class T>
|
||||
class TinierSharePrep : public PersonalPrep<T>
|
||||
{
|
||||
typename T::TripleGenerator* triple_generator;
|
||||
typename T::whole_type::TripleGenerator* real_triple_generator;
|
||||
MascotParams params;
|
||||
|
||||
TinierPrep<TinierSecret<typename T::mac_key_type>> whole_prep;
|
||||
typedef typename T::whole_type secret_type;
|
||||
ShareThread<secret_type>& thread;
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_bits() { throw not_implemented(); }
|
||||
void buffer_bits();
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
void buffer_inputs(int player);
|
||||
|
||||
void buffer_secret_triples();
|
||||
|
||||
void init_real(Player& P);
|
||||
|
||||
public:
|
||||
TinierSharePrep(DataPositions& usage, ShareThread<secret_type>& thread,
|
||||
int input_player = PersonalPrep<T>::SECURE);
|
||||
TinierSharePrep(DataPositions& usage, int input_player =
|
||||
PersonalPrep<T>::SECURE);
|
||||
TinierSharePrep(SubProcessor<T>*, DataPositions& usage);
|
||||
|
||||
@@ -15,10 +15,16 @@ namespace GC
|
||||
|
||||
template<class T>
|
||||
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) :
|
||||
TinierSharePrep<T>(usage, ShareThread<secret_type>::s(), input_player)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage,
|
||||
ShareThread<secret_type>& thread, int input_player) :
|
||||
PersonalPrep<T>(usage, input_player), triple_generator(0),
|
||||
whole_prep(usage,
|
||||
ShareThread<TinierSecret<typename T::mac_key_type>>::s(),
|
||||
input_player == PersonalPrep<T>::SECURE)
|
||||
real_triple_generator(0),
|
||||
thread(thread)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -33,6 +39,8 @@ TinierSharePrep<T>::~TinierSharePrep()
|
||||
{
|
||||
if (triple_generator)
|
||||
delete triple_generator;
|
||||
if (real_triple_generator)
|
||||
delete real_triple_generator;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -44,15 +52,14 @@ void TinierSharePrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
params.generateMACs = true;
|
||||
params.amplify = false;
|
||||
params.check = false;
|
||||
auto& thread = ShareThread<TinierSecret<typename T::mac_key_type>>::s();
|
||||
auto& thread = ShareThread<typename T::whole_type>::s();
|
||||
triple_generator = new typename T::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
|
||||
OnlineOptions::singleton.batch_size
|
||||
* TinierSecret<typename T::mac_key_type>::default_length, 1,
|
||||
OnlineOptions::singleton.batch_size, 1,
|
||||
params, thread.MC->get_alphai(), &protocol.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
this->inputs.resize(thread.P->num_players());
|
||||
whole_prep.init(*thread.P);
|
||||
init_real(protocol.P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -63,12 +70,8 @@ void TinierSharePrep<T>::buffer_triples()
|
||||
this->buffer_personal_triples();
|
||||
return;
|
||||
}
|
||||
|
||||
array<TinierSecret<typename T::mac_key_type>, 3> whole;
|
||||
whole_prep.get(DATA_TRIPLE, whole.data());
|
||||
for (size_t i = 0; i < whole[0].get_regs().size(); i++)
|
||||
this->triples.push_back(
|
||||
{{ whole[0].get_reg(i), whole[1].get_reg(i), whole[2].get_reg(i) }});
|
||||
else
|
||||
buffer_secret_triples();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -81,12 +84,21 @@ void TinierSharePrep<T>::buffer_inputs(int player)
|
||||
inputs.at(player).push_back(x);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::TinierSharePrep<T>::buffer_bits()
|
||||
{
|
||||
this->bits.push_back(
|
||||
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TinierSharePrep<T>::data_sent()
|
||||
{
|
||||
size_t res = whole_prep.data_sent();
|
||||
size_t res = 0;
|
||||
if (triple_generator)
|
||||
res += triple_generator->data_sent();
|
||||
if (real_triple_generator)
|
||||
res += real_triple_generator->data_sent();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
/*
|
||||
* TinyMC.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyMC.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
@@ -14,7 +14,7 @@ namespace GC
|
||||
template<class T>
|
||||
class TinyMC : public MAC_Check_Base<T>
|
||||
{
|
||||
typename T::check_type::MAC_Check part_MC;
|
||||
typename T::part_type::MAC_Check part_MC;
|
||||
PointerVector<int> sizes;
|
||||
|
||||
public:
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
/*
|
||||
* TinyPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_TINYPREP_H_
|
||||
#define GC_TINYPREP_H_
|
||||
|
||||
#include "Thread.h"
|
||||
#include "OT/MascotParams.h"
|
||||
#include "Protocols/Beaver.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class TinyPrep : public BufferPrep<T>
|
||||
{
|
||||
protected:
|
||||
ShareThread<T>& thread;
|
||||
|
||||
typename T::TripleGenerator* triple_generator;
|
||||
MascotParams params;
|
||||
|
||||
vector<array<typename T::part_type, 3>> triple_buffer;
|
||||
|
||||
const bool amplify;
|
||||
|
||||
public:
|
||||
TinyPrep(DataPositions& usage, ShareThread<T>& thread, bool amplify = true);
|
||||
~TinyPrep();
|
||||
|
||||
void set_protocol(Beaver<T>& protocol);
|
||||
void init(Player& P);
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_bits();
|
||||
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
void buffer_inputs_(int player, typename T::InputGenerator* input_generator);
|
||||
|
||||
array<T, 3> get_triple_no_count(int n_bits);
|
||||
|
||||
size_t data_sent();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TinyOnlyPrep : public TinyPrep<T>
|
||||
{
|
||||
typename T::part_type::TripleGenerator* input_generator;
|
||||
|
||||
public:
|
||||
TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread);
|
||||
~TinyOnlyPrep();
|
||||
|
||||
void set_protocol(Beaver<T>& protocol);
|
||||
|
||||
void buffer_inputs(int player)
|
||||
{
|
||||
this->buffer_inputs_(player, input_generator);
|
||||
}
|
||||
|
||||
size_t data_sent();
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_TINYPREP_H_ */
|
||||
182
GC/TinyPrep.hpp
182
GC/TinyPrep.hpp
@@ -3,7 +3,7 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyPrep.h"
|
||||
#include "TinierSharePrep.h"
|
||||
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
|
||||
@@ -11,78 +11,26 @@ namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
TinyPrep<T>::TinyPrep(DataPositions& usage, ShareThread<T>& thread,
|
||||
bool amplify) :
|
||||
BufferPrep<T>(usage), thread(thread), triple_generator(0),
|
||||
amplify(amplify)
|
||||
void TinierSharePrep<T>::init_real(Player& P)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinyOnlyPrep<T>::TinyOnlyPrep(DataPositions& usage, ShareThread<T>& thread) :
|
||||
TinyPrep<T>(usage, thread), input_generator(0)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinyPrep<T>::~TinyPrep()
|
||||
{
|
||||
if (triple_generator)
|
||||
delete triple_generator;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinyOnlyPrep<T>::~TinyOnlyPrep()
|
||||
{
|
||||
if (input_generator)
|
||||
delete input_generator;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
|
||||
{
|
||||
init(protocol.P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::init(Player& P)
|
||||
{
|
||||
params.generateMACs = true;
|
||||
params.amplify = false;
|
||||
params.check = false;
|
||||
auto& thread = ShareThread<T>::s();
|
||||
triple_generator = new typename T::TripleGenerator(
|
||||
assert(real_triple_generator == 0);
|
||||
real_triple_generator = new typename T::whole_type::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), P.N, -1,
|
||||
OnlineOptions::singleton.batch_size, 1, params,
|
||||
thread.MC->get_alphai(), &P);
|
||||
triple_generator->multi_threaded = false;
|
||||
real_triple_generator->multi_threaded = false;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyOnlyPrep<T>::set_protocol(Beaver<T>& protocol)
|
||||
void TinierSharePrep<T>::buffer_secret_triples()
|
||||
{
|
||||
TinyPrep<T>::set_protocol(protocol);
|
||||
input_generator = new typename T::part_type::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
|
||||
OnlineOptions::singleton.batch_size, 1, this->params,
|
||||
this->thread.MC->get_alphai(), &protocol.P);
|
||||
input_generator->multi_threaded = false;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_triples()
|
||||
{
|
||||
auto& triple_generator = this->triple_generator;
|
||||
auto& triple_generator = real_triple_generator;
|
||||
assert(triple_generator != 0);
|
||||
params.generateBits = false;
|
||||
vector<array<typename T::check_type, 3>> triples;
|
||||
TripleShuffleSacrifice<typename T::check_type> sacrifice;
|
||||
vector<array<T, 3>> triples;
|
||||
TripleShuffleSacrifice<T> sacrifice;
|
||||
size_t required;
|
||||
if (amplify)
|
||||
required = sacrifice.minimum_n_inputs_with_combining();
|
||||
else
|
||||
required = sacrifice.minimum_n_inputs();
|
||||
required = sacrifice.minimum_n_inputs_with_combining();
|
||||
while (triples.size() < required)
|
||||
{
|
||||
triple_generator->generatePlainTriples();
|
||||
@@ -92,9 +40,11 @@ void TinyPrep<T>::buffer_triples()
|
||||
triple_generator->valueBits[2].set_portion(i,
|
||||
triple_generator->plainTriples[i][2]);
|
||||
triple_generator->run_multipliers({});
|
||||
assert(triple_generator->plainTriples.size() != 0);
|
||||
for (size_t i = 0; i < triple_generator->plainTriples.size(); i++)
|
||||
{
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
int dl = secret_type::default_length;
|
||||
for (int j = 0; j < dl; j++)
|
||||
{
|
||||
triples.push_back({});
|
||||
for (int k = 0; k < 3; k++)
|
||||
@@ -103,10 +53,10 @@ void TinyPrep<T>::buffer_triples()
|
||||
share.set_share(
|
||||
triple_generator->plainTriples.at(i).at(k).get_bit(
|
||||
j));
|
||||
typename T::part_type::mac_type mac;
|
||||
typename T::mac_type mac;
|
||||
mac = thread.MC->get_alphai() * share.get_share();
|
||||
for (auto& multiplier : triple_generator->ot_multipliers)
|
||||
mac += multiplier->macs.at(k).at(i * T::default_length + j);
|
||||
mac += multiplier->macs.at(k).at(i * dl + j);
|
||||
share.set_mac(mac);
|
||||
}
|
||||
}
|
||||
@@ -114,104 +64,10 @@ void TinyPrep<T>::buffer_triples()
|
||||
}
|
||||
sacrifice.triple_sacrifice(triples, triples,
|
||||
*thread.P, thread.MC->get_part_MC());
|
||||
if (amplify)
|
||||
sacrifice.triple_combine(triples, triples, *thread.P,
|
||||
thread.MC->get_part_MC());
|
||||
for (size_t i = 0; i < triples.size() / T::default_length; i++)
|
||||
{
|
||||
this->triples.push_back({});
|
||||
auto& triple = this->triples.back();
|
||||
for (auto& x : triple)
|
||||
x.resize_regs(T::default_length);
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
{
|
||||
auto& source_triple = triples[j + i * T::default_length];
|
||||
for (int k = 0; k < 3; k++)
|
||||
triple[k].get_reg(j) = source_triple[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_bits()
|
||||
{
|
||||
auto tmp = BufferPrep<T>::get_random_from_inputs(thread.P->num_players());
|
||||
for (auto& bit : tmp.get_regs())
|
||||
{
|
||||
this->bits.push_back({});
|
||||
this->bits.back().resize_regs(1);
|
||||
this->bits.back().get_reg(0) = bit;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TinyPrep<T>::buffer_inputs_(int player, typename T::InputGenerator* input_generator)
|
||||
{
|
||||
auto& inputs = this->inputs;
|
||||
inputs.resize(this->thread.P->num_players());
|
||||
assert(input_generator);
|
||||
input_generator->generateInputs(player);
|
||||
assert(input_generator->inputs.size() >= T::default_length);
|
||||
for (size_t i = 0; i < input_generator->inputs.size() / T::default_length; i++)
|
||||
{
|
||||
inputs[player].push_back({});
|
||||
inputs[player].back().share.resize_regs(T::default_length);
|
||||
for (int j = 0; j < T::default_length; j++)
|
||||
{
|
||||
auto& source_input = input_generator->inputs[j
|
||||
+ i * T::default_length];
|
||||
inputs[player].back().share.get_reg(j) = source_input.share;
|
||||
inputs[player].back().value ^= typename T::open_type(
|
||||
source_input.value.get_bit(0)) << j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
array<T, 3> TinyPrep<T>::get_triple_no_count(int n_bits)
|
||||
{
|
||||
assert(n_bits > 0);
|
||||
while ((unsigned)n_bits > triple_buffer.size())
|
||||
{
|
||||
array<T, 3> tmp;
|
||||
this->get(DATA_TRIPLE, tmp.data());
|
||||
for (size_t i = 0; i < tmp[0].get_regs().size(); i++)
|
||||
{
|
||||
triple_buffer.push_back(
|
||||
{ {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} });
|
||||
}
|
||||
}
|
||||
|
||||
array<T, 3> res;
|
||||
for (int j = 0; j < 3; j++)
|
||||
res[j].resize_regs(n_bits);
|
||||
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
for (int j = 0; j < 3; j++)
|
||||
res[j].get_reg(i) = triple_buffer.back()[j];
|
||||
triple_buffer.pop_back();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TinyPrep<T>::data_sent()
|
||||
{
|
||||
size_t res = 0;
|
||||
if (triple_generator)
|
||||
res += triple_generator->data_sent();
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TinyOnlyPrep<T>::data_sent()
|
||||
{
|
||||
auto res = TinyPrep<T>::data_sent();
|
||||
if (input_generator)
|
||||
res += input_generator->data_sent();
|
||||
return res;
|
||||
sacrifice.triple_combine(triples, triples, *thread.P,
|
||||
thread.MC->get_part_MC());
|
||||
for (auto& triple : triples)
|
||||
this->triples.push_back(triple);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
/*
|
||||
* TinySecret.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinySecret.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
@@ -21,6 +21,9 @@ namespace GC
|
||||
|
||||
template<class T> class TinyOnlyPrep;
|
||||
template<class T> class TinyMC;
|
||||
template<class T> class VectorProtocol;
|
||||
template<class T> class VectorInput;
|
||||
template<class T> class CcdPrep;
|
||||
|
||||
template<class T>
|
||||
class VectorSecret : public Secret<T>
|
||||
@@ -50,11 +53,6 @@ public:
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
return BitVec::field_type();
|
||||
}
|
||||
|
||||
static int size()
|
||||
{
|
||||
return part_type::size() * default_length;
|
||||
@@ -166,9 +164,9 @@ public:
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void other_input(U& inputter, int from, int)
|
||||
void other_input(U& inputter, int from, int n_bits)
|
||||
{
|
||||
inputter.add_other(from);
|
||||
inputter.add_other(from, n_bits);
|
||||
}
|
||||
|
||||
template <class U>
|
||||
@@ -187,9 +185,9 @@ class TinySecret : public VectorSecret<TinyShare<S>>
|
||||
public:
|
||||
typedef TinyMC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef Beaver<This> Protocol;
|
||||
typedef ::Input<This> Input;
|
||||
typedef TinyOnlyPrep<This> LivePrep;
|
||||
typedef VectorProtocol<This> Protocol;
|
||||
typedef VectorInput<This> Input;
|
||||
typedef CcdPrep<This> LivePrep;
|
||||
typedef Memory<This> DynamicMemory;
|
||||
|
||||
typedef OTTripleGenerator<This> TripleGenerator;
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
/*
|
||||
* TinyShare.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TinyShare.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
} /* namespace GC */
|
||||
@@ -10,13 +10,14 @@
|
||||
#include "ShareParty.h"
|
||||
#include "Secret.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<int S> class TinySecret;
|
||||
template<class T> class ShareThread;
|
||||
template<class T> class TinierSharePrep;
|
||||
|
||||
template<int S>
|
||||
class TinyShare : public Spdz2kShare<1, S>, public ShareSecret<TinySecret<S>>
|
||||
@@ -28,12 +29,18 @@ public:
|
||||
|
||||
typedef void DynamicMemory;
|
||||
|
||||
typedef NoLivePrep<This> LivePrep;
|
||||
typedef Beaver<This> Protocol;
|
||||
typedef MAC_Check_Z2k_<This> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ::Input<This> Input;
|
||||
typedef TinierSharePrep<This> LivePrep;
|
||||
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
typedef This small_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "tiny share";
|
||||
|
||||
@@ -18,14 +18,14 @@ class VectorInput : public InputBase<T>
|
||||
deque<int> input_lengths;
|
||||
|
||||
public:
|
||||
VectorInput(typename T::MAC_Check&, Preprocessing<T>&, Player& P) :
|
||||
part_input(0, P)
|
||||
VectorInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
|
||||
part_input(MC.get_part_MC(), prep.get_part(), P)
|
||||
{
|
||||
part_input.reset_all(P);
|
||||
}
|
||||
|
||||
VectorInput(SubProcessor<T>& proc, typename T::MAC_Check&) :
|
||||
VectorInput(proc.MC, proc.DataF, proc.P)
|
||||
part_input(proc.MC, proc.DataF, proc.P)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -41,8 +41,10 @@ public:
|
||||
input_lengths.push_back(n_bits);
|
||||
}
|
||||
|
||||
void add_other(int)
|
||||
void add_other(int player, int n_bits)
|
||||
{
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
part_input.add_other(player);
|
||||
}
|
||||
|
||||
void send_mine()
|
||||
|
||||
@@ -17,6 +17,8 @@ class VectorProtocol : public ProtocolBase<T>
|
||||
typename T::part_type::Protocol part_protocol;
|
||||
|
||||
public:
|
||||
Player& P;
|
||||
|
||||
VectorProtocol(Player& P);
|
||||
|
||||
void init_mul(SubProcessor<T>* proc);
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_VECTORPROTOCOL_HPP_
|
||||
#define GC_VECTORPROTOCOL_HPP_
|
||||
|
||||
#include "VectorProtocol.h"
|
||||
|
||||
namespace GC
|
||||
@@ -10,7 +13,7 @@ namespace GC
|
||||
|
||||
template<class T>
|
||||
VectorProtocol<T>::VectorProtocol(Player& P) :
|
||||
part_protocol(P)
|
||||
part_protocol(P), P(P)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -54,3 +57,5 @@ T VectorProtocol<T>::finalize_mul(int n)
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
|
||||
#define BIT_INSTRUCTIONS \
|
||||
X(XORS, T::xors(PROC, EXTRA)) \
|
||||
X(XORCB, C0.xor_(PC1, PC2)) \
|
||||
X(XORCB, processor.xorc(instruction)) \
|
||||
X(XORCBI, C0.xor_(PC1, IMM)) \
|
||||
X(NOTS, processor.nots(INST)) \
|
||||
X(ANDRS, T::andrs(PROC, EXTRA)) \
|
||||
|
||||
@@ -20,13 +20,14 @@
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/TinierPrep.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
|
||||
@@ -8,9 +8,8 @@
|
||||
|
||||
#include "GC/TinySecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/TinyPrep.h"
|
||||
#include "GC/TinierPrep.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
@@ -29,3 +28,4 @@
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
ShamirOptions ShamirOptions::singleton;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "FHE/NTL-Subs.h"
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinierPrep.h"
|
||||
#include "GC/TinyMC.h"
|
||||
|
||||
#include "SPDZ.hpp"
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
OnlineOptions online_opts;
|
||||
Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector<string>({"localhost"}));
|
||||
OnlineOptions& online_opts = OnlineOptions::singleton;
|
||||
Names N;
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
online_opts = {opt, argc, argv};
|
||||
opt.parse(argc, argv);
|
||||
opt.syntax = string(argv[0]) + " <progname>";
|
||||
|
||||
string progname;
|
||||
if (opt.firstArgs.size() > 1)
|
||||
progname = *opt.firstArgs.at(1);
|
||||
@@ -50,36 +52,14 @@ int main(int argc, const char** argv)
|
||||
int R = ring_opts.ring_size_from_opts_or_schedule(progname);
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 128:
|
||||
Machine<FakeShare<SignedZ2<128>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 256:
|
||||
Machine<FakeShare<SignedZ2<256>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 192:
|
||||
Machine<FakeShare<SignedZ2<192>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 384:
|
||||
Machine<FakeShare<SignedZ2<384>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 512:
|
||||
Machine<FakeShare<SignedZ2<512>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
#define X(L) \
|
||||
case L: \
|
||||
Machine<FakeShare<SignedZ2<L>>, FakeShare<gf2n>>(0, N, progname, \
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \
|
||||
online_opts.live_prep, online_opts).run(); \
|
||||
break;
|
||||
X(64) X(128) X(256) X(192) X(384) X(512)
|
||||
#undef X
|
||||
default:
|
||||
cerr << "Not compiled for " << R << "-bit rings" << endl;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Machines/MalRep.hpp"
|
||||
|
||||
|
||||
23
Machines/no-party.cpp
Normal file
23
Machines/no-party.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* no-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/NoShare.h"
|
||||
|
||||
#include "Processor/OnlineMachine.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Math/gfp.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
OnlineOptions::singleton = {opt, argc, argv};
|
||||
OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton);
|
||||
OnlineOptions::singleton.finalize(opt, argc, argv);
|
||||
machine.start_networking();
|
||||
// use primes of length 65 to 128 for arithmetic computation
|
||||
machine.run<NoShare<gfp_<0, 2>>, NoShare<gf2n>>();
|
||||
}
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
|
||||
#include "Processor/FieldMachine.hpp"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user