edaBits, ChaiGear, TopGear, CCD.

This commit is contained in:
Marcel Keller
2020-03-20 20:30:29 +11:00
parent 7d44986d99
commit 92a3fb0184
285 changed files with 8251 additions and 1467 deletions

View File

@@ -233,6 +233,8 @@ public:
template <class T>
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
template <class T>
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
template <class T>
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)

View File

@@ -11,6 +11,8 @@
#include <vector>
using namespace std;
#include "Tools/CheckVector.h"
typedef unsigned long wire_id_t;
typedef unsigned long gate_id_t;
typedef unsigned int party_id_t;
@@ -37,20 +39,4 @@ public:
bool call(bool left, bool right) { return rep[2 * left + right]; }
};
template <class T>
class CheckVector : public vector<T>
{
public:
CheckVector() : vector<T>() {}
CheckVector(size_t size) : vector<T>(size) {}
CheckVector(size_t size, const T& def) : vector<T>(size, def) {}
#ifdef CHECK_SIZE
T& operator[](size_t i) { return this->at(i); }
const T& operator[](size_t i) const { return this->at(i); }
#else
T& at(size_t i) { return (*this)[i]; }
const T& at(size_t i) const { return (*this)[i]; }
#endif
};
#endif /* CIRCUIT_INC_COMMON_H_ */

View File

@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.1.5 (Mar 20, 2020)
- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338)
- Optimized daBits
- Optimized logistic regression
- Faster compilation of repetitive code (compiler option `-C`)
- ChaiGear: [HighGear](https://eprint.iacr.org/2017/1230) with covert key generation
- [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs
- Binary computation based on Shamir secret sharing
- Fixed security bug: Prove correctness of ciphertexts in input tuple generation
- Fixed security bug: Missing check in MASCOT bit generation and various binary computations
## 0.1.4 (Dec 23, 2019)
- Mixed circuit computation with secret sharing

View File

@@ -36,6 +36,8 @@ opcodes = dict(
STMSBI = 0x243,
MOVSB = 0x244,
INPUTB = 0x246,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
XORCBI = 0x210,
BITDECC = 0x211,
CONVCINT = 0x213,
@@ -49,15 +51,23 @@ opcodes = dict(
MULCBI = 0x21c,
SHRCBI = 0x21d,
SHLCBI = 0x21e,
CONVCINTVEC = 0x21f,
PRINTREGSIGNED = 0x220,
PRINTREGB = 0x221,
PRINTREGPLAINB = 0x222,
PRINTFLOATPLAINB = 0x223,
CONDPRINTSTRB = 0x224,
CONVCBIT = 0x230,
CONVCBITVEC = 0x231,
)
class xors(base.Instruction):
class BinaryVectorInstruction(base.Instruction):
is_vec = lambda self: True
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))
class xors(BinaryVectorInstruction):
code = opcodes['XORS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
@@ -73,15 +83,21 @@ class xorcbi(base.Instruction):
code = opcodes['XORCBI']
arg_format = ['cbw','cb','int']
class andrs(base.Instruction):
class andrs(BinaryVectorInstruction):
code = opcodes['ANDRS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
class ands(base.Instruction):
def add_usage(self, req_node):
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
class ands(BinaryVectorInstruction):
code = opcodes['ANDS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
class andm(base.Instruction):
def add_usage(self, req_node):
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
class andm(BinaryVectorInstruction):
code = opcodes['ANDM']
arg_format = ['int','sbw','sb','cb']
@@ -181,6 +197,31 @@ class convcbit(base.Instruction):
code = opcodes['CONVCBIT']
arg_format = ['ciw','cb']
@base.vectorize
class convcintvec(base.Instruction):
code = opcodes['CONVCINTVEC']
arg_format = tools.chain(['c'], tools.cycle(['cbw']))
class convcbitvec(BinaryVectorInstruction):
code = opcodes['CONVCBITVEC']
arg_format = ['int','ciw','cb']
def __init__(self, *args):
super(convcbitvec, self).__init__(*args)
assert(args[2].n == args[0])
args[1].set_size(args[0])
class convcbit2s(BinaryVectorInstruction):
code = opcodes['CONVCBIT2S']
arg_format = ['int','sbw','cb']
@base.vectorize
class split(base.Instruction):
code = opcodes['SPLIT']
arg_format = tools.chain(['int','s'], tools.cycle(['sbw']))
def __init__(self, *args, **kwargs):
super(split_class, self).__init__(*args, **kwargs)
assert (len(args) - 2) % args[0] == 0
class movsb(base.Instruction):
code = opcodes['MOVSB']
arg_format = ['sbw','sb']
@@ -196,9 +237,9 @@ class bitb(base.Instruction):
code = opcodes['BITB']
arg_format = ['sbw']
class reveal(base.Instruction):
class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable):
code = opcodes['REVEAL']
arg_format = ['int','cbw','sb']
arg_format = tools.cycle(['int','cbw','sb'])
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = []

View File

@@ -1,15 +1,17 @@
from Compiler.types import MemValue, read_mem_value, regint, Array, cint
from Compiler.types import _bitint, _number, _fix, _structure, _bit
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint
from Compiler.program import Tape, Program
from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint, library
from Compiler import instructions_base
import Compiler.GC.instructions as inst
import operator
import math
from functools import reduce
class bits(Tape.Register, _structure, _bit):
n = 40
size = 1
unit = 64
PreOp = staticmethod(floatingpoint.PreOpN)
decomposed = None
@staticmethod
@@ -19,9 +21,7 @@ class bits(Tape.Register, _structure, _bit):
[1 - x for x in l])]
@classmethod
def get_type(cls, length):
if length is None:
return cls
elif length == 1:
if length == 1:
return cls.bit_type
if length not in cls.types:
class bitsn(cls):
@@ -65,6 +65,11 @@ class bits(Tape.Register, _structure, _bit):
return res + suffix
else:
return self.decomposed[:n] + suffix
@staticmethod
def bit_decompose_clear(a, n_bits):
res = [cbits.get_type(a.size)() for i in range(n_bits)]
cbits.conv_cint_vec(a, *res)
return res
@classmethod
def malloc(cls, size):
return Program.prog.malloc(size, cls)
@@ -87,41 +92,61 @@ class bits(Tape.Register, _structure, _bit):
def __init__(self, value=None, n=None, size=None):
if size != 1 and size is not None:
raise Exception('invalid size for bit type: %s' % size)
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape)
self.set_length(n or self.n)
self.n = n or self.n
size = math.ceil(self.n / self.unit) if self.n != None else None
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape,
size=size)
if value is not None:
self.load_other(value)
def copy(self):
return type(self)(n=instructions_base.get_global_vector_size())
def set_length(self, n):
if n > self.max_length:
print(self.max_length)
raise Exception('too long: %d' % n)
self.n = n
def set_size(self, size):
pass
def load_other(self, other):
if isinstance(other, cint):
size = other.size
other = sum(x << i for i, x in enumerate(other))
other = other.to_regint(size)
if isinstance(other, int):
assert(self.n == other.size)
self.conv_regint_by_bit(self.n, self, other.to_regint(1))
elif isinstance(other, int):
self.set_length(self.n or util.int_len(other))
self.load_int(other)
elif isinstance(other, regint):
assert(other.size == 1)
self.conv_regint(self.n, self, other)
assert(other.size == math.ceil(self.n / self.unit))
for i, (x, y) in enumerate(zip(self, other)):
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
elif isinstance(self, type(other)) or isinstance(other, type(self)):
self.mov(self, other)
assert(self.n == other.n)
for i in range(math.ceil(self.n / self.unit)):
self.mov(self[i], other[i])
else:
try:
other = self.bit_compose(other.bit_decompose())
self.mov(self, other)
self.load_other(other)
except:
raise CompilerError('cannot convert from %s to %s' % \
(type(other), type(self)))
def long_one(self):
return 2**self.n - 1
return 2**self.n - 1 if self.n != None else None
def __repr__(self):
return '%s(%d/%d)' % \
(super(bits, self).__repr__(), self.n, type(self).n)
if self.n != None:
suffix = '%d' % self.n
if type(self).n != None and type(self).n != self.n:
suffice += '/%d' % type(self).n
else:
suffix = 'undef'
return '%s(%s)' % (super(bits, self).__repr__(), suffix)
__str__ = __repr__
def _new_by_number(self, i, size=1):
assert(size == 1)
n = min(self.unit, self.n - (i - self.i) * self.unit)
res = self.get_type(n)()
res.i = i
res.program = self.program
return res
class cbits(bits):
max_length = 64
@@ -131,6 +156,12 @@ class cbits(bits):
store_inst = (None, inst.stmcb)
bitdec = inst.bitdecc
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
conv_cint_vec = inst.convcintvec
@classmethod
def conv_regint_by_bit(cls, n, res, other):
assert n == res.n
assert n == other.size
cls.conv_cint_vec(cint(other, size=other.size), res)
types = {}
def load_int(self, value):
self.load_other(regint(value))
@@ -187,6 +218,13 @@ class cbits(bits):
if self.n > 64:
raise CompilerError('too many bits')
inst.convcbit(dest, self)
def to_regint_by_bit(self):
if self.n != None:
res = regint(size=self.n)
else:
res = regint()
inst.convcbitvec(self.n, res, self)
return res
class sbits(bits):
max_length = 128
@@ -199,6 +237,11 @@ class sbits(bits):
bitdec = inst.bitdecs
bitcom = inst.bitcoms
conv_regint = inst.convsint
@classmethod
def conv_regint_by_bit(cls, n, res, other):
tmp = cbits.get_type(n)()
tmp.conv_regint_by_bit(n, tmp, other)
res.load_other(tmp)
mov = inst.movsb
types = {}
def __init__(self, *args, **kwargs):
@@ -250,18 +293,26 @@ class sbits(bits):
self.mov(self, lower + (upper << 64))
else:
raise NotImplementedError('more than 128 bits wanted')
def load_other(self, other):
if isinstance(other, cbits) and self.n == other.n:
inst.convcbit2s(self.n, self, other)
else:
super(sbits, self).load_other(other)
@read_mem_value
def __add__(self, other):
if isinstance(other, int):
if isinstance(other, int) or other is None:
return self.xor_int(other)
else:
if not isinstance(other, sbits):
other = self.conv(other)
n = min(self.n, other.n)
if self.n is None or other.n is None:
assert self.n == other.n
n = None
else:
n = min(self.n, other.n)
res = self.new(n=n)
inst.xors(n, res, self, other)
max_n = max(self.n, other.n)
if max_n > n:
if self.n != None and max(self.n, other.n) > n:
if self.n > n:
longer = self
else:
@@ -293,17 +344,13 @@ class sbits(bits):
return res
except AttributeError:
return NotImplemented
@read_mem_value
def __rmul__(self, other):
if isinstance(other, cbits):
return other * self
else:
return self.mul_int(other)
__rmul__ = __mul__
@read_mem_value
def __and__(self, other):
if util.is_zero(other):
return 0
elif util.is_all_ones(other, self.n):
elif util.is_all_ones(other, self.n) or \
(other is None and self.n == None):
return self
res = self.new(n=self.n)
if not isinstance(other, sbits):
@@ -314,6 +361,7 @@ class sbits(bits):
assert(self.n == other.n)
inst.ands(self.n, res, self, other)
return res
__rand__ = __and__
def xor_int(self, other):
if other == 0:
return self
@@ -326,6 +374,7 @@ class sbits(bits):
for x,y in zip(self_bits, other_bits)] \
+ extra_bits)
def mul_int(self, other):
assert(util.is_constant(other))
if other == 0:
return 0
elif other == 1:
@@ -344,14 +393,19 @@ class sbits(bits):
# res = type(self)(n=self.n)
# inst.nots(res, self)
# return res
one = self.new(value=self.long_one(), n=self.n)
if self.n == None or self.n > self.unit:
one = self.get_type(self.n)()
self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
else:
one = self.new(value=self.long_one(), n=self.n)
return self + one
def __neg__(self):
return self
def reveal(self):
if self.n > self.clear_type.max_length:
raise Exception('too long to reveal')
res = self.clear_type(n=self.n)
if self.n == None or \
self.n > max(self.max_length, self.clear_type.max_length):
assert(self.unit == self.clear_type.unit)
res = self.clear_type.get_type(self.n)()
inst.reveal(self.n, res, self)
return res
def equal(self, other, n=None):
@@ -395,8 +449,11 @@ class sbits(bits):
@staticmethod
def bit_adder(*args, **kwargs):
return sbitint.bit_adder(*args, **kwargs)
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
class sbitvec(object):
class sbitvec(_vec):
@classmethod
def get_type(cls, n):
return cls
@@ -414,8 +471,27 @@ class sbitvec(object):
def from_matrix(cls, matrix):
# any number of rows, limited number of columns
return cls.combine(cls(row) for row in matrix)
def __init__(self, elements=None):
if elements is not None:
def __init__(self, elements=None, length=None):
if length:
assert isinstance(elements, sint)
if Program.prog.use_split():
n = Program.prog.use_split()
columns = [[sbits.get_type(elements.size)()
for i in range(n)] for i in range(length)]
inst.split(n, elements, *sum(columns, []))
x = sbitint.wallace_tree_without_finish(columns, False)
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
else:
assert Program.prog.options.ring
l = int(Program.prog.options.ring)
r, r_bits = sint.get_edabit(length, size=elements.size)
c = ((elements - r) << (l - length)).reveal()
c >>= l - length
cb = [(c >> i) for i in range(length)]
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
elif elements is not None:
self.v = sbits.trans(elements)
def popcnt(self):
res = sbitint.wallace_tree([[b] for b in self.v])
@@ -426,8 +502,15 @@ class sbitvec(object):
if stop is None:
start, stop = stop, start
return sbits.trans(self.v[start:stop])
def coerce(self, other):
if isinstance(other, cint):
size = other.size
return (other.get_vector(base, min(64, size - base)) \
for base in range(0, size, 64))
return other
def __xor__(self, other):
return self.from_vec(x ^ y for x, y in zip(self.v, other.v))
other = self.coerce(other)
return self.from_vec(x ^ y for x, y in zip(self.v, other))
def __and__(self, other):
return self.from_vec(x & y for x, y in zip(self.v, other.v))
def if_else(self, x, y):
@@ -453,6 +536,26 @@ class sbitvec(object):
def bit_decompose(self):
return self.v
bit_compose = from_vec
def reveal(self):
assert len(self) == 1
return self.v[0].reveal()
def long_one(self):
return [x.long_one() for x in self.v]
def __rsub__(self, other):
return self.from_vec(y - x for x, y in zip(self.v, other))
def half_adder(self, other):
other = self.coerce(other)
res = zip(*(x.half_adder(y) for x, y in zip(self.v, other)))
return (self.from_vec(x) for x in res)
def __mul__(self, other):
if isinstance(other, int):
return self.from_vec(x * other for x in self.v)
def __add__(self, other):
return self.from_vec(x + y for x, y in zip(self.v, other))
def bit_and(self, other):
return self & other
def bit_xor(self, other):
return self ^ other
class bit(object):
n = 1

View File

@@ -16,12 +16,13 @@ from functools import reduce
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
It is based on the precondition that every register is only defined once."""
def __init__(self, n):
def __init__(self, n, program):
self.alloc = dict_by_id()
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
self.defined = dict_by_id()
self.dealloc = set_by_id()
self.n = n
self.program = program
def alloc_reg(self, reg, free):
base = reg.vectorbase
@@ -76,7 +77,8 @@ class StraightlineAllocator:
# unused register
self.alloc_reg(j, alloc_pool)
unused_regs.append(j)
if unused_regs and len(unused_regs) == len(list(i.get_def())):
if unused_regs and len(unused_regs) == len(list(i.get_def())) and \
self.program.verbose:
# only report if all assigned registers are unused
print("Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller)))
@@ -175,37 +177,8 @@ class Merger:
except StopIteration:
return mergecount, None
def expand_vector_args(inst):
if inst.is_vec():
for arg in inst.args:
arg.create_vector_elements()
res = sum(list(zip(*inst.args)), ())
return list(res)
else:
return inst.args
for i in merges_iter:
if isinstance(instructions[n], startinput_class):
instructions[n].args[1] += instructions[i].args[1]
elif isinstance(instructions[n], (stopinput, gstopinput)):
if instructions[n].get_size() != instructions[i].get_size():
raise NotImplemented()
else:
instructions[n].args += instructions[i].args[1:]
else:
if instructions[n].get_size() != instructions[i].get_size():
# merge as non-vector instruction
instructions[n].args = expand_vector_args(instructions[n]) + \
expand_vector_args(instructions[i])
if instructions[n].is_vec():
instructions[n].size = 1
else:
instructions[n].args += instructions[i].args
# join arg_formats if not special iterators
# if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \
# not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)):
# instructions[n].arg_format += instructions[i].arg_format
instructions[n].merge(instructions[i])
instructions[i] = None
self.merge_nodes(n, i)
mergecount += 1
@@ -343,7 +316,7 @@ class Merger:
merge = merges[i]
t = type(self.instructions[merge[0]])
self.counter[t] += len(merge)
if len(merge) > 1000:
if len(merge) > 10000:
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
@@ -504,7 +477,8 @@ class Merger:
next_available_depth[type(instr), d] = depth
round_type[depth] = instr.merge_id()
parallel_open[depth] += len(instr.args) * instr.get_size()
if int(options.max_parallel_open) > 0:
parallel_open[depth] += len(instr.args) * instr.get_size()
depths[n] = depth
if isinstance(instr, ReadMemoryInstruction):
@@ -557,8 +531,9 @@ class Merger:
print("Processed dependency of %d/%d instructions at" % \
(n, len(block.instructions)), time.asctime())
if len(open_nodes) > 1000:
print("Program has %d %s instructions" % (len(open_nodes), merge_classes))
if len(open_nodes) > 1000 and self.block.parent.program.verbose:
print("Basic block has %d %s instructions" %
(len(open_nodes), merge_classes))
def merge_nodes(self, i, j):
""" Merge node j into i, removing node j """
@@ -608,7 +583,7 @@ class Merger:
eliminate(list(G.pred[i])[0])
eliminate(i)
count += 2
if count > 0:
if count > 0 and self.block.parent.program.verbose:
print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats)))

View File

@@ -69,13 +69,31 @@ def divide_by_two(res, x, m=1):
inv2m(tmp, m)
mulc(res, x, tmp)
@instructions_base.cisc
def LTZ(s, a, k, kappa):
"""
s = (a ?< 0)
k: bit length of a
"""
from .types import sint
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
movs(s, sint.conv(sbitvec(a, k).v[-1]))
return
elif program.options.ring:
from . import floatingpoint
assert(int(program.options.ring) >= k)
m = k - 1
shift = int(program.options.ring) - k
r_prime, r_bin = MaskingBitsInRing(k)
tmp = a - r_prime
c_prime = (tmp << shift).reveal() >> shift
a = r_bin[0].bit_decompose_clear(c_prime, m)
b = r_bin[:m]
u = CarryOutRaw(a[::-1], b[::-1])
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
return
t = sint()
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
@@ -86,6 +104,7 @@ def LessThanZero(a, k, kappa):
LTZ(res, a, k, kappa)
return res
@instructions_base.cisc
def Trunc(d, a, k, m, kappa, signed):
"""
d = a >> m
@@ -120,7 +139,7 @@ def TruncRing(d, a, k, m, signed):
movs(d, res)
return res
def TruncZeroes(a, k, m, signed):
def TruncZeros(a, k, m, signed):
if program.options.ring:
return TruncLeakyInRing(a, k, m, signed)
else:
@@ -139,9 +158,8 @@ def TruncLeakyInRing(a, k, m, signed):
from .types import sint, intbitint, cint, cgf2n
n_bits = k - m
n_shift = int(program.options.ring) - n_bits
if program.use_dabit and n_bits > 1:
r, r_bits = zip(*(sint.get_dabit() for i in range(n_bits)))
r = sint.bit_compose(r)
if n_bits > 1:
r, r_bits = MaskingBitsInRing(n_bits, True)
else:
r_bits = [sint.get_random_bit() for i in range(n_bits)]
r = sint.bit_compose(r_bits)
@@ -150,7 +168,7 @@ def TruncLeakyInRing(a, k, m, signed):
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits, 0)
BitLTL(u, masked, r_bits[:n_bits], 0)
res = (u << n_bits) + masked - r
if signed:
res -= (1 << (n_bits - 1))
@@ -174,6 +192,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
return res
@instructions_base.cisc
def Mod2m(a_prime, a, k, m, kappa, signed):
"""
a_prime = a % 2^m
@@ -199,16 +218,11 @@ def Mod2mRing(a_prime, a, k, m, signed):
assert(int(program.options.ring) >= k)
from Compiler.types import sint, intbitint, cint
shift = int(program.options.ring) - m
if program.use_dabit:
r, r_bin = zip(*(sint.get_dabit() for i in range(m)))
else:
r = [sint.get_random_bit() for i in range(m)]
r_bin = r
r_prime = sint.bit_compose(r)
r_prime, r_bin = MaskingBitsInRing(m, True)
tmp = a + r_prime
c_prime = (tmp << shift).reveal() >> shift
u = sint()
BitLTL(u, c_prime, r_bin, 0)
BitLTL(u, c_prime, r_bin[:m], 0)
res = (u << m) + c_prime - r_prime
if a_prime is not None:
movs(a_prime, res)
@@ -247,19 +261,35 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
adds(a_prime, t[5], t[4])
return r_dprime, r_prime, c, c_prime, u, t, c2k1
def MaskingBitsInRing(m, strict=False):
from Compiler.types import sint
if program.use_edabit():
return sint.get_edabit(m, strict)
elif program.use_dabit:
r, r_bin = zip(*(sint.get_dabit() for i in range(m)))
else:
r = [sint.get_random_bit() for i in range(m)]
r_bin = r
return sint.bit_compose(r), r_bin
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
"""
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
r_prime = random secret integer in range [0, 2^m - 1]
b = array containing bits of r_prime
"""
program.curr_tape.require_bit_length(k + kappa)
from .types import sint
if program.use_edabit() and m > 1:
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
tmp, b[:] = sint.get_edabit(m, True)
movs(r_prime, tmp)
return
t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)]
t[0][1] = b[-1]
PRandInt(r_dprime, k + kappa - m)
# r_dprime is always multiplied by 2^m
program.curr_tape.require_bit_length(k + kappa)
if use_dabit and program.use_dabit and m > 1:
from .types import sint
r, b[:] = zip(*(sint.get_dabit() for i in range(m)))
r = sint.bit_compose(r)
movs(r_prime, r)
@@ -389,17 +419,25 @@ def CarryOut(res, a, b, c=0, kappa=None):
b: array of secret bits (same length as a)
c: initial carry-in bit
"""
from .types import sint
movs(res, sint.conv(CarryOutRaw(a, b, c)))
def CarryOutRaw(a, b, c=0):
assert len(a) == len(b)
k = len(a)
from . import types
d = [program.curr_block.new_reg('s') for i in range(k)]
s = [program.curr_block.new_reg('s') for i in range(3)]
for i in range(k):
d[i] = list(b[i].half_adder(a[i]))
s[0] = d[-1][0] * c
s[0] = d[-1][0].bit_and(c)
s[1] = d[-1][1] + s[0]
d[-1][1] = s[1]
movs(res, types.sint.conv(CarryOutAux(d[::-1], kappa)))
return CarryOutAux(d[::-1], None)
def CarryOutRawLE(a, b, c=0):
""" Little-endian version """
return CarryOutRaw(a[::-1], b[::-1], c)
def CarryOutLE(a, b, c=0):
""" Little-endian version """
@@ -416,13 +454,12 @@ def BitLTL(res, a, b, kappa):
b: array of secret bits (same length as a)
"""
k = len(b)
from . import floatingpoint
a_bits = floatingpoint.bits(a, k)
a_bits = b[0].bit_decompose_clear(a, k)
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
t = [program.curr_block.new_reg('s') for i in range(1)]
for i in range(len(b)):
s[0][i] = b[0].long_one() - b[i]
CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa)
CarryOut(t[0], a_bits[::-1], s[0][::-1], b[0].long_one(), kappa)
subsfi(res, t[0], 1)
return a_bits, s[0]

View File

@@ -71,9 +71,10 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
if prog.main_thread_running:
prog.update_req(prog.curr_tape)
print('Program requires:', repr(prog.req_num))
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
print('Memory size:', dict(prog.allocated_mem))
if prog.verbose:
print('Program requires:', repr(prog.req_num))
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
print('Memory size:', dict(prog.allocated_mem))
# finalize the memory
prog.finalize_memory()

View File

@@ -17,22 +17,16 @@ P_VALUES = { 32: 2147565569, \
P_VALUES[-1] = P_VALUES[128]
BIT_LENGTHS = { -1: 32,
BIT_LENGTHS = { -1: 64,
32: 16,
64: 16,
128: 64,
256: 64,
512: 64 }
STAT_SEC = { -1: 6,
32: 6,
64: 30,
128: 40,
256: 40,
512: 40 }
COST = { 'modp': defaultdict(lambda: 0,
COST = defaultdict(lambda: defaultdict(lambda: 0),
{ 'modp': defaultdict(lambda: 0,
{ 'triple': 0.00020652622883106154,
'square': 0.00020652622883106154,
'bit': 0.00020652622883106154,
@@ -51,7 +45,7 @@ COST = { 'modp': defaultdict(lambda: 0,
'all': { 'round': 0,
'inv': 0,
}
}
})
try:

View File

@@ -4,6 +4,7 @@ from . import types
from . import comparison
from . import program
from . import util
from . import instructions_base
##
## Helper functions for floating point arithmetic
@@ -50,13 +51,14 @@ def maskField(a, k, kappa):
asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
return c, r
@instructions_base.ret_cisc
def EQZ(a, k, kappa):
if program.Program.prog.options.ring:
c, r = maskRing(a, k)
else:
c, r = maskField(a, k, kappa)
d = [None]*k
for i,b in enumerate(bits(c, k)):
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
d[i] = r[i].bit_xor(b)
return 1 - types.sint.conv(KOR(d, kappa))
@@ -299,6 +301,7 @@ def BitDec(a, k, m, kappa, bits_to_compute=None):
def BitDecRing(a, k, m):
n_shift = int(program.Program.prog.options.ring) - m
assert(n_shift >= 0)
if program.Program.prog.use_dabit:
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
r = types.sint.bit_compose(r)
@@ -328,10 +331,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
print('BitDec assertion failed')
print('a =', a.value)
print('a mod 2^%d =' % k, (a.value % 2**k))
res = r[0].bit_adder(r, list(bits(c,m)))
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
return [types.sint.conv(bit) for bit in res]
@instructions_base.ret_cisc
def Pow2(a, l, kappa):
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
@@ -361,10 +365,16 @@ def B2U_from_Pow2(pow2a, l, kappa):
for i in range(l):
bit(r[i])
r_bits = r
comparison.PRandInt(t, kappa)
asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l)))
comparison.program.curr_tape.require_bit_length(l + kappa)
c = list(bits(c, l))
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
assert n_shift > 0
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift
else:
comparison.PRandInt(t, kappa)
asm_open(c, pow2a + two_power(l) * t +
sum(two_power(i) * r[i] for i in range(l)))
comparison.program.curr_tape.require_bit_length(l + kappa)
c = list(r_bits[0].bit_decompose_clear(c, l))
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
#print ' '.join(str(b.value) for b in x)
y = PreOR(x, kappa)
@@ -402,10 +412,14 @@ def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
r_prime += t2
r_dprime += t1 - t2
#assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
comparison.PRandInt(rk, kappa)
r_dprime += two_power(l) * rk
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
asm_open(c, a + r_dprime + r_prime)
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift
else:
comparison.PRandInt(rk, kappa)
r_dprime += two_power(l) * rk
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
asm_open(c, a + r_dprime + r_prime)
for i in range(1,l):
ci[i] = c % two_power(i)
#assert(ci[i].value == c.value % 2**i)
@@ -439,6 +453,12 @@ def TruncInRing(to_shift, l, pow2m):
bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l))
return types.sint.bit_compose(reversed(bits))
def SplitInRing(a, l, m):
pow2m = Pow2(m, l, None)
upper = TruncInRing(a, l, pow2m)
lower = a - upper * pow2m
return lower, upper, pow2m
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
@@ -496,6 +516,7 @@ def FLRound(x, mode):
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
return v, p, z, s
@instructions_base.ret_cisc
def TruncPr(a, k, m, kappa=None, signed=True):
""" Probabilistic truncation [a/2^m + u]
where Pr[u = 1] = (a % 2^m) / 2^m
@@ -513,8 +534,11 @@ def TruncPrRing(a, k, m, signed=True):
n_ring = int(program.Program.prog.options.ring)
assert n_ring >= k, '%d too large' % k
if k == n_ring:
for i in range(m):
a += types.sint.get_random_bit() << i
if program.Program.prog.use_edabit():
a += types.sint.get_edabit(m, True)[0]
else:
for i in range(m):
a += types.sint.get_random_bit() << i
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
else:
from .types import sint
@@ -525,13 +549,22 @@ def TruncPrRing(a, k, m, signed=True):
trunc_pr(res, a, k, m)
else:
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
if program.Program.prog.use_edabit():
lower = sint.get_edabit(m, True)[0]
upper = sint.get_edabit(k - m, True)[0]
msb = sint.get_random_bit()
r = (msb << k) + (upper << m) + lower
else:
r_bits = [sint.get_random_bit() for i in range(k + 1)]
r = sint.bit_compose(r_bits)
upper = sint.bit_compose(r_bits[m:k])
msb = r_bits[-1]
n_shift = n_ring - (k + 1)
tmp = a + r
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + \
overflow = msb.bit_xor(masked >> (n_ring - 1))
res = shifted - upper + \
(overflow << (k - m))
if signed:
res -= (1 << (k - m - 1))
@@ -555,6 +588,7 @@ def TruncPrField(a, k, m, kappa=None):
d = (a - a_prime) / two_to_m
return d
@instructions_base.ret_cisc
def SDiv(a, b, l, kappa, round_nearest=False):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
@@ -564,7 +598,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest,
@@ -576,7 +610,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest, signed=False)
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)

View File

@@ -11,8 +11,10 @@ documentation
"""
import itertools
import operator
from . import tools
from random import randint
from functools import reduce
from Compiler.config import *
from Compiler.exceptions import *
import Compiler.instructions_base as base
@@ -318,6 +320,11 @@ class use_inp(base.Instruction):
code = base.opcodes['USE_INP']
arg_format = ['int','int','int']
class use_edabit(base.Instruction):
r""" edaBit usage. """
code = base.opcodes['USE_EDABIT']
arg_format = ['int','int','int']
class run_tape(base.Instruction):
r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """
code = base.opcodes['RUN_TAPE']
@@ -808,7 +815,29 @@ class dabit(base.DataInstruction):
code = base.opcodes['DABIT']
arg_format = ['sw', 'sbw']
field_type = 'modp'
data_type = 'bit'
data_type = 'dabit'
@base.vectorize
class edabit(base.Instruction):
""" edaBit """
__slots__ = []
code = base.opcodes['EDABIT']
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment(('edabit', len(self.args) - 1), self.get_size())
@base.vectorize
class sedabit(base.Instruction):
""" strict edaBit """
__slots__ = []
code = base.opcodes['SEDABIT']
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment(('sedabit', len(self.args) - 1), self.get_size())
@base.gf2n
@base.vectorize
@@ -988,7 +1017,19 @@ class startinput(base.RawInputInstruction):
req_node.increment((self.field_type, 'input', self.args[0]), \
self.args[1])
class stopinput(base.RawInputInstruction):
def merge(self, other):
self.args[1] += other.args[1]
class StopInputInstruction(base.RawInputInstruction):
__slots__ = []
def merge(self, other):
if self.get_size() != other.get_size():
raise NotImplemented()
else:
self.args += other.args[1:]
class stopinput(StopInputInstruction):
r""" Receive inputs from player $p$ and put in registers. """
__slots__ = []
code = base.opcodes['STOPINPUT']
@@ -997,7 +1038,7 @@ class stopinput(base.RawInputInstruction):
def has_var_args(self):
return True
class gstopinput(base.RawInputInstruction):
class gstopinput(StopInputInstruction):
r""" Receive inputs from player $p$ and put in registers. """
__slots__ = []
code = 0x100 + base.opcodes['STOPINPUT']
@@ -1322,6 +1363,26 @@ class bitdecint(base.Instruction):
code = base.opcodes['BITDECINT']
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
class incint(base.VectorInstruction):
__slots__ = []
code = base.opcodes['INCINT']
arg_format = ['ciw', 'ci', 'i', 'i', 'i']
def __init__(self, *args, **kwargs):
assert len(args[1]) == 1
if len(args) == 3:
args = list(args) + [1, len(args[0])]
super(incint, self).__init__(*args, **kwargs)
class shuffle(base.VectorInstruction):
__slots__ = []
code = base.opcodes['SHUFFLE']
arg_format = ['ciw','ci']
def __init__(self, *args, **kwargs):
super(shuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
###
### Clear comparison instructions
###
@@ -1429,6 +1490,9 @@ class convmodp(base.Instruction):
code = base.opcodes['CONVMODP']
arg_format = ['ciw', 'c', 'int']
def __init__(self, *args, **kwargs):
if len(args) == len(self.arg_format):
super(convmodp_class, self).__init__(*args)
return
bitlength = kwargs.get('bitlength')
bitlength = program.bit_length if bitlength is None else bitlength
if bitlength > 64:
@@ -1472,7 +1536,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
def merge_id(self):
# can merge different sizes
# but not if large
if self.get_size() > 100:
if self.get_size() is None or self.get_size() > 100:
return type(self), self.get_size()
return type(self)
@@ -1561,6 +1625,31 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
class matmul_base(base.DataInstruction):
data_type = 'triple'
is_vec = lambda self: True
def get_repeat(self):
return reduce(operator.mul, self.args[3:6])
class matmuls(matmul_base):
""" Secret matrix multiplication """
code = base.opcodes['MATMULS']
arg_format = ['sw','s','s','int','int','int']
class matmulsm(matmul_base):
""" Secret matrix multiplication reading directly from memory """
code = base.opcodes['MATMULSM']
arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci',
'int','int']
def __init__(self, *args, **kwargs):
matmul_base.__init__(self, *args, **kwargs)
for i in range(2):
assert args[6 + i].size == args[3 + i]
for i in range(2):
assert args[8 + i].size == args[4 + i]
@base.vectorize
class trunc_pr(base.VarArgsInstruction):
""" Probalistic truncation for semi-honest computation """

View File

@@ -8,6 +8,7 @@ from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
from Compiler import tools
from Compiler import program
###
@@ -59,6 +60,7 @@ opcodes = dict(
NPLAYERS = 0xE2,
THRESHOLD = 0xE3,
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -93,6 +95,8 @@ opcodes = dict(
MULRS = 0xA7,
DOTPRODS = 0xA8,
TRUNC_PR = 0xA9,
MATMULS = 0xAA,
MATMULSM = 0xAB,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -103,6 +107,8 @@ opcodes = dict(
INPUTMASK = 0x56,
PREP = 0x57,
DABIT = 0x58,
EDABIT = 0x59,
SEDABIT = 0x5A,
# Input
INPUT = 0x60,
INPUTFIX = 0xF0,
@@ -153,6 +159,8 @@ opcodes = dict(
MULINT = 0x9D,
DIVINT = 0x9E,
PRINTINT = 0x9F,
INCINT = 0xD1,
SHUFFLE = 0xD2,
# Conversion
CONVINT = 0xC0,
CONVMODP = 0xC1,
@@ -235,13 +243,17 @@ def vectorize(instruction, global_dict=None):
def __init__(self, size, *args, **kwargs):
self.size = size
super(Vectorized_Instruction, self).__init__(*args, **kwargs)
for arg,f in zip(self.args, self.arg_format):
if issubclass(ArgFormats[f], RegisterArgFormat):
arg.set_size(size)
if not kwargs.get('copying', False):
for arg,f in zip(self.args, self.arg_format):
if issubclass(ArgFormats[f], RegisterArgFormat):
arg.set_size(size)
def get_code(self):
return (self.size << 10) + self.code
return instruction.get_code(self, self.get_size())
def get_pre_arg(self):
return "%d, " % self.size
try:
return "%d, " % self.size
except:
return "{undef}, "
def is_vec(self):
return True
def get_size(self):
@@ -250,6 +262,9 @@ def vectorize(instruction, global_dict=None):
set_global_vector_size(self.size)
super(Vectorized_Instruction, self).expand()
reset_global_vector_size()
def copy(self, size, subs):
return type(self)(size, *self.get_new_args(size, subs),
copying=True)
@functools.wraps(instruction)
def maybe_vectorized_instruction(*args, **kwargs):
@@ -360,6 +375,169 @@ def gf2n(instruction):
return maybe_gf2n_instruction
#return instruction
class Mergeable:
pass
def cisc(function):
class MergeCISC(Mergeable):
instructions = {}
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.calls = [(args, kwargs)]
self.params = []
self.used = []
for arg in self.args[1:]:
if isinstance(arg, program.curr_tape.Register):
self.used.append(arg)
self.params.append(type(arg))
else:
self.params.append(arg)
self.function = function
program.curr_block.instructions.append(self)
def get_def(self):
return [self.args[0]]
def get_used(self):
return self.used
def is_vec(self):
return True
def merge_id(self):
return self.function, tuple(self.params), \
tuple(sorted(self.kwargs.items()))
def merge(self, other):
self.calls += other.calls
def get_size(self):
return self.args[0].size
def new_instructions(self, size, regs):
if self.merge_id() not in self.instructions:
from Compiler.program import Tape
tape = Tape(self.function.__name__, program)
old_tape = program.curr_tape
program.curr_tape = tape
block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block
set_global_vector_size(None)
args = []
for arg in self.args:
try:
args.append(type(arg)(size=None))
except:
args.append(arg)
program.options.cisc = False
self.function(*args, **self.kwargs)
program.options.cisc = True
reset_global_vector_size()
program.curr_tape = old_tape
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
args[0].can_eliminate = False
merger.eliminate_dead_code()
assert int(program.options.max_parallel_open) == 0, \
'merging restriction not compatible with ' \
'mergeable CISC instructions'
merger.longest_paths_merge()
filtered = filter(lambda x: x is not None, block.instructions)
self.instructions[self.merge_id()] = list(filtered), args
template, args = self.instructions[self.merge_id()]
subs = util.dict_by_id()
for arg, reg in zip(args, regs):
subs[arg] = reg
set_global_vector_size(size)
for inst in template:
inst.copy(size, subs)
reset_global_vector_size()
def expand_merged(self):
tape = program.curr_tape
block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block
size = sum(call[0][0].size for call in self.calls)
new_regs = []
for arg in self.args:
try:
new_regs.append(type(arg)(size=size))
except:
break
base = 0
for call in self.calls:
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
set_global_vector_size(reg.size)
reg.mov(new_reg.get_vector(base, reg.size), reg)
reset_global_vector_size()
base += reg.size
self.new_instructions(size, new_regs)
base = 0
for call in self.calls:
reg = call[0][0]
set_global_vector_size(reg.size)
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
reset_global_vector_size()
base += reg.size
return block.instructions
MergeCISC.__name__ = function.__name__
def wrapper(*args, **kwargs):
if program.options.cisc:
return MergeCISC(*args, **kwargs)
else:
return function(*args, **kwargs)
return wrapper
def ret_cisc(function):
def instruction(res, *args, **kwargs):
res.mov(res, function(*args, **kwargs))
instruction.__name__ = function.__name__
instruction = cisc(instruction)
def wrapper(*args, **kwargs):
if not program.options.cisc:
return function(*args, **kwargs)
from Compiler import types
if isinstance(args[0], types._clear):
res_type = type(args[1])
else:
res_type = type(args[0])
res = res_type(size=args[0].size)
instruction(res, *args, **kwargs)
return res
return wrapper
def sfix_cisc(function):
from Compiler.types import sfix, sint, cfix, copy_doc
def instruction(res, arg, k, f):
assert k is not None
assert f is not None
old = sfix.k, sfix.f, cfix.k, cfix.f
sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4
res.mov(res, function(sfix._new(arg, k=k, f=f)).v)
sfix.k, sfix.f, cfix.k, cfix.f = old
instruction.__name__ = function.__name__
instruction = cisc(instruction)
def wrapper(*args, **kwargs):
if isinstance(args[0], sfix):
assert len(args) == 1
assert not kwargs
assert args[0].size == args[0].v.size
k = args[0].k
f = args[0].f
res = sfix._new(sint(size=args[0].size), k=k, f=f)
instruction(res.v, args[0].v, k, f)
return res
else:
return function(*args, **kwargs)
copy_doc(wrapper, function)
return wrapper
class RegType(object):
""" enum-like static class for Register types """
@@ -381,6 +559,8 @@ class RegType(object):
return res
class ArgFormat(object):
is_reg = False
@classmethod
def check(cls, arg):
return NotImplemented
@@ -390,11 +570,13 @@ class ArgFormat(object):
return NotImplemented
class RegisterArgFormat(ArgFormat):
is_reg = True
@classmethod
def check(cls, arg):
if not isinstance(arg, program.curr_tape.Register):
raise ArgumentError(arg, 'Invalid register argument')
if arg.i > REG_MAX:
if arg.i > REG_MAX and arg.i != float('inf'):
raise ArgumentError(arg, 'Register index too large')
if arg.program != program.curr_tape:
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
@@ -425,7 +607,7 @@ class ClearIntAF(RegisterArgFormat):
class IntArgFormat(ArgFormat):
@classmethod
def check(cls, arg):
if not isinstance(arg, int):
if not isinstance(arg, int) and not arg is None:
raise ArgumentError(arg, 'Expected an integer-valued argument')
@classmethod
@@ -487,7 +669,7 @@ ArgFormats = {
}
def format_str_is_reg(format_str):
return issubclass(ArgFormats[format_str], RegisterArgFormat)
return ArgFormats[format_str].is_reg
def format_str_is_writeable(format_str):
return format_str_is_reg(format_str) and format_str[-1] == 'w'
@@ -504,7 +686,8 @@ class Instruction(object):
def __init__(self, *args, **kwargs):
""" Create an instruction and append it to the program list. """
self.args = list(args)
self.check_args()
if not kwargs.get('copying', False):
self.check_args()
if not program.FIRST_PASS:
if kwargs.get('add_to_prog', True):
program.curr_block.instructions.append(self)
@@ -519,9 +702,9 @@ class Instruction(object):
if Instruction.count % 100000 == 0:
print("Compiled %d lines at" % self.__class__.count, time.asctime())
def get_code(self):
return self.code
def get_code(self, prefix=0):
return (prefix << 10) + self.code
def get_encoding(self):
enc = int_to_bytes(self.get_code())
# add the number of registers if instruction flagged as has var args
@@ -540,12 +723,12 @@ class Instruction(object):
def check_args(self):
""" Check the args match up with that specified in arg_format """
for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)):
if arg is None:
if not isinstance(self.arg_format, (list, tuple)):
break # end of optional arguments
else:
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
try:
if len(self.args) != len(self.arg_format):
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
except TypeError:
pass
for n,(arg,f) in enumerate(zip(self.args, self.arg_format)):
try:
ArgFormats[f].check(arg)
except ArgumentError as e:
@@ -589,6 +772,42 @@ class Instruction(object):
def merge_id(self):
return type(self), self.get_size()
def merge(self, other):
if self.get_size() != other.get_size():
# merge as non-vector instruction
self.args = self.expand_vector_args() + other.expand_vector_args()
if self.is_vec():
self.size = 1
else:
self.args += other.args
def expand_vector_args(self):
if self.is_vec():
for arg in self.args:
arg.create_vector_elements()
res = sum(list(zip(*self.args)), ())
return list(res)
else:
return self.args
def expand_merged(self):
return [self]
def get_new_args(self, size, subs):
new_args = []
for arg, f in zip(self.args, self.arg_format):
if arg in subs:
new_args.append(subs[arg])
elif arg is None:
new_args.append(size)
else:
if format_str_is_writeable(f):
new_args.append(arg.copy())
subs[arg] = new_args[-1]
else:
new_args.append(arg)
return new_args
# String version of instruction attempting to replicate encoded version
def __str__(self):
@@ -606,6 +825,13 @@ class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
class VectorInstruction(Instruction):
__slots__ = []
is_vec = lambda self: True
def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
###
### Basic arithmetic
###

View File

@@ -22,7 +22,7 @@ def get_block():
def vectorize(function):
def vectorized_function(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
if len(args) > 0 and 'size' in dir(args[0]):
instructions_base.set_global_vector_size(args[0].size)
res = function(*args, **kwargs)
instructions_base.reset_global_vector_size()
@@ -88,7 +88,7 @@ def print_str(s, *args):
raise CompilerError('Cannot print secret value:', args[i])
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, list):
elif isinstance(val, (list, tuple, Array)):
print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val)
else:
try:
@@ -362,9 +362,14 @@ class Function:
class FunctionTape(Function):
# not thread-safe
def __init__(self, function, name=None, compile_args=[],
single_thread=False):
Function.__init__(self, function, name, compile_args)
self.single_thread = single_thread
def on_first_call(self, wrapped_function):
self.thread = MPCThread(wrapped_function, self.name,
args=self.compile_args)
args=self.compile_args,
single_thread=self.single_thread)
def on_call(self, base, bases):
return FunctionTapeCall(self.thread, base, bases)
@@ -376,6 +381,9 @@ def function_tape_with_compile_args(*args):
return FunctionTape(function, compile_args=args)
return wrapper
def single_thread_function_tape(function):
return FunctionTape(function, single_thread=True)
def memorize(x):
if isinstance(x, (tuple, list)):
return tuple(memorize(i) for i in x)
@@ -397,13 +405,15 @@ class FunctionBlock(Function):
block.alloc_pool = defaultdict(set)
del parent_node.children[-1]
self.node = get_tape().req_node
print('Compiling function', self.name)
if get_program().verbose:
print('Compiling function', self.name)
result = wrapped_function(*self.compile_args)
if result is not None:
self.result = memorize(result)
else:
self.result = None
print('Done compiling function', self.name)
if get_program().verbose:
print('Done compiling function', self.name)
p_return_address = get_tape().program.malloc(1, 'ci')
get_tape().function_basicblocks[block] = p_return_address
return_address = regint.load_mem(p_return_address)
@@ -528,7 +538,7 @@ def chunky_odd_even_merge_sort(a):
a[m], a[m+step] = cond_swap(a[m], a[m+step])
for i in range(len(a)):
a[i].store_in_mem(i * a[i].sizeof())
chunk = MPCThread(round, 'sort-%d-%d' % (l,k))
chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True)
chunk.start()
chunk.join()
#round()
@@ -541,7 +551,6 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
instructions.program.restart_main_thread()
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
@@ -657,7 +666,6 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
run_postproc()
if isinstance(a, list):
instructions.program.restart_main_thread()
for i in range(n):
a[i] = load_secret_mem(a_base + i)
instructions.program.free(a_base, 's')
@@ -669,7 +677,6 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
instructions.program.restart_main_thread()
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
@@ -764,7 +771,6 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
range_loop(outer, n // l)
if isinstance(a, list):
instructions.program.restart_main_thread()
for i in range(n):
a[i] = load_secret_mem(a_base + i)
instructions.program.free(a_base, 's')
@@ -772,30 +778,39 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
instructions.program.free(tmp_i, 'ci')
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32):
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
n_threads=None):
steps = {}
l = sorted_length
while l < len(a):
l *= 2
k = 1
while k < l:
k *= 2
n_outer = len(a) // l
n_inner = l // k
n_innermost = 1 if k == 2 else k // 2 - 1
@for_range_parallel(n_parallel // n_innermost // n_inner, n_outer)
def loop(i):
@for_range_parallel(n_parallel // n_innermost, n_inner)
def inner(j):
base = i*l + j
step = l//k
if k == 2:
a[base], a[base+step] = cond_swap(a[base], a[base+step])
else:
@for_range_parallel(n_parallel, n_innermost)
def f(i):
m1 = step + i * 2 * step
m2 = m1 + base
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
key = k
if key not in steps:
@function_block
def step(l):
l = MemValue(l)
@for_range_opt_multithread(n_threads, len(a) // k)
def _(i):
n_inner = l // k
j = i % n_inner
i //= n_inner
base = i*l + j
step = l//k
if k == 2:
a[base], a[base+step] = \
cond_swap(a[base], a[base+step])
else:
@for_range_opt(n_innermost)
def f(i):
m1 = step + i * 2 * step
m2 = m1 + base
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
steps[key] = step
steps[key](l)
def mergesort(A):
B = Array(len(A), sint)
@@ -899,11 +914,12 @@ def for_range_parallel(n_parallel, n_loops):
def for_range_opt(n_loops, budget=None):
""" Execute loop bodies in parallel up to an optimization budget.
This prevents excessive loop unrolling. The budget is respected
even with nested loops.
even with nested loops. Note that optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_parallel` in this case.
:param n_loops: compile-time (int)
:param n_loops: int/regint/cint
:param budget: number of instructions after which to start optimization (default is 100,000)
:type: compile-time (int)
Example:
@@ -912,6 +928,7 @@ def for_range_opt(n_loops, budget=None):
@for_range_opt(n)
def _(i):
...
"""
return map_reduce_single(None, n_loops, budget=budget)
@@ -929,6 +946,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
else:
# use Arrays for multithread version
use_array = True
if not util.is_constant(n_loops):
budget //= 10
def decorator(loop_body):
my_n_parallel = n_parallel
if isinstance(n_parallel, int):
@@ -955,17 +974,18 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
r = reducer(mem_state, state)
write_state_to_memory(r)
else:
if n_loops == 0:
if is_zero(n_loops):
return
regint.push(0)
n_opt_loops_reg = regint(0)
n_opt_loops_inst = get_block().instructions[-1]
parent_block = get_block()
@while_do(lambda x: x + regint.pop() <= n_loops, regint(0))
@while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0))
def _(i):
state = tuplify(initializer())
k = 0
block = get_block()
while k < n_loops and (len(get_block()) < budget \
or k == 0) \
while (not util.is_constant(n_loops) or k < n_loops) \
and (len(get_block()) < budget or k == 0) \
and block is get_block():
j = i + k
state = reducer(tuplify(loop_body(j)), state)
@@ -974,13 +994,13 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
write_state_to_memory(r)
global n_opt_loops
n_opt_loops = k
regint.push(k)
n_opt_loops_inst.args[1] = k
return i + k
my_n_parallel = n_opt_loops
loop_rounds = n_loops // my_n_parallel
blocks = get_tape().basicblocks
n_to_merge = 5
if loop_rounds == 1 and parent_block is blocks[-n_to_merge]:
if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]:
# merge blocks started by if and do_while
def exit_elimination(block):
if block.exit_condition is not None:
@@ -996,14 +1016,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
for block in blocks[-n_to_merge + 1:]:
merged.instructions += block.instructions
exit_elimination(block)
block.purge()
block.purge(retain_usage=False)
del blocks[-n_to_merge + 1:]
del get_tape().req_node.children[-1]
merged.children = []
get_tape().active_basicblock = merged
else:
req_node = get_tape().req_node.children[-1].nodes[0]
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
if util.is_constant(loop_rounds):
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
if isinstance(n_loops, int):
state = mem_state
for j in range(loop_rounds * my_n_parallel, n_loops):
@@ -1040,7 +1061,9 @@ def for_range_opt_multithread(n_threads, n_loops):
"""
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
threads, in parallel up to an optimization budget per thread
similar to :py:func:`for_range_opt`.
similar to :py:func:`for_range_opt`. Note that optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_multithread` in this case.
:param n_threads: compile-time (int)
:param n_loops: regint/cint/int
@@ -1089,7 +1112,7 @@ def multithread(n_threads, n_items):
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
thread_mem_req={}, looping=True):
n_threads = n_threads or 1
assert(n_threads != 0)
if isinstance(n_loops, list):
split = n_loops
n_loops = reduce(operator.mul, n_loops)
@@ -1103,9 +1126,22 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
return new_body
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
return lambda loop_body: new_dec(decorator(loop_body))
n_loops = MemValue.if_necessary(n_loops)
if n_threads == None or util.is_one(n_loops):
if not looping:
return lambda loop_body: loop_body(0, n_loops)
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint)
return lambda loop_body: dec(lambda i: loop_body(i, thread_mem))
else:
return dec
def decorator(loop_body):
thread_rounds = n_loops // n_threads
remainder = n_loops % n_threads
thread_rounds = MemValue.if_necessary(n_loops // n_threads)
if util.is_constant(thread_rounds):
remainder = n_loops % n_threads
else:
remainder = 0
for t in thread_mem_req:
if t != regint:
raise CompilerError('Not implemented for other than regint')
@@ -1113,6 +1149,11 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
state = tuple(initializer())
def f(inc):
base = args[get_arg()][0]
if not util.is_constant(thread_rounds):
i = base / thread_rounds
overhang = n_loops % n_threads
inc = i < overhang
base += inc.if_else(i, overhang)
if not looping:
return loop_body(base, thread_rounds + inc)
if thread_mem_req:
@@ -1129,7 +1170,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
return loop_body(base + i)
prog = get_program()
threads = []
if thread_rounds:
if not util.is_zero(thread_rounds):
tape = prog.new_tape(f, (0,), 'multithread')
for i in range(n_threads - remainder):
mem_state = make_array(initializer())
@@ -1465,6 +1506,15 @@ def get_player_id():
playerid(res._v)
return res
def break_point(name=''):
"""
Insert break point. This makes sure that all following code
will be executed after preceding code.
:param name: Name for identification (optional)
"""
get_tape().start_new_basicblock(name=name)
# Fixed point ops
from math import ceil, log
@@ -1590,6 +1640,7 @@ def IntDiv(a, b, k, kappa=None):
return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k,
kappa, nearest=True)
@instructions_base.ret_cisc
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
"""
Goldschmidt method as presented in Catrina10,

View File

@@ -1,9 +1,37 @@
import mpc_math, math
"""
This module contains machine learning functionality. It is work in
progress, so you must expect things to change. The most tested
functionality is logistic regression. It can be run as follows::
sgd = ml.SGD([ml.Dense(n_examples, n_features, 1),
ml.Output(n_examples, approx=True)], n_epochs,
report_loss=True)
sgd.layers[0].X.input_from(0)
sgd.layers[1].Y.input_from(1)
sgd.reset()
sgd.run()
This loads measurements from party 0 and labels (0/1) from party
1. After running, the model is stored in :py:obj:`sgd.layers[0].W` and
:py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines
whether to use an approximate sigmoid function. Inference can be run as
follows::
data = sfix.Matrix(n_test, n_features)
data.input_from(0)
res = sgd.eval(data)
print_ln('Results: %s', [x.reveal() for x in res])
"""
import math
from Compiler import mpc_math
from Compiler.types import *
from Compiler.types import _unreduced_squant
from Compiler.library import *
from Compiler.util import is_zero
from Compiler.util import is_zero, tree_reduce
from Compiler.comparison import CarryOutRawLE
from Compiler.GC.types import sbitint
from functools import reduce
def log_e(x):
@@ -31,6 +59,29 @@ def sigmoid_prime(x):
sx = sigmoid(x)
return sx * (1 - sx)
@vectorize
def approx_sigmoid(x):
if approx_sigmoid.special and \
get_program().options.ring and get_program().use_edabit():
l = int(get_program().options.ring)
r, r_bits = sint.get_edabit(x.k, False)
c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k)
c_bits = c.bit_decompose(x.k)
lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1])
higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:],
lower_overflow)
sign = higher_bits[-1]
higher_bits.pop(-1)
aa = sign & ~util.tree_reduce(operator.and_, higher_bits)
bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits])
a, b = (sint.conv(x) for x in (aa, bb))
else:
a = x < -0.5
b = x > 0.5
return a.if_else(0, b.if_else(1, 0.5 + x))
approx_sigmoid.special = False
def lse_0_from_e_x(x, e_x):
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
@@ -56,7 +107,7 @@ class Layer:
n_threads = 1
class Output(Layer):
def __init__(self, N, debug=False):
def __init__(self, N, debug=False, approx=False):
self.N = N
self.X = sfix.Array(N)
self.Y = sfix.Array(N)
@@ -64,9 +115,8 @@ class Output(Layer):
self.l = MemValue(sfix(-1))
self.e_x = sfix.Array(N)
self.debug = debug
self.weights = cint.Array(N)
self.weights.assign_all(1)
self.weight_total = N
self.weights = None
self.approx = approx
nablas = lambda self: ()
thetas = lambda self: ()
@@ -75,30 +125,45 @@ class Output(Layer):
def divisor(self, divisor, size):
return cfix(1.0 / divisor, size=size)
def forward(self, N=None):
N = N or self.N
def forward(self, batch):
if self.approx:
self.l.write(999)
return
N = len(batch)
lse = sfix.Array(N)
@multithread(self.n_threads, N)
def _(base, size):
x = self.X.get_vector(base, size)
y = self.Y.get_vector(base, size)
y = self.Y.get(batch.get_vector(base, size))
e_x = exp(-x)
self.e_x.assign(e_x, base)
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
e_x = self.e_x.get_vector(0, N)
self.l.write(sum(lse) * \
self.divisor(self.N, 1))
self.divisor(N, 1))
def backward(self):
@multithread(self.n_threads, self.N)
def eval(self, size, base=0):
if self.approx:
return approx_sigmoid(self.X.get_vector(base, size))
else:
return sigmoid_from_e_x(self.X.get_vector(base, size),
self.e_x.get_vector(base, size))
def backward(self, batch):
N = len(batch)
@multithread(self.n_threads, N)
def _(base, size):
diff = sigmoid_from_e_x(self.X.get_vector(base, size),
self.e_x.get_vector(base, size)) - \
self.Y.get_vector(base, size)
diff = self.eval(size, base) - \
self.Y.get(batch.get_vector(base, size))
assert sfix.f == cfix.f
diff *= self.weights.get_vector(base, size)
self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \
base)
if self.weights is None:
diff *= self.divisor(N, size)
else:
assert N == len(self.weights)
diff *= self.weights.get_vector(base, size)
if self.weight_total != 1:
diff *= self.divisor(self.weight_total, size)
self.nabla_X.assign(diff, base)
# @for_range_opt(len(diff))
# def _(i):
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
@@ -112,6 +177,7 @@ class Output(Layer):
#print_ln('%s', x)
def set_weights(self, weights):
self.weights = cfix.Array(len(weights))
self.weights.assign(weights)
self.weight_total = sum(weights)
@@ -119,16 +185,28 @@ class DenseBase(Layer):
thetas = lambda self: (self.W, self.b)
nablas = lambda self: (self.nabla_W, self.nabla_b)
def backward_params(self, f_schur_Y):
N = self.N
def backward_params(self, f_schur_Y, batch):
N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
def _(j, k):
assert self.d == 1
a = [f_schur_Y[i][0][k] for i in range(N)]
b = [self.X[i][0][j] for i in range(N)]
tmp[j][k] = sfix.unreduced_dot_product(a, b)
assert self.d == 1
if self.d_out == 1:
@multithread(self.n_threads, self.d_in)
def _(base, size):
A = sfix.Matrix(1, self.N, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
mp = A.direct_mul(B, reduce=False,
indices=(regint(0, size=1),
regint.inc(N),
batch.get_vector(),
regint.inc(size, base)))
tmp.assign_vector(mp, base)
else:
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
def _(j, k):
a = [f_schur_Y[i][0][k] for i in range(N)]
b = [self.X[i][0][j] for i in batch]
tmp[j][k] = sfix.unreduced_dot_product(a, b)
if self.d_in * self.d_out < 100000:
print('reduce at once')
@@ -189,26 +267,34 @@ class Dense(DenseBase):
self.W[i][j] = sfix.get_random(-r, r)
self.b.assign_all(0)
def compute_f_input(self):
prod = MultiArray([self.N, self.d, self.d_out], sfix)
@for_range_opt_multithread(self.n_threads, self.N)
def _(i):
self.X[i].plain_mul(self.W, res=prod[i])
def compute_f_input(self, batch):
N = len(batch)
prod = MultiArray([N, self.d, self.d_out], sfix)
assert self.d == 1
assert self.d_out == 1
@multithread(self.n_threads, N)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
prod.assign_vector(
X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size),
regint.inc(self.d_in),
regint.inc(self.d_in),
regint.inc(self.d_out))),
base)
@for_range_opt_multithread(self.n_threads, self.N)
def _(i):
@for_range_opt(self.d)
def _(j):
v = prod[i][j].get_vector() + self.b.get_vector()
self.f_input[i][j].assign(v)
@multithread(self.n_threads, N)
def _(base, size):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
progress('f input')
def forward(self):
self.compute_f_input()
self.Y.assign_vector(self.f(self.f_input.get_vector()))
def forward(self, batch=None):
self.compute_f_input(batch=batch)
self.Y.assign_vector(self.f(
self.f_input.get_part_vector(0, len(batch))))
def backward(self, compute_nabla_X=True):
N = self.N
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
d = self.d
d_out = self.d_out
X = self.X
@@ -233,6 +319,7 @@ class Dense(DenseBase):
@for_range_opt(N)
def _(i):
i = batch[i]
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
progress('f prime schur Y')
@@ -240,6 +327,7 @@ class Dense(DenseBase):
if compute_nabla_X:
@for_range_opt(N)
def _(i):
i = batch[i]
if self.activation == 'id':
nabla_X[i] = nabla_Y[i].mul_trans(W)
else:
@@ -247,7 +335,7 @@ class Dense(DenseBase):
progress('nabla X')
self.backward_params(f_schur_Y)
self.backward_params(f_schur_Y, batch=batch)
class QuantizedDense(DenseBase):
def __init__(self, N, d_in, d_out):
@@ -443,8 +531,8 @@ class QuantConv2d(QuantConvBase):
_, inputs_h, inputs_w, n_channels_in = self.input_shape
return weights_h * weights_w * n_channels_in
def forward(self, N=1):
assert(N == 1)
def forward(self, batch):
assert len(batch) == 1
assert(self.weight_shape[0] == self.output_shape[-1])
_, weights_h, weights_w, _ = self.weight_shape
@@ -499,8 +587,8 @@ class QuantDepthwiseConv2d(QuantConvBase):
_, weights_h, weights_w, _ = self.weight_shape
return weights_h * weights_w
def forward(self, N=1):
assert(N == 1)
def forward(self, batch):
assert len(batch) == 1
assert(self.weight_shape[-1] == self.output_shape[-1])
assert(self.input_shape[-1] == self.output_shape[-1])
@@ -562,8 +650,8 @@ class QuantAveragePool2d(QuantBase):
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
def forward(self, N=1):
assert(N == 1)
def forward(self, batch):
assert len(batch) == 1
_, input_h, input_w, n_channels_in = self.input_shape
_, output_h, output_w, n_channels_out = self.output_shape
@@ -623,8 +711,8 @@ class QuantReshape(QuantBase):
for i in range(2):
sint.get_input_from(player)
def forward(self, N=1):
assert(N == 1)
def forward(self, batch):
assert len(batch) == 1
# reshaping is implicit
self.Y.assign(self.X)
@@ -634,8 +722,8 @@ class QuantSoftmax(QuantBase):
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
def forward(self, N=1):
assert(N == 1)
def forward(self, batch):
assert len(batch) == 1
assert(len(self.input_shape) == 2)
# just print the best
@@ -648,31 +736,40 @@ class QuantSoftmax(QuantBase):
class Optimizer:
n_threads = Layer.n_threads
def forward(self, N):
def forward(self, N=None, batch=None):
if batch is None:
batch = regint.Array(N)
batch.assign(regint.inc(N))
for j in range(len(self.layers) - 1):
self.layers[j].forward()
self.layers[j + 1].X.assign(self.layers[j].Y)
self.layers[-1].forward(N)
self.layers[j].forward(batch=batch)
tmp = self.layers[j].Y.get_part_vector(0, len(batch))
self.layers[j + 1].X.assign_vector(tmp)
self.layers[-1].forward(batch=batch)
def backward(self):
def eval(self, data):
N = len(data)
self.layers[0].X.assign(data)
self.forward(N)
return self.layers[-1].eval(N)
def backward(self, batch):
for j in range(1, len(self.layers)):
self.layers[-j].backward()
self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X)
self.layers[0].backward(compute_nabla_X=False)
self.layers[-j].backward(batch=batch)
self.layers[-j - 1].nabla_Y.assign_vector(
self.layers[-j].nabla_X.get_part_vector(0, len(batch)))
self.layers[0].backward(compute_nabla_X=False, batch=batch)
def run(self):
def run(self, batch_size=None):
if batch_size is not None:
N = batch_size
else:
N = self.layers[0].N
i = MemValue(0)
@do_while
def _():
if self.X_by_label is not None:
N = self.layers[0].N
assert self.layers[-1].N == N
assert N % 2 == 0
n = N // 2
@for_range(n)
def _(i):
self.layers[-1].Y[i] = 0
self.layers[-1].Y[i + n] = 1
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
self.X_by_label) / n))
print('%d runs per epoch' % n_per_epoch)
@@ -680,26 +777,27 @@ class Optimizer:
for label, X in enumerate(self.X_by_label):
indices = regint.Array(n * n_per_epoch)
indices_by_label.append(indices)
indices.assign(i % len(X) for i in range(len(indices)))
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
indices.shuffle()
@for_range(n_per_epoch)
def _(j):
j = MemValue(j)
batch = regint.Array(N)
for label, X in enumerate(self.X_by_label):
indices = indices_by_label[label]
@for_range_multithread(self.n_threads, 1, n)
def _(i):
idx = indices[i + j * n]
self.layers[0].X[i + label * n] = X[idx]
self.forward(None)
self.backward()
batch.assign(indices.get_vector(j * n, n) +
regint(label * len(self.X_by_label[0]), size=n),
label * n)
self.forward(batch=batch)
self.backward(batch=batch)
self.update(i)
else:
self.forward(None)
self.backward()
batch = regint.Array(N)
batch.assign(regint.inc(N))
self.forward(batch=batch)
self.backward(batch=batch)
self.update(i)
loss = self.layers[-1].l
if self.report_loss:
if self.report_loss and not self.layers[-1].approx:
print_ln('loss after epoch %s: %s', i, loss.reveal())
else:
print_ln('done with epoch %s', i)
@@ -772,6 +870,13 @@ class SGD(Optimizer):
def reset(self, X_by_label=None):
self.X_by_label = X_by_label
if X_by_label is not None:
for label, X in enumerate(X_by_label):
@for_range_multithread(self.n_threads, 1, len(X))
def _(i):
j = i + label * len(X_by_label[0])
self.layers[0].X[j] = X[i]
self.layers[-1].Y[j] = label
for y in self.delta_thetas:
y.assign_all(0)
for layer in self.layers:
@@ -780,16 +885,15 @@ class SGD(Optimizer):
def update(self, i_epoch):
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
self.delta_thetas):
@for_range_opt_multithread(self.n_threads, len(nabla))
def _(k):
old = delta_theta[k]
if isinstance(old, Array):
old = old.get_vector()
@multithread(self.n_threads, len(nabla))
def _(base, size):
old = delta_theta.get_vector(base, size)
red_old = self.momentum * old
new = self.gamma * nabla[k]
new = self.gamma * nabla.get_vector(base, size)
diff = red_old - new
delta_theta[k] = diff
theta[k] = theta[k] + delta_theta[k]
delta_theta.assign_vector(diff, base)
theta.assign_vector(theta.get_vector(base, size) +
delta_theta.get_vector(base, size), base)
if self.debug:
for x, name in (old, 'old'), (red_old, 'red_old'), \
(new, 'new'), (diff, 'diff'):

View File

@@ -12,6 +12,8 @@ from Compiler import floatingpoint
from Compiler import types
from Compiler import comparison
from Compiler import program
from Compiler import instructions_base
# polynomials as enumerated on Hart's book
##
# @private
@@ -154,8 +156,8 @@ def p_eval(p_c, x):
def sTrigSub(x):
# reduction to 2* \pi
f = x * (1.0 / (2 * pi))
f = load_sint(trunc(f), type(x))
y = x - (f) * (2 * pi)
f = trunc(f)
y = x - (f) * x.coerce(2 * pi)
# reduction to \pi
b1 = y > pi
w = b1 * ((2 * pi - y) - y) + y
@@ -210,6 +212,7 @@ def scos(w, s):
# facade method calls --it is built in a generic way
@instructions_base.sfix_cisc
def sin(x):
"""
Returns the sine of any given fractional value.
@@ -224,6 +227,7 @@ def sin(x):
return ssin(w, b1)
@instructions_base.sfix_cisc
def cos(x):
"""
Returns the cosine of any given fractional value.
@@ -239,6 +243,7 @@ def cos(x):
return scos(w, b2)
@instructions_base.sfix_cisc
def tan(x):
"""
Returns the tangent of any given fractional value.
@@ -258,6 +263,7 @@ def tan(x):
@types.vectorize
@instructions_base.sfix_cisc
def exp2_fx(a):
"""
Power of two for fixed-point numbers.
@@ -273,39 +279,49 @@ def exp2_fx(a):
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
n_bits = a.f + n_int_bits
n_shift = int(types.program.options.ring) - a.k
r_bits = [sint.get_random_bit() for i in range(a.k)]
shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal()
if types.program.use_edabit():
l = sint.get_edabit(a.f, True)
u = sint.get_edabit(a.k - a.f, True)
r_bits = l[1] + u[1]
r = l[0] + (u[0] << a.f)
lower_r = l[0]
else:
r_bits = [sint.get_random_bit() for i in range(a.k)]
r = sint.bit_compose(r_bits)
lower_r = sint.bit_compose(r_bits[:a.f])
shifted = ((a.v - r) << n_shift).reveal()
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = sint()
comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1],
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])
lower_r = sint.bit_compose(r_bits[:a.f])
lower_masked = sint.bit_compose(masked_bits[:a.f])
lower = lower_r + lower_masked - (lower_overflow << (a.f))
lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f))
c = types.sfix._new(lower, k=a.k, f=a.f)
higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits],
r_bits[a.f:n_bits],
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
masked_bits[a.f:n_bits],
carry_in=lower_overflow,
get_carry=True)
d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]),
k=a.k, f=a.f)
assert(len(higher_bits) == n_bits - a.f + 1)
pow2_bits = [sint.conv(x) for x in higher_bits]
d = floatingpoint.Pow2_from_bits(pow2_bits[:-1])
e = p_eval(p_1045, c)
g = d * e
small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False,
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
2 ** n_int_bits, signed=False,
nearest=types.sfix.round_nearest),
k=a.k, f=a.f)
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
r_bits[n_bits:-1],
higher_bits[-1])
# should be for free
highest_bits = intbitint.ripple_carry_adder(
highest_bits = r_bits[0].ripple_carry_adder(
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
carry_in=higher_bits[-1])
bits_to_check = [x.bit_xor(y)
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
t = floatingpoint.KMul(bits_to_check)
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
bits_to_check))
# sign
s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry)
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
return s.if_else(t.if_else(small_result, 0), g)
else:
# obtain absolute value of a
@@ -313,16 +329,17 @@ def exp2_fx(a):
a = (s * (-2) + 1) * a
# isolates fractional part of number
b = trunc(a)
c = a - load_sint(b, type(a))
c = a - b
# squares integer part of a
d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a))
d = b.pow2(a.k - a.f)
# evaluates fractional part of a in p_1045
e = p_eval(p_1045, c)
g = d * e
return (1 - s) * g + s * ((types.sfix(1)) / g)
return (1 - s) * g + s / g
@types.vectorize
@instructions_base.sfix_cisc
def log2_fx(x):
"""
Returns the result of :math:`\log_2(x)` for any unbounded
@@ -347,14 +364,13 @@ def log2_fx(x):
v, p, vlen = d.v, d.p, d.vlen
# isolates mantisa of d, now the n can be also substituted by the
# secret shared p from d in the expresion above.
v = load_sint(v, type(x))
w = (1.0 / (2 ** (vlen)))
w = x.coerce(1.0 / (2 ** (vlen)))
v = v * w
# polynomials for the log_2 evaluation of f are calculated
P = p_eval(p_2524, v)
Q = p_eval(q_2524, v)
# the log is returned by adding the result of the division plus p.
a = P / Q + load_sint(vlen + p, type(x))
a = P / Q + (vlen + p)
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
@@ -515,7 +531,7 @@ def norm_simplified_SQ(b, k):
# @return g: approximated sqrt
def sqrt_simplified_fx(x):
# fix theta (number of iterations)
theta = max(int(math.ceil(math.log(types.sfix.k))), 6)
theta = max(int(math.ceil(math.log(x.k))), 6)
# process to use 2^(m/2) approximation
m_odd, m, w = norm_simplified_SQ(x.v, x.k)
@@ -524,15 +540,15 @@ def sqrt_simplified_fx(x):
m_odd = (1 - 2 * m_odd) + m_odd
w = (w * 2 - w) * (1-m_odd) + w
# map number to use sfix format and instantiate the number
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2))
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
# obtains correct 2 ** (m/2)
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
w = (w * (2 ** (1/2.0)) - w) * m_odd + w
# produce x/ 2^(m/2)
y_0 = types.cfix(1.0) / w
y_0 = 1 / w
# from this point on it sufices to work sfix-wise
g_0 = (y_0 * x)
h_0 = y_0 * types.cfix(0.5)
h_0 = y_0 * 0.5
gh_0 = g_0 * h_0
## initialization
@@ -689,7 +705,8 @@ def sqrt_fx(x_l, k, f):
@types.vectorize
def sqrt(x, k = types.sfix.k, f = types.sfix.f):
@instructions_base.sfix_cisc
def sqrt(x, k=None, f=None):
"""
Returns the square root (sfix) of any given fractional
value as long as it can be rounded to a integral value
@@ -699,7 +716,11 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f):
:return: square root of :py:obj:`x` (sfix).
"""
if (3 *k -2 * f >= types.sfix.f):
if k is None:
k = x.k
if f is None:
f = x.f
if (3 *k -2 * f >= f):
return sqrt_simplified_fx(x)
# raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ")
else:
@@ -707,6 +728,7 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f):
return sqrt_fx(param ,k ,f)
@instructions_base.sfix_cisc
def atan(x):
"""
Returns the arctangent (sfix) of any given fractional value.
@@ -720,12 +742,12 @@ def atan(x):
x_abs = (s * (-2) + 1) * x
# angle isolation
b = x_abs > 1
v = (types.cfix(1.0) / x_abs)
v = 1 / x_abs
v = (1 - b) * (x_abs - v) + v
v_2 =v*v
# range of polynomial coefficients
assert x.k - x.f >= 18
assert x.k - x.f >= 15
P = p_eval(p_5102, v_2)
Q = p_eval(q_5102, v_2)

View File

@@ -57,10 +57,14 @@ class intBlock(Block):
length = sum(self.lengths)
self.n_bits = length * entries_per_block
self.start = self.value_type.hard_conv(start * length)
self.lower, self.shift = \
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
if Program.prog.options.ring:
self.lower, trunc, self.shift = floatingpoint.SplitInRing(
self.value, self.n_bits, self.start)
else:
self.lower, self.shift = \
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
Program.prog.security, True)
trunc = (self.value - self.lower) / self.shift
trunc = (self.value - self.lower) / self.shift
self.slice = trunc.mod2m(length, self.n_bits, False)
self.upper = (trunc - self.slice) * self.shift
def get_slice(self):
@@ -810,7 +814,7 @@ def get_n_threads(n_loops):
if n_loops > 2048:
return 8
else:
return 1
return None
else:
return n_threads
@@ -1375,7 +1379,11 @@ def get_value_size(value_type):
if value_type == sgf2n:
return Program.prog.galois_length
elif value_type == sint:
return 127 - Program.prog.security
ring = Program.prog.options.ring
if ring:
return int(ring)
else:
return 127 - Program.prog.security
else:
return value_type.max_length
@@ -1477,11 +1485,13 @@ class PackedIndexStructure(object):
rem = mod2m(index, self.log_entries_per_block, log2(self.size), False)
c = mod2m(rem, self.log_entries_per_element, \
self.log_entries_per_block, False)
b = (rem - c) / self.entries_per_element
b = (rem - c).trunc_zeros(self.log_entries_per_element,
self.log_entries_per_block)
if self.small:
return 0, b, c
else:
return (index - rem) / self.entries_per_block, b, c
return (index - rem).trunc_zeros(self.log_entries_per_block,
log2(self.size)), b, c
else:
index_bits = bit_decompose(index, log2(self.size))
l1 = self.log_entries_per_element

View File

@@ -301,7 +301,8 @@ def iter_waksman(a, config, reverse=False):
conf_address = MemValue(config.address + depth.read()*n)
do_round(size, conf_address, a.address, a2.address, 1)
for i in range(n):
@for_range(n)
def _(i):
a[i] = a2[i]
nblocks.write(nblocks*2)
@@ -317,7 +318,8 @@ def iter_waksman(a, config, reverse=False):
conf_address = MemValue(config.address + depth.read()*n)
do_round(size, conf_address, a.address, a2.address, 0)
for i in range(n):
@for_range(n)
def _(i):
a[i] = a2[i]
nblocks.write(nblocks//2)
@@ -379,6 +381,14 @@ def config_shuffle(n, value_type):
config_bits = configure_waksman(perm)
# 2-D array
config = Array(len(config_bits) * len(perm), value_type.reg_type)
if n > 1024:
for x in config_bits:
for y in x:
get_program().public_input(y)
@for_range(sum(len(x) for x in config_bits))
def _(i):
config[i] = public_input()
return config
for i,c in enumerate(config_bits):
for j,b in enumerate(c):
config[i * len(perm) + j] = b

View File

@@ -22,12 +22,14 @@ data_types = dict(
bit = 2,
inverse = 3,
bittriple = 4,
bitgf2ntriple = 5
bitgf2ntriple = 5,
dabit = 6,
)
field_types = dict(
modp = 0,
gf2n = 1,
bit = 2,
)
@@ -39,6 +41,7 @@ class Program(object):
and threads. """
def __init__(self, args, options, param=-1, assemblymode=False):
self.options = options
self.verbose = options.verbose
self.args = args
self.init_names(args, assemblymode)
self.P = P_VALUES[param]
@@ -56,7 +59,8 @@ class Program(object):
self.security = 40
print('Default security parameter:', self.security)
self.galois_length = int(options.galois)
print('Galois length:', self.galois_length)
if self.verbose:
print('Galois length:', self.galois_length)
self.schedule = [('start', [])]
self.tape_counter = 0
self.tapes = []
@@ -73,9 +77,9 @@ class Program(object):
self.n_threads = 1
self.free_threads = set()
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
self.use_public_input_file = False
self.types = {}
self.budget = int(self.options.budget)
self.verbose = False
self.to_merge = [Compiler.instructions.asm_open_class, \
Compiler.instructions.gasm_open_class, \
Compiler.instructions.muls_class, \
@@ -89,12 +93,15 @@ class Program(object):
Compiler.instructions.inputfix_class,
Compiler.instructions.inputfloat_class,
Compiler.instructions.inputmixed_class,
Compiler.instructions.trunc_pr_class]
Compiler.instructions.trunc_pr_class,
Compiler.instructions_base.Mergeable]
import Compiler.GC.instructions as gc
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]
self.use_trunc_pr = False
self.use_dabit = options.mixed
self._edabit = options.edabit
self._split = False
Program.prog = self
self.reset_values()
@@ -132,7 +139,8 @@ class Program(object):
else:
# assume source is in main SPDZ directory
self.programs_dir = sys.path[0] + '/Programs'
print('Compiling program in', self.programs_dir)
if self.verbose:
print('Compiling program in', self.programs_dir)
# create extra directories if needed
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
@@ -161,7 +169,7 @@ class Program(object):
self.name += '-' + '-'.join(args[1:])
self.progname = progname
def new_tape(self, function, args=[], name=None):
def new_tape(self, function, args=[], name=None, single_thread=False):
if name is None:
name = function.__name__
name = "%s-%s" % (self.name, name)
@@ -170,7 +178,7 @@ class Program(object):
tape_index = len(self.tapes)
self.tape_stack.append(self.curr_tape)
self.curr_tape = Tape(name, self)
self.curr_tape.prevent_direct_memory_write = True
self.curr_tape.prevent_direct_memory_write = not single_thread
self.tapes.append(self.curr_tape)
function(*args)
self.finalize_tape(self.curr_tape)
@@ -183,7 +191,8 @@ class Program(object):
raise CompilerError('Compiler does not support ' \
'recursive spawning of threads')
if self.free_threads:
thread_number = self.free_threads.pop()
thread_number = min(self.free_threads)
self.free_threads.remove(thread_number)
else:
thread_number = self.n_threads
self.n_threads += 1
@@ -376,7 +385,7 @@ class Program(object):
else:
addr = self.allocated_mem[mem_type]
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)):
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
self.allocated_mem_blocks[addr,mem_type] = size
return addr
@@ -404,6 +413,7 @@ class Program(object):
def public_input(self, x):
self.public_input_file.write('%s\n' % str(x))
self.use_public_input_file = True
def set_bit_length(self, bit_length):
self.bit_length = bit_length
@@ -421,6 +431,22 @@ class Program(object):
self.tape_counter += 1
return res
def use_edabit(self, change=None):
if change is None:
return self._edabit
else:
self._edabit = change
def use_edabit_for(self, *args):
return True
def use_split(self, change=None):
if change is None:
return self._split
else:
assert change in (2, 3)
self._split = change
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
def __init__(self, name, program):
@@ -503,13 +529,17 @@ class Tape:
self.exit_condition.set_relative_jump(offset)
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
def purge(self):
def purge(self, retain_usage=True):
def relevant(inst):
req_node = Tape.ReqNode('')
req_node.num = Tape.ReqNum()
inst.add_usage(req_node)
return req_node.num != {}
self.usage_instructions = list(filter(relevant, self.instructions))
if retain_usage:
self.usage_instructions = list(filter(relevant,
self.instructions))
else:
self.usage_instructions = []
if len(self.usage_instructions) > 1000:
print('Retaining %d instructions' % len(self.usage_instructions))
del self.instructions
@@ -526,6 +556,12 @@ class Tape:
req_node.num['all', 'round'] = self.n_rounds
req_node.num['all', 'inv'] = self.n_to_merge
def expand_cisc(self):
new_instructions = []
for inst in self.instructions:
new_instructions.extend(inst.expand_merged())
self.instructions = new_instructions
def __str__(self):
return self.name
@@ -577,8 +613,9 @@ class Tape:
def unpurged(function):
def wrapper(self, *args, **kwargs):
if self.purged:
print('%s called on purged block %s, ignoring' % \
(function.__name__, self.name))
if self.program.verbose:
print('%s called on purged block %s, ignoring' % \
(function.__name__, self.name))
return
return function(self, *args, **kwargs)
return wrapper
@@ -592,7 +629,8 @@ class Tape:
if self.if_states:
raise CompilerError('Unclosed if/else blocks')
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
if self.program.verbose:
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
for block in self.basicblocks:
al.determine_scope(block, options)
@@ -601,7 +639,7 @@ class Tape:
# need to do this if there are several blocks
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
for i,block in enumerate(self.basicblocks):
if len(block.instructions) > 0:
if len(block.instructions) > 0 and self.program.verbose:
print('Processing basic block %s, %d/%d, %d instructions' % \
(block.name, i, len(self.basicblocks), \
len(block.instructions)))
@@ -609,7 +647,7 @@ class Tape:
merger = al.Merger(block, options, \
tuple(self.program.to_merge))
if options.dead_code_elimination:
if len(block.instructions) > 10000:
if len(block.instructions) > 100000:
print('Eliminate dead code...')
merger.eliminate_dead_code()
if options.merge_opens and self.merge_opens:
@@ -617,14 +655,14 @@ class Tape:
block.used_from_scope = util.set_by_id()
block.defined_registers = util.set_by_id()
continue
if len(block.instructions) > 10000:
if len(block.instructions) > 100000:
print('Merging instructions...')
numrounds = merger.longest_paths_merge()
block.n_rounds = numrounds
block.n_to_merge = len(merger.open_nodes)
if numrounds > 0:
if numrounds > 0 and self.program.verbose:
print('Program requires %d rounds of communication' % numrounds)
if merger.counter:
if merger.counter and self.program.verbose:
print('Block requires', \
', '.join('%d %s' % (y, x.__name__) \
for x, y in list(merger.counter.items())))
@@ -635,6 +673,9 @@ class Tape:
if not (options.merge_opens and self.merge_opens):
print('Not merging instructions in tape %s' % self.name)
if options.cisc:
self.expand_cisc()
# add jumps
offset = 0
for block in self.basicblocks:
@@ -659,7 +700,7 @@ class Tape:
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
print('Re-allocating...')
allocator = al.StraightlineAllocator(REG_MAX)
allocator = al.StraightlineAllocator(REG_MAX, self.program)
def alloc(block):
for reg in sorted(block.used_from_scope,
key=lambda x: (x.reg_type, x.i)):
@@ -673,7 +714,7 @@ class Tape:
if child.instructions:
left.append(child)
for i,block in enumerate(reversed(self.basicblocks)):
if len(block.instructions) > 10000:
if len(block.instructions) > 100000:
print('Allocating %s, %d/%d' % \
(block.name, i, len(self.basicblocks)))
if block.exit_condition is not None:
@@ -684,9 +725,11 @@ class Tape:
allocator.process(block.instructions, block.alloc_pool)
# offline data requirements
print('Compile offline data requirements...')
if self.program.verbose:
print('Compile offline data requirements...')
self.req_num = self.req_tree.aggregate()
print('Tape requires', self.req_num)
if self.program.verbose:
print('Tape requires', self.req_num)
for req,num in sorted(self.req_num.items()):
if num == float('inf') or num >= 2 ** 32:
num = -1
@@ -708,6 +751,14 @@ class Tape:
self.basicblocks[-1].instructions.append(
Compiler.instructions.guse_prep(req[1], num, \
add_to_prog=False))
elif req[0] == 'edabit':
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_edabit(False, req[1], num, \
add_to_prog=False))
elif req[0] == 'sedabit':
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_edabit(True, req[1], num, \
add_to_prog=False))
if not self.is_empty():
# bit length requirement
@@ -723,6 +774,11 @@ class Tape:
print('Tape requires prime bit length', self.req_bit_length['p'])
print('Tape requires galois bit length', self.req_bit_length['2'])
@unpurged
def expand_cisc(self):
for block in self.basicblocks:
block.expand_cisc()
@unpurged
def _get_instructions(self):
return itertools.chain.\
@@ -786,7 +842,7 @@ class Tape:
def reset_registers(self):
""" Reset register values to zero. """
self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX)
self.reg_values = RegType.create_dict(lambda: [])
def get_value(self, reg_type, i):
return self.reg_values[reg_type][i]
@@ -826,7 +882,7 @@ class Tape:
return res
def cost(self):
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
if req[1] != 'input')
if req[1] != 'input' and req[0] != 'edabit')
def __str__(self):
return ", ".join('%s inputs in %s from player %d' \
% (num, req[0], req[2]) \
@@ -927,9 +983,11 @@ class Tape:
self.relative_i = 0
if i is not None:
self.i = i
else:
elif size is not None:
self.i = program.reg_counter[reg_type]
program.reg_counter[reg_type] += size
else:
self.i = float('inf')
self.vector = []
if value is not None:
self.value = value
@@ -952,8 +1010,9 @@ class Tape:
def set_size(self, size):
if self.size == size:
return
elif not self.program.options.assemblymode:
raise CompilerError('Mismatch of instruction and register size')
elif not self.program.program.options.assemblymode:
raise CompilerError('Mismatch of instruction and register size:'
' %s != %s' % (self.size, size))
elif self.size == 1 and self.vectorbase is self:
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
# create vector register in assembly mode
@@ -979,6 +1038,10 @@ class Tape:
return Tape.Register(self.reg_type, self.program, size=size, i=i)
def get_vector(self, base, size):
if base == 0 and size == self.size:
return self
if size == 1:
return self[base]
res = self._new_by_number(self.i + base, size=size)
res.set_vectorbase(self)
self.create_vector_elements()
@@ -1007,7 +1070,10 @@ class Tape:
def __len__(self):
return self.size
def copy(self):
return Tape.Register(self.reg_type, Program.prog.curr_tape)
@property
def value(self):
return self.program.reg_values[self.reg_type][self.i]
@@ -1029,6 +1095,9 @@ class Tape:
self.reg_type == RegType.ClearGF2N or \
self.reg_type == RegType.ClearInt
def __bool__(self):
raise CompilerError('cannot derive truth value from register')
def __str__(self):
return self.reg_type + str(self.i)

View File

@@ -95,7 +95,8 @@ class ClientMessageType:
class MPCThread(object):
def __init__(self, target, name, args = [], runtime_arg = None):
def __init__(self, target, name, args = [], runtime_arg = 0,
single_thread = False):
""" Create a thread from a callable object. """
if not callable(target):
raise CompilerError('Target %s for thread %s is not callable' % (target,name))
@@ -105,18 +106,33 @@ class MPCThread(object):
self.args = args
self.runtime_arg = runtime_arg
self.running = 0
self.tape_handle = program.new_tape(target, args, name,
single_thread=single_thread)
self.run_handles = []
def start(self, runtime_arg = None):
self.running += 1
program.start_thread(self, runtime_arg or self.runtime_arg)
self.run_handles.append(program.run_tape(self.tape_handle, \
runtime_arg or self.runtime_arg))
def join(self):
if not self.running:
raise CompilerError('Thread %s is not running' % self.name)
self.running -= 1
program.stop_thread(self)
program.join_tape(self.run_handles.pop(0))
def copy_doc(a, b):
try:
a.__doc__ = b.__doc__
except:
pass
def no_doc(operation):
def wrapper(*args, **kwargs):
return operation(*args, **kwargs)
return wrapper
def copy_doc(a, b):
try:
a.__doc__ = b.__doc__
@@ -131,7 +147,9 @@ def no_doc(operation):
def vectorize(operation):
def vectorized_operation(self, *args, **kwargs):
if len(args):
from .GC.types import bits
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
and not isinstance(args[0], bits) \
and args[0].size != self.size:
raise CompilerError('Different vector sizes of operands')
set_global_vector_size(self.size)
@@ -292,6 +310,10 @@ class _int(object):
def bit_adder(*args, **kwargs):
return intbitint.bit_adder(*args, **kwargs)
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return intbitint.ripple_carry_adder(*args, **kwargs)
def if_else(self, a, b):
""" MUX on bit in arithmetic circuits.
@@ -416,6 +438,9 @@ class _structure(object):
res_params) \
for k in range(len(row))).reduce_after_mul()
class _vec(object):
pass
class _register(Tape.Register, _number, _structure):
@staticmethod
def n_elements():
@@ -427,7 +452,7 @@ class _register(Tape.Register, _number, _structure):
val = val.read()
if isinstance(val, cls):
return val
elif not isinstance(val, _register):
elif not isinstance(val, (_register, _vec)):
try:
return type(val)(cls.conv(v) for v in val)
except TypeError:
@@ -467,8 +492,7 @@ class _register(Tape.Register, _number, _structure):
address = regint.conv(address)
if size > 1 and address.size == 1:
res = regint(size=size)
for i in range(size):
movint(res[i], address + regint(i, size=1))
incint(res, address, 1)
return res
else:
return address
@@ -1089,6 +1113,29 @@ class regint(_register, _int):
rand(res, bit_length)
return res
@classmethod
def inc(cls, size, base=0, step=1, repeat=1, wrap=None):
"""
Produce :py:class:`regint` vector with certain patterns.
This is particularly useful for :py:meth:`SubMultiArray.direct_mul`.
:param size: Result size
:param base: First value
:param step: Increase step
:param repeat: Repeate this many times
:param wrap: Start over after this many increases
The following produces (1, 1, 1, 3, 3, 3, 5, 5, 5, 7)::
regint.inc(10, 1, 2, 3)
"""
res = regint(size=size)
if wrap is None:
wrap = size
incint(res, cls.conv(base, size=1), step, repeat, wrap)
return res
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Receive n register values from socket """
@@ -1333,6 +1380,12 @@ class regint(_register, _int):
res += bit
return res
def shuffle(self):
""" Returns insecure shuffle of vector. """
res = regint(size=len(self))
shuffle(res, self)
return res
def reveal(self):
""" Identity. """
return self
@@ -1371,7 +1424,7 @@ class localint(object):
class _secret(_register):
__slots__ = []
mov = staticmethod(movs)
mov = staticmethod(set_instruction_type(movs))
PreOR = staticmethod(lambda l: floatingpoint.PreORC(l))
PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l))
@@ -1475,9 +1528,7 @@ class _secret(_register):
res = cls(size=size)
n_rows = len(A) // n
n_cols = len(B) // n
dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)],
[B[k * n_cols + j % n_cols] for k in range(n)]]
for j in range(size)), []))
matmuls(res, A, B, n_rows, n, n_cols)
return res
@no_doc
@@ -1502,7 +1553,7 @@ class _secret(_register):
@read_mem_value
@vectorize
def load_other(self, val):
from Compiler.GC.types import sbits
from Compiler.GC.types import sbits, sbitvec
if isinstance(val, self.clear_type):
self.load_clear(val)
elif isinstance(val, type(self)):
@@ -1510,9 +1561,16 @@ class _secret(_register):
elif isinstance(val, sbits):
assert(val.n == self.size)
r = self.get_dabit()
v = regint()
bitdecint_class(regint((r[1] ^ val).reveal()), *v)
movs(self, r[0].bit_xor(v))
movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit()))
elif isinstance(val, sbitvec):
assert(sum(x.n for x in val.v) == self.size)
for val_part, base in zip(val, range(0, self.size, 64)):
left = min(64, self.size - base)
r = self.get_dabit(size=left)
v = regint(size=left)
bitdecint_class(regint((r[1] ^ val_part).reveal()), *v)
part = r[0].bit_xor(v)
vmovs(left, self.get_vector(base, left), part)
else:
self.load_clear(self.clear_type(val))
@@ -1555,8 +1613,8 @@ class _secret(_register):
size or one size 1 for a value-vector multiplication.
:param other: any compatible type """
if isinstance(other, _secret) and max(self.size, other.size) > 1 \
and min(self.size, other.size) == 1:
if isinstance(other, _secret) and (1 in (self.size, other.size)) \
and (self.size, other.size) != (1, 1):
x, y = (other, self) if self.size < other.size else (self, other)
res = type(self)(size=x.size)
mulrs(res, x, y)
@@ -1667,10 +1725,35 @@ class sint(_secret, _int):
dabit(*res)
return res
@vectorized_classmethod
def get_edabit(cls, n_bits, strict=False):
""" Bits in arithmetic and binary circuit """
""" according to security model """
if not program.use_edabit_for(strict, n_bits):
if program.use_dabit:
a, b = zip(*(sint.get_dabit() for i in range(n_bits)))
return sint.bit_compose(a), b
else:
a = [sint.get_random_bit() for i in range(n_bits)]
return sint.bit_compose(a), a
whole = cls()
size = get_global_vector_size()
from Compiler.GC.types import sbits, sbitvec
bits = [sbits.get_type(size)() for i in range(n_bits)]
if strict:
sedabit(whole, *bits)
else:
edabit(whole, *bits)
return whole, bits
@staticmethod
def long_one():
return 1
@staticmethod
def bit_decompose_clear(a, n_bits):
return floatingpoint.bits(a, n_bits)
@classmethod
def get_raw_input_from(cls, player):
res = cls()
@@ -1733,6 +1816,15 @@ class sint(_secret, _int):
""" Store in memory by public address. """
self._store_in_mem(address, stms, stmsi)
@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None):
if indices is None:
indices = [regint.inc(i) for i in (n, m, m, l)]
res = cls(size=indices[0].size * indices[3].size)
matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]),
len(indices[3]), *(list(indices) + [m, l]))
return res
def __init__(self, val=None, size=None):
"""
:param val: initialization (sint/cint/regint/int/cgf2n or list thereof)
@@ -1810,6 +1902,7 @@ class sint(_secret, _int):
return self.mod2m(int(l))
raise NotImplementedError('Modulo only implemented for powers of two.')
@vectorize
@read_mem_value
def mod2m(self, m, bit_length=None, security=None, signed=True):
""" Secret modulo power of two.
@@ -1941,6 +2034,10 @@ class sint(_secret, _int):
comparison.Trunc(res, tmp, 2 * k, k, kappa, True)
return res
def trunc_zeros(self, n_zeros, bit_length=None, signed=True):
bit_length = bit_length or program.bit_length
return comparison.TruncZeros(self, bit_length, n_zeros, signed)
@staticmethod
def two_power(n):
return floatingpoint.two_power(n)
@@ -2120,11 +2217,14 @@ class _bitint(object):
@classmethod
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
if cls.log_rounds:
return cls.carry_lookahead_adder(a, b, carry_in=carry_in)
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
elif cls.linear_rounds:
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
return cls.ripple_carry_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
else:
return cls.carry_select_adder(a, b, carry_in=carry_in)
return cls.carry_select_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
@classmethod
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
@@ -2205,8 +2305,8 @@ class _bitint(object):
@staticmethod
def full_adder(a, b, carry):
s = a + b
return s + carry, util.if_else(s, carry, a)
s = a ^ b
return s ^ carry, a ^ (s & (carry ^ a))
@staticmethod
def bit_comparator(a, b):
@@ -2243,6 +2343,7 @@ class _bitint(object):
b = util.bit_decompose(other, self.n_bits)
return self.compose(self.bit_adder(a, b))
@ret_cisc
def mul(self, other):
if type(other) == self.bin_type:
raise CompilerError('Unclear multiplication')
@@ -2270,12 +2371,13 @@ class _bitint(object):
@classmethod
def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True):
columns = [[_f for _f in (bit_matrix[j][i-j] \
for j in range(min(len(bit_matrix), i + 1))) if _f] \
for j in range(min(len(bit_matrix), i + 1))) \
if not is_zero(_f)] \
for i in range(len(bit_matrix[0]))]
return cls.wallace_tree_from_columns(columns, get_carry)
@classmethod
def wallace_tree_from_columns(cls, columns, get_carry=True):
def wallace_tree_without_finish(cls, columns, get_carry=True):
self = cls
while max(len(c) for c in columns) > 2:
new_columns = [[] for i in range(len(columns) + 1)]
@@ -2296,7 +2398,12 @@ class _bitint(object):
columns = new_columns[:-1]
for col in columns:
col.extend([0] * (2 - len(col)))
return self.bit_adder(*list(zip(*columns)))
return tuple(list(x) for x in zip(*columns))
@classmethod
def wallace_tree_from_columns(cls, columns, get_carry=True):
summands = cls.wallace_tree_without_finish(columns, get_carry)
return cls.bit_adder(*summands)
@classmethod
def wallace_tree(cls, rows):
@@ -2556,15 +2663,15 @@ def parse_type(other, k=None, f=None):
if isinstance(other, cfix.scalars):
return cfix(other, k=k, f=f)
elif isinstance(other, cint):
tmp = cfix()
tmp = cfix(k=k, f=f)
tmp.load_int(other)
return tmp
elif isinstance(other, sint):
tmp = sfix()
tmp = sfix(k=k, f=f)
tmp.load_int(other)
return tmp
elif isinstance(other, sfloat):
tmp = sfix(other)
tmp = sfix(other, k=k, f=f)
return tmp
else:
return other
@@ -2631,8 +2738,8 @@ class cfix(_number, _structure):
@vectorize_init
def __init__(self, v=None, k=None, f=None, size=None):
""" :param v: cfix/float/int """
f = f or self.f
k = k or self.k
f = self.f if f is None else f
k = self.k if k is None else k
self.f = f
self.k = k
self.size = get_global_vector_size()
@@ -2693,7 +2800,9 @@ class cfix(_number, _structure):
def mul(self, other):
""" Clear fixed-point multiplication.
:param other: cfix/cint/regint/int """
:param other: cfix/cint/regint/int/sint """
if isinstance(other, sint):
return sfix._new(self.v * other, k=self.k, f=self.f)
other = parse_type(other)
if isinstance(other, cfix):
assert self.f == other.f
@@ -2814,8 +2923,12 @@ class cfix(_number, _structure):
if isinstance(other, cfix):
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
elif isinstance(other, sfix):
return sfix(library.FPDiv(self.v, other.v, self.k, self.f,
other.kappa, nearest=sfix.round_nearest))
assert self.k == other.k
assert self.f == other.f
return sfix._new(library.FPDiv(self.v, other.v, self.k, self.f,
other.kappa,
nearest=sfix.round_nearest),
k=self.k, f=self.f)
else:
raise TypeError('Incompatible fixed point types in division')
@@ -3022,25 +3135,29 @@ class _fix(_single):
""" Convert secret integer.
:param other: sint """
res = cls()
res.f = f or cls.f
res.k = k or cls.k
res = cls(k=k, f=f)
res.load_int(cls.int_type.conv(other))
return res
@classmethod
def _new(cls, other, k=None, f=None):
res = cls(other)
res.k = k or cls.k
res.f = f or cls.f
res = cls(other, k=k, f=f)
return res
@vectorize_init
def __init__(self, _v=None, size=None):
def __init__(self, _v=None, k=None, f=None, size=None):
""" :params _v: compile-time value (int/float) """
self.size = get_global_vector_size()
f = self.f
k = self.k
if k is None:
k = self.k
else:
self.k = k
if f is None:
f = self.f
else:
self.f = f
assert k is not None
assert f is not None
# warning: don't initialize a sfix from a sint, this is only used in internal methods;
# for external initialization use load_int.
if _v is None:
@@ -3110,7 +3227,7 @@ class _fix(_single):
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
self.kappa,
self.round_nearest)
if self.size >= other.size:
if 'vec' not in self.__dict__:
return self._new(val, k=self.k, f=self.f)
else:
return self.vec._new(val, k=self.k, f=self.f)
@@ -3123,7 +3240,7 @@ class _fix(_single):
@vectorize
def __neg__(self):
""" Secret fixed-point negation. """
return type(self)(-self.v)
return self._new(-self.v, k=self.k, f=self.f)
@vectorize
def __truediv__(self, other):
@@ -3131,14 +3248,17 @@ class _fix(_single):
:param other: sfix/cfix/sint/cint/regint/int """
other = self.coerce(other)
assert self.k == other.k
assert self.f == other.f
if isinstance(other, _fix):
return type(self)(library.FPDiv(self.v, other.v, self.k, self.f,
self.kappa,
nearest=self.round_nearest))
v = library.FPDiv(self.v, other.v, self.k, self.f, self.kappa,
nearest=self.round_nearest)
elif isinstance(other, cfix):
return type(self)(library.sint_cint_division(self.v, other.v, self.k, self.f, self.kappa))
v = library.sint_cint_division(self.v, other.v, self.k, self.f,
self.kappa)
else:
raise TypeError('Incompatible fixed point types in division')
return self._new(v, k=self.k, f=self.f)
@vectorize
def __rtruediv__(self, other):
@@ -3192,11 +3312,21 @@ class sfix(_fix):
lower = average - 0.5 * 2 ** log_range
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None):
# pre-multiplication must be identity
tmp = cls.int_type.direct_matrix_mul(A, B, n, m, l, indices=indices)
res = unreduced_sfix._new(tmp)
if reduce:
res = res.reduce_after_mul()
return res
def coerce(self, other):
return parse_type(other, k=self.k, f=self.f)
def mul_no_reduce(self, other, res_params=None):
assert self.f == other.f
assert self.k == other.k
return self.unreduced(self.v * other.v)
def pre_mul(self):
@@ -3221,6 +3351,8 @@ class unreduced_sfix(_single):
self.k = k
self.m = m
self.kappa = kappa
assert self.k is not None
assert self.m is not None
def __add__(self, other):
if is_zero(other):
@@ -3236,7 +3368,8 @@ class unreduced_sfix(_single):
def reduce_after_mul(self):
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
nearest=sfix.round_nearest,
signed=True))
signed=True),
k=self.k // 2, f=self.m)
sfix.unreduced_type = unreduced_sfix
@@ -4012,12 +4145,15 @@ class Array(object):
pass
try:
other.store_in_mem(self.get_address(base))
assert len(self) >= other.size + base
if len(self) != None and util.is_constant(base):
assert len(self) >= other.size + base
except AttributeError:
for i,j in enumerate(other):
self[i] = j
return self
assign_vector = assign
def assign_all(self, value, use_threads=True, conv=True):
""" Assign the same value to all entries.
@@ -4040,6 +4176,19 @@ class Array(object):
size = size or self.length
return self.value_type.load_mem(self.get_address(base), size=size)
get_part_vector = get_vector
def get(self, indices):
return self.value_type.load_mem(
regint(self.address, size=len(indices)) + indices,
size=len(indices))
def expand_to_vector(self, index, size):
assert self.value_type.n_elements() == 1
address = regint(size=size)
incint(address, regint(self.get_address(index), size=1), 0)
return self.value_type.load_mem(address, size=size)
def get_mem_value(self, index):
return MemValue(self[index], self.get_address(index))
@@ -4082,12 +4231,15 @@ class Array(object):
def shuffle(self):
""" Insecure shuffle in place. """
@library.for_range(len(self))
def _(i):
j = regint.get_random(64) % (len(self) - i)
tmp = self[i]
self[i] = self[i + j]
self[i + j] = tmp
if self.value_type == regint:
self.assign(self.get_vector().shuffle())
else:
@library.for_range(len(self))
def _(i):
j = regint.get_random(64) % (len(self) - i)
tmp = self[i]
self[i] = self[i + j]
self[i + j] = tmp
def reveal(self):
""" Reveal the whole array.
@@ -4184,6 +4336,14 @@ class SubMultiArray(object):
assert self.sizes == other.sizes
self.assign_vector(other.get_vector())
def get_part_vector(self, base=0, size=None):
assert self.value_type.n_elements() == 1
part_size = reduce(operator.mul, self.sizes[1:])
size = (size or len(self)) * part_size
assert size <= self.total_size()
return self.value_type.load_mem(self.address + base * part_size,
size=size)
def same_shape(self):
""" :return: new multidimensional array with same shape and basic type """
return MultiArray(self.sizes, self.value_type)
@@ -4254,6 +4414,7 @@ class SubMultiArray(object):
for i, x in enumerate(other):
matrix[i][0] = x
res = self * matrix
library.break_point()
return Array.create_from(x[0] for x in res)
elif isinstance(other, SubMultiArray):
assert len(other.sizes) == 2
@@ -4292,6 +4453,32 @@ class SubMultiArray(object):
else:
raise NotImplementedError
def direct_mul(self, other, reduce=True, indices=None):
""" Matrix multiplication in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
:return: Matrix as vector of relevant type (row-major)
The following executes a matrix multiplication selecting every third row
of :py:obj:`A`::
A = sfix.Matrix(7, 4)
B = sfix.Matrix(4, 5)
C = sfix.Matrix(3, 5)
C.assign_vector(A.direct_mul(B, indices=(regint.inc(3, 0, 3),
regint.inc(4),
regint.inc(4),
regint.inc(5)))
"""
assert len(self.sizes) == 2
assert len(other.sizes) == 2
assert self.sizes[1] == other.sizes[0]
return self.value_type.direct_matrix_mul(self.address, other.address,
self.sizes[0], *other.sizes,
reduce=reduce, indices=indices)
def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True,
res=None):
assert len(self.sizes) == 2
@@ -4349,6 +4536,26 @@ class SubMultiArray(object):
lambda x, j: [x[k][j] for k in range(len(x))],
reduce=reduce, res=res)
def parallel_mul(self, other):
assert self.sizes[1] == other.sizes[0]
assert len(self.sizes) == 2
assert len(other.sizes) == 2
assert self.value_type.n_elements() == 1
n = self.sizes[0] * other.sizes[1]
a = []
b = []
for i in range(self.sizes[1]):
addresses = regint(size=n)
incint(addresses, regint(self.address + i), self.sizes[1],
other.sizes[1], n)
a.append(self.value_type.load_mem(addresses, size=n))
addresses = regint(size=n)
incint(addresses, regint(other.address + i * other.sizes[1]), 1,
1, other.sizes[1])
b.append(self.value_type.load_mem(addresses, size=n))
res = self.value_type.dot_product(a, b)
return res
def transpose(self):
""" Matrix transpose.

View File

@@ -15,6 +15,7 @@ public:
bool fewer_rounds;
bool check_open;
bool check_beaver_open;
bool R_after_msg;
EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv)
{
@@ -54,11 +55,21 @@ public:
"-B", // Flag token.
"--no-beaver-open-check" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Only open R after message is known", // Help description.
"-R", // Flag token.
"--R-after-msg" // Flag token.
);
opt.parse(argc, argv);
prep_mul = not opt.isSet("-D");
fewer_rounds = opt.isSet("-P");
check_open = not opt.isSet("-C");
check_beaver_open = not opt.isSet("-B");
R_after_msg = opt.isSet("-R");
opt.resetArgs();
}
};

View File

@@ -17,6 +17,10 @@
#include "Processor/Input.hpp"
#include "Processor/Processor.hpp"
#include "Processor/Data_Files.hpp"
#include "Protocols/MascotPrep.hpp"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "OT/NPartyTripleGenerator.hpp"
#include <assert.h>
@@ -41,6 +45,10 @@ int main(int argc, const char** argv)
Sub_Data_Files<pShare> prep(N, prefix, usage);
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
BaseMachine machine;
machine.ot_setups.push_back({P, false});
GC::ShareThread<typename pShare::bit_type> thread(N,
OnlineOptions::singleton, P, {}, usage);
SubProcessor<pShare> proc(_, MCp, prep, P);
pShare sk, __;

View File

@@ -11,11 +11,14 @@
#include "Math/gfp.h"
#include "ECDSA/P256Element.h"
#include "Tools/Bundle.h"
#include "GC/TinyMC.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/CcdSecret.h"
#include "GC/VectorInput.h"
#include "ECDSA/preprocessing.hpp"
#include "ECDSA/sign.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MaliciousRepPrep.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Processor/Input.hpp"
@@ -24,6 +27,8 @@
#include "GC/ShareSecret.hpp"
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "Machines/ShamirMachine.hpp"
#include <assert.h>

View File

@@ -7,8 +7,6 @@
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirInput.hpp"
#include "Protocols/ShamirMC.hpp"
#include "Protocols/MaliciousShamirMC.hpp"
#include "hm-ecdsa-party.hpp"

View File

@@ -18,6 +18,7 @@
#include "Processor/Processor.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Input.hpp"
#include "GC/TinyPrep.hpp"
#include <assert.h>
@@ -103,6 +104,8 @@ void run(int argc, const char** argv)
typename pShare::Direct_MC MCp(keyp, N, 0);
ArithmeticProcessor _({}, 0);
typename pShare::LivePrep sk_prep(0, usage);
GC::ShareThread<typename pShare::bit_type> thread(N,
OnlineOptions::singleton, P, {}, usage);
SubProcessor<pShare> sk_proc(_, MCp, sk_prep, P);
pShare sk, __;
// synchronize

View File

@@ -11,6 +11,13 @@
#include "Processor/Data_Files.h"
#include "Protocols/ReplicatedPrep.h"
#include "Protocols/MaliciousShamirShare.h"
#include "GC/TinierSecret.h"
#include "GC/TinierPrep.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/TinyMC.h"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdSecret.h"
template<template<class U> class T>
class EcTuple
@@ -18,6 +25,8 @@ class EcTuple
public:
T<P256Element::Scalar> a;
T<P256Element::Scalar> b;
P256Element::Scalar c;
T<P256Element> secret_R;
P256Element R;
};
@@ -62,9 +71,10 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
for (int i = 0; i < buffer_size; i++)
secret_Rs.push_back(bs[i] / cs_opened[i]);
}
vector<P256Element> opened_Rs;
vector<P256Element> opened_Rs(buffer_size);
typename cShare::Direct_MC MCc(MCp.get_alphai());
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (not opts.R_after_msg)
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul(&proc);
@@ -74,10 +84,13 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
}
if (opts.fewer_rounds)
MCp.POpen_End(cs_opened, cs, extra_player);
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
if (opts.fewer_rounds)
for (int i = 0; i < buffer_size; i++)
opened_Rs[i] /= cs_opened[i];
if (not opts.R_after_msg)
{
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
if (opts.fewer_rounds)
for (int i = 0; i < buffer_size; i++)
opened_Rs[i] /= cs_opened[i];
}
if (prep_mul)
protocol.stop_exchange();
if (opts.check_open)
@@ -88,7 +101,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
{
tuples.push_back(
{ inv_ks[i], prep_mul ? protocol.finalize_mul() : pShare(),
opened_Rs[i] });
cs_opened[i], secret_Rs[i], opened_Rs[i] });
}
timer.stop();
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
@@ -112,6 +125,7 @@ void check(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
assert(open_sk * inv_k == MC.POpen(tuple.b, P));
assert(tuple.R == k);
}
MC.Check(P);
}
template<>

View File

@@ -7,7 +7,6 @@
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirInput.hpp"
#include "Protocols/ShamirMC.hpp"
#include "hm-ecdsa-party.hpp"

View File

@@ -49,7 +49,10 @@ inline P256Element::Scalar hash_to_scalar(const unsigned char* message, size_t l
template<template<class U> class T>
EcSignature sign(const unsigned char* message, size_t length,
EcTuple<T> tuple,
typename T<P256Element::Scalar>::MAC_Check& MC, Player& P,
typename T<P256Element::Scalar>::MAC_Check& MC,
typename T<P256Element>::MAC_Check& MCc,
Player& P,
EcdsaOptions opts,
P256Element pk,
T<P256Element::Scalar> sk = {},
SubProcessor<T<P256Element::Scalar>>* proc = 0)
@@ -60,16 +63,30 @@ EcSignature sign(const unsigned char* message, size_t length,
size_t start = P.sent;
auto stats = P.comm_stats;
EcSignature signature;
signature.R = tuple.R;
vector<P256Element> opened_R;
if (opts.R_after_msg)
MCc.POpen_Begin(opened_R, {tuple.secret_R}, P);
T<P256Element::Scalar> prod = tuple.b;
auto& protocol = proc->protocol;
if (proc)
{
auto& protocol = proc->protocol;
protocol.init_mul(proc);
protocol.prepare_mul(sk, tuple.a);
protocol.exchange();
protocol.start_exchange();
}
if (opts.R_after_msg)
{
MCc.POpen_End(opened_R, {tuple.secret_R}, P);
tuple.R = opened_R[0];
if (opts.fewer_rounds)
tuple.R /= tuple.c;
}
if (proc)
{
protocol.stop_exchange();
prod = protocol.finalize_mul();
}
signature.R = tuple.R;
auto rx = tuple.R.x();
signature.s = MC.open(
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
@@ -132,7 +149,7 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
{
check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message,
check(sign(message, 1 << i, tuples[i], MCp, MCc, P, opts, pk, sk, proc), message,
1 << i, pk);
if (not opts.check_open)
continue;
@@ -142,6 +159,7 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
auto stats = check_player.comm_stats;
auto start = check_player.sent;
MCp.Check(P);
MCc.Check(P);
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
<< (check_player.sent - start) << " bytes" << endl;
auto diff = (check_player.comm_stats - stats);

View File

@@ -221,4 +221,12 @@ public:
}
};
class ran_out
{
const char* what() const
{
return "insufficient preprocessing";
}
};
#endif

45
FHE/AddableVector.cpp Normal file
View File

@@ -0,0 +1,45 @@
/*
* AddableVector.cpp
*
*/
#include "AddableVector.h"
#include "Rq_Element.h"
#include "FHE_Keys.h"
template<class T>
AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
const FHE_PK& pk) const
{
int phi_m = this->size();
assert(phi_m == pk.get_params().phi_m());
AddableVector res(phi_m);
for (int i = 0; i < phi_m; i++)
{
int k = j + i, s = 1;
while (k >= phi_m)
{
k -= phi_m;
s = -s;
}
if (s == 1)
{
res[k] = (*this)[i];
}
else
{
res[k] = -(*this)[i];
}
}
return res;
}
template
AddableVector<fixint<0>> AddableVector<fixint<0>>::mul_by_X_i(int j,
const FHE_PK& pk) const;
template
AddableVector<fixint<1>> AddableVector<fixint<1>>::mul_by_X_i(int j,
const FHE_PK& pk) const;
template
AddableVector<fixint<2>> AddableVector<fixint<2>>::mul_by_X_i(int j,
const FHE_PK& pk) const;

View File

@@ -10,6 +10,7 @@
using namespace std;
#include "FHE/Plaintext.h"
#include "Rq_Element.h"
template<class T>
class AddableVector: public vector<T>
@@ -27,6 +28,11 @@ public:
this->assign(other.begin(), other.end());
}
AddableVector(const Rq_Element& other) :
AddableVector(other.to_vec_bigint())
{
}
template <class U>
void allocate_slots(const U& init)
{
@@ -66,6 +72,8 @@ public:
(*this)[i].mul(x[i], y[i]);
}
AddableVector mul_by_X_i(int i, const FHE_PK& pk) const;
void generateUniform(PRNG& G, int n_bits)
{
for (unsigned int i = 0; i < this->size(); i++)
@@ -171,6 +179,19 @@ public:
for (int i = 0; i < n; i++)
(*this)[i].resize(m);
}
AddableMatrix mul_by_X_i(int i, const FHE_PK& pk) const;
};
template<class T>
AddableMatrix<T> AddableMatrix<T>::mul_by_X_i(int i,
const FHE_PK& pk) const
{
AddableMatrix<T> res;
res.resize(this->size());
for (size_t j = 0; j < this->size(); j++)
res[j] = (*this)[j].mul_by_X_i(i, pk);
return res;
}
#endif /* FHEOFFLINE_ADDABLEVECTOR_H_ */

View File

@@ -1,5 +1,6 @@
#include "Ciphertext.h"
#include "PPData.h"
#include "P2Data.h"
#include "Exceptions/Exceptions.h"
Ciphertext::Ciphertext(const FHE_PK& pk) : Ciphertext(pk.get_params())
@@ -118,12 +119,11 @@ template<class T,class FD,class S>
void mul(Ciphertext& ans,const Plaintext<T,FD,S>& a,const Ciphertext& c)
{
a.to_poly();
const vector<S>& aa=a.get_poly();
int lev=c.cc0.level();
Rq_Element ra((*ans.params).FFTD(),evaluation,evaluation);
if (lev==0) { ra.lower_level(); }
ra.from_vec(aa);
ra.from(a.get_iterator());
ans.mul(c, ra);
}

View File

@@ -35,11 +35,17 @@ class Ciphertext
Ciphertext(const FHE_PK &pk);
Ciphertext(const Rq_Element& a0, const Rq_Element& a1, const Ciphertext& C) :
Ciphertext(C.get_params())
{
set(a0, a1, C.get_pk_id());
}
~Ciphertext() { ; }
// Rely on default copy assignment/constructor
word get_pk_id() { return pk_id; }
word get_pk_id() const { return pk_id; }
void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id)
{ cc0=a0; cc1=a1; this->pk_id = pk_id; }
@@ -93,9 +99,14 @@ class Ciphertext
template <class FD>
Ciphertext& operator*=(const Plaintext_<FD>& other) { ::mul(*this, *this, other); return *this; }
Ciphertext mul(const Ciphertext& x, const FHE_PK& pk) const
Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const
{ Ciphertext res(*params); ::mul(res, *this, x, pk); return res; }
Ciphertext mul_by_X_i(int i, const FHE_PK&) const
{
return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this};
}
int level() const { return cc0.level(); }
// pack/unpack (like IO) also assume params are known and already set

View File

@@ -4,34 +4,30 @@
void DiscreteGauss::set(double RR)
{
R=RR;
e=exp(1); e1=exp(0.25); e2=exp(-1.35);
if (RR > 0 or NewHopeB < 1)
NewHopeB = max(1, int(round(2 * RR * RR)));
assert(NewHopeB > 0);
}
/* Return a value distributed normaly with std dev R */
int DiscreteGauss::sample(PRNG& G, int stretch) const
/* This uses the approximation to a Gaussian via
* binomial distribution
*
* This procedure consumes 2*NewHopeB bits
*
*/
int DiscreteGauss::sample(PRNG &G, int stretch) const
{
/* Uses the ratio method from Wikipedia to get a
Normal(0,1) variable X
Then multiplies X by R
*/
double U,V,X,R1,R2,R3,X2;
int ans;
while (true)
{ U=G.get_double();
V=G.get_double();
R1=5-4*e1*U;
R2=4*e2/U+1.4;
R3=-4/log(U);
X=sqrt(8/e)*(V-0.5)/U;
X2=X*X;
if (X2<=R1 || (X2<R2 && X2<=R3))
{ ans=(int) (X*R*stretch);
return ans;
}
int s= 0;
// stretch refers to the standard deviation
int B = NewHopeB * stretch * stretch;
for (int i = 0; i < B; i++)
{
s += G.get_bit();
s -= G.get_bit();
}
return s;
}
@@ -39,7 +35,8 @@ int DiscreteGauss::sample(PRNG& G, int stretch) const
void RandomVectors::set(int nn,int hh,double R)
{
n=nn;
h=hh;
if (h > 0)
h=hh;
DG.set(R);
}
@@ -124,7 +121,7 @@ bool RandomVectors::operator!=(const RandomVectors& other) const
bool DiscreteGauss::operator!=(const DiscreteGauss& other) const
{
if (other.R != R or other.e != e or other.e1 != e1 or other.e2 != e2)
if (other.NewHopeB != NewHopeB)
return true;
else
return false;

View File

@@ -7,35 +7,33 @@
#include <FHE/Generator.h>
#include "Math/modp.h"
#include "Math/gfp.h"
#include "Tools/random.h"
#include <vector>
/* Uses the Ratio method as opposed to the Box-Muller method
* as the Ratio method is thread safe, but it is 50% slower
*/
#include <math.h>
class DiscreteGauss
{
double R; // Standard deviation
double e; // Precomputed exp(1)
double e1; // Precomputed exp(0.25)
double e2; // Precomputed exp(-1.35)
/* This is the bound we use on for the NewHope approximation
* to a discrete Gaussian with sigma=sqrt(B/2)
*/
int NewHopeB;
public:
void set(double R);
void pack(octetStream& o) const { o.serialize(R); }
void unpack(octetStream& o) { o.unserialize(R); }
void pack(octetStream& o) const { o.serialize(NewHopeB); }
void unpack(octetStream& o) { o.unserialize(NewHopeB); }
DiscreteGauss() { set(0); }
DiscreteGauss(double R) { set(R); }
~DiscreteGauss() { ; }
// Rely on default copy constructor/assignment
int sample(PRNG& G, int stretch = 1) const;
double get_R() const { return R; }
double get_R() const { return sqrt(0.5 * NewHopeB); }
int get_NewHopeB() const { return NewHopeB; }
bool operator!=(const DiscreteGauss& other) const;
};
@@ -56,8 +54,8 @@ class RandomVectors
void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
void unpack(octetStream& o) { o.get(n); o.get(h); DG.unpack(o); }
RandomVectors() { ; }
RandomVectors(int nn,int hh,double R) { set(nn,hh,R); }
RandomVectors(int h, double R) : n(0), h(h), DG(R) {}
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
~RandomVectors() { ; }
// Rely on default copy constructor/assignment
@@ -81,7 +79,8 @@ class RandomVectors
bool operator!=(const RandomVectors& other) const;
};
class RandomGenerator : public Generator<bigint>
template<class T>
class RandomGenerator : public Generator<T>
{
protected:
mutable PRNG G;
@@ -90,39 +89,42 @@ public:
RandomGenerator(PRNG& G) { this->G.SetSeed(G); }
};
class UniformGenerator : public RandomGenerator
template<class T>
class UniformGenerator : public RandomGenerator<T>
{
int n_bits;
bool positive;
public:
UniformGenerator(PRNG& G, int n_bits, bool positive = true) :
RandomGenerator(G), n_bits(n_bits), positive(positive) {}
Generator* clone() const { return new UniformGenerator(*this); }
void get(bigint& x) const { G.get_bigint(x, n_bits, positive); }
RandomGenerator<T>(G), n_bits(n_bits), positive(positive) {}
Generator<T>* clone() const { return new UniformGenerator<T>(*this); }
void get(T& x) const { this->G.get(x, n_bits, positive); }
};
class GaussianGenerator : public RandomGenerator
template<class T>
class GaussianGenerator : public RandomGenerator<T>
{
DiscreteGauss DG;
int stretch;
public:
GaussianGenerator(const DiscreteGauss& DG, PRNG& G, int stretch = 1) :
RandomGenerator(G), DG(DG), stretch(stretch) {}
Generator* clone() const { return new GaussianGenerator(*this); }
void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), DG.sample(G, stretch)); }
RandomGenerator<T>(G), DG(DG), stretch(stretch) {}
Generator<T>* clone() const { return new GaussianGenerator<T>(*this); }
void get(T& x) const { x = DG.sample(this->G, stretch); }
};
int sample_half(PRNG& G);
class HalfGenerator : public RandomGenerator
template<class T>
class HalfGenerator : public RandomGenerator<T>
{
public:
HalfGenerator(PRNG& G) :
RandomGenerator(G) {}
Generator* clone() const { return new HalfGenerator(*this); }
void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), sample_half(G)); }
RandomGenerator<T>(G) {}
Generator<T>* clone() const { return new HalfGenerator<T>(*this); }
void get(T& x) const { x = sample_half(this->G); }
};
#endif

View File

@@ -1,7 +1,7 @@
#include "FHE/FFT_Data.h"
#include "FHE/FFT.h"
#include "Math/Subroutines.h"
#include "FHE/Subroutines.h"
void FFT_Data::assign(const FFT_Data& FFTD)

View File

@@ -4,6 +4,7 @@
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/gfp.h"
#include "Math/fixint.h"
#include "FHE/Ring.h"
/* Class for holding modular arithmetic data wrt the ring
@@ -36,6 +37,7 @@ class FFT_Data
public:
typedef gfp T;
typedef bigint S;
typedef fixint<GFP_MOD_SZ> poly_type;
void init(const Ring& Rg,const Zp_Data& PrD);

View File

@@ -161,7 +161,7 @@ void FHE_PK::encrypt(Ciphertext& c, const vector<S>& mess,
const Random_Coins& rc) const
{
Rq_Element mm((*params).FFTD(),polynomial,polynomial);
mm.from_vec(mess);
mm.from(Iterator<S>(mess));
quasi_encrypt(c, mm, rc);
}
@@ -302,6 +302,7 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
{ dec_sh.negate(); }
// Now convert to a vector of bigint's and add the required randomness
assert(pr != 0);
bigint Bd=((*params).B()<<(*params).secp())/(num_players*pr);
Bd=Bd/2; // make slightly smaller due to rounding issues
@@ -334,6 +335,29 @@ void FHE_SK::dist_decrypt_2(vector<bigint>& vv,const vector<bigint>& vv1) const
}
}
void FHE_PK::pack(octetStream& o) const
{
o.append((octet*) "PKPKPKPK", 8);
a0.pack(o);
b0.pack(o);
Sw_a.pack(o);
Sw_b.pack(o);
pr.pack(o);
}
void FHE_PK::unpack(octetStream& o)
{
char tag[8];
o.consume((octet*) tag, 8);
if (memcmp(tag, "PKPKPKPK", 8))
throw runtime_error("invalid serialization of public key");
a0.unpack(o);
b0.unpack(o);
Sw_a.unpack(o);
Sw_b.unpack(o);
pr.unpack(o);
}
bool FHE_PK::operator!=(const FHE_PK& x) const
{
@@ -379,7 +403,7 @@ template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
template void FHE_PK::encrypt(Ciphertext& c, const vector<int>& mess,
const Random_Coins& rc) const;
template void FHE_PK::encrypt(Ciphertext& c, const vector<bigint>& mess,
template void FHE_PK::encrypt(Ciphertext& c, const vector<fixint<2>>& mess,
const Random_Coins& rc) const;
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,

View File

@@ -27,7 +27,7 @@ class FHE_SK
// secret key always on lower level
void assign(const Rq_Element& s) { sk=s; sk.lower_level(); }
FHE_SK(const FHE_Params& pms, const bigint& p = 0)
FHE_SK(const FHE_Params& pms, const bigint& p)
: sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; }
FHE_SK(const FHE_PK& pk);
@@ -110,6 +110,17 @@ class FHE_PK
Sw_b(pms.FFTD(),evaluation,evaluation)
{ params=&pms; pr=p; }
FHE_PK(const FHE_Params& pms, int p) :
FHE_PK(pms, bigint(p))
{
}
template<class FD>
FHE_PK(const FHE_Params& pms, const FD& FTD) :
FHE_PK(pms, FTD.get_prime())
{
}
// Rely on default copy constructor/assignment
const Rq_Element& a() const { return a0; }
@@ -148,10 +159,8 @@ class FHE_PK
friend istream& operator>>(istream& s, FHE_PK& PK)
{ s >> PK.a0 >> PK.b0 >> PK.Sw_a >> PK.Sw_b; return s; }
void pack(octetStream& o) const
{ a0.pack(o); b0.pack(o); Sw_a.pack(o); Sw_b.pack(o); pr.pack(o); }
void unpack(octetStream& o)
{ a0.unpack(o); b0.unpack(o); Sw_a.unpack(o); Sw_b.unpack(o); pr.unpack(o); }
void pack(octetStream& o) const;
void unpack(octetStream& o);
bool operator!=(const FHE_PK& x) const;

View File

@@ -29,14 +29,14 @@ class FHE_Params
public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), sec_p(-1) {}
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(64, 0.7), sec_p(-1) {}
int n_mults() const { return FFTData.size() - 1; }
// Rely on default copy assignment/constructor (not that they should
// ever be needed)
void set(const Ring& R,const vector<bigint>& primes,double r=3.2,int hwt=64);
void set(const Ring& R,const vector<bigint>& primes,double r=-1,int hwt=-1);
void set_sec(int sec);
vector<bigint> sampleGaussian(PRNG& G, int noise_boost = 1) const

View File

@@ -24,40 +24,17 @@ NTL_CLIENT
#include "FHEOffline/DataSetup.h"
template <>
void generate_setup(int n_parties, int plaintext_length, int sec,
FHE_Params& params, FFT_Data& FTD, int slack, bool round_up)
{
Ring Rp;
bigint p0p,p1p,p;
SPDZ_Data_Setup_Char_p(Rp, FTD, p0p, p1p, n_parties, plaintext_length, sec,
slack, round_up);
params.set(Rp, {p0p, p1p});
}
template <>
void generate_setup(int n_parties, int plaintext_length, int sec,
FHE_Params& params, P2Data& P2D, int slack, bool round_up)
{
Ring R;
bigint pr0,pr1;
SPDZ_Data_Setup_Char_2(R, P2D, pr0, pr1, n_parties, plaintext_length, sec,
slack, round_up);
params.set(R, {pr0, pr1});
}
void generate_setup(int n, int lgp, int lg2, int sec, bool skip_2,
int slack, bool round_up)
{
DataSetup setup;
// do the full setup for SHE data
generate_setup(n, lgp, sec, setup.setup_p.params, setup.setup_p.FieldD, slack, round_up);
Parameters(n, lgp, sec, slack, round_up).generate_setup(setup.setup_p.params,
setup.setup_p.FieldD);
if (!skip_2)
generate_setup(n, lg2, sec, setup.setup_2.params, setup.setup_2.FieldD, slack, round_up);
Parameters(n, lg2, sec, slack, round_up).generate_setup(
setup.setup_2.params, setup.setup_2.FieldD);
setup.write_setup(skip_2);
}
@@ -208,9 +185,9 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool roun
/*
* Subroutine for creating the FHE parameters
*/
int SPDZ_Data_Setup_Char_p_Sub(Ring& R, bigint& pr0, bigint& pr1, int n,
int idx, int& m, bigint& p, int sec, int slack = 0, bool round_up = false)
int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p)
{
int n = n_parties;
int lg2pi[5][2][9]
= { { {130,132,132,132,132,132,132,132,132},
{104,104,104,106,106,108,108,110,110} },
@@ -291,13 +268,13 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
/*
* Create the char p FHE parameters
*/
void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1,
int n, int lgp, int sec, int slack, bool round_up)
template <>
void Parameters::SPDZ_Data_Setup(FFT_Data& FTD)
{
bigint p;
int idx, m;
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
SPDZ_Data_Setup_Char_p_Sub(R, pr0, pr1, n, idx, m, p, sec, slack, round_up);
SPDZ_Data_Setup_Primes(p, plaintext_length, idx, m);
SPDZ_Data_Setup_Char_p_Sub(idx, m, p);
Zp_Data Zp(p);
gfp::init_field(p);
@@ -574,9 +551,12 @@ void char_2_dimension(int& m, int& lg2)
}
}
void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1,
int n, int lg2, int sec, int slack, bool round_up)
template <>
void Parameters::SPDZ_Data_Setup(P2Data& P2D)
{
int n = n_parties;
int lg2 = plaintext_length;
int lg2pi[2][9]
= { {70,70,70,70,70,70,70,70,70},
{70,75,75,75,75,80,80,80,80}
@@ -760,7 +740,8 @@ void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
{ throw bad_value(); }
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
SPDZ_Data_Setup_Char_p_Sub(R,pr0,pr1,n,idx,m,p,sec);
Parameters parameters(n, lgp, sec);
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p);
int mx=0;
for (int i=0; i<R.phi_m(); i++)
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
@@ -769,6 +750,9 @@ void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
Zp_Data Zp(p);
gfp::init_field(p);
PPD.init(R,Zp);
pr0 = parameters.pr0;
pr1 = parameters.pr1;
}

View File

@@ -9,19 +9,47 @@
#include "FHE/PPData.h"
#include "FHE/FHE_Params.h"
/* Routines to set up key sizes given the number of players n
* - And size lgp of plaintext modulus p for the char p case
*/
class Parameters
{
int n_parties;
int plaintext_length;
int sec;
int slack;
bool round_up;
public:
Ring R;
bigint pr0, pr1;
Parameters(int n_parties, int plaintext_length, int sec, int slack = 0,
bool round_up = false) :
n_parties(n_parties), plaintext_length(plaintext_length), sec(sec), slack(
slack), round_up(round_up)
{
}
template<class FD>
void generate_setup(FHE_Params& params, FD& FTD)
{
SPDZ_Data_Setup(FTD);
params.set(R, {pr0, pr1});
}
template<class FD>
void SPDZ_Data_Setup(FD& FTD);
int SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p);
};
// Main setup routine (need NTL if online_only is false)
void generate_setup(int nparties, int lgp, int lg2,
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
template <class FD>
void generate_setup(int n_parties, int plaintext_length, int sec,
FHE_Params& params, FD& FTD, int slack, bool round_up);
// semi-homomorphic, includes slack
template <class FD>
int generate_semi_setup(int plaintext_length, int sec,
@@ -35,10 +63,6 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
void init(Ring& Rg,int m);
void init(P2Data& P2D,const Ring& Rg);
// For use when we only care about p being of a certain size
void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1,
int n, int lgp, int sec, int slack = 0, bool round_up = false);
// For use when we want p to be a specific value
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p);
@@ -53,9 +77,6 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
// pre-generated dimensions for characteristic 2
void char_2_dimension(int& m, int& lg2);
void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1,
int n, int lg2, int sec = -1, int slacke = 0, bool round_up = false);
// try to avoid expensive generation by loading from disk if possible
void load_or_generate(P2Data& P2D, const Ring& Rg);

View File

@@ -5,6 +5,7 @@
#include <FHE/NoiseBounds.h>
#include "FHEOffline/Proof.h"
#include "Protocols/CowGearOptions.h"
#include <math.h>
@@ -13,14 +14,44 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
p(p), phi_m(phi_m), n(n), sec(sec),
slack(numBits(Proof::slack(slack_param, sec, phi_m))), sigma(sigma), h(h)
{
if (sigma <= 0)
this->sigma = sigma = FHE_Params().get_R();
cerr << "Standard deviation: " << this->sigma << endl;
h += extra_h * sec;
B_clean = (phi_m * p / 2
+ p * sigma
* (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m)
+ 16 * sqrt(n * h * phi_m))) << slack;
B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3);
drown = 1 + (bigint(1) << sec);
cout << "log(slack): " << slack << endl;
produce_epsilon_constants();
if (CowGearOptions::singleton.top_gear())
{
// according to documentation of SCALE-MAMBA 1.7
// excluding a factor of n because we don't always add up n ciphertexts
B_clean = (bigint(phi_m) << (sec + 2)) * p
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * sqrt(h));
mpf_class V_s;
if (h > 0)
V_s = sqrt(h);
else
V_s = sigma * sqrt(phi_m);
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
#ifdef NOISY
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cout << "V_s: " << V_s << endl;
cout << "c1: " << c1 << endl;
cout << "c2: " << c2 << endl;
cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
cout << "B_scale: " << B_scale << endl;
#endif
}
else
{
B_clean = (phi_m * p / 2
+ p * sigma
* (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m)
+ 16 * sqrt(n * h * phi_m))) << slack;
B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3);
cout << "log(slack): " << slack << endl;
}
drown = 1 + n * (bigint(1) << sec);
}
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
@@ -34,28 +65,62 @@ bigint SemiHomomorphicNoiseBounds::min_p0()
return B_clean * drown * p;
}
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q)
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma)
{
if (sigma <= 0)
sigma = FHE_Params().get_R();
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
return 37.8 * (log_q - log2(3.2));
return 37.8 * (log_q - log2(sigma));
}
void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
{
double C[3];
for (int i = 0; i < 3; i++)
{
C[i] = -1;
}
for (double x = 0.1; x < 10.0; x += .1)
{
double t = erfc(x), tp = 1;
for (int i = 1; i < 3; i++)
{
tp *= t;
double lgtp = log(tp) / log(2.0);
if (C[i] < 0 && lgtp < FHE_epsilon)
{
C[i] = pow(x, i);
}
}
}
c1 = C[1];
c2 = C[2];
}
NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
double sigma, int h) :
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, sigma, h)
{
B_KS = p * phi_m * sigma
* (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h)
+ 2.77 * n * n * sqrt(h)
+ pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h))
+ 4.62 * n);
if (CowGearOptions::singleton.top_gear())
{
B_KS = p * c2 * this->sigma * phi_m / sqrt(12);
}
else
{
B_KS = p * phi_m * mpf_class(this->sigma)
* (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h)
+ 2.77 * n * n * sqrt(h)
+ pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h))
+ 4.62 * n);
}
#ifdef NOISY
cout << "p size: " << numBits(p) << endl;
cout << "phi(m): " << phi_m << endl;
cout << "n: " << n << endl;
cout << "sec: " << sec << endl;
cout << "sigma: " << sigma << endl;
cout << "sigma: " << this->sigma << endl;
cout << "h: " << h << endl;
cout << "B_clean size: " << numBits(B_clean) << endl;
cout << "B_scale size: " << numBits(B_scale) << endl;

View File

@@ -13,27 +13,33 @@ int phi_N(int N);
class SemiHomomorphicNoiseBounds
{
protected:
static const int FHE_epsilon = 55;
const bigint p;
const int phi_m;
const int n;
const int sec;
int slack;
const double sigma;
mpf_class sigma;
const int h;
bigint B_clean;
bigint B_scale;
bigint drown;
mpf_class c1, c2;
void produce_epsilon_constants();
public:
SemiHomomorphicNoiseBounds(const bigint& p, int phi_m, int n, int sec,
int slack, bool extra_h = false, double sigma = 3.2, int h = 64);
int slack, bool extra_h = false, double sigma = -1, int h = 64);
// with scaling
bigint min_p0(const bigint& p1);
// without scaling
bigint min_p0();
bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); }
static double min_phi_m(int log_q);
static double min_phi_m(int log_q, double sigma = -1);
};
// as per ePrint 2012:642 for slack = 0
@@ -43,7 +49,7 @@ class NoiseBounds : public SemiHomomorphicNoiseBounds
public:
NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
double sigma = 3.2, int h = 64);
double sigma = -1, int h = 64);
bigint U1(const bigint& p0, const bigint& p1);
bigint U2(const bigint& p0, const bigint& p1);
bigint min_p0(const bigint& p0, const bigint& p1);

View File

@@ -20,6 +20,7 @@ class P2Data
public:
typedef gf2n_short T;
typedef int S;
typedef int poly_type;
int num_slots() const { return slots; }
int degree() const { return A.size() ? A.size() : 0; }

View File

@@ -1,4 +1,4 @@
#include "Math/Subroutines.h"
#include "FHE/Subroutines.h"
#include "FHE/PPData.h"
#include "FHE/FFT.h"
#include "FHE/Matrix.h"

View File

@@ -4,6 +4,7 @@
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/gfp.h"
#include "Math/fixint.h"
#include "FHE/Ring.h"
/* Class for holding modular arithmetic data wrt the ring
@@ -16,6 +17,7 @@ class PPData
public:
typedef gf2n_short T;
typedef bigint S;
typedef fixint<GFP_MOD_SZ> poly_type;
Ring R;
Zp_Data prData;

View File

@@ -3,6 +3,21 @@
#include "FHE/Ring_Element.h"
#include "FHE/PPData.h"
#include "FHE/P2Data.h"
#include "FHE/Rq_Element.h"
#include "FHE_Keys.h"
#include "Math/Z2k.hpp"
template<>
void Plaintext<gfp, FFT_Data, bigint>::from(const Generator<bigint>& source) const
{
for (auto& x : b)
{
source.get(bigint::tmp);
x = bigint::tmp;
}
}
template<>
@@ -11,7 +26,7 @@ void Plaintext<gfp,FFT_Data,bigint>::from_poly() const
if (type!=Polynomial) { return; }
Ring_Element e(*Field_Data,polynomial);
e.from_vec(b);
e.from(b);
e.change_rep(evaluation);
a.resize(n_slots);
for (unsigned int i=0; i<a.size(); i++)
@@ -29,7 +44,7 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
for (unsigned int i=0; i<a.size(); i++)
{ e.set_element(i,a[i].get()); }
e.change_rep(polynomial);
e.to_vec_bigint(b);
from(e.get_iterator());
type=Both;
}
@@ -59,7 +74,10 @@ void Plaintext<gfp,PPData,bigint>::to_poly() const
{ bb[i]=a[i].get(); }
(*Field_Data).from_eval(bb);
for (unsigned int i=0; i<bb.size(); i++)
{ to_bigint(b[i],bb[i],(*Field_Data).prData); }
{
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
b[i] = bigint::tmp;
}
type=Both;
}
@@ -122,7 +140,7 @@ void Plaintext<T, FD, S>::allocate_slots(const bigint& value)
{
b.resize(degree);
for (auto& x : b)
x = value;
x.allocate_slots(value);
}
template<>
@@ -151,52 +169,20 @@ void signed_mod(bigint& x, const bigint& mod, const bigint& half_mod, const bigi
x += dest_mod;
}
template<>
void Plaintext<gfp,FFT_Data,bigint>::set_poly_mod(const Generator<bigint>& generator,const bigint& mod)
template<class T, class FD, class S>
void Plaintext<T, FD, S>::set_poly_mod(const Generator<bigint>& generator,const bigint& mod)
{
allocate(Polynomial);
bigint half_mod = mod / 2;
for (unsigned int i=0; i<b.size(); i++)
{
generator.get(b[i]);
signed_mod(b[i], mod, half_mod, Field_Data->get_prime());
generator.get(bigint::tmp);
signed_mod(bigint::tmp, mod, half_mod, Field_Data->get_prime());
b[i] = bigint::tmp;
}
}
template<>
void Plaintext<gfp,FFT_Data,bigint>::set_poly_mod(const vector<bigint>& vv,const bigint& mod)
{
b = vv;
vector<bigint>& pol = b;
bigint half_mod = mod / 2;
for (unsigned int i=0; i<vv.size(); i++)
{
pol[i] = vv[i];
if (pol[i] > half_mod)
pol[i] -= mod;
pol[i] %= (*Field_Data).get_prime();
if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); }
}
type = Polynomial;
}
template<>
void Plaintext<gfp,PPData,bigint>::set_poly_mod(const vector<bigint>& vv,const bigint& mod)
{
b = vv;
vector<bigint>& pol = b;
for (unsigned int i=0; i<vv.size(); i++)
{ if (vv[i]>mod/2) { pol[i]=vv[i]-mod; }
else { pol[i]=vv[i]; }
pol[i]=pol[i]%(*Field_Data).get_prime();
if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); }
}
type = Polynomial;
}
template<>
@@ -229,7 +215,8 @@ void Plaintext<gf2n_short,P2Data,int>::set_poly_mod(const Generator<bigint>& gen
}
void rand_poly(vector<bigint>& b,PRNG& G,const bigint& pr,bool positive=true)
template<class T>
void rand_poly(vector<T>& b,PRNG& G,const bigint& pr,bool positive=true)
{
for (unsigned int i=0; i<b.size(); i++)
{
@@ -416,7 +403,7 @@ void Plaintext<T,FD,S>::assign_constant(T constant, PT_Type t)
template<class T,class FD,class S>
Plaintext<T,FD,S>& Plaintext<T,FD,S>::operator+=(
const Plaintext<T,FD,S>& y)
const Plaintext& y)
{
if (Field_Data!=y.Field_Data) { throw field_mismatch(); }
@@ -687,8 +674,16 @@ void Plaintext<gf2n_short,P2Data,int>::negate()
template<class T, class FD, class S>
Rq_Element Plaintext<T, FD, S>::mul_by_X_i(int i, const FHE_PK& pk) const
{
return Rq_Element(pk.get_params(), *this).mul_by_X_i(i);
}
template<class T,class FD,class S>
bool Plaintext<T,FD,S>::equals(const Plaintext<T,FD,S>& x) const
bool Plaintext<T,FD,S>::equals(const Plaintext& x) const
{
if (Field_Data!=x.Field_Data) { return false; }
if (type!=x.type)
@@ -730,7 +725,7 @@ void Plaintext<T,FD,S>::unpack(octetStream& o)
if (size != b.size())
throw length_error("unexpected length received");
for (unsigned int i = 0; i < b.size(); i++)
o.get(b[i]);
b[i] = o.get<S>();
}

View File

@@ -18,10 +18,14 @@
*/
#include "FHE/Generator.h"
#include "Math/fixint.h"
#include <vector>
using namespace std;
class FHE_PK;
class Rq_Element;
// Forward declaration as apparently this is needed for friends in templates
template<class T,class FD,class S> class Plaintext;
template<class T,class FD,class S> ostream& operator<<(ostream& s,const Plaintext<T,FD,S>& e);
@@ -35,9 +39,11 @@ enum condition { Full, Diagonal, Bits };
enum PT_Type { Polynomial, Evaluation, Both };
template<class T,class FD,class S>
template<class T,class FD,class _>
class Plaintext
{
typedef typename FD::poly_type S;
int n_slots;
int degree;
@@ -60,7 +66,7 @@ class Plaintext
const FD& get_field() const { return *Field_Data; }
unsigned int num_slots() const { return n_slots; }
void assign(const Plaintext<T,FD,S>& p)
void assign(const Plaintext& p)
{ Field_Data=p.Field_Data;
a=p.a; b=p.b; type=p.type;
n_slots = p.n_slots;
@@ -70,12 +76,7 @@ class Plaintext
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
Plaintext(const Plaintext<T,FD,S>& p) { assign(p); }
~Plaintext() { ; }
Plaintext& operator=(const Plaintext<T,FD,S>& p)
{ if (this!=&p) { assign(p); }
return *this;
}
Plaintext(const FD& FieldD, const Rq_Element& other);
void allocate(PT_Type type) const;
void allocate() const { allocate(type); }
@@ -117,17 +118,23 @@ class Plaintext
void set_poly(const vector<S>& v)
{ type=Polynomial; b=v; }
const vector<S>& get_poly() const
{ if (type==Evaluation) { throw rep_mismatch(); }
{
to_poly();
return b;
}
Iterator<S> get_iterator() const { to_poly(); return b; }
void from(const Generator<bigint>& source) const;
// This sets a poly from a vector of bigint's which needs centering
// modulo mod, before assigning (used in decryption)
// vv[i] is already assumed reduced modulo mod though but in
// range [0,...,mod)
void set_poly_mod(const vector<bigint>& vv,const bigint& mod);
void set_poly_mod(const vector<bigint>& vv,const bigint& mod)
{
set_poly_mod(Iterator<bigint>(vv), mod);
}
void set_poly_mod(const Generator<bigint>& generator, const bigint& mod);
// Converts between Evaluation,Polynomial and Both representations
@@ -144,36 +151,38 @@ class Plaintext
void assign_one(PT_Type t = Evaluation);
void assign_constant(T constant, PT_Type t = Evaluation);
friend void add<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
friend void sub<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
friend void mul<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
friend void sqr<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x);
friend void add<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
friend void sub<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
friend void mul<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
friend void sqr<>(Plaintext& z,const Plaintext& x);
Plaintext<T,FD,S> operator+(const Plaintext<T,FD,S>& x) const
{ Plaintext<T,FD,S> res(*Field_Data); add(res, *this, x); return res; }
Plaintext<T,FD,S> operator-(const Plaintext<T,FD,S>& x) const
{ Plaintext<T,FD,S> res(*Field_Data); sub(res, *this, x); return res; }
Plaintext operator+(const Plaintext& x) const
{ Plaintext res(*Field_Data); add(res, *this, x); return res; }
Plaintext operator-(const Plaintext& x) const
{ Plaintext res(*Field_Data); sub(res, *this, x); return res; }
void mul(const Plaintext<T,FD,S>& x, const Plaintext<T,FD,S>& y)
void mul(const Plaintext& x, const Plaintext& y)
{ x.from_poly(); y.from_poly(); ::mul(*this, x, y); }
Plaintext<T,FD,S> operator*(const Plaintext<T,FD,S>& x)
{ Plaintext<T,FD,S> res(*Field_Data); res.mul(*this, x); return res; }
Plaintext operator*(const Plaintext& x)
{ Plaintext res(*Field_Data); res.mul(*this, x); return res; }
Plaintext<T,FD,S>& operator+=(const Plaintext<T,FD,S>& y);
Plaintext<T,FD,S>& operator-=(const Plaintext<T,FD,S>& y)
Plaintext& operator+=(const Plaintext& y);
Plaintext& operator-=(const Plaintext& y)
{ to_poly(); y.to_poly(); ::sub(*this, *this, y); return *this; }
void negate();
bool equals(const Plaintext<T,FD,S>& x) const;
bool operator!=(const Plaintext<T,FD,S>& x) { return !equals(x); }
Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const;
bool equals(const Plaintext& x) const;
bool operator!=(const Plaintext& x) { return !equals(x); }
bool is_diagonal() const { throw not_implemented(); }
bool is_binary() const { throw not_implemented(); }
friend ostream& operator<< <>(ostream& s,const Plaintext<T,FD,S>& e);
friend istream& operator>> <>(istream& s,Plaintext<T,FD,S>& e);
friend ostream& operator<< <>(ostream& s,const Plaintext& e);
friend istream& operator>> <>(istream& s,Plaintext& e);
/* Pack and unpack into an octetStream
* For unpack we assume the FFTD has been assigned correctly already

View File

@@ -9,8 +9,10 @@
class FHE_PK;
class Int_Random_Coins : public AddableMatrix<bigint>
class Int_Random_Coins : public AddableMatrix<fixint<0>>
{
typedef value_type::value_type T;
const FHE_Params* params;
public:
Int_Random_Coins(const FHE_Params& params) : params(&params)
@@ -20,14 +22,16 @@ public:
void sample(PRNG& G)
{
(*this)[0].from(HalfGenerator(G));
(*this)[0].from(HalfGenerator<T>(G));
for (int i = 1; i < 3; i++)
(*this)[i].from(GaussianGenerator(params->get_DG(), G));
(*this)[i].from(GaussianGenerator<T>(params->get_DG(), G));
}
};
class Random_Coins
{
typedef bigint T;
Rq_Element uu,vv,ww;
const FHE_Params *params;
@@ -56,16 +60,25 @@ class Random_Coins
template <class T>
void assign(const vector<T>& u,const vector<T>& v,const vector<T>& w)
{ uu.from_vec(u); vv.from_vec(v); ww.from_vec(w); }
{
uu.from(u);
vv.from(v);
ww.from(w);
}
void assign(const Int_Random_Coins& rc)
{ uu.from_vec(rc[0]); vv.from_vec(rc[1]); ww.from_vec(rc[2]); }
{
uu.from(rc[0]);
vv.from(rc[1]);
ww.from(rc[2]);
}
/* Generate a standard distribution */
void generate(PRNG& G)
{ uu.from(HalfGenerator(G));
vv.from(GaussianGenerator(params->get_DG(), G));
ww.from(GaussianGenerator(params->get_DG(), G));
{
uu.from(HalfGenerator<T>(G));
vv.from(GaussianGenerator<T>(params->get_DG(), G));
ww.from(GaussianGenerator<T>(params->get_DG(), G));
}
// Generate all from Uniform in range (-B,...B)
@@ -74,9 +87,9 @@ class Random_Coins
if (B1 == 0)
uu.assign_zero();
else
uu.from(UniformGenerator(G,numBits(B1)));
vv.from(UniformGenerator(G,numBits(B2)));
ww.from(UniformGenerator(G,numBits(B3)));
uu.from(UniformGenerator<T>(G,numBits(B1)));
vv.from(UniformGenerator<T>(G,numBits(B2)));
ww.from(UniformGenerator<T>(G,numBits(B3)));
}

View File

@@ -147,6 +147,47 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
}
Ring_Element Ring_Element::mul_by_X_i(int j) const
{
Ring_Element ans;
auto& a = *this;
ans.partial_assign(a);
if (ans.rep == evaluation)
{
modp xj, xj2;
Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD());
Sqr(xj2, xj, (*a.FFTD).get_prD());
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{
Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD());
Mul(xj, xj, xj2, (*a.FFTD).get_prD());
}
}
else
{
Ring_Element aa(*ans.FFTD, ans.rep);
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
{
int k= j + i, s= 1;
while (k >= (*ans.FFTD).phi_m())
{
k-= (*ans.FFTD).phi_m();
s= -s;
}
if (s == 1)
{
aa.element[k]= a.element[i];
}
else
{
Negate(aa.element[k], a.element[i], (*a.FFTD).get_prD());
}
}
ans= aa;
}
return ans;
}
void Ring_Element::randomize(PRNG& G,bool Diag)
{
@@ -318,20 +359,6 @@ void Ring_Element::from_vec(const vector<int>& v)
// cout << "RE:from_vec<int>:: " << *this << endl;
}
template <class T>
void Ring_Element::from(const Generator<T>& generator)
{
RepType t=rep;
rep=polynomial;
T tmp;
for (int i=0; i<(*FFTD).phi_m(); i++)
{
generator.get(tmp);
element[i].convert_destroy(tmp, (*FFTD).get_prD());
}
change_rep(t);
}
ConversionIterator Ring_Element::get_iterator() const
{
if (rep != polynomial)
@@ -389,6 +416,7 @@ modp Ring_Element::get_constant() const
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
{
ZpD.pack(o);
o.store((int)v.size());
for (unsigned int i=0; i<v.size(); i++)
{ v[i].pack(o,ZpD); }
@@ -397,6 +425,12 @@ void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
{
Zp_Data check_Zpd;
check_Zpd.unpack(o);
if (check_Zpd != ZpD)
throw runtime_error(
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
+ to_string(ZpD.pr_bit_length));
unsigned int length;
o.get(length);
v.resize(length);
@@ -408,7 +442,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
void Ring_Element::pack(octetStream& o) const
{
check_size();
o.store(rep);
o.store(unsigned(rep));
store(o,element,(*FFTD).get_prD());
}

View File

@@ -90,6 +90,8 @@ class Ring_Element
friend void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b);
friend void mul(Ring_Element& ans,const Ring_Element& a,const modp& b);
Ring_Element mul_by_X_i(int i) const;
void randomize(PRNG& G,bool Diag=false);
bool equals(const Ring_Element& a) const;
@@ -116,6 +118,12 @@ class Ring_Element
template <class T>
void from(const Generator<T>& generator);
template <class T>
void from(const vector<T>& source)
{
from(Iterator<T>(source));
}
// This gets the constant term of the poly rep as a modp element
modp get_constant() const;
modp get_element(int i) const { return element[i]; }
@@ -167,5 +175,20 @@ public:
inline void mul(Ring_Element& ans,const modp& a,const Ring_Element& b)
{ mul(ans,b,a); }
template <class T>
void Ring_Element::from(const Generator<T>& generator)
{
RepType t=rep;
rep=polynomial;
T tmp;
for (int i=0; i<(*FFTD).phi_m(); i++)
{
generator.get(tmp);
element[i].convert_destroy(tmp, (*FFTD).get_prD());
}
change_rep(t);
}
#endif

View File

@@ -1,6 +1,12 @@
#include "Rq_Element.h"
#include "FHE_Keys.h"
#include "Exceptions/Exceptions.h"
Rq_Element::Rq_Element(const FHE_PK& pk) :
Rq_Element(pk.get_params().FFTD())
{
}
Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
{
if (prd.size() > 0)
@@ -57,6 +63,19 @@ void Rq_Element::negate()
a[i].negate();
}
Rq_Element Rq_Element::mul_by_X_i(int i) const
{
Rq_Element res;
res.lev = lev;
res.a.clear();
for (auto& x : a)
{
auto tmp = x.mul_by_X_i(i);
res.a.push_back(tmp);
}
return res;
}
void add(Rq_Element& ans,const Rq_Element& ra,const Rq_Element& rb)
{
ans.partial_assign(ra, rb);
@@ -173,19 +192,6 @@ void Rq_Element::from_vec(const vector<int>& v,int level)
a[i].from_vec(v);
}
template <class T>
void Rq_Element::from(const Generator<T>& generator, int level)
{
set_level(level);
if (lev == 1)
{
auto clone = generator.clone();
a[1].from(*clone);
delete clone;
}
a[0].from(generator);
}
vector<bigint> Rq_Element::to_vec_bigint() const
{
vector<bigint> v;
@@ -220,7 +226,7 @@ void Rq_Element::to_vec_bigint(vector<bigint>& v) const
}
}
ConversionIterator Rq_Element::get_iterator()
ConversionIterator Rq_Element::get_iterator() const
{
if (lev != 0)
throw not_implemented();
@@ -339,7 +345,8 @@ void Rq_Element::raise_level()
void Rq_Element::check_level() const
{
if ((unsigned)lev > (unsigned)n_mults())
throw range_error("level out of range");
throw range_error(
"level out of range: " + to_string(lev) + "/" + to_string(n_mults()));
}
void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)

View File

@@ -17,6 +17,7 @@
#include "FHE/FHE_Params.h"
#include "FHE/tools.h"
#include "FHE/Generator.h"
#include "Plaintext.h"
#include <vector>
// Forward declare the friend functions
@@ -62,9 +63,18 @@ protected:
Rq_Element(const FHE_Params& params) :
Rq_Element(params.FFTD()) {}
Rq_Element(const FHE_PK& pk);
Rq_Element(const Ring_Element& b0,const Ring_Element& b1) :
a({b0, b1}), lev(n_mults()) {}
template<class T, class FD, class S>
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext) :
Rq_Element(params)
{
from(plaintext.get_iterator());
}
// Destructor
~Rq_Element()
{ ; }
@@ -91,6 +101,8 @@ protected:
// Multiply something by p1 and make level 1
void mul_by_p1();
Rq_Element mul_by_X_i(int i) const;
void randomize(PRNG& G,int lev=-1);
// Scale from level 1 to level 0, if at level 0 do nothing
@@ -113,10 +125,16 @@ protected:
vector<bigint> to_vec_bigint() const;
void to_vec_bigint(vector<bigint>& v) const;
ConversionIterator get_iterator();
ConversionIterator get_iterator() const;
template <class T>
void from(const Generator<T>& generator, int level=-1);
template <class T>
void from(const vector<T>& source, int level=-1)
{
from(Iterator<T>(source), level);
}
bigint infinity_norm() const;
bigint get_prime(int i) const
@@ -156,9 +174,22 @@ template<class S>
Rq_Element& Rq_Element::operator+=(const vector<S>& other)
{
Rq_Element tmp = *this;
tmp.from_vec(other, lev);
tmp.from(Iterator<S>(other), lev);
add(*this, *this, tmp);
return *this;
}
template <class T>
void Rq_Element::from(const Generator<T>& generator, int level)
{
set_level(level);
if (lev == 1)
{
auto clone = generator.clone();
a[1].from(*clone);
delete clone;
}
a[0].from(generator);
}
#endif

View File

@@ -1,8 +1,8 @@
#include "Subroutines.h"
#include "modp.h"
#include "Math/modp.h"
#include "modp.hpp"
#include "Math/modp.hpp"
void Subs(modp& ans,const vector<int>& poly,const modp& x,const Zp_Data& ZpD)
{

View File

@@ -18,6 +18,11 @@ CutAndChooseMachine::CutAndChooseMachine(int argc, const char** argv)
"--covert" // Flag token.
);
parse_options(argc, argv);
if (produce_inputs)
{
cerr << "Producing input tuples is not implemented" << endl;
exit(1);
}
covert = opt.isSet("--covert");
if (not covert and sec != 40)
throw runtime_error("active cut-and-choose only implemented for 40-bit security");

View File

@@ -8,6 +8,10 @@
#include "Protocols/fake-stuff.h"
#include "FHE/NTL-Subs.h"
#include "Tools/benchmarking.h"
#include "Tools/Bundle.h"
#include "PairwiseSetup.h"
#include "Proof.h"
#include "SimpleMachine.h"
#include <iostream>
using namespace std;
@@ -58,8 +62,8 @@ void PartSetup<FD>::generate_setup(int n_parties, int plaintext_length, int sec,
int slack, bool round_up)
{
sec = max(sec, 40);
::generate_setup(n_parties, plaintext_length, sec, params, FieldD,
slack, round_up);
Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup(
params, FieldD);
params.set_sec(sec);
pk = FHE_PK(params, FieldD.get_prime());
sk = FHE_SK(params, FieldD.get_prime());
@@ -254,6 +258,7 @@ void DataSetup::output(int my_number, int nn, bool specific_dir)
template <class FD>
void PartSetup<FD>::pack(octetStream& os)
{
os.append((octet*)"PARTSETU", 8);
params.pack(os);
FieldD.pack(os);
pk.pack(os);
@@ -265,8 +270,15 @@ void PartSetup<FD>::pack(octetStream& os)
template <class FD>
void PartSetup<FD>::unpack(octetStream& os)
{
char tag[8];
os.consume((octet*) tag, 8);
if (memcmp(tag, "PARTSETU", 8))
throw runtime_error("invalid serialization of setup");
params.unpack(os);
FieldD.unpack(os);
pk = {params, FieldD};
sk = pk;
calpha = params;
pk.unpack(os);
sk.unpack(os);
calpha.unpack(os);
@@ -305,5 +317,94 @@ bool PartSetup<FD>::operator!=(const PartSetup<FD>& other)
return false;
}
template<class FD>
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
}
template<class FD>
void PartSetup<FD>::generate(Player& P, MachineBase&, int plaintext_length,
int sec)
{
generate_setup(P.num_players(), plaintext_length, sec,
INTERACTIVE_SPDZ1_SLACK, false);
}
template<class FD>
void PartSetup<FD>::check(Player& P, MachineBase& machine)
{
Bundle<octetStream> bundle(P);
bundle.mine.store(machine.extra_slack);
auto& os = bundle.mine;
params.pack(os);
FieldD.hash(os);
pk.pack(os);
calpha.pack(os);
bundle.compare(P);
}
template<class FD>
void PartSetup<FD>::covert_key_generation(Player& P,
MultiplicativeMachine& machine, int num_runs)
{
auto& setup = machine.setup.part<FD>();
Run_Gen_Protocol(setup.pk, setup.sk, P, num_runs, false);
}
template<class FD>
void PartSetup<FD>::covert_mac_generation(Player& P,
MultiplicativeMachine& machine, int num_runs)
{
auto& setup = machine.setup.part<FD>();
generate_mac_key(setup.alphai, setup.calpha, setup.FieldD, setup.pk, P,
num_runs);
}
template<class FD>
void PartSetup<FD>::covert_secrets_generation(Player& P,
MultiplicativeMachine& machine, int num_runs)
{
octetStream os;
params.pack(os);
FieldD.pack(os);
string filename = PREP_DIR "ChaiGear-Secrets-" + to_string(num_runs) + "-"
+ os.check_sum(20).get_str(16) + "-P" + to_string(P.my_num());
string error;
try
{
ifstream input(filename);
os.input(input);
unpack(os);
}
catch (exception& e)
{
error = e.what();
}
try
{
check(P, machine);
}
catch (mismatch_among_parties& e)
{
error = e.what();
}
if (not error.empty())
{
cerr << "Running secrets generation because " << error << endl;
covert_key_generation(P, machine, num_runs);
covert_mac_generation(P, machine, num_runs);
ofstream output(filename);
octetStream os;
pack(os);
os.output(output);
}
}
template class PartSetup<FFT_Data>;
template class PartSetup<P2Data>;

View File

@@ -12,6 +12,8 @@
#include "Math/Setup.h"
class DataSetup;
class MachineBase;
class MultiplicativeMachine;
template <class FD>
class PartSetup
@@ -26,6 +28,11 @@ public:
Ciphertext calpha;
typename FD::T alphai;
static string name()
{
return "GlobalParams-" + T::type_string();
}
PartSetup();
void generate_setup(int n_parties, int plaintext_length, int sec, int slack,
bool round_up);
@@ -42,6 +49,19 @@ public:
void check(int sec) const;
bool operator!=(const PartSetup<FD>& other);
void secure_init(Player& P, MachineBase& machine, int plaintext_length,
int sec);
void generate(Player& P, MachineBase& machine, int plaintext_length,
int sec);
void check(Player& P, MachineBase& machine);
void covert_key_generation(Player& P, MultiplicativeMachine& machine,
int num_runs);
void covert_mac_generation(Player& P, MultiplicativeMachine& machine,
int num_runs);
void covert_secrets_generation(Player& P, MultiplicativeMachine& machine,
int num_runs);
};
class DataSetup

View File

@@ -13,7 +13,7 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
bigint limit = pk.get_params().Q() << 64;
vv.allocate_slots(limit);
vv1.allocate_slots(limit);
mf.allocate_slots(pk.get_params().p0() << 64);
mf.allocate_slots(pk.p() << 64);
}
template<class FD>

View File

@@ -262,7 +262,7 @@ void EncCommit<T,FD,S>::Create_More() const
{ if (cond!=Full)
{ throw not_implemented(); }
else
{ m_Delta.from(UniformGenerator(Gseed[i],numBits(Bound1))); }
{ m_Delta.from(UniformGenerator<bigint>(Gseed[i],numBits(Bound1))); }
rc_Delta.generateUniform(Gseed[i],Bound2,Bound3,Bound3);
Ciphertext Delta(params);
(*pk).quasi_encrypt(Delta,m_Delta,rc_Delta);
@@ -319,7 +319,7 @@ void EncCommit<T,FD,S>::Create_More() const
if (cond!=Full)
{ throw not_implemented(); }
else
{ mm.from(UniformGenerator(G,numBits(Bound1))); }
{ mm.from(UniformGenerator<bigint>(G, numBits(Bound1))); }
rr.generateUniform(G,Bound2,Bound3,Bound3);
(*pk).quasi_encrypt(cc,mm,rr);
occ.reset_write_head();
@@ -357,10 +357,10 @@ void EncCommit<T,FD,S>::Create_More() const
{ throw not_implemented();
}
else
{ m_Delta.from(UniformGenerator(G,numBits(Bound1))); }
{ m_Delta.from(UniformGenerator<bigint>(G, numBits(Bound1))); }
rc_Delta.generateUniform(G,Bound2,Bound3,Bound3);
Iterator<S> vm=m[i].get_iterator();
auto vm=m[i].get_iterator();
z[0].from(vm);
add(z[0],z[0],m_Delta);
add(rr,rc[i],rc_Delta);

View File

@@ -12,13 +12,17 @@
#include "FHE/Plaintext.h"
#include "Tools/MemoryUsage.h"
class MachineBase;
template<class T,class FD,class S>
class EncCommitBase
{
public:
size_t volatile_memory;
EncCommitBase() : volatile_memory(0) {}
const MachineBase* machine;
EncCommitBase(const MachineBase* machine = 0) : volatile_memory(0), machine(machine) {}
virtual ~EncCommitBase() {}
virtual condition get_condition() { return Full; }
virtual void next(Plaintext<T,FD,S>& mess, Ciphertext& c)

View File

@@ -27,7 +27,7 @@ Multiplier<FD>::Multiplier(int offset, PairwiseMachine& machine, Player& P,
product_share(machine.setup<FD>().FieldD), rc(machine.pk),
volatile_capacity(0)
{
product_share.allocate_slots(machine.setup<FD>().params.p0() << 64);
product_share.allocate_slots(machine.pk.p() << 64);
}
template <class FD>
@@ -35,7 +35,7 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
const Ciphertext& enc_a, const Plaintext_<FD>& b)
{
Rq_Element bb(enc_a.get_params(), evaluation, evaluation);
bb.from_vec(b.get_poly());
bb.from(b.get_iterator());
multiply_and_add(res, enc_a, bb);
}

View File

@@ -7,6 +7,8 @@
#include "FHEOffline/PairwiseMachine.h"
#include "FHEOffline/Producer.h"
#include "Protocols/SemiShare.h"
#include "GC/SemiSecret.h"
#include "GC/SemiPrep.h"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiInput.hpp"
@@ -21,20 +23,21 @@ PairwiseGenerator<FD>::PairwiseGenerator(int thread_num,
thread_num, machine.output),
EC(P, machine.other_pks, machine.setup<FD>().FieldD, timers, machine, *this),
MC(machine.setup<FD>().alphai),
C(machine.sec, machine.setup<FD>().params), volatile_memory(0),
n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)),
C(n_ciphertexts, machine.setup<FD>().params), volatile_memory(0),
machine(machine)
{
for (int i = 1; i < P.num_players(); i++)
multipliers.push_back(new Multiplier<FD>(i, *this));
const FD& FieldD = machine.setup<FD>().FieldD;
a.resize(machine.sec, FieldD);
b.resize(machine.sec, FieldD);
c.resize(machine.sec, FieldD);
a.resize(n_ciphertexts, FieldD);
b.resize(n_ciphertexts, FieldD);
c.resize(n_ciphertexts, FieldD);
a.allocate_slots(FieldD.get_prime());
b.allocate_slots(FieldD.get_prime());
// extra limb for addition
c.allocate_slots((bigint)FieldD.get_prime() << 64);
b_mod_q.resize(machine.sec,
b_mod_q.resize(n_ciphertexts,
{ machine.setup<FD>().params, evaluation, evaluation });
}
@@ -64,8 +67,8 @@ void PairwiseGenerator<FD>::run()
c.mul(a, b);
timers["Plaintext multiplication"].stop();
timers["FFT of b"].start();
for (int i = 0; i < machine.sec; i++)
b_mod_q.at(i).from_vec(b.at(i).get_poly());
for (int i = 0; i < n_ciphertexts; i++)
b_mod_q.at(i).from(b.at(i).get_iterator());
timers["FFT of b"].stop();
timers["Proof exchange"].start();
size_t verifier_memory = EC.create_more(ciphertexts, cleartexts);
@@ -73,7 +76,7 @@ void PairwiseGenerator<FD>::run()
volatile_memory = max(prover_memory, verifier_memory);
Rq_Element values({machine.setup<FD>().params, evaluation, evaluation});
for (int k = 0; k < machine.sec; k++)
for (int k = 0; k < n_ciphertexts; k++)
{
producer.ai = a[k];
producer.bi = b[k];
@@ -90,7 +93,7 @@ void PairwiseGenerator<FD>::run()
else
{
timers["Plaintext conversion"].start();
values.from_vec(producer.values[j].get_poly());
values.from(producer.values[j].get_iterator());
timers["Plaintext conversion"].stop();
}
@@ -122,7 +125,7 @@ void PairwiseGenerator<FD>::generate_inputs(int player)
{
SeededPRNG G;
b[0].randomize(G);
b_mod_q.at(0).from_vec(b.at(0).get_poly());
b_mod_q.at(0).from(b.at(0).get_iterator());
producer.macs[0].mul(machine.setup<FD>().alpha, b[0]);
}
else

View File

@@ -29,6 +29,8 @@ class PairwiseGenerator : public GeneratorBase
MultiEncCommit<FD> EC;
MAC_Check<T> MC;
int n_ciphertexts;
// temporary data
AddableVector<Ciphertext> C;
octetStream ciphertexts, cleartexts;

View File

@@ -48,12 +48,21 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
template <class FD>
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
alpha = FieldD;
}
template <class T>
void secure_init(T& setup, Player& P, MachineBase& machine,
int plaintext_length, int sec)
{
machine.sec = sec;
sec = max(sec, 40);
machine.drown_sec = sec;
string filename = PREP_DIR "Params-" + FD::T::type_string() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-P"
string filename = PREP_DIR + T::name() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num());
try
{
@@ -61,38 +70,54 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
octetStream os;
os.input(file);
os.get(machine.extra_slack);
params.unpack(os);
FieldD.unpack(os);
FieldD.init_field();
check(P, machine);
setup.unpack(os);
setup.check(P, machine);
}
catch (...)
{
cout << "Finding parameters for security " << sec << " and field size ~2^"
<< plaintext_length << endl;
machine.extra_slack = generate_semi_setup(plaintext_length, sec, params, FieldD, true);
check(P, machine);
setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine);
octetStream os;
os.store(machine.extra_slack);
params.pack(os);
FieldD.pack(os);
setup.pack(os);
ofstream file(filename);
os.output(file);
}
alpha = FieldD;
}
template <class FD>
void PairwiseSetup<FD>::check(Player& P, PairwiseMachine& machine)
void PairwiseSetup<FD>::generate(Player&, MachineBase& machine,
int plaintext_length, int sec)
{
machine.extra_slack = generate_semi_setup(plaintext_length, sec, params,
FieldD, true);
}
template<class FD>
void PairwiseSetup<FD>::pack(octetStream& os) const
{
params.pack(os);
FieldD.pack(os);
}
template<class FD>
void PairwiseSetup<FD>::unpack(octetStream& os)
{
params.unpack(os);
FieldD.unpack(os);
FieldD.init_field();
}
template <class FD>
void PairwiseSetup<FD>::check(Player& P, MachineBase& machine)
{
Bundle<octetStream> bundle(P);
bundle.mine.store(machine.extra_slack);
params.pack(bundle.mine);
FieldD.hash(bundle.mine);
P.Broadcast_Receive(bundle, true);
for (auto& os : bundle)
if (os != bundle.mine)
throw runtime_error("mismatch of parameters among parties");
bundle.compare(P);
}
template <class FD>
@@ -161,3 +186,6 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
template class PairwiseSetup<FFT_Data>;
template class PairwiseSetup<P2Data>;
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);

View File

@@ -11,6 +11,11 @@
#include "Networking/Player.h"
class PairwiseMachine;
class MachineBase;
template <class T>
void secure_init(T& setup, Player& P, MachineBase& machine,
int plaintext_length, int sec);
template <class FD>
class PairwiseSetup
@@ -24,15 +29,24 @@ public:
Plaintext_<FD> alpha;
string dirname;
static string name()
{
return "PairwiseParams-" + FD::T::type_string();
}
PairwiseSetup() : params(0), alpha(FieldD) {}
void init(const Player& P, int sec, int plaintext_length, int& extra_slack);
void secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec);
void check(Player& P, PairwiseMachine& machine);
void generate(Player& P, MachineBase& machine, int plaintext_length, int sec);
void check(Player& P, MachineBase& machine);
void covert_key_generation(Player& P, PairwiseMachine& machine, int num_runs);
void covert_mac_generation(Player& P, PairwiseMachine& machine, int num_runs);
void pack(octetStream& os) const;
void unpack(octetStream& os);
void set_alphai(T alphai);
};

View File

@@ -9,6 +9,8 @@
#include "Sacrificing.h"
#include "Reshare.h"
#include "DistDecrypt.h"
#include "SimpleEncCommit.h"
#include "SimpleMachine.h"
#include "Tools/mkpath.h"
template<class FD>
@@ -610,7 +612,7 @@ InputProducer<FD>::~InputProducer()
template<class FD>
void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
const Ciphertext& calpha, EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd,
const T& alphai)
const T& alphai, int player)
{
(void)alphai;
@@ -620,43 +622,86 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
G.ReSeed();
Ciphertext gama(params),dummyc(params),ca(params);
vector<octetStream> oca(P.num_players());
const FD& FieldD = dd.f.get_field();
Plaintext<T,FD,S> a(FieldD),ai(FieldD),gai(FieldD);
Random_Coins rc(params);
this->n_slots = FieldD.num_slots();
Share<T> Sh;
a.randomize(G);
rc.generate(G);
pk.encrypt(ca, a, rc);
ca.pack(oca[P.my_num()]);
P.Broadcast_Receive(oca);
inputs.resize(P.num_players());
for (int j = 0; j < P.num_players(); j++)
int min, max;
if (player < 0)
{
ca.unpack(oca[j]);
// Reshare the aj values
dd.reshare(ai, ca, EC);
min = 0;
max = P.num_players();
}
else
{
min = player;
max = player + 1;
}
// Generate encrypted MAC values
mul(gama, calpha, ca, pk);
map<string, Timer> timers;
assert(EC.machine);
SimpleEncCommit_<FD> personal_EC(P, pk, FieldD, timers, *EC.machine, 0);
octetStream ciphertexts, cleartexts;
// Get shares on the MACs
dd.reshare(gai, gama, EC);
for (unsigned int i = 0; i < ai.num_slots(); i++)
for (int j = min; j < max; j++)
{
AddableVector<Ciphertext> C;
vector<Plaintext_<FD>> m(EC.machine->sec, FieldD);
if (j == P.my_num())
{
Sh.set_share(ai.element(i));
Sh.set_mac(gai.element(i));
if (write_output)
for (auto& x : m)
x.randomize(G);
personal_EC.generate_proof(C, m, ciphertexts, cleartexts);
P.send_all(ciphertexts, true);
P.send_all(cleartexts, true);
}
else
{
P.receive_player(j, ciphertexts, true);
P.receive_player(j, cleartexts, true);
C.resize(personal_EC.machine->sec, pk.get_params());
Verifier<FD, fixint<GFP_MOD_SZ>>(personal_EC.proof).NIZKPoK(C, ciphertexts,
cleartexts, pk, false, false);
}
inputs[j].clear();
for (size_t i = 0; i < C.size(); i++)
{
auto& ca = C.at(i);
auto& a = m.at(i);
// Reshare the aj values
dd.reshare(ai, ca, EC);
// Generate encrypted MAC values
mul(gama, calpha, ca, pk);
// Get shares on the MACs
dd.reshare(gai, gama, EC);
for (unsigned int i = 0; i < ai.num_slots(); i++)
{
Sh.output(outf[j], false);
if (j == P.my_num())
Sh.set_share(ai.element(i));
Sh.set_mac(gai.element(i));
if (write_output)
{
a.element(i).output(outf[j], false);
Sh.output(outf[j], false);
if (j == P.my_num())
{
a.element(i).output(outf[j], false);
}
}
else
{
inputs[j].push_back({Sh, {}});
if (j == P.my_num())
inputs[j].back().value = a.element(i);
}
}
}

View File

@@ -221,6 +221,8 @@ class InputProducer : public Producer<FD>
bool write_output;
public:
vector<vector<InputTuple<Share<T>>>> inputs;
InputProducer(const Player& P, int output_thread = 0, bool write_output = true,
string dir = PREP_DIR);
~InputProducer();
@@ -228,7 +230,15 @@ public:
string data_type() { return "Inputs"; }
void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha,
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai);
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai)
{
run(P, pk, calpha, EC, dd, alphai, -1);
}
void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha,
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai,
int player);
int sacrifice(const Player& P, MAC_Check<T>& MC);
// no ops

View File

@@ -6,6 +6,7 @@
#include "Proof.h"
#include "FHE/P2Data.h"
#include "FHEOffline/EncCommit.h"
#include "Math/Z2k.hpp"
double Proof::dist = 0;
@@ -32,22 +33,50 @@ bigint Proof::slack(int slack, int sec, int phim)
}
}
void Proof::get_challenge(vector<int>& e, const octetStream& ciphertexts) const
void Proof::set_challenge(const octetStream& ciphertexts)
{
unsigned int i;
bigint hashout = ciphertexts.check_sum();
for (i=0; i<sec; i++)
{ e[i]=(hashout.get_ui()>>(i))&1; }
octetStream hash = ciphertexts.hash();
PRNG G;
assert(hash.get_length() >= SEED_SIZE);
G.SetSeed(hash.get_data());
set_challenge(G);
}
void Proof::set_challenge(PRNG& G)
{
unsigned int i;
if (top_gear)
{
W.resize(V, vector<int>(U));
for (i = 0; i < V; i++)
for (unsigned j = 0; j < U; j++)
W[i][j] = G.get_uint(2 * phim) - 1;
}
else
{
e.resize(sec);
for (i = 0; i < sec; i++)
{
e[i] = G.get_bit();
}
}
}
void Proof::generate_challenge(const Player& P)
{
GlobalPRNG G(P);
set_challenge(G);
}
template<class T>
class AbsoluteBoundChecker
{
bigint bound, neg_bound;
T bound, neg_bound;
public:
AbsoluteBoundChecker(bigint bound) : bound(bound), neg_bound(-bound) {}
bool outside(const bigint& value, double& dist)
AbsoluteBoundChecker(T bound) : bound(bound), neg_bound(-this->bound) {}
bool outside(const T& value, double& dist)
{
(void)dist;
#ifdef PRINT_MIN_DIST
@@ -57,17 +86,17 @@ public:
}
};
template <class T>
bool Proof::check_bounds(T& z, AddableMatrix<bigint>& t, int i) const
template <class T, class X>
bool Proof::check_bounds(T& z, X& t, int i) const
{
unsigned int j,k;
// Check Bound 1 and Bound 2
AbsoluteBoundChecker plain_checker(plain_check * n_proofs);
AbsoluteBoundChecker rand_checker(rand_check * n_proofs);
AbsoluteBoundChecker<fixint<2>> plain_checker(plain_check * n_proofs);
AbsoluteBoundChecker<fixint<2>> rand_checker(rand_check * n_proofs);
for (j=0; j<phim; j++)
{
const bigint& te = z[j];
auto& te = z[j];
if (plain_checker.outside(te, dist))
{
cout << "Fail on Check 1 " << i << " " << j << endl;
@@ -78,10 +107,10 @@ bool Proof::check_bounds(T& z, AddableMatrix<bigint>& t, int i) const
}
for (k=0; k<3; k++)
{
const vector<bigint>& coeffs = t[k];
auto& coeffs = t[k];
for (j=0; j<coeffs.size(); j++)
{
const bigint& te = coeffs.at(j);
auto& te = coeffs.at(j);
if (rand_checker.outside(te, dist))
{
cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl;
@@ -101,7 +130,7 @@ Proof::Preimages::Preimages(int size, const FHE_PK& pk, const bigint& p, int n_p
// extra limb for addition
bigint limit = p << (64 + n_players);
m.allocate_slots(limit);
r.allocate_slots(limit);
r.allocate_slots(n_players);
m_tmp = m[0][0];
r_tmp = r[0][0];
}
@@ -136,8 +165,8 @@ void Proof::Preimages::unpack(octetStream& os)
unsigned int size;
os.get(size);
m.resize(size);
if (size != r.size())
throw runtime_error("unexpected size of preimage randomness");
assert(not r.empty());
r.resize(size, r[0]);
for (size_t i = 0; i < m.size(); i++)
{
m[i].unpack(os);
@@ -151,7 +180,6 @@ void Proof::Preimages::check_sizes()
throw runtime_error("preimage sizes don't match");
}
template bool Proof::check_bounds(Plaintext<gfp,FFT_Data,bigint>& z, AddableMatrix<bigint>& t, int i) const;
template bool Proof::check_bounds(AddableVector<bigint>& z, AddableMatrix<bigint>& t, int i) const;
template bool Proof::check_bounds(Plaintext_<P2Data>& z, AddableMatrix<bigint>& t, int i) const;
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<0>>& t, int i) const;
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<1>>& t, int i) const;
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<2>>& t, int i) const;

View File

@@ -8,6 +8,7 @@ using namespace std;
#include "Math/bigint.h"
#include "FHE/Ciphertext.h"
#include "FHE/AddableVector.h"
#include "Protocols/CowGearOptions.h"
#include "config.h"
@@ -21,6 +22,8 @@ enum SlackType
class Proof
{
unsigned int sec;
Proof(); // Private to avoid default
public:
@@ -29,12 +32,14 @@ class Proof
class Preimages
{
bigint m_tmp;
AddableVector<bigint> r_tmp;
typedef Int_Random_Coins::value_type::value_type r_type;
fixint<GFP_MOD_SZ> m_tmp;
AddableVector<r_type> r_tmp;
public:
Preimages(int size, const FHE_PK& pk, const bigint& p, int n_players);
AddableMatrix<bigint> m;
AddableMatrix<fixint<GFP_MOD_SZ>> m;
Randomness r;
void add(octetStream& os);
void pack(octetStream& os);
@@ -43,18 +48,22 @@ class Proof
size_t report_size(ReportType type) { return m.report_size(type) + r.report_size(type); }
};
unsigned int sec;
bigint tau,rho;
unsigned int phim;
int B_plain_length, B_rand_length;
bigint plain_check, rand_check;
unsigned int V;
unsigned int U;
const FHE_PK* pk;
int n_proofs;
vector<int> e;
vector<vector<int>> W;
bool top_gear;
static double dist;
protected:
@@ -65,19 +74,67 @@ class Proof
tau=Tau; rho=Rho;
phim=(pk.get_params()).phi_m();
V=2*sec-1;
top_gear = use_top_gear(pk);
if (top_gear)
{
V = ceil((sec + 2) / log2(2 * phim + 1));
U = 2 * V;
#ifdef VERBOSE
cerr << "Using " << U << " ciphertexts per proof" << endl;
#endif
}
else
{
U = sec;
V = 2 * sec - 1;
}
}
Proof(int sec, const FHE_PK& pk, int n_proofs = 1) :
Proof(sec, pk.p() / 2, 2 * 3.2 * sqrt(pk.get_params().phi_m()), pk,
Proof(sec, pk.p() / 2,
pk.get_params().get_DG().get_NewHopeB(), pk,
n_proofs) {}
virtual ~Proof() {}
public:
static bigint slack(int slack, int sec, int phim);
void get_challenge(vector<int>& e, const octetStream& ciphertexts) const;
template <class T>
bool check_bounds(T& z, AddableMatrix<bigint>& t, int i) const;
static bool use_top_gear(const FHE_PK& pk)
{
return CowGearOptions::singleton.top_gear() and pk.p() > 2;
}
static int n_ciphertext_per_proof(int sec, const FHE_PK& pk)
{
return Proof(sec, pk, 1).U;
}
void set_challenge(const octetStream& ciphertexts);
void set_challenge(PRNG& G);
void generate_challenge(const Player& P);
template <class T, class X>
bool check_bounds(T& z, X& t, int i) const;
template<class T, class U>
void apply_challenge(int i, T& output, const U& input, const FHE_PK& pk) const
{
if (top_gear)
{
for (unsigned j = 0; j < this->U; j++)
if (W[i][j] >= 0)
output += input[j].mul_by_X_i(W[i][j], pk);
}
else
for (unsigned k = 0; k < sec; k++)
{
unsigned j = (i + 1) - (k + 1);
if (j < sec && e.at(j))
output += input.at(j);
}
}
};
class NonInteractiveProof : public Proof
@@ -111,8 +168,7 @@ public:
Proof(sec, pk, n_proofs)
{
bigint B;
// using mu = 1
B = bigint(1) << (sec - 1);
B = bigint(1) << sec;
B_plain_length = numBits(B * tau);
B_rand_length = numBits(B * rho);
// leeway for completeness

View File

@@ -3,6 +3,7 @@
#include "FHE/P2Data.h"
#include "Tools/random.h"
#include "Math/Z2k.hpp"
template <class FD, class U>
@@ -62,35 +63,25 @@ template <class FD, class U>
bool Prover<FD,U>::Stage_2(Proof& P, octetStream& cleartexts,
const vector<U>& x,
const Proof::Randomness& r,
const vector<int>& e)
const FHE_PK& pk)
{
size_t allocate = P.V * P.phim
* (5 + numBytes(P.plain_check) + 3 * (5 + numBytes(P.rand_check)));
cleartexts.resize_precise(allocate);
cleartexts.reset_write_head();
unsigned int i,k;
int j,ee;
unsigned int i;
#ifndef LESS_ALLOC_MORE_MEM
AddableVector<bigint> z;
AddableMatrix<bigint> t;
AddableVector<fixint<GFP_MOD_SZ>> z;
AddableMatrix<fixint<GFP_MOD_SZ>> t;
#endif
cleartexts.reset_write_head();
cleartexts.store(P.V);
for (i=0; i<P.V; i++)
{ z=y[i];
t=s[i];
for (k=0; k<P.sec; k++)
{ j=(i+1)-(k+1);
if (j<0 || j>=(int) P.sec) { ee=0; }
else { ee=e[j]; }
if (ee!=0)
{
z += x[j];
t += r[j];
}
}
P.apply_challenge(i, z, x, pk);
P.apply_challenge(i, t, r, pk);
if (not P.check_bounds(z, t, i))
return false;
z.pack(cleartexts);
@@ -118,8 +109,6 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
const Proof::Randomness& r,
bool Diag,bool binary)
{
vector<int> e(P.sec);
// AElement<T> AE;
// for (i=0; i<P.sec; i++)
// { AE.assign(x.at(i));
@@ -134,9 +123,9 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
while (!ok)
{ cnt++;
Stage_1(P,ciphertexts,c,pk,Diag,binary);
P.get_challenge(e, ciphertexts);
P.set_challenge(ciphertexts);
// Check check whether we are OK, or whether we should abort
ok = Stage_2(P,cleartexts,x,r,e);
ok = Stage_2(P,cleartexts,x,r,pk);
}
if (cnt > 1)
cout << "\t\tNumber iterations of prover = " << cnt << endl;
@@ -173,7 +162,7 @@ void Prover<FD, U>::report_size(ReportType type, MemoryUsage& res)
template class Prover<FFT_Data, Plaintext_<FFT_Data> >;
template class Prover<FFT_Data, AddableVector<bigint> >;
//template class Prover<FFT_Data, AddableVector<bigint> >;
template class Prover<P2Data, Plaintext_<P2Data> >;
template class Prover<P2Data, AddableVector<bigint> >;
//template class Prover<P2Data, AddableVector<bigint> >;

View File

@@ -14,8 +14,8 @@ class Prover
AddableVector< Plaintext_<FD> > y;
#ifdef LESS_ALLOC_MORE_MEM
AddableVector<bigint> z;
AddableMatrix<bigint> t;
AddableVector<fixint<GFP_MOD_SZ>> z;
AddableMatrix<Int_Random_Coins::value_type::value_type> t;
#endif
public:
@@ -30,7 +30,7 @@ public:
bool Stage_2(Proof& P, octetStream& cleartexts,
const vector<U>& x,
const Proof::Randomness& r,
const vector<int>& e);
const FHE_PK& pk);
/* Only has a non-interactive version using the ROM
- If Diag is true then the plaintexts x are assumed to be

View File

@@ -85,6 +85,7 @@ void Triple_Checking(const Player& P, MAC_Check<T>& MC, int nm,
// Triple checking
int left_todo=nm;
factory.triples.clear();
while (left_todo>0)
{ int this_loop=amortize;
if (this_loop>left_todo)
@@ -121,7 +122,17 @@ void Triple_Checking(const Player& P, MAC_Check<T>& MC, int nm,
for (int i=0; i<this_loop; i++)
{ if (!Tau[i].is_zero())
{ throw Offline_Check_Error("Multiplication Triples"); }
{
MC.POpen(PO, {a1[i], b1[i], c1[i], a2[i], b2[i], c2[i]}, P);
for (int j = 0; j < 2; j++)
{
auto prod = PO[3 * j + 1] * PO[3 * j];
if (PO[3 * j + 2] != prod)
cout << PO[3 * j + 2] << " != " << prod << " = "
<< PO[3 * j + 1] << " * " << PO[3 * j] << endl;
}
throw Offline_Check_Error("Multiplication Triples");
}
if (write_output)
{
a1[i].output(outf,false);
@@ -326,6 +337,7 @@ void Square_Checking(const Player& P, MAC_Check<T>& MC, int ns,
// Do the square checking
int left_todo=ns;
square_factory.tuples.clear();
while (left_todo>0)
{ int this_loop=amortize;
if (this_loop>left_todo)
@@ -363,6 +375,8 @@ void Square_Checking(const Player& P, MAC_Check<T>& MC, int ns,
{
a[i].output(outf_s,false); b[i].output(outf_s,false);
}
else
square_factory.tuples.push_back({{a[i], b[i]}});
}
left_todo-=this_loop;
}

View File

@@ -25,6 +25,8 @@ template <class T>
class TupleSacriFactory
{
public:
vector<array<T, 2>> tuples;
virtual ~TupleSacriFactory() {}
virtual void get(T& a, T& b) = 0;
};

View File

@@ -14,7 +14,8 @@
template<class T, class FD, class S>
SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) :
sec(machine.sec), extra_slack(machine.extra_slack), n_rounds(0)
EncCommitBase_<FD>(&machine),
extra_slack(machine.extra_slack), n_rounds(0)
{
}
@@ -36,7 +37,7 @@ NonInteractiveProofSimpleEncCommit<FD>::NonInteractiveProofSimpleEncCommit(
P(P), pk(pk), FTD(FTD),
proof(machine.sec, pk, machine.extra_slack),
#ifdef LESS_ALLOC_MORE_MEM
r(this->sec, this->pk.get_params()), prover(proof, FTD),
r(proof.U, this->pk.get_params()), prover(proof, FTD),
verifier(proof),
#endif
timers(timers)
@@ -46,9 +47,9 @@ NonInteractiveProofSimpleEncCommit<FD>::NonInteractiveProofSimpleEncCommit(
template <class FD>
SimpleEncCommitFactory<FD>::SimpleEncCommitFactory(const FHE_PK& pk,
const FD& FTD, const MachineBase& machine) :
cnt(-1), n_calls(0)
cnt(-1), n_calls(0), pk(pk)
{
int sec = machine.sec;
int sec = Proof::n_ciphertext_per_proof(machine.sec, pk);
c.resize(sec, pk.get_params());
m.resize(sec, FTD);
for (int i = 0; i < sec; i++)
@@ -70,6 +71,13 @@ void SimpleEncCommitFactory<FD>::next(Plaintext_<FD>& mess, Ciphertext& C)
create_more();
mess = m[cnt];
C = c[cnt];
if (Proof::use_top_gear(pk))
{
mess = mess + mess;
C = C + C;
}
cnt--;
n_calls++;
}
@@ -84,18 +92,21 @@ void SimpleEncCommitFactory<FD>::prepare_plaintext(PRNG& G)
template<class T,class FD,class S>
void SimpleEncCommitBase<T, FD, S>::generate_ciphertexts(
AddableVector<Ciphertext>& c, const vector<Plaintext_<FD> >& m,
Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers)
Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers,
Proof& proof)
{
timers["Generating"].start();
PRNG G;
G.ReSeed();
prepare_plaintext(G);
Random_Coins rc(pk.get_params());
for (int i = 0; i < sec; i++)
c.resize(proof.U, pk);
r.resize(proof.U, pk);
for (unsigned i = 0; i < proof.U; i++)
{
r[i].sample(G);
rc.assign(r[i]);
pk.encrypt(c[i], m[i], rc);
pk.encrypt(c[i], m.at(i), rc);
}
timers["Generating"].stop();
memory_usage.update("random coins", rc.report_size(CAPACITY));
@@ -103,14 +114,14 @@ void SimpleEncCommitBase<T, FD, S>::generate_ciphertexts(
template <class FD>
size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciphertext>& c,
const vector<Plaintext_<FD> >& m, octetStream& ciphertexts,
vector<Plaintext_<FD> >& m, octetStream& ciphertexts,
octetStream& cleartexts)
{
timers["Proving"].start();
#ifndef LESS_ALLOC_MORE_MEM
Proof::Randomness r(this->sec, pk.get_params());
Proof::Randomness r(proof.U, pk.get_params());
#endif
this->generate_ciphertexts(c, m, r, pk, timers);
this->generate_ciphertexts(c, m, r, pk, timers, proof);
#ifndef LESS_ALLOC_MORE_MEM
Prover<FD, Plaintext_<FD> > prover(proof, FTD);
#endif
@@ -118,6 +129,13 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
pk, c, m, r, false, false);
timers["Proving"].stop();
if (proof.top_gear)
{
c += c;
for (auto& mm : m)
mm += mm;
}
// cout << "Checking my own proof" << endl;
// if (!Verifier<gfp>().NIZKPoK(c[P.my_num()], proofs[P.my_num()], pk, false, false))
// throw runtime_error("proof check failed");
@@ -137,7 +155,7 @@ void SimpleEncCommit<T,FD,S>::create_more()
cleartexts);
cout << "Done checking proofs in round " << this->n_rounds << endl;
this->n_rounds++;
this->cnt = this->sec - 1;
this->cnt = this->proof.U - 1;
this->memory_usage.update("serialized ciphertexts",
ciphertexts.get_max_length());
this->memory_usage.update("serialized cleartexts", cleartexts.get_max_length());
@@ -150,7 +168,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
octetStream& cleartexts)
{
AddableVector<Ciphertext> others_ciphertexts;
others_ciphertexts.resize(this->sec, pk.get_params());
others_ciphertexts.resize(proof.U, pk.get_params());
for (int i = 1; i < P.num_players(); i++)
{
#ifdef VERBOSE_HE
@@ -189,18 +207,31 @@ void SimpleEncCommit<T, FD, S>::add_ciphertexts(
vector<Ciphertext>& ciphertexts, int offset)
{
(void)offset;
for (int j = 0; j < this->sec; j++)
for (unsigned j = 0; j < this->proof.U; j++)
add(this->c[j], this->c[j], ciphertexts[j]);
}
template<class FD>
SummingEncCommit<FD>::SummingEncCommit(const Player& P, const FHE_PK& pk,
const FD& FTD, map<string, Timer>& timers, const MachineBase& machine,
int thread_num) :
SimpleEncCommitFactory<FD>(pk, FTD, machine), SimpleEncCommitBase_<FD>(
machine), proof(machine.sec, pk, P.num_players()), pk(pk), FTD(
FTD), P(P), thread_num(thread_num),
#ifdef LESS_ALLOC_MORE_MEM
prover(proof, FTD), verifier(proof), preimages(proof.V,
this->pk, FTD.get_prime(), P.num_players()),
#endif
timers(timers)
{
}
template<class FD>
void SummingEncCommit<FD>::create_more()
{
octetStream cleartexts;
const Player& P = this->P;
InteractiveProof proof(this->sec, this->pk, P.num_players());
AddableVector<Ciphertext> commitments;
vector<int> e(this->sec);
size_t prover_size;
MemoryUsage& memory_usage = this->memory_usage;
TreeSum<Ciphertext> tree_sum(2, 2, thread_num % P.num_players());
@@ -210,10 +241,10 @@ void SummingEncCommit<FD>::create_more()
#ifdef LESS_ALLOC_MORE_MEM
Proof::Randomness& r = preimages.r;
#else
Proof::Randomness r(this->sec, this->pk.get_params());
Proof::Randomness r(proof.U, this->pk.get_params());
Prover<FD, Plaintext_<FD> > prover(proof, this->FTD);
#endif
this->generate_ciphertexts(this->c, this->m, r, pk, timers);
this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof);
this->timers["Stage 1 of proof"].start();
prover.Stage_1(proof, ciphertexts, this->c, this->pk, false, false);
this->timers["Stage 1 of proof"].stop();
@@ -228,10 +259,10 @@ void SummingEncCommit<FD>::create_more()
tree_sum.run(commitments, P);
this->timers["Exchanging ciphertexts"].stop();
generate_challenge(e, P);
proof.generate_challenge(P);
this->timers["Stage 2 of proof"].start();
prover.Stage_2(proof, cleartexts, this->m, r, e);
prover.Stage_2(proof, cleartexts, this->m, r, pk);
this->timers["Stage 2 of proof"].stop();
prover_size = prover.report_size(CAPACITY) + r.report_size(CAPACITY)
@@ -273,10 +304,10 @@ void SummingEncCommit<FD>::create_more()
#else
Verifier<FD,S> verifier(proof);
#endif
verifier.Stage_2(e, this->c, ciphertexts, cleartexts,
verifier.Stage_2(this->c, ciphertexts, cleartexts,
this->pk, false, false);
this->timers["Verifying"].stop();
this->cnt = this->sec - 1;
this->cnt = proof.U - 1;
this->volatile_memory =
+ commitments.report_size(CAPACITY) + ciphertexts.get_max_length()
@@ -339,7 +370,7 @@ template <class FD>
void MultiEncCommit<FD>::add_ciphertexts(vector<Ciphertext>& ciphertexts,
int offset)
{
for (int i = 0; i < this->sec; i++)
for (unsigned i = 0; i < this->proof.U; i++)
generator.multipliers[offset - 1]->multiply_and_add(generator.c.at(i),
ciphertexts.at(i), generator.b_mod_q.at(i));
}

View File

@@ -20,14 +20,13 @@ template<class T,class FD,class S>
class SimpleEncCommitBase : public EncCommitBase<T,FD,S>
{
protected:
int sec;
int extra_slack;
int n_rounds;
void generate_ciphertexts(AddableVector<Ciphertext>& c,
const vector<Plaintext_<FD> >& m, Proof::Randomness& r,
const FHE_PK& pk, map<string, Timer>& timers);
const FHE_PK& pk, map<string, Timer>& timers, Proof& proof);
virtual void prepare_plaintext(PRNG& G) = 0;
@@ -46,12 +45,16 @@ template <class FD>
class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_<FD>
{
protected:
typedef bigint S;
typedef fixint<GFP_MOD_SZ> S;
const PlayerBase& P;
const FHE_PK& pk;
const FD& FTD;
virtual const FHE_PK& get_pk_for_verification(int offset) = 0;
virtual void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset) = 0;
public:
NonInteractiveProof proof;
#ifdef LESS_ALLOC_MORE_MEM
@@ -60,17 +63,13 @@ protected:
Verifier<FD,S> verifier;
#endif
virtual const FHE_PK& get_pk_for_verification(int offset) = 0;
virtual void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset) = 0;
public:
map<string, Timer>& timers;
NonInteractiveProofSimpleEncCommit(const PlayerBase& P, const FHE_PK& pk,
const FD& FTD, map<string, Timer>& timers,
const MachineBase& machine);
virtual ~NonInteractiveProofSimpleEncCommit() {}
size_t generate_proof(AddableVector<Ciphertext>& c, const vector<Plaintext_<FD> >& m,
size_t generate_proof(AddableVector<Ciphertext>& c, vector<Plaintext_<FD> >& m,
octetStream& ciphertexts, octetStream& cleartexts);
size_t create_more(octetStream& my_ciphertext, octetStream& my_cleartext);
virtual size_t report_size(ReportType type);
@@ -87,6 +86,8 @@ protected:
int n_calls;
const FHE_PK& pk;
void prepare_plaintext(PRNG& G);
virtual void create_more() = 0;
@@ -104,7 +105,8 @@ class SimpleEncCommit: public NonInteractiveProofSimpleEncCommit<FD>,
public SimpleEncCommitFactory<FD>
{
protected:
const FHE_PK& get_pk_for_verification(int offset) { (void)offset; return this->pk; }
const FHE_PK& get_pk_for_verification(int)
{ return NonInteractiveProofSimpleEncCommit<FD>::pk; }
void prepare_plaintext(PRNG& G)
{ SimpleEncCommitFactory<FD>::prepare_plaintext(G); }
void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset);
@@ -127,7 +129,7 @@ template <class FD>
class SummingEncCommit: public SimpleEncCommitFactory<FD>,
public SimpleEncCommitBase_<FD>
{
typedef bigint S;
typedef fixint<GFP_MOD_SZ> S;
InteractiveProof proof;
const FHE_PK& pk;
@@ -148,15 +150,8 @@ public:
map<string, Timer>& timers;
SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD,
map<string, Timer>& timers, const MachineBase& machine, int thread_num) :
SimpleEncCommitFactory<FD>(pk, FTD, machine), SimpleEncCommitBase_<FD>(machine),
proof(this->sec, pk, P.num_players()), pk(pk), FTD(FTD), P(P),
thread_num(thread_num),
#ifdef LESS_ALLOC_MORE_MEM
prover(proof, FTD), verifier(proof), preimages(proof.V, this->pk,
FTD.get_prime(), P.num_players()),
#endif
timers(timers) {}
map<string, Timer>& timers, const MachineBase& machine, int thread_num);
void next(Plaintext_<FD>& mess, Ciphertext& C) { SimpleEncCommitFactory<FD>::next(mess, C); }
void create_more();
size_t report_size(ReportType type);

View File

@@ -12,10 +12,11 @@
template <template <class> class T, class FD>
SimpleGenerator<T,FD>::SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
const MultiplicativeMachine& machine, int thread_num, Dtype data_type) :
GeneratorBase(thread_num, N),
setup(setup), machine(machine), dd(P, setup),
volatile_memory(0),
const MultiplicativeMachine& machine,
int thread_num, Dtype data_type, Player* player) :
GeneratorBase(thread_num, N, player),
setup(setup), machine(machine),
volatile_memory(0), dd(P, setup),
EC(P, setup.pk, setup.FieldD, timers, machine, thread_num)
{
if (machine.produce_inputs)
@@ -51,14 +52,14 @@ SimpleGenerator<T,FD>::~SimpleGenerator()
}
template <template <class> class T, class FD>
void SimpleGenerator<T,FD>::run()
void SimpleGenerator<T,FD>::run(bool exhaust)
{
Timer timer(CLOCK_THREAD_CPUTIME_ID);
timer.start();
timers["MC init"].start();
MAC_Check<typename FD::T> MC(setup.alphai);
timers["MC init"].stop();
while (total < machine.nTriplesPerThread or EC.has_left())
while (total < machine.nTriplesPerThread or (exhaust and EC.has_left()))
{
producer->run(P, setup.pk, setup.calpha, EC, dd, setup.alphai);
producer->sacrifice(P, MC);

View File

@@ -55,20 +55,20 @@ class SimpleGenerator : public GeneratorBase
const PartSetup<FD>& setup;
const MultiplicativeMachine& machine;
SimpleDistDecrypt<FD> dd;
size_t volatile_memory;
public:
SimpleDistDecrypt<FD> dd;
T<FD> EC;
Producer<FD>* producer;
SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
const MultiplicativeMachine& machine, int thread_num,
Dtype data_type = DATA_TRIPLE);
Dtype data_type = DATA_TRIPLE, Player* player = 0);
~SimpleGenerator();
void run();
void run() { run(true); }
void run(bool exhaust);
size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res);
size_t report_sent() { return P.sent; }

View File

@@ -53,8 +53,6 @@ public:
class MultiplicativeMachine : public MachineBase
{
protected:
DataSetup setup;
void parse_options(int argc, const char** argv);
void generate_setup(int slack);
@@ -62,6 +60,8 @@ protected:
void fake_keys(int slack);
public:
DataSetup setup;
virtual ~MultiplicativeMachine() {}
virtual int get_covert() const { return 0; }

View File

@@ -1,7 +1,8 @@
#include "Verifier.h"
#include "Math/Z2k.hpp"
template <class FD, class S>
Verifier<FD,S>::Verifier(const Proof& proof) : P(proof)
Verifier<FD,S>::Verifier(Proof& proof) : P(proof)
{
#ifdef LESS_ALLOC_MORE_MEM
z.resize(proof.phim);
@@ -39,19 +40,18 @@ bool Check_Decoding(const vector<S>& AE, bool Diag)
template <class FD, class S>
void Verifier<FD,S>::Stage_2(const vector<int>& e,
void Verifier<FD,S>::Stage_2(
AddableVector<Ciphertext>& c,octetStream& ciphertexts,
octetStream& cleartexts,
const FHE_PK& pk,bool Diag,bool binary)
{
unsigned int i,k,V=P.V;
unsigned int i, V;
c.unpack(ciphertexts, pk);
if (c.size() != P.sec)
if (c.size() != P.U)
throw length_error("number of received ciphertexts incorrect");
// Now check the encryptions are correct
int ee;
Ciphertext d1(pk.get_params()), d2(pk.get_params());
Random_Coins rc(pk.get_params());
ciphertexts.get(V);
@@ -67,13 +67,7 @@ void Verifier<FD,S>::Stage_2(const vector<int>& e,
if (!P.check_bounds(z, t, i))
throw runtime_error("preimage out of bounds");
d1.unpack(ciphertexts);
for (k=0; k<P.sec; k++)
{ int jj=(i+1)-(k+1);
if (jj<0 || jj>= (int) P.sec) { ee=0; }
else { ee=e[jj]; }
if (ee!=0)
{ add(d1,d1,c.at(jj)); }
}
P.apply_challenge(i, d1, c, pk);
rc.assign(t[0], t[1], t[2]);
pk.encrypt(d2,z,rc);
if (!(d1 == d2))
@@ -107,13 +101,18 @@ void Verifier<FD,S>::NIZKPoK(AddableVector<Ciphertext>& c,
const FHE_PK& pk,bool Diag,
bool binary)
{
vector<int> e(P.sec);
P.set_challenge(ciphertexts);
P.get_challenge(e, ciphertexts);
Stage_2(c,ciphertexts,cleartexts,pk,Diag,binary);
Stage_2(e,c,ciphertexts,cleartexts,pk,Diag,binary);
if (P.top_gear)
{
assert(not Diag);
assert(not binary);
c += c;
}
}
template class Verifier<FFT_Data, bigint>;
template class Verifier<P2Data, bigint>;
template class Verifier<FFT_Data, fixint<2>>;
template class Verifier<P2Data, fixint<2>>;

View File

@@ -8,14 +8,14 @@ template <class FD, class S>
class Verifier
{
AddableVector<S> z;
AddableMatrix<S> t;
AddableMatrix<Int_Random_Coins::value_type::value_type> t;
const Proof& P;
Proof& P;
public:
Verifier(const Proof& proof);
Verifier(Proof& proof);
void Stage_2(const vector<int>& e,
void Stage_2(
AddableVector<Ciphertext>& c, octetStream& ciphertexts,
octetStream& cleartexts,const FHE_PK& pk,bool Diag,bool binary=false);

31
GC/BitAdder.h Normal file
View File

@@ -0,0 +1,31 @@
/*
* BitAdder.h
*
*/
#ifndef GC_BITADDER_H_
#define GC_BITADDER_H_
#include <vector>
using namespace std;
class BitAdder
{
public:
template<class T>
void add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
SubProcessor<T>& proc, int length, ThreadQueues* queues = 0,
int player = -1);
template<class T>
void add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
size_t begin, size_t end, SubProcessor<T>& proc, int length,
int input_begin = -1, const void* supply = 0);
template<class T>
void multi_add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
size_t begin, size_t end, SubProcessor<T>& proc, int length,
int input_begin);
};
#endif /* GC_BITADDER_H_ */

163
GC/BitAdder.hpp Normal file
View File

@@ -0,0 +1,163 @@
/*
* BitAdder.cpp
*
*/
#include "BitAdder.h"
#include <assert.h>
template<class T>
void BitAdder::add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
SubProcessor<T>& proc, int length, ThreadQueues* queues, int player)
{
assert(not summands.empty());
assert(not summands[0].empty());
res.resize(summands[0][0].size());
if (queues)
{
assert(length == T::default_length);
int n_available = queues->find_available();
int n_per_thread = queues->get_n_per_thread(res.size());
vector<vector<array<T, 3>>> triples(n_available);
vector<void*> supplies(n_available);
for (int i = 0; i < n_available; i++)
{
if (T::expensive_triples)
{
supplies[i] = &triples[i];
for (size_t j = 0; j < n_per_thread * summands.size(); j++)
triples[i].push_back(proc.DataF.get_triple(T::default_length));
#ifdef VERBOSE_EDA
cerr << "supplied " << triples[i].size() << endl;
#endif
}
}
ThreadJob job(&res, &summands, T::default_length, player);
int start = queues->distribute_no_setup(job, res.size(), 0, 1,
&supplies);
BitAdder().add(res, summands, start,
summands[0][0].size(), proc, T::default_length);
queues->wrap_up(job);
}
else
add(res, summands, 0, res.size(), proc, length);
}
template<class T>
void BitAdder::add(vector<vector<T> >& res,
const vector<vector<vector<T> > >& summands, size_t begin, size_t end,
SubProcessor<T>& proc, int length, int input_begin, const void* supply)
{
#ifdef VERBOSE_EDA
fprintf(stderr, "add bits %lu to %lu\n", begin, end);
#endif
if (input_begin < 0)
input_begin = begin;
int n_bits = summands.size();
for (size_t i = begin; i < end; i++)
res[i].resize(n_bits + 1);
size_t n_items = end - begin;
if (supply)
{
#ifdef VERBOSE_EDA
fprintf(stderr, "got supply\n");
#endif
auto& s = *(vector<array<T, 3>>*) supply;
assert(s.size() == n_items * n_bits);
proc.DataF.push_triples(s);
}
if (summands[0].size() > 2)
return multi_add(res, summands, begin, end, proc, length, input_begin);
vector<T> carries(n_items);
auto& protocol = proc.protocol;
for (int i = 0; i < n_bits; i++)
{
assert(summands[i].size() == 2);
assert(summands[i][0].size() >= input_begin + n_items);
assert(summands[i][1].size() >= input_begin + n_items);
vector<T> a(n_items), b(n_items);
for (size_t j = 0; j < n_items; j++)
{
a[j] = summands[i][0][input_begin + j];
b[j] = summands[i][1][input_begin + j];
}
protocol.init_mul(&proc);
for (size_t j = 0; j < n_items; j++)
{
res[begin + j][i] = a[j] + b[j] + carries[j];
// full adder using MUX
protocol.prepare_mul(a[j] + b[j], carries[j] + a[j], length);
}
protocol.exchange();
for (size_t j = 0; j < n_items; j++)
carries[j] = a[j] + protocol.finalize_mul(length);
}
for (size_t j = 0; j < n_items; j++)
res[begin + j][n_bits] = carries[j];
}
template<class T>
void BitAdder::multi_add(vector<vector<T> >& res,
const vector<vector<vector<T> > >& summands, size_t begin, size_t end,
SubProcessor<T>& proc, int length, int input_begin)
{
int n_bits = summands.size() + ceil(log2(proc.P.num_players()));
size_t n_items = end - begin;
assert(n_bits > 0);
vector<vector<vector<T>>> my_summands(n_bits);
for (auto& x : my_summands)
{
x.resize(2);
for (auto& y : x)
y.resize(n_items);
}
for (size_t i = 0; i < summands.size(); i++)
for (int j = 0; j < 2; j++)
{
auto& x = my_summands.at(i).at(j);
auto& z = summands.at(i).at(j);
auto y = z.begin() + input_begin;
assert(y + n_items <= z.end());
x.clear();
x.insert(x.begin(), y, y + n_items);
}
vector<vector<T>> my_res(n_items);
for (size_t k = 2; k < summands.at(0).size(); k++)
{
add(my_res, my_summands, 0, n_items, proc, length);
for (size_t i = 0; i < my_summands.size(); i++)
for (size_t j = 0; j < n_items; j++)
my_summands[i][0][j] = my_res[j][i];
for (size_t i = 0; i < summands.size(); i++)
{
auto& z = summands.at(i).at(k);
auto y = z.begin() + input_begin;
assert(y + n_items <= z.end());
auto& x = my_summands.at(i).at(1);
x.clear();
x.insert(x.begin(), y, y + n_items);
}
}
add(res, my_summands, begin, end, proc, length, 0);
}

74
GC/CcdPrep.h Normal file
View File

@@ -0,0 +1,74 @@
/*
* CcdPrep.h
*
*/
#ifndef GC_CCDPREP_H_
#define GC_CCDPREP_H_
#include "Protocols/ReplicatedPrep.h"
class DataPositions;
namespace GC
{
template<class T> class ShareThread;
template<class T>
class CcdPrep : public BufferPrep<T>
{
typename T::part_type::LivePrep part_prep;
typename T::part_type::MAC_Check part_MC;
SubProcessor<typename T::part_type>* part_proc;
ShareThread<T>& thread;
public:
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
BufferPrep<T>(usage), part_prep(usage), part_proc(0), thread(thread)
{
}
~CcdPrep()
{
if (part_proc)
delete part_proc;
}
void set_protocol(typename T::Protocol& protocol)
{
part_proc = new SubProcessor<typename T::part_type>(part_MC,
part_prep, protocol.get_part().P);
}
Preprocessing<typename T::part_type>& get_part()
{
return part_prep;
}
void buffer_triples()
{
throw not_implemented();
}
void buffer_bits()
{
assert(part_proc);
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
this->bits.push_back(part_prep.get_bit());
}
void buffer_squares()
{
throw not_implemented();
}
void buffer_inverses()
{
throw not_implemented();
}
};
} /* namespace GC */
#endif /* GC_CCDPREP_H_ */

72
GC/CcdSecret.h Normal file
View File

@@ -0,0 +1,72 @@
/*
* CcdSecret.h
*
*/
#ifndef GC_CCDSECRET_H_
#define GC_CCDSECRET_H_
#include "TinySecret.h"
#include "CcdShare.h"
namespace GC
{
template<class T> class TinyMC;
template<class T> class VectorProtocol;
template<class T> class VectorInput;
template<class T> class CcdPrep;
template<class T>
class CcdSecret : public VectorSecret<CcdShare<T>>
{
typedef CcdSecret This;
typedef VectorSecret<CcdShare<T>> super;
public:
typedef TinyMC<This> MC;
typedef MC MAC_Check;
typedef VectorProtocol<This> Protocol;
typedef CcdPrep<This> LivePrep;
typedef VectorInput<This> Input;
typedef typename This::part_type check_type;
static string type_short()
{
return "CCD";
}
static MC* new_mc(typename super::mac_key_type mac_key)
{
return new MC(mac_key);
}
CcdSecret()
{
}
CcdSecret(const typename This::part_type& other) :
super(other)
{
}
CcdSecret(const super& other) :
super(other)
{
}
CcdSecret(const typename This::part_type::super& other) :
super(other)
{
}
CcdSecret(const GC::Clear& other) :
super(other)
{
}
};
} /* namespace GC */
#endif /* GC_CCDSECRET_H_ */

93
GC/CcdShare.h Normal file
View File

@@ -0,0 +1,93 @@
/*
* CcdShare.h
*
*/
#ifndef GC_CCDSHARE_H_
#define GC_CCDSHARE_H_
#include "Protocols/ShamirShare.h"
namespace GC
{
template<class T> class CcdSecret;
template<class T>
class CcdShare : public ShamirShare<T>, public ShareSecret<CcdSecret<T>>
{
typedef CcdShare This;
public:
typedef ShamirShare<T> super;
typedef Bit clear;
typedef ReplicatedPrep<This> LivePrep;
typedef ShamirInput<This> Input;
typedef ShamirMC<This> MAC_Check;
typedef This small_type;
typedef NoShare bit_type;
static const int default_length = 1;
static string name()
{
return "CCD";
}
static MAC_Check* new_mc(T)
{
return new MAC_Check;
}
static This new_reg()
{
return {};
}
CcdShare()
{
}
CcdShare(const CcdSecret<T>& other) :
super(other.get_bit(0))
{
}
template<class U>
CcdShare(const U& other) :
super(other)
{
}
void XOR(const This& a, const This& b)
{
*this = a + b;
}
void public_input(bool input)
{
*this = input;
}
void random()
{
CcdSecret<T> tmp;
ShareThread<CcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
This& operator^=(const This& other)
{
*this += other;
return *this;
}
};
}
#endif /* GC_CCDSHARE_H_ */

View File

@@ -53,8 +53,14 @@ public:
{ processor.andrs(args); }
static void ands(GC::Processor<FakeSecret>& processor, const vector<int>& regs);
template <class T>
static void xors(GC::Processor<T>& processor, const vector<int>& regs)
{ processor.xors(regs); }
template <class T>
static void inputb(T& processor, const vector<int>& args)
{ processor.input(args); }
template <class T>
static void reveal_inst(T& processor, const vector<int>& args)
{ processor.reveal(args); }
static void trans(Processor<FakeSecret>& processor, int n_inputs,
const vector<int>& args);

View File

@@ -64,6 +64,8 @@ enum
STMSBI = 0x243,
MOVSB = 0x244,
INPUTB = 0x246,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
// write to clear
CLEAR_WRITE = 0x210,
XORCBI = 0x210,
@@ -79,6 +81,7 @@ enum
MULCBI = 0x21c,
SHRCBI = 0x21d,
SHLCBI = 0x21e,
CONVCINTVEC = 0x21f,
// don't write
PRINTREGSIGNED = 0x220,
PRINTREGB = 0x221,
@@ -87,6 +90,7 @@ enum
CONDPRINTSTRB = 0x224,
// write to regint
CONVCBIT = 0x230,
CONVCBITVEC = 0x231,
};
#endif /* PROCESSOR_GC_INSTRUCTION_H_ */

View File

@@ -3,6 +3,9 @@
*
*/
#ifndef GC_INSTRUCTION_HPP_
#define GC_INSTRUCTION_HPP_
#include <algorithm>
#include "GC/Instruction.h"
@@ -14,6 +17,8 @@
#include "GC/Instruction_inline.h"
#include "Processor/Instruction.hpp"
namespace GC
{
@@ -97,3 +102,5 @@ void Instruction::parse(istream& s, int pos)
}
} /* namespace GC */
#endif

View File

@@ -11,6 +11,8 @@
#include "GC/Program.h"
#include "ThreadMaster.h"
#include "Program.hpp"
namespace GC
{

69
GC/MaliciousCcdSecret.h Normal file
View File

@@ -0,0 +1,69 @@
/*
* MaliciousCcdSecret.h
*
*/
#ifndef GC_MALICIOUSCCDSECRET_H_
#define GC_MALICIOUSCCDSECRET_H_
#include "CcdSecret.h"
#include "MaliciousCcdShare.h"
namespace GC
{
template<class T> class VectorInput;
template<class T>
class MaliciousCcdSecret : public VectorSecret<MaliciousCcdShare<T>>
{
typedef MaliciousCcdSecret This;
typedef VectorSecret<MaliciousCcdShare<T>> super;
public:
typedef TinyMC<This> MC;
typedef MC MAC_Check;
typedef VectorProtocol<This> Protocol;
typedef CcdPrep<This> LivePrep;
typedef VectorInput<This> Input;
typedef typename This::part_type check_type;
static string type_short()
{
return "CCD";
}
static MC* new_mc(typename super::mac_key_type mac_key)
{
return new MC(mac_key);
}
MaliciousCcdSecret()
{
}
MaliciousCcdSecret(const typename This::part_type& other) :
super(other)
{
}
MaliciousCcdSecret(const super& other) :
super(other)
{
}
MaliciousCcdSecret(const typename This::part_type::super& other) :
super(other)
{
}
MaliciousCcdSecret(const GC::Clear& other) :
super(other)
{
}
};
} /* namespace GC */
#endif /* GC_MALICIOUSCCDSECRET_H_ */

104
GC/MaliciousCcdShare.h Normal file
View File

@@ -0,0 +1,104 @@
/*
* MalicousCcdShare.h
*
*/
#ifndef GC_MALICIOUSCCDSHARE_H_
#define GC_MALICIOUSCCDSHARE_H_
#include "CcdShare.h"
#include "Protocols/MaliciousShamirShare.h"
template<class T> class MaliciousRepPrep;
namespace GC
{
template<class T> class MaliciousCcdSecret;
template<class T>
class MaliciousCcdShare: public MaliciousShamirShare<T>, public ShareSecret<
MaliciousCcdSecret<T>>
{
typedef MaliciousCcdShare This;
public:
typedef MaliciousShamirShare<T> super;
typedef Bit clear;
typedef MaliciousRepPrep<This> LivePrep;
typedef ShamirInput<This> Input;
typedef Beaver<This> Protocol;
typedef MaliciousShamirMC<This> MAC_Check;
typedef This small_type;
typedef NoShare bit_type;
static const int default_length = 1;
static string name()
{
return "Malicious CCD";
}
static MAC_Check* new_mc(T)
{
return new MAC_Check;
}
static This new_reg()
{
return {};
}
MaliciousCcdShare()
{
}
MaliciousCcdShare(const MaliciousCcdSecret<T>& other) :
super(other.get_bit(0))
{
}
template<class U>
MaliciousCcdShare(const U& other) :
super(other)
{
}
void XOR(const This& a, const This& b)
{
*this = a + b;
}
void public_input(bool input)
{
*this = input;
}
void random()
{
MaliciousCcdSecret<T> tmp;
ShareThread<MaliciousCcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
This& operator^=(const This& other)
{
*this += other;
return *this;
}
This get_bit(int i)
{
assert(i == 0);
return *this;
}
};
} /* namespace GC */
#endif /* GC_MALICIOUSCCDSHARE_H_ */

View File

@@ -21,6 +21,37 @@ namespace GC
template<class T> class ShareThread;
template<class T> class RepPrep;
class SmallMalRepSecret : public FixedVec<BitVec_<unsigned char>, 2>
{
typedef FixedVec<BitVec_<unsigned char>, 2> super;
typedef SmallMalRepSecret This;
public:
typedef MaliciousRepMC<This> MC;
typedef BitVec_<unsigned char> open_type;
typedef open_type clear;
typedef BitVec mac_key_type;
static MC* new_mc(mac_key_type)
{
return new HashMaliciousRepMC<This>;
}
SmallMalRepSecret()
{
}
template<class T>
SmallMalRepSecret(const T& other) :
super(other)
{
}
This lsb() const
{
return *this & 1;
}
};
class MaliciousRepSecret : public ReplicatedSecret<MaliciousRepSecret>
{
typedef ReplicatedSecret<MaliciousRepSecret> super;
@@ -36,6 +67,11 @@ public:
typedef RepPrep<MaliciousRepSecret> LivePrep;
typedef MaliciousRepSecret part_type;
typedef MaliciousRepSecret whole_type;
typedef SmallMalRepSecret small_type;
static const bool expensive_triples = true;
static MC* new_mc(mac_key_type)
{

View File

@@ -12,10 +12,11 @@
namespace GC
{
class NoValue
class NoValue : public ValueInterface
{
public:
const static int n_bits = 0;
const static int MAX_N_BITS = 0;
static bool allows(Dtype)
{
@@ -32,11 +33,33 @@ public:
throw runtime_error("VM does not support binary circuits");
}
NoValue() {}
NoValue(int) { fail(); }
void assign(const char*) { fail(); }
int get() const { fail(); return 0; }
int operator<<(int) const { fail(); return 0; }
void operator+=(int) { fail(); }
bool get_bit(int) { fail(); return 0; }
void randomize(PRNG&) { fail(); }
};
inline ostream& operator<<(ostream& o, NoValue)
{
return o;
}
template<class T>
inline bool operator!=(const T&, NoValue&)
{
NoValue::fail();
return true;
}
class NoShare : public Phase
{
public:
@@ -52,8 +75,13 @@ public:
typedef NoShare bit_type;
typedef NoShare part_type;
typedef NoShare small_type;
static const int default_length = 1;
static const bool needs_ot = false;
static const bool expensive_triples = false;
static const bool is_real = false;
static MC* new_mc(mac_key_type)
{
@@ -91,10 +119,13 @@ public:
}
static void inputb(Processor<NoShare>&, const vector<int>&) { fail(); }
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(GC::Clear, int, mac_key_type) { fail(); return {}; }
NoShare() {}
NoShare(int) { fail(); }
@@ -115,6 +146,13 @@ public:
void operator^=(NoShare) { fail(); }
NoShare operator+(const NoShare&) const { fail(); return {}; }
NoShare operator+(int) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
NoShare operator>>(int) const { fail(); return {}; }
NoShare lsb() const { fail(); return {}; }
NoShare get_bit(int) const { fail(); return {}; }
};
} /* namespace GC */

35
GC/PersonalPrep.h Normal file
View File

@@ -0,0 +1,35 @@
/*
* PersonalPrep.h
*
*/
#ifndef GC_PERSONALPREP_H_
#define GC_PERSONALPREP_H_
#include "Protocols/ReplicatedPrep.h"
#include "ShareThread.h"
namespace GC
{
template<class T>
class PersonalPrep : public BufferPrep<T>
{
protected:
static const int SECURE = -1;
const int input_player;
void buffer_personal_triples();
public:
PersonalPrep(DataPositions& usage, int input_player);
void buffer_personal_triples(size_t n, ThreadQueues* queues = 0);
void buffer_personal_triples(vector<array<T, 3>>& triples, size_t begin,
size_t end);
};
}
#endif /* GC_PERSONALPREP_H_ */

109
GC/PersonalPrep.hpp Normal file
View File

@@ -0,0 +1,109 @@
/*
* PersonalPrep.cpp
*
*/
#ifndef GC_PERSONALPREP_HPP_
#define GC_PERSONALPREP_HPP_
#include "PersonalPrep.h"
namespace GC
{
template<class T>
PersonalPrep<T>::PersonalPrep(DataPositions& usage, int input_player) :
BufferPrep<T>(usage), input_player(input_player)
{
assert((input_player >= 0) or (input_player == SECURE));
}
template<class T>
void PersonalPrep<T>::buffer_personal_triples()
{
buffer_personal_triples(OnlineOptions::singleton.batch_size);
}
template<class T>
void PersonalPrep<T>::buffer_personal_triples(size_t batch_size, ThreadQueues* queues)
{
ShuffleSacrifice<T> sacri;
batch_size = max(batch_size, (size_t)sacri.minimum_n_outputs()) + sacri.C;
vector<array<T, 3>> triples(batch_size);
if (queues)
{
PersonalTripleJob job(&triples, input_player);
int start = queues->distribute(job, batch_size);
buffer_personal_triples(triples, start, batch_size);
queues->wrap_up(job);
}
else
buffer_personal_triples(triples, 0, batch_size);
auto &party = ShareThread<typename T::whole_type>::s();
assert(party.P != 0);
assert(party.MC != 0);
auto& MC = party.MC->get_part_MC();
auto& P = *party.P;
GlobalPRNG G(P);
vector<T> shares;
for (int i = 0; i < sacri.C; i++)
{
int challenge = G.get_uint(triples.size());
for (auto& x : triples[challenge])
shares.push_back(x);
triples.erase(triples.begin() + challenge);
}
PointerVector<typename T::open_type> opened;
MC.POpen(opened, shares, P);
for (int i = 0; i < sacri.C; i++)
{
array<typename T::open_type, 3> triple({{opened.next(), opened.next(),
opened.next()}});
if (triple[0] * triple[1] != triple[2])
{
cout << triple[2] << " != " << triple[0] * triple[1] << " = "
<< triple[0] << " * " << triple[1] << endl;
throw runtime_error("personal triple incorrect");
}
}
this->triples.insert(this->triples.end(), triples.begin(), triples.end());
}
template<class T>
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
size_t begin, size_t end)
{
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
RunningTimer timer;
#endif
auto& party = ShareThread<typename T::whole_type>::s();
auto& MC = party.MC->get_part_MC();
auto& P = *party.P;
assert(input_player < P.num_players());
typename T::Input input(MC, *this, P);
input.reset_all(P);
for (size_t i = begin; i < end; i++)
{
typename T::clear x[2];
for (int j = 0; j < 2; j++)
this->get_input(triples[i][j], x[j], input_player);
if (P.my_num() == input_player)
input.add_mine(x[0] * x[1], T::default_length);
else
input.add_other(input_player);
}
input.exchange();
for (size_t i = begin; i < end; i++)
triples[i][2] = input.finalize(input_player, T::default_length);
#ifdef VERBOSE_EDA
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
#endif
}
}
#endif

View File

@@ -14,6 +14,7 @@ using namespace std;
#include "Math/Integer.h"
#include "Processor/ProcessorBase.h"
#include "Processor/Instruction.h"
namespace GC
{
@@ -84,11 +85,15 @@ public:
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
void xors(const vector<int>& args);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);
void andrs(const vector<int>& args) { and_(args, true); }
void ands(const vector<int>& args) { and_(args, false); }
void input(const vector<int>& args);
void reveal(const vector<int>& args);
void reveal(const ::BaseInstruction& instruction);
void print_reg(int reg, int n);
void print_reg_plain(Clear& value);

View File

@@ -185,6 +185,14 @@ void Processor<T>::xors(const vector<int>& args)
}
}
template<class T>
void Processor<T>::andm(const ::BaseInstruction& instruction)
{
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i]
& C[instruction.get_r(2) + i];
}
template <class T>
void Processor<T>::and_(const vector<int>& args, bool repeat)
{
@@ -209,6 +217,22 @@ void Processor<T>::input(const vector<int>& args)
}
}
template <class T>
void Processor<T>::reveal(const vector<int>& args)
{
for (size_t j = 0; j < args.size(); j += 3)
{
int n = args[j];
int r0 = args[j + 1];
int r1 = args[j + 2];
if (n > max(T::default_length, Clear::N_BITS))
assert(T::default_length == Clear::N_BITS);
for (int i = 0; i < DIV_CEIL(n, T::default_length); i++)
S[r1 + i].reveal(min(Clear::N_BITS, n - i * Clear::N_BITS),
C[r0 + i]);
}
}
template <class T>
void Processor<T>::print_reg(int reg, int n)
{
@@ -232,6 +256,8 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Clear& value)
unsigned n_shift = 0;
if (n_bits > 1)
n_shift = sizeof(value.get()) * 8 - n_bits;
if (n_shift > 63)
n_shift = 0;
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
}

View File

@@ -3,12 +3,17 @@
*
*/
#ifndef GC_PROGRAM_HPP_
#define GC_PROGRAM_HPP_
#include <GC/Program.h>
#include "config.h"
#include "Tools/callgrind.h"
#include "Instruction.hpp"
namespace GC
{
@@ -133,3 +138,5 @@ BreakType Program<T>::execute(Processor<T>& Proc, U& dynamic_memory,
}
} /* namespace GC */
#endif

View File

@@ -8,6 +8,7 @@
#include "MaliciousRepSecret.h"
#include "ShiftableTripleBuffer.h"
#include "PersonalPrep.h"
#include "Protocols/ReplicatedPrep.h"
namespace GC
@@ -16,13 +17,13 @@ namespace GC
template<class T> class ShareThread;
template<class T>
class RepPrep : public BufferPrep<T>, ShiftableTripleBuffer<T>
class RepPrep : public PersonalPrep<T>, ShiftableTripleBuffer<T>
{
ReplicatedBase* protocol;
public:
RepPrep(DataPositions& usage, ShareThread<T>& thread);
RepPrep(DataPositions& usage);
RepPrep(DataPositions& usage, int input_player = PersonalPrep<T>::SECURE);
~RepPrep();
void set_protocol(typename T::Protocol& protocol);
@@ -33,6 +34,8 @@ public:
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
void buffer_inputs(int player);
void get(Dtype type, T* data)
{
BufferPrep<T>::get(type, data);

Some files were not shown because too many files have changed in this diff Show More